diff --git a/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/qos/ThriftObjectSizeUtils.java b/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/qos/ThriftObjectSizeUtils.java index c97473fccab..42afbd151a6 100644 --- a/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/qos/ThriftObjectSizeUtils.java +++ b/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/qos/ThriftObjectSizeUtils.java @@ -46,10 +46,10 @@ private ThriftObjectSizeUtils() { public static long getApproximateWriteByteCount(Map>> batchMutateMap) { long approxBytesForKeys = getCollectionSize(batchMutateMap.keySet(), ThriftObjectSizeUtils::getByteBufferSize); - long approxBytesForValues = getCollectionSize(batchMutateMap.values(), currentMap -> - getCollectionSize(currentMap.keySet(), ThriftObjectSizeUtils::getStringSize) + long approxBytesForValues = getCollectionSize(batchMutateMap.values(), + currentMap -> getCollectionSize(currentMap.keySet(), ThriftObjectSizeUtils::getStringSize) + getCollectionSize(currentMap.values(), - mutations -> getCollectionSize(mutations, ThriftObjectSizeUtils::getMutationSize))); + mutations -> getCollectionSize(mutations, ThriftObjectSizeUtils::getMutationSize))); return approxBytesForKeys + approxBytesForValues; } diff --git a/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/qos/ThriftQueryWeighers.java b/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/qos/ThriftQueryWeighers.java index 2f31cf247cf..c9a018d9e69 100644 --- a/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/qos/ThriftQueryWeighers.java +++ b/atlasdb-cassandra/src/main/java/com/palantir/atlasdb/keyvalue/cassandra/qos/ThriftQueryWeighers.java @@ -49,22 +49,27 @@ public final class ThriftQueryWeighers { private ThriftQueryWeighers() { } public static final QosClient.QueryWeigher>> MULTIGET_SLICE = - readWeigher(ThriftObjectSizeUtils::getApproximateReadByteCount); + readWeigher(ThriftObjectSizeUtils::getApproximateReadByteCount, Map::size); public static final QosClient.QueryWeigher> GET_RANGE_SLICES = - readWeigher(ThriftObjectSizeUtils::getApproximateReadByteCount); + readWeigher(ThriftObjectSizeUtils::getApproximateReadByteCount, List::size); public static final QosClient.QueryWeigher GET = - readWeigher(ThriftObjectSizeUtils::getColumnOrSuperColumnSize); + readWeigher(ThriftObjectSizeUtils::getColumnOrSuperColumnSize, ignored -> 1); public static final QosClient.QueryWeigher EXECUTE_CQL3_QUERY = - readWeigher(ThriftObjectSizeUtils::getCqlResultSize); - - public static final QosClient.QueryWeigher batchMutate(Map>> mutationMap) { - return writeWeigher(() -> ThriftObjectSizeUtils.getApproximateWriteByteCount(mutationMap)); + // TODO(nziebart): we need to inspect the schema to see how many rows there are - a CQL row is NOT a + // partition. rows here will depend on the type of query executed in CqlExecutor: either (column, ts) pairs, + // or (key, column, ts) triplets + readWeigher(ThriftObjectSizeUtils::getCqlResultSize, ignored -> 1); + + public static QosClient.QueryWeigher batchMutate( + Map>> mutationMap) { + long numRows = mutationMap.size(); + return writeWeigher(numRows, () -> ThriftObjectSizeUtils.getApproximateWriteByteCount(mutationMap)); } - public static QosClient.QueryWeigher readWeigher(Function bytesRead) { + public static QosClient.QueryWeigher readWeigher(Function bytesRead, Function numRows) { return new QosClient.QueryWeigher() { @Override public QueryWeight estimate() { @@ -75,14 +80,14 @@ public QueryWeight estimate() { public QueryWeight weigh(T result, long timeTakenNanos) { return ImmutableQueryWeight.builder() .numBytes(safeGetNumBytesOrDefault(() -> bytesRead.apply(result))) - .timeTakenNanos((int)timeTakenNanos) - .numDistinctRows(1) + .timeTakenNanos(timeTakenNanos) + .numDistinctRows(numRows.apply(result)) .build(); } }; } - public static QosClient.QueryWeigher writeWeigher(Supplier bytesWritten) { + public static QosClient.QueryWeigher writeWeigher(long numRows, Supplier bytesWritten) { Supplier weight = Suppliers.memoize(() -> safeGetNumBytesOrDefault(bytesWritten))::get; return new QosClient.QueryWeigher() { @@ -91,6 +96,7 @@ public QueryWeight estimate() { return ImmutableQueryWeight.builder() .from(DEFAULT_ESTIMATED_WEIGHT) .numBytes(weight.get()) + .numDistinctRows(numRows) .build(); } @@ -98,7 +104,7 @@ public QueryWeight estimate() { public QueryWeight weigh(T result, long timeTakenNanos) { return ImmutableQueryWeight.builder() .from(estimate()) - .timeTakenNanos((int)timeTakenNanos) + .timeTakenNanos(timeTakenNanos) .build(); } }; diff --git a/atlasdb-cassandra/src/test/java/com/palantir/atlasdb/keyvalue/cassandra/qos/ThriftQueryWeighersTest.java b/atlasdb-cassandra/src/test/java/com/palantir/atlasdb/keyvalue/cassandra/qos/ThriftQueryWeighersTest.java new file mode 100644 index 00000000000..c19d01e508d --- /dev/null +++ b/atlasdb-cassandra/src/test/java/com/palantir/atlasdb/keyvalue/cassandra/qos/ThriftQueryWeighersTest.java @@ -0,0 +1,94 @@ +/* + * Copyright 2017 Palantir Technologies, Inc. All rights reserved. + * + * Licensed under the BSD-3 License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.palantir.atlasdb.keyvalue.cassandra.qos; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; + +import org.apache.cassandra.thrift.ColumnOrSuperColumn; +import org.apache.cassandra.thrift.CqlResult; +import org.apache.cassandra.thrift.KeySlice; +import org.apache.cassandra.thrift.Mutation; +import org.junit.Test; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +public class ThriftQueryWeighersTest { + + private static final ByteBuffer BYTES1 = ByteBuffer.allocate(3); + private static final ByteBuffer BYTES2 = ByteBuffer.allocate(7); + private static final ColumnOrSuperColumn COLUMN = new ColumnOrSuperColumn(); + private static final KeySlice KEY_SLICE = new KeySlice(); + private static final Mutation MUTATION = new Mutation(); + + private static final long UNIMPORTANT_ARG = 123L; + + @Test + public void multigetSliceWeigherReturnsCorrectNumRows() { + Map> result = ImmutableMap.of( + BYTES1, ImmutableList.of(COLUMN, COLUMN), + BYTES2, ImmutableList.of(COLUMN)); + + long actualNumRows = ThriftQueryWeighers.MULTIGET_SLICE.weigh(result, UNIMPORTANT_ARG).numDistinctRows(); + + assertThat(actualNumRows).isEqualTo(2); + } + + @Test + public void rangeSlicesWeigherReturnsCorrectNumRows() { + List result = ImmutableList.of(KEY_SLICE, KEY_SLICE, KEY_SLICE); + + long actualNumRows = ThriftQueryWeighers.GET_RANGE_SLICES.weigh(result, UNIMPORTANT_ARG).numDistinctRows(); + + assertThat(actualNumRows).isEqualTo(3); + } + + @Test + public void getWeigherReturnsCorrectNumRows() { + long actualNumRows = ThriftQueryWeighers.GET.weigh(COLUMN, UNIMPORTANT_ARG).numDistinctRows(); + + assertThat(actualNumRows).isEqualTo(1); + } + + @Test + public void executeCql3QueryWeigherReturnsOneRowAlways() { + long actualNumRows = ThriftQueryWeighers.EXECUTE_CQL3_QUERY.weigh(new CqlResult(), + UNIMPORTANT_ARG).numDistinctRows(); + + assertThat(actualNumRows).isEqualTo(1); + } + + @Test + public void batchMutateWeigherReturnsCorrectNumRows() { + Map>> mutations = ImmutableMap.of( + BYTES1, ImmutableMap.of( + "table1", ImmutableList.of(MUTATION, MUTATION), + "table2", ImmutableList.of(MUTATION)), + BYTES2, ImmutableMap.of( + "table1", ImmutableList.of(MUTATION))); + + long actualNumRows = ThriftQueryWeighers.batchMutate(mutations).weigh(null, UNIMPORTANT_ARG) + .numDistinctRows(); + + assertThat(actualNumRows).isEqualTo(2); + } + +}