Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added exception catches for the OkHTTPClient header vulnerability #3682

Merged
merged 15 commits into from
Aug 2, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
import datadog.communication.monitor.Monitoring;
import datadog.communication.monitor.Recording;
import datadog.trace.util.Strings;
import okhttp3.HttpUrl;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.security.NoSuchAlgorithmException;
Expand All @@ -16,12 +23,6 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import okhttp3.HttpUrl;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DDAgentFeaturesDiscovery implements DroppingPolicy {

Expand Down Expand Up @@ -70,13 +71,13 @@ public DDAgentFeaturesDiscovery(
boolean enableV05Traces,
boolean metricsEnabled) {
this.client = client;
this.agentBaseUrl = agentUrl;
agentBaseUrl = agentUrl;
this.metricsEnabled = metricsEnabled;
this.traceEndpoints =
traceEndpoints =
enableV05Traces
? new String[] {V5_ENDPOINT, V4_ENDPOINT, V3_ENDPOINT}
: new String[] {V4_ENDPOINT, V3_ENDPOINT};
this.discoveryTimer = monitoring.newTimer("trace.agent.discovery.time");
? new String[]{V5_ENDPOINT, V4_ENDPOINT, V3_ENDPOINT}
: new String[]{V4_ENDPOINT, V3_ENDPOINT};
discoveryTimer = monitoring.newTimer("trace.agent.discovery.time");
}

private void reset() {
Expand All @@ -97,12 +98,14 @@ public void discover() {
try (Recording recording = discoveryTimer.start()) {
boolean fallback = true;
try (Response response =
client
.newCall(new Request.Builder().url(agentBaseUrl.resolve("info").url()).build())
.execute()) {
client
.newCall(new Request.Builder().url(agentBaseUrl.resolve("info").url()).build())
.execute()) {
if (response.isSuccessful()) {
fallback = !processInfoResponse(response.body().string());
}
} catch (IllegalArgumentException e) {
nayeem-kamal marked this conversation as resolved.
Show resolved Hide resolved
nayeem-kamal marked this conversation as resolved.
Show resolved Hide resolved
illegalArgumentErrorQueryingEndpoint("info", getClass());
} catch (Throwable error) {
errorQueryingEndpoint("info", error);
}
Expand All @@ -119,7 +122,7 @@ public void discover() {
traceEndpoint = probeTracesEndpoint(traceEndpoints);
} else if (state == null || state.isEmpty()) {
// Still need to probe so that state is correctly assigned
probeTracesEndpoint(new String[] {traceEndpoint});
probeTracesEndpoint(new String[]{traceEndpoint});
}
}

Expand All @@ -136,25 +139,26 @@ public void discover() {
private String probeTracesEndpoint(String[] endpoints) {
for (String candidate : endpoints) {
try (Response response =
client
.newCall(
new Request.Builder()
.put(OkHttpUtils.msgpackRequestBodyOf(Collections.<ByteBuffer>emptyList()))
.url(agentBaseUrl.resolve(candidate))
.build())
.execute()) {
client
.newCall(
new Request.Builder()
.put(OkHttpUtils.msgpackRequestBodyOf(Collections.<ByteBuffer>emptyList()))
.url(agentBaseUrl.resolve(candidate))
.build())
.execute()) {
if (response.code() != 404) {
state = response.header(DATADOG_AGENT_STATE);
return candidate;
}
} catch (IOException e) {
errorQueryingEndpoint(candidate, e);
} catch (IllegalArgumentException e) {
nayeem-kamal marked this conversation as resolved.
Show resolved Hide resolved
illegalArgumentErrorQueryingEndpoint(candidate, getClass());
}
}
return V3_ENDPOINT;
}

@SuppressWarnings("unchecked")
private boolean processInfoResponse(String response) {
try {
Map<String, Object> map = RESPONSE_ADAPTER.fromJson(response);
Expand Down Expand Up @@ -209,7 +213,7 @@ private boolean processInfoResponse(String response) {
supportsDropping =
null != canDrop
&& ("true".equalsIgnoreCase(String.valueOf(canDrop))
|| Boolean.TRUE.equals(canDrop));
|| Boolean.TRUE.equals(canDrop));
}
try {
state = Strings.sha256(response);
Expand All @@ -223,7 +227,6 @@ private boolean processInfoResponse(String response) {
return false;
}

@SuppressWarnings("unchecked")
private static void discoverStatsDPort(final Map<String, Object> info) {
try {
Map<String, ?> config = (Map<String, ?>) info.get("config");
nayeem-kamal marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -274,6 +277,10 @@ private void errorQueryingEndpoint(String endpoint, Throwable t) {
log.debug("Error querying {} at {}", endpoint, agentBaseUrl, t);
}

private void illegalArgumentErrorQueryingEndpoint(String endpoint, Class c) {
log.debug("Error querying {} at {} from {}", endpoint, agentBaseUrl, c.getName());
}

public String state() {
return state;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,13 @@
import datadog.communication.ddagent.SharedCommunicationObjects;
import datadog.communication.http.OkHttpUtils;
import datadog.trace.util.AgentThreadFactory;
import okhttp3.HttpUrl;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
Expand All @@ -17,12 +24,6 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CountDownLatch;
import okhttp3.HttpUrl;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class FleetServiceImpl implements FleetService {

Expand All @@ -41,7 +42,7 @@ public class FleetServiceImpl implements FleetService {

public FleetServiceImpl(SharedCommunicationObjects sco, AgentThreadFactory agentThreadFactory) {
this.sco = sco;
this.thread = agentThreadFactory.newThread(new AgentConfigPollingRunnable());
thread = agentThreadFactory.newThread(new AgentConfigPollingRunnable());
}

@Override
Expand All @@ -59,12 +60,12 @@ public FleetService.FleetSubscription subscribe(

@Override
public void close() throws IOException {
this.thread.interrupt();
thread.interrupt();
try {
this.thread.join(5000);
thread.join(5000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
log.warn("Interrupted waiting for thread " + this.thread.getName() + "to join");
log.warn("Interrupted waiting for thread " + thread.getName() + "to join");
}
}

Expand All @@ -88,8 +89,8 @@ private class AgentConfigPollingRunnable implements Runnable {

@Override
public void run() {
this.okHttpClient = sco.okHttpClient;
this.httpUrl =
okHttpClient = sco.okHttpClient;
httpUrl =
sco.agentUrl.newBuilder().addPathSegment("v0.6").addPathSegment("config").build();

if (testingLatch != null) {
Expand Down Expand Up @@ -128,36 +129,41 @@ private boolean mainLoopIteration() throws InterruptedException {
}

private boolean fetchConfig(FleetSubscriptionImpl sub) {
Request request = OkHttpUtils.prepareRequest(httpUrl, sub.headers).get().build();
Response response;
try {
response = okHttpClient.newCall(request).execute();
} catch (IOException e) {
log.warn("IOException on HTTP class to fleet service", e);
return false;
}

if (response.code() == 200) {
byte[] body;
Request request = OkHttpUtils.prepareRequest(httpUrl, sub.headers).get().build();
Response response;
try {
body = consumeBody(response);
response = okHttpClient.newCall(request).execute();
} catch (IOException e) {
log.warn("IOException when reading fleet service response");
log.warn("IOException on HTTP class to fleet service", e);
return false;
}

digest.reset();
byte[] hash = digest.digest(body);
if (Arrays.equals(hash, sub.lastHash)) {
return true;
}
if (response.code() == 200) {
byte[] body;
try {
body = consumeBody(response);
} catch (IOException e) {
log.warn("IOException when reading fleet service response");
return false;
}

sub.lastHash = hash;
sub.listener.onNewConfiguration(new ByteArrayInputStream(body));
digest.reset();
byte[] hash = digest.digest(body);
if (Arrays.equals(hash, sub.lastHash)) {
return true;
}

return true;
} else {
log.warn("FleetService: agent responded with code " + response.code());
sub.lastHash = hash;
sub.listener.onNewConfiguration(new ByteArrayInputStream(body));

return true;
} else {
log.warn("FleetService: agent responded with code " + response.code());
return false;
}
} catch (IllegalArgumentException e) {
nayeem-kamal marked this conversation as resolved.
Show resolved Hide resolved
log.warn("Illegal argument exception in {}: fetchConfig()", getClass().getName());
return false;
}
}
Expand All @@ -181,7 +187,7 @@ private void failureWait() {
waitSeconds =
BACKOFF_INITIAL
* Math.pow(
BACKOFF_BASE, Math.min((double) consecutiveFailures - 1, BACKOFF_MAX_EXPONENT));
BACKOFF_BASE, Math.min((double) consecutiveFailures - 1, BACKOFF_MAX_EXPONENT));
if (testingLatch != null && testingLatch.getCount() > 0) {
waitSeconds = 0;
}
Expand Down Expand Up @@ -223,7 +229,7 @@ private class FleetSubscriptionImpl implements FleetService.FleetSubscription {

private FleetSubscriptionImpl(Product product, ConfigurationListener listener) {
this.product = product;
this.headers = Collections.singletonMap(CONFIG_PRODUCT_HEADER, product.name());
headers = Collections.singletonMap(CONFIG_PRODUCT_HEADER, product.name());
devinsba marked this conversation as resolved.
Show resolved Hide resolved
this.listener = listener;
}

Expand Down
Loading