diff --git a/native/dns/CMakeLists.txt b/native/dns/CMakeLists.txt index 0f6ec25..54421cf 100644 --- a/native/dns/CMakeLists.txt +++ b/native/dns/CMakeLists.txt @@ -1,6 +1,6 @@ set(MODULE_NAME "teaclient_dns") -set(SOURCE_FILES ${SOURCE_FILES} src/resolver.cpp src/types.cpp src/response.cpp utils.cpp) +set(SOURCE_FILES ${SOURCE_FILES} src/resolver.cpp src/types.cpp src/response.cpp utils.cpp src/resolver_linux.cpp) find_package(Libevent REQUIRED) include_directories(${LIBEVENT_INCLUDE_DIRS}) diff --git a/native/dns/src/resolver.cpp b/native/dns/src/resolver.cpp index f72c218..306649a 100644 --- a/native/dns/src/resolver.cpp +++ b/native/dns/src/resolver.cpp @@ -23,130 +23,12 @@ using namespace std; using namespace tc::dns; -Resolver::Resolver() { - -} +Resolver::Resolver() { } Resolver::~Resolver() { this->finalize(); } -bool Resolver::initialize(std::string &error, bool hosts, bool resolv) { - if(this->event.loop_active) - this->finalize(); - - this->event.loop_active = true; - this->event.base = event_base_new(); - if(!this->event.base) { - error = "failed to allcoate event base"; - return false; - } - this->event.loop = std::thread(std::bind(&Resolver::event_loop_runner, this)); - - this->ub_ctx = ub_ctx_create_event(this->event.base); - if(!this->ub_ctx) { - this->finalize(); - error = "failed to create ub context"; - return false; - } - - /* Add /etc/hosts */ - auto err = !hosts ? 0 : ub_ctx_hosts((struct ub_ctx*) this->ub_ctx, nullptr); - if(err != 0) { - cerr << "Failed to add hosts file: " << ub_strerror(err) << endl; - } - - /* Add resolv.conf */ - err = !resolv ? 0 : ub_ctx_resolvconf((struct ub_ctx*) this->ub_ctx, nullptr); - if(err != 0) { - cerr << "Failed to add hosts file: " << ub_strerror(err) << endl; - } - - return true; -} - -void Resolver::finalize() { - this->event.loop_active = false; - if(this->event.base) { - auto ret = event_base_loopexit(this->event.base, nullptr); - if(ret != 0) { - cerr << "Failed to exit event base loop: " << ret << endl; - } - } - - { - this->event.condition.notify_one(); - if(this->event.loop.joinable()) - this->event.loop.join(); - } - - { - unique_lock lock(this->request_lock); - auto dns_list = std::move(this->dns_requests); - auto tsdns_list = std::move(this->tsdns_requests); - - for(auto entry : dns_list) { - ub_cancel(this->ub_ctx, entry->ub_id); - - entry->callback(ResultState::ABORT, 0, nullptr); - this->destroy_dns_request(entry); - } - - for(auto entry : tsdns_list) { - entry->callback(ResultState::ABORT, 0, ""); - this->destroy_tsdns_request(entry); - } - lock.unlock(); - } - - ub_ctx_delete((struct ub_ctx*) this->ub_ctx); - this->ub_ctx = nullptr; - - if(this->event.base) { - event_base_free(this->event.base); - this->event.base = nullptr; - } -} - -void Resolver::event_loop_runner() { - while(true) { - { - unique_lock lock{this->event.lock}; - if(!this->event.loop_active) - break; - - this->event.condition.wait(lock); - if(!this->event.loop_active) - break; - } - - event_base_loop(this->event.base, 0); - } -} - -//Call only within the event loop! -void Resolver::destroy_dns_request(Resolver::dns_request *request) { - assert(this_thread::get_id() == this->event.loop.get_id() || !this->event.loop_active); - - { - lock_guard lock{this->request_lock}; - this->dns_requests.erase(std::find(this->dns_requests.begin(), this->dns_requests.end(), request), this->dns_requests.end()); - } - - if(request->register_event) { - event_del_noblock(request->register_event); - event_free(request->register_event); - request->register_event = nullptr; - } - - if(request->timeout_event) { - event_del_noblock(request->timeout_event); - event_free(request->timeout_event); - request->timeout_event = nullptr; - } - delete request; -} - void Resolver::destroy_tsdns_request(Resolver::tsdns_request *request) { assert(this_thread::get_id() == this->event.loop.get_id() || !this->event.loop_active); @@ -184,322 +66,6 @@ void Resolver::destroy_tsdns_request(Resolver::tsdns_request *request) { delete request; } -//--------------- DNS -void Resolver::resolve_dns(const char *name, const rrtype::value &rrtype, const rrclass::value &rrclass, const std::chrono::microseconds& timeout, const dns_callback_t& callback) { - if(!this->event.loop_active) { - callback(ResultState::INITIALISATION_FAILED, 3, nullptr); - return; - } - - auto request = new dns_request{}; - request->resolver = this; - - request->callback = callback; - request->host = name; - request->rrtype = rrtype; - request->rrclass = rrclass; - - request->timeout_event = evtimer_new(this->event.base, [](evutil_socket_t, short, void *_request) { - auto request = static_cast(_request); - request->resolver->evtimer_dns_callback(request); - }, request); - - request->register_event = evuser_new(this->event.base, [](evutil_socket_t, short, void *_request) { - auto request = static_cast(_request); - - auto errc = ub_resolve_event(request->resolver->ub_ctx, request->host.c_str(), (int) request->rrtype, (int) request->rrclass, (void*) request, [](void* _request, int a, void* b, int c, int d, char* e) { - auto request = static_cast(_request); - request->resolver->ub_callback(request, a, b, c, d, e); - }, &request->ub_id); - - if(errc != 0) { - request->callback(ResultState::INITIALISATION_FAILED, errc, nullptr); - request->resolver->destroy_dns_request(request); - } - }, request); - - if(!request->timeout_event || !request->register_event) { - callback(ResultState::INITIALISATION_FAILED, 2, nullptr); - - if(request->timeout_event) - event_free(request->timeout_event); - - if(request->register_event) - event_free(request->register_event); - - delete request; - return; - } - - /* - * Lock here all requests so the event loop cant already delete the request - */ - unique_lock rlock{this->request_lock}; - - { - auto errc = event_add(request->timeout_event, nullptr); - //TODO: Check for error - - evuser_trigger(request->register_event); - } - - { - auto seconds = chrono::floor(timeout); - auto microseconds = chrono::ceil(timeout - seconds); - - timeval tv{seconds.count(), microseconds.count()}; - auto errc = event_add(request->timeout_event, &tv); - - //TODO: Check for error - } - - this->dns_requests.push_back(request); - rlock.unlock(); - - /* Activate the event loop */ - this->event.condition.notify_one(); -} - -void Resolver::evtimer_dns_callback(tc::dns::Resolver::dns_request *request) { - if(request->ub_id > 0) { - auto errc = ub_cancel(this->ub_ctx, request->ub_id); - if(errc != 0) { - cerr << "Failed to cancel DNS request " << request->ub_id << " after timeout (" << errc << "/" << ub_strerror(errc) << ")!" << endl; - } - } - - request->callback(ResultState::DNS_TIMEOUT, 0, nullptr); - this->destroy_dns_request(request); -} - -void Resolver::ub_callback(dns_request* request, int rcode, void *packet, int packet_length, int sec, char *why_bogus) { - if(rcode != 0) { - request->callback(ResultState::DNS_FAIL, rcode, nullptr); - } else { - auto callback = request->callback; - auto data = std::unique_ptr(new DNSResponse{(uint8_t) sec, why_bogus, packet, (size_t) packet_length}); - callback(ResultState::SUCCESS, 0, std::move(data)); - } - - this->destroy_dns_request(request); -} - -thread_local std::vector visited_links; -std::string DNSResponseBuffer::parse_dns_dn(std::string &error, size_t &index, bool allow_compression) { - if(allow_compression) { - visited_links.clear(); - visited_links.reserve(8); - - if(std::find(visited_links.begin(), visited_links.end(), index) != visited_links.end()) { - error = "circular link detected"; - return ""; - } - visited_links.push_back(index); - } - - error.clear(); - - string result; - result.reserve(256); //Max length is 253 - - while(true) { - if(index + 1 > this->length) { - error = "truncated data (missing code)"; - goto exit; - } - - auto code = this->buffer[index++]; - if(code == 0) break; - - if((code >> 6U) == 3) { - if(!allow_compression) { - error = "found link, but links are not allowed"; - goto exit; - } - - auto lower_addr = this->buffer[index++]; - if(index + 1 > this->length) { - error = "truncated data (missing lower link address)"; - goto exit; - } - - size_t addr = ((code & 0x3FU) << 8U) | lower_addr; - if(addr >= this->length) { - error = "invalid link address"; - goto exit; - } - auto tail = this->parse_dns_dn(error, addr, true); - if(!error.empty()) - goto exit; - - if(!result.empty()) - result += "." + tail; - else - result = tail; - break; - } else { - if(code > 63) { - error = "max domain label length is 63 characters"; - goto exit; - } - - if(!result.empty()) - result += "."; - - if(index + code >= this->length) { - error = "truncated data (domain label)"; - goto exit; - } - - result.append((const char*) (this->buffer + index), code); - index += code; - } - } - - exit: - if(allow_compression) visited_links.pop_back(); - return result; -} - -DNSResponseBuffer::~DNSResponseBuffer() { - ::free(this->buffer); -} - -DNSResponse::DNSResponse(uint8_t secure_state, const char* bogus, void *packet, size_t size) { - this->bogus = bogus ? std::string{bogus} : std::string{"packet is secure"}; - this->secure_state = secure_state; - - this->packet = make_shared(); - this->packet->buffer = (uint8_t*) malloc(size); - this->packet->length = size; - - memcpy(this->packet->buffer, packet, size); -} - -response::DNSHeader DNSResponse::header() const { - return response::DNSHeader{this}; -} - -bool DNSResponse::parse(std::string &error) { - if(this->is_parsed) { - error = this->parse_error; - return error.empty(); - } - error.clear(); - this->is_parsed = true; - - auto header = this->header(); - size_t index = 12; /* 12 bits for the header */ - - { - auto count = header.query_count(); - this->parsed_queries.reserve(count); - - for(size_t idx = 0; idx < count; idx++) { - auto dn = this->packet->parse_dns_dn(error, index, true); - if(!error.empty()) { - error = "failed to parse query " + to_string(idx) + " dn: " + error; // NOLINT(performance-inefficient-string-concatenation) - goto error_exit; - } - - if(index + 4 > this->packet_length()) { - error = "truncated data for query " + to_string(index); - goto error_exit; - } - - auto type = (rrtype::value) ntohs(*(uint16_t*) (this->packet->buffer + index)); - index += 2; - - auto klass = (rrclass::value) ntohs(*(uint16_t*) (this->packet->buffer + index)); - index += 2; - - this->parsed_queries.emplace_back(new response::DNSQuery{dn, type, klass}); - } - } - - { - auto count = header.answer_count(); - this->parsed_answers.reserve(count); - - for(size_t idx = 0; idx < count; idx++) { - this->parsed_answers.push_back(this->parse_rr(error, index, true)); - if(!error.empty()) { - error = "failed to parse answer " + to_string(idx) + ": " + error; // NOLINT(performance-inefficient-string-concatenation) - goto error_exit; - } - } - } - - { - auto count = header.authority_count(); - this->parsed_authorities.reserve(count); - - for(size_t idx = 0; idx < count; idx++) { - this->parsed_authorities.push_back(this->parse_rr(error, index, true)); - if(!error.empty()) { - error = "failed to parse authority " + to_string(idx) + ": " + error; // NOLINT(performance-inefficient-string-concatenation) - goto error_exit; - } - } - } - - { - auto count = header.additional_count(); - this->parsed_additionals.reserve(count); - - for(size_t idx = 0; idx < count; idx++) { - this->parsed_additionals.push_back(this->parse_rr(error, index, true)); - if(!error.empty()) { - error = "failed to parse additional " + to_string(idx) + ": " + error; // NOLINT(performance-inefficient-string-concatenation) - goto error_exit; - } - } - } - - return true; - - error_exit: - this->parsed_queries.clear(); - this->parsed_answers.clear(); - this->parsed_authorities.clear(); - this->parsed_additionals.clear(); - return false; -} - -std::shared_ptr DNSResponse::parse_rr(std::string &error, size_t &index, bool allow_compressed) { - auto dn = this->packet->parse_dns_dn(error, index, allow_compressed); - if(!error.empty()) { - error = "failed to parse rr dn: " + error; // NOLINT(performance-inefficient-string-concatenation) - return nullptr; - } - - if(index + 10 > this->packet_length()) { - error = "truncated header"; - return nullptr; - } - - auto type = (rrtype::value) ntohs(*(uint16_t*) (this->packet->buffer + index)); - index += 2; - - auto klass = (rrclass::value) ntohs(*(uint16_t*) (this->packet->buffer + index)); - index += 2; - - auto ttl = ntohl(*(uint32_t*) (this->packet->buffer + index)); - index += 4; - - auto payload_length = ntohs(*(uint16_t*) (this->packet->buffer + index)); - index += 2; - - if(index + payload_length > this->packet_length()) { - error = "truncated body"; - return nullptr; - } - - auto response = std::shared_ptr(new response::DNSResourceRecords{this->packet, index, payload_length, ttl, dn, type, klass}); - index += payload_length; - return response; -} - //---------------------- TSDNS void Resolver::resolve_tsdns(const char *query, const sockaddr_storage& server_address, const std::chrono::microseconds& timeout, const tc::dns::Resolver::tsdns_callback_t &callback) { /* create the socket */ diff --git a/native/dns/src/resolver.h b/native/dns/src/resolver.h index 775a720..2760f8c 100644 --- a/native/dns/src/resolver.h +++ b/native/dns/src/resolver.h @@ -9,8 +9,10 @@ #include #include "./types.h" - +#ifndef WIN32 struct ub_ctx; +#endif + namespace tc::dns { namespace response { class DNSHeader; @@ -18,12 +20,17 @@ namespace tc::dns { class DNSResourceRecords; } - struct DNSResponseBuffer { + struct DNSResponseData { +#ifdef WIN32 + //IMPLEMENT ME! +#else uint8_t* buffer{nullptr}; size_t length{0}; - ~DNSResponseBuffer(); std::string parse_dns_dn(std::string& /* error */, size_t& /* index */, bool /* compression allowed */); +#endif + + ~DNSResponseData(); }; class Resolver; @@ -38,21 +45,23 @@ namespace tc::dns { bool parse(std::string& /* error */); +#ifndef WIN32 [[nodiscard]] inline const std::string why_bogus() const { return this->bogus; } - [[nodiscard]] inline const uint8_t* packet_data() const { return this->packet->buffer; } - [[nodiscard]] inline size_t packet_length() const { return this->packet->length; } + [[nodiscard]] inline const uint8_t* packet_data() const { return this->data->buffer; } + [[nodiscard]] inline size_t packet_length() const { return this->data->length; } [[nodiscard]] inline bool is_secure() const { return this->secure_state > 0; } [[nodiscard]] inline bool is_secure_dnssec() const { return this->secure_state == 2; } [[nodiscard]] response::DNSHeader header() const; - +#endif [[nodiscard]] q_list_t queries() const { return this->parsed_queries; } [[nodiscard]] rr_list_t answers() const { return this->parsed_answers; } [[nodiscard]] rr_list_t authorities() const { return this->parsed_authorities; } [[nodiscard]] rr_list_t additionals() const { return this->parsed_additionals; } private: +#ifndef WIN32 DNSResponse(uint8_t /* secure state */, const char* /* bogus */, void* /* packet */, size_t /* length */); std::shared_ptr parse_rr(std::string& /* error */, size_t& index, bool /* compression allowed dn */); @@ -60,7 +69,10 @@ namespace tc::dns { std::string bogus; uint8_t secure_state{0}; - std::shared_ptr packet{nullptr}; + std::shared_ptr data{nullptr}; +#else + +#endif bool is_parsed{false}; std::string parse_error{}; @@ -101,6 +113,7 @@ namespace tc::dns { void resolve_dns(const char* /* name */, const rrtype::value& /* rrtype */, const rrclass::value& /* rrclass */, const std::chrono::microseconds& /* timeout */, const dns_callback_t& /* callback */); void resolve_tsdns(const char* /* name */, const sockaddr_storage& /* server */, const std::chrono::microseconds& /* timeout */, const tsdns_callback_t& /* callback */); private: +#ifndef WIN32 struct dns_request { Resolver* resolver{nullptr}; @@ -115,6 +128,8 @@ namespace tc::dns { dns_callback_t callback{}; }; + struct ub_ctx* ub_ctx = nullptr; +#endif struct tsdns_request { Resolver* resolver{nullptr}; @@ -139,23 +154,28 @@ namespace tc::dns { event_base* base = nullptr; } event; - struct ub_ctx* ub_ctx = nullptr; std::vector tsdns_requests{}; +#ifndef WIN32 std::vector dns_requests{}; +#endif std::recursive_mutex request_lock{}; /* this is recursive because due to the instance callback resolve_dns could be called recursively */ +#ifndef WIN32 void destroy_dns_request(dns_request*); +#endif void destroy_tsdns_request(tsdns_request*); void event_loop_runner(); - void evtimer_tsdns_callback(tsdns_request* /* request */); +#ifndef WIN32 void evtimer_dns_callback(dns_request* /* request */); + void ub_callback(dns_request* /* request */, int /* rcode */, void* /* packet */, int /* packet_len */, int /* sec */, char* /* why_bogus */); +#endif + void evtimer_tsdns_callback(tsdns_request* /* request */); void event_tsdns_write(tsdns_request* /* request */); void event_tsdns_read(tsdns_request* /* request */); - void ub_callback(dns_request* /* request */, int /* rcode */, void* /* packet */, int /* packet_len */, int /* sec */, char* /* why_bogus */); }; } \ No newline at end of file diff --git a/native/dns/src/resolver_linux.cpp b/native/dns/src/resolver_linux.cpp new file mode 100644 index 0000000..99cdefa --- /dev/null +++ b/native/dns/src/resolver_linux.cpp @@ -0,0 +1,449 @@ +#include "./resolver.h" +#include "./response.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include /* for TSDNS */ +#include + +using namespace std; +using namespace tc::dns; + +bool Resolver::initialize(std::string &error, bool hosts, bool resolv) { + if(this->event.loop_active) + this->finalize(); + + this->event.loop_active = true; + this->event.base = event_base_new(); + if(!this->event.base) { + error = "failed to allcoate event base"; + return false; + } + this->event.loop = std::thread(std::bind(&Resolver::event_loop_runner, this)); + + this->ub_ctx = ub_ctx_create_event(this->event.base); + if(!this->ub_ctx) { + this->finalize(); + error = "failed to create ub context"; + return false; + } + + /* Add /etc/hosts */ + auto err = !hosts ? 0 : ub_ctx_hosts((struct ub_ctx*) this->ub_ctx, nullptr); + if(err != 0) { + cerr << "Failed to add hosts file: " << ub_strerror(err) << endl; + } + + /* Add resolv.conf */ + err = !resolv ? 0 : ub_ctx_resolvconf((struct ub_ctx*) this->ub_ctx, nullptr); + if(err != 0) { + cerr << "Failed to add hosts file: " << ub_strerror(err) << endl; + } + + return true; +} + +void Resolver::finalize() { + this->event.loop_active = false; + if(this->event.base) { + auto ret = event_base_loopexit(this->event.base, nullptr); + if(ret != 0) { + cerr << "Failed to exit event base loop: " << ret << endl; + } + } + + { + this->event.condition.notify_one(); + if(this->event.loop.joinable()) + this->event.loop.join(); + } + + { + unique_lock lock(this->request_lock); + auto dns_list = std::move(this->dns_requests); + auto tsdns_list = std::move(this->tsdns_requests); + + for(auto entry : dns_list) { + ub_cancel(this->ub_ctx, entry->ub_id); + + entry->callback(ResultState::ABORT, 0, nullptr); + this->destroy_dns_request(entry); + } + + for(auto entry : tsdns_list) { + entry->callback(ResultState::ABORT, 0, ""); + this->destroy_tsdns_request(entry); + } + lock.unlock(); + } + + ub_ctx_delete((struct ub_ctx*) this->ub_ctx); + this->ub_ctx = nullptr; + + if(this->event.base) { + event_base_free(this->event.base); + this->event.base = nullptr; + } +} + +void Resolver::event_loop_runner() { + while(true) { + { + unique_lock lock{this->event.lock}; + if(!this->event.loop_active) + break; + + this->event.condition.wait(lock); + if(!this->event.loop_active) + break; + } + + event_base_loop(this->event.base, 0); + } +} + +//Call only within the event loop! +void Resolver::destroy_dns_request(Resolver::dns_request *request) { + assert(this_thread::get_id() == this->event.loop.get_id() || !this->event.loop_active); + + { + lock_guard lock{this->request_lock}; + this->dns_requests.erase(std::find(this->dns_requests.begin(), this->dns_requests.end(), request), this->dns_requests.end()); + } + + if(request->register_event) { + event_del_noblock(request->register_event); + event_free(request->register_event); + request->register_event = nullptr; + } + + if(request->timeout_event) { + event_del_noblock(request->timeout_event); + event_free(request->timeout_event); + request->timeout_event = nullptr; + } + delete request; +} + +//--------------- DNS +void Resolver::resolve_dns(const char *name, const rrtype::value &rrtype, const rrclass::value &rrclass, const std::chrono::microseconds& timeout, const dns_callback_t& callback) { + if(!this->event.loop_active) { + callback(ResultState::INITIALISATION_FAILED, 3, nullptr); + return; + } + + auto request = new dns_request{}; + request->resolver = this; + + request->callback = callback; + request->host = name; + request->rrtype = rrtype; + request->rrclass = rrclass; + + request->timeout_event = evtimer_new(this->event.base, [](evutil_socket_t, short, void *_request) { + auto request = static_cast(_request); + request->resolver->evtimer_dns_callback(request); + }, request); + + request->register_event = evuser_new(this->event.base, [](evutil_socket_t, short, void *_request) { + auto request = static_cast(_request); + + auto errc = ub_resolve_event(request->resolver->ub_ctx, request->host.c_str(), (int) request->rrtype, (int) request->rrclass, (void*) request, [](void* _request, int a, void* b, int c, int d, char* e) { + auto request = static_cast(_request); + request->resolver->ub_callback(request, a, b, c, d, e); + }, &request->ub_id); + + if(errc != 0) { + request->callback(ResultState::INITIALISATION_FAILED, errc, nullptr); + request->resolver->destroy_dns_request(request); + } + }, request); + + if(!request->timeout_event || !request->register_event) { + callback(ResultState::INITIALISATION_FAILED, 2, nullptr); + + if(request->timeout_event) + event_free(request->timeout_event); + + if(request->register_event) + event_free(request->register_event); + + delete request; + return; + } + + /* + * Lock here all requests so the event loop cant already delete the request + */ + unique_lock rlock{this->request_lock}; + + { + auto errc = event_add(request->timeout_event, nullptr); + //TODO: Check for error + + evuser_trigger(request->register_event); + } + + { + auto seconds = chrono::floor(timeout); + auto microseconds = chrono::ceil(timeout - seconds); + + timeval tv{seconds.count(), microseconds.count()}; + auto errc = event_add(request->timeout_event, &tv); + + //TODO: Check for error + } + + this->dns_requests.push_back(request); + rlock.unlock(); + + /* Activate the event loop */ + this->event.condition.notify_one(); +} + +void Resolver::evtimer_dns_callback(tc::dns::Resolver::dns_request *request) { + if(request->ub_id > 0) { + auto errc = ub_cancel(this->ub_ctx, request->ub_id); + if(errc != 0) { + cerr << "Failed to cancel DNS request " << request->ub_id << " after timeout (" << errc << "/" << ub_strerror(errc) << ")!" << endl; + } + } + + request->callback(ResultState::DNS_TIMEOUT, 0, nullptr); + this->destroy_dns_request(request); +} + +void Resolver::ub_callback(dns_request* request, int rcode, void *packet, int packet_length, int sec, char *why_bogus) { + if(rcode != 0) { + request->callback(ResultState::DNS_FAIL, rcode, nullptr); + } else { + auto callback = request->callback; + auto data = std::unique_ptr(new DNSResponse{(uint8_t) sec, why_bogus, packet, (size_t) packet_length}); + callback(ResultState::SUCCESS, 0, std::move(data)); + } + + this->destroy_dns_request(request); +} + +thread_local std::vector visited_links; +std::string DNSResponseData::parse_dns_dn(std::string &error, size_t &index, bool allow_compression) { + if(allow_compression) { + visited_links.clear(); + visited_links.reserve(8); + + if(std::find(visited_links.begin(), visited_links.end(), index) != visited_links.end()) { + error = "circular link detected"; + return ""; + } + visited_links.push_back(index); + } + + error.clear(); + + string result; + result.reserve(256); //Max length is 253 + + while(true) { + if(index + 1 > this->length) { + error = "truncated data (missing code)"; + goto exit; + } + + auto code = this->buffer[index++]; + if(code == 0) break; + + if((code >> 6U) == 3) { + if(!allow_compression) { + error = "found link, but links are not allowed"; + goto exit; + } + + auto lower_addr = this->buffer[index++]; + if(index + 1 > this->length) { + error = "truncated data (missing lower link address)"; + goto exit; + } + + size_t addr = ((code & 0x3FU) << 8U) | lower_addr; + if(addr >= this->length) { + error = "invalid link address"; + goto exit; + } + auto tail = this->parse_dns_dn(error, addr, true); + if(!error.empty()) + goto exit; + + if(!result.empty()) + result += "." + tail; + else + result = tail; + break; + } else { + if(code > 63) { + error = "max domain label length is 63 characters"; + goto exit; + } + + if(!result.empty()) + result += "."; + + if(index + code >= this->length) { + error = "truncated data (domain label)"; + goto exit; + } + + result.append((const char*) (this->buffer + index), code); + index += code; + } + } + + exit: + if(allow_compression) visited_links.pop_back(); + return result; +} + +DNSResponseData::~DNSResponseData() { + ::free(this->buffer); +} + +DNSResponse::DNSResponse(uint8_t secure_state, const char* bogus, void *packet, size_t size) { + this->bogus = bogus ? std::string{bogus} : std::string{"packet is secure"}; + this->secure_state = secure_state; + + this->data = make_shared(); + this->data->buffer = (uint8_t*) malloc(size); + this->data->length = size; + + memcpy(this->data->buffer, packet, size); +} + +response::DNSHeader DNSResponse::header() const { + return response::DNSHeader{this}; +} + +bool DNSResponse::parse(std::string &error) { + if(this->is_parsed) { + error = this->parse_error; + return error.empty(); + } + error.clear(); + this->is_parsed = true; + + auto header = this->header(); + size_t index = 12; /* 12 bits for the header */ + + { + auto count = header.query_count(); + this->parsed_queries.reserve(count); + + for(size_t idx = 0; idx < count; idx++) { + auto dn = this->data->parse_dns_dn(error, index, true); + if(!error.empty()) { + error = "failed to parse query " + to_string(idx) + " dn: " + error; // NOLINT(performance-inefficient-string-concatenation) + goto error_exit; + } + + if(index + 4 > this->packet_length()) { + error = "truncated data for query " + to_string(index); + goto error_exit; + } + + auto type = (rrtype::value) ntohs(*(uint16_t*) (this->data->buffer + index)); + index += 2; + + auto klass = (rrclass::value) ntohs(*(uint16_t*) (this->data->buffer + index)); + index += 2; + + this->parsed_queries.emplace_back(new response::DNSQuery{dn, type, klass}); + } + } + + { + auto count = header.answer_count(); + this->parsed_answers.reserve(count); + + for(size_t idx = 0; idx < count; idx++) { + this->parsed_answers.push_back(this->parse_rr(error, index, true)); + if(!error.empty()) { + error = "failed to parse answer " + to_string(idx) + ": " + error; // NOLINT(performance-inefficient-string-concatenation) + goto error_exit; + } + } + } + + { + auto count = header.authority_count(); + this->parsed_authorities.reserve(count); + + for(size_t idx = 0; idx < count; idx++) { + this->parsed_authorities.push_back(this->parse_rr(error, index, true)); + if(!error.empty()) { + error = "failed to parse authority " + to_string(idx) + ": " + error; // NOLINT(performance-inefficient-string-concatenation) + goto error_exit; + } + } + } + + { + auto count = header.additional_count(); + this->parsed_additionals.reserve(count); + + for(size_t idx = 0; idx < count; idx++) { + this->parsed_additionals.push_back(this->parse_rr(error, index, true)); + if(!error.empty()) { + error = "failed to parse additional " + to_string(idx) + ": " + error; // NOLINT(performance-inefficient-string-concatenation) + goto error_exit; + } + } + } + + return true; + + error_exit: + this->parsed_queries.clear(); + this->parsed_answers.clear(); + this->parsed_authorities.clear(); + this->parsed_additionals.clear(); + return false; +} + +std::shared_ptr DNSResponse::parse_rr(std::string &error, size_t &index, bool allow_compressed) { + auto dn = this->data->parse_dns_dn(error, index, allow_compressed); + if(!error.empty()) { + error = "failed to parse rr dn: " + error; // NOLINT(performance-inefficient-string-concatenation) + return nullptr; + } + + if(index + 10 > this->packet_length()) { + error = "truncated header"; + return nullptr; + } + + auto type = (rrtype::value) ntohs(*(uint16_t*) (this->data->buffer + index)); + index += 2; + + auto klass = (rrclass::value) ntohs(*(uint16_t*) (this->data->buffer + index)); + index += 2; + + auto ttl = ntohl(*(uint32_t*) (this->data->buffer + index)); + index += 4; + + auto payload_length = ntohs(*(uint16_t*) (this->data->buffer + index)); + index += 2; + + if(index + payload_length > this->packet_length()) { + error = "truncated body"; + return nullptr; + } + + auto response = std::shared_ptr(new response::DNSResourceRecords{this->data, index, payload_length, ttl, dn, type, klass}); + index += payload_length; + return response; +} diff --git a/native/dns/src/response.cpp b/native/dns/src/response.cpp index 32df875..6a5d531 100644 --- a/native/dns/src/response.cpp +++ b/native/dns/src/response.cpp @@ -28,13 +28,13 @@ uint16_t DNSHeader::field(int index) const { return ((uint16_t*) this->response->packet_data())[index]; } -DNSResourceRecords::DNSResourceRecords(std::shared_ptr packet, size_t payload_offset, size_t length, uint32_t ttl, std::string name, rrtype::value type, rrclass::value klass) +DNSResourceRecords::DNSResourceRecords(std::shared_ptr packet, size_t payload_offset, size_t length, uint32_t ttl, std::string name, rrtype::value type, rrclass::value klass) : offset{payload_offset}, length{length}, ttl{ttl}, name{std::move(name)}, type{type}, klass{klass} { - this->packet = std::move(packet); + this->data = std::move(packet); } const uint8_t* DNSResourceRecords::payload_data() const { - return this->packet->buffer + this->offset; + return this->data->buffer + this->offset; } //---------------- AAAA @@ -46,6 +46,7 @@ in6_addr AAAA::address() { return { .__in6_u = { .__u6_addr32 = { + //Attention unaligned memory access ((uint32_t*) this->handle->payload_data())[0], ((uint32_t*) this->handle->payload_data())[1], ((uint32_t*) this->handle->payload_data())[2], @@ -61,26 +62,26 @@ bool SRV::is_valid() { return false; size_t index = this->handle->payload_offset() + 6; std::string error{}; - this->handle->dns_packet()->parse_dns_dn(error, index, true); + this->handle->dns_data()->parse_dns_dn(error, index, true); return error.empty(); } std::string SRV::target_hostname() { size_t index = this->handle->payload_offset() + 6; std::string error{}; - return this->handle->dns_packet()->parse_dns_dn(error, index, true); + return this->handle->dns_data()->parse_dns_dn(error, index, true); } //---------------- All types with a name bool named_base::is_valid() { size_t index = this->handle->payload_offset(); std::string error{}; - this->handle->dns_packet()->parse_dns_dn(error, index, true); + this->handle->dns_data()->parse_dns_dn(error, index, true); return error.empty(); } std::string named_base::name() { size_t index = this->handle->payload_offset(); std::string error{}; - return this->handle->dns_packet()->parse_dns_dn(error, index, true); + return this->handle->dns_data()->parse_dns_dn(error, index, true); } diff --git a/native/dns/src/response.h b/native/dns/src/response.h index 7bdcd92..01a9586 100644 --- a/native/dns/src/response.h +++ b/native/dns/src/response.h @@ -9,9 +9,10 @@ namespace tc::dns { class DNSResponse; - class DNSResponseBuffer; + class DNSResponseData; namespace response { + #ifndef WIN32 class DNSHeader { friend class tc::dns::DNSResponse; public: @@ -36,6 +37,7 @@ namespace tc::dns { explicit DNSHeader(const DNSResponse* response) : response{response} {} const DNSResponse* response{nullptr}; }; + #endif class DNSQuery { friend class tc::dns::DNSResponse; @@ -54,16 +56,30 @@ namespace tc::dns { class DNSResourceRecords { friend class tc::dns::DNSResponse; public: - [[nodiscard]] inline std::string qname() const { return this->name; } - [[nodiscard]] inline rrtype::value atype() const { return this->type; } - [[nodiscard]] inline rrclass::value aclass() const { return this->klass; } - [[nodiscard]] inline uint16_t attl() const { return this->ttl; } + [[nodiscard]] inline std::string qname() const { + return this->name; + } + [[nodiscard]] inline rrtype::value atype() const { + return this->type; + } + [[nodiscard]] inline rrclass::value aclass() const { + return this->klass; + } + [[nodiscard]] inline uint16_t attl() const { + return this->ttl; + } + #ifndef WIN32 [[nodiscard]] const uint8_t* payload_data() const; [[nodiscard]] inline size_t payload_length() const { return this->length; } [[nodiscard]] inline size_t payload_offset() const { return this->offset; } + #else + [[nodiscard]] inline PDNS_RECORDA native_record() const { return this->nrecord; } + #endif - [[nodiscard]] inline std::shared_ptr dns_packet() const { return this->packet; } + [[nodiscard]] inline std::shared_ptr dns_data() const { + return this->data; + } template [[nodiscard]] inline T parse() const { @@ -72,9 +88,15 @@ namespace tc::dns { return T{this}; } private: - DNSResourceRecords(std::shared_ptr, size_t, size_t, uint32_t, std::string , rrtype::value, rrclass::value); + std::shared_ptr data{nullptr}; + + #ifdef WIN32 + DNSResourceRecords(std::shared_ptr, PDNS_RECORDA); + + PDNS_RECORDA nrecord{nullptr}; + #else + DNSResourceRecords(std::shared_ptr, size_t, size_t, uint32_t, std::string , rrtype::value, rrclass::value); - std::shared_ptr packet{nullptr}; size_t offset{0}; size_t length{0}; @@ -83,6 +105,7 @@ namespace tc::dns { std::string name; rrtype::value type; rrclass::value klass; + #endif }; namespace rrparser {