#include "common.h" // include the PolarSSL library #pragma comment(lib,"polarssl.lib") DWORD packet_find_tlv_buf(PUCHAR payload, DWORD payloadLength, DWORD index, TlvType type, Tlv *tlv); typedef struct _PacketCompletionRoutineEntry { LPCSTR requestId; PacketRequestCompletion handler; struct _PacketCompletionRoutineEntry *next; } PacketCompletionRoutineEntry; PacketCompletionRoutineEntry *packetCompletionRoutineList = NULL; /************ * Core API * ************/ /* * Transmit a single string to the remote connection with instructions to * print it to the screen or whatever medium has been established. */ DWORD send_core_console_write(Remote *remote, LPCSTR fmt, ...) { Packet *request = NULL; CHAR buf[8192]; va_list ap; DWORD res; do { va_start(ap, fmt); _vsnprintf(buf, sizeof(buf) - 1, fmt, ap); va_end(ap); // Create a message with the 'core_print' method if (!(request = packet_create(PACKET_TLV_TYPE_REQUEST, "core_console_write"))) { res = ERROR_NOT_ENOUGH_MEMORY; break; } // Add the string to print if ((res = packet_add_tlv_string(request, TLV_TYPE_STRING, buf)) != NO_ERROR) break; res = packet_transmit(remote, request, NULL); } while (0); // Cleanup on failure if (res != ERROR_SUCCESS) { if (request) packet_destroy(request); } return res; } /******************* * Packet Routines * *******************/ /* * Create a packet of a given type (request/response) and method. */ Packet *packet_create(PacketTlvType type, LPCSTR method) { Packet *packet = NULL; BOOL success = FALSE; do { if (!(packet = (Packet *)malloc(sizeof(Packet)))) break; memset(packet, 0, sizeof(packet)); // Initialize the header length and message type packet->header.length = htonl(sizeof(TlvHeader)); packet->header.type = htonl((DWORD)type); // Initialize the payload to be blank packet->payload = NULL; packet->payloadLength = 0; // Add the method TLV if provided if (method) { if (packet_add_tlv_string(packet, TLV_TYPE_METHOD, method) != ERROR_SUCCESS) break; } success = TRUE; } while (0); // Clean up the packet on failure if ((!success) && (packet)) { packet_destroy(packet); packet = NULL; } return packet; } /* * Create a response packet from a request, referencing the requestors * message identifier. */ Packet *packet_create_response(Packet *request) { Packet *response = NULL; Tlv method, requestId; BOOL success = FALSE; PacketTlvType responseType; if (packet_get_type(request) == PACKET_TLV_TYPE_PLAIN_REQUEST) responseType = PACKET_TLV_TYPE_PLAIN_RESPONSE; else responseType = PACKET_TLV_TYPE_RESPONSE; do { // Get the request TLV's method if (packet_get_tlv_string(request, TLV_TYPE_METHOD, &method) != ERROR_SUCCESS) break; // Try to allocate a response packet if (!(response = packet_create(responseType, (PCHAR)method.buffer))) break; // Get the request TLV's request identifier if (packet_get_tlv_string(request, TLV_TYPE_REQUEST_ID, &requestId) != ERROR_SUCCESS) break; // Add the request identifier to the packet packet_add_tlv_string(response, TLV_TYPE_REQUEST_ID, (PCHAR)requestId.buffer); success = TRUE; } while (0); // Cleanup on failure if (!success) { if (response) packet_destroy(response); response = NULL; } return response; } /* * Destroy the packet context and the payload buffer */ VOID packet_destroy(Packet *packet) { if (packet->payload) free(packet->payload); free(packet); } /* * Add a TLV as a string, including the null terminator. */ DWORD packet_add_tlv_string(Packet *packet, TlvType type, LPCSTR str) { return packet_add_tlv_raw(packet, type, (PUCHAR)str, strlen(str) + 1); } /* * Add a TLV as a string, including the null terminator. */ DWORD packet_add_tlv_uint(Packet *packet, TlvType type, UINT val) { val = htonl(val); return packet_add_tlv_raw(packet, type, (PUCHAR)&val, sizeof(val)); } /* * Add a TLV as a bool. */ DWORD packet_add_tlv_bool(Packet *packet, TlvType type, BOOL val) { return packet_add_tlv_raw(packet, type, (PUCHAR)&val, 1); } /* * Add a TLV group. A TLV group is a TLV that contains multiple sub-TLVs */ DWORD packet_add_tlv_group(Packet *packet, TlvType type, Tlv *entries, DWORD numEntries) { DWORD totalSize = 0, offset = 0, index = 0, res = ERROR_SUCCESS; PCHAR buffer = NULL; // Calculate the total TLV size. for (index = 0; index < numEntries; index++) totalSize += entries[index].header.length + sizeof(TlvHeader); do { // Allocate storage for the complete buffer if (!(buffer = (PCHAR)malloc(totalSize))) { res = ERROR_NOT_ENOUGH_MEMORY; break; } // Copy the memory into the new buffer for (index = 0; index < numEntries; index++) { TlvHeader rawHeader; // Convert byte order for storage rawHeader.length = htonl(entries[index].header.length + sizeof(TlvHeader)); rawHeader.type = htonl((DWORD)entries[index].header.type); // Copy the TLV header & payload memcpy(buffer + offset, &rawHeader, sizeof(TlvHeader)); memcpy(buffer + offset + sizeof(TlvHeader), entries[index].buffer, entries[index].header.length); // Update the offset into the buffer offset += entries[index].header.length + sizeof(TlvHeader); } // Now add the TLV group with its contents populated res = packet_add_tlv_raw(packet, type, buffer, totalSize); } while (0); // Free the temporary buffer if (buffer) free(buffer); return res; } /* * Add an array of TLVs */ DWORD packet_add_tlvs(Packet *packet, Tlv *entries, DWORD numEntries) { DWORD index; for (index = 0; index < numEntries; index++) packet_add_tlv_raw(packet, entries[index].header.type, entries[index].buffer, entries[index].header.length); return ERROR_SUCCESS; } /* * Add an arbitrary TLV */ DWORD packet_add_tlv_raw(Packet *packet, TlvType type, LPVOID buf, DWORD length) { DWORD headerLength = sizeof(TlvHeader); DWORD realLength = length + headerLength; DWORD newPayloadLength = packet->payloadLength + realLength; PUCHAR newPayload = NULL; // Allocate/Reallocate the packet's payload if (packet->payload) newPayload = (PUCHAR)realloc(packet->payload, newPayloadLength); else newPayload = (PUCHAR)malloc(newPayloadLength); if (!newPayload) return ERROR_NOT_ENOUGH_MEMORY; // Populate the new TLV ((LPDWORD)(newPayload + packet->payloadLength))[0] = htonl(realLength); ((LPDWORD)(newPayload + packet->payloadLength))[1] = htonl((DWORD)type); memcpy(newPayload + packet->payloadLength + headerLength, buf, length); // Update the header length and payload length packet->header.length = htonl(ntohl(packet->header.length) + realLength); packet->payload = newPayload; packet->payloadLength = newPayloadLength; return ERROR_SUCCESS; } /* * Checks to see if a tlv is null terminated */ DWORD packet_is_tlv_null_terminated(Packet *packet, Tlv *tlv) { if ((tlv->header.length) && (tlv->buffer[tlv->header.length - 1] != 0)) return ERROR_NOT_FOUND; return ERROR_SUCCESS; } /* * Get the type of the packet */ PacketTlvType packet_get_type(Packet *packet) { return (PacketTlvType)ntohl(packet->header.type); } TlvMetaType packet_get_tlv_meta(Packet *packet, Tlv *tlv) { return TLV_META_TYPE_MASK(tlv->header.type); } /* * Get the TLV of the given type */ DWORD packet_get_tlv(Packet *packet, TlvType type, Tlv *tlv) { return packet_enum_tlv(packet, 0, type, tlv); } /* * Get a TLV as a string */ DWORD packet_get_tlv_string(Packet *packet, TlvType type, Tlv *tlv) { DWORD res; if ((res = packet_get_tlv(packet, type, tlv)) == ERROR_SUCCESS) res = packet_is_tlv_null_terminated(packet, tlv); return res; } /* * Enumerate a TLV group (a TLV that consists other multiple sub-TLVs) and * finds the first match of a given type, if it exists. */ DWORD packet_get_tlv_group_entry(Packet *packet, Tlv *group, TlvType type, Tlv *entry) { return packet_find_tlv_buf(group->buffer, group->header.length, 0, type, entry); } /* * Enumerate a TLV, optionally of a specified typed. */ DWORD packet_enum_tlv(Packet *packet, DWORD index, TlvType type, Tlv *tlv) { return packet_find_tlv_buf(packet->payload, packet->payloadLength, index, type, tlv); } /* * Get the value of a string TLV */ PCHAR packet_get_tlv_value_string(Packet *packet, TlvType type) { Tlv stringTlv; PCHAR string = NULL; if (packet_get_tlv_string(packet, type, &stringTlv) == ERROR_SUCCESS) string = (PCHAR)stringTlv.buffer; return string; } /* * Get the value of a UINT TLV */ UINT packet_get_tlv_value_uint(Packet *packet, TlvType type) { Tlv uintTlv; if ((packet_get_tlv(packet, type, &uintTlv) != ERROR_SUCCESS) || (uintTlv.header.length < sizeof(DWORD))) return 0; return ntohl(*(LPDWORD)uintTlv.buffer); } /* * Get the value of a bool TLV */ BOOL packet_get_tlv_value_bool(Packet *packet, TlvType type) { Tlv boolTlv; BOOL val = FALSE; if (packet_get_tlv(packet, type, &boolTlv) == ERROR_SUCCESS) val = (BOOL)(*(PCHAR)boolTlv.buffer); return val; } /* * Add an exception to a packet */ DWORD packet_add_exception(Packet *packet, DWORD code, PCHAR fmt, ...) { DWORD codeNbo = htonl(code); char buf[8192]; Tlv entries[2]; va_list ap; // Ensure null termination buf[sizeof(buf) - 1] = 0; va_start(ap, fmt); _vsnprintf(buf, sizeof(buf) - 1, fmt, ap); va_end(ap); // Populate the TLV group array entries[0].header.type = TLV_TYPE_EXCEPTION_CODE; entries[0].header.length = 4; entries[0].buffer = (PUCHAR)&codeNbo; entries[1].header.type = TLV_TYPE_EXCEPTION_STRING; entries[1].header.length = strlen(buf) + 1; entries[1].buffer = buf; // Add the TLV group, or try to at least. return packet_add_tlv_group(packet, TLV_TYPE_EXCEPTION, entries, 2); } /* * Get the result code from the packet */ DWORD packet_get_result(Packet *packet) { return packet_get_tlv_value_uint(packet, TLV_TYPE_RESULT); } /* * Enumerate TLV entries in a buffer until hitting a given index (optionally * for a given type as well). */ DWORD packet_find_tlv_buf(PUCHAR payload, DWORD payloadLength, DWORD index, TlvType type, Tlv *tlv) { DWORD currentIndex = 0; DWORD offset = 0, length = 0; BOOL found = FALSE; PUCHAR current; memset(tlv, 0, sizeof(Tlv)); do { // Enumerate the TLV's for (current = payload, length = 0; !found && current; offset += length, current += length) { TlvHeader *header = (TlvHeader *)current; if ((current + sizeof(TlvHeader) > payload + payloadLength) || (current < payload)) break; // TLV's length length = ntohl(header->length); // Matching type? if (((TlvType)ntohl(header->type) != type) && (type != TLV_TYPE_ANY)) continue; // Matching index? if (currentIndex != index) { currentIndex++; continue; } if ((current + length > payload + payloadLength) || (current < payload)) break; tlv->header.type = ntohl(header->type); tlv->header.length = ntohl(header->length) - sizeof(TlvHeader); tlv->buffer = payload + offset + sizeof(TlvHeader); found = TRUE; } } while (0); return (found) ? ERROR_SUCCESS : ERROR_NOT_FOUND; } /*********************** * Completion Routines * ***********************/ /* * Add a completion routine for a given request identifier */ DWORD packet_add_completion_handler(LPCSTR requestId, PacketRequestCompletion *completion) { PacketCompletionRoutineEntry *entry; DWORD res = ERROR_SUCCESS; do { // Allocate the entry if (!(entry = (PacketCompletionRoutineEntry *)malloc( sizeof(PacketCompletionRoutineEntry)))) { res = ERROR_NOT_ENOUGH_MEMORY; break; } // Copy the completion routine information memcpy(&entry->handler, completion, sizeof(PacketRequestCompletion)); // Copy the request identifier if (!(entry->requestId = strdup(requestId))) { res = ERROR_NOT_ENOUGH_MEMORY; free(entry); break; } // Add the entry to the list entry->next = packetCompletionRoutineList; packetCompletionRoutineList = entry; } while (0); return res; } /* * Call the register completion handler(s) for the given request identifier. */ DWORD packet_call_completion_handlers(Remote *remote, Packet *response, LPCSTR requestId) { PacketCompletionRoutineEntry *current; DWORD result = packet_get_result(response); DWORD matches = 0; Tlv methodTlv; LPCSTR method = NULL; // Get the method associated with this packet if (packet_get_tlv_string(response, TLV_TYPE_METHOD, &methodTlv) == ERROR_SUCCESS) method = (LPCSTR)methodTlv.buffer; // Enumerate the completion routine list for (current = packetCompletionRoutineList; current; current = current->next) { // Does the request id of the completion entry match the packet's request // id? if (strcmp(requestId, current->requestId)) continue; // Call the completion routine current->handler.routine(remote, response, current->handler.context, method, result); // Increment the number of matched handlers matches++; } if (matches) packet_remove_completion_handler(requestId); return (matches > 0) ? ERROR_SUCCESS : ERROR_NOT_FOUND; } /* * Remove one or more completion handlers for the given request identifier */ DWORD packet_remove_completion_handler(LPCSTR requestId) { PacketCompletionRoutineEntry *current, *next, *prev; // Enumerate the list, removing entries that match for (current = packetCompletionRoutineList, next = NULL, prev = NULL; current; prev = current, current = next) { next = current->next; if (strcmp(requestId, current->requestId)) continue; // Remove the entry from the list if (prev) prev->next = next; else packetCompletionRoutineList = next; // Deallocate it free((PCHAR)current->requestId); free(current); } return ERROR_SUCCESS; } /* * Transmit and destroy a packet */ DWORD packet_transmit(Remote *remote, Packet *packet, PacketRequestCompletion *completion) { CryptoContext *crypto; Tlv requestId; DWORD res; // If the packet does not already have a request identifier, create // one for it if (packet_get_tlv_string(packet, TLV_TYPE_REQUEST_ID, &requestId) != ERROR_SUCCESS) { DWORD index; CHAR rid[32]; rid[sizeof(rid) - 1] = 0; for (index = 0; index < sizeof(rid) - 1; index++) rid[index] = (rand() % 0x5e) + 0x21; packet_add_tlv_string(packet, TLV_TYPE_REQUEST_ID, rid); } do { // If a completion routine was supplied and the packet has a request // identifier, insert the completion routine into the list if ((completion) && (packet_get_tlv_string(packet, TLV_TYPE_REQUEST_ID, &requestId) == ERROR_SUCCESS)) packet_add_completion_handler((LPCSTR)requestId.buffer, completion); // If the endpoint has a cipher established and this is not a plaintext // packet, we encrypt if ((crypto = remote_get_cipher(remote)) && (packet_get_type(packet) != PACKET_TLV_TYPE_PLAIN_REQUEST) && (packet_get_type(packet) != PACKET_TLV_TYPE_PLAIN_RESPONSE)) { ULONG origPayloadLength = packet->payloadLength; PUCHAR origPayload = packet->payload; // Encrypt if ((res = crypto->handlers.encrypt(crypto, packet->payload, packet->payloadLength, &packet->payload, &packet->payloadLength)) != ERROR_SUCCESS) { SetLastError(res); break; } // Destroy the original payload as we no longer need it free(origPayload); // Update the header length packet->header.length = htonl(packet->payloadLength + sizeof(TlvHeader)); } // Transmit the packet's header (length, type) if (ssl_write(&remote->ssl, (LPCSTR)&packet->header, sizeof(packet->header)) == SOCKET_ERROR) break; // Transmit the packet's payload if (ssl_write(&remote->ssl, packet->payload, packet->payloadLength) == SOCKET_ERROR) break; // Destroy the packet packet_destroy(packet); SetLastError(ERROR_SUCCESS); } while (0); res = GetLastError(); return res; } /* * Transmits a response with nothing other than a result code in it */ DWORD packet_transmit_empty_response(Remote *remote, Packet *packet, DWORD res) { Packet *response = packet_create_response(packet); if (!response) return ERROR_NOT_ENOUGH_MEMORY; // Add the result code packet_add_tlv_uint(response, TLV_TYPE_RESULT, res); // Transmit the response return packet_transmit(remote, response, NULL); } /* * Receive a new packet */ DWORD packet_receive(Remote *remote, Packet **packet) { DWORD headerBytes = 0, payloadBytesLeft = 0, res; CryptoContext *crypto = NULL; Packet *localPacket = NULL; TlvHeader header; LONG bytesRead; BOOL inHeader = TRUE; PUCHAR payload = NULL; ULONG payloadLength; do { // Read the packet length while (inHeader) { if ((bytesRead = ssl_read(&remote->ssl, ((PUCHAR)&header + headerBytes), sizeof(TlvHeader) - headerBytes, 0)) <= 0) { if(bytesRead == POLARSSL_ERR_NET_TRY_AGAIN) continue; if (!bytesRead) SetLastError(ERROR_NOT_FOUND); break; } headerBytes += bytesRead; if (headerBytes != sizeof(TlvHeader)) continue; else inHeader = FALSE; } if (bytesRead != sizeof(TlvHeader)) break; // Initialize the header header.length = header.length; header.type = header.type; payloadLength = ntohl(header.length) - sizeof(TlvHeader); payloadBytesLeft = payloadLength; // Allocate the payload if (!(payload = (PUCHAR)malloc(payloadLength))) { SetLastError(ERROR_NOT_ENOUGH_MEMORY); break; } // Read the payload while (payloadBytesLeft > 0) { if ((bytesRead = ssl_read(&remote->ssl, payload + payloadLength - payloadBytesLeft, payloadBytesLeft, 0)) <= 0) { if(bytesRead == POLARSSL_ERR_NET_TRY_AGAIN) continue; if (GetLastError() == WSAEWOULDBLOCK) continue; if (!bytesRead) SetLastError(ERROR_NOT_FOUND); break; } payloadBytesLeft -= bytesRead; } // Didn't finish? if (payloadBytesLeft) break; // Allocate a packet structure if (!(localPacket = (Packet *)malloc(sizeof(Packet)))) { SetLastError(ERROR_NOT_ENOUGH_MEMORY); break; } // If the connection has an established cipher and this packet is not // plaintext, decrypt if ((crypto = remote_get_cipher(remote)) && (packet_get_type(localPacket) != PACKET_TLV_TYPE_PLAIN_REQUEST) && (packet_get_type(localPacket) != PACKET_TLV_TYPE_PLAIN_RESPONSE)) { ULONG origPayloadLength = payloadLength; PUCHAR origPayload = payload; // Decrypt if ((res = crypto->handlers.decrypt(crypto, payload, payloadLength, &payload, &payloadLength)) != ERROR_SUCCESS) { SetLastError(res); break; } // We no longer need the encrypted payload free(origPayload); } localPacket->header.length = header.length; localPacket->header.type = header.type; localPacket->payload = payload; localPacket->payloadLength = payloadLength; *packet = localPacket; SetLastError(ERROR_SUCCESS); } while (0); res = GetLastError(); // Cleanup on failure if (res != ERROR_SUCCESS) { if (payload) free(payload); if (localPacket) free(localPacket); } return res; }