#include "common.h"
#include "pivot_tree.h"

typedef struct _PivotNode
{
	BYTE guid[sizeof(GUID)];
	PivotContext* ctx;

	struct _PivotNode* left;
	struct _PivotNode* right;
} PivotNode;

#ifdef DEBUGTRACE
void pivot_tree_to_string(char** buffer, PivotNode* node, char* prefix)
{
	// each line is the prefix size, plus the guid size plus a null and a \n and the two pointers
	int curLen = *buffer ? (int)strlen(*buffer) : 0;
	int newLen = (int)strlen(prefix) + 32 + 2 + (sizeof(LPVOID) * 2 + 8) * 2;
	*buffer = (char*)realloc(*buffer, curLen + 1 + newLen);
	if (node != NULL)
	{
		PUCHAR h = node->guid;
		sprintf(*buffer + curLen, "%s%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X (%p) (%p)\n\x00",
			prefix,
			h[0], h[1], h[2], h[3], h[4], h[5], h[6], h[7], h[8], h[9], h[10], h[11], h[12], h[13], h[14], h[15],
			node->left, node->right);
		char p[512];
		char f[512];
		sprintf(f, "%%%ds  | ", strlen(prefix));
		sprintf(p, f, " ");
		// print the right hand side first, as it seems to make sense when viewing the content
		pivot_tree_to_string(buffer, node->right, p);
		pivot_tree_to_string(buffer, node->left, p);
	}
	else
	{
		sprintf(*buffer + strlen(*buffer), "%sNULL\n", prefix);
	}
}

void dbgprint_pivot_tree(PivotTree* tree)
{
	char* buffer = NULL;
	pivot_tree_to_string(&buffer, tree->head, "  ");
	if (buffer)
	{
		dprintf("[PIVOTTREE] contents:\n%s", buffer);
		free(buffer);
	}
}
#endif

PivotTree* pivot_tree_create()
{
	return (PivotTree*)calloc(1, sizeof(PivotTree));
}

DWORD pivot_tree_add_node(PivotNode* parent, PivotNode* node)
{
	int cmp = memcmp(node->guid, parent->guid, sizeof(parent->guid));

	if (cmp < 0)
	{
		if (parent->left == NULL)
		{
			dprintf("[PIVOTTREE] Adding node to left");
			parent->left = node;
			return ERROR_SUCCESS;
		}

		dprintf("[PIVOTTREE] Adding node to left subtree");
		return pivot_tree_add_node(parent->left, node);
	}

	if (parent->right == NULL)
	{
		dprintf("[PIVOTTREE] Adding node to right");
		parent->right = node;
		return ERROR_SUCCESS;
	}

	dprintf("[PIVOTTREE] Adding node to right subtree");
	return pivot_tree_add_node(parent->right, node);
}

DWORD pivot_tree_add(PivotTree* tree, LPBYTE guid, PivotContext* ctx)
{
	PivotNode* node = (PivotNode*)calloc(1, sizeof(PivotNode));
#ifdef DEBUGTRACE
	PUCHAR h = (PUCHAR)&guid[0];
	dprintf("[PIVOTTREE] Adding GUID: %02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X",
		h[0], h[1], h[2], h[3], h[4], h[5], h[6], h[7], h[8], h[9], h[10], h[11], h[12], h[13], h[14], h[15]);
#endif

	memcpy(node->guid, guid, sizeof(node->guid));
	node->ctx = ctx;

	if (tree->head == NULL)
	{
		tree->head = node;
		return ERROR_SUCCESS;
	}

	return pivot_tree_add_node(tree->head, node);
}

PivotNode* pivot_tree_largest_node(PivotNode* node)
{
	if (node == NULL)
	{
		return NULL;
	}

	if (node->right == NULL)
	{
		return node;
	}
	return pivot_tree_largest_node(node->right);
}

PivotContext* pivot_tree_remove_node(PivotNode* parent, LPBYTE guid)
{
	dprintf("[PIVOTTREE] Trying to remove from %p (%p) (%p)", parent, parent->left, parent->right);
	int cmp = memcmp(guid, parent->guid, sizeof(parent->guid));
	if (cmp < 0 && parent->left != NULL)
	{
		dprintf("[PIVOTTREE] Removing from left subtree");
		int cmp = memcmp(guid, parent->left->guid, sizeof(parent->guid));
		dprintf("[PIVOTTREE] Right left compare: %d", cmp);
		if (cmp == 0)
		{
			dprintf("[PIVOTTREE] Removing right child");
			PivotNode* remove = parent->left;
			PivotNode* left = remove->left;
			PivotNode* largest = pivot_tree_largest_node(left);

			if (largest != NULL)
			{
				largest->right = remove->right;
				parent->left = left;
			}
			else
			{
				parent->left = remove->right;
			}

			PivotContext* context = remove->ctx;
			free(remove);
			return context;
		}

		return pivot_tree_remove_node(parent->left, guid);
	}

	if (cmp > 0 && parent->right != NULL)
	{
		dprintf("[PIVOTTREE] Removing from right subtree");
		int cmp = memcmp(guid, parent->right->guid, sizeof(parent->guid));
		dprintf("[PIVOTTREE] Right subtree compare: %d", cmp);
		if (cmp == 0)
		{
			dprintf("[PIVOTTREE] Removing right child");
			PivotNode* remove = parent->right;
			PivotNode* left = remove->left;
			PivotNode* largest = pivot_tree_largest_node(left);

			if (largest != NULL)
			{
				largest->right = remove->right;
				parent->right = left;
			}
			else
			{
				parent->right = remove->right;
			}

			PivotContext* context = remove->ctx;
			free(remove);
			return context;
		}

		return pivot_tree_remove_node(parent->right, guid);
	}

	return NULL;
}

PivotContext* pivot_tree_remove(PivotTree* tree, LPBYTE guid)
{
#ifdef DEBUGTRACE
	PUCHAR h = (PUCHAR)&guid[0];
	dprintf("[PIVOTTREE] Removing GUID: %02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X",
		h[0], h[1], h[2], h[3], h[4], h[5], h[6], h[7], h[8], h[9], h[10], h[11], h[12], h[13], h[14], h[15]);
#endif

	if (tree->head == NULL)
	{
		return NULL;
	}

	int cmp = memcmp(guid, tree->head->guid, sizeof(tree->head->guid));

	if (cmp == 0)
	{
		dprintf("[PIVOTTREE] Removing head node");
		PivotNode* remove = tree->head;
		PivotNode* left = tree->head->left;
		PivotNode* largest = pivot_tree_largest_node(left);

		if (largest != NULL)
		{
			largest->right = tree->head->right;
			tree->head = left;
		}
		else
		{
			tree->head = tree->head->right;
		}

		PivotContext* context = remove->ctx;
		free(remove);
		return context;
	}

	dprintf("[PIVOTTREE] Removing non-head node");
	return pivot_tree_remove_node(tree->head, guid);
}

PivotContext* pivot_tree_find_node(PivotNode* node, LPBYTE guid)
{
	if (node == NULL)
	{
		dprintf("[PIVOTTREE] Current pivot node is null, bailing out");
		return NULL;
	}

#ifdef DEBUGTRACE
	PUCHAR h = (PUCHAR)&guid[0];
	dprintf("[PIVOTTREE] Saerch GUID: %02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X",
		h[0], h[1], h[2], h[3], h[4], h[5], h[6], h[7], h[8], h[9], h[10], h[11], h[12], h[13], h[14], h[15]);
	h = node->guid;
	dprintf("[PIVOTTREE] Node   GUID: %02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X%02X",
		h[0], h[1], h[2], h[3], h[4], h[5], h[6], h[7], h[8], h[9], h[10], h[11], h[12], h[13], h[14], h[15]);
#endif

	int cmp = memcmp(guid, node->guid, sizeof(node->guid));
	if (cmp == 0)
	{
		dprintf("[PIVOTTREE] node found");
		return node->ctx;
	}

	if (cmp < 0)
	{
		dprintf("[PIVOTTREE] Searching left subtree");
		return pivot_tree_find_node(node->left, guid);
	}

	dprintf("[PIVOTTREE] Searching right subtree");
	return pivot_tree_find_node(node->right, guid);
}

PivotContext* pivot_tree_find(PivotTree* tree, LPBYTE guid)
{
	dprintf("[PIVOTTREE] search tree %p, head node %p", tree, tree->head);
	return pivot_tree_find_node(tree->head, guid);
}

void pivot_tree_traverse_node(PivotNode* node, PivotTreeTraverseCallback callback, LPVOID state)
{
	if (node != NULL)
	{
		pivot_tree_traverse_node(node->left, callback, state);
		callback(node->guid, node->ctx, state);
		pivot_tree_traverse_node(node->right, callback, state);
	}
}

void pivot_tree_traverse(PivotTree* tree, PivotTreeTraverseCallback callback, LPVOID state)
{
	pivot_tree_traverse_node(tree->head, callback, state);
}

void pivot_tree_destroy_node(PivotNode* node)
{
	if (node != NULL)
	{
		pivot_tree_destroy_node(node->left);
		pivot_tree_destroy_node(node->right);
		free(node);
	}
}

void pivot_tree_destroy(PivotTree* tree)
{
	pivot_tree_destroy_node(tree->head);
	free(tree);
}