#include "./resolver.h" #include "./response.h" #include #include #include #include #include #include #include #include #include /* for TSDNS */ #include #ifdef WIN32 #include #define SOCK_NONBLOCK (0) #define MSG_DONTWAIT (0) #else #include #endif using namespace std; using namespace tc::dns; Resolver::Resolver() { } Resolver::~Resolver() { this->finalize(); } bool Resolver::initialize(std::string &error, bool hosts, bool resolv) { if(this->event.loop_active) this->finalize(); this->event.loop_active = true; this->event.base = event_base_new(); if(!this->event.base) { error = "failed to allcoate event base"; return false; } this->event.loop = std::thread(std::bind(&Resolver::event_loop_runner, this)); this->ub_ctx = ub_ctx_create_event(this->event.base); if(!this->ub_ctx) { this->finalize(); error = "failed to create ub context"; return false; } /* Add /etc/hosts */ auto err = !hosts ? 0 : ub_ctx_hosts((struct ub_ctx*) this->ub_ctx, nullptr); if(err != 0) { cerr << "Failed to add hosts file: " << ub_strerror(err) << endl; } /* Add resolv.conf */ err = !resolv ? 0 : ub_ctx_resolvconf((struct ub_ctx*) this->ub_ctx, nullptr); if(err != 0) { cerr << "Failed to add hosts file: " << ub_strerror(err) << endl; } return true; } void Resolver::finalize() { this->event.loop_active = false; if(this->event.base) { auto ret = event_base_loopexit(this->event.base, nullptr); if(ret != 0) { cerr << "Failed to exit event base loop: " << ret << endl; } } { this->event.condition.notify_one(); if(this->event.loop.joinable()) this->event.loop.join(); } { unique_lock lock(this->request_lock); auto dns_list = std::move(this->dns_requests); auto tsdns_list = std::move(this->tsdns_requests); for(auto entry : dns_list) { ub_cancel(this->ub_ctx, entry->ub_id); entry->callback(ResultState::ABORT, 0, nullptr); this->destroy_dns_request(entry); } for(auto entry : tsdns_list) { entry->callback(ResultState::ABORT, 0, ""); this->destroy_tsdns_request(entry); } lock.unlock(); } ub_ctx_delete((struct ub_ctx*) this->ub_ctx); this->ub_ctx = nullptr; if(this->event.base) { event_base_free(this->event.base); this->event.base = nullptr; } } void Resolver::event_loop_runner() { while(true) { { unique_lock lock{this->event.lock}; if(!this->event.loop_active) break; this->event.condition.wait(lock); if(!this->event.loop_active) break; } event_base_loop(this->event.base, 0); } } //Call only within the event loop! void Resolver::destroy_dns_request(Resolver::dns_request *request) { assert(this_thread::get_id() == this->event.loop.get_id() || !this->event.loop_active); { lock_guard lock{this->request_lock}; this->dns_requests.erase(std::find(this->dns_requests.begin(), this->dns_requests.end(), request), this->dns_requests.end()); } if(request->register_event) { event_del_noblock(request->register_event); event_free(request->register_event); request->register_event = nullptr; } if(request->timeout_event) { event_del_noblock(request->timeout_event); event_free(request->timeout_event); request->timeout_event = nullptr; } delete request; } void Resolver::destroy_tsdns_request(Resolver::tsdns_request *request) { assert(this_thread::get_id() == this->event.loop.get_id() || !this->event.loop_active); { lock_guard lock{this->request_lock}; this->tsdns_requests.erase(std::find(this->tsdns_requests.begin(), this->tsdns_requests.end(), request), this->tsdns_requests.end()); } if(request->event_read) { event_del_noblock(request->event_read); event_free(request->event_read); request->event_read = nullptr; } if(request->event_write) { event_del_noblock(request->event_write); event_free(request->event_write); request->event_write = nullptr; } if(request->timeout_event) { event_del_noblock(request->timeout_event); event_free(request->timeout_event); request->timeout_event = nullptr; } if(request->socket > 0) { #ifndef WIN32 ::shutdown(request->socket, SHUT_RDWR); #endif ::close(request->socket); request->socket = 0; } delete request; } //--------------- DNS void Resolver::resolve_dns(const char *name, const rrtype::value &rrtype, const rrclass::value &rrclass, const std::chrono::microseconds& timeout, const dns_callback_t& callback) { if(!this->event.loop_active) { callback(ResultState::INITIALISATION_FAILED, 3, nullptr); return; } auto request = new dns_request{}; request->resolver = this; request->callback = callback; request->host = name; request->rrtype = rrtype; request->rrclass = rrclass; request->timeout_event = evtimer_new(this->event.base, [](evutil_socket_t, short, void *_request) { auto request = static_cast(_request); request->resolver->evtimer_dns_callback(request); }, request); request->register_event = evuser_new(this->event.base, [](evutil_socket_t, short, void *_request) { auto request = static_cast(_request); auto errc = ub_resolve_event(request->resolver->ub_ctx, request->host.c_str(), (int) request->rrtype, (int) request->rrclass, (void*) request, [](void* _request, int a, void* b, int c, int d, char* e) { auto request = static_cast(_request); request->resolver->ub_callback(request, a, b, c, d, e); }, &request->ub_id); if(errc != 0) { request->callback(ResultState::INITIALISATION_FAILED, errc, nullptr); request->resolver->destroy_dns_request(request); } }, request); if(!request->timeout_event || !request->register_event) { callback(ResultState::INITIALISATION_FAILED, 2, nullptr); if(request->timeout_event) event_free(request->timeout_event); if(request->register_event) event_free(request->register_event); delete request; return; } /* * Lock here all requests so the event loop cant already delete the request */ unique_lock rlock{this->request_lock}; { auto errc = event_add(request->timeout_event, nullptr); //TODO: Check for error evuser_trigger(request->register_event); } { auto seconds = chrono::floor(timeout); auto microseconds = chrono::ceil(timeout - seconds); timeval tv{seconds.count(), microseconds.count()}; auto errc = event_add(request->timeout_event, &tv); //TODO: Check for error } this->dns_requests.push_back(request); rlock.unlock(); /* Activate the event loop */ this->event.condition.notify_one(); } void Resolver::evtimer_dns_callback(tc::dns::Resolver::dns_request *request) { if(request->ub_id > 0) { auto errc = ub_cancel(this->ub_ctx, request->ub_id); if(errc != 0) { cerr << "Failed to cancel DNS request " << request->ub_id << " after timeout (" << errc << "/" << ub_strerror(errc) << ")!" << endl; } } request->callback(ResultState::DNS_TIMEOUT, 0, nullptr); this->destroy_dns_request(request); } void Resolver::ub_callback(dns_request* request, int rcode, void *packet, int packet_length, int sec, char *why_bogus) { if(rcode != 0) { request->callback(ResultState::DNS_FAIL, rcode, nullptr); } else { auto callback = request->callback; auto data = std::unique_ptr(new DNSResponse{(uint8_t) sec, why_bogus, packet, (size_t) packet_length}); callback(ResultState::SUCCESS, 0, std::move(data)); } this->destroy_dns_request(request); } thread_local std::vector visited_links; std::string DNSResponseBuffer::parse_dns_dn(std::string &error, size_t &index, bool allow_compression) { if(allow_compression) { visited_links.clear(); visited_links.reserve(8); if(std::find(visited_links.begin(), visited_links.end(), index) != visited_links.end()) { error = "circular link detected"; return ""; } visited_links.push_back(index); } error.clear(); string result; result.reserve(256); //Max length is 253 while(true) { if(index + 1 > this->length) { error = "truncated data (missing code)"; goto exit; } auto code = this->buffer[index++]; if(code == 0) break; if((code >> 6U) == 3) { if(!allow_compression) { error = "found link, but links are not allowed"; goto exit; } auto lower_addr = this->buffer[index++]; if(index + 1 > this->length) { error = "truncated data (missing lower link address)"; goto exit; } size_t addr = ((code & 0x3FU) << 8U) | lower_addr; if(addr >= this->length) { error = "invalid link address"; goto exit; } auto tail = this->parse_dns_dn(error, addr, true); if(!error.empty()) goto exit; if(!result.empty()) result += "." + tail; else result = tail; break; } else { if(code > 63) { error = "max domain label length is 63 characters"; goto exit; } if(!result.empty()) result += "."; if(index + code >= this->length) { error = "truncated data (domain label)"; goto exit; } result.append((const char*) (this->buffer + index), code); index += code; } } exit: if(allow_compression) visited_links.pop_back(); return result; } DNSResponseBuffer::~DNSResponseBuffer() { ::free(this->buffer); } DNSResponse::DNSResponse(uint8_t secure_state, const char* bogus, void *packet, size_t size) { this->bogus = bogus ? std::string{bogus} : std::string{"packet is secure"}; this->secure_state = secure_state; this->packet = make_shared(); this->packet->buffer = (uint8_t*) malloc(size); this->packet->length = size; memcpy(this->packet->buffer, packet, size); } response::DNSHeader DNSResponse::header() const { return response::DNSHeader{this}; } bool DNSResponse::parse(std::string &error) { if(this->is_parsed) { error = this->parse_error; return error.empty(); } error.clear(); this->is_parsed = true; auto header = this->header(); size_t index = 12; /* 12 bits for the header */ { auto count = header.query_count(); this->parsed_queries.reserve(count); for(size_t idx = 0; idx < count; idx++) { auto dn = this->packet->parse_dns_dn(error, index, true); if(!error.empty()) { error = "failed to parse query " + to_string(idx) + " dn: " + error; // NOLINT(performance-inefficient-string-concatenation) goto error_exit; } if(index + 4 > this->packet_length()) { error = "truncated data for query " + to_string(index); goto error_exit; } auto type = (rrtype::value) ntohs(*(uint16_t*) (this->packet->buffer + index)); index += 2; auto klass = (rrclass::value) ntohs(*(uint16_t*) (this->packet->buffer + index)); index += 2; this->parsed_queries.emplace_back(new response::DNSQuery{dn, type, klass}); } } { auto count = header.answer_count(); this->parsed_answers.reserve(count); for(size_t idx = 0; idx < count; idx++) { this->parsed_answers.push_back(this->parse_rr(error, index, true)); if(!error.empty()) { error = "failed to parse answer " + to_string(idx) + ": " + error; // NOLINT(performance-inefficient-string-concatenation) goto error_exit; } } } { auto count = header.authority_count(); this->parsed_authorities.reserve(count); for(size_t idx = 0; idx < count; idx++) { this->parsed_authorities.push_back(this->parse_rr(error, index, true)); if(!error.empty()) { error = "failed to parse authority " + to_string(idx) + ": " + error; // NOLINT(performance-inefficient-string-concatenation) goto error_exit; } } } { auto count = header.additional_count(); this->parsed_additionals.reserve(count); for(size_t idx = 0; idx < count; idx++) { this->parsed_additionals.push_back(this->parse_rr(error, index, true)); if(!error.empty()) { error = "failed to parse additional " + to_string(idx) + ": " + error; // NOLINT(performance-inefficient-string-concatenation) goto error_exit; } } } return true; error_exit: this->parsed_queries.clear(); this->parsed_answers.clear(); this->parsed_authorities.clear(); this->parsed_additionals.clear(); return false; } std::shared_ptr DNSResponse::parse_rr(std::string &error, size_t &index, bool allow_compressed) { auto dn = this->packet->parse_dns_dn(error, index, allow_compressed); if(!error.empty()) { error = "failed to parse rr dn: " + error; // NOLINT(performance-inefficient-string-concatenation) return nullptr; } if(index + 10 > this->packet_length()) { error = "truncated header"; return nullptr; } auto type = (rrtype::value) ntohs(*(uint16_t*) (this->packet->buffer + index)); index += 2; auto klass = (rrclass::value) ntohs(*(uint16_t*) (this->packet->buffer + index)); index += 2; auto ttl = ntohl(*(uint32_t*) (this->packet->buffer + index)); index += 4; auto payload_length = ntohs(*(uint16_t*) (this->packet->buffer + index)); index += 2; if(index + payload_length > this->packet_length()) { error = "truncated body"; return nullptr; } auto response = std::shared_ptr(new response::DNSResourceRecords{this->packet, index, payload_length, ttl, dn, type, klass}); index += payload_length; return response; } //---------------------- TSDNS void Resolver::resolve_tsdns(const char *query, const sockaddr_storage& server_address, const std::chrono::microseconds& timeout, const tc::dns::Resolver::tsdns_callback_t &callback) { /* create the socket */ auto socket = ::socket(server_address.ss_family, SOCK_STREAM | SOCK_NONBLOCK, 0); if(socket <= 0) { callback(ResultState::INITIALISATION_FAILED, -1, "failed to allocate socket: " + to_string(errno) + "/" + strerror(errno)); return; } #ifdef WIN32 u_long enabled = 0; auto non_block_rs = ioctlsocket(this->_socket, FIONBIO, &enabled); if (non_block_rs != NO_ERROR) { ::close(socket); callback(ResultState::INITIALISATION_FAILED, -2, "failed to enable nonblock: " + to_string(errno) + "/" + strerror(errno)); return; } #endif int opt = 1; setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(int)); auto request = new tsdns_request{}; request->resolver = this; request->callback = callback; request->socket = socket; request->timeout_event = evtimer_new(this->event.base, [](evutil_socket_t, short, void *_request) { auto request = static_cast(_request); request->resolver->evtimer_tsdns_callback(request); }, request); request->event_read = event_new(this->event.base, socket, EV_READ | EV_PERSIST, [](evutil_socket_t, short, void *_request){ auto request = static_cast(_request); request->resolver->event_tsdns_read(request); }, request); request->event_write = event_new(this->event.base, socket, EV_WRITE, [](evutil_socket_t, short, void *_request){ auto request = static_cast(_request); request->resolver->event_tsdns_write(request); }, request); if(!request->timeout_event || !request->event_write || !request->event_read) { callback(ResultState::INITIALISATION_FAILED, -3, ""); this->destroy_tsdns_request(request); return; } request->write_buffer = query; request->write_buffer += "\n\r\r\r\n"; int result = ::connect(socket, reinterpret_cast (&server_address), sizeof(server_address)); if (result < 0) { #ifdef WIN32 auto error = WSAGetLastError(); if(error != WSAEWOULDBLOCK) { /* * TODO! wchar_t *s = nullptr; FormatMessageW( FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, nullptr, error, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPWSTR)&s, 0, nullptr ); fprintf(stdout, "Connect failed with code %d. Error: %ld/%S\n", result, error, s); LocalFree(s); */ callback(ResultState::TSDNS_CONNECTION_FAIL, -1, "Failed to connect"); this->destroy_tsdns_request(request); } #else if(errno != EINPROGRESS) { callback(ResultState::TSDNS_CONNECTION_FAIL, -1, "Failed to connect with code: " + to_string(errno) + "/" + strerror(errno)); this->destroy_tsdns_request(request); return; } #endif } event_add(request->event_write, nullptr); event_add(request->event_read, nullptr); { auto seconds = chrono::floor(timeout); auto microseconds = chrono::ceil(timeout - seconds); timeval tv{seconds.count(), microseconds.count()}; auto errc = event_add(request->timeout_event, &tv); //TODO: Check for error } { lock_guard lock{this->request_lock}; this->tsdns_requests.push_back(request); } /* Activate the event loop */ this->event.condition.notify_one(); } void Resolver::evtimer_tsdns_callback(Resolver::tsdns_request *request) { request->callback(ResultState::DNS_TIMEOUT, 0, ""); this->destroy_tsdns_request(request); } void Resolver::event_tsdns_read(Resolver::tsdns_request *request) { int64_t buffer_length = 1024; char buffer[1024]; buffer_length = recv(request->socket, buffer, (int) buffer_length, MSG_DONTWAIT); if(buffer_length < 0) { #ifdef WIN32 auto error = WSAGetLastError(); if(error != WSAEWOULDBLOCK) return; request->callback(ResultState::TSDNS_CONNECTION_FAIL, -2, "read failed: " + to_string(error)); #else if(errno == EAGAIN) return; request->callback(ResultState::TSDNS_CONNECTION_FAIL, -2, "read failed: " + to_string(errno) + "/" + strerror(errno)); #endif this->destroy_tsdns_request(request); return; } else if(buffer_length == 0) { if(request->read_buffer.empty()) { request->callback(ResultState::TSDNS_EMPTY_RESPONSE, 0, ""); } else { request->callback(ResultState::SUCCESS, 0, request->read_buffer); } this->destroy_tsdns_request(request); return; } lock_guard lock{request->buffer_lock}; request->read_buffer.append(buffer, buffer_length); } void Resolver::event_tsdns_write(Resolver::tsdns_request *request) { lock_guard lock{request->buffer_lock}; if(request->write_buffer.empty()) return; auto written = send(request->socket, request->write_buffer.data(), min(request->write_buffer.size(), 1024UL), MSG_DONTWAIT); if(written < 0) { #ifdef WIN32 auto error = WSAGetLastError(); if(error != WSAEWOULDBLOCK) return; request->callback(ResultState::TSDNS_CONNECTION_FAIL, -4, "write failed: " + to_string(error)); #else if(errno == EAGAIN) return; request->callback(ResultState::TSDNS_CONNECTION_FAIL, -4, "write failed: " + to_string(errno) + "/" + strerror(errno)); #endif this->destroy_tsdns_request(request); return; } else if(written == 0) { request->callback(ResultState::TSDNS_CONNECTION_FAIL, -5, "remote peer hang up"); this->destroy_tsdns_request(request); return; } if(written == request->write_buffer.size()) request->write_buffer.clear(); else { request->write_buffer = request->write_buffer.substr(written); event_add(request->event_write, nullptr); } }