Skip to content

Commit

Permalink
OPIK-415 Compute traces cost based on token usage
Browse files Browse the repository at this point in the history
  • Loading branch information
Borys Tkachenko committed Nov 22, 2024
1 parent 66f0985 commit dd0b37f
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import jakarta.validation.constraints.Pattern;
import lombok.Builder;

import java.math.BigDecimal;
import java.time.Instant;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -42,7 +43,9 @@ public record Trace(
@JsonView({Trace.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy,
@JsonView({Trace.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy,
@JsonView({
Trace.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) List<FeedbackScore> feedbackScores){
Trace.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) List<FeedbackScore> feedbackScores,
@JsonView({
Trace.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) BigDecimal totalEstimatedCost){

public record TracePage(
@JsonView(Trace.View.Public.class) int page,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ AND id in (
""";

private static final String ESTIMATED_COST_VERSION = "1.0";
private static final BigDecimal ZERO_COST = new BigDecimal("0.00000000");
public static final BigDecimal ZERO_COST = new BigDecimal("0.00000000");

private final @NonNull ConnectionFactory connectionFactory;
private final @NonNull FeedbackScoreDAO feedbackScoreDAO;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import reactor.core.publisher.Mono;
import reactor.core.publisher.SignalType;

import java.math.BigDecimal;
import java.time.Instant;
import java.util.Arrays;
import java.util.List;
Expand All @@ -43,6 +44,7 @@
import static com.comet.opik.domain.AsyncContextUtils.bindWorkspaceIdToFlux;
import static com.comet.opik.domain.AsyncContextUtils.bindWorkspaceIdToMono;
import static com.comet.opik.domain.FeedbackScoreDAO.EntityType;
import static com.comet.opik.domain.SpanDAO.ZERO_COST;
import static com.comet.opik.infrastructure.instrumentation.InstrumentAsyncUtils.Segment;
import static com.comet.opik.infrastructure.instrumentation.InstrumentAsyncUtils.endSegment;
import static com.comet.opik.infrastructure.instrumentation.InstrumentAsyncUtils.startSegment;
Expand Down Expand Up @@ -250,7 +252,8 @@ INSERT INTO traces (
private static final String SELECT_BY_ID = """
SELECT
t.*,
sumMap(s.usage) as usage
sumMap(s.usage) as usage,
sum(s.total_estimated_cost) as total_estimated_cost
FROM (
SELECT
*
Expand All @@ -263,7 +266,8 @@ INSERT INTO traces (
LEFT JOIN (
SELECT
trace_id,
usage
usage,
total_estimated_cost
FROM spans
WHERE workspace_id = :workspace_id
AND trace_id = :id
Expand All @@ -279,7 +283,8 @@ LEFT JOIN (
private static final String SELECT_BY_PROJECT_ID = """
SELECT
t.*,
sumMap(s.usage) as usage
sumMap(s.usage) as usage,
sum(s.total_estimated_cost) as total_estimated_cost
FROM (
SELECT
id,
Expand Down Expand Up @@ -324,7 +329,8 @@ AND id in (
LEFT JOIN (
SELECT
trace_id,
usage
usage,
total_estimated_cost
FROM spans
WHERE workspace_id = :workspace_id
AND project_id = :project_id
Expand Down Expand Up @@ -747,6 +753,9 @@ private Publisher<Trace> mapToDto(Result result) {
.filter(it -> !it.isEmpty())
.orElse(null))
.usage(row.get("usage", Map.class))
.totalEstimatedCost(row.get("total_estimated_cost", BigDecimal.class).equals(ZERO_COST)
? null
: row.get("total_estimated_cost", BigDecimal.class))
.createdAt(row.get("created_at", Instant.class))
.lastUpdatedAt(row.get("last_updated_at", Instant.class))
.createdBy(row.get("created_by", String.class))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ public static BigDecimal textGenerationCost(ModelPrice modelPrice, Map<String, I
}

public static BigDecimal defaultCost(ModelPrice modelPrice, Map<String, Integer> usage) {
return new BigDecimal("0");
return BigDecimal.ZERO;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class TracesResourceTest {
public static final String URL_TEMPLATE = "%s/v1/private/traces";
private static final String URL_TEMPLATE_SPANS = "%s/v1/private/spans";
private static final String[] IGNORED_FIELDS_TRACES = {"projectId", "projectName", "createdAt",
"lastUpdatedAt", "feedbackScores", "createdBy", "lastUpdatedBy"};
"lastUpdatedAt", "feedbackScores", "createdBy", "lastUpdatedBy", "totalEstimatedCost"};
private static final String[] IGNORED_FIELDS_SPANS = SpansResourceTest.IGNORED_FIELDS;
private static final String[] IGNORED_FIELDS_SCORES = {"createdAt", "lastUpdatedAt", "createdBy", "lastUpdatedBy"};

Expand Down Expand Up @@ -3232,6 +3232,44 @@ void getTraceWithUsage() {
getAndAssert(trace, projectId, API_KEY, TEST_WORKSPACE);
}

@ParameterizedTest
@MethodSource
void getTraceWithCost(BigDecimal spanExpectedCost, String model) {
BigDecimal traceExpectedCost = spanExpectedCost == null ? null : spanExpectedCost.multiply(BigDecimal.valueOf(5));
var projectName = RandomStringUtils.randomAlphanumeric(10);
var trace = factory.manufacturePojo(Trace.class)
.toBuilder()
.id(null)
.projectName(projectName)
.usage(Map.of("completion_tokens", 200 * 5L, "prompt_tokens", 300 * 5L, "total_tokens", 4 * 5L))
.feedbackScores(null)
.build();
var id = create(trace, API_KEY, TEST_WORKSPACE);

var spans = PodamFactoryUtils.manufacturePojoList(factory, Span.class).stream()
.map(spanInStream -> spanInStream.toBuilder()
.projectName(projectName)
.traceId(id)
.usage(Map.of("completion_tokens", 200, "prompt_tokens", 300, "total_tokens", 4))
.model(model)
.build())
.collect(Collectors.toList());

batchCreateSpansAndAssert(spans, API_KEY, TEST_WORKSPACE);

var projectId = getProjectId(projectName, TEST_WORKSPACE, API_KEY);
trace = trace.toBuilder().id(id).build();
Trace createdTrace = getAndAssert(trace, projectId, API_KEY, TEST_WORKSPACE);
assertThat(createdTrace.totalEstimatedCost()).isEqualTo(traceExpectedCost);
}

static Stream<Arguments> getTraceWithCost() {
return Stream.of(
Arguments.of(new BigDecimal("0.00070000"), "gpt-3.5-turbo-1106"),
Arguments.of(null, "unknown-model"),
Arguments.of(null, null));
}

@Test
void getTraceWithoutUsage() {
var apiKey = UUID.randomUUID().toString();
Expand Down

0 comments on commit dd0b37f

Please sign in to comment.