#include #include #include "QueryClient.h" #include #include "src/InstanceHandler.h" #include #include #include using namespace std; using namespace std::chrono; using namespace ts; using namespace ts::server; using namespace ts::server::query; #if defined(TCP_CORK) && !defined(TCP_NOPUSH) #define TCP_NOPUSH TCP_CORK #endif //#define DEBUG_TRAFFIC NetworkBuffer* NetworkBuffer::allocate(size_t length) { auto result = (NetworkBuffer*) malloc(length + sizeof(NetworkBuffer)); new (result) NetworkBuffer{}; result->length = length; result->ref_count++; return result; } NetworkBuffer* NetworkBuffer::ref() { this->ref_count++; return this; } void NetworkBuffer::unref() { if(--this->ref_count == 0) { this->NetworkBuffer::~NetworkBuffer(); ::free(this); } } QueryClient::QueryClient(QueryServer* handle, int sockfd) : ConnectedClient(handle->sql, nullptr), handle(handle), client_file_descriptor(sockfd) { memtrack::allocated(this); int enabled = 1; int disabled = 0; setsockopt(sockfd, SOL_SOCKET, SO_KEEPALIVE, &enabled, sizeof(enabled)); if(setsockopt(sockfd, IPPROTO_TCP, TCP_NOPUSH, &disabled, sizeof disabled) < 0) { logError(this->getServerId(), "[Query] Could not disable nopush for {} ({}/{})", CLIENT_STR_LOG_PREFIX_(this), errno, strerror(errno)); } if(setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, &enabled, sizeof enabled) < 0) { logError(this->getServerId(), "[Query] Could not disable no delay for {} ({}/{})", CLIENT_STR_LOG_PREFIX_(this), errno, strerror(errno)); } this->state = ConnectionState::CONNECTED; connectedTimestamp = system_clock::now(); this->resetEventMask(); } void QueryClient::initialize_weak_reference(const std::shared_ptr &self) { ConnectedClient::initialize_weak_reference(self); this->command_queue = std::make_unique( serverInstance->server_command_executor(), std::make_unique(dynamic_pointer_cast(self)) ); this->event_read = event_new(this->handle->event_io_loop, this->client_file_descriptor, EV_READ | EV_PERSIST, [](int a, short b, void* c){ ((QueryClient *) c)->handle_event_read(a, b, c); }, this); this->event_write = event_new(this->handle->event_io_loop, this->client_file_descriptor, EV_WRITE, [](int a, short b, void* c){ ((QueryClient *) c)->handle_event_write(a, b, c); }, this); } QueryClient::~QueryClient() { if(this->line_buffer) { free(this->line_buffer); this->line_buffer = nullptr; } this->ssl_handler.finalize(); while(this->write_buffer_head) { auto buffer = std::exchange(this->write_buffer_head, this->write_buffer_head->next_buffer); buffer->unref(); } this->write_buffer_tail = nullptr; memtrack::freed(this); } void QueryClient::preInitialize() { this->properties()[property::CLIENT_TYPE] = ClientType::CLIENT_QUERY; this->properties()[property::CLIENT_TYPE_EXACT] = ClientType::CLIENT_QUERY; this->properties()[property::CLIENT_UNIQUE_IDENTIFIER] = "UnknownQuery"; this->properties()[property::CLIENT_NICKNAME] = string() + "ServerQuery#" + this->getLoggingPeerIp() + "/" + to_string(this->getPeerPort()); DatabaseHelper::assignDatabaseId(this->sql, this->getServerId(), this->ref()); if(ts::config::query::sslMode == 0) { this->connectionType = ConnectionType::PLAIN; this->postInitialize(); } } void QueryClient::postInitialize() { lock_guard command_lock(this->command_lock); this->connectTimestamp = system_clock::now(); this->properties()[property::CLIENT_LASTCONNECTED] = duration_cast(this->connectTimestamp.time_since_epoch()).count(); if(ts::config::query::sslMode == 1 && this->connectionType != ConnectionType::SSL_ENCRYPTED) { command_result error{error::failed_connection_initialisation, "Please use a SSL encryption!"}; this->notifyError(error); error.release_data(); this->disconnect("Please us a SSL encryption for more security.\nThe server denies also all other connections!"); return; } send_message(config::query::motd); assert(this->handle); if(this->handle->ip_whitelist) { this->whitelisted = this->handle->ip_whitelist->contains(this->remote_address); } else { this->whitelisted = false; } if(!this->whitelisted && this->handle->ip_blacklist) { assert(this->handle->ip_blacklist); if(this->handle->ip_blacklist->contains(this->remote_address)) { Command cmd("error"); auto err = findError("client_login_not_permitted"); cmd["id"] = err.errorId; cmd["msg"] = err.message; cmd["extra_msg"] = "You're not permitted to use the query interface! (Your blacklisted)"; this->sendCommand(cmd); this->disconnect("blacklisted"); return;; } } debugMessage(LOG_QUERY, "Got new query client from {}. Whitelisted: {}", this->getLoggingPeerIp(), this->whitelisted); if(!this->whitelisted) { std::lock_guard connect_lock{this->handle->connected_clients_mutex}; auto address = this->getPeerIp(); if(this->handle->client_connect_bans.count(address) > 0) { auto ban = this->handle->client_connect_bans[address]; Command cmd("error"); auto err = findError("server_connect_banned"); cmd["id"] = err.errorId; cmd["msg"] = err.message; cmd["extra_msg"] = "you may retry in " + to_string(duration_cast(ban - system_clock::now()).count()) + " seconds"; this->sendCommand(cmd); this->close_connection(std::chrono::system_clock::now() + std::chrono::seconds{1}); } } this->update_client_needed_permissions(); } void QueryClient::send_message(const std::string_view& message) { if(this->state == ConnectionState::DISCONNECTED || !this->handle) { return; } if(this->connectionType == ConnectionType::PLAIN) { this->enqueue_write_buffer(message); } else if(this->connectionType == ConnectionType::SSL_ENCRYPTED) { this->ssl_handler.send(pipes::buffer_view{(void*) message.data(), message.length()}); } else { logCritical(LOG_GENERAL, "Invalid query connection type to write to!"); } } bool QueryClient::disconnect(const std::string &reason) { if(!reason.empty()) { ts::command_builder notify{"disconnect"}; notify.put_unchecked(0, "reason", reason); this->sendCommand(notify, false); } return this->close_connection(system_clock::now() + seconds(1)); } bool QueryClient::close_connection(const std::chrono::system_clock::time_point& flush_timeout_) { this->flush_timeout = flush_timeout_; bool should_flush = std::chrono::system_clock::now() < flush_timeout; { std::lock_guard network_lock{this->network_mutex}; if(this->event_read) { event_del_noblock(this->event_read); } if(!should_flush && this->event_write) { event_del_noblock(this->event_write); } } if(should_flush) { this->handle->enqueue_query_disconnect(dynamic_pointer_cast(this->ref())); } else { this->handle->enqueue_query_connection_close(dynamic_pointer_cast(this->ref())); } return true; } void QueryClient::execute_final_disconnect() { assert(!this->server); { std::unique_lock network_lock{this->network_mutex}; auto event_read_ = std::exchange(this->event_read, nullptr); auto event_write_ = std::exchange(this->event_write, nullptr); network_lock.unlock(); if(event_read_) { event_del_block(event_read_); event_free(event_read_); } if(event_write_) { event_del_block(event_write_); event_free(event_write_); } } if(this->client_file_descriptor > 0) { if(shutdown(this->client_file_descriptor, SHUT_RDWR) < 0) { debugMessage(LOG_QUERY, "Could not shutdown query client socket! {} ({})", errno, strerror(errno)); } if(close(this->client_file_descriptor) < 0) { debugMessage(LOG_QUERY, "Failed to close the query client socket! {} ({})", errno, strerror(errno)); } this->client_file_descriptor = -1; } } void QueryClient::enqueue_write_buffer(const std::string_view &message) { auto buffer = NetworkBuffer::allocate(message.length()); memcpy(buffer->data(), message.data(), message.length()); { std::lock_guard buffer_lock{this->network_mutex}; if(this->event_write) { *this->write_buffer_tail = buffer; this->write_buffer_tail = &buffer->next_buffer; event_add(this->event_write, nullptr); return; } } /* We don't have a network write event. Drop the buffer. */ buffer->unref(); } void QueryClient::handle_event_write(int fd, short, void *) { NetworkBuffer* write_buffer{nullptr}; { std::lock_guard buffer_lock{this->network_mutex}; if(this->write_buffer_head) { write_buffer = this->write_buffer_head->ref(); } } while(write_buffer) { auto length = send(fd, (const char*) write_buffer->data() + write_buffer->bytes_written, write_buffer->length - write_buffer->bytes_written, MSG_NOSIGNAL); if(length == -1) { write_buffer->unref(); if (errno == EINTR || errno == EAGAIN) { std::lock_guard event_lock{this->network_mutex}; if(this->event_write) { event_add(this->event_write, nullptr); } } else { logError(LOG_QUERY, "{} Failed to send message ({}/{}). Closing connection.", CLIENT_STR_LOG_PREFIX, errno, strerror(errno)); this->close_connection(std::chrono::system_clock::time_point{}); { std::unique_lock event_lock{this->network_mutex}; auto event_write_ = std::exchange(this->event_write, nullptr); event_lock.unlock(); if(event_write_) { event_del_noblock(event_write_); event_free(event_write_); } } /* the "this" ptr might be dangling now since we can't join the write event any more! */ } return; } write_buffer->bytes_written += length; assert(write_buffer->bytes_written <= write_buffer->length); if(write_buffer->bytes_written == write_buffer->length) { /* * Even though we might free the buffer here we could still use the pointer for comparison. * If the buffer is still the head buffer it should not have been deallocated since * the queue itself holds a reference. */ write_buffer->unref(); /* Buffer must be freed, but we don't want do that while holding the lock */ NetworkBuffer* cleanup_buffer{nullptr}; /* Buffer finished, sending next one */ { std::lock_guard buffer_lock{this->network_mutex}; if(this->write_buffer_head == write_buffer) { /* Buffer successfully send. Sending the next one. */ cleanup_buffer = this->write_buffer_head; this->write_buffer_head = this->write_buffer_head->next_buffer; if(this->write_buffer_head) { /* we've a next buffer */ write_buffer = this->write_buffer_head->ref(); } else { assert(this->write_buffer_tail == &write_buffer->next_buffer); write_buffer = nullptr; this->write_buffer_tail = &this->write_buffer_head; } } else if(this->write_buffer_head) { /* Our buffer got dropped (who knows why). Just send the next one. */ write_buffer = this->write_buffer_head->ref(); } else { /* * Nothing more to write. */ write_buffer = nullptr; } } if(cleanup_buffer) { cleanup_buffer->unref(); } } } /* This state should only be reached when no more messages are pending to write */ assert(!write_buffer); if(this->state == ConnectionState::DISCONNECTING) { this->handle->enqueue_query_connection_close(dynamic_pointer_cast(this->ref())); } } void QueryClient::handle_event_read(int fd, short, void *) { static const size_t kReadBufferLength = 1024 * 8; uint8_t buffer[kReadBufferLength]; auto length = read(fd, buffer, kReadBufferLength); if(length <= 0){ if(errno == EINTR || errno == EAGAIN) { /* Nothing to read */ return; } if(length == 0 && errno == 0) { logMessage(LOG_QUERY, "{} Connection closed. Client disconnected.", CLIENT_STR_LOG_PREFIX); } else { logMessage(LOG_QUERY, "{} Failed to received message ({}/{}). Closing connection.", CLIENT_STR_LOG_PREFIX, errno, strerror(errno)); } this->close_connection(std::chrono::system_clock::time_point{}); { std::unique_lock network_lock{this->network_mutex}; auto event_read_ = std::exchange(this->event_read, nullptr); network_lock.unlock(); if(event_read_) { event_del_noblock(event_read_); event_free(event_read_); } } /* the "this" ptr might be dangling now since we can't join the read event any more! */ return; } this->handle_message_read(std::string_view{(const char *) buffer, (size_t) length}); } inline bool is_ssl_handshake_header(const std::string_view& buffer) { if(buffer.length() < 0x05) return false; //Header too small! if(buffer[0] != 0x16) return false; //recordType=handshake if(buffer[1] < 1 || buffer[1] > 3) return false; //SSL version if(buffer[2] < 1 || buffer[2] > 3) return false; //TLS version return true; } void QueryClient::handle_message_read(const std::string_view &message) { if(this->state >= ConnectionState::DISCONNECTING) { /* We don't need to handle any messages. */ return; } switch (this->connectionType) { case ConnectionType::PLAIN: this->handle_decoded_message(message); break; case ConnectionType::SSL_ENCRYPTED: this->ssl_handler.process_incoming_data(pipes::buffer_view{message.data(), message.length()}); break; case ConnectionType::UNKNOWN: { if(config::query::sslMode != 0 && is_ssl_handshake_header(message)) { this->initializeSSL(); /* * - Content * \x16 * -Version (1) * \x03 \x00 * - length (2) * \x00 \x04 * * - Header * \x00 -> hello request (3) * \x05 -> length (4) */ this->ssl_handler.process_incoming_data(pipes::buffer_view{message.data(), message.length()}); } else { this->connectionType = ConnectionType::PLAIN; this->postInitialize(); this->handle_decoded_message(message); } } } } inline size_t line_buffer_size(size_t target_size) { return target_size; } void QueryClient::handle_decoded_message(const std::string_view &message) { if(this->line_buffer_length + message.length() > this->line_buffer_capacity) { this->line_buffer_capacity = line_buffer_size(this->line_buffer_length + message.length()); auto new_buffer = (char*) malloc(this->line_buffer_capacity); memcpy(new_buffer, this->line_buffer, this->line_buffer_length); free(this->line_buffer); this->line_buffer = new_buffer; } memcpy(this->line_buffer + this->line_buffer_length, message.data(), message.length()); this->line_buffer_length += message.length(); /* * Now we're analyzing the line buffer. * Note: Telnet commands will be executed as empty (idle commands) */ size_t command_start_index{0}, command_end_index, command_start_next; for(; this->line_buffer_scan_offset < this->line_buffer_length; this->line_buffer_scan_offset++) { if(this->line_buffer[this->line_buffer_scan_offset] == '\n') { command_end_index = this->line_buffer_scan_offset; command_start_next = this->line_buffer_scan_offset + 1; } else if((uint8_t) this->line_buffer[this->line_buffer_scan_offset] == 255) { if(this->line_buffer_scan_offset + 3 > this->line_buffer_length) { /* We don't have enough space to fill the telnet command so we use that as the new scan offset */ command_end_index = this->line_buffer_scan_offset; command_start_next = this->line_buffer_scan_offset; if(command_start_next == command_end_index) { /* We've no prepended data so we're waiting for the tcp command. Loop finished. */ break; } } else { command_end_index = this->line_buffer_scan_offset; command_start_next = this->line_buffer_scan_offset + 3; logTrace(LOG_QUERY, "[{}:{}] Received telnet command, code = {}, option = {}", this->getLoggingPeerIp(), this->getPeerPort(), (uint8_t) this->line_buffer[this->line_buffer_scan_offset + 1], (uint8_t) this->line_buffer[this->line_buffer_scan_offset + 2] ); } } else { continue; } /* No need to check for the upper bounds since there will be \n or 255 before the end of the line */ while(this->line_buffer[command_start_index] == '\r') { command_start_index++; } while(command_end_index > command_start_index + 1 && this->line_buffer[command_end_index - 1] == '\r') { command_end_index--; } std::string_view command_view{this->line_buffer + command_start_index, command_end_index - command_start_index}; this->command_queue->enqueue_command_string(command_view); command_start_index = command_start_next; if(this->line_buffer_scan_offset + 1 < command_start_next) { this->line_buffer_scan_offset = command_start_next - 1; } } if(command_start_index > 0) { if(command_start_index == this->line_buffer_length) { this->line_buffer_length = 0; this->line_buffer_scan_offset = 0; } else { assert(this->line_buffer_length > command_start_index); assert(this->line_buffer_scan_offset > command_start_index); memcpy(this->line_buffer, this->line_buffer + command_start_index, this->line_buffer_length - command_start_index); this->line_buffer_length -= command_start_index; this->line_buffer_scan_offset -= command_start_index; } } if(this->line_buffer_length > ts::config::query::max_line_buffer) { this->line_buffer_length = 0; /* Buffer will be truncated later */ logWarning(LOG_QUERY, "[{}] Client exceeded max query line buffer size. Disconnecting client."); this->disconnect("line buffer length exceeded"); } /* Shrink if possible */ if(this->line_buffer_capacity > 8 * 1024 && this->line_buffer_length < 8 * 1024) { this->line_buffer_capacity = 8 * 1024; auto new_buffer = (char*) malloc(this->line_buffer_capacity); memcpy(new_buffer, this->line_buffer, this->line_buffer_length); free(this->line_buffer); this->line_buffer = new_buffer; } } void QueryClient::initializeSSL() { this->connectionType = ConnectionType::SSL_ENCRYPTED; this->ssl_handler.direct_process(pipes::PROCESS_DIRECTION_OUT, true); this->ssl_handler.direct_process(pipes::PROCESS_DIRECTION_IN, true); this->ssl_handler.callback_data([&](const pipes::buffer_view& buffer) { this->handle_decoded_message(std::string_view{buffer.data_ptr(), buffer.length()}); }); this->ssl_handler.callback_write([&](const pipes::buffer_view& buffer) { this->enqueue_write_buffer(std::string_view{buffer.data_ptr(), buffer.length()}); }); this->ssl_handler.callback_initialized = std::bind(&QueryClient::postInitialize, this); this->ssl_handler.callback_error([&](int code, const std::string& message) { if(code == PERROR_SSL_ACCEPT) { this->disconnect("invalid accept"); } else if(code == PERROR_SSL_TIMEOUT) this->disconnect("invalid accept (timeout)"); else logError(LOG_QUERY, "Got unknown ssl error ({} | {})", code, message); }); { auto context = serverInstance->sslManager()->getQueryContext(); auto options = make_shared(); options->type = pipes::SSL::SERVER; options->context_method = TLS_method(); options->default_keypair({context->privateKey, context->certificate}); if(!this->ssl_handler.initialize(options)) { logError(LOG_QUERY, "[{}] Failed to setup ssl!", CLIENT_STR_LOG_PREFIX); } } } void QueryClient::sendCommand(const ts::Command &command, bool) { auto cmd = command.build(); send_message(cmd + config::query::newlineCharacter); logTrace(LOG_QUERY, "Send command {}", cmd); } void QueryClient::sendCommand(const ts::command_builder &command, bool) { send_message(command.build() + config::query::newlineCharacter); logTrace(LOG_QUERY, "Send command {}", command.build()); } void QueryClient::tick_server(const std::chrono::system_clock::time_point &time) { ConnectedClient::tick_server(time); } /* FIXME: TODO: Forbit this while beeing in finalDisconnect! */ void QueryClient::tick_query() { if(this->idleTimestamp.time_since_epoch().count() > 0 && system_clock::now() - this->idleTimestamp > minutes(5)){ debugMessage(LOG_QUERY, "Dropping client " + this->getLoggingPeerIp() + "|" + this->getDisplayName() + ". (Timeout)"); this->close_connection(system_clock::now() + seconds(1)); } if(this->connectionType == ConnectionType::UNKNOWN && system_clock::now() - milliseconds(500) > connectedTimestamp) { this->connectionType = ConnectionType::PLAIN; this->postInitialize(); } if(this->flush_timeout.time_since_epoch().count() > 0 && std::chrono::system_clock::now() > this->flush_timeout) { this->handle->enqueue_query_connection_close(dynamic_pointer_cast(this->ref())); } } bool QueryClient::ignoresFlood() { return this->whitelisted || ConnectedClient::ignoresFlood(); } void QueryClient::disconnect_from_virtual_server(const std::string& reason) { std::lock_guard command_lock{this->command_lock}; auto old_server = std::exchange(this->server, nullptr); if(old_server) { { std::unique_lock tree_lock(old_server->channel_tree_lock); if(this->currentChannel) { old_server->client_move(this->ref(), nullptr, nullptr, "", ViewReasonId::VREASON_USER_ACTION, false, tree_lock); } old_server->unregisterClient(this->ref(), reason, tree_lock); } { std::lock_guard channel_lock{this->channel_lock}; this->channels->reset(); this->currentChannel = nullptr; } old_server->groups->disableCache(this->getClientDatabaseId()); this->loadDataForCurrentServer(); } serverInstance->getGroupManager()->enableCache(this->getClientDatabaseId()); }