//TFTP server in C++
#include <algorithm>
#include <string.h>
#include <time.h>
#include "TFTPserv.h"

const int bufferSize = 4096;
const int tftpServPort = 69;

const unsigned char OpRead = 1;
const unsigned char OpWrite = 2;
const unsigned char OpData = 3;
const unsigned char OpAck = 4;
const unsigned char OpError = 5;
const unsigned char OpOptAck = 6;

const unsigned char ErrFileNotFound = 1;
const unsigned char ErrAccessViolation = 2;
const unsigned char ErrDiskFull = 3;
const unsigned char ErrIllegalOperation = 4;
const unsigned char ErrUnknownTransferId = 5;
const unsigned char ErrFileExists = 6;
const unsigned char ErrNoSuchUser = 7;
const unsigned char ErrFailedOptNegotiation = 8;

// Creates the TFTP server, passing pointer to C caller
extern "C" void* createTFTPServer(){
	return new TFTPserv();
}
// Destroys the given TFTP server
extern "C" void destroyTFTPServer(void * server){
	delete (TFTPserv *)(server);
}
// Adds a file to the TFTP server, from C caller
extern "C" int addTFTPFile(void * server, char* filename, unsigned int filenamelen, char* file, unsigned int filelen){
	string filenamestr(filename,filenamelen);
	string filestr(file,filelen);
	((TFTPserv*)server)->addFile(filenamestr,filestr);
	return 0;
}
// Starts the TFTP server (in new thread)
extern "C" int startTFTPServer(void * server){
	return ((TFTPserv*)server)->start();
}
// Stops the TFTP server
extern "C" int stopTFTPServer(void * server){
	return ((TFTPserv*)server)->stop();
}

//constructor
TFTPserv::TFTPserv(): fileIndexes(), files(), transfers(){
	index = 0;
	thread = NULL;
	shuttingDown = false;
	smellySock = 0;
}
//htons except returns a binary string
string TFTPserv::htonstring(unsigned short input){
	unsigned short output = htons(input);
	string res((const char*)&output,2);
	return res;
}
//adds a "file" based on name, contents
void TFTPserv::addFile(string filename, string data){
	fileIndexes[filename] = index;
	files[index] = data;
	index++;
}

//checks whether a transfer needs to be marked for resend
void TFTPserv::checkRetransmission(map<string,unsigned int> & transfer){
	unsigned int elapsed = clock() / CLOCKS_PER_SEC - transfer["lastSent"];
	if ( elapsed <= transfer["timeout"] )
		return;
	if (transfer["retries"] >= 3){
		transfers.erase(&transfer);
		return;
	}
	transfer["lastSent"] = 0;
	transfer["retries"] = transfer["retries"] + 1;
}

//Gets a request packet, figures out what to do, and does it
void TFTPserv::dispatchRequest(sockaddr_in &from, string buf)
{
	unsigned short op = ntohs(*((unsigned short*)buf.c_str()));
	size_t currentSpot = 2;

	switch (op)
	{
		case OpRead:
		{
			currentSpot = buf.find('\x00',2);
			if (currentSpot == string::npos)
			{
				return;
			}

			string fn(buf.substr(2,currentSpot - 2)); //get filename
			if (fileIndexes.count(fn) == 0) //nonexistant file
			{
				return;
			}

			size_t newSpot = buf.find('\x00',currentSpot + 1);
			if (newSpot == string::npos) //invalid packet format
			{
				return;
			}
			//string mode(buf.substr(currentSpot + 1, newSpot - currentSpot - 1)); //don't really need this

			//New transfer!
			map<string,unsigned int> *transfer = new map<string,unsigned int>();
			(*transfer)["type"] = OpRead;
			(*transfer)["fromIp"] = *((unsigned int *)&from.sin_addr);
			(*transfer)["fromPort"] = from.sin_port;
			(*transfer)["file"] = fileIndexes[fn];
			(*transfer)["block"] = 1;
			(*transfer)["blksize"] = 512;
			(*transfer)["offset"] = 0;
			(*transfer)["timeout"] = 3;
			(*transfer)["lastSent"] = 0;
			(*transfer)["retries"] = 0;
			
			//process_options
			processOptions((struct sockaddr *)&from, sizeof(sockaddr_in), buf, *transfer, (unsigned int)(newSpot + 1));

			transfers.insert(transfer);
			break;
		}
		case OpAck:
		{
			//Got an ack
			unsigned short block = ntohs(*((unsigned short*)(buf.c_str() + 2)));
			map<string,unsigned int> *transfer = NULL;
			//Find transfer
			for (set<map<string, unsigned int> *>::iterator it = transfers.begin(); it != transfers.end(); ++it)
			{
				if ((*(*it))["fromIp"] == *((unsigned int *)&from.sin_addr)
					&& (*(*it))["fromPort"] == from.sin_port
					&& (*(*it))["block"] == block)
				{
					transfer = *it;
				}
			}

			if (transfer == NULL)
			{
				return;
			}

			(*transfer)["offset"] = (*transfer)["offset"] + (*transfer)["blksize"];
			(*transfer)["block"] = (*transfer)["block"] + 1;
			(*transfer)["lastSent"] = 0;
			(*transfer)["retries"] = 0;

			if ((*transfer)["offset"] <= files[(*transfer)["file"]].length())
			{
				return; //not complete
			}

			transfers.erase(transfer); // we're done!
			delete transfer;
		}
	}
}
// Extracts an int option in ascii form; if in range saves it and appends an option ack to replyPacket
void TFTPserv::checkIntOption(const char * optName, int min, int max, string & opt, string & val, map<string, unsigned int> & transfer, string & replyPacket)
{
	if (opt.compare(optName) != 0)
	{
		return;
	}

	//convert ascii to integer value
	int intval = 0;
	for (unsigned int i = 0; i < val.length(); i++)
	{
		if (val[i] >= '0' && val[i] <= '9')
		{
			intval = intval * 10 + val[i] - '0';
		}
	}

	//Validate it
	if (intval > max)
	{
		intval = max;
	}
	if (intval < min)
	{
		intval = min;
	}

	//Save it
	transfer[optName] = intval;
	//append ack
	replyPacket.append(opt).append(1, (char)0).append(val).append(1, (char)0);
}

//Parses all options from received packet
void TFTPserv::processOptions(struct sockaddr * from, unsigned int fromlen, string buf, map<string, unsigned int> & transfer, unsigned int spot)
{
	//Start with optack (two byte)
	string data = htonstring(OpOptAck);

	//Loop over options
	size_t currentSpot = spot;
	while (currentSpot < buf.length() - 4)
	{
		//Get option
		size_t nextSpot = buf.find('\x00', currentSpot);

		if (nextSpot == string::npos)
		{
			return;
		}

		string opt(buf.substr(currentSpot, nextSpot - currentSpot));

		//Get value
		currentSpot = nextSpot + 1;
		nextSpot = buf.find('\x00', currentSpot);

		if (nextSpot == string::npos)
		{
			return;
		}

		string val(buf.substr(currentSpot, nextSpot - currentSpot));
		currentSpot = nextSpot + 1;
		for (string::iterator it = opt.begin(); it < opt.end(); it++)
			*it = tolower(*it);
		checkIntOption("blksize", 8, 65464, opt, val, transfer, data);
		checkIntOption("timeout", 1, 255, opt, val, transfer, data);

		if (opt.compare("tsize") == 0)
		{
			//get length
			size_t flen = files[transfer["file"]].length();

			//convert to ascii
			string strlen;
			while (flen > 0)
			{
				strlen.insert(0, 1, (char)((flen % 10) + '0'));
				flen = flen / 10;
			}

			data.append(opt).append(1, (char)0).append(strlen).append(1, (char)0);
		}
	}
	//Send packet
	sendto(smellySock, data.c_str(), (int)data.length(), 0, from, (int)fromlen);
}

// Asks server to stop
int TFTPserv::stop()
{
	shuttingDown = true;
	DWORD res = 0xffffffff;

	if (thread != NULL)
	{
		res = WaitForSingleObject(thread, 5000);
	}

	thread = NULL;
	return res;
}

// Method to pass to CreateThread
DWORD WINAPI runTFTPServer(void* server)
{
	return ((TFTPserv*)server)->run();
}

// Starts server
int TFTPserv::start()
{
	//get socket
	smellySock = socket(AF_INET, SOCK_DGRAM, 0);

	if (smellySock == INVALID_SOCKET)
	{
		return -1;
	}

	//Se up server socket address
	struct sockaddr_in server;
	server.sin_family = AF_INET;
	server.sin_port = htons(tftpServPort);
	server.sin_addr.s_addr = 0; //a.k.a. INADDR_ANY

	// Bind address to socket
	if (bind(smellySock, (struct sockaddr *)&server, sizeof(server)) != 0)
	{
		return GetLastError();
	}

	thread = CreateThread(NULL, 0, &runTFTPServer, this, 0, NULL);

	if (thread != NULL)
	{
		return ERROR_SUCCESS;
	}

	return GetLastError();
}

//Internal run method that does all the hard work
int TFTPserv::run()
{
	// Setup timeout
	fd_set recvSet;
	fd_set sendSet;
	int n;

	//Main packet-handling loop
	shuttingDown = false;
	while (shuttingDown == false){
		FD_ZERO(&recvSet);
		FD_SET(smellySock, &recvSet);
		FD_ZERO(&sendSet);

		//Do we need to check for sent items? Let's see
		for (set<map<string, unsigned int> *>::iterator it = transfers.begin(); it != transfers.end(); ++it){
			if ((*(*it))["lastSent"] != 0){
				FD_SET(smellySock, &recvSet);
				break;
			}
		}

		struct timeval tv;
		tv.tv_sec = 1;
		tv.tv_usec = 0;
		n = select((int)(smellySock + 1), &recvSet, &sendSet, NULL, &tv);

		if (n == -1)
		{
			break; //Error
		}

		//Get request
		char receiveBuf[bufferSize];
		struct sockaddr client;
		int clientLength = sizeof(client);
		int receiveSize = 0;

		if (n != 0)
		{
			receiveSize = recvfrom(smellySock, receiveBuf, bufferSize, 0, &client, &clientLength);
		}

		if (receiveSize > 0)
		{
			string data(receiveBuf, receiveSize);
			dispatchRequest(*((sockaddr_in *)&client), data);
		}

		//Now see if we need to transmit/retransmit another block
		for (set<map<string, unsigned int> *>::iterator it = transfers.begin(); it != transfers.end(); ++it)
		{
			map<string, unsigned int> & transfer = *(*it);
			if (transfer["type"] != OpRead)
			{
				continue;
			}

			if (transfer["lastSent"] != 0)
			{
				checkRetransmission(transfer);
			}
			else
			{
				string block = files[transfer["file"]].substr(transfer["offset"], transfer["blksize"]);
				if (block.size() > 0)
				{
					string packet(htonstring(OpData));
					packet.append(htonstring(transfer["block"]));
					packet.append(block);
					//Send packet
					//first get address
					sockaddr_in client;
					memset((void*)&client, 0, sizeof(client));
					client.sin_family = AF_INET;
					*((unsigned int *)&client.sin_addr) = transfer["fromIp"];
					client.sin_port = (unsigned short)transfer["fromPort"];
					//whew. now send
					sendto(smellySock, packet.c_str(), (int)packet.size(), 0, (sockaddr*)&client, (int)sizeof(sockaddr_in));

					transfer["lastSent"] = clock() / CLOCKS_PER_SEC;
				}
			}
		}
	}

#ifdef WIN32
	closesocket(smellySock);
#else
	close(smellySock);
#endif
	return 0;
}