313 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			313 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
#include <netinet/tcp.h>
 | 
						|
#include <log/LogUtils.h>
 | 
						|
#include <misc/memtracker.h>
 | 
						|
#include "crypt.h"
 | 
						|
#define DEFINE_HELPER
 | 
						|
#include "LicenseRequest.h"
 | 
						|
#include "License.h"
 | 
						|
#include <csignal>
 | 
						|
 | 
						|
using namespace std;
 | 
						|
using namespace std::chrono;
 | 
						|
using namespace ts;
 | 
						|
using namespace license;
 | 
						|
 | 
						|
#define DEBUG_LICENSE_CLIENT
 | 
						|
#define CERR(message) LICENSE_FERR(this, CouldNotConnectException, message)
 | 
						|
 | 
						|
 | 
						|
LicenceRequest::LicenceRequest(const std::shared_ptr<LicenseRequestData> & license, const sockaddr_in& remoteAddr) : data(license) {
 | 
						|
#ifdef DEBUG_LICENSE_CLIENT
 | 
						|
    memtrack::allocated<LicenceRequest>(this);
 | 
						|
#endif
 | 
						|
    memcpy(&this->remote_address, &remoteAddr, sizeof(remoteAddr));
 | 
						|
 | 
						|
    assert(license->info);
 | 
						|
}
 | 
						|
 | 
						|
LicenceRequest::~LicenceRequest() {
 | 
						|
#ifdef DEBUG_LICENSE_CLIENT
 | 
						|
    memtrack::freed<LicenceRequest>(this);
 | 
						|
#endif
 | 
						|
    this->abortRequest();
 | 
						|
 | 
						|
    if(this->closeThread) {
 | 
						|
        this->closeThread->join();
 | 
						|
        delete this->closeThread;
 | 
						|
        this->closeThread = nullptr;
 | 
						|
    }
 | 
						|
 | 
						|
 | 
						|
	delete this->currentFuture;
 | 
						|
	this->currentFuture = nullptr;
 | 
						|
}
 | 
						|
 | 
						|
threads::Future<std::shared_ptr<LicenseRequestResponse>> LicenceRequest::requestInfo() {
 | 
						|
    {
 | 
						|
        lock_guard lock(this->lock);
 | 
						|
        if(this->currentFuture) return *this->currentFuture;
 | 
						|
        this->currentFuture = new threads::Future<std::shared_ptr<LicenseRequestResponse>>();
 | 
						|
    }
 | 
						|
 | 
						|
    this->beginRequest();
 | 
						|
    return *this->currentFuture;
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
//Basic IO
 | 
						|
void LicenceRequest::handleEventWrite(int fd, short event, void* ptrClient) {
 | 
						|
    auto* client = static_cast<LicenceRequest *>(ptrClient);
 | 
						|
 | 
						|
    buffer::RawBuffer* buffer = nullptr;
 | 
						|
    {
 | 
						|
	    lock_guard lock(client->lock);
 | 
						|
        if((event & EV_TIMEOUT) > 0) { //Connect timeout
 | 
						|
	        LICENSE_FERR(client, ConnectionException, "Connect timeout");
 | 
						|
        	return;
 | 
						|
        }
 | 
						|
 | 
						|
	    if(client->state == protocol::CONNECTING){
 | 
						|
		    client->handleConnected();
 | 
						|
	    }
 | 
						|
 | 
						|
	    if(client->state == protocol::UNCONNECTED || !client->event_write)
 | 
						|
		    return;
 | 
						|
 | 
						|
        buffer = TAILQ_FIRST(&client->writeQueue);
 | 
						|
        if(!buffer) return;
 | 
						|
 | 
						|
        auto writtenBytes = send(fd, &buffer->buffer[buffer->index], buffer->length - buffer->index, MSG_NOSIGNAL | MSG_DONTWAIT);
 | 
						|
        buffer->index += writtenBytes;
 | 
						|
 | 
						|
        if(buffer->index >= buffer->length) {
 | 
						|
            TAILQ_REMOVE(&client->writeQueue, buffer, tail);
 | 
						|
            delete buffer;
 | 
						|
        }
 | 
						|
        if(!TAILQ_EMPTY(&client->writeQueue))
 | 
						|
            event_add(client->event_write, nullptr);
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
void LicenceRequest::sendPacket(const protocol::packet& packet) {
 | 
						|
	if(this->state == protocol::UNCONNECTED || this->state == protocol::DISCONNECTING) {
 | 
						|
		if(this->verbose)
 | 
						|
			logError("Tried to send a packet to an unconnected remote!");
 | 
						|
		return;
 | 
						|
	}
 | 
						|
    packet.prepare();
 | 
						|
 | 
						|
    auto buffer = new buffer::RawBuffer(packet.data.length() + sizeof(packet.header));
 | 
						|
    memcpy(buffer->buffer, &packet.header, sizeof(packet.header));
 | 
						|
    memcpy(&buffer->buffer[sizeof(packet.header)], packet.data.data(), packet.data.length());
 | 
						|
 | 
						|
    if(!this->cryptKey.empty())
 | 
						|
        xorBuffer(&buffer->buffer[sizeof(packet.header)], packet.data.length(), this->cryptKey.data(), this->cryptKey.length());
 | 
						|
 | 
						|
    {
 | 
						|
	    lock_guard lock(this->lock);
 | 
						|
        TAILQ_INSERT_TAIL(&this->writeQueue, buffer, tail);
 | 
						|
        if(this->event_write)
 | 
						|
	        event_add(this->event_write, nullptr);
 | 
						|
    }
 | 
						|
}
 | 
						|
 | 
						|
void LicenceRequest::handleEventRead(int fd, short, void* ptrClient) {
 | 
						|
    auto* client = static_cast<LicenceRequest *>(ptrClient);
 | 
						|
 | 
						|
	auto buffer = std::unique_ptr<void, decltype(free)*>{malloc(1024), free};
 | 
						|
    sockaddr_in remoteAddr{};
 | 
						|
    socklen_t remoteAddrSize = sizeof(remoteAddr);
 | 
						|
 | 
						|
    auto read = recvfrom(fd, buffer.get(), 1024, MSG_NOSIGNAL | MSG_DONTWAIT, reinterpret_cast<sockaddr *>(&remoteAddr), &remoteAddrSize);
 | 
						|
 | 
						|
    if(read < 0){
 | 
						|
        if(errno == EWOULDBLOCK) return;
 | 
						|
	    if(client->event_read)
 | 
						|
		    event_del_noblock(client->event_read);
 | 
						|
        LICENSE_FERR(client, ConnectionException, "Invalid read: " + string(strerror(errno)) + "/" + to_string(errno));
 | 
						|
	    return;
 | 
						|
    } else if(read == 0) {
 | 
						|
    	if(client->event_read)
 | 
						|
    	    event_del_noblock(client->event_read);
 | 
						|
	    LICENSE_FERR(client, ConnectionException, "IO error (" + to_string(errno) + "): " + string(strerror(errno)));
 | 
						|
	    return;
 | 
						|
    }
 | 
						|
 | 
						|
    client->handleMessage(string((char*) buffer.get(), read));
 | 
						|
}
 | 
						|
 | 
						|
 | 
						|
static int enabled = 1;
 | 
						|
static int disabled = 0;
 | 
						|
void LicenceRequest::beginRequest() {
 | 
						|
	lock_guard lock(this->lock);
 | 
						|
	TAILQ_INIT(&this->writeQueue);
 | 
						|
 | 
						|
    this->file_descriptor = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0);
 | 
						|
    if(this->file_descriptor < 0) CERR("Socket setup failed");
 | 
						|
 | 
						|
	signal(SIGPIPE, SIG_IGN);
 | 
						|
 | 
						|
	auto state = ::connect(this->file_descriptor, reinterpret_cast<const sockaddr *>(&this->remote_address), sizeof(this->remote_address));
 | 
						|
    if(state < 0 && errno != EINPROGRESS) CERR("connect() failed (" + string(strerror(errno)) + ")");
 | 
						|
 | 
						|
    if(setsockopt(this->file_descriptor, SOL_SOCKET, SO_REUSEADDR, &enabled, sizeof(enabled)) < 0) CERR("could not set reuse addr");
 | 
						|
    if(setsockopt(this->file_descriptor, IPPROTO_TCP, TCP_CORK, &disabled, sizeof(disabled)) < 0) CERR("could not set no push");
 | 
						|
 | 
						|
    if(fcntl(this->file_descriptor, F_SETFD, fcntl(this->file_descriptor, F_GETFL, 0)  | FD_CLOEXEC | O_NONBLOCK) < 0) CERR("Failed to set FD_CLOEXEC and O_NONBLOCK");
 | 
						|
 | 
						|
    this->event_base = event_base_new();
 | 
						|
    this->event_read = event_new(this->event_base, this->file_descriptor, EV_READ | EV_PERSIST, LicenceRequest::handleEventRead, this);
 | 
						|
    this->event_write = event_new(this->event_base, this->file_descriptor, EV_WRITE, LicenceRequest::handleEventWrite, this);
 | 
						|
 | 
						|
	this->state = protocol::CONNECTING; //First set connected, then we could enable the event loop
 | 
						|
 | 
						|
    event_dispatch = std::thread([&]() {
 | 
						|
        signal(SIGPIPE, SIG_IGN);
 | 
						|
 | 
						|
	    { /* now we could start listening */
 | 
						|
	    	lock_guard _lock(this->lock);
 | 
						|
	    	if(!this->event_read || !this->event_write) return;
 | 
						|
 | 
						|
		    event_add(this->event_read, nullptr);
 | 
						|
		    timeval connect_timeout{5, 0};
 | 
						|
		    event_add(this->event_write, &connect_timeout);
 | 
						|
	    }
 | 
						|
 | 
						|
	    event_base_dispatch(this->event_base);
 | 
						|
    });
 | 
						|
}
 | 
						|
 | 
						|
void LicenceRequest::handleConnected() {
 | 
						|
	this->state = protocol::HANDSCAKE;
 | 
						|
 | 
						|
	uint8_t handshakeBuffer[4];
 | 
						|
	handshakeBuffer[0] = 0xC0;
 | 
						|
	handshakeBuffer[1] = 0xFF;
 | 
						|
	handshakeBuffer[2] = 0xEE;
 | 
						|
	handshakeBuffer[3] = LICENSE_PROT_VERSION;
 | 
						|
	this->sendPacket(protocol::packet{protocol::PACKET_CLIENT_HANDSHAKE, string((const char*) handshakeBuffer, 4)}); //Initialise packet
 | 
						|
}
 | 
						|
 | 
						|
void LicenceRequest::handleMessage(const std::string& message) {
 | 
						|
	this->buffer += message;
 | 
						|
    if(this->buffer.length() < sizeof(protocol::packet::header))
 | 
						|
    	return;
 | 
						|
 | 
						|
    protocol::packet packet{protocol::PACKET_DISCONNECT, ""};
 | 
						|
    memcpy(&packet.header,  this->buffer.data(), sizeof(protocol::packet::header));
 | 
						|
    if(packet.header.length <= this->buffer.length() - sizeof(protocol::packet::header)) {
 | 
						|
	    packet.data = this->buffer.substr(sizeof(protocol::packet::header), packet.header.length);
 | 
						|
	    this->buffer = this->buffer.substr(sizeof(protocol::packet::header) + packet.header.length);
 | 
						|
    } else {
 | 
						|
    	return;
 | 
						|
    }
 | 
						|
 | 
						|
    if(!this->cryptKey.empty()) {
 | 
						|
        xorBuffer((char*) packet.data.data(), packet.data.length(), this->cryptKey.data(), this->cryptKey.length());
 | 
						|
    }
 | 
						|
 | 
						|
	if(packet.header.packetId == protocol::PACKET_SERVER_HANDSHAKE) {
 | 
						|
        this->handlePacketHandshake(packet.data);
 | 
						|
	} else if(packet.header.packetId == protocol::PACKET_DISCONNECT) {
 | 
						|
        this->handlePacketDisconnect(packet.data);
 | 
						|
	} else if(packet.header.packetId == protocol::PACKET_SERVER_VALIDATION_RESPONSE) {
 | 
						|
        this->handlePacketLicenseInfo(packet.data);
 | 
						|
	} else if(packet.header.packetId == protocol::PACKET_SERVER_PROPERTY_ADJUSTMENT) {
 | 
						|
        this->handlePacketInfoAdjustment(packet.data);
 | 
						|
    } else
 | 
						|
		LICENSE_FERR(this, ConnectionException, "Invalid packet id (" + to_string(packet.header.packetId) + ")");
 | 
						|
 | 
						|
	if(!this->buffer.empty() && this->state != protocol::DISCONNECTING && this->state != protocol::UNCONNECTED)
 | 
						|
		this->handleMessage("");
 | 
						|
}
 | 
						|
 | 
						|
void LicenceRequest::disconnect(const std::string& message) {
 | 
						|
	if(this->state != protocol::UNCONNECTED && this->state != protocol::DISCONNECTING)
 | 
						|
		this->sendPacket({protocol::PACKET_DISCONNECT, message});
 | 
						|
	this->closeConnection();
 | 
						|
	//TODO flush?
 | 
						|
}
 | 
						|
 | 
						|
void LicenceRequest::closeConnection() {
 | 
						|
	event *event_read, *event_write;
 | 
						|
    {
 | 
						|
        lock_guard lock(this->lock);
 | 
						|
        if(this->state == protocol::UNCONNECTED) return;
 | 
						|
 | 
						|
        if(this->event_dispatch.get_id() == this_thread::get_id()) { //We could not close in the same thread as we read/write (we're joining it later)
 | 
						|
            if(this->state == protocol::DISCONNECTING) return;
 | 
						|
 | 
						|
            this->state = protocol::DISCONNECTING;
 | 
						|
            this->closeThread = new threads::Thread(THREAD_SAVE_OPERATIONS, [&]() { this->closeConnection(); });
 | 
						|
 | 
						|
#ifdef DEBUG_LICENSE_CLIENT
 | 
						|
	        if(this->verbose) {
 | 
						|
		        debugMessage("Running close in a new thread");
 | 
						|
		        this->closeThread->name("License request close");
 | 
						|
	        }
 | 
						|
#endif
 | 
						|
            return;
 | 
						|
        }
 | 
						|
        this->state = protocol::UNCONNECTED;
 | 
						|
 | 
						|
	    event_read = this->event_read;
 | 
						|
	    event_write = this->event_write;
 | 
						|
 | 
						|
	    this->event_write = nullptr;
 | 
						|
	    this->event_read = nullptr;
 | 
						|
    }
 | 
						|
 | 
						|
    if(event_read) {
 | 
						|
        event_del_block(event_read);
 | 
						|
        event_free(event_read);
 | 
						|
    }
 | 
						|
 | 
						|
    if(event_write) {
 | 
						|
	    event_del_block(event_write);
 | 
						|
	    event_free(event_write);
 | 
						|
    }
 | 
						|
 | 
						|
    /* close before base shutdown (else epoll hangup) */
 | 
						|
	if(this->file_descriptor > 0) {
 | 
						|
		shutdown(this->file_descriptor, SHUT_RDWR);
 | 
						|
		close(this->file_descriptor);
 | 
						|
	}
 | 
						|
	this->file_descriptor = 0;
 | 
						|
 | 
						|
    {
 | 
						|
	    lock_guard lock(this->lock);
 | 
						|
        ts::buffer::RawBuffer* buffer;
 | 
						|
        while ((buffer = TAILQ_FIRST(&this->writeQueue))) {
 | 
						|
            TAILQ_REMOVE(&this->writeQueue, buffer, tail);
 | 
						|
            delete buffer;
 | 
						|
        }
 | 
						|
    }
 | 
						|
 | 
						|
	{
 | 
						|
		if(this->event_base) {
 | 
						|
			timeval seconds{1, 0};
 | 
						|
			event_base_loopexit(this->event_base, &seconds);
 | 
						|
			event_base_loopexit(this->event_base, nullptr);
 | 
						|
		}
 | 
						|
 | 
						|
		if(this->event_dispatch.joinable()) {
 | 
						|
			this->event_dispatch.join();
 | 
						|
		}
 | 
						|
 | 
						|
		if(this->event_base) {
 | 
						|
			event_base_free(this->event_base);
 | 
						|
			this->event_base = nullptr;
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
#ifdef DEBUG_LICENSE_CLIENT
 | 
						|
	if(this->verbose)
 | 
						|
	    debugMessage("Executing close done");
 | 
						|
#endif
 | 
						|
}
 | 
						|
 | 
						|
void LicenceRequest::abortRequest(const std::chrono::system_clock::time_point &timeout) {
 | 
						|
	this->closeConnection();
 | 
						|
}
 |