526 lines
18 KiB
C++
Raw Normal View History

#include <algorithm>
#include <src/server/QueryServer.h>
#include "QueryClient.h"
#include <netinet/tcp.h>
#include <src/Configuration.h>
#include <log/LogUtils.h>
#include <misc/memtracker.h>
#include "src/InstanceHandler.h"
#include <pipes/errors.h>
#include <query/command2.h>
#include <misc/std_unique_ptr.h>
using namespace std;
using namespace std::chrono;
using namespace ts;
using namespace ts::server;
#if defined(TCP_CORK) && !defined(TCP_NOPUSH)
#define TCP_NOPUSH TCP_CORK
#endif
//#define DEBUG_TRAFFIC
QueryClient::QueryClient(QueryServer* handle, int sockfd) : ConnectedClient(handle->sql, nullptr), handle(handle), clientFd(sockfd) {
memtrack::allocated<QueryClient>(this);
int enabled = 1;
int disabled = 0;
setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, &enabled, sizeof(enabled));
if(setsockopt(sockfd, IPPROTO_TCP, TCP_NOPUSH, &disabled, sizeof disabled) < 0) {
logError(this->getServerId(), "[Query] Could not disable nopush for {} ({}/{})", CLIENT_STR_LOG_PREFIX_(this), errno, strerror(errno));
}
if(setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, &enabled, sizeof enabled) < 0) {
logError(this->getServerId(), "[Query] Could not disable no delay for {} ({}/{})", CLIENT_STR_LOG_PREFIX_(this), errno, strerror(errno));
}
this->readEvent = event_new(this->handle->eventLoop, this->clientFd, EV_READ | EV_PERSIST, [](int a, short b, void* c){ ((QueryClient*) c)->handleMessageRead(a, b, c); }, this);
this->writeEvent = event_new(this->handle->eventLoop, this->clientFd, EV_WRITE, [](int a, short b, void* c){ ((QueryClient*) c)->handleMessageWrite(a, b, c); }, this);
this->state = ConnectionState::CONNECTED;
connectedTimestamp = system_clock::now();
this->resetEventMask();
}
void QueryClient::applySelfLock(const std::shared_ptr<ts::server::QueryClient> &cl) {
this->_this = cl;
}
QueryClient::~QueryClient() {
memtrack::freed<QueryClient>(this);
// if(this->closeLock.tryLock() != 0)
// logCritical("Query manager deleted, but is still in usage! (closeLock)");
// if(this->bufferLock.tryLock() != 0)
// logCritical("Query manager deleted, but is still in usage! (bufferLock)");
this->ssl_handler.finalize();
}
void QueryClient::preInitialize() {
this->properties()[property::CLIENT_TYPE] = ClientType::CLIENT_QUERY;
this->properties()[property::CLIENT_TYPE_EXACT] = ClientType::CLIENT_QUERY;
this->properties()[property::CLIENT_UNIQUE_IDENTIFIER] = "UnknownQuery";
this->properties()[property::CLIENT_NICKNAME] = string() + "ServerQuery#" + this->getLoggingPeerIp() + "/" + to_string(this->getPeerPort());
DatabaseHelper::assignDatabaseId(this->sql, this->getServerId(), _this.lock());
if(ts::config::query::sslMode == 0) {
this->connectionType = ConnectionType::PLAIN;
this->postInitialize();
}
}
void QueryClient::postInitialize() {
lock_guard<recursive_mutex> lock(this->lock_packet_handle); /* we dont want to handle anything while we're initializing */
this->connectTimestamp = system_clock::now();
this->properties()[property::CLIENT_LASTCONNECTED] = duration_cast<seconds>(this->connectTimestamp.time_since_epoch()).count();
if(ts::config::query::sslMode == 1 && this->connectionType != ConnectionType::SSL_ENCRIPTED) {
this->notifyError({findError("failed_connection_initialisation"), "Please use a SSL encryption!"});
this->disconnect("Please us a SSL encryption for more security.\nThe server denies also all other connections!");
return;
}
writeMessage(config::query::motd);
assert(this->handle);
if(this->handle->ip_blacklist) {
assert(this->handle->ip_blacklist);
if(this->handle->ip_blacklist->contains(this->remote_address)) {
Command cmd("error");
auto err = findError("client_login_not_permitted");
cmd["id"] = err.errorId;
cmd["msg"] = err.message;
cmd["extra_msg"] = "You're not permitted to use the query interface! (Your blacklisted)";
this->sendCommand(cmd);
this->disconnect("blacklisted");
return;;
}
if(this->handle->ip_whitelist)
this->whitelisted = this->handle->ip_whitelist->contains(this->remote_address);
else
this->whitelisted = false;
debugMessage(LOG_QUERY, "Got new query client from {}. Whitelisted: {}", this->getLoggingPeerIp(), this->whitelisted);
}
if(!this->whitelisted) {
threads::MutexLock lock(this->handle->loginLock);
if(this->handle->queryBann.count(this->getPeerIp()) > 0) {
auto ban = this->handle->queryBann[this->getPeerIp()];
Command cmd("error");
auto err = findError("server_connect_banned");
cmd["id"] = err.errorId;
cmd["msg"] = err.message;
cmd["extra_msg"] = "you may retry in " + to_string(duration_cast<seconds>(ban - system_clock::now()).count()) + " seconds";
this->sendCommand(cmd);
this->disconnect("");
}
}
this->update_cached_permissions();
}
void QueryClient::writeMessage(const std::string& message) {
if(this->state == ConnectionState::DISCONNECTED || !this->handle) return;
if(this->connectionType == ConnectionType::PLAIN) this->writeRawMessage(message);
else if(this->connectionType == ConnectionType::SSL_ENCRIPTED) this->ssl_handler.send(pipes::buffer_view{(void*) message.data(), message.length()});
else logCritical("Invalid query connection type!");
}
bool QueryClient::disconnect(const std::string &reason) {
if(!reason.empty()) {
Command cmd("disconnect");
cmd["reason"] = reason;
this->sendCommand(cmd);
}
return this->closeConnection(system_clock::now() + seconds(1));
}
bool QueryClient::closeConnection(const std::chrono::system_clock::time_point& flushTimeout) {
auto ownLock = dynamic_pointer_cast<QueryClient>(_this.lock());
if(!ownLock) return false;
unique_lock<std::recursive_mutex> handleLock(this->lock_packet_handle);
unique_lock<threads::Mutex> lock(this->closeLock);
bool flushing = flushTimeout.time_since_epoch().count() != 0;
if(this->state == ConnectionState::DISCONNECTED || (flushing && this->state == ConnectionState::DISCONNECTING)) return false;
this->state = flushing ? ConnectionState::DISCONNECTING : ConnectionState::DISCONNECTED;
if(this->readEvent) { //Attention dont trigger this within the read thread!
event_del_block(this->readEvent);
event_free(this->readEvent);
this->readEvent = nullptr;
}
if(this->server){
{
unique_lock channel_lock(this->server->channel_tree_lock);
this->server->unregisterClient(_this.lock(), "disconnected", channel_lock);
}
2019-09-14 14:22:16 +02:00
this->server->groups->disableCache(this->getClientDatabaseId());
this->server = nullptr;
}
if(flushing){
this->flushThread = new threads::Thread(THREAD_SAVE_OPERATIONS | THREAD_EXECUTE_LATER, [ownLock, flushTimeout](){
while(ownLock->state == ConnectionState::DISCONNECTING && flushTimeout > system_clock::now()){
{
lock_guard<threads::Mutex> l(ownLock->bufferLock);
if(ownLock->readQueue.empty() && ownLock->writeQueue.empty()) break;
}
usleep(10 * 1000);
}
if(ownLock->state == ConnectionState::DISCONNECTING) ownLock->disconnectFinal();
});
flushThread->name("Flush thread QC").execute();
} else {
threads::MutexLock l1(this->flushThreadLock);
handleLock.unlock();
lock.unlock();
if(this->flushThread){
threads::NegatedMutexLock l(this->closeLock);
this->flushThread->join();
}
disconnectFinal();
}
return true;
}
void QueryClient::disconnectFinal() {
lock_guard<recursive_mutex> lock_tick(this->lock_query_tick);
lock_guard<recursive_mutex> lock_handle(this->lock_packet_handle);
threads::MutexLock lock_close(this->closeLock);
threads::MutexLock lock_buffer(this->bufferLock);
if(final_disconnected) {
logError(LOG_QUERY, "Tried to disconnect a client twice!");
return;
}
final_disconnected = true;
this->state = ConnectionState::DISCONNECTED;
{
threads::MutexTryLock l(this->flushThreadLock);
if(!!l) {
if(this->flushThread) {
this->flushThread->detach();
delete this->flushThread; //Release the captured this lock
this->flushThread = nullptr;
}
}
}
if(this->writeEvent) {
event_del_block(this->writeEvent);
event_free(this->writeEvent);
this->writeEvent = nullptr;
}
if(this->readEvent) {
event_del_block(this->readEvent);
event_free(this->readEvent);
this->readEvent = nullptr;
}
if(this->clientFd > 0) {
if(shutdown(this->clientFd, SHUT_RDWR) < 0)
debugMessage(LOG_QUERY, "Could not shutdown query client socket! {} ({})", errno, strerror(errno));
if(close(this->clientFd) < 0)
debugMessage(LOG_QUERY, "Failed to close the query client socket! {} ({})", errno, strerror(errno));
this->clientFd = -1;
}
if(this->server) {
{
unique_lock channel_lock(this->server->channel_tree_lock);
this->server->unregisterClient(_this.lock(), "disconnected", channel_lock);
}
2019-09-14 14:22:16 +02:00
this->server->groups->disableCache(this->getClientDatabaseId());
this->server = nullptr;
}
this->readQueue.clear();
this->writeQueue.clear();
if(this->handle)
this->handle->unregisterConnection(dynamic_pointer_cast<QueryClient>(_this.lock()));
}
void QueryClient::writeRawMessage(const std::string &message) {
{
threads::MutexLock lock(this->bufferLock);
this->writeQueue.push_back(message);
}
if(this->writeEvent) event_add(this->writeEvent, nullptr);
}
void QueryClient::handleMessageWrite(int fd, short, void *) {
auto ownLock = _this.lock();
threads::MutexTryLock lock(this->bufferLock);
if(this->state == ConnectionState::DISCONNECTED) return;
if(!lock) {
if(this->writeEvent) event_add(this->writeEvent, nullptr);
return;
}
int writes = 0;
string buffer;
while(writes < 10 && !this->writeQueue.empty()) {
if(buffer.empty()) {
buffer = std::move(this->writeQueue.front());
this->writeQueue.pop_front();
}
auto length = send(fd, buffer.data(), buffer.length(), MSG_NOSIGNAL);
#ifdef DEBUG_TRAFFIC
debugMessage("Write " + to_string(buffer.length()));
hexDump((void *) buffer.data(), buffer.length());
#endif
if(length == -1) {
if (errno == EINTR || errno == EAGAIN) {
if(this->writeEvent)
event_add(this->writeEvent, nullptr);
return;
}
else {
logError(LOG_QUERY, "{} Failed to write message: {} ({} => {})", CLIENT_STR_LOG_PREFIX, length, errno, strerror(errno));
threads::Thread([=](){ ownLock->closeConnection(); }).detach();
return;
}
} else {
if(buffer.length() == length)
buffer = "";
else
buffer = buffer.substr(length);
}
writes++;
}
if(!buffer.empty())
this->writeQueue.push_front(buffer);
if(!this->writeQueue.empty() && this->writeEvent)
event_add(this->writeEvent, nullptr);
}
void QueryClient::handleMessageRead(int fd, short, void *) {
auto ownLock = dynamic_pointer_cast<QueryClient>(_this.lock());
if(!ownLock) {
logCritical(LOG_QUERY, "Could not get own lock!");
return;
}
string buffer(1024, 0);
auto length = read(fd, (void*) buffer.data(), buffer.length());
if(length <= 0){
if(errno == EINTR || errno == EAGAIN)
;//event_add(this->readEvent, nullptr);
else {
logError(LOG_QUERY, "{} Failed to read! Code: {} errno: {} message: {}", CLIENT_STR_LOG_PREFIX, length, errno, strerror(errno));
event_del_noblock(this->readEvent);
threads::Thread(THREAD_SAVE_OPERATIONS, [ownLock](){ ownLock->closeConnection(); }).detach();
}
return;
}
buffer.resize(length);
{
threads::MutexLock l(this->bufferLock);
if(this->state == ConnectionState::DISCONNECTED)
return;
this->readQueue.push_back(std::move(buffer));
#ifdef DEBUG_TRAFFIC
debugMessage("Read " + to_string(buffer.length()));
hexDump((void *) buffer.data(), buffer.length());
#endif
}
if(this->handle)
this->handle->executePool()->execute([ownLock]() {
int counter = 0;
while(ownLock->tickIOMessageProgress() && counter++ < 15);
});
}
bool QueryClient::tickIOMessageProgress() {
lock_guard<recursive_mutex> lock(this->lock_packet_handle);
if(!this->handle || this->state == ConnectionState::DISCONNECTED || this->state == ConnectionState::DISCONNECTING) return false;
string message;
bool next = false;
{
threads::MutexLock l(this->bufferLock);
if(this->readQueue.empty()) return false;
message = std::move(this->readQueue.front());
this->readQueue.pop_front();
next |= this->readQueue.empty();
}
if(this->connectionType == ConnectionType::PLAIN) {
int count = 0;
while(this->handleMessage(pipes::buffer_view{(void*) message.data(), message.length()}) && count++ < 15) message = "";
next |= count == 15;
} else if(this->connectionType == ConnectionType::SSL_ENCRIPTED) {
this->ssl_handler.process_incoming_data(pipes::buffer_view{(void*) message.data(), message.length()});
} else if(this->connectionType == ConnectionType::UNKNOWN) {
if(config::query::sslMode != 0 && pipes::SSL::isSSLHeader(message)) {
this->initializeSSL();
/*
* - Content
* \x16
* -Version (1)
* \x03 \x00
* - length (2)
* \x00 \x04
*
* - Header
* \x00 -> hello request (3)
* \x05 -> length (4)
*/
//this->writeRawMessage(string("\x16\x03\x01\x00\x05\x00\x00\x00\x00\x00", 10));
} else {
this->connectionType = ConnectionType::PLAIN;
this->postInitialize();
}
next = true;
{
threads::MutexLock l(this->bufferLock);
this->readQueue.push_front(std::move(message));
}
}
return next;
}
extern InstanceHandler* serverInstance;
void QueryClient::initializeSSL() {
this->connectionType = ConnectionType::SSL_ENCRIPTED;
this->ssl_handler.direct_process(pipes::PROCESS_DIRECTION_OUT, true);
this->ssl_handler.direct_process(pipes::PROCESS_DIRECTION_IN, true);
this->ssl_handler.callback_data(std::bind(&QueryClient::handleMessage, this, placeholders::_1));
this->ssl_handler.callback_write(std::bind(&QueryClient::writeRawMessage, this, placeholders::_1));
this->ssl_handler.callback_initialized = std::bind(&QueryClient::postInitialize, this);
this->ssl_handler.callback_error([&](int code, const std::string& message) {
if(code == PERROR_SSL_ACCEPT) {
this->disconnect("invalid accept");
} else if(code == PERROR_SSL_TIMEOUT)
this->disconnect("invalid accept (timeout)");
else
logError(LOG_QUERY, "Got unknown ssl error ({} | {})", code, message);
});
{
auto context = serverInstance->sslManager()->getQueryContext();
auto options = make_shared<pipes::SSL::Options>();
options->type = pipes::SSL::SERVER;
options->context_method = TLS_method();
options->default_keypair({context->privateKey, context->certificate});
if(!this->ssl_handler.initialize(options)) {
logError(LOG_QUERY, "[{}] Failed to setup ssl!", CLIENT_STR_LOG_PREFIX);
}
}
}
bool QueryClient::handleMessage(const pipes::buffer_view& message) {
{
threads::MutexLock l(this->closeLock);
if(this->state == ConnectionState::DISCONNECTED)
return false;
}
#ifdef DEBUG_TRAFFIC
debugMessage("Handling message " + to_string(message.length()));
hexDump((void *) message.data(), message.length());
#endif
string command;
{
this->lineBuffer += message.string();
int length = 2;
auto pos = this->lineBuffer.find("\r\n");
if(pos == string::npos) pos = this->lineBuffer.find("\n\r");
if(pos == string::npos) {
length = 1;
pos = this->lineBuffer.find('\n');
}
if(pos != string::npos){
command = this->lineBuffer.substr(0, pos);
if(this->lineBuffer.size() > pos + length)
this->lineBuffer = this->lineBuffer.substr(pos + length);
else
this->lineBuffer.clear();
}
if(pos == string::npos) return false;
}
if(command.empty() || command.find_first_not_of(' ') == string::npos) { //Empty command
logTrace(LOG_QUERY, "[{}:{}] Got query idle command (Empty command or spaces)", this->getLoggingPeerIp(), this->getPeerPort());
CMD_RESET_IDLE; //if idle time over 5 min than connection drop
return true;
}
logTrace(LOG_QUERY, "[{}:{}] Got query command {}", this->getLoggingPeerIp(), this->getPeerPort(), command);
unique_ptr<Command> cmd;
try {
cmd = make_unique<Command>(Command::parse(pipes::buffer_view{(void*) command.data(), command.length()}, true, !ts::config::server::strict_ut8_mode));
} catch(std::invalid_argument& ex) {
this->notifyError(CommandResult{findError("parameter_convert"), ex.what()});
return false;
} catch(command_malformed_exception& ex) {
this->notifyError(CommandResult{findError("parameter_convert"), "invalid character @" + to_string(ex.index())});
return false;
} catch(std::exception& ex) {
this->notifyError(CommandResult{ErrorType::VSError, ex.what()});
return false;
}
try {
this->handleCommandFull(*cmd);
} catch(std::exception& ex) {
this->notifyError(CommandResult{ErrorType::VSError, ex.what()});
}
return true;
}
void QueryClient::sendCommand(const ts::Command &command, bool) {
auto cmd = command.build();
writeMessage(cmd + config::query::newlineCharacter);
logTrace(LOG_QUERY, "Send command {}", cmd);
}
void QueryClient::tick(const std::chrono::system_clock::time_point &time) {
ConnectedClient::tick(time);
}
void QueryClient::queryTick() {
lock_guard<recursive_mutex> lock_tick(this->lock_query_tick);
if(this->idleTimestamp.time_since_epoch().count() > 0 && system_clock::now() - this->idleTimestamp > minutes(5)){
debugMessage(LOG_QUERY, "Dropping client " + this->getLoggingPeerIp() + "|" + this->getDisplayName() + ". (Timeout)");
this->closeConnection(system_clock::now() + seconds(1));
}
if(this->connectionType == ConnectionType::UNKNOWN && system_clock::now() - milliseconds(500) > connectedTimestamp) {
this->connectionType = ConnectionType::PLAIN;
this->postInitialize();
}
}
bool QueryClient::notifyChannelSubscribed(const deque<shared_ptr<BasicChannel>> &) {
return false;
}
bool QueryClient::notifyChannelUnsubscribed(const deque<shared_ptr<BasicChannel>> &){
return false;
}
bool QueryClient::ignoresFlood() {
return this->whitelisted || ConnectedClient::ignoresFlood();
}