diff --git a/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/qos/ThriftQueryWeighers.java b/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/qos/ThriftQueryWeighers.java index 44874ed6b86..6e0a36be0db 100644 --- a/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/qos/ThriftQueryWeighers.java +++ b/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/qos/ThriftQueryWeighers.java @@ -32,6 +32,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Suppliers; import com.palantir.atlasdb.keyvalue.cassandra.CassandraClient; import com.palantir.atlasdb.qos.ImmutableQueryWeight; @@ -42,7 +43,8 @@ public final class ThriftQueryWeighers { private static final Logger log = LoggerFactory.getLogger(CassandraClient.class); - public static final QueryWeight DEFAULT_ESTIMATED_WEIGHT = ImmutableQueryWeight.builder() + @VisibleForTesting + static final QueryWeight DEFAULT_ESTIMATED_WEIGHT = ImmutableQueryWeight.builder() .numBytes(100) .numDistinctRows(1) .timeTakenNanos(TimeUnit.MILLISECONDS.toNanos(2)) @@ -86,13 +88,21 @@ public QueryWeight estimate() { } @Override - public QueryWeight weigh(T result, long timeTakenNanos) { + public QueryWeight weighSuccess(T result, long timeTakenNanos) { return ImmutableQueryWeight.builder() .numBytes(safeGetNumBytesOrDefault(() -> bytesRead.apply(result))) .timeTakenNanos(timeTakenNanos) .numDistinctRows(numRows.apply(result)) .build(); } + + @Override + public QueryWeight weighFailure(Exception error, long timeTakenNanos) { + return ImmutableQueryWeight.builder() + .from(estimate()) + .timeTakenNanos(timeTakenNanos) + .build(); + } }; } @@ -110,7 +120,15 @@ public QueryWeight estimate() { } @Override - public QueryWeight weigh(T result, long timeTakenNanos) { + public QueryWeight weighSuccess(T result, long timeTakenNanos) { + return ImmutableQueryWeight.builder() + .from(estimate()) + .timeTakenNanos(timeTakenNanos) + .build(); + } + + @Override + public QueryWeight weighFailure(Exception error, long timeTakenNanos) { return ImmutableQueryWeight.builder() .from(estimate()) .timeTakenNanos(timeTakenNanos) diff --git a/atlasdb-cassandra/src/test/java/com/palantir/atlasdb/keyvalue/cassandra/qos/ThriftQueryWeighersTest.java b/atlasdb-cassandra/src/test/java/com/palantir/atlasdb/keyvalue/cassandra/qos/ThriftQueryWeighersTest.java index b26f146263c..2ee01172eb8 100644 --- a/atlasdb-cassandra/src/test/java/com/palantir/atlasdb/keyvalue/cassandra/qos/ThriftQueryWeighersTest.java +++ b/atlasdb-cassandra/src/test/java/com/palantir/atlasdb/keyvalue/cassandra/qos/ThriftQueryWeighersTest.java @@ -32,6 +32,9 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.palantir.atlasdb.qos.ImmutableQueryWeight; +import com.palantir.atlasdb.qos.QosClient; +import com.palantir.atlasdb.qos.QueryWeight; public class ThriftQueryWeighersTest { @@ -42,7 +45,12 @@ public class ThriftQueryWeighersTest { private static final KeySlice KEY_SLICE = new KeySlice(); private static final Mutation MUTATION = new Mutation(); - private static final long UNIMPORTANT_ARG = 123L; + private static final long TIME_TAKEN = 123L; + + private static final QueryWeight DEFAULT_WEIGHT = ImmutableQueryWeight.builder() + .from(ThriftQueryWeighers.DEFAULT_ESTIMATED_WEIGHT) + .timeTakenNanos(TIME_TAKEN) + .build(); @Test public void multigetSliceWeigherReturnsCorrectNumRows() { @@ -50,7 +58,7 @@ public void multigetSliceWeigherReturnsCorrectNumRows() { BYTES1, ImmutableList.of(COLUMN_OR_SUPER, COLUMN_OR_SUPER), BYTES2, ImmutableList.of(COLUMN_OR_SUPER)); - long actualNumRows = ThriftQueryWeighers.MULTIGET_SLICE.weigh(result, UNIMPORTANT_ARG).numDistinctRows(); + long actualNumRows = ThriftQueryWeighers.MULTIGET_SLICE.weighSuccess(result, TIME_TAKEN).numDistinctRows(); assertThat(actualNumRows).isEqualTo(2); } @@ -59,30 +67,30 @@ public void multigetSliceWeigherReturnsCorrectNumRows() { public void rangeSlicesWeigherReturnsCorrectNumRows() { List result = ImmutableList.of(KEY_SLICE, KEY_SLICE, KEY_SLICE); - long actualNumRows = ThriftQueryWeighers.GET_RANGE_SLICES.weigh(result, UNIMPORTANT_ARG).numDistinctRows(); + long actualNumRows = ThriftQueryWeighers.GET_RANGE_SLICES.weighSuccess(result, TIME_TAKEN).numDistinctRows(); assertThat(actualNumRows).isEqualTo(3); } @Test public void getWeigherReturnsCorrectNumRows() { - long actualNumRows = ThriftQueryWeighers.GET.weigh(COLUMN_OR_SUPER, UNIMPORTANT_ARG).numDistinctRows(); + long actualNumRows = ThriftQueryWeighers.GET.weighSuccess(COLUMN_OR_SUPER, TIME_TAKEN).numDistinctRows(); assertThat(actualNumRows).isEqualTo(1); } @Test public void executeCql3QueryWeigherReturnsOneRowAlways() { - long actualNumRows = ThriftQueryWeighers.EXECUTE_CQL3_QUERY.weigh(new CqlResult(), - UNIMPORTANT_ARG).numDistinctRows(); + long actualNumRows = ThriftQueryWeighers.EXECUTE_CQL3_QUERY.weighSuccess(new CqlResult(), + TIME_TAKEN).numDistinctRows(); assertThat(actualNumRows).isEqualTo(1); } @Test public void casQueryWeigherReturnsOneRowAlways() { - long actualNumRows = ThriftQueryWeighers.cas(ImmutableList.of(COLUMN, COLUMN)).weigh(new CASResult(true), - UNIMPORTANT_ARG).numDistinctRows(); + long actualNumRows = ThriftQueryWeighers.cas(ImmutableList.of(COLUMN, COLUMN)).weighSuccess(new CASResult(true), + TIME_TAKEN).numDistinctRows(); assertThat(actualNumRows).isEqualTo(1); } @@ -96,10 +104,68 @@ public void batchMutateWeigherReturnsCorrectNumRows() { BYTES2, ImmutableMap.of( "baz", ImmutableList.of(MUTATION))); - long actualNumRows = ThriftQueryWeighers.batchMutate(mutations).weigh(null, UNIMPORTANT_ARG) + long actualNumRows = ThriftQueryWeighers.batchMutate(mutations).weighSuccess(null, TIME_TAKEN) .numDistinctRows(); assertThat(actualNumRows).isEqualTo(3); } + @Test + public void multigetSliceWeigherReturnsDefaultEstimateForFailure() { + QueryWeight weight = ThriftQueryWeighers.MULTIGET_SLICE.weighFailure(new RuntimeException(), TIME_TAKEN); + + assertThat(weight).isEqualTo(DEFAULT_WEIGHT); + } + + @Test + public void getWeigherReturnsDefaultEstimateForFailure() { + QueryWeight weight = ThriftQueryWeighers.GET.weighFailure(new RuntimeException(), TIME_TAKEN); + + assertThat(weight).isEqualTo(DEFAULT_WEIGHT); + } + + @Test + public void getRangeSlicesWeigherReturnsDefaultEstimateForFailure() { + QueryWeight weight = ThriftQueryWeighers.GET_RANGE_SLICES.weighFailure(new RuntimeException(), TIME_TAKEN); + + assertThat(weight).isEqualTo(DEFAULT_WEIGHT); + } + + @Test + public void batchMutateWeigherReturnsEstimateForFailure() { + Map>> mutations = ImmutableMap.of( + BYTES1, ImmutableMap.of("foo", ImmutableList.of(MUTATION, MUTATION))); + + QosClient.QueryWeigher weigher = ThriftQueryWeighers.batchMutate(mutations); + + QueryWeight expected = ImmutableQueryWeight.builder() + .from(weigher.estimate()) + .timeTakenNanos(TIME_TAKEN) + .build(); + QueryWeight actual = weigher.weighFailure(new RuntimeException(), TIME_TAKEN); + + assertThat(actual).isEqualTo(expected); + } + + @Test + public void casWeigherReturnsEstimateForFailure() { + QosClient.QueryWeigher weigher = ThriftQueryWeighers.cas(ImmutableList.of(COLUMN, COLUMN)); + + QueryWeight expected = ImmutableQueryWeight.builder() + .from(weigher.estimate()) + .timeTakenNanos(TIME_TAKEN) + .build(); + QueryWeight actual = weigher.weighFailure(new RuntimeException(), TIME_TAKEN); + + assertThat(actual).isEqualTo(expected); + } + + @Test + public void cql3QueryWeigherReturnsDefaultEstimateForFailure() { + QueryWeight weight = ThriftQueryWeighers.EXECUTE_CQL3_QUERY.weighFailure(new RuntimeException(), + TIME_TAKEN); + + assertThat(weight).isEqualTo(DEFAULT_WEIGHT); + } + } diff --git a/qos-service-api/src/main/java/com/palantir/atlasdb/qos/QosClient.java b/qos-service-api/src/main/java/com/palantir/atlasdb/qos/QosClient.java index 775448b96a8..3c5be52a3bd 100644 --- a/qos-service-api/src/main/java/com/palantir/atlasdb/qos/QosClient.java +++ b/qos-service-api/src/main/java/com/palantir/atlasdb/qos/QosClient.java @@ -24,7 +24,8 @@ interface Query { interface QueryWeigher { QueryWeight estimate(); - QueryWeight weigh(T result, long timeTakenNanos); + QueryWeight weighSuccess(T result, long timeTakenNanos); + QueryWeight weighFailure(Exception error, long timeTakenNanos); } T executeRead( diff --git a/qos-service-impl/src/main/java/com/palantir/atlasdb/qos/client/AtlasDbQosClient.java b/qos-service-impl/src/main/java/com/palantir/atlasdb/qos/client/AtlasDbQosClient.java index 6732368940f..6e4624c8ee0 100644 --- a/qos-service-impl/src/main/java/com/palantir/atlasdb/qos/client/AtlasDbQosClient.java +++ b/qos-service-impl/src/main/java/com/palantir/atlasdb/qos/client/AtlasDbQosClient.java @@ -15,12 +15,14 @@ */ package com.palantir.atlasdb.qos.client; +import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Stopwatch; import com.google.common.base.Ticker; import com.palantir.atlasdb.qos.QosClient; import com.palantir.atlasdb.qos.QueryWeight; @@ -65,16 +67,20 @@ private T execute( long estimatedNumBytes = weigher.estimate().numBytes(); rateLimiter.consumeWithBackoff(estimatedNumBytes); - // TODO(nziebart): decide what to do if we encounter a timeout exception - long startTimeNanos = ticker.read(); - T result = query.execute(); - long totalTimeNanos = ticker.read() - startTimeNanos; + Stopwatch timer = Stopwatch.createStarted(ticker); - QueryWeight actualWeight = weigher.weigh(result, totalTimeNanos); - weightMetric.accept(actualWeight); - rateLimiter.recordAdjustment(actualWeight.numBytes() - estimatedNumBytes); - - return result; + QueryWeight actualWeight = null; + try { + T result = query.execute(); + actualWeight = weigher.weighSuccess(result, timer.elapsed(TimeUnit.NANOSECONDS)); + return result; + } catch (Exception ex) { + actualWeight = weigher.weighFailure(ex, timer.elapsed(TimeUnit.NANOSECONDS)); + throw ex; + } finally { + weightMetric.accept(actualWeight); + rateLimiter.recordAdjustment(actualWeight.numBytes() - estimatedNumBytes); + } } } diff --git a/qos-service-impl/src/test/java/com/palantir/atlasdb/qos/client/AtlasDbQosClientTest.java b/qos-service-impl/src/test/java/com/palantir/atlasdb/qos/client/AtlasDbQosClientTest.java index fc5d6f94791..2e347874c22 100644 --- a/qos-service-impl/src/test/java/com/palantir/atlasdb/qos/client/AtlasDbQosClientTest.java +++ b/qos-service-impl/src/test/java/com/palantir/atlasdb/qos/client/AtlasDbQosClientTest.java @@ -20,6 +20,7 @@ import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyLong; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; @@ -72,7 +73,9 @@ public void setUp() { when(ticker.read()).thenReturn(START_NANOS).thenReturn(END_NANOS); when(weigher.estimate()).thenReturn(ESTIMATED_WEIGHT); - when(weigher.weigh(any(), anyLong())).thenReturn(ACTUAL_WEIGHT); + when(weigher.weighSuccess(any(), anyLong())).thenReturn(ACTUAL_WEIGHT); + when(weigher.weighFailure(any(), anyLong())).thenReturn(ACTUAL_WEIGHT); + } @Test @@ -96,7 +99,7 @@ public void recordsReadMetrics() throws TestCheckedException { public void passesResultAndTimeToReadWeigher() throws TestCheckedException { qosClient.executeRead(() -> "foo", weigher); - verify(weigher).weigh("foo", TOTAL_NANOS); + verify(weigher).weighSuccess("foo", TOTAL_NANOS); } @Test @@ -116,6 +119,39 @@ public void recordsWriteMetrics() throws TestCheckedException { verifyNoMoreInteractions(metrics); } + @Test + public void recordsReadMetricsOnFailure() throws TestCheckedException { + TestCheckedException error = new TestCheckedException(); + assertThatThrownBy(() -> qosClient.executeRead(() -> { + throw error; + }, weigher)).isInstanceOf(TestCheckedException.class); + + verify(metrics).recordRead(ACTUAL_WEIGHT); + verifyNoMoreInteractions(metrics); + } + + @Test + public void recordsWriteMetricsOnFailure() throws TestCheckedException { + TestCheckedException error = new TestCheckedException(); + assertThatThrownBy(() -> qosClient.executeWrite(() -> { + throw error; + }, weigher)).isInstanceOf(TestCheckedException.class); + + verify(metrics).recordWrite(ACTUAL_WEIGHT); + verifyNoMoreInteractions(metrics); + } + + @Test + public void passesExceptionToWeigherOnFailure() throws TestCheckedException { + TestCheckedException error = new TestCheckedException(); + assertThatThrownBy(() -> qosClient.executeRead(() -> { + throw error; + }, weigher)).isInstanceOf(TestCheckedException.class); + + verify(weigher).weighFailure(error, TOTAL_NANOS); + verify(weigher, never()).weighSuccess(any(), anyLong()); + } + @Test public void propagatesCheckedExceptions() throws TestCheckedException { assertThatThrownBy(() -> qosClient.executeRead(() -> {