From eaac1e8b248b3d84fcc2e8658e26091db78587ad Mon Sep 17 00:00:00 2001 From: SimonWilkinson Date: Sat, 5 Jan 2019 15:47:00 +0000 Subject: [PATCH] Rework DNSServer to be more robust (#5573) * DNSServer: Handle examplewww.com correctly Just replacing 'www.' with the empty string when we assign the domainname will remove all occurrences of 'www.', not just those at the start of the string. Change this to a startsWith check so that only "www." at the beginning of the string is removed. * DNSServer: Rework request handling Rewrite the request handling in the DNSServer code to address the following issues: Compatibility with EDNS #1: RFC6891 says that "Responders that choose not to implement the protocol extensions defined in this document MUST respond with a return code (RCODE) of FORMERR to messages containing an OPT record in the additional section and MUST NOT include an OPT record in the response" If we have any additional records in the request, then we need to return a FORMERR, and not whatever custom error code the user may have set. Compatibility with EDNS #2: If we're returning an error, we need to explicitly zero all of the record counters. In the existing code, if there is an additional record present in the request, we return an ARCOUNT of 1 in the response, despite including no additional records in the payload. Don't answer non-A requests If we receive an AAAA request (or any other non-A record) requests, we shouldn't respond to it with an A record. Don't answer non-IN requests If we receive a request for a non-IN type, don't answer it (it's unlikely that we'd see this in the real world) Don't read off the end of malformed packets If a packet claims to have a query, but then doesn't include one, or includes a query with malformed labels, don't read off the end of the allocated data structure. * DNSServer: Clarify and tidy writing the answer record Modify the code used to write the answer record back to the server so that it is clearer that we are writing network byte order 16-bit quantities, and to clarify what's happening with the pointer used at the start of the answer. --- libraries/DNSServer/src/DNSServer.cpp | 276 ++++++++++++++++---------- libraries/DNSServer/src/DNSServer.h | 24 ++- 2 files changed, 188 insertions(+), 112 deletions(-) diff --git a/libraries/DNSServer/src/DNSServer.cpp b/libraries/DNSServer/src/DNSServer.cpp index 5e1382bcd1..31d2c90792 100644 --- a/libraries/DNSServer/src/DNSServer.cpp +++ b/libraries/DNSServer/src/DNSServer.cpp @@ -1,6 +1,7 @@ #include "DNSServer.h" #include #include +#include #ifdef DEBUG_ESP_PORT #define DEBUG_OUTPUT DEBUG_ESP_PORT @@ -8,6 +9,8 @@ #define DEBUG_OUTPUT Serial #endif +#define DNS_HEADER_SIZE sizeof(DNSHeader) + DNSServer::DNSServer() { _ttl = lwip_htonl(60); @@ -46,149 +49,208 @@ void DNSServer::stop() void DNSServer::downcaseAndRemoveWwwPrefix(String &domainName) { domainName.toLowerCase(); - domainName.replace("www.", ""); + if (domainName.startsWith("www.")) + domainName.remove(0, 4); } -void DNSServer::processNextRequest() +void DNSServer::respondToRequest(uint8_t *buffer, size_t length) { - size_t packetSize = _udp.parsePacket(); + DNSHeader *dnsHeader; + uint8_t *query, *start; + const char *matchString; + size_t remaining, labelLength, queryLength; + uint16_t qtype, qclass; + + dnsHeader = (DNSHeader *)buffer; - if (packetSize >= sizeof(DNSHeader)) - { - uint8_t* buffer = reinterpret_cast(malloc(packetSize)); - if (buffer == NULL) return; + // Must be a query for us to do anything with it + if (dnsHeader->QR != DNS_QR_QUERY) + return; - _udp.read(buffer, packetSize); + // If operation is anything other than query, we don't do it + if (dnsHeader->OPCode != DNS_OPCODE_QUERY) + return replyWithError(dnsHeader, DNSReplyCode::NotImplemented); + + // Only support requests containing single queries - everything else + // is badly defined + if (dnsHeader->QDCount != lwip_htons(1)) + return replyWithError(dnsHeader, DNSReplyCode::FormError); + + // We must return a FormError in the case of a non-zero ARCount to + // be minimally compatible with EDNS resolvers + if (dnsHeader->ANCount != 0 || dnsHeader->NSCount != 0 + || dnsHeader->ARCount != 0) + return replyWithError(dnsHeader, DNSReplyCode::FormError); + + // Even if we're not going to use the query, we need to parse it + // so we can check the address type that's being queried + + query = start = buffer + DNS_HEADER_SIZE; + remaining = length - DNS_HEADER_SIZE; + while (remaining != 0 && *start != 0) { + labelLength = *start; + if (labelLength + 1 > remaining) + return replyWithError(dnsHeader, DNSReplyCode::FormError); + remaining -= (labelLength + 1); + start += (labelLength + 1); + } - DNSHeader* dnsHeader = reinterpret_cast(buffer); + // 1 octet labelLength, 2 octet qtype, 2 octet qclass + if (remaining < 5) + return replyWithError(dnsHeader, DNSReplyCode::FormError); - if (dnsHeader->QR == DNS_QR_QUERY && - dnsHeader->OPCode == DNS_OPCODE_QUERY && - requestIncludesOnlyOneQuestion(dnsHeader) && - (_domainName == "*" || getDomainNameWithoutWwwPrefix(buffer, packetSize) == _domainName) - ) - { - replyWithIP(buffer, packetSize); - } - else if (dnsHeader->QR == DNS_QR_QUERY) - { - replyWithCustomCode(buffer, packetSize); + start += 1; // Skip the 0 length label that we found above + + memcpy(&qtype, start, sizeof(qtype)); + start += 2; + memcpy(&qclass, start, sizeof(qclass)); + start += 2; + + queryLength = start - query; + + if (qclass != lwip_htons(DNS_QCLASS_ANY) + && qclass != lwip_htons(DNS_QCLASS_IN)) + return replyWithError(dnsHeader, DNSReplyCode::NonExistentDomain, + query, queryLength); + + if (qtype != lwip_htons(DNS_QTYPE_A) + && qtype != lwip_htons(DNS_QTYPE_ANY)) + return replyWithError(dnsHeader, DNSReplyCode::NonExistentDomain, + query, queryLength); + + // If we have no domain name configured, just return an error + if (_domainName == "") + return replyWithError(dnsHeader, _errorReplyCode, + query, queryLength); + + // If we're running with a wildcard we can just return a result now + if (_domainName == "*") + return replyWithIP(dnsHeader, query, queryLength); + + matchString = _domainName.c_str(); + + start = query; + + // If there's a leading 'www', skip it + if (*start == 3 && strncasecmp("www", (char *) start + 1, 3) == 0) + start += 4; + + while (*start != 0) { + labelLength = *start; + start += 1; + while (labelLength > 0) { + if (tolower(*start) != *matchString) + return replyWithError(dnsHeader, _errorReplyCode, + query, queryLength); + ++start; + ++matchString; + --labelLength; } + if (*start == 0 && *matchString == '\0') + return replyWithIP(dnsHeader, query, queryLength); - free(buffer); + if (*matchString != '.') + return replyWithError(dnsHeader, _errorReplyCode, + query, queryLength); + ++matchString; } -} -bool DNSServer::requestIncludesOnlyOneQuestion(const DNSHeader* dnsHeader) -{ - return lwip_ntohs(dnsHeader->QDCount) == 1 && - dnsHeader->ANCount == 0 && - dnsHeader->NSCount == 0 && - dnsHeader->ARCount == 0; + return replyWithError(dnsHeader, _errorReplyCode, + query, queryLength); } -String DNSServer::getDomainNameWithoutWwwPrefix(const uint8_t* buffer, size_t packetSize) +void DNSServer::processNextRequest() { - String parsedDomainName; - - const uint8_t* pos = buffer + sizeof(DNSHeader); - const uint8_t* end = buffer + packetSize; - - // to minimize reallocations due to concats below - // we reserve enough space that a median or average domain - // name size cold be easily contained without a reallocation - // - max size would be 253, in 2013, average is 11 and max was 42 - // - parsedDomainName.reserve(32); - - uint8_t labelLength = *pos; - - while (true) - { - if (labelLength == 0) - { - // no more labels - downcaseAndRemoveWwwPrefix(parsedDomainName); - return parsedDomainName; - } + size_t currentPacketSize; - // append next label - for (int i = 0; i < labelLength && pos < end; i++) - { - pos++; - parsedDomainName += static_cast(*pos); - } + currentPacketSize = _udp.parsePacket(); + if (currentPacketSize == 0) + return; - if (pos >= end) - { - // malformed packet, return an empty domain name - parsedDomainName = ""; - return parsedDomainName; - } - else - { - // next label - pos++; - labelLength = *pos; - - // if there is another label, add delimiter - if (labelLength != 0) - { - parsedDomainName += "."; - } - } - } + // The DNS RFC requires that DNS packets be less than 512 bytes in size, + // so just discard them if they are larger + if (currentPacketSize > MAX_DNS_PACKETSIZE) + return; + + // If the packet size is smaller than the DNS header, then someone is + // messing with us + if (currentPacketSize < DNS_HEADER_SIZE) + return; + + std::unique_ptr buffer(new (std::nothrow) uint8_t[currentPacketSize]); + + if (buffer == NULL) + return; + + _udp.read(buffer.get(), currentPacketSize); + respondToRequest(buffer.get(), currentPacketSize); +} + +void DNSServer::writeNBOShort(uint16_t value) +{ + _udp.write((unsigned char *)&value, 2); } -void DNSServer::replyWithIP(uint8_t* buffer, size_t packetSize) +void DNSServer::replyWithIP(DNSHeader *dnsHeader, + unsigned char * query, + size_t queryLength) { - DNSHeader* dnsHeader = reinterpret_cast(buffer); + uint16_t value; dnsHeader->QR = DNS_QR_RESPONSE; - dnsHeader->ANCount = dnsHeader->QDCount; - dnsHeader->QDCount = dnsHeader->QDCount; - //dnsHeader->RA = 1; + dnsHeader->QDCount = lwip_htons(1); + dnsHeader->ANCount = lwip_htons(1); + dnsHeader->NSCount = 0; + dnsHeader->ARCount = 0; _udp.beginPacket(_udp.remoteIP(), _udp.remotePort()); - _udp.write(buffer, packetSize); + _udp.write((unsigned char *) dnsHeader, sizeof(DNSHeader)); + _udp.write(query, queryLength); + + // Rather than restate the name here, we use a pointer to the name contained + // in the query section. Pointers have the top two bits set. + value = 0xC000 | DNS_HEADER_SIZE; + writeNBOShort(lwip_htons(value)); - _udp.write((uint8_t)192); // answer name is a pointer - _udp.write((uint8_t)12); // pointer to offset at 0x00c + // Answer is type A (an IPv4 address) + writeNBOShort(lwip_htons(DNS_QTYPE_A)); - _udp.write((uint8_t)0); // 0x0001 answer is type A query (host address) - _udp.write((uint8_t)1); + // Answer is in the Internet Class + writeNBOShort(lwip_htons(DNS_QCLASS_IN)); - _udp.write((uint8_t)0); //0x0001 answer is class IN (internet address) - _udp.write((uint8_t)1); - + // Output TTL (already NBO) _udp.write((unsigned char*)&_ttl, 4); // Length of RData is 4 bytes (because, in this case, RData is IPv4) - _udp.write((uint8_t)0); - _udp.write((uint8_t)4); + writeNBOShort(lwip_htons(sizeof(_resolvedIP))); _udp.write(_resolvedIP, sizeof(_resolvedIP)); _udp.endPacket(); - - #ifdef DEBUG_ESP_DNS - DEBUG_OUTPUT.printf("DNS responds: %s for %s\n", - IPAddress(_resolvedIP).toString().c_str(), getDomainNameWithoutWwwPrefix(buffer, packetSize).c_str() ); - #endif } -void DNSServer::replyWithCustomCode(uint8_t* buffer, size_t packetSize) +void DNSServer::replyWithError(DNSHeader *dnsHeader, + DNSReplyCode rcode, + unsigned char *query, + size_t queryLength) { - if (packetSize < sizeof(DNSHeader)) - { - return; - } - - DNSHeader* dnsHeader = reinterpret_cast(buffer); - dnsHeader->QR = DNS_QR_RESPONSE; - dnsHeader->RCode = (unsigned char)_errorReplyCode; - dnsHeader->QDCount = 0; + dnsHeader->RCode = (unsigned char) rcode; + if (query) + dnsHeader->QDCount = lwip_htons(1); + else + dnsHeader->QDCount = 0; + dnsHeader->ANCount = 0; + dnsHeader->NSCount = 0; + dnsHeader->ARCount = 0; _udp.beginPacket(_udp.remoteIP(), _udp.remotePort()); - _udp.write(buffer, sizeof(DNSHeader)); + _udp.write((unsigned char *)dnsHeader, sizeof(DNSHeader)); + if (query != NULL) + _udp.write(query, queryLength); _udp.endPacket(); } + +void DNSServer::replyWithError(DNSHeader *dnsHeader, + DNSReplyCode rcode) +{ + replyWithError(dnsHeader, rcode, NULL, 0); +} diff --git a/libraries/DNSServer/src/DNSServer.h b/libraries/DNSServer/src/DNSServer.h index d6e7de444d..0f3ebd7a34 100644 --- a/libraries/DNSServer/src/DNSServer.h +++ b/libraries/DNSServer/src/DNSServer.h @@ -6,7 +6,14 @@ #define DNS_QR_RESPONSE 1 #define DNS_OPCODE_QUERY 0 +#define DNS_QCLASS_IN 1 +#define DNS_QCLASS_ANY 255 + +#define DNS_QTYPE_A 1 +#define DNS_QTYPE_ANY 255 + #define MAX_DNSNAME_LENGTH 253 +#define MAX_DNS_PACKETSIZE 512 enum class DNSReplyCode { @@ -65,9 +72,16 @@ class DNSServer DNSReplyCode _errorReplyCode; void downcaseAndRemoveWwwPrefix(String &domainName); - String getDomainNameWithoutWwwPrefix(const uint8_t* buffer, size_t packetSize); - bool requestIncludesOnlyOneQuestion(const DNSHeader* dnsHeader); - void replyWithIP(uint8_t* buffer, size_t packetSize); - void replyWithCustomCode(uint8_t* buffer, size_t packetSize); + void replyWithIP(DNSHeader *dnsHeader, + unsigned char * query, + size_t queryLength); + void replyWithError(DNSHeader *dnsHeader, + DNSReplyCode rcode, + unsigned char *query, + size_t queryLength); + void replyWithError(DNSHeader *dnsHeader, + DNSReplyCode rcode); + void respondToRequest(uint8_t *buffer, size_t length); + void writeNBOShort(uint16_t value); }; -#endif \ No newline at end of file +#endif