Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added exception catches for the OkHTTPClient header vulnerability #3682

Merged
merged 15 commits into from
Aug 2, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ public DDAgentFeaturesDiscovery(
boolean enableV05Traces,
boolean metricsEnabled) {
this.client = client;
this.agentBaseUrl = agentUrl;
agentBaseUrl = agentUrl;
this.metricsEnabled = metricsEnabled;
this.traceEndpoints =
traceEndpoints =
enableV05Traces
? new String[] {V5_ENDPOINT, V4_ENDPOINT, V3_ENDPOINT}
: new String[] {V4_ENDPOINT, V3_ENDPOINT};
this.discoveryTimer = monitoring.newTimer("trace.agent.discovery.time");
discoveryTimer = monitoring.newTimer("trace.agent.discovery.time");
}

private void reset() {
Expand Down Expand Up @@ -154,7 +154,6 @@ private String probeTracesEndpoint(String[] endpoints) {
return V3_ENDPOINT;
}

@SuppressWarnings("unchecked")
private boolean processInfoResponse(String response) {
try {
Map<String, Object> map = RESPONSE_ADAPTER.fromJson(response);
Expand Down Expand Up @@ -223,7 +222,6 @@ private boolean processInfoResponse(String response) {
return false;
}

@SuppressWarnings("unchecked")
private static void discoverStatsDPort(final Map<String, Object> info) {
try {
Map<String, ?> config = (Map<String, ?>) info.get("config");
nayeem-kamal marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public class FleetServiceImpl implements FleetService {

public FleetServiceImpl(SharedCommunicationObjects sco, AgentThreadFactory agentThreadFactory) {
this.sco = sco;
this.thread = agentThreadFactory.newThread(new AgentConfigPollingRunnable());
thread = agentThreadFactory.newThread(new AgentConfigPollingRunnable());
}

@Override
Expand All @@ -59,12 +59,12 @@ public FleetService.FleetSubscription subscribe(

@Override
public void close() throws IOException {
this.thread.interrupt();
thread.interrupt();
try {
this.thread.join(5000);
thread.join(5000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
log.warn("Interrupted waiting for thread " + this.thread.getName() + "to join");
log.warn("Interrupted waiting for thread " + thread.getName() + "to join");
}
}

Expand All @@ -88,9 +88,8 @@ private class AgentConfigPollingRunnable implements Runnable {

@Override
public void run() {
this.okHttpClient = sco.okHttpClient;
this.httpUrl =
sco.agentUrl.newBuilder().addPathSegment("v0.6").addPathSegment("config").build();
okHttpClient = sco.okHttpClient;
httpUrl = sco.agentUrl.newBuilder().addPathSegment("v0.6").addPathSegment("config").build();

if (testingLatch != null) {
testingLatch.countDown();
Expand Down Expand Up @@ -128,6 +127,7 @@ private boolean mainLoopIteration() throws InterruptedException {
}

private boolean fetchConfig(FleetSubscriptionImpl sub) {

Request request = OkHttpUtils.prepareRequest(httpUrl, sub.headers).get().build();
Response response;
try {
Expand Down Expand Up @@ -223,7 +223,7 @@ private class FleetSubscriptionImpl implements FleetService.FleetSubscription {

private FleetSubscriptionImpl(Product product, ConfigurationListener listener) {
this.product = product;
this.headers = Collections.singletonMap(CONFIG_PRODUCT_HEADER, product.name());
headers = Collections.singletonMap(CONFIG_PRODUCT_HEADER, product.name());
devinsba marked this conversation as resolved.
Show resolved Hide resolved
this.listener = listener;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,16 @@ private static OkHttpClient buildHttpClient(
public Request authenticate(final Route route, final Response response) {
final String credential =
Credentials.basic(proxyUsername, proxyPassword == null ? "" : proxyPassword);
return response
.request()
.newBuilder()
.header("Proxy-Authorization", credential)
.build();
try {
return response
.request()
.newBuilder()
.header("Proxy-Authorization", credential)
.build();
} catch (IllegalArgumentException e) {
throw new IllegalArgumentException(
nayeem-kamal marked this conversation as resolved.
Show resolved Hide resolved
"IllegalArgumentException at Proxy-Authorization header");
}
}
});
}
Expand All @@ -175,24 +180,28 @@ public Request authenticate(final Route route, final Response response) {
}

public static Request.Builder prepareRequest(final HttpUrl url, Map<String, String> headers) {
final Request.Builder builder =
new Request.Builder()
.url(url)
.addHeader(DATADOG_META_LANG, "java")
.addHeader(DATADOG_META_LANG_VERSION, JAVA_VERSION)
.addHeader(DATADOG_META_LANG_INTERPRETER, JAVA_VM_NAME)
.addHeader(DATADOG_META_LANG_INTERPRETER_VENDOR, JAVA_VM_VENDOR);

final String containerId = ContainerInfo.get().getContainerId();
if (containerId != null) {
builder.addHeader(DATADOG_CONTAINER_ID, containerId);
}
try {
final Request.Builder builder =
new Request.Builder()
.url(url)
.addHeader(DATADOG_META_LANG, "java")
.addHeader(DATADOG_META_LANG_VERSION, JAVA_VERSION)
.addHeader(DATADOG_META_LANG_INTERPRETER, JAVA_VM_NAME)
.addHeader(DATADOG_META_LANG_INTERPRETER_VENDOR, JAVA_VM_VENDOR);

final String containerId = ContainerInfo.get().getContainerId();
if (containerId != null) {
builder.addHeader(DATADOG_CONTAINER_ID, containerId);
}

for (Map.Entry<String, String> e : headers.entrySet()) {
builder.addHeader(e.getKey(), e.getValue());
}
for (Map.Entry<String, String> e : headers.entrySet()) {
builder.addHeader(e.getKey(), e.getValue());
}

return builder;
return builder;
} catch (IllegalArgumentException e) {
throw new IllegalArgumentException();
nayeem-kamal marked this conversation as resolved.
Show resolved Hide resolved
}
}

public static Request.Builder prepareRequest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,11 @@ public BatchUploader(Config config) {
// see: https://github.com/DataDog/dd-trace-java/pull/1582
clientBuilder.connectionSpecs(Collections.singletonList(ConnectionSpec.CLEARTEXT));
}
client = clientBuilder.build();
try {
client = clientBuilder.build();
} catch (IllegalArgumentException e) {
throw new IllegalArgumentException();
devinsba marked this conversation as resolved.
Show resolved Hide resolved
}
client.dispatcher().setMaxRequests(MAX_RUNNING_REQUESTS);
// We are mainly talking to the same(ish) host so we need to raise this limit
client.dispatcher().setMaxRequestsPerHost(MAX_RUNNING_REQUESTS);
Expand Down Expand Up @@ -151,28 +155,32 @@ private void makeUploadRequest(byte[] json, String tags) throws IOException {
if (!tags.isEmpty()) {
builder.addQueryParameter("ddtags", tags);
}
Request.Builder requestBuilder = new Request.Builder().url(builder.build()).post(body);
if (apiKey != null) {
if (apiKey.isEmpty()) {
log.debug("API key is empty");
try {
Request.Builder requestBuilder = new Request.Builder().url(builder.build()).post(body);
if (apiKey != null) {
if (apiKey.isEmpty()) {
log.debug("API key is empty");
}
if (apiKey.length() != 32) {
log.debug(
"API key length is incorrect (truncated?) expected=32 actual={} API key={}...",
apiKey.length(),
apiKey.substring(0, Math.min(apiKey.length(), 6)));
}
requestBuilder.addHeader(HEADER_DD_API_KEY, apiKey);
nayeem-kamal marked this conversation as resolved.
Show resolved Hide resolved
} else {
log.debug("API key is null");
}
if (apiKey.length() != 32) {
log.debug(
"API key length is incorrect (truncated?) expected=32 actual={} API key={}...",
apiKey.length(),
apiKey.substring(0, Math.min(apiKey.length(), 6)));
if (containerId != null) {
requestBuilder.addHeader(HEADER_DD_CONTAINER_ID, containerId);
}
requestBuilder.addHeader(HEADER_DD_API_KEY, apiKey);
} else {
log.debug("API key is null");
}
if (containerId != null) {
requestBuilder.addHeader(HEADER_DD_CONTAINER_ID, containerId);
Request request = requestBuilder.build();
log.debug("Sending request: {} CT: {}", request, request.body().contentType());
client.newCall(request).enqueue(responseCallback);
inflightRequests.register();
} catch (IllegalArgumentException e) {
throw new IllegalArgumentException();
}
Request request = requestBuilder.build();
log.debug("Sending request: {} CT: {}", request, request.body().contentType());
client.newCall(request).enqueue(responseCallback);
inflightRequests.register();
}

public void shutdown() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ public OkHttpSink(
boolean compressionEnabled,
Map<String, String> headers) {
this.client = client;
this.metricsUrl = HttpUrl.get(agentUrl).resolve(path);
this.listeners = new CopyOnWriteArrayList<>();
metricsUrl = HttpUrl.get(agentUrl).resolve(path);
listeners = new CopyOnWriteArrayList<>();
this.bufferingEnabled = bufferingEnabled;
this.compressionEnabled = compressionEnabled;
this.headers = new HashMap<>(headers);
Expand All @@ -71,6 +71,7 @@ public void accept(int messageCount, ByteBuffer buffer) {
// without copying the buffer, otherwise this needs to be async,
// so need to copy and buffer the request, and let it be executed
// on the main task scheduler as a last resort

if (!bufferingEnabled || lastRequestTime.get() < ASYNC_THRESHOLD_LATENCY) {
send(prepareRequest(metricsUrl, headers).post(makeRequestBody(buffer)).build());
AgentTaskScheduler.Scheduled<OkHttpSink> future = this.future;
Expand All @@ -82,7 +83,7 @@ public void accept(int messageCount, ByteBuffer buffer) {
}
} else {
if (asyncTaskStarted.compareAndSet(false, true)) {
this.future =
future =
AgentTaskScheduler.INSTANCE.scheduleAtFixedRate(
new Sender(enqueuedRequests), this, 1, 1, SECONDS);
}
Expand Down Expand Up @@ -141,7 +142,7 @@ public void onEvent(EventListener.EventType eventType, String message) {

@Override
public void register(EventListener listener) {
this.listeners.add(listener);
listeners.add(listener);
}

private void handleFailure(okhttp3.Response response) throws IOException {
devinsba marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,22 +77,24 @@ public DDAgentApi(
boolean metricsEnabled) {
this.featuresDiscovery = featuresDiscovery;
this.agentUrl = agentUrl;
this.httpClient = client;
this.sendPayloadTimer = monitoring.newTimer("trace.agent.send.time");
this.agentErrorCounter = monitoring.newCounter("trace.agent.error.counter");
httpClient = client;
sendPayloadTimer = monitoring.newTimer("trace.agent.send.time");
agentErrorCounter = monitoring.newCounter("trace.agent.error.counter");
this.metricsEnabled = metricsEnabled;

this.headers = new HashMap<>();
this.headers.put(DATADOG_CLIENT_COMPUTED_TOP_LEVEL, "true");
this.headers.put(DATADOG_META_TRACER_VERSION, DDTraceCoreInfo.VERSION);
headers = new HashMap<>();
headers.put(DATADOG_CLIENT_COMPUTED_TOP_LEVEL, "true");
headers.put(DATADOG_META_TRACER_VERSION, DDTraceCoreInfo.VERSION);
}

@Override
public void addResponseListener(final RemoteResponseListener listener) {
if (!responseListeners.contains(listener)) {
responseListeners.add(listener);
}
}

@Override
public Response sendSerializedTraces(final Payload payload) {
final int sizeInBytes = payload.sizeInBytes();
String tracesEndpoint = featuresDiscovery.getTraceEndpoint();
Expand All @@ -118,8 +120,8 @@ public Response sendSerializedTraces(final Payload payload) {
metricsEnabled && featuresDiscovery.supportsMetrics() ? "true" : "")
.put(payload.toRequest())
.build();
this.totalTraces += payload.traceCount();
this.receivedTraces += payload.traceCount();
totalTraces += payload.traceCount();
receivedTraces += payload.traceCount();
try (final Recording recording = sendPayloadTimer.start();
final okhttp3.Response response = httpClient.newCall(request).execute()) {
handleAgentChange(response.header(DATADOG_AGENT_STATE));
Expand Down Expand Up @@ -161,7 +163,7 @@ private void handleAgentChange(String state) {

private void countAndLogSuccessfulSend(final int traceCount, final int sizeInBytes) {
// count the successful traces
this.sentTraces += traceCount;
sentTraces += traceCount;

ioLogger.success(createSendLogMessage(traceCount, sizeInBytes, "Success"));
}
Expand All @@ -172,7 +174,7 @@ private void countAndLogFailedSend(
final okhttp3.Response response,
final IOException outer) {
// count the failed traces
this.failedTraces += traceCount;
failedTraces += traceCount;
// these are used to catch and log if there is a failure in debug logging the response body
String agentError = getResponseBody(response);
String sendErrorString =
Expand Down Expand Up @@ -209,13 +211,13 @@ private String createSendLogMessage(
+ ")"
+ " traces to the DD agent."
+ " Total: "
+ this.totalTraces
+ totalTraces
+ ", Received: "
+ this.receivedTraces
+ receivedTraces
+ ", Sent: "
+ this.sentTraces
+ sentTraces
+ ", Failed: "
+ this.failedTraces
+ failedTraces
+ ".";
}
}
devinsba marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ public RemoteApi.Response sendSerializedTraces(Payload payload) {
.addHeader(DD_API_KEY_HEADER, apiKey)
.post(payload.toRequest())
.build();
this.totalTraces += payload.traceCount();
this.receivedTraces += payload.traceCount();
totalTraces += payload.traceCount();
receivedTraces += payload.traceCount();

int httpCode = 0;
IOException lastException = null;
Expand Down Expand Up @@ -172,6 +172,8 @@ public RemoteApi.Response sendSerializedTraces(Payload payload) {
return RemoteApi.Response.success(httpCode);
}
}
} catch (IllegalArgumentException e) {
devinsba marked this conversation as resolved.
Show resolved Hide resolved
throw new IllegalArgumentException();
} catch (final IOException e) {
countAndLogFailedSend(payload.traceCount(), sizeInBytes, null, e);
return RemoteApi.Response.failed(e);
Expand All @@ -180,7 +182,7 @@ public RemoteApi.Response sendSerializedTraces(Payload payload) {

private void countAndLogSuccessfulSend(final int traceCount, final int sizeInBytes) {
// count the successful traces
this.sentTraces += traceCount;
sentTraces += traceCount;

ioLogger.success(createSendLogMessage(traceCount, sizeInBytes, "Success"));
}
Expand All @@ -191,7 +193,7 @@ private void countAndLogFailedSend(
final DDIntakeApi.Response response,
final IOException outer) {
// count the failed traces
this.failedTraces += traceCount;
failedTraces += traceCount;
// these are used to catch and log if there is a failure in debug logging the response body
String intakeError = response != null ? response.body : "";
String sendErrorString =
Expand Down Expand Up @@ -229,13 +231,13 @@ private String createSendLogMessage(
+ ")"
+ " traces to the DD Intake."
+ " Total: "
+ this.totalTraces
+ totalTraces
+ ", Received: "
+ this.receivedTraces
+ receivedTraces
+ ", Sent: "
+ this.sentTraces
+ sentTraces
+ ", Failed: "
+ this.failedTraces
+ failedTraces
+ ".";
}

devinsba marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
Loading