//#define NO_OPEN_SSL /* because we're lazy and dont want to build this lib extra for the TeaClient */
#define FIXEDINT_H_INCLUDED /* else it will be included by ge */

#include "misc/endianness.h"
#include <ed25519/ed25519.h>
#include <ed25519/ge.h>
#include <log/LogUtils.h>
#include "misc/memtracker.h"
#include "misc/digest.h"
#include "CryptionHandler.h"

using namespace std;
using namespace ts;
using namespace ts::connection;
using namespace ts::protocol;


CryptionHandler::CryptionHandler() {
	memtrack::allocated<CryptionHandler>(this);
}

CryptionHandler::~CryptionHandler() {
	memtrack::freed<CryptionHandler>(this);
}

void CryptionHandler::reset() {
	this->useDefaultChipherKeyNonce = true;
	this->iv_struct_length = 0;
	memset(this->iv_struct, 0, sizeof(this->iv_struct));
	memcpy(this->current_mac, this->default_mac, sizeof(this->default_mac));

	for(auto& cache : this->cache_key_client)
		cache.generation = 0xFFEF;
	for(auto& cache : this->cache_key_server)
		cache.generation = 0xFFEF;
}

bool CryptionHandler::setupSharedSecret(const std::string& alpha, const std::string& beta, ecc_key *publicKey, ecc_key *ownKey, std::string &error) {
	size_t bufferLength = 128;
	uint8_t* buffer = new uint8_t[bufferLength];
	int err;
	if((err = ecc_shared_secret(ownKey, publicKey, buffer, (unsigned long*) &bufferLength)) != CRYPT_OK){
		delete[] buffer;
		error = "Could not calculate shared secret. Message: " + string(error_to_string(err));
		return false;
	}
	auto result = this->setupSharedSecret(alpha, beta, string((const char*) buffer, bufferLength), error);
	delete[] buffer;
	return result;
}

bool CryptionHandler::setupSharedSecret(const std::string& alpha, const std::string& beta, const std::string &sharedKey, std::string &error) {
	auto secret_hash = digest::sha1(sharedKey);

	char ivStruct[SHA_DIGEST_LENGTH];
	memcpy(ivStruct, alpha.data(), 10);
	memcpy(&ivStruct[10], beta.data(), 10);

	for (int index = 0; index < SHA_DIGEST_LENGTH; index++) {
		ivStruct[index] ^= (uint8_t) secret_hash[index];
	}

	{
		lock_guard lock(this->cache_key_lock);
		memcpy(this->iv_struct, ivStruct, SHA_DIGEST_LENGTH);
		this->iv_struct_length = SHA_DIGEST_LENGTH;

		auto iv_hash = digest::sha1(ivStruct, SHA_DIGEST_LENGTH);
		memcpy(this->current_mac, iv_hash.data(), 8);

		this->useDefaultChipherKeyNonce = false;
	}

	return true;
}

void _fe_neg(fe h, const fe f) {
	int32_t f0 = f[0];
	int32_t f1 = f[1];
	int32_t f2 = f[2];
	int32_t f3 = f[3];
	int32_t f4 = f[4];
	int32_t f5 = f[5];
	int32_t f6 = f[6];
	int32_t f7 = f[7];
	int32_t f8 = f[8];
	int32_t f9 = f[9];
	int32_t h0 = -f0;
	int32_t h1 = -f1;
	int32_t h2 = -f2;
	int32_t h3 = -f3;
	int32_t h4 = -f4;
	int32_t h5 = -f5;
	int32_t h6 = -f6;
	int32_t h7 = -f7;
	int32_t h8 = -f8;
	int32_t h9 = -f9;

	h[0] = h0;
	h[1] = h1;
	h[2] = h2;
	h[3] = h3;
	h[4] = h4;
	h[5] = h5;
	h[6] = h6;
	h[7] = h7;
	h[8] = h8;
	h[9] = h9;
}

inline void keyMul(uint8_t* target_buffer, const uint8_t* publicKey /* compressed */, const uint8_t* privateKey /* uncompressed */, bool negate){
	ge_p3 keyA{};
	ge_p2 result{};

	ge_frombytes_negate_vartime(&keyA, publicKey);
	if(negate) {
		_fe_neg(*(fe*) &keyA.X, *(const fe*) &keyA.X); /* undo negate */
		_fe_neg(*(fe*) &keyA.T, *(const fe*) &keyA.T); /* undo negate */
	}
	ge_scalarmult_vartime(&result, privateKey, &keyA);

	ge_tobytes(target_buffer, &result);
}

bool CryptionHandler::setupSharedSecretNew(const std::string &alpha, const std::string &beta, const char* privateKey /* uncompressed */, const char* publicKey /* compressed */) {
	assert(alpha.length() == 10);
	assert(beta.length() == 54);

	string shared;
	string sharedIv;
	shared.resize(32, '\0');
	sharedIv.resize(64, '\0');
	keyMul((uint8_t*) shared.data(), reinterpret_cast<const uint8_t *>(publicKey), reinterpret_cast<const uint8_t *>(privateKey), true); //Remote key get negated
	sharedIv = digest::sha512(shared);

	auto xor_key = alpha + beta;
	for(int i = 0; i < 64; i++)
		sharedIv[i] ^= xor_key[i];

	{
		lock_guard lock(this->cache_key_lock);
		memcpy(this->iv_struct, sharedIv.data(), 64);
		this->iv_struct_length = 64;

		auto digest_buffer = digest::sha1((char*) this->iv_struct, 64);
		memcpy(this->current_mac, digest_buffer.data(), 8);
		this->useDefaultChipherKeyNonce = false;
	}

	return true;
}

bool CryptionHandler::generate_key_nonce(protocol::BasicPacket* packet, bool use_default, uint8_t(& key)[16], uint8_t(& nonce)[16]){
	return this->generate_key_nonce(
			dynamic_cast<protocol::ClientPacket *>(packet) != nullptr,
			packet->type().type(),
			packet->packetId(),
			packet->generationId(),
			use_default,
			key,
			nonce
	);
}

bool CryptionHandler::generate_key_nonce(
		bool to_server, /* its from the client to the server */
		protocol::PacketType type,
		uint16_t packet_id,
		uint16_t generation,
		bool use_default,
		uint8_t (& key)[16],
		uint8_t (& nonce)[16]
) {
	if (this->useDefaultChipherKeyNonce || use_default) {
		memcpy(key, this->default_key, 16);
		memcpy(nonce, this->default_nonce, 16);
		return true;
	}

	auto& key_cache_array = to_server ? this->cache_key_client : this->cache_key_server;
	if(type < 0 || type >= key_cache_array.max_size()) {
		logError(0, "Tried to generate a crypt key with invalid type ({})!", type);
		return false;
	}

	auto& key_cache = key_cache_array[type];
	if(key_cache.generation != generation) {
		const size_t buffer_length = 6 + this->iv_struct_length;
		char* buffer = new char[buffer_length];
		memset(buffer, 0, 6 + this->iv_struct_length);

		if (to_server) {
			buffer[0] = 0x31;
		}  else {
			buffer[0] = 0x30;
		}
		buffer[1] = (char) (type & 0xF);

		le2be32(generation, buffer, 2);
		memcpy(&buffer[6], this->iv_struct, this->iv_struct_length);
		auto key_nonce = digest::sha256(buffer, 6 + this->iv_struct_length);

		memcpy(key_cache.key, key_nonce.data(), 16);
		memcpy(key_cache.nonce, key_nonce.data() + 16, 16);
		key_cache.generation = generation;

		delete[] buffer;
	}

	memcpy(key, key_cache.key, 16);
	memcpy(nonce, key_cache.nonce, 16);

	//Xor the key
	key[0] ^= (uint8_t) ((packet_id >> 8) & 0xFF);
	key[1] ^=(packet_id & 0xFF);

	return true;
}

bool CryptionHandler::verify_encryption(const pipes::buffer_view &packet, uint16_t packet_id, uint16_t generation) {
	int err;
	int success = false;

	uint8_t key[16], nonce[16];
	if(!generate_key_nonce(true, (protocol::PacketType) (packet[12] & 0xF), packet_id, generation, false, key, nonce))
		return false;

	auto mac = packet.view(0, 8);
	auto header = packet.view(8, 5);
	auto data = packet.view(13);

	auto length = data.length();

	const unsigned long target_length = 2048;
	uint8_t target_buffer[2048];
	if(target_length < length)
		return false;

	err = eax_decrypt_verify_memory(find_cipher("rijndael"),
	                                (uint8_t *) key, /* the key */
	                                (size_t)    16, /* key is 16 bytes */
	                                (uint8_t *) nonce, /* the nonce */
	                                (size_t)    16, /* nonce is 16 bytes */
	                                (uint8_t *) header.data_ptr(), /* example header */
	                                (unsigned long) header.length(), /* header length */
	                                (const unsigned char *) data.data_ptr(),
	                                (unsigned long) data.length(),
	                                (unsigned char *) target_buffer,
	                                (unsigned char *) mac.data_ptr(),
	                                (unsigned long) mac.length(),
	                                &success
	);

	return err == CRYPT_OK && success;
}

bool CryptionHandler::decryptPacket(protocol::BasicPacket *packet, std::string &error, bool use_default) {
	int err;
	int success = false;

	auto header = packet->header();
	auto data = packet->data();

	uint8_t key[16], nonce[16];
	if(!generate_key_nonce(packet, use_default, key, nonce)) {
		error = "Could not generate key/nonce";
		return false;
	}

	size_t target_length = 2048;
	uint8_t target_buffer[2048];
	auto length = data.length();
	if(target_length < length) {
		error = "buffer too large";
		return false;
	}

	err = eax_decrypt_verify_memory(find_cipher("rijndael"),
	                                (uint8_t *) key, /* the key */
	                                (unsigned long)    16, /* key is 16 bytes */
	                                (uint8_t *) nonce, /* the nonce */
	                                (unsigned long)    16, /* nonce is 16 bytes */
	                                (uint8_t *) header.data_ptr(), /* example header */
	                                (unsigned long) header.length(), /* header length */
	                                (const unsigned char *) data.data_ptr(),
	                                (unsigned long) data.length(),
	                                (unsigned char *) target_buffer,
	                                (unsigned char *) packet->mac().data_ptr(),
	                                (unsigned long) packet->mac().length(),
	                                &success
	);

	if((err) != CRYPT_OK){
		error = "eax_decrypt_verify_memory(...) returned " + to_string(err) + "/" + error_to_string(err);
		return false;
	}
	if(!success){
		error = "memory verify failed!";
		return false;
	}

	packet->data(pipes::buffer_view{target_buffer, length});
	packet->setEncrypted(false);
	return true;
}

bool CryptionHandler::encryptPacket(protocol::BasicPacket *packet, std::string &error, bool use_default) {
	uint8_t key[16], nonce[16];
	if(!generate_key_nonce(packet, use_default, key, nonce)) {
		error = "Could not generate key/nonce";
		return false;
	}

	size_t length = packet->data().length();

	size_t tag_length = 8;
	char tag_buffer[8];

	size_t target_length = 2048;
	uint8_t target_buffer[2048];
	if(target_length < length) {
		error = "buffer too large";
		return false;
	}

	int err;
	if((err = eax_encrypt_authenticate_memory(find_cipher("rijndael"),
	                                          (uint8_t *) key, /* the key */
	                                          (unsigned long)    16, /* key is 16 bytes */
	                                          (uint8_t *) nonce, /* the nonce */
	                                          (unsigned long)    16, /* nonce is 16 bytes */
	                                          (uint8_t *) packet->header().data_ptr(), /* example header */
	                                          (unsigned long) packet->header().length(), /* header length */
	                                          (uint8_t *) packet->data().data_ptr(), /* The plain text */
	                                          (unsigned long) packet->data().length(), /* Plain text length */
	                                          (uint8_t *) target_buffer, /* The result buffer */
	                                          (uint8_t *) tag_buffer,
	                                          (unsigned long *)  &tag_length
	)) != CRYPT_OK){
		error = "eax_encrypt_authenticate_memory(...) returned " + to_string(err) + "/" + error_to_string(err);
		return false;
	}
	assert(tag_length == 8);

	packet->data(pipes::buffer_view{target_buffer, length});
	packet->mac().write(tag_buffer, tag_length);
	packet->setEncrypted(true);
	return true;
}

bool CryptionHandler::progressPacketIn(protocol::BasicPacket* packet, std::string& error, bool use_default) {
	while(blocked)
		this_thread::sleep_for(chrono::microseconds(100));

	if(packet->isEncrypted()){
		bool success = decryptPacket(packet, error, use_default);
		if(success) packet->setEncrypted(false);
		return success;
	}
	return true;
}

bool CryptionHandler::progressPacketOut(protocol::BasicPacket* packet, std::string& error, bool use_default) {
	while(blocked)
		this_thread::sleep_for(chrono::microseconds(100));

	if(packet->has_flag(PacketFlag::Unencrypted)) {
		packet->mac().write(this->current_mac, 8);
	} else {
		bool success = encryptPacket(packet, error, use_default);
		if(success) packet->setEncrypted(true);
		return success;
	}
	return true;
}