Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize MDNS to prevent overflow and endless loop #2333

Merged
merged 5 commits into from
Aug 1, 2016
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 48 additions & 14 deletions libraries/ESP8266mDNS/ESP8266mDNS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -493,11 +493,11 @@ void MDNSResponder::_parsePacket(){
return;
}

int numAnswers = packetHeader[3];
int numAnswers = packetHeader[3] + packetHeader[5];
// Assume that the PTR answer always comes first and that it is always accompanied by a TXT, SRV, AAAA (optional) and A answer in the same packet.
if (numAnswers < 4) {
#ifdef MDNS_DEBUG_RX
Serial.println("Expected a packet with 4 answers, returning");
Serial.printf("Expected a packet with 4 or more answers, got %u\n", numAnswers);
#endif
_conn->flush();
return;
Expand All @@ -510,11 +510,14 @@ void MDNSResponder::_parsePacket(){
bool serviceMatch = false;
MDNSAnswer *answer;
uint8_t partsCollected = 0;
uint8_t stringsRead = 0;

answerHostName[0] = '\0';

// Clear answer list
if (_newQuery) {
int numAnswers = _getNumAnswers();
for (int n = numAnswers - 1; n >= 0; n--) {
int oldAnswers = _getNumAnswers();
for (int n = oldAnswers - 1; n >= 0; n--) {
answer = _getAnswerFromIdx(n);
os_free(answer->hostname);
os_free(answer);
Expand All @@ -526,21 +529,29 @@ void MDNSResponder::_parsePacket(){

while (numAnswers--) {
// Read name
stringsRead = 0;
do {
tmp8 = _conn_read8();
if (tmp8 & 0xC0) { // Compressed pointer (not supported)
tmp8 = _conn_read8();
break;
}
if (tmp8 == 0x00) { // �nd of name
if (tmp8 == 0x00) { // End of name
break;
}
if(stringsRead > 3){
#ifdef MDNS_DEBUG_RX
Serial.println("failed to read the response name");
#endif
_conn->flush();
return;
}
_conn_readS(serviceName, tmp8);
serviceName[tmp8] = '\0';
#ifdef MDNS_DEBUG_RX
Serial.printf(" %d ", tmp8);
for (int n = 0; n < tmp8; n++) {
Serial.printf("%02x ", serviceName[n]);
Serial.printf("%c", serviceName[n]);
}
Serial.println();
#endif
Expand All @@ -552,23 +563,41 @@ void MDNSResponder::_parsePacket(){
#endif
}
}
stringsRead++;
} while (true);

uint16_t answerType = _conn_read16(); // Read type
uint16_t answerClass = _conn_read16(); // Read class
uint32_t answerTtl = _conn_read32(); // Read ttl
uint16_t answerRdlength = _conn_read16(); // Read rdlength

if(answerRdlength > 255){
if(answerType == MDNS_TYPE_TXT && answerRdlength < 1460){
while(--answerRdlength) _conn->read();
} else {
#ifdef MDNS_DEBUG_RX
Serial.printf("Data len too long! %u\n", answerRdlength);
#endif
_conn->flush();
return;
}
}

#ifdef MDNS_DEBUG_RX
Serial.printf("type: %04x rdlength: %d\n", answerType, answerRdlength);
#endif

if (answerType == MDNS_TYPE_PTR) {
partsCollected |= 0x01;
_conn_readS(hostName, answerRdlength); // Read rdata
if(hostName[answerRdlength-2] & 0xc0){
memcpy(answerHostName, hostName+1, answerRdlength-3);
answerHostName[answerRdlength-3] = '\0';
}
#ifdef MDNS_DEBUG_RX
Serial.printf("PTR %d ", answerRdlength);
for (int n = 0; n < answerRdlength; n++) {
Serial.printf("%02x ", hostName[n]);
Serial.printf("%c", hostName[n]);
}
Serial.println();
#endif
Expand All @@ -578,8 +607,9 @@ void MDNSResponder::_parsePacket(){
partsCollected |= 0x02;
_conn_readS(hostName, answerRdlength); // Read rdata
#ifdef MDNS_DEBUG_RX
Serial.printf("TXT %d ", answerRdlength);
for (int n = 0; n < answerRdlength; n++) {
Serial.printf("%02x ", hostName[n]);
Serial.printf("%c", hostName[n]);
}
Serial.println();
#endif
Expand All @@ -594,14 +624,16 @@ void MDNSResponder::_parsePacket(){
// Read hostname
tmp8 = _conn_read8();
if (tmp8 & 0xC0) { // Compressed pointer (not supported)
#ifdef MDNS_DEBUG_RX
Serial.println("Skipping compressed pointer");
#endif
tmp8 = _conn_read8();
}
else {
_conn_readS(answerHostName, tmp8);
answerHostName[tmp8] = '\0';
#ifdef MDNS_DEBUG_RX
Serial.printf(" %d ", tmp8);
Serial.printf("SRV %d ", tmp8);
for (int n = 0; n < tmp8; n++) {
Serial.printf("%02x ", answerHostName[n]);
}
Expand All @@ -621,7 +653,7 @@ void MDNSResponder::_parsePacket(){
}
else {
#ifdef MDNS_DEBUG_RX
Serial.printf("Ignoring unsupported type %d\n", tmp8);
Serial.printf("Ignoring unsupported type %02x\n", tmp8);
#endif
for (int n = 0; n < answerRdlength; n++)
(void)_conn_read8();
Expand Down Expand Up @@ -654,6 +686,8 @@ void MDNSResponder::_parsePacket(){
}
answer->hostname = (char *)os_malloc(strlen(answerHostName) + 1);
os_strcpy(answer->hostname, answerHostName);
_conn->flush();
return;
}
}

Expand All @@ -663,7 +697,7 @@ void MDNSResponder::_parsePacket(){

// PARSE REQUEST NAME

hostNameLen = _conn_read8();
hostNameLen = _conn_read8() % 255;
_conn_readS(hostName, hostNameLen);
hostName[hostNameLen] = '\0';

Expand All @@ -685,7 +719,7 @@ void MDNSResponder::_parsePacket(){
}

if(!serviceParsed){
serviceNameLen = _conn_read8();
serviceNameLen = _conn_read8() % 255;
_conn_readS(serviceName, serviceNameLen);
serviceName[serviceNameLen] = '\0';

Expand Down Expand Up @@ -718,7 +752,7 @@ void MDNSResponder::_parsePacket(){
}

if(!protoParsed){
protoNameLen = _conn_read8();
protoNameLen = _conn_read8() % 255;
_conn_readS(protoName, protoNameLen);
protoName[protoNameLen] = '\0';
if(protoNameLen == 4 && protoName[0] == '_'){
Expand All @@ -740,7 +774,7 @@ void MDNSResponder::_parsePacket(){

if(!localParsed){
char localName[32];
uint8_t localNameLen = _conn_read8();
uint8_t localNameLen = _conn_read8() % 31;
_conn_readS(localName, localNameLen);
localName[localNameLen] = '\0';
tmp = _conn_read8();
Expand Down