#include "common.h" #include "remote.h" #include "packet_encryption.h" typedef struct _CryptProviderParams { const TCHAR* provider; const DWORD type; const DWORD flags; } CryptProviderParams; typedef struct _RsaKey { BLOBHEADER header; DWORD length; BYTE key[1]; } RsaKey; const CryptProviderParams AesProviders[] = { {MS_ENH_RSA_AES_PROV, PROV_RSA_AES, 0}, {MS_ENH_RSA_AES_PROV, PROV_RSA_AES, CRYPT_NEWKEYSET}, {MS_ENH_RSA_AES_PROV_XP, PROV_RSA_AES, 0}, {MS_ENH_RSA_AES_PROV_XP, PROV_RSA_AES, CRYPT_NEWKEYSET} }; DWORD decrypt_packet(Remote* remote, Packet** packet, LPBYTE buffer, DWORD bufferSize) { DWORD result = ERROR_SUCCESS; Packet* localPacket = NULL; HCRYPTKEY dupKey = 0; #ifdef DEBUGTRACE PUCHAR h = buffer; vdprintf("[DEC] Packet header: [0x%02X 0x%02X 0x%02X 0x%02X] [0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X] [0x%02X 0x%02X 0x%02X 0x%02X] [0x%02X 0x%02X 0x%02X 0x%02X] [0x%02X 0x%02X 0x%02X 0x%02X]", h[0], h[1], h[2], h[3], h[4], h[5], h[6], h[7], h[8], h[9], h[10], h[11], h[12], h[13], h[14], h[15], h[16], h[17], h[18], h[19], h[20], h[21], h[22], h[23], h[24], h[25], h[26], h[27], h[28], h[29], h[30], h[31]); #endif vdprintf("[DEC] Packet buffer size is: %u", bufferSize); do { PacketHeader* header = (PacketHeader*)buffer; // Start by decoding the entire packet xor_bytes(header->xor_key, buffer + sizeof(header->xor_key), bufferSize - sizeof(header->xor_key)); #ifdef DEBUGTRACE h = buffer; vdprintf("[DEC] Packet header: [0x%02X 0x%02X 0x%02X 0x%02X] [0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X] [0x%02X 0x%02X 0x%02X 0x%02X] [0x%02X 0x%02X 0x%02X 0x%02X] [0x%02X 0x%02X 0x%02X 0x%02X]", h[0], h[1], h[2], h[3], h[4], h[5], h[6], h[7], h[8], h[9], h[10], h[11], h[12], h[13], h[14], h[15], h[16], h[17], h[18], h[19], h[20], h[21], h[22], h[23], h[24], h[25], h[26], h[27], h[28], h[29], h[30], h[31]); #endif // Allocate a packet structure if (!(localPacket = (Packet *)calloc(1, sizeof(Packet)))) { result = ERROR_NOT_ENOUGH_MEMORY; break; } DWORD encFlags = ntohl(header->enc_flags); vdprintf("[DEC] Encryption flags set to %x", encFlags); // Only decrypt if the context was set up correctly if (remote->enc_ctx != NULL && remote->enc_ctx->valid && encFlags != ENC_FLAG_NONE) { vdprintf("[DEC] Context is valid, moving on ... "); LPBYTE payload = buffer + sizeof(PacketHeader); // the first 16 bytes of the payload we're given is the IV LPBYTE iv = payload; vdprintf("[DEC] IV: %02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X", iv[0], iv[1], iv[2], iv[3], iv[4], iv[5], iv[6], iv[7], iv[8], iv[9], iv[10], iv[11], iv[12], iv[13], iv[14], iv[15]); // the rest of the payload bytes contains the actual encrypted data DWORD encryptedSize = ntohl(header->length) - sizeof(TlvHeader) - AES256_BLOCKSIZE; LPBYTE encryptedData = payload + AES256_BLOCKSIZE; vdprintf("[DEC] Encrypted Size: %u (%x)", encryptedSize, encryptedSize); vdprintf("[DEC] Encrypted Size mod AES256_BLOCKSIZE: %u", encryptedSize % AES256_BLOCKSIZE); if (!CryptDuplicateKey(remote->enc_ctx->aes_key, NULL, 0, &dupKey)) { result = GetLastError(); vdprintf("[DEC] Failed to duplicate key: %d (%x)", result, result); break; } DWORD mode = CRYPT_MODE_CBC; if (!CryptSetKeyParam(dupKey, KP_MODE, (const BYTE*)&mode, 0)) { result = GetLastError(); dprintf("[ENC] Failed to set mode to CBC: %d (%x)", result, result); break; } // decrypt! if (!CryptSetKeyParam(remote->enc_ctx->aes_key, KP_IV, iv, 0)) { result = GetLastError(); vdprintf("[DEC] Failed to set IV: %d (%x)", result, result); break; } if (!CryptDecrypt(remote->enc_ctx->aes_key, 0, TRUE, 0, encryptedData, &encryptedSize)) { result = GetLastError(); vdprintf("[DEC] Failed to decrypt: %d (%x)", result, result); break; } // shift the decrypted data back to the start of the packet buffer so that we // can pretend it's a normal packet memmove_s(iv, encryptedSize, encryptedData, encryptedSize); // adjust the header size header->length = htonl(encryptedSize + sizeof(TlvHeader)); // done, the packet parsing can continue as normal now } localPacket->header.length = header->length; localPacket->header.type = header->type; localPacket->payloadLength = ntohl(localPacket->header.length) - sizeof(TlvHeader); vdprintf("[DEC] Actual payload Length: %d", localPacket->payloadLength); vdprintf("[DEC] Header Type: %d", ntohl(localPacket->header.type)); localPacket->payload = malloc(localPacket->payloadLength); if (localPacket->payload == NULL) { vdprintf("[DEC] failed to allocate payload"); result = ERROR_NOT_ENOUGH_MEMORY; break; } vdprintf("[DEC] Local packet payload successfully allocated, copying data"); memcpy_s(localPacket->payload, localPacket->payloadLength, buffer + sizeof(PacketHeader), localPacket->payloadLength); #ifdef DEBUGTRACE h = localPacket->payload; vdprintf("[DEC] TLV 1 length / type: [0x%02X 0x%02X 0x%02X 0x%02X] [0x%02X 0x%02X 0x%02X 0x%02X]", h[0], h[1], h[2], h[3], h[4], h[5], h[6], h[7]); DWORD tl = ntohl(((TlvHeader*)h)->length); vdprintf("[DEC] Skipping %u bytes", tl); h += tl; vdprintf("[DEC] TLV 2 length / type: [0x%02X 0x%02X 0x%02X 0x%02X] [0x%02X 0x%02X 0x%02X 0x%02X]", h[0], h[1], h[2], h[3], h[4], h[5], h[6], h[7]); #endif vdprintf("[DEC] Writing localpacket %p to packet pointer %p", localPacket, packet); *packet = localPacket; } while (0); if (result != ERROR_SUCCESS) { if (localPacket != NULL) { packet_destroy(localPacket); } } return result; } DWORD encrypt_packet(Remote* remote, Packet* packet, LPBYTE* buffer, LPDWORD bufferSize) { DWORD result = ERROR_SUCCESS; HCRYPTKEY dupKey = 0; vdprintf("[ENC] Preparing for encryption ..."); // create a new XOR key here, because the content will be copied into the final // payload as part of the prepration process rand_xor_key(packet->header.xor_key); // copy the session ID to the header as this will be used later to identify the packet's destination session memcpy_s(packet->header.session_guid, sizeof(packet->header.session_guid), remote->orig_config->session.session_guid, sizeof(remote->orig_config->session.session_guid)); // Only encrypt if the context was set up correctly if (remote->enc_ctx != NULL && remote->enc_ctx->valid) { vdprintf("[ENC] Context is valid, moving on ... "); // only encrypt the packet if encryption has been enabled if (remote->enc_ctx->enabled) { do { vdprintf("[ENC] Context is enabled, doing the AES encryption"); if (!CryptDuplicateKey(remote->enc_ctx->aes_key, NULL, 0, &dupKey)) { result = GetLastError(); vdprintf("[ENC] Failed to duplicate AES key: %d (%x)", result, result); break; } DWORD mode = CRYPT_MODE_CBC; if (!CryptSetKeyParam(dupKey, KP_MODE, (const BYTE*)&mode, 0)) { result = GetLastError(); dprintf("[ENC] Failed to set mode to CBC: %d (%x)", result, result); break; } BYTE iv[AES256_BLOCKSIZE]; if (!CryptGenRandom(remote->enc_ctx->provider, sizeof(iv), iv)) { result = GetLastError(); vdprintf("[ENC] Failed to generate random IV: %d (%x)", result, result); } vdprintf("[ENC] IV: %02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X", iv[0], iv[1], iv[2], iv[3], iv[4], iv[5], iv[6], iv[7], iv[8], iv[9], iv[10], iv[11], iv[12], iv[13], iv[14], iv[15]); if (!CryptSetKeyParam(dupKey, KP_IV, iv, 0)) { result = GetLastError(); vdprintf("[ENC] Failed to set IV: %d (%x)", result, result); break; } vdprintf("[ENC] IV Set successfully"); // mark this packet as an encrypted packet packet->header.enc_flags = htonl(ENC_FLAG_AES256); // Round up DWORD maxEncryptSize = ((packet->payloadLength / AES256_BLOCKSIZE) + 1) * AES256_BLOCKSIZE; // Need to have space for the IV at the start, as well as the packet Header DWORD memSize = maxEncryptSize + sizeof(iv) + sizeof(packet->header); *buffer = (BYTE*)malloc(memSize); BYTE* headerPos = *buffer; BYTE* ivPos = headerPos + sizeof(packet->header); BYTE* payloadPos = ivPos + sizeof(iv); *bufferSize = packet->payloadLength; // prepare the payload memcpy_s(payloadPos, packet->payloadLength, packet->payload, packet->payloadLength); if (!CryptEncrypt(dupKey, 0, TRUE, 0, payloadPos, bufferSize, maxEncryptSize)) { result = GetLastError(); vdprintf("[ENC] Failed to encrypt: %d (%x)", result, result); } else { vdprintf("[ENC] Data encrypted successfully, size is %u", *bufferSize); } // update the length to match the size of the encrypted data with IV and the TlVHeader packet->header.length = ntohl(*bufferSize + sizeof(iv) + sizeof(TlvHeader)); // update the returned total size to include both the IV and header size. *bufferSize += sizeof(iv) + sizeof(packet->header); // write the header and IV to the payload memcpy_s(headerPos, sizeof(packet->header), &packet->header, sizeof(packet->header)); memcpy_s(ivPos, sizeof(iv), iv, sizeof(iv)); } while (0); } else { dprintf("[ENC] Enabling the context"); // if the encryption is valid, then we set the enbaled flag here because // we know that the first packet going out is the response to the negotiation // and from here we want to make sure that the encryption function is on. remote->enc_ctx->enabled = TRUE; } } else { vdprintf("[ENC] No encryption context present"); } // if we don't have a valid buffer at this point, we'll create one and add the packet as per normal if (*buffer == NULL) { *bufferSize = packet->payloadLength + sizeof(packet->header); *buffer = (BYTE*)malloc(*bufferSize); BYTE* headerPos = *buffer; BYTE* payloadPos = headerPos + sizeof(packet->header); // mark this packet as a non-encrypted packet packet->header.enc_flags = htonl(ENC_FLAG_NONE); memcpy_s(headerPos, sizeof(packet->header), &packet->header, sizeof(packet->header)); memcpy_s(payloadPos, packet->payloadLength, packet->payload, packet->payloadLength); } vdprintf("[ENC] Packet buffer size is: %u", *bufferSize); #ifdef DEBUGTRACE LPBYTE h = *buffer; vdprintf("[ENC] Sending header (before XOR): [0x%02X 0x%02X 0x%02X 0x%02X] [0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X] [0x%02X 0x%02X 0x%02X 0x%02X] [0x%02X 0x%02X 0x%02X 0x%02X] [0x%02X 0x%02X 0x%02X 0x%02X]", h[0], h[1], h[2], h[3], h[4], h[5], h[6], h[7], h[8], h[9], h[10], h[11], h[12], h[13], h[14], h[15], h[16], h[17], h[18], h[19], h[20], h[21], h[22], h[23], h[24], h[25], h[26], h[27], h[28], h[29], h[30], h[31]); #endif // finally XOR obfuscate like we always did before, skippig the xor key itself. xor_bytes(packet->header.xor_key, *buffer + sizeof(packet->header.xor_key), *bufferSize - sizeof(packet->header.xor_key)); vdprintf("[ENC] Packet encoded and ready for transmission"); #ifdef DEBUGTRACE vdprintf("[ENC] Sending header (after XOR): [0x%02X 0x%02X 0x%02X 0x%02X] [0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X 0x%02X] [0x%02X 0x%02X 0x%02X 0x%02X] [0x%02X 0x%02X 0x%02X 0x%02X] [0x%02X 0x%02X 0x%02X 0x%02X]", h[0], h[1], h[2], h[3], h[4], h[5], h[6], h[7], h[8], h[9], h[10], h[11], h[12], h[13], h[14], h[15], h[16], h[17], h[18], h[19], h[20], h[21], h[22], h[23], h[24], h[25], h[26], h[27], h[28], h[29], h[30], h[31]); #endif if (dupKey != 0) { CryptDestroyKey(dupKey); } return result; } DWORD public_key_encrypt(CHAR* publicKeyPem, unsigned char* data, DWORD dataLength, unsigned char** encryptedData, DWORD* encryptedDataLength) { DWORD result = ERROR_SUCCESS; LPBYTE pubKeyBin = NULL; CERT_PUBLIC_KEY_INFO* pubKeyInfo = NULL; HCRYPTPROV rsaProv = 0; HCRYPTKEY pubCryptKey = 0; LPBYTE cipherText = NULL; do { if (publicKeyPem == NULL) { result = ERROR_BAD_ARGUMENTS; break; } DWORD binaryRequiredSize = 0; CryptStringToBinaryA(publicKeyPem, 0, CRYPT_STRING_BASE64HEADER, NULL, &binaryRequiredSize, NULL, NULL); dprintf("[ENC] Required size for the binary key is: %u (%x)", binaryRequiredSize, binaryRequiredSize); pubKeyBin = (LPBYTE)malloc(binaryRequiredSize); if (pubKeyBin == NULL) { result = ERROR_OUTOFMEMORY; break; } if (!CryptStringToBinaryA(publicKeyPem, 0, CRYPT_STRING_BASE64HEADER, pubKeyBin, &binaryRequiredSize, NULL, NULL)) { result = GetLastError(); dprintf("[ENC] Failed to convert the given base64 encoded key into bytes: %u (%x)", result, result); break; } DWORD keyRequiredSize = 0; if (!CryptDecodeObjectEx(X509_ASN_ENCODING, X509_PUBLIC_KEY_INFO, pubKeyBin, binaryRequiredSize, CRYPT_ENCODE_ALLOC_FLAG, 0, &pubKeyInfo, &keyRequiredSize)) { result = GetLastError(); dprintf("[ENC] Failed to decode: %u (%x)", result, result); break; } dprintf("[ENC] Key algo: %s", pubKeyInfo->Algorithm.pszObjId); if (!CryptAcquireContext(&rsaProv, NULL, MS_ENHANCED_PROV, PROV_RSA_FULL, CRYPT_VERIFYCONTEXT)) { dprintf("[ENC] Failed to create the RSA provider with CRYPT_VERIFYCONTEXT"); if (!CryptAcquireContext(&rsaProv, NULL, MS_ENHANCED_PROV, PROV_RSA_FULL, CRYPT_NEWKEYSET)) { result = GetLastError(); dprintf("[ENC] Failed to create the RSA provider with CRYPT_NEWKEYSET: %u (%x)", result, result); break; } else { dprintf("[ENC] Created the RSA provider with CRYPT_NEWKEYSET"); } } else { dprintf("[ENC] Created the RSA provider with CRYPT_VERIFYCONTEXT"); } if (!CryptImportPublicKeyInfo(rsaProv, X509_ASN_ENCODING, pubKeyInfo, &pubCryptKey)) { result = GetLastError(); dprintf("[ENC] Failed to import the key: %u (%x)", result, result); break; } DWORD requiredEncSize = dataLength; CryptEncrypt(pubCryptKey, 0, TRUE, 0, NULL, &requiredEncSize, requiredEncSize); dprintf("[ENC] Encrypted data length: %u (%x)", requiredEncSize, requiredEncSize); cipherText = (LPBYTE)calloc(1, requiredEncSize); if (cipherText == NULL) { result = ERROR_OUTOFMEMORY; break; } memcpy_s(cipherText, requiredEncSize, data, dataLength); if (!CryptEncrypt(pubCryptKey, 0, TRUE, 0, cipherText, &dataLength, requiredEncSize)) { result = GetLastError(); dprintf("[ENC] Failed to encrypt: %u (%x)", result, result); } else { dprintf("[ENC] Encryption witih RSA succeded, byteswapping because MS is stupid and does stuff in little endian."); // Given that we are encrypting such a small amount of data, we're going to assume that the size // of the key matches the size of the block of data we've decrypted. for (DWORD i = 0; i < requiredEncSize / 2; ++i) { BYTE b = cipherText[i]; cipherText[i] = cipherText[requiredEncSize - i - 1]; cipherText[requiredEncSize - i - 1] = b; } *encryptedData = cipherText; *encryptedDataLength = requiredEncSize; } } while (0); if (result != ERROR_SUCCESS) { if (cipherText != NULL) { free(cipherText); } } if (pubKeyInfo != NULL) { LocalFree(pubKeyInfo); } if (pubCryptKey != 0) { CryptDestroyKey(pubCryptKey); } if (rsaProv != 0) { CryptReleaseContext(rsaProv, 0); } return result; } DWORD free_encryption_context(Remote* remote) { DWORD result = ERROR_SUCCESS; dprintf("[ENC] Freeing encryption context %p", remote->enc_ctx); if (remote->enc_ctx != NULL) { dprintf("[ENC] Encryption context not null, so ditching AES key %ul", remote->enc_ctx->aes_key); if (remote->enc_ctx->aes_key != 0) { CryptDestroyKey(remote->enc_ctx->aes_key); } dprintf("[ENC] Encryption context not null, so ditching provider"); if (remote->enc_ctx->provider != 0) { CryptReleaseContext(remote->enc_ctx->provider, 0); } dprintf("[ENC] Encryption context not null, so freeing the context"); free(remote->enc_ctx); remote->enc_ctx = NULL; } return result; } DWORD request_negotiate_aes_key(Remote* remote, Packet* packet) { DWORD result = ERROR_SUCCESS; Packet* response = packet_create_response(packet); do { if (remote->enc_ctx != NULL) { free_encryption_context(remote); } remote->enc_ctx = (PacketEncryptionContext*)calloc(1, sizeof(PacketEncryptionContext)); if (remote->enc_ctx == NULL) { dprintf("[ENC] failed to allocate the encryption context"); result = ERROR_OUTOFMEMORY; break; } PacketEncryptionContext* ctx = remote->enc_ctx; for (int i = 0; i < _countof(AesProviders); ++i) { if (!CryptAcquireContext(&ctx->provider, NULL, AesProviders[i].provider, AesProviders[i].type, AesProviders[i].flags)) { result = GetLastError(); dprintf("[ENC] failed to acquire the crypt context %d: %d (%x)", i, result, result); } else { result = ERROR_SUCCESS; ctx->provider_idx = i; dprintf("[ENC] managed to acquire the crypt context %d!", i); break; } } if (result != ERROR_SUCCESS) { break; } ctx->key_data.header.bType = PLAINTEXTKEYBLOB; ctx->key_data.header.bVersion = CUR_BLOB_VERSION; ctx->key_data.header.aiKeyAlg = CALG_AES_256; ctx->key_data.length = sizeof(ctx->key_data.key); if (!CryptGenRandom(ctx->provider, ctx->key_data.length, ctx->key_data.key)) { result = GetLastError(); dprintf("[ENC] failed to generate random key: %d (%x)", result, result); break; } if (!CryptImportKey(ctx->provider, (const BYTE*)&ctx->key_data, sizeof(Aes256Key), 0, 0, &ctx->aes_key)) { result = GetLastError(); dprintf("[ENC] failed to import random key: %d (%x)", result, result); break; } // now we need to encrypt this key data using the public key given CHAR* pubKeyPem = packet_get_tlv_value_string(packet, TLV_TYPE_RSA_PUB_KEY); unsigned char* cipherText = NULL; DWORD cipherTextLength = 0; DWORD pubEncryptResult = public_key_encrypt(pubKeyPem, remote->enc_ctx->key_data.key, remote->enc_ctx->key_data.length, &cipherText, &cipherTextLength); packet_add_tlv_uint(response, TLV_TYPE_SYM_KEY_TYPE, ENC_FLAG_AES256); if (pubEncryptResult == ERROR_SUCCESS && cipherText != NULL) { // encryption succeeded, pass this key back to the call in encrypted form packet_add_tlv_raw(response, TLV_TYPE_ENC_SYM_KEY, cipherText, cipherTextLength); free(cipherText); } else { // no public key was given, so send it back in the raw packet_add_tlv_raw(response, TLV_TYPE_SYM_KEY, remote->enc_ctx->key_data.key, remote->enc_ctx->key_data.length); } ctx->valid = TRUE; } while (0); packet_transmit_response(result, remote, response); remote->enc_ctx->enabled = TRUE; return ERROR_SUCCESS; }