From 3df4eb45b3d19c530928380eedf00a77840bb009 Mon Sep 17 00:00:00 2001 From: Thiago Hora Date: Thu, 31 Oct 2024 17:44:52 +0100 Subject: [PATCH 1/9] [OPIK-309] Create prompt endpoint --- .../main/java/com/comet/opik/api/Prompt.java | 53 +++ .../error/EntityAlreadyExistsException.java | 10 +- .../resources/v1/priv/ProjectsResource.java | 2 +- .../api/resources/v1/priv/PromptResource.java | 68 ++++ .../com/comet/opik/domain/DatasetDAO.java | 3 +- .../com/comet/opik/domain/DatasetService.java | 3 +- .../opik/domain/EntityConstraintHandler.java | 33 ++ .../java/com/comet/opik/domain/PromptDAO.java | 25 ++ .../com/comet/opik/domain/PromptService.java | 68 ++++ .../v1/events/DatasetEventListenerTest.java | 1 - .../v1/priv/DatasetExperimentE2ETest.java | 2 +- .../v1/priv/DatasetsResourceTest.java | 3 +- .../resources/v1/priv/PromptResourceTest.java | 318 ++++++++++++++++++ 13 files changed, 582 insertions(+), 7 deletions(-) create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/domain/EntityConstraintHandler.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java create mode 100644 apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java b/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java new file mode 100644 index 000000000..023ab2788 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java @@ -0,0 +1,53 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonView; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.validation.constraints.NotBlank; +import jakarta.validation.constraints.Pattern; +import lombok.Builder; + +import java.time.Instant; +import java.util.List; +import java.util.UUID; + +import static com.comet.opik.utils.ValidationUtils.NULL_OR_NOT_BLANK; + +@Builder(toBuilder = true) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public record Prompt( + @JsonView( { + Prompt.View.Public.class, Prompt.View.Write.class}) UUID id, + @JsonView({Prompt.View.Public.class, Prompt.View.Write.class}) @NotBlank String name, + @JsonView({Prompt.View.Public.class, + Prompt.View.Write.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") String description, + @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt, + @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy, + @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant lastUpdatedAt, + @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy){ + + public static class View { + public static class Write { + } + + public static class Public { + } + } + + public record PromptPage( + @JsonView( { + Project.View.Public.class}) int page, + @JsonView({Project.View.Public.class}) int size, + @JsonView({Project.View.Public.class}) long total, + @JsonView({Project.View.Public.class}) List content) + implements + Page{ + + public static Prompt.PromptPage empty(int page) { + return new Prompt.PromptPage(page, 0, 0, List.of()); + } + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/error/EntityAlreadyExistsException.java b/apps/opik-backend/src/main/java/com/comet/opik/api/error/EntityAlreadyExistsException.java index df74e5e10..b8a16bddd 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/error/EntityAlreadyExistsException.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/error/EntityAlreadyExistsException.java @@ -6,6 +6,14 @@ public class EntityAlreadyExistsException extends ClientErrorException { public EntityAlreadyExistsException(ErrorMessage response) { - super(Response.status(Response.Status.CONFLICT).entity(response).build()); + this((Object) response); + } + + public EntityAlreadyExistsException(io.dropwizard.jersey.errors.ErrorMessage response) { + this((Object) response); + } + + private EntityAlreadyExistsException(Object entity) { + super(Response.status(Response.Status.CONFLICT).entity(entity).build()); } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java index 186d52022..ac0a72229 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/ProjectsResource.java @@ -95,7 +95,7 @@ public Response getById(@PathParam("id") UUID id) { } @POST - @Operation(operationId = "createProject", summary = "Create project", description = "Get project", responses = { + @Operation(operationId = "createProject", summary = "Create project", description = "Create project", responses = { @ApiResponse(responseCode = "201", description = "Created", headers = { @Header(name = "Location", required = true, example = "${basePath}/v1/private/projects/{projectId}", schema = @Schema(implementation = String.class))}), @ApiResponse(responseCode = "422", description = "Unprocessable Content", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java new file mode 100644 index 000000000..46c85886c --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java @@ -0,0 +1,68 @@ +package com.comet.opik.api.resources.v1.priv; + +import com.codahale.metrics.annotation.Timed; +import com.comet.opik.api.Prompt; +import com.comet.opik.api.error.ErrorMessage; +import com.comet.opik.domain.PromptService; +import com.comet.opik.infrastructure.auth.RequestContext; +import com.comet.opik.infrastructure.ratelimit.RateLimited; +import com.fasterxml.jackson.annotation.JsonView; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.headers.Header; +import io.swagger.v3.oas.annotations.media.Content; +import io.swagger.v3.oas.annotations.media.Schema; +import io.swagger.v3.oas.annotations.parameters.RequestBody; +import io.swagger.v3.oas.annotations.responses.ApiResponse; +import jakarta.inject.Inject; +import jakarta.inject.Provider; +import jakarta.validation.Valid; +import jakarta.ws.rs.Consumes; +import jakarta.ws.rs.POST; +import jakarta.ws.rs.Path; +import jakarta.ws.rs.Produces; +import jakarta.ws.rs.core.Context; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.core.UriInfo; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; + +@Path("/v1/private/prompts") +@Produces(MediaType.APPLICATION_JSON) +@Consumes(MediaType.APPLICATION_JSON) +@Timed +@Slf4j +@RequiredArgsConstructor(onConstructor_ = @Inject) +public class PromptResource { + + private final @NonNull Provider requestContext; + private final @NonNull PromptService promptService; + + @POST + @Operation(operationId = "createPrompt", summary = "Create prompt", description = "Create prompt", responses = { + @ApiResponse(responseCode = "201", description = "Created", headers = { + @Header(name = "Location", required = true, example = "${basePath}/v1/private/prompts/{promptId}", schema = @Schema(implementation = String.class))}), + @ApiResponse(responseCode = "422", description = "Unprocessable Content", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "409", description = "Conflict", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))), + + }) + @RateLimited + public Response createPrompt( + @RequestBody(content = @Content(schema = @Schema(implementation = Prompt.class))) @JsonView(Prompt.View.Write.class) @Valid Prompt prompt, + @Context UriInfo uriInfo) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Creating prompt with name '{}', on workspace_id '{}'", prompt.name(), workspaceId); + prompt = promptService.prompt(prompt); + log.info("Prompt created with id '{}' name '{}', on workspace_id '{}'", prompt.id(), prompt.name(), + workspaceId); + + var resourceUri = uriInfo.getAbsolutePathBuilder().path("/%s".formatted(prompt.id())).build(); + + return Response.created(resourceUri).build(); + } + +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetDAO.java index c687dc72a..d96b85b49 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetDAO.java @@ -81,6 +81,7 @@ List find(@Bind("limit") int limit, Optional findByName(@Bind("workspace_id") String workspaceId, @Bind("name") String name); @SqlBatch("UPDATE datasets SET last_created_experiment_at = :experimentCreatedAt WHERE id = :datasetId AND workspace_id = :workspace_id") - int[] recordExperiments(@Bind("workspace_id") String workspaceId, @BindMethods Collection datasets); + int[] recordExperiments(@Bind("workspace_id") String workspaceId, + @BindMethods Collection datasets); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetService.java index 4002a9c59..48df7716f 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/DatasetService.java @@ -278,7 +278,8 @@ private List enrichDatasetWithAdditionalInformation(List datas return datasets.stream() .map(dataset -> { var resume = experimentSummary.computeIfAbsent(dataset.id(), ExperimentSummary::empty); - var datasetItemSummary = datasetItemSummaryMap.computeIfAbsent(dataset.id(), DatasetItemSummary::empty); + var datasetItemSummary = datasetItemSummaryMap.computeIfAbsent(dataset.id(), + DatasetItemSummary::empty); return dataset.toBuilder() .experimentCount(resume.experimentCount()) diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/EntityConstraintHandler.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/EntityConstraintHandler.java new file mode 100644 index 000000000..b6991c25d --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/EntityConstraintHandler.java @@ -0,0 +1,33 @@ +package com.comet.opik.domain; + +import com.comet.opik.api.error.EntityAlreadyExistsException; +import org.jdbi.v3.core.statement.UnableToExecuteStatementException; + +import java.sql.SQLIntegrityConstraintViolationException; +import java.util.function.Supplier; + +interface EntityConstraintHandler { + + static EntityConstraintHandler handle(EntityConstraintAction entityAction) { + return () -> entityAction; + } + + interface EntityConstraintAction { + T execute(); + } + + EntityConstraintAction wrappedAction(); + + default T withError(Supplier errorProvider) { + try { + return wrappedAction().execute(); + } catch (UnableToExecuteStatementException e) { + if (e.getCause() instanceof SQLIntegrityConstraintViolationException) { + throw errorProvider.get(); + } else { + throw e; + } + } + } + +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java new file mode 100644 index 000000000..6bfaee9d6 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java @@ -0,0 +1,25 @@ +package com.comet.opik.domain; + +import com.comet.opik.api.Prompt; +import com.comet.opik.infrastructure.db.UUIDArgumentFactory; +import org.jdbi.v3.sqlobject.config.RegisterArgumentFactory; +import org.jdbi.v3.sqlobject.config.RegisterConstructorMapper; +import org.jdbi.v3.sqlobject.customizer.Bind; +import org.jdbi.v3.sqlobject.customizer.BindMethods; +import org.jdbi.v3.sqlobject.statement.SqlQuery; +import org.jdbi.v3.sqlobject.statement.SqlUpdate; + +import java.util.UUID; + +@RegisterConstructorMapper(Prompt.class) +@RegisterArgumentFactory(UUIDArgumentFactory.class) +interface PromptDAO { + + @SqlUpdate("INSERT INTO prompts (id, name, description, created_by, last_updated_by, workspace_id) " + + "VALUES (:bean.id, :bean.name, :bean.description, :bean.createdBy, :bean.lastUpdatedBy, :workspaceId)") + void save(@Bind("workspaceId") String workspaceId, @BindMethods("bean") Prompt prompt); + + @SqlQuery("SELECT * FROM prompts WHERE id = :id AND workspace_id = :workspaceId") + Prompt findById(@Bind("id") UUID id, @Bind("workspaceId") String workspaceId); + +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java new file mode 100644 index 000000000..4916baa41 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java @@ -0,0 +1,68 @@ +package com.comet.opik.domain; + +import com.comet.opik.api.Prompt; +import com.comet.opik.api.error.EntityAlreadyExistsException; +import com.comet.opik.infrastructure.auth.RequestContext; +import com.google.inject.ImplementedBy; +import io.dropwizard.jersey.errors.ErrorMessage; +import jakarta.inject.Inject; +import jakarta.inject.Provider; +import jakarta.inject.Singleton; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import ru.vyarus.guicey.jdbi3.tx.TransactionTemplate; + +import static com.comet.opik.infrastructure.db.TransactionTemplateAsync.WRITE; + +@ImplementedBy(PromptServiceImpl.class) +public interface PromptService { + Prompt prompt(Prompt prompt); + +} + +@Singleton +@Slf4j +@RequiredArgsConstructor(onConstructor_ = @Inject) +class PromptServiceImpl implements PromptService { + + public static final String ALREADY_EXISTS = "Prompt id or name already exists"; + private final @NonNull Provider requestContext; + private final @NonNull IdGenerator idGenerator; + private final @NonNull TransactionTemplate transactionTemplate; + + @Override + public Prompt prompt(Prompt prompt) { + + String workspaceId = requestContext.get().getWorkspaceId(); + String userName = requestContext.get().getUserName(); + + var newPrompt = prompt.toBuilder() + .id(prompt.id() == null ? idGenerator.generateId() : prompt.id()) + .createdBy(userName) + .lastUpdatedBy(userName) + .build(); + + IdGenerator.validateVersion(newPrompt.id(), "prompt"); + + return EntityConstraintHandler + .handle(() -> savePrompt(workspaceId, newPrompt)) + .withError(this::newConflict); + } + + private Prompt savePrompt(String workspaceId, Prompt newPrompt) { + return transactionTemplate.inTransaction(WRITE, handle -> { + PromptDAO promptDAO = handle.attach(PromptDAO.class); + + promptDAO.save(workspaceId, newPrompt); + + return promptDAO.findById(newPrompt.id(), workspaceId); + }); + } + + private EntityAlreadyExistsException newConflict() { + log.info(ALREADY_EXISTS); + return new EntityAlreadyExistsException(new ErrorMessage(ALREADY_EXISTS)); + } + +} diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/DatasetEventListenerTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/DatasetEventListenerTest.java index ca55c88d1..73b585a62 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/DatasetEventListenerTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/events/DatasetEventListenerTest.java @@ -51,7 +51,6 @@ class DatasetEventListenerTest { private static final String BASE_RESOURCE_URI = "%s/v1/private/datasets"; private static final String EXPERIMENT_RESOURCE_URI = "%s/v1/private/experiments"; - private static final String API_KEY = UUID.randomUUID().toString(); private static final String USER = UUID.randomUUID().toString(); private static final String WORKSPACE_ID = UUID.randomUUID().toString(); diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetExperimentE2ETest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetExperimentE2ETest.java index 8f122ef48..37e907ff6 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetExperimentE2ETest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetExperimentE2ETest.java @@ -47,7 +47,7 @@ import static org.assertj.core.api.Assertions.within; @TestInstance(TestInstance.Lifecycle.PER_CLASS) -@DisplayName("Dataset Event Listener") +@DisplayName("Dataset Experiments E2E Test") class DatasetExperimentE2ETest { private static final String BASE_RESOURCE_URI = "%s/v1/private/datasets"; diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java index e5c2b9224..b3f0910bb 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/DatasetsResourceTest.java @@ -136,7 +136,8 @@ class DatasetsResourceTest { public static final String[] IGNORED_FIELDS_DATA_ITEM = {"createdAt", "lastUpdatedAt", "experimentItems", "createdBy", "lastUpdatedBy"}; public static final String[] DATASET_IGNORED_FIELDS = {"id", "createdAt", "lastUpdatedAt", "createdBy", - "lastUpdatedBy", "experimentCount", "mostRecentExperimentAt", "lastCreatedExperimentAt", "datasetItemsCount"}; + "lastUpdatedBy", "experimentCount", "mostRecentExperimentAt", "lastCreatedExperimentAt", + "datasetItemsCount"}; public static final String API_KEY = UUID.randomUUID().toString(); private static final String USER = UUID.randomUUID().toString(); diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java new file mode 100644 index 000000000..a7f67a614 --- /dev/null +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java @@ -0,0 +1,318 @@ +package com.comet.opik.api.resources.v1.priv; + +import com.comet.opik.api.Project; +import com.comet.opik.api.Prompt; +import com.comet.opik.api.error.ErrorMessage; +import com.comet.opik.api.resources.utils.AuthTestUtils; +import com.comet.opik.api.resources.utils.ClickHouseContainerUtils; +import com.comet.opik.api.resources.utils.ClientSupportUtils; +import com.comet.opik.api.resources.utils.MigrationUtils; +import com.comet.opik.api.resources.utils.MySQLContainerUtils; +import com.comet.opik.api.resources.utils.RedisContainerUtils; +import com.comet.opik.api.resources.utils.TestDropwizardAppExtensionUtils; +import com.comet.opik.api.resources.utils.TestUtils; +import com.comet.opik.api.resources.utils.WireMockUtils; +import com.comet.opik.infrastructure.DatabaseAnalyticsFactory; +import com.comet.opik.infrastructure.auth.RequestContext; +import com.comet.opik.podam.PodamFactoryUtils; +import com.github.tomakehurst.wiremock.client.WireMock; +import com.redis.testcontainers.RedisContainer; +import jakarta.ws.rs.client.Entity; +import jakarta.ws.rs.core.HttpHeaders; +import jakarta.ws.rs.core.MediaType; +import org.jdbi.v3.core.Jdbi; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.testcontainers.clickhouse.ClickHouseContainer; +import org.testcontainers.containers.MySQLContainer; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.lifecycle.Startables; +import ru.vyarus.dropwizard.guice.test.ClientSupport; +import ru.vyarus.dropwizard.guice.test.jupiter.ext.TestDropwizardAppExtension; +import uk.co.jemos.podam.api.PodamFactory; + +import java.sql.SQLException; +import java.util.List; +import java.util.UUID; +import java.util.stream.Stream; + +import static com.comet.opik.api.resources.utils.ClickHouseContainerUtils.DATABASE_NAME; +import static com.comet.opik.api.resources.utils.MigrationUtils.CLICKHOUSE_CHANGELOG_FILE; +import static com.comet.opik.infrastructure.auth.RequestContext.SESSION_COOKIE; +import static com.comet.opik.infrastructure.auth.RequestContext.WORKSPACE_HEADER; +import static com.comet.opik.infrastructure.auth.TestHttpClientUtils.UNAUTHORIZED_RESPONSE; +import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; +import static com.github.tomakehurst.wiremock.client.WireMock.matching; +import static com.github.tomakehurst.wiremock.client.WireMock.matchingJsonPath; +import static com.github.tomakehurst.wiremock.client.WireMock.okJson; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.params.provider.Arguments.arguments; + +@Testcontainers(parallel = true) +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +@DisplayName("Prompt Resource Test") +class PromptResourceTest { + + private static final String RESOURCE_PATH = "%s/v1/private/prompts"; + + private static final String API_KEY = UUID.randomUUID().toString(); + private static final String USER = UUID.randomUUID().toString(); + private static final String WORKSPACE_ID = UUID.randomUUID().toString(); + private static final String TEST_WORKSPACE = UUID.randomUUID().toString(); + + private static final RedisContainer REDIS = RedisContainerUtils.newRedisContainer(); + private static final ClickHouseContainer CLICKHOUSE_CONTAINER = ClickHouseContainerUtils.newClickHouseContainer(); + private static final MySQLContainer MYSQL = MySQLContainerUtils.newMySQLContainer(); + + @RegisterExtension + private static final TestDropwizardAppExtension app; + + private static final WireMockUtils.WireMockRuntime wireMock; + + static { + Startables.deepStart(REDIS, CLICKHOUSE_CONTAINER, MYSQL).join(); + wireMock = WireMockUtils.startWireMock(); + + DatabaseAnalyticsFactory databaseAnalyticsFactory = ClickHouseContainerUtils + .newDatabaseAnalyticsFactory(CLICKHOUSE_CONTAINER, DATABASE_NAME); + + app = TestDropwizardAppExtensionUtils.newTestDropwizardAppExtension( + MYSQL.getJdbcUrl(), databaseAnalyticsFactory, wireMock.runtimeInfo(), REDIS.getRedisURI()); + } + + private final PodamFactory factory = PodamFactoryUtils.newPodamFactory(); + + private String baseURI; + private ClientSupport client; + + @BeforeAll + void setUpAll(ClientSupport client, Jdbi jdbi) throws SQLException { + + MigrationUtils.runDbMigration(jdbi, MySQLContainerUtils.migrationParameters()); + + try (var connection = CLICKHOUSE_CONTAINER.createConnection("")) { + MigrationUtils.runDbMigration(connection, CLICKHOUSE_CHANGELOG_FILE, + ClickHouseContainerUtils.migrationParameters()); + } + + this.baseURI = "http://localhost:%d".formatted(client.getPort()); + this.client = client; + + ClientSupportUtils.config(client); + + mockTargetWorkspace(API_KEY, TEST_WORKSPACE, WORKSPACE_ID); + } + + private static void mockTargetWorkspace(String apiKey, String workspaceName, String workspaceId) { + AuthTestUtils.mockTargetWorkspace(wireMock.server(), apiKey, workspaceName, workspaceId, USER); + } + + @AfterAll + void tearDownAll() { + wireMock.server().stop(); + } + + @Nested + @DisplayName("Api Key Authentication:") + @TestInstance(TestInstance.Lifecycle.PER_CLASS) + class ApiKey { + + private final String fakeApikey = UUID.randomUUID().toString(); + private final String okApikey = UUID.randomUUID().toString(); + + Stream credentials() { + return Stream.of( + arguments(okApikey, true), + arguments(fakeApikey, false), + arguments("", false)); + } + + @BeforeEach + void setUp() { + + wireMock.server().stubFor( + post(urlPathEqualTo("/opik/auth")) + .withHeader(HttpHeaders.AUTHORIZATION, equalTo(fakeApikey)) + .withRequestBody(matchingJsonPath("$.workspaceName", matching(".+"))) + .willReturn(WireMock.unauthorized())); + + wireMock.server().stubFor( + post(urlPathEqualTo("/opik/auth")) + .withHeader(HttpHeaders.AUTHORIZATION, equalTo("")) + .withRequestBody(matchingJsonPath("$.workspaceName", matching(".+"))) + .willReturn(WireMock.unauthorized())); + } + + @ParameterizedTest + @MethodSource("credentials") + @DisplayName("create prompt: when api key is present, then return proper response") + void createProject__whenApiKeyIsPresent__thenReturnProperResponse(String apiKey, boolean success) { + + var project = factory.manufacturePojo(Project.class); + String workspaceName = UUID.randomUUID().toString(); + + mockTargetWorkspace(okApikey, workspaceName, WORKSPACE_ID); + + try (var actualResponse = client.target(RESOURCE_PATH.formatted(baseURI)) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.entity(project, MediaType.APPLICATION_JSON_TYPE))) { + + if (success) { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201); + assertThat(actualResponse.hasEntity()).isFalse(); + } else { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(401); + assertThat(actualResponse.hasEntity()).isTrue(); + assertThat(actualResponse.readEntity(io.dropwizard.jersey.errors.ErrorMessage.class)) + .isEqualTo(UNAUTHORIZED_RESPONSE); + } + } + } + + } + + @Nested + @DisplayName("Session Token Authentication:") + @TestInstance(TestInstance.Lifecycle.PER_CLASS) + class SessionTokenCookie { + + private final String sessionToken = UUID.randomUUID().toString(); + private final String fakeSessionToken = UUID.randomUUID().toString(); + + Stream credentials() { + return Stream.of( + arguments(sessionToken, true, "OK_" + UUID.randomUUID()), + arguments(fakeSessionToken, false, UUID.randomUUID().toString())); + } + + @BeforeAll + void setUp() { + wireMock.server().stubFor( + post(urlPathEqualTo("/opik/auth-session")) + .withCookie(SESSION_COOKIE, equalTo(sessionToken)) + .withRequestBody(matchingJsonPath("$.workspaceName", matching("OK_.+"))) + .willReturn(okJson(AuthTestUtils.newWorkspaceAuthResponse(USER, WORKSPACE_ID)))); + + wireMock.server().stubFor( + post(urlPathEqualTo("/opik/auth-session")) + .withCookie(SESSION_COOKIE, equalTo(fakeSessionToken)) + .withRequestBody(matchingJsonPath("$.workspaceName", matching(".+"))) + .willReturn(WireMock.unauthorized())); + } + + @ParameterizedTest + @MethodSource("credentials") + @DisplayName("create prompt: when session token is present, then return proper response") + void createProject__whenSessionTokenIsPresent__thenReturnProperResponse(String sessionToken, boolean success, + String workspaceName) { + var project = factory.manufacturePojo(Project.class); + + try (var actualResponse = client.target(RESOURCE_PATH.formatted(baseURI)).request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .cookie(SESSION_COOKIE, sessionToken) + .header(WORKSPACE_HEADER, workspaceName) + .post(Entity.entity(project, MediaType.APPLICATION_JSON_TYPE))) { + + if (success) { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(201); + assertThat(actualResponse.hasEntity()).isFalse(); + } else { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(401); + assertThat(actualResponse.hasEntity()).isTrue(); + assertThat(actualResponse.readEntity(io.dropwizard.jersey.errors.ErrorMessage.class)) + .isEqualTo(UNAUTHORIZED_RESPONSE); + } + } + } + } + + private UUID createPrompt(Prompt prompt, String apiKey, String workspaceName) { + try (var response = client.target(RESOURCE_PATH.formatted(baseURI)) + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(RequestContext.WORKSPACE_HEADER, workspaceName) + .post(Entity.json(prompt))) { + + assertThat(response.getStatus()).isEqualTo(201); + + return TestUtils.getIdFromLocation(response.getLocation()); + } + } + + @Nested + @DisplayName("Create Prompt") + @TestInstance(TestInstance.Lifecycle.PER_CLASS) + class CreatePrompt { + + @Test + @DisplayName("Should create prompt") + void shouldCreatePrompt() { + + var prompt = factory.manufacturePojo(Prompt.class); + + var promptId = createPrompt(prompt, API_KEY, TEST_WORKSPACE); + + assertThat(promptId).isNotNull(); + } + + @ParameterizedTest + @MethodSource + @DisplayName("when prompt state is invalid, then return conflict") + void when__promptIsInvalid__thenReturnError(Prompt prompt, int expectedStatusCode, Object expectedBody, + Class expectedResponseClass) { + + try (var response = client.target(RESOURCE_PATH.formatted(baseURI)) + .request() + .header(HttpHeaders.AUTHORIZATION, API_KEY) + .header(RequestContext.WORKSPACE_HEADER, TEST_WORKSPACE) + .post(Entity.json(prompt))) { + + assertThat(response.getStatus()).isEqualTo(expectedStatusCode); + + var actualBody = response.readEntity(expectedResponseClass); + + assertThat(actualBody).isEqualTo(expectedBody); + } + } + + Stream when__promptIsInvalid__thenReturnError() { + Prompt prompt = factory.manufacturePojo(Prompt.class).toBuilder() + .id(UUID.randomUUID()) + .build(); + + Prompt duplicatedPrompt = factory.manufacturePojo(Prompt.class); + createPrompt(duplicatedPrompt, API_KEY, TEST_WORKSPACE); + + return Stream.of( + Arguments.of(prompt, 400, + new ErrorMessage(List.of("prompt id must be a version 7 UUID")), + ErrorMessage.class), + Arguments.of(duplicatedPrompt.toBuilder().name(UUID.randomUUID().toString()).build(), 409, + new io.dropwizard.jersey.errors.ErrorMessage("Prompt id or name already exists"), + io.dropwizard.jersey.errors.ErrorMessage.class), + Arguments.of(duplicatedPrompt.toBuilder().id(factory.manufacturePojo(UUID.class)).build(), 409, + new io.dropwizard.jersey.errors.ErrorMessage("Prompt id or name already exists"), + io.dropwizard.jersey.errors.ErrorMessage.class), + Arguments.of(factory.manufacturePojo(Prompt.class).toBuilder().description("").build(), 422, + new ErrorMessage(List.of("description must not be blank")), + ErrorMessage.class), + Arguments.of(factory.manufacturePojo(Prompt.class).toBuilder().name("").build(), 422, + new ErrorMessage(List.of("name must not be blank")), ErrorMessage.class)); + } + } + +} \ No newline at end of file From 18cc5096d7fd3d1496be904e11cc855a66f0db49 Mon Sep 17 00:00:00 2001 From: Thiago Hora Date: Thu, 31 Oct 2024 19:12:26 +0100 Subject: [PATCH 2/9] [OPIK-309] Expose API contracts --- .../comet/opik/api/CreatePromptVersion.java | 17 ++ .../main/java/com/comet/opik/api/Prompt.java | 36 ++- .../com/comet/opik/api/PromptVersion.java | 55 ++++ .../comet/opik/api/PromptVersionRetrieve.java | 13 + .../api/resources/v1/priv/PromptResource.java | 251 ++++++++++++++++++ 5 files changed, 360 insertions(+), 12 deletions(-) create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/CreatePromptVersion.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersionRetrieve.java diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/CreatePromptVersion.java b/apps/opik-backend/src/main/java/com/comet/opik/api/CreatePromptVersion.java new file mode 100644 index 000000000..04064916c --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/CreatePromptVersion.java @@ -0,0 +1,17 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonView; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import jakarta.validation.constraints.NotBlank; +import jakarta.validation.constraints.NotNull; +import lombok.Builder; + +@Builder(toBuilder = true) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public record CreatePromptVersion(@JsonView( { + PromptVersion.View.Detail.class}) @NotBlank String name, + @JsonView({PromptVersion.View.Detail.class}) @NotNull PromptVersion version){ +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java b/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java index 023ab2788..4d7de1df4 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java @@ -20,14 +20,23 @@ @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) public record Prompt( @JsonView( { - Prompt.View.Public.class, Prompt.View.Write.class}) UUID id, - @JsonView({Prompt.View.Public.class, Prompt.View.Write.class}) @NotBlank String name, + Prompt.View.Public.class, Prompt.View.Write.class, Prompt.View.Detail.class}) UUID id, + @JsonView({Prompt.View.Public.class, Prompt.View.Write.class, Prompt.View.Detail.class}) @NotBlank String name, @JsonView({Prompt.View.Public.class, - Prompt.View.Write.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") String description, - @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt, - @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy, - @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant lastUpdatedAt, - @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy){ + Prompt.View.Write.class, + Prompt.View.Detail.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") String description, + @JsonView({Prompt.View.Public.class, + Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt, + @JsonView({Prompt.View.Public.class, + Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy, + @JsonView({Prompt.View.Public.class, + Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant lastUpdatedAt, + @JsonView({Prompt.View.Public.class, + Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy, + @JsonView({Prompt.View.Public.class, + Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) long versionCount, + @JsonView({ + Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) PromptVersion latestVersion){ public static class View { public static class Write { @@ -35,14 +44,17 @@ public static class Write { public static class Public { } - } + public static class Detail { + } + } + @Builder public record PromptPage( @JsonView( { - Project.View.Public.class}) int page, - @JsonView({Project.View.Public.class}) int size, - @JsonView({Project.View.Public.class}) long total, - @JsonView({Project.View.Public.class}) List content) + Prompt.View.Public.class}) int page, + @JsonView({Prompt.View.Public.class}) int size, + @JsonView({Prompt.View.Public.class}) long total, + @JsonView({Prompt.View.Public.class}) List content) implements Page{ diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java new file mode 100644 index 000000000..e235f271f --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java @@ -0,0 +1,55 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonView; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.validation.constraints.NotNull; +import lombok.Builder; + +import java.time.Instant; +import java.util.List; +import java.util.Set; +import java.util.UUID; + +@Builder(toBuilder = true) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public record PromptVersion( + @JsonView( { + PromptVersion.View.Public.class, + PromptVersion.View.Detail.class}) @Schema(description = "version unique identifier, generated if absent") UUID id, + @JsonView({PromptVersion.View.Public.class, + PromptVersion.View.Detail.class}) @Schema(description = "version short unique identifier, generated if absent") String commit, + @JsonView({PromptVersion.View.Detail.class}) @NotNull String template, + @JsonView({ + PromptVersion.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Set variables, + @JsonView({PromptVersion.View.Public.class, + PromptVersion.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt, + @JsonView({PromptVersion.View.Public.class, + PromptVersion.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy){ + + public static class View { + public static class Public { + } + + public static class Detail { + } + } + + @Builder + public record PromptVersionPage( + @JsonView( { + PromptVersion.View.Public.class}) int page, + @JsonView({PromptVersion.View.Public.class}) int size, + @JsonView({PromptVersion.View.Public.class}) long total, + @JsonView({PromptVersion.View.Public.class}) List content) + implements + Page{ + + public static PromptVersion.PromptVersionPage empty(int page) { + return new PromptVersion.PromptVersionPage(page, 0, 0, List.of()); + } + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersionRetrieve.java b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersionRetrieve.java new file mode 100644 index 000000000..60ae9a086 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersionRetrieve.java @@ -0,0 +1,13 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import jakarta.validation.constraints.NotBlank; +import lombok.Builder; + +@Builder(toBuilder = true) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public record PromptVersionRetrieve(@NotBlank String name, String commit) { +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java index 46c85886c..daf20d0f8 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java @@ -1,8 +1,12 @@ package com.comet.opik.api.resources.v1.priv; import com.codahale.metrics.annotation.Timed; +import com.comet.opik.api.CreatePromptVersion; import com.comet.opik.api.Prompt; +import com.comet.opik.api.PromptVersion; +import com.comet.opik.api.PromptVersionRetrieve; import com.comet.opik.api.error.ErrorMessage; +import com.comet.opik.domain.IdGenerator; import com.comet.opik.domain.PromptService; import com.comet.opik.infrastructure.auth.RequestContext; import com.comet.opik.infrastructure.ratelimit.RateLimited; @@ -13,13 +17,21 @@ import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.parameters.RequestBody; import io.swagger.v3.oas.annotations.responses.ApiResponse; +import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.inject.Inject; import jakarta.inject.Provider; import jakarta.validation.Valid; +import jakarta.validation.constraints.Min; import jakarta.ws.rs.Consumes; +import jakarta.ws.rs.DELETE; +import jakarta.ws.rs.DefaultValue; +import jakarta.ws.rs.GET; import jakarta.ws.rs.POST; +import jakarta.ws.rs.PUT; import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; import jakarta.ws.rs.Produces; +import jakarta.ws.rs.QueryParam; import jakarta.ws.rs.core.Context; import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.Response; @@ -28,16 +40,23 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import java.time.Instant; +import java.util.Set; +import java.util.UUID; +import java.util.stream.IntStream; + @Path("/v1/private/prompts") @Produces(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON) @Timed @Slf4j @RequiredArgsConstructor(onConstructor_ = @Inject) +@Tag(name = "Prompts", description = "Prompt resources") public class PromptResource { private final @NonNull Provider requestContext; private final @NonNull PromptService promptService; + private final @NonNull IdGenerator idGenerator; @POST @Operation(operationId = "createPrompt", summary = "Create prompt", description = "Create prompt", responses = { @@ -65,4 +84,236 @@ public Response createPrompt( return Response.created(resourceUri).build(); } + @GET + @Operation(operationId = "getPrompts", summary = "Get prompts", description = "Get prompts", responses = { + @ApiResponse(responseCode = "200", description = "OK", content = @Content(schema = @Schema(implementation = Prompt.PromptPage.class))), + }) + @JsonView({Prompt.View.Public.class}) + public Response getPrompts( + @QueryParam("page") @Min(1) @DefaultValue("1") int page, + @QueryParam("size") @Min(1) @DefaultValue("10") int size, + @QueryParam("name") String name) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Getting prompts by name '{}' on workspace_id '{}'", name, workspaceId); + var promptPage = Prompt.PromptPage.builder() + .page(page) + .size(5) + .total(5) + .content(IntStream.range(0, 5).mapToObj(i -> generatePrompt()).toList()) + .build(); + log.info("Got prompts by name '{}', count '{}' on workspace_id '{}'", name, promptPage.size(), workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED).entity(promptPage).build(); + } + + private Prompt generatePrompt() { + return Prompt.builder() + .id(idGenerator.generateId()) + .name("Prompt 1") + .description("Description 1") + .createdAt(Instant.now()) + .createdBy("User 1") + .lastUpdatedAt(Instant.now()) + .lastUpdatedBy("User 1") + .latestVersion(generatePromptVersion()) + .build(); + } + + @GET + @Path("{id}") + @Operation(operationId = "getPromptById", summary = "Get prompt by id", description = "Get prompt by id", responses = { + @ApiResponse(responseCode = "200", description = "Prompt resource", content = @Content(schema = @Schema(implementation = Prompt.class))), + @ApiResponse(responseCode = "404", description = "Not Found", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))), + }) + @JsonView({Prompt.View.Detail.class}) + public Response getPromptById(@PathParam("id") UUID id) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Getting prompt by id '{}' on workspace_id '{}'", id, workspaceId); + + Prompt prompt = generatePrompt(); + + log.info("Got prompt by id '{}' on workspace_id '{}'", id, workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED).entity(prompt).build(); + } + + @PUT + @Path("{id}") + @Operation(operationId = "updatePrompt", summary = "Update prompt", description = "Update prompt", responses = { + @ApiResponse(responseCode = "204", description = "No content"), + @ApiResponse(responseCode = "422", description = "Unprocessable Content", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "404", description = "Not Found", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))), + @ApiResponse(responseCode = "409", description = "Conflict", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))), + }) + @RateLimited + public Response updatePrompt( + @PathParam("id") UUID id, + @RequestBody(content = @Content(schema = @Schema(implementation = Prompt.class))) @JsonView(Prompt.View.Write.class) @Valid Prompt prompt) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Updating prompt with id '{}' on workspace_id '{}'", id, workspaceId); + + log.info("Updated prompt with id '{}' on workspace_id '{}'", id, workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED).build(); + } + + @DELETE + @Path("{id}") + @Operation(operationId = "deletePrompt", summary = "Delete prompt", description = "Delete prompt", responses = { + @ApiResponse(responseCode = "204", description = "No content") + }) + public Response deletePrompt(@PathParam("id") UUID id) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Deleting prompt by id '{}' on workspace_id '{}'", id, workspaceId); + + log.info("Deleted prompt by id '{}' on workspace_id '{}'", id, workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED).build(); + } + + @POST + @Path("/versions") + @Operation(operationId = "createPromptVersion", summary = "Create prompt version", description = "Create prompt version", responses = { + @ApiResponse(responseCode = "200", description = "OK", content = @Content(schema = @Schema(implementation = PromptVersion.class))), + @ApiResponse(responseCode = "422", description = "Unprocessable Content", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "409", description = "Conflict", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))) + }) + @RateLimited + @JsonView({PromptVersion.View.Detail.class}) + public Response createPromptVersion( + @RequestBody(content = @Content(schema = @Schema(implementation = CreatePromptVersion.class))) @JsonView({ + PromptVersion.View.Detail.class}) @Valid CreatePromptVersion promptVersion) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Creating prompt version commit '{}' on workspace_id '{}'", promptVersion.version().commit(), + workspaceId); + + UUID id = idGenerator.generateId(); + log.info("Created prompt version commit '{}' with id '{}' on workspace_id '{}'", + promptVersion.version().commit(), id, workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED) + .entity(generatePromptVersion(promptVersion, id)) + .build(); + } + + private PromptVersion generatePromptVersion(CreatePromptVersion promptVersion, UUID id) { + return PromptVersion.builder() + .id(id) + .commit(promptVersion.version().commit() == null + ? id.toString().substring(id.toString().length() - 7) + : promptVersion.version().commit()) + .template(promptVersion.version().template()) + .variables( + Set.of("user_message")) + .createdAt(Instant.now()) + .createdBy("User 1") + .build(); + } + + @GET + @Path("/{id}/versions") + @Operation(operationId = "getPromptVersions", summary = "Get prompt versions", description = "Get prompt versions", responses = { + @ApiResponse(responseCode = "200", description = "OK", content = @Content(schema = @Schema(implementation = PromptVersion.PromptVersionPage.class))), + }) + @JsonView({PromptVersion.View.Public.class}) + public Response getPromptVersions(@PathParam("id") UUID id, + @QueryParam("page") @Min(1) @DefaultValue("1") int page, + @QueryParam("size") @Min(1) @DefaultValue("10") int size) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Getting prompt versions by id '{}' on workspace_id '{}'", id, workspaceId); + + PromptVersion.PromptVersionPage promptVersionPage = PromptVersion.PromptVersionPage.builder() + .page(1) + .size(5) + .total(5) + .content(IntStream.range(0, 5).mapToObj(i -> generatePromptVersion()).toList()) + .build(); + + log.info("Got prompt versions by id '{}' on workspace_id '{}'", id, workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED).entity(promptVersionPage).build(); + } + + @GET + @Path("/{id}/versions/{versionId}") + @Operation(operationId = "getPromptVersionById", summary = "Get prompt version by id", description = "Get prompt version by id", responses = { + @ApiResponse(responseCode = "200", description = "Prompt version resource", content = @Content(schema = @Schema(implementation = PromptVersion.class))), + @ApiResponse(responseCode = "404", description = "Not Found", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))), + }) + @JsonView({PromptVersion.View.Detail.class}) + public Response getPromptVersionById(@PathParam("id") UUID id, @PathParam("versionId") UUID versionId) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Getting prompt id '{}' and version by id '{}' on workspace_id '{}'", id, versionId, workspaceId); + + PromptVersion promptVersion = generatePromptVersion().toBuilder() + .id(versionId) + .commit(versionId.toString().substring(versionId.toString().length() - 7)) + .build(); + + log.info("Got prompt id '{}' and version by id '{}' on workspace_id '{}'", id, versionId, workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED).entity(promptVersion).build(); + } + + @POST + @Path("/prompts/versions/retrieve") + @Operation(operationId = "retrievePromptVersion", summary = "Retrieve prompt version", description = "Retrieve prompt version", responses = { + @ApiResponse(responseCode = "200", description = "OK", content = @Content(schema = @Schema(implementation = PromptVersion.class))), + @ApiResponse(responseCode = "422", description = "Unprocessable Content", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "404", description = "Not Found", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))), + }) + @JsonView({PromptVersion.View.Detail.class}) + public Response retrievePromptVersion( + @RequestBody(content = @Content(schema = @Schema(implementation = PromptVersionRetrieve.class))) @Valid PromptVersionRetrieve retrieve) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Retrieving prompt name '{}' with commit '{}' on workspace_id '{}'", retrieve.name(), + retrieve.commit(), workspaceId); + + UUID id = idGenerator.generateId(); + + log.info("Retrieved prompt name '{}' with commit '{}' on workspace_id '{}'", retrieve.name(), + retrieve.commit(), workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED) + .entity(generatePromptVersion().toBuilder() + .id(id) + .commit(retrieve.commit() == null + ? id.toString().substring(id.toString().length() - 7) + : retrieve.commit()) + .build()) + .build(); + } + + private PromptVersion generatePromptVersion() { + var id = idGenerator.generateId(); + return PromptVersion.builder() + .id(id) + .commit(id.toString().substring(id.toString().length() - 7)) + .template("Hello %s, My question is ${user_message}".formatted(id)) + .variables( + Set.of("user_message")) + .createdAt(Instant.now()) + .createdBy("User 1") + .build(); + } + } From 3a14f3c441cc5b496ce295c2436aaf5b6c9a606c Mon Sep 17 00:00:00 2001 From: Thiago Hora Date: Thu, 31 Oct 2024 19:12:26 +0100 Subject: [PATCH 3/9] [OPIK-309] Expose API contracts --- .../comet/opik/api/CreatePromptVersion.java | 17 ++ .../main/java/com/comet/opik/api/Prompt.java | 37 ++- .../com/comet/opik/api/PromptVersion.java | 55 ++++ .../comet/opik/api/PromptVersionRetrieve.java | 13 + .../api/resources/v1/priv/PromptResource.java | 251 ++++++++++++++++++ 5 files changed, 361 insertions(+), 12 deletions(-) create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/CreatePromptVersion.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersionRetrieve.java diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/CreatePromptVersion.java b/apps/opik-backend/src/main/java/com/comet/opik/api/CreatePromptVersion.java new file mode 100644 index 000000000..04064916c --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/CreatePromptVersion.java @@ -0,0 +1,17 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonView; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import jakarta.validation.constraints.NotBlank; +import jakarta.validation.constraints.NotNull; +import lombok.Builder; + +@Builder(toBuilder = true) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public record CreatePromptVersion(@JsonView( { + PromptVersion.View.Detail.class}) @NotBlank String name, + @JsonView({PromptVersion.View.Detail.class}) @NotNull PromptVersion version){ +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java b/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java index 023ab2788..151b9688b 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java @@ -5,6 +5,7 @@ import com.fasterxml.jackson.databind.PropertyNamingStrategies; import com.fasterxml.jackson.databind.annotation.JsonNaming; import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.annotation.Nullable; import jakarta.validation.constraints.NotBlank; import jakarta.validation.constraints.Pattern; import lombok.Builder; @@ -20,14 +21,23 @@ @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) public record Prompt( @JsonView( { - Prompt.View.Public.class, Prompt.View.Write.class}) UUID id, - @JsonView({Prompt.View.Public.class, Prompt.View.Write.class}) @NotBlank String name, + Prompt.View.Public.class, Prompt.View.Write.class, Prompt.View.Detail.class}) UUID id, + @JsonView({Prompt.View.Public.class, Prompt.View.Write.class, Prompt.View.Detail.class}) @NotBlank String name, @JsonView({Prompt.View.Public.class, - Prompt.View.Write.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") String description, - @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt, - @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy, - @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant lastUpdatedAt, - @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy){ + Prompt.View.Write.class, + Prompt.View.Detail.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") String description, + @JsonView({Prompt.View.Public.class, + Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt, + @JsonView({Prompt.View.Public.class, + Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy, + @JsonView({Prompt.View.Public.class, + Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant lastUpdatedAt, + @JsonView({Prompt.View.Public.class, + Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy, + @JsonView({Prompt.View.Public.class, + Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Long versionCount, + @JsonView({ + Prompt.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable PromptVersion latestVersion){ public static class View { public static class Write { @@ -35,14 +45,17 @@ public static class Write { public static class Public { } - } + public static class Detail { + } + } + @Builder public record PromptPage( @JsonView( { - Project.View.Public.class}) int page, - @JsonView({Project.View.Public.class}) int size, - @JsonView({Project.View.Public.class}) long total, - @JsonView({Project.View.Public.class}) List content) + Prompt.View.Public.class}) int page, + @JsonView({Prompt.View.Public.class}) int size, + @JsonView({Prompt.View.Public.class}) long total, + @JsonView({Prompt.View.Public.class}) List content) implements Page{ diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java new file mode 100644 index 000000000..e235f271f --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java @@ -0,0 +1,55 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonView; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.validation.constraints.NotNull; +import lombok.Builder; + +import java.time.Instant; +import java.util.List; +import java.util.Set; +import java.util.UUID; + +@Builder(toBuilder = true) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public record PromptVersion( + @JsonView( { + PromptVersion.View.Public.class, + PromptVersion.View.Detail.class}) @Schema(description = "version unique identifier, generated if absent") UUID id, + @JsonView({PromptVersion.View.Public.class, + PromptVersion.View.Detail.class}) @Schema(description = "version short unique identifier, generated if absent") String commit, + @JsonView({PromptVersion.View.Detail.class}) @NotNull String template, + @JsonView({ + PromptVersion.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Set variables, + @JsonView({PromptVersion.View.Public.class, + PromptVersion.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt, + @JsonView({PromptVersion.View.Public.class, + PromptVersion.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy){ + + public static class View { + public static class Public { + } + + public static class Detail { + } + } + + @Builder + public record PromptVersionPage( + @JsonView( { + PromptVersion.View.Public.class}) int page, + @JsonView({PromptVersion.View.Public.class}) int size, + @JsonView({PromptVersion.View.Public.class}) long total, + @JsonView({PromptVersion.View.Public.class}) List content) + implements + Page{ + + public static PromptVersion.PromptVersionPage empty(int page) { + return new PromptVersion.PromptVersionPage(page, 0, 0, List.of()); + } + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersionRetrieve.java b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersionRetrieve.java new file mode 100644 index 000000000..60ae9a086 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersionRetrieve.java @@ -0,0 +1,13 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import jakarta.validation.constraints.NotBlank; +import lombok.Builder; + +@Builder(toBuilder = true) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public record PromptVersionRetrieve(@NotBlank String name, String commit) { +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java index 46c85886c..daf20d0f8 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java @@ -1,8 +1,12 @@ package com.comet.opik.api.resources.v1.priv; import com.codahale.metrics.annotation.Timed; +import com.comet.opik.api.CreatePromptVersion; import com.comet.opik.api.Prompt; +import com.comet.opik.api.PromptVersion; +import com.comet.opik.api.PromptVersionRetrieve; import com.comet.opik.api.error.ErrorMessage; +import com.comet.opik.domain.IdGenerator; import com.comet.opik.domain.PromptService; import com.comet.opik.infrastructure.auth.RequestContext; import com.comet.opik.infrastructure.ratelimit.RateLimited; @@ -13,13 +17,21 @@ import io.swagger.v3.oas.annotations.media.Schema; import io.swagger.v3.oas.annotations.parameters.RequestBody; import io.swagger.v3.oas.annotations.responses.ApiResponse; +import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.inject.Inject; import jakarta.inject.Provider; import jakarta.validation.Valid; +import jakarta.validation.constraints.Min; import jakarta.ws.rs.Consumes; +import jakarta.ws.rs.DELETE; +import jakarta.ws.rs.DefaultValue; +import jakarta.ws.rs.GET; import jakarta.ws.rs.POST; +import jakarta.ws.rs.PUT; import jakarta.ws.rs.Path; +import jakarta.ws.rs.PathParam; import jakarta.ws.rs.Produces; +import jakarta.ws.rs.QueryParam; import jakarta.ws.rs.core.Context; import jakarta.ws.rs.core.MediaType; import jakarta.ws.rs.core.Response; @@ -28,16 +40,23 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import java.time.Instant; +import java.util.Set; +import java.util.UUID; +import java.util.stream.IntStream; + @Path("/v1/private/prompts") @Produces(MediaType.APPLICATION_JSON) @Consumes(MediaType.APPLICATION_JSON) @Timed @Slf4j @RequiredArgsConstructor(onConstructor_ = @Inject) +@Tag(name = "Prompts", description = "Prompt resources") public class PromptResource { private final @NonNull Provider requestContext; private final @NonNull PromptService promptService; + private final @NonNull IdGenerator idGenerator; @POST @Operation(operationId = "createPrompt", summary = "Create prompt", description = "Create prompt", responses = { @@ -65,4 +84,236 @@ public Response createPrompt( return Response.created(resourceUri).build(); } + @GET + @Operation(operationId = "getPrompts", summary = "Get prompts", description = "Get prompts", responses = { + @ApiResponse(responseCode = "200", description = "OK", content = @Content(schema = @Schema(implementation = Prompt.PromptPage.class))), + }) + @JsonView({Prompt.View.Public.class}) + public Response getPrompts( + @QueryParam("page") @Min(1) @DefaultValue("1") int page, + @QueryParam("size") @Min(1) @DefaultValue("10") int size, + @QueryParam("name") String name) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Getting prompts by name '{}' on workspace_id '{}'", name, workspaceId); + var promptPage = Prompt.PromptPage.builder() + .page(page) + .size(5) + .total(5) + .content(IntStream.range(0, 5).mapToObj(i -> generatePrompt()).toList()) + .build(); + log.info("Got prompts by name '{}', count '{}' on workspace_id '{}'", name, promptPage.size(), workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED).entity(promptPage).build(); + } + + private Prompt generatePrompt() { + return Prompt.builder() + .id(idGenerator.generateId()) + .name("Prompt 1") + .description("Description 1") + .createdAt(Instant.now()) + .createdBy("User 1") + .lastUpdatedAt(Instant.now()) + .lastUpdatedBy("User 1") + .latestVersion(generatePromptVersion()) + .build(); + } + + @GET + @Path("{id}") + @Operation(operationId = "getPromptById", summary = "Get prompt by id", description = "Get prompt by id", responses = { + @ApiResponse(responseCode = "200", description = "Prompt resource", content = @Content(schema = @Schema(implementation = Prompt.class))), + @ApiResponse(responseCode = "404", description = "Not Found", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))), + }) + @JsonView({Prompt.View.Detail.class}) + public Response getPromptById(@PathParam("id") UUID id) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Getting prompt by id '{}' on workspace_id '{}'", id, workspaceId); + + Prompt prompt = generatePrompt(); + + log.info("Got prompt by id '{}' on workspace_id '{}'", id, workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED).entity(prompt).build(); + } + + @PUT + @Path("{id}") + @Operation(operationId = "updatePrompt", summary = "Update prompt", description = "Update prompt", responses = { + @ApiResponse(responseCode = "204", description = "No content"), + @ApiResponse(responseCode = "422", description = "Unprocessable Content", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "404", description = "Not Found", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))), + @ApiResponse(responseCode = "409", description = "Conflict", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))), + }) + @RateLimited + public Response updatePrompt( + @PathParam("id") UUID id, + @RequestBody(content = @Content(schema = @Schema(implementation = Prompt.class))) @JsonView(Prompt.View.Write.class) @Valid Prompt prompt) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Updating prompt with id '{}' on workspace_id '{}'", id, workspaceId); + + log.info("Updated prompt with id '{}' on workspace_id '{}'", id, workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED).build(); + } + + @DELETE + @Path("{id}") + @Operation(operationId = "deletePrompt", summary = "Delete prompt", description = "Delete prompt", responses = { + @ApiResponse(responseCode = "204", description = "No content") + }) + public Response deletePrompt(@PathParam("id") UUID id) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Deleting prompt by id '{}' on workspace_id '{}'", id, workspaceId); + + log.info("Deleted prompt by id '{}' on workspace_id '{}'", id, workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED).build(); + } + + @POST + @Path("/versions") + @Operation(operationId = "createPromptVersion", summary = "Create prompt version", description = "Create prompt version", responses = { + @ApiResponse(responseCode = "200", description = "OK", content = @Content(schema = @Schema(implementation = PromptVersion.class))), + @ApiResponse(responseCode = "422", description = "Unprocessable Content", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "409", description = "Conflict", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))) + }) + @RateLimited + @JsonView({PromptVersion.View.Detail.class}) + public Response createPromptVersion( + @RequestBody(content = @Content(schema = @Schema(implementation = CreatePromptVersion.class))) @JsonView({ + PromptVersion.View.Detail.class}) @Valid CreatePromptVersion promptVersion) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Creating prompt version commit '{}' on workspace_id '{}'", promptVersion.version().commit(), + workspaceId); + + UUID id = idGenerator.generateId(); + log.info("Created prompt version commit '{}' with id '{}' on workspace_id '{}'", + promptVersion.version().commit(), id, workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED) + .entity(generatePromptVersion(promptVersion, id)) + .build(); + } + + private PromptVersion generatePromptVersion(CreatePromptVersion promptVersion, UUID id) { + return PromptVersion.builder() + .id(id) + .commit(promptVersion.version().commit() == null + ? id.toString().substring(id.toString().length() - 7) + : promptVersion.version().commit()) + .template(promptVersion.version().template()) + .variables( + Set.of("user_message")) + .createdAt(Instant.now()) + .createdBy("User 1") + .build(); + } + + @GET + @Path("/{id}/versions") + @Operation(operationId = "getPromptVersions", summary = "Get prompt versions", description = "Get prompt versions", responses = { + @ApiResponse(responseCode = "200", description = "OK", content = @Content(schema = @Schema(implementation = PromptVersion.PromptVersionPage.class))), + }) + @JsonView({PromptVersion.View.Public.class}) + public Response getPromptVersions(@PathParam("id") UUID id, + @QueryParam("page") @Min(1) @DefaultValue("1") int page, + @QueryParam("size") @Min(1) @DefaultValue("10") int size) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Getting prompt versions by id '{}' on workspace_id '{}'", id, workspaceId); + + PromptVersion.PromptVersionPage promptVersionPage = PromptVersion.PromptVersionPage.builder() + .page(1) + .size(5) + .total(5) + .content(IntStream.range(0, 5).mapToObj(i -> generatePromptVersion()).toList()) + .build(); + + log.info("Got prompt versions by id '{}' on workspace_id '{}'", id, workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED).entity(promptVersionPage).build(); + } + + @GET + @Path("/{id}/versions/{versionId}") + @Operation(operationId = "getPromptVersionById", summary = "Get prompt version by id", description = "Get prompt version by id", responses = { + @ApiResponse(responseCode = "200", description = "Prompt version resource", content = @Content(schema = @Schema(implementation = PromptVersion.class))), + @ApiResponse(responseCode = "404", description = "Not Found", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))), + }) + @JsonView({PromptVersion.View.Detail.class}) + public Response getPromptVersionById(@PathParam("id") UUID id, @PathParam("versionId") UUID versionId) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Getting prompt id '{}' and version by id '{}' on workspace_id '{}'", id, versionId, workspaceId); + + PromptVersion promptVersion = generatePromptVersion().toBuilder() + .id(versionId) + .commit(versionId.toString().substring(versionId.toString().length() - 7)) + .build(); + + log.info("Got prompt id '{}' and version by id '{}' on workspace_id '{}'", id, versionId, workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED).entity(promptVersion).build(); + } + + @POST + @Path("/prompts/versions/retrieve") + @Operation(operationId = "retrievePromptVersion", summary = "Retrieve prompt version", description = "Retrieve prompt version", responses = { + @ApiResponse(responseCode = "200", description = "OK", content = @Content(schema = @Schema(implementation = PromptVersion.class))), + @ApiResponse(responseCode = "422", description = "Unprocessable Content", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "400", description = "Bad Request", content = @Content(schema = @Schema(implementation = ErrorMessage.class))), + @ApiResponse(responseCode = "404", description = "Not Found", content = @Content(schema = @Schema(implementation = io.dropwizard.jersey.errors.ErrorMessage.class))), + }) + @JsonView({PromptVersion.View.Detail.class}) + public Response retrievePromptVersion( + @RequestBody(content = @Content(schema = @Schema(implementation = PromptVersionRetrieve.class))) @Valid PromptVersionRetrieve retrieve) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + log.info("Retrieving prompt name '{}' with commit '{}' on workspace_id '{}'", retrieve.name(), + retrieve.commit(), workspaceId); + + UUID id = idGenerator.generateId(); + + log.info("Retrieved prompt name '{}' with commit '{}' on workspace_id '{}'", retrieve.name(), + retrieve.commit(), workspaceId); + + return Response.status(Response.Status.NOT_IMPLEMENTED) + .entity(generatePromptVersion().toBuilder() + .id(id) + .commit(retrieve.commit() == null + ? id.toString().substring(id.toString().length() - 7) + : retrieve.commit()) + .build()) + .build(); + } + + private PromptVersion generatePromptVersion() { + var id = idGenerator.generateId(); + return PromptVersion.builder() + .id(id) + .commit(id.toString().substring(id.toString().length() - 7)) + .template("Hello %s, My question is ${user_message}".formatted(id)) + .variables( + Set.of("user_message")) + .createdAt(Instant.now()) + .createdBy("User 1") + .build(); + } + } From ace7442218c9dad968b9fdd4b45435c420b61346 Mon Sep 17 00:00:00 2001 From: Thiago Hora Date: Mon, 4 Nov 2024 12:19:54 +0100 Subject: [PATCH 4/9] [OPIK-310] Expose get prompts api --- .../api/resources/v1/priv/PromptResource.java | 11 +- .../java/com/comet/opik/domain/PromptDAO.java | 20 ++ .../com/comet/opik/domain/PromptService.java | 27 ++ .../resources/v1/priv/PromptResourceTest.java | 247 +++++++++++++++++- 4 files changed, 297 insertions(+), 8 deletions(-) diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java index daf20d0f8..c1e7ae250 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java @@ -97,15 +97,12 @@ public Response getPrompts( String workspaceId = requestContext.get().getWorkspaceId(); log.info("Getting prompts by name '{}' on workspace_id '{}'", name, workspaceId); - var promptPage = Prompt.PromptPage.builder() - .page(page) - .size(5) - .total(5) - .content(IntStream.range(0, 5).mapToObj(i -> generatePrompt()).toList()) - .build(); + + Prompt.PromptPage promptPage = promptService.find(name, page, size); + log.info("Got prompts by name '{}', count '{}' on workspace_id '{}'", name, promptPage.size(), workspaceId); - return Response.status(Response.Status.NOT_IMPLEMENTED).entity(promptPage).build(); + return Response.ok(promptPage).build(); } private Prompt generatePrompt() { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java index 6bfaee9d6..8214e54e7 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java @@ -4,11 +4,15 @@ import com.comet.opik.infrastructure.db.UUIDArgumentFactory; import org.jdbi.v3.sqlobject.config.RegisterArgumentFactory; import org.jdbi.v3.sqlobject.config.RegisterConstructorMapper; +import org.jdbi.v3.sqlobject.customizer.AllowUnusedBindings; import org.jdbi.v3.sqlobject.customizer.Bind; import org.jdbi.v3.sqlobject.customizer.BindMethods; +import org.jdbi.v3.sqlobject.customizer.Define; import org.jdbi.v3.sqlobject.statement.SqlQuery; import org.jdbi.v3.sqlobject.statement.SqlUpdate; +import org.jdbi.v3.stringtemplate4.UseStringTemplateEngine; +import java.util.List; import java.util.UUID; @RegisterConstructorMapper(Prompt.class) @@ -22,4 +26,20 @@ interface PromptDAO { @SqlQuery("SELECT * FROM prompts WHERE id = :id AND workspace_id = :workspaceId") Prompt findById(@Bind("id") UUID id, @Bind("workspaceId") String workspaceId); + @SqlQuery("SELECT * FROM prompts " + + " WHERE workspace_id = :workspace_Id " + + " AND name like concat('%', :name, '%') " + + " ORDER BY id DESC " + + " LIMIT :limit OFFSET :offset ") + @UseStringTemplateEngine + @AllowUnusedBindings + List find(@Define("name") @Bind("name") String name, @Bind("workspace_Id") String workspaceId, + @Bind("offset") int offset, @Bind("limit") int limit); + + @SqlQuery("SELECT COUNT(id) FROM prompts " + + " WHERE workspace_id = :workspace_Id " + + " AND name like concat('%', :name, '%') ") + @UseStringTemplateEngine + @AllowUnusedBindings + long count(@Define("name") @Bind("name") String name, @Bind("workspace_Id") String workspaceId); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java index 4916baa41..1f49b3e3f 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java @@ -13,12 +13,16 @@ import lombok.extern.slf4j.Slf4j; import ru.vyarus.guicey.jdbi3.tx.TransactionTemplate; +import java.util.List; + +import static com.comet.opik.api.Prompt.PromptPage; import static com.comet.opik.infrastructure.db.TransactionTemplateAsync.WRITE; @ImplementedBy(PromptServiceImpl.class) public interface PromptService { Prompt prompt(Prompt prompt); + PromptPage find(String name, int page, int size); } @Singleton @@ -60,6 +64,29 @@ private Prompt savePrompt(String workspaceId, Prompt newPrompt) { }); } + @Override + public PromptPage find(String name, int page, int size) { + + String workspaceId = requestContext.get().getWorkspaceId(); + + return transactionTemplate.inTransaction(handle -> { + PromptDAO promptDAO = handle.attach(PromptDAO.class); + + long total = promptDAO.count(name, workspaceId); + + var offset = (page - 1) * size; + + List content = promptDAO.find(name, workspaceId, offset, size); + + return PromptPage.builder() + .page(page) + .size(content.size()) + .content(content) + .total(total) + .build(); + }); + } + private EntityAlreadyExistsException newConflict() { log.info(ALREADY_EXISTS); return new EntityAlreadyExistsException(new ErrorMessage(ALREADY_EXISTS)); diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java index a7f67a614..4e82adc80 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java @@ -18,8 +18,11 @@ import com.github.tomakehurst.wiremock.client.WireMock; import com.redis.testcontainers.RedisContainer; import jakarta.ws.rs.client.Entity; +import jakarta.ws.rs.client.WebTarget; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; +import org.assertj.core.api.Assertions; +import org.assertj.core.api.recursive.comparison.RecursiveComparisonConfiguration; import org.jdbi.v3.core.Jdbi; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; @@ -41,8 +44,10 @@ import uk.co.jemos.podam.api.PodamFactory; import java.sql.SQLException; +import java.time.Instant; import java.util.List; import java.util.UUID; +import java.util.stream.IntStream; import java.util.stream.Stream; import static com.comet.opik.api.resources.utils.ClickHouseContainerUtils.DATABASE_NAME; @@ -183,6 +188,34 @@ void createProject__whenApiKeyIsPresent__thenReturnProperResponse(String apiKey, } } + @ParameterizedTest + @MethodSource("credentials") + @DisplayName("find prompt: when api key is present, then return proper response") + void findPrompt__whenApiKeyIsPresent__thenReturnProperResponse(String apiKey, boolean success) { + + String workspaceName = UUID.randomUUID().toString(); + + mockTargetWorkspace(okApikey, workspaceName, WORKSPACE_ID); + + try (var actualResponse = client.target(RESOURCE_PATH.formatted(baseURI)) + .request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(WORKSPACE_HEADER, workspaceName) + .get()) { + + if (success) { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(200); + assertThat(actualResponse.hasEntity()).isTrue(); + } else { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(401); + assertThat(actualResponse.hasEntity()).isTrue(); + assertThat(actualResponse.readEntity(io.dropwizard.jersey.errors.ErrorMessage.class)) + .isEqualTo(UNAUTHORIZED_RESPONSE); + } + } + } + } @Nested @@ -238,6 +271,30 @@ void createProject__whenSessionTokenIsPresent__thenReturnProperResponse(String s } } } + + @ParameterizedTest + @MethodSource("credentials") + @DisplayName("find prompt: when session token is present, then return proper response") + void findPrompt__whenSessionTokenIsPresent__thenReturnProperResponse(String sessionToken, boolean success, + String workspaceName) { + + try (var actualResponse = client.target(RESOURCE_PATH.formatted(baseURI)).request() + .accept(MediaType.APPLICATION_JSON_TYPE) + .cookie(SESSION_COOKIE, sessionToken) + .header(WORKSPACE_HEADER, workspaceName) + .get()) { + + if (success) { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(200); + assertThat(actualResponse.hasEntity()).isTrue(); + } else { + assertThat(actualResponse.getStatusInfo().getStatusCode()).isEqualTo(401); + assertThat(actualResponse.hasEntity()).isTrue(); + assertThat(actualResponse.readEntity(io.dropwizard.jersey.errors.ErrorMessage.class)) + .isEqualTo(UNAUTHORIZED_RESPONSE); + } + } + } } private UUID createPrompt(Prompt prompt, String apiKey, String workspaceName) { @@ -259,7 +316,7 @@ private UUID createPrompt(Prompt prompt, String apiKey, String workspaceName) { class CreatePrompt { @Test - @DisplayName("Should create prompt") + @DisplayName("Success: should create prompt") void shouldCreatePrompt() { var prompt = factory.manufacturePojo(Prompt.class); @@ -315,4 +372,192 @@ Stream when__promptIsInvalid__thenReturnError() { } } + @Nested + @DisplayName("Find Prompt") + @TestInstance(TestInstance.Lifecycle.PER_CLASS) + class FindPrompt { + + @Test + @DisplayName("Success: should find prompt") + void shouldFindPrompt() { + + String apiKey = UUID.randomUUID().toString(); + String workspaceName = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId); + + var prompt = factory.manufacturePojo(Prompt.class).toBuilder() + .lastUpdatedBy(USER) + .createdBy(USER) + .build(); + + createPrompt(prompt, apiKey, workspaceName); + + List expectedPrompts = List.of(prompt); + + findPromptsAndAssertPage(expectedPrompts, apiKey, workspaceName, expectedPrompts.size(), 1, null); + } + + @Test + @DisplayName("when search by name, then return prompt matching name") + void when__searchByName__thenReturnPromptMatchingName() { + + String apiKey = UUID.randomUUID().toString(); + String workspaceName = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId); + + var prompt = factory.manufacturePojo(Prompt.class).toBuilder() + .lastUpdatedBy(USER) + .createdBy(USER) + .build(); + + createPrompt(prompt, apiKey, workspaceName); + + List expectedPrompts = List.of(prompt); + + findPromptsAndAssertPage(expectedPrompts, apiKey, workspaceName, expectedPrompts.size(), 1, prompt.name()); + } + + @ParameterizedTest + @MethodSource + @DisplayName("when search by partial name, then return prompt matching name") + void when__searchByPartialName__thenReturnPromptMatchingName(String promptName, String partialSearch) { + + String apiKey = UUID.randomUUID().toString(); + String workspaceName = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId); + + IntStream.range(0, 4).forEach(i -> { + var prompt = factory.manufacturePojo(Prompt.class).toBuilder() + .lastUpdatedBy(USER) + .createdBy(USER) + .build(); + + createPrompt(prompt, apiKey, workspaceName); + }); + + var prompt = factory.manufacturePojo(Prompt.class).toBuilder() + .name(promptName) + .lastUpdatedBy(USER) + .createdBy(USER) + .build(); + + createPrompt(prompt, apiKey, workspaceName); + + List expectedPrompts = List.of(prompt); + findPromptsAndAssertPage(expectedPrompts, apiKey, workspaceName, expectedPrompts.size(), 1, partialSearch); + } + + Stream when__searchByPartialName__thenReturnPromptMatchingName() { + return Stream.of( + arguments("prompt", "pro"), + arguments("prompt", "pt"), + arguments("prompt", "om")); + } + + @Test + @DisplayName("when fetch prompts, then return prompts sorted by creation time") + void when__fetchPrompts__thenReturnPromptsSortedByCreationTime() { + + String apiKey = UUID.randomUUID().toString(); + String workspaceName = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId); + + var prompts = PodamFactoryUtils.manufacturePojoList(factory, Prompt.class).stream() + .map(prompt -> prompt.toBuilder() + .lastUpdatedBy(USER) + .createdBy(USER) + .build()) + .toList(); + + prompts.forEach(prompt -> createPrompt(prompt, apiKey, workspaceName)); + + List expectedPrompts = prompts.reversed(); + + findPromptsAndAssertPage(expectedPrompts, apiKey, workspaceName, expectedPrompts.size(), 1, null); + } + + @Test + @DisplayName("when fetch prompts using pagination, then return prompts paginated") + void when__fetchPromptsUsingPagination__thenReturnPromptsPaginated() { + + String apiKey = UUID.randomUUID().toString(); + String workspaceName = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId); + + var prompts = IntStream.range(0, 20) + .mapToObj(i -> factory.manufacturePojo(Prompt.class).toBuilder() + .lastUpdatedBy(USER) + .createdBy(USER) + .build()) + .toList(); + + prompts.forEach(prompt -> createPrompt(prompt, apiKey, workspaceName)); + + List promptPage1 = prompts.reversed().subList(0, 10); + List promptPage2 = prompts.reversed().subList(10, 20); + + findPromptsAndAssertPage(promptPage1, apiKey, workspaceName, prompts.size(), 1, null); + findPromptsAndAssertPage(promptPage2, apiKey, workspaceName, prompts.size(), 2, null); + } + } + + private void findPromptsAndAssertPage(List expectedPrompts, String apiKey, String workspaceName, + int expectedTotal, int page, String nameSearch) { + + WebTarget target = client.target(RESOURCE_PATH.formatted(baseURI)); + + if (nameSearch != null) { + target = target.queryParam("name", nameSearch); + } + + if (page > 1) { + target = target.queryParam("page", page); + } + + try (var response = target + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(RequestContext.WORKSPACE_HEADER, workspaceName) + .get()) { + + assertThat(response.getStatus()).isEqualTo(200); + + var promptPage = response.readEntity(Prompt.PromptPage.class); + + assertThat(promptPage.total()).isEqualTo(expectedTotal); + assertThat(promptPage.content()).hasSize(expectedPrompts.size()); + assertThat(promptPage.page()).isEqualTo(page); + assertThat(promptPage.size()).isEqualTo(expectedPrompts.size()); + + assertThat(promptPage.content()) + .usingRecursiveComparison( + RecursiveComparisonConfiguration.builder() + .withIgnoredFields("versionCount", "latestVersion") + .withComparatorForType(this::comparatorForCreateAtAndUpdatedAt, Instant.class) + .build()) + .isEqualTo(expectedPrompts); + } + } + + private int comparatorForCreateAtAndUpdatedAt(Instant actual, Instant expected) { + var now = Instant.now(); + + if (actual.isAfter(now) || actual.equals(now)) + return 1; + if (actual.isBefore(expected)) + return -1; + + Assertions.assertThat(actual).isBetween(expected, now); + return 0; + } } \ No newline at end of file From 6a32ec10b87c465ffe62f010186f650be4e315cc Mon Sep 17 00:00:00 2001 From: Thiago Hora Date: Mon, 4 Nov 2024 14:40:58 +0100 Subject: [PATCH 5/9] Add logic to create first version when specified --- .../main/java/com/comet/opik/api/Prompt.java | 3 + .../com/comet/opik/api/PromptVersion.java | 58 +++++++++++++ .../comet/opik/domain/CommitGenerator.java | 14 +++ .../opik/domain/EntityConstraintHandler.java | 27 ++++++ .../com/comet/opik/domain/PromptService.java | 46 +++++++++- .../comet/opik/domain/PromptVersionDAO.java | 25 ++++++ ..._increate_prompt_version_commit_length.sql | 6 ++ .../domain/EntityConstraintHandlerTest.java | 85 +++++++++++++++++++ 8 files changed, 262 insertions(+), 2 deletions(-) create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/domain/CommitGenerator.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/domain/PromptVersionDAO.java create mode 100644 apps/opik-backend/src/main/resources/liquibase/db-app-state/migrations/000005_increate_prompt_version_commit_length.sql create mode 100644 apps/opik-backend/src/test/java/com/comet/opik/domain/EntityConstraintHandlerTest.java diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java b/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java index 023ab2788..a726fb4f4 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/Prompt.java @@ -5,6 +5,7 @@ import com.fasterxml.jackson.databind.PropertyNamingStrategies; import com.fasterxml.jackson.databind.annotation.JsonNaming; import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.annotation.Nullable; import jakarta.validation.constraints.NotBlank; import jakarta.validation.constraints.Pattern; import lombok.Builder; @@ -24,6 +25,8 @@ public record Prompt( @JsonView({Prompt.View.Public.class, Prompt.View.Write.class}) @NotBlank String name, @JsonView({Prompt.View.Public.class, Prompt.View.Write.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") String description, + @JsonView({ + Prompt.View.Write.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") @Nullable String template, @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt, @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy, @JsonView({Prompt.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant lastUpdatedAt, diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java new file mode 100644 index 000000000..cef10c243 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java @@ -0,0 +1,58 @@ +package com.comet.opik.api; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonView; +import com.fasterxml.jackson.databind.PropertyNamingStrategies; +import com.fasterxml.jackson.databind.annotation.JsonNaming; +import io.swagger.v3.oas.annotations.media.Schema; +import jakarta.annotation.Nullable; +import jakarta.validation.constraints.NotNull; +import lombok.Builder; + +import java.time.Instant; +import java.util.List; +import java.util.Set; +import java.util.UUID; + +@Builder(toBuilder = true) +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) +public record PromptVersion( + @JsonView( { + PromptVersion.View.Public.class, + PromptVersion.View.Detail.class}) @Schema(description = "version unique identifier, generated if absent") UUID id, + @JsonView({PromptVersion.View.Public.class, + PromptVersion.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) UUID promptId, + @JsonView({PromptVersion.View.Public.class, + PromptVersion.View.Detail.class}) @Schema(description = "version short unique identifier, generated if absent") String commit, + @JsonView({PromptVersion.View.Detail.class}) @NotNull String template, + @JsonView({ + PromptVersion.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Set variables, + @JsonView({PromptVersion.View.Public.class, + PromptVersion.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt, + @JsonView({PromptVersion.View.Public.class, + PromptVersion.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy){ + + public static class View { + public static class Public { + } + + public static class Detail { + } + } + + @Builder + public record PromptVersionPage( + @JsonView( { + PromptVersion.View.Public.class}) int page, + @JsonView({PromptVersion.View.Public.class}) int size, + @JsonView({PromptVersion.View.Public.class}) long total, + @JsonView({PromptVersion.View.Public.class}) List content) + implements + Page{ + + public static PromptVersion.PromptVersionPage empty(int page) { + return new PromptVersion.PromptVersionPage(page, 0, 0, List.of()); + } + } +} \ No newline at end of file diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/CommitGenerator.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/CommitGenerator.java new file mode 100644 index 000000000..d738e967d --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/CommitGenerator.java @@ -0,0 +1,14 @@ +package com.comet.opik.domain; + +import lombok.NonNull; +import lombok.experimental.UtilityClass; + +import java.util.UUID; + +@UtilityClass +class CommitGenerator { + + public String generateCommit(@NonNull UUID id) { + return id.toString().substring(id.toString().length() - 8); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/EntityConstraintHandler.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/EntityConstraintHandler.java index b6991c25d..a0919e854 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/EntityConstraintHandler.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/EntityConstraintHandler.java @@ -1,13 +1,18 @@ package com.comet.opik.domain; import com.comet.opik.api.error.EntityAlreadyExistsException; +import com.google.common.base.Preconditions; import org.jdbi.v3.core.statement.UnableToExecuteStatementException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.sql.SQLIntegrityConstraintViolationException; import java.util.function.Supplier; interface EntityConstraintHandler { + Logger log = LoggerFactory.getLogger(EntityConstraintHandler.class); + static EntityConstraintHandler handle(EntityConstraintAction entityAction) { return () -> entityAction; } @@ -30,4 +35,26 @@ default T withError(Supplier errorProvider) { } } + default T withRetry(int times, Supplier errorProvider) { + Preconditions.checkArgument(times > 0, "Retry times must be greater than 0"); + + return internalRetry(times, errorProvider); + } + + private T internalRetry(int times, Supplier errorProvider) { + try { + return wrappedAction().execute(); + } catch (UnableToExecuteStatementException e) { + if (e.getCause() instanceof SQLIntegrityConstraintViolationException) { + if (times > 0) { + log.warn("Retrying due to constraint violation, remaining attempts: {}", times); + return internalRetry(times - 1, errorProvider); + } + throw errorProvider.get(); + } else { + throw e; + } + } + } + } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java index 4916baa41..706bf0dab 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java @@ -1,6 +1,7 @@ package com.comet.opik.domain; import com.comet.opik.api.Prompt; +import com.comet.opik.api.PromptVersion; import com.comet.opik.api.error.EntityAlreadyExistsException; import com.comet.opik.infrastructure.auth.RequestContext; import com.google.inject.ImplementedBy; @@ -11,8 +12,11 @@ import lombok.NonNull; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; import ru.vyarus.guicey.jdbi3.tx.TransactionTemplate; +import java.util.UUID; + import static com.comet.opik.infrastructure.db.TransactionTemplateAsync.WRITE; @ImplementedBy(PromptServiceImpl.class) @@ -43,11 +47,49 @@ public Prompt prompt(Prompt prompt) { .lastUpdatedBy(userName) .build(); - IdGenerator.validateVersion(newPrompt.id(), "prompt"); + IdGenerator.validateVersion(prompt.id(), "prompt"); - return EntityConstraintHandler + var createdPrompt = EntityConstraintHandler .handle(() -> savePrompt(workspaceId, newPrompt)) .withError(this::newConflict); + + log.info("Prompt created with id '{}' name '{}', on workspace_id '{}'", createdPrompt.id(), + createdPrompt.name(), + workspaceId); + + if (!StringUtils.isEmpty(prompt.template())) { + EntityConstraintHandler + .handle(() -> createPromptVersionFromPromptRequest(prompt, createdPrompt, workspaceId)) + .withRetry(3, this::newConflict); + } + + return createdPrompt; + } + + private PromptVersion createPromptVersionFromPromptRequest(Prompt prompt, Prompt createdPrompt, + String workspaceId) { + log.info("Creating prompt version for prompt id '{}'", createdPrompt.id()); + + var createdVersion = transactionTemplate.inTransaction(WRITE, handle -> { + PromptVersionDAO promptVersionDAO = handle.attach(PromptVersionDAO.class); + + UUID versionId = idGenerator.generateId(); + PromptVersion promptVersion = PromptVersion.builder() + .id(versionId) + .promptId(createdPrompt.id()) + .commit(CommitGenerator.generateCommit(versionId)) + .template(prompt.template()) + .createdBy(createdPrompt.createdBy()) + .build(); + + promptVersionDAO.save(workspaceId, promptVersion); + + return promptVersionDAO.findById(versionId, workspaceId); + }); + + log.info("Created Prompt version for prompt id '{}'", createdPrompt.id()); + + return createdVersion; } private Prompt savePrompt(String workspaceId, Prompt newPrompt) { diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptVersionDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptVersionDAO.java new file mode 100644 index 000000000..f1f46fb00 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptVersionDAO.java @@ -0,0 +1,25 @@ +package com.comet.opik.domain; + +import com.comet.opik.api.PromptVersion; +import com.comet.opik.infrastructure.db.UUIDArgumentFactory; +import org.jdbi.v3.sqlobject.config.RegisterArgumentFactory; +import org.jdbi.v3.sqlobject.config.RegisterConstructorMapper; +import org.jdbi.v3.sqlobject.customizer.Bind; +import org.jdbi.v3.sqlobject.customizer.BindMethods; +import org.jdbi.v3.sqlobject.statement.SqlQuery; +import org.jdbi.v3.sqlobject.statement.SqlUpdate; + +import java.util.UUID; + +@RegisterConstructorMapper(PromptVersion.class) +@RegisterArgumentFactory(UUIDArgumentFactory.class) +interface PromptVersionDAO { + + @SqlUpdate("INSERT INTO prompt_versions (id, prompt_id, commit, template, created_by, workspace_id) " + + "VALUES (:bean.id, :bean.promptId, :bean.commit, :bean.template, :bean.createdBy, :workspace_id)") + void save(@Bind("workspace_id") String workspaceId, @BindMethods("bean") PromptVersion prompt); + + @SqlQuery("SELECT * FROM prompt_versions WHERE id = :id AND workspace_id = :workspace_id") + PromptVersion findById(@Bind("id") UUID id, @Bind("workspace_id") String workspaceId); + +} diff --git a/apps/opik-backend/src/main/resources/liquibase/db-app-state/migrations/000005_increate_prompt_version_commit_length.sql b/apps/opik-backend/src/main/resources/liquibase/db-app-state/migrations/000005_increate_prompt_version_commit_length.sql new file mode 100644 index 000000000..021459eb5 --- /dev/null +++ b/apps/opik-backend/src/main/resources/liquibase/db-app-state/migrations/000005_increate_prompt_version_commit_length.sql @@ -0,0 +1,6 @@ +--liquibase formatted sql +--changeset thiagohora:increate_prompt_version_commit_length + +ALTER TABLE prompt_versions MODIFY COLUMN commit VARCHAR(8); + +--rollback ALTER TABLE prompt_versions MODIFY COLUMN commit VARCHAR(7); diff --git a/apps/opik-backend/src/test/java/com/comet/opik/domain/EntityConstraintHandlerTest.java b/apps/opik-backend/src/test/java/com/comet/opik/domain/EntityConstraintHandlerTest.java new file mode 100644 index 000000000..df0b7a1bd --- /dev/null +++ b/apps/opik-backend/src/test/java/com/comet/opik/domain/EntityConstraintHandlerTest.java @@ -0,0 +1,85 @@ +package com.comet.opik.domain; + +import com.comet.opik.api.error.EntityAlreadyExistsException; +import io.dropwizard.jersey.errors.ErrorMessage; +import org.jdbi.v3.core.statement.StatementContext; +import org.jdbi.v3.core.statement.UnableToExecuteStatementException; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import java.sql.SQLIntegrityConstraintViolationException; +import java.util.function.Supplier; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class EntityConstraintHandlerTest { + + private static final Supplier ENTITY_ALREADY_EXISTS = () -> new EntityAlreadyExistsException( + new ErrorMessage(409, "Entity already exists")); + + @Test + void testWithError() { + EntityConstraintHandler handler = EntityConstraintHandler.handle(() -> { + fail(); + return null; + }); + + assertThrows(EntityAlreadyExistsException.class, () -> handler.withError(ENTITY_ALREADY_EXISTS)); + } + + private static void fail() { + throw new UnableToExecuteStatementException(new SQLIntegrityConstraintViolationException( + "Duplicate entry '1' for key 'PRIMARY'"), Mockito.mock(StatementContext.class)); + } + + @Test + void testWithRetrySuccess() { + EntityConstraintHandler handler = EntityConstraintHandler.handle(() -> "Success"); + + assertEquals("Success", handler.withRetry(3, ENTITY_ALREADY_EXISTS)); + } + + @Test + void testWithRetryFailure() { + EntityConstraintHandler.EntityConstraintAction action = Mockito + .spy(new EntityConstraintHandler.EntityConstraintAction() { + @Override + public String execute() { + fail(); + return ""; + } + }); + + EntityConstraintHandler handler = EntityConstraintHandler.handle(action); + + assertThrows(EntityAlreadyExistsException.class, () -> handler.withRetry(3, ENTITY_ALREADY_EXISTS)); + Mockito.verify(action, Mockito.times(4)).execute(); + } + + @Test + void testWithRetryExhausted() { + EntityConstraintHandler.EntityConstraintAction action = Mockito + .spy(new EntityConstraintHandler.EntityConstraintAction() { + @Override + public String execute() { + fail(); + return ""; + } + }); + + EntityConstraintHandler handler = EntityConstraintHandler.handle(action); + + assertThrows(EntityAlreadyExistsException.class, () -> handler.withRetry(1, ENTITY_ALREADY_EXISTS)); + Mockito.verify(action, Mockito.times(2)).execute(); + } + + @Test + void testWithRetryNonConstraintViolation() { + EntityConstraintHandler handler = EntityConstraintHandler.handle(() -> { + throw new UnableToExecuteStatementException(new RuntimeException(), Mockito.mock(StatementContext.class)); + }); + + assertThrows(UnableToExecuteStatementException.class, () -> handler.withRetry(3, ENTITY_ALREADY_EXISTS)); + } +} From 2b5c7d0713621b56bbd622965290eec8472bc8c9 Mon Sep 17 00:00:00 2001 From: Thiago Hora Date: Mon, 4 Nov 2024 16:19:25 +0100 Subject: [PATCH 6/9] Initial commit --- .../api/resources/v1/priv/PromptResource.java | 23 +--- .../opik/domain/EntityConstraintHandler.java | 12 ++ .../java/com/comet/opik/domain/PromptDAO.java | 3 + .../com/comet/opik/domain/PromptService.java | 127 ++++++++++++++++-- .../resources/v1/priv/PromptResourceTest.java | 48 ++++++- 5 files changed, 179 insertions(+), 34 deletions(-) diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java index c1e7ae250..906b494e0 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/resources/v1/priv/PromptResource.java @@ -196,27 +196,12 @@ public Response createPromptVersion( log.info("Creating prompt version commit '{}' on workspace_id '{}'", promptVersion.version().commit(), workspaceId); - UUID id = idGenerator.generateId(); - log.info("Created prompt version commit '{}' with id '{}' on workspace_id '{}'", - promptVersion.version().commit(), id, workspaceId); + var createdVersion = promptService.createPromptVersion(promptVersion); - return Response.status(Response.Status.NOT_IMPLEMENTED) - .entity(generatePromptVersion(promptVersion, id)) - .build(); - } + log.info("Created prompt version commit '{}' with id '{}' on workspace_id '{}'", + promptVersion.version().commit(), createdVersion.id(), workspaceId); - private PromptVersion generatePromptVersion(CreatePromptVersion promptVersion, UUID id) { - return PromptVersion.builder() - .id(id) - .commit(promptVersion.version().commit() == null - ? id.toString().substring(id.toString().length() - 7) - : promptVersion.version().commit()) - .template(promptVersion.version().template()) - .variables( - Set.of("user_message")) - .createdAt(Instant.now()) - .createdBy("User 1") - .build(); + return Response.ok(createdVersion).build(); } @GET diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/EntityConstraintHandler.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/EntityConstraintHandler.java index a0919e854..9a79ab9b6 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/EntityConstraintHandler.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/EntityConstraintHandler.java @@ -35,6 +35,18 @@ default T withError(Supplier errorProvider) { } } + default T onErrorDo(Supplier errorProvider) { + try { + return wrappedAction().execute(); + } catch (UnableToExecuteStatementException e) { + if (e.getCause() instanceof SQLIntegrityConstraintViolationException) { + return errorProvider.get(); + } else { + throw e; + } + } + } + default T withRetry(int times, Supplier errorProvider) { Preconditions.checkArgument(times > 0, "Retry times must be greater than 0"); diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java index 8214e54e7..e9a9e0084 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java @@ -42,4 +42,7 @@ List find(@Define("name") @Bind("name") String name, @Bind("workspace_Id @UseStringTemplateEngine @AllowUnusedBindings long count(@Define("name") @Bind("name") String name, @Bind("workspace_Id") String workspaceId); + + @SqlQuery("SELECT * FROM prompts WHERE name = :name AND workspace_id = :workspaceId") + Prompt findByName(String name, String workspaceId); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java index 4cda42fc6..c033e26eb 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java @@ -1,5 +1,6 @@ package com.comet.opik.domain; +import com.comet.opik.api.CreatePromptVersion; import com.comet.opik.api.Prompt; import com.comet.opik.api.PromptVersion; import com.comet.opik.api.error.EntityAlreadyExistsException; @@ -19,6 +20,7 @@ import java.util.UUID; import static com.comet.opik.api.Prompt.PromptPage; +import static com.comet.opik.infrastructure.db.TransactionTemplateAsync.READ_ONLY; import static com.comet.opik.infrastructure.db.TransactionTemplateAsync.WRITE; @ImplementedBy(PromptServiceImpl.class) @@ -26,6 +28,8 @@ public interface PromptService { Prompt prompt(Prompt prompt); PromptPage find(String name, int page, int size); + + PromptVersion createPromptVersion(CreatePromptVersion promptVersion); } @Singleton @@ -34,12 +38,13 @@ public interface PromptService { class PromptServiceImpl implements PromptService { public static final String ALREADY_EXISTS = "Prompt id or name already exists"; + public static final String VERSION_ALREADY_EXISTS = "Prompt version already exists"; private final @NonNull Provider requestContext; private final @NonNull IdGenerator idGenerator; private final @NonNull TransactionTemplate transactionTemplate; @Override - public Prompt prompt(Prompt prompt) { + public Prompt prompt(@NonNull Prompt prompt) { String workspaceId = requestContext.get().getWorkspaceId(); String userName = requestContext.get().getUserName(); @@ -50,9 +55,7 @@ public Prompt prompt(Prompt prompt) { .lastUpdatedBy(userName) .build(); - IdGenerator.validateVersion(prompt.id(), "prompt"); - - var createdPrompt = EntityConstraintHandler + Prompt createdPrompt = EntityConstraintHandler .handle(() -> savePrompt(workspaceId, newPrompt)) .withError(this::newConflict); @@ -62,15 +65,14 @@ public Prompt prompt(Prompt prompt) { if (!StringUtils.isEmpty(prompt.template())) { EntityConstraintHandler - .handle(() -> createPromptVersionFromPromptRequest(prompt, createdPrompt, workspaceId)) + .handle(() -> createPromptVersionFromPromptRequest(createdPrompt, workspaceId, prompt.template())) .withRetry(3, this::newConflict); } return createdPrompt; } - private PromptVersion createPromptVersionFromPromptRequest(Prompt prompt, Prompt createdPrompt, - String workspaceId) { + private PromptVersion createPromptVersionFromPromptRequest(Prompt createdPrompt, String workspaceId, String template) { log.info("Creating prompt version for prompt id '{}'", createdPrompt.id()); var createdVersion = transactionTemplate.inTransaction(WRITE, handle -> { @@ -81,7 +83,7 @@ private PromptVersion createPromptVersionFromPromptRequest(Prompt prompt, Prompt .id(versionId) .promptId(createdPrompt.id()) .commit(CommitGenerator.generateCommit(versionId)) - .template(prompt.template()) + .template(template) .createdBy(createdPrompt.createdBy()) .build(); @@ -95,16 +97,29 @@ private PromptVersion createPromptVersionFromPromptRequest(Prompt prompt, Prompt return createdVersion; } - private Prompt savePrompt(String workspaceId, Prompt newPrompt) { + private Prompt savePrompt(String workspaceId, Prompt prompt) { + + IdGenerator.validateVersion(prompt.id(), "prompt"); + return transactionTemplate.inTransaction(WRITE, handle -> { PromptDAO promptDAO = handle.attach(PromptDAO.class); - promptDAO.save(workspaceId, newPrompt); + promptDAO.save(workspaceId, prompt); - return promptDAO.findById(newPrompt.id(), workspaceId); + return promptDAO.findById(prompt.id(), workspaceId); }); } + private EntityAlreadyExistsException newConflict() { + log.info(ALREADY_EXISTS); + return new EntityAlreadyExistsException(new ErrorMessage(ALREADY_EXISTS)); + } + + private EntityAlreadyExistsException newVersionConflict() { + log.info(VERSION_ALREADY_EXISTS); + return new EntityAlreadyExistsException(new ErrorMessage(VERSION_ALREADY_EXISTS)); + } + @Override public PromptPage find(String name, int page, int size) { @@ -128,9 +143,93 @@ public PromptPage find(String name, int page, int size) { }); } - private EntityAlreadyExistsException newConflict() { - log.info(ALREADY_EXISTS); - return new EntityAlreadyExistsException(new ErrorMessage(ALREADY_EXISTS)); + private Prompt getOrCreatePrompt(String workspaceId, String name, String userName) { + + Prompt prompt = findByName(workspaceId, name); + + if (prompt != null) { + return prompt; + } + + var newPrompt = Prompt.builder() + .id(idGenerator.generateId()) + .name(name) + .createdBy(userName) + .lastUpdatedBy(userName) + .build(); + + return EntityConstraintHandler + .handle(() -> savePrompt(workspaceId, newPrompt)) + .onErrorDo(() -> findByName(workspaceId, name)); + } + + private Prompt findByName(String workspaceId, String name) { + return transactionTemplate.inTransaction(READ_ONLY, handle -> { + PromptDAO promptDAO = handle.attach(PromptDAO.class); + + return promptDAO.findByName(name, workspaceId); + }); + } + + @Override + public PromptVersion createPromptVersion(@NonNull CreatePromptVersion createPromptVersion) { + + String workspaceId = requestContext.get().getWorkspaceId(); + String userName = requestContext.get().getUserName(); + + Prompt prompt = getOrCreatePrompt(workspaceId, createPromptVersion.name(), userName); + + UUID id = createPromptVersion.version().id() == null ? idGenerator.generateId() : createPromptVersion.version().id(); + String commit = createPromptVersion.version().commit() == null ? CommitGenerator.generateCommit(id) : createPromptVersion.version().commit(); + + EntityConstraintHandler handler = EntityConstraintHandler.handle(() -> { + PromptVersion promptVersion = createPromptVersion.version().toBuilder() + .promptId(prompt.id()) + .createdBy(userName) + .id(id) + .commit(commit) + .build(); + + return savePromptVersion(workspaceId, promptVersion); + }); + + if (createPromptVersion.version().commit() != null) { + return handler.withError(this::newVersionConflict); + } else { + return handler.onErrorDo(() -> retryableCreateVersion(workspaceId, createPromptVersion, prompt, userName)); + } + } + + private PromptVersion retryableCreateVersion(String workspaceId, CreatePromptVersion request, Prompt prompt, String userName) { + return EntityConstraintHandler.handle(() -> { + UUID newId = idGenerator.generateId(); + + PromptVersion promptVersion = request.version().toBuilder() + .promptId(prompt.id()) + .createdBy(userName) + .id(newId) + .commit(CommitGenerator.generateCommit(newId)) + .build(); + + return savePromptVersion(workspaceId, promptVersion); + + }).withRetry(3, this::newVersionConflict); + } + + private PromptVersion savePromptVersion(String workspaceId, PromptVersion promptVersion) { + log.info("Creating prompt version for prompt id '{}'", promptVersion.promptId()); + + var createdVersion = transactionTemplate.inTransaction(WRITE, handle -> { + PromptVersionDAO promptVersionDAO = handle.attach(PromptVersionDAO.class); + + promptVersionDAO.save(workspaceId, promptVersion); + + return promptVersionDAO.findById(promptVersion.id(), workspaceId); + }); + + log.info("Created Prompt version for prompt id '{}'", promptVersion.promptId()); + + return createdVersion; } } diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java index 8956861af..edeb5f615 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java @@ -2,6 +2,7 @@ import com.comet.opik.api.Project; import com.comet.opik.api.Prompt; +import com.comet.opik.api.PromptVersion; import com.comet.opik.api.error.ErrorMessage; import com.comet.opik.api.resources.utils.AuthTestUtils; import com.comet.opik.api.resources.utils.ClickHouseContainerUtils; @@ -320,7 +321,25 @@ class CreatePrompt { @DisplayName("Success: should create prompt") void shouldCreatePrompt() { - var prompt = factory.manufacturePojo(Prompt.class); + var prompt = factory.manufacturePojo(Prompt.class).toBuilder() + .lastUpdatedBy(USER) + .createdBy(USER) + .template(null) + .build(); + + var promptId = createPrompt(prompt, API_KEY, TEST_WORKSPACE); + + assertThat(promptId).isNotNull(); + } + + @Test + @DisplayName("when prompt contains first version template, then return created prompt") + void when__promptContainsFirstVersionTemplate__thenReturnCreatedPrompt() { + + var prompt = factory.manufacturePojo(Prompt.class).toBuilder() + .lastUpdatedBy(USER) + .createdBy(USER) + .build(); var promptId = createPrompt(prompt, API_KEY, TEST_WORKSPACE); @@ -512,6 +531,33 @@ void when__fetchPromptsUsingPagination__thenReturnPromptsPaginated() { } } + @Nested + @DisplayName("Create Prompt Version") + @TestInstance(TestInstance.Lifecycle.PER_CLASS) + class CreatePromptVersion { + + @Test + @DisplayName("Success: should create prompt version") + void shouldCreatePromptVersion() { + + var prompt = factory.manufacturePojo(Prompt.class).toBuilder() + .lastUpdatedBy(USER) + .createdBy(USER) + .template(null) + .build(); + + var promptId = createPrompt(prompt, API_KEY, TEST_WORKSPACE); + + var promptVersion = factory.manufacturePojo(PromptVersion.class).toBuilder() + .promptId(promptId) + .createdBy(USER) + .build(); + + assertThat(promptId).isNotNull(); + } + + } + private void findPromptsAndAssertPage(List expectedPrompts, String apiKey, String workspaceName, int expectedTotal, int page, String nameSearch) { From 300301b1cfb0c1b966616685abf45606c7f0c567 Mon Sep 17 00:00:00 2001 From: Thiago Hora Date: Mon, 4 Nov 2024 17:04:51 +0100 Subject: [PATCH 7/9] Address PR review --- .../resources/v1/priv/PromptResourceTest.java | 31 ++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java index d8f7970d8..85dba06c9 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java @@ -20,6 +20,7 @@ import jakarta.ws.rs.client.WebTarget; import jakarta.ws.rs.core.HttpHeaders; import jakarta.ws.rs.core.MediaType; +import org.apache.commons.lang3.RandomStringUtils; import org.assertj.core.api.Assertions; import org.assertj.core.api.recursive.comparison.RecursiveComparisonConfiguration; import org.jdbi.v3.core.Jdbi; @@ -422,9 +423,36 @@ void when__searchByName__thenReturnPromptMatchingName() { findPromptsAndAssertPage(expectedPrompts, apiKey, workspaceName, expectedPrompts.size(), 1, prompt.name()); } + @Test + @DisplayName("when search by name with mismatched partial name, then return empty page") + void when__searchByNameWithMismatchedPartialName__thenReturnEmptyPage() { + + String apiKey = UUID.randomUUID().toString(); + String workspaceName = UUID.randomUUID().toString(); + String workspaceId = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId); + + String name = RandomStringUtils.randomAlphanumeric(10); + + String partialSearch = name.substring(0, 5) + "@" + RandomStringUtils.randomAlphanumeric(2); + + var prompt = factory.manufacturePojo(Prompt.class).toBuilder() + .name(name) + .lastUpdatedBy(USER) + .createdBy(USER) + .build(); + + createPrompt(prompt, apiKey, workspaceName); + + List expectedPrompts = List.of(); + + findPromptsAndAssertPage(expectedPrompts, apiKey, workspaceName, expectedPrompts.size(), 1, partialSearch); + } + @ParameterizedTest @MethodSource - @DisplayName("when search by partial name, then return prompt matching name") + @DisplayName("when search by partial name, then return prompt matching name") void when__searchByPartialName__thenReturnPromptMatchingName(String promptName, String partialSearch) { String apiKey = UUID.randomUUID().toString(); @@ -510,6 +538,7 @@ void when__fetchPromptsUsingPagination__thenReturnPromptsPaginated() { findPromptsAndAssertPage(promptPage1, apiKey, workspaceName, prompts.size(), 1, null); findPromptsAndAssertPage(promptPage2, apiKey, workspaceName, prompts.size(), 2, null); } + } private void findPromptsAndAssertPage(List expectedPrompts, String apiKey, String workspaceName, From 59a7c4542fc472ff2b8998a2588bbfc6cd24e8ab Mon Sep 17 00:00:00 2001 From: Thiago Hora Date: Mon, 4 Nov 2024 20:41:32 +0100 Subject: [PATCH 8/9] [OPIK-314] Create prompt version endpoint --- apps/opik-backend/pom.xml | 5 + .../comet/opik/api/CreatePromptVersion.java | 3 +- .../com/comet/opik/api/PromptVersion.java | 8 +- .../opik/api/validate/CommitValidation.java | 23 ++ .../opik/api/validate/CommitValidator.java | 15 ++ .../domain/MustacheVariableExtractor.java | 47 ++++ .../java/com/comet/opik/domain/PromptDAO.java | 4 +- .../com/comet/opik/domain/PromptService.java | 42 +++- .../com/comet/opik/utils/ValidationUtils.java | 1 + .../resources/v1/priv/PromptResourceTest.java | 232 +++++++++++++++++- .../comet/opik/podam/PodamFactoryUtils.java | 3 + .../PromptVersionManufacturer.java | 47 ++++ 12 files changed, 411 insertions(+), 19 deletions(-) create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/validate/CommitValidation.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/api/validate/CommitValidator.java create mode 100644 apps/opik-backend/src/main/java/com/comet/opik/domain/MustacheVariableExtractor.java create mode 100644 apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/PromptVersionManufacturer.java diff --git a/apps/opik-backend/pom.xml b/apps/opik-backend/pom.xml index 18776cfce..d49a4b974 100644 --- a/apps/opik-backend/pom.xml +++ b/apps/opik-backend/pom.xml @@ -62,6 +62,11 @@ + + com.github.spullara.mustache.java + compiler + 0.9.10 + io.opentelemetry.instrumentation opentelemetry-instrumentation-annotations diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/CreatePromptVersion.java b/apps/opik-backend/src/main/java/com/comet/opik/api/CreatePromptVersion.java index 04064916c..7f92ef5c9 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/CreatePromptVersion.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/CreatePromptVersion.java @@ -4,6 +4,7 @@ import com.fasterxml.jackson.annotation.JsonView; import com.fasterxml.jackson.databind.PropertyNamingStrategies; import com.fasterxml.jackson.databind.annotation.JsonNaming; +import jakarta.validation.Valid; import jakarta.validation.constraints.NotBlank; import jakarta.validation.constraints.NotNull; import lombok.Builder; @@ -13,5 +14,5 @@ @JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class) public record CreatePromptVersion(@JsonView( { PromptVersion.View.Detail.class}) @NotBlank String name, - @JsonView({PromptVersion.View.Detail.class}) @NotNull PromptVersion version){ + @JsonView({PromptVersion.View.Detail.class}) @NotNull @Valid PromptVersion version){ } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java index daecd65ef..1bdecaff4 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/PromptVersion.java @@ -1,12 +1,14 @@ package com.comet.opik.api; +import com.comet.opik.api.validate.CommitValidation; +import com.comet.opik.utils.ValidationUtils; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonView; import com.fasterxml.jackson.databind.PropertyNamingStrategies; import com.fasterxml.jackson.databind.annotation.JsonNaming; import io.swagger.v3.oas.annotations.media.Schema; import jakarta.annotation.Nullable; -import jakarta.validation.constraints.NotNull; +import jakarta.validation.constraints.NotBlank; import lombok.Builder; import java.time.Instant; @@ -24,8 +26,8 @@ public record PromptVersion( @JsonView({PromptVersion.View.Public.class, PromptVersion.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) UUID promptId, @JsonView({PromptVersion.View.Public.class, - PromptVersion.View.Detail.class}) @Schema(description = "version short unique identifier, generated if absent") String commit, - @JsonView({PromptVersion.View.Detail.class}) @NotNull String template, + PromptVersion.View.Detail.class}) @Schema(description = "version short unique identifier, generated if absent. it must be 8 characters long", requiredMode = Schema.RequiredMode.NOT_REQUIRED, pattern = ValidationUtils.COMMIT_PATTERN) @CommitValidation String commit, + @JsonView({PromptVersion.View.Detail.class}) @NotBlank String template, @JsonView({ PromptVersion.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Set variables, @JsonView({PromptVersion.View.Public.class, diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/validate/CommitValidation.java b/apps/opik-backend/src/main/java/com/comet/opik/api/validate/CommitValidation.java new file mode 100644 index 000000000..98168b5e3 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/validate/CommitValidation.java @@ -0,0 +1,23 @@ +package com.comet.opik.api.validate; + +import jakarta.validation.Constraint; +import jakarta.validation.Payload; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Target({ElementType.FIELD, ElementType.ANNOTATION_TYPE, ElementType.TYPE}) +@Retention(RetentionPolicy.RUNTIME) +@Constraint(validatedBy = {CommitValidator.class}) +@Documented +public @interface CommitValidation { + + String message() default "if present, the commit message must be 8 alphanumeric characters long"; + + Class[] groups() default {}; + + Class[] payload() default {}; +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/api/validate/CommitValidator.java b/apps/opik-backend/src/main/java/com/comet/opik/api/validate/CommitValidator.java new file mode 100644 index 000000000..1e0c16100 --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/api/validate/CommitValidator.java @@ -0,0 +1,15 @@ +package com.comet.opik.api.validate; + +import com.comet.opik.utils.ValidationUtils; +import jakarta.validation.ConstraintValidator; +import jakarta.validation.ConstraintValidatorContext; + +import java.util.regex.Pattern; + +public class CommitValidator implements ConstraintValidator { + + @Override + public boolean isValid(String commit, ConstraintValidatorContext context) { + return commit == null || Pattern.matches(ValidationUtils.COMMIT_PATTERN, commit); + } +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/MustacheVariableExtractor.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/MustacheVariableExtractor.java new file mode 100644 index 000000000..985b4594a --- /dev/null +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/MustacheVariableExtractor.java @@ -0,0 +1,47 @@ +package com.comet.opik.domain; + +import com.github.mustachejava.Code; +import com.github.mustachejava.DefaultMustacheFactory; +import com.github.mustachejava.Mustache; +import com.github.mustachejava.MustacheFactory; +import com.github.mustachejava.codes.ValueCode; +import lombok.experimental.UtilityClass; + +import java.io.StringReader; +import java.util.HashSet; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + +@UtilityClass +class MustacheVariableExtractor { + + public static final MustacheFactory MF = new DefaultMustacheFactory(); + + public static Set extractVariables(String template) { + Set variables = new HashSet<>(); + + // Initialize Mustache Factory + Mustache mustache = MF.compile(new StringReader(template), "template"); + + // Get th e root node of the template + Code[] codes = mustache.getCodes(); + collectVariables(codes, variables); + + return variables; + } + + private static void collectVariables(Code[] codes, Set variables) { + for (Code code : codes) { + if (Objects.requireNonNull(code) instanceof ValueCode valueCode) { + variables.add(valueCode.getName()); + } else { + Optional.ofNullable(code) + .map(Code::getCodes) + .map(it -> it.length > 0) + .ifPresent(it -> collectVariables(code.getCodes(), variables)); + } + } + } + +} diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java index e9a9e0084..f02f100cb 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptDAO.java @@ -43,6 +43,6 @@ List find(@Define("name") @Bind("name") String name, @Bind("workspace_Id @AllowUnusedBindings long count(@Define("name") @Bind("name") String name, @Bind("workspace_Id") String workspaceId); - @SqlQuery("SELECT * FROM prompts WHERE name = :name AND workspace_id = :workspaceId") - Prompt findByName(String name, String workspaceId); + @SqlQuery("SELECT * FROM prompts WHERE name = :name AND workspace_id = :workspace_id") + Prompt findByName(@Bind("name") String name, @Bind("workspace_id") String workspaceId); } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java index c033e26eb..86bd311c8 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java @@ -17,6 +17,7 @@ import ru.vyarus.guicey.jdbi3.tx.TransactionTemplate; import java.util.List; +import java.util.Set; import java.util.UUID; import static com.comet.opik.api.Prompt.PromptPage; @@ -72,7 +73,8 @@ public Prompt prompt(@NonNull Prompt prompt) { return createdPrompt; } - private PromptVersion createPromptVersionFromPromptRequest(Prompt createdPrompt, String workspaceId, String template) { + private PromptVersion createPromptVersionFromPromptRequest(Prompt createdPrompt, String workspaceId, + String template) { log.info("Creating prompt version for prompt id '{}'", createdPrompt.id()); var createdVersion = transactionTemplate.inTransaction(WRITE, handle -> { @@ -87,6 +89,8 @@ private PromptVersion createPromptVersionFromPromptRequest(Prompt createdPrompt, .createdBy(createdPrompt.createdBy()) .build(); + IdGenerator.validateVersion(promptVersion.id(), "prompt"); + promptVersionDAO.save(workspaceId, promptVersion); return promptVersionDAO.findById(versionId, workspaceId); @@ -179,8 +183,12 @@ public PromptVersion createPromptVersion(@NonNull CreatePromptVersion createProm Prompt prompt = getOrCreatePrompt(workspaceId, createPromptVersion.name(), userName); - UUID id = createPromptVersion.version().id() == null ? idGenerator.generateId() : createPromptVersion.version().id(); - String commit = createPromptVersion.version().commit() == null ? CommitGenerator.generateCommit(id) : createPromptVersion.version().commit(); + UUID id = createPromptVersion.version().id() == null + ? idGenerator.generateId() + : createPromptVersion.version().id(); + String commit = createPromptVersion.version().commit() == null + ? CommitGenerator.generateCommit(id) + : createPromptVersion.version().commit(); EntityConstraintHandler handler = EntityConstraintHandler.handle(() -> { PromptVersion promptVersion = createPromptVersion.version().toBuilder() @@ -196,11 +204,13 @@ public PromptVersion createPromptVersion(@NonNull CreatePromptVersion createProm if (createPromptVersion.version().commit() != null) { return handler.withError(this::newVersionConflict); } else { + // only retry if commit is not provided return handler.onErrorDo(() -> retryableCreateVersion(workspaceId, createPromptVersion, prompt, userName)); } } - private PromptVersion retryableCreateVersion(String workspaceId, CreatePromptVersion request, Prompt prompt, String userName) { + private PromptVersion retryableCreateVersion(String workspaceId, CreatePromptVersion request, Prompt prompt, + String userName) { return EntityConstraintHandler.handle(() -> { UUID newId = idGenerator.generateId(); @@ -219,17 +229,35 @@ private PromptVersion retryableCreateVersion(String workspaceId, CreatePromptVer private PromptVersion savePromptVersion(String workspaceId, PromptVersion promptVersion) { log.info("Creating prompt version for prompt id '{}'", promptVersion.promptId()); - var createdVersion = transactionTemplate.inTransaction(WRITE, handle -> { + IdGenerator.validateVersion(promptVersion.id(), "prompt version"); + + transactionTemplate.inTransaction(WRITE, handle -> { PromptVersionDAO promptVersionDAO = handle.attach(PromptVersionDAO.class); promptVersionDAO.save(workspaceId, promptVersion); - return promptVersionDAO.findById(promptVersion.id(), workspaceId); + return null; }); log.info("Created Prompt version for prompt id '{}'", promptVersion.promptId()); - return createdVersion; + return getById(workspaceId, promptVersion.id()); + } + + private PromptVersion getById(String workspaceId, UUID id) { + PromptVersion promptVersion = transactionTemplate.inTransaction(READ_ONLY, handle -> { + PromptVersionDAO promptVersionDAO = handle.attach(PromptVersionDAO.class); + + return promptVersionDAO.findById(id, workspaceId); + }); + + return promptVersion.toBuilder() + .variables(getVariables(promptVersion.template())) + .build(); + } + + private Set getVariables(String template) { + return MustacheVariableExtractor.extractVariables(template); } } diff --git a/apps/opik-backend/src/main/java/com/comet/opik/utils/ValidationUtils.java b/apps/opik-backend/src/main/java/com/comet/opik/utils/ValidationUtils.java index b8d140e70..b0bc78726 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/utils/ValidationUtils.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/utils/ValidationUtils.java @@ -26,6 +26,7 @@ public class ValidationUtils { * @see Ai Explainer */ public static final String NULL_OR_NOT_BLANK = "(?s)^\\s*(\\S.*\\S|\\S)\\s*$"; + public static final String COMMIT_PATTERN = "^[a-zA-Z0-9]{8}$"; /** * Canonical String representation to ensure precision over float or double. diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java index edeb5f615..f0556a269 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java @@ -1,5 +1,6 @@ package com.comet.opik.api.resources.v1.priv; +import com.comet.opik.api.CreatePromptVersion; import com.comet.opik.api.Project; import com.comet.opik.api.Prompt; import com.comet.opik.api.PromptVersion; @@ -47,6 +48,7 @@ import java.sql.SQLException; import java.time.Instant; import java.util.List; +import java.util.Set; import java.util.UUID; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -86,6 +88,14 @@ class PromptResourceTest { private static final WireMockUtils.WireMockRuntime wireMock; public static final String[] IGNORED_FIELDS = {"versionCount", "latestVersion", "template"}; + public static final String TEMPLATE = """ + Hi {{%s}}, + + This is a test prompt. The current time is {{%s}}. + + Regards, + {{%s}} + """; static { Startables.deepStart(REDIS, CLICKHOUSE_CONTAINER, MYSQL).join(); @@ -458,7 +468,11 @@ void when__searchByPartialName__thenReturnPromptMatchingName(String promptName, .createdBy(USER) .build(); - createPrompt(prompt, apiKey, workspaceName); + Prompt updatedPrompt = prompt.toBuilder() + .name(prompt.name().replace(partialSearch, "")) + .build(); + + createPrompt(updatedPrompt, apiKey, workspaceName); }); var prompt = factory.manufacturePojo(Prompt.class).toBuilder() @@ -534,7 +548,7 @@ void when__fetchPromptsUsingPagination__thenReturnPromptsPaginated() { @Nested @DisplayName("Create Prompt Version") @TestInstance(TestInstance.Lifecycle.PER_CLASS) - class CreatePromptVersion { + class CreatePromptVersions { @Test @DisplayName("Success: should create prompt version") @@ -546,16 +560,222 @@ void shouldCreatePromptVersion() { .template(null) .build(); - var promptId = createPrompt(prompt, API_KEY, TEST_WORKSPACE); + UUID promptId = createPrompt(prompt, API_KEY, TEST_WORKSPACE); + + String variable1 = UUID.randomUUID().toString(); + String variable2 = UUID.randomUUID().toString(); + String variable3 = UUID.randomUUID().toString(); - var promptVersion = factory.manufacturePojo(PromptVersion.class).toBuilder() - .promptId(promptId) + var expectedPromptVersion = factory.manufacturePojo(PromptVersion.class).toBuilder() .createdBy(USER) + .template(TEMPLATE.formatted(variable1, variable2, variable3)) + .variables(Set.of(variable1, variable2, variable3)) + .commit(null) + .id(null) .build(); - assertThat(promptId).isNotNull(); + var request = new CreatePromptVersion(prompt.name(), expectedPromptVersion); + + PromptVersion actualPromptVersion = createPromptVersion(request, API_KEY, TEST_WORKSPACE); + + assertPromptVersion(actualPromptVersion, expectedPromptVersion, promptId); + } + + @Test + @DisplayName("when prompt version contains commit, then return created prompt version") + void when__promptVersionContainsCommit__thenReturnCreatedPromptVersion() { + + var prompt = factory.manufacturePojo(Prompt.class).toBuilder() + .lastUpdatedBy(USER) + .createdBy(USER) + .template(null) + .build(); + + UUID promptId = createPrompt(prompt, API_KEY, TEST_WORKSPACE); + + String variable1 = UUID.randomUUID().toString(); + String variable2 = UUID.randomUUID().toString(); + String variable3 = UUID.randomUUID().toString(); + + var versionId = factory.manufacturePojo(UUID.class); + + var expectedPromptVersion = factory.manufacturePojo(PromptVersion.class).toBuilder() + .createdBy(USER) + .template(TEMPLATE.formatted(variable1, variable2, variable3)) + .variables(Set.of(variable1, variable2, variable3)) + .commit(versionId.toString().substring(versionId.toString().length() - 8)) + .id(versionId) + .build(); + + var request = new CreatePromptVersion(prompt.name(), expectedPromptVersion); + + PromptVersion actualPromptVersion = createPromptVersion(request, API_KEY, TEST_WORKSPACE); + + assertPromptVersion(actualPromptVersion, expectedPromptVersion, promptId); + } + + @Test + @DisplayName("when prompt doesn't exist, then return created prompt version") + void when__promptDoesNotExist__thenReturnCreatedPromptVersion() { + + var apiKey = UUID.randomUUID().toString(); + var workspaceName = UUID.randomUUID().toString(); + var workspaceId = UUID.randomUUID().toString(); + + mockTargetWorkspace(apiKey, workspaceName, workspaceId); + + var promptName = UUID.randomUUID().toString(); + + String variable1 = UUID.randomUUID().toString(); + String variable2 = UUID.randomUUID().toString(); + String variable3 = UUID.randomUUID().toString(); + + var versionId = factory.manufacturePojo(UUID.class); + + var expectedPromptVersion = factory.manufacturePojo(PromptVersion.class).toBuilder() + .createdBy(USER) + .template(TEMPLATE.formatted(variable1, variable2, variable3)) + .variables(Set.of(variable1, variable2, variable3)) + .commit(versionId.toString().substring(versionId.toString().length() - 8)) + .id(versionId) + .build(); + + var request = new CreatePromptVersion(promptName, expectedPromptVersion); + + PromptVersion actualPromptVersion = createPromptVersion(request, apiKey, workspaceName); + + List prompts = getPrompts(promptName, apiKey, workspaceName); + + assertPromptVersion(actualPromptVersion, expectedPromptVersion, prompts.getFirst().id()); + } + + @ParameterizedTest + @MethodSource + @DisplayName("when prompt version is invalid, then return error") + void when__promptVersionIsInvalid__thenReturnError(CreatePromptVersion promptVersion, int expectedStatusCode, + Object expectedBody, Class expectedResponseClass) { + + try (var response = client.target(RESOURCE_PATH.formatted(baseURI) + "/versions") + .request() + .header(HttpHeaders.AUTHORIZATION, API_KEY) + .header(RequestContext.WORKSPACE_HEADER, TEST_WORKSPACE) + .post(Entity.json(promptVersion))) { + + assertThat(response.getStatus()).isEqualTo(expectedStatusCode); + + var actualBody = response.readEntity(expectedResponseClass); + + assertThat(actualBody).isEqualTo(expectedBody); + } } + Stream when__promptVersionIsInvalid__thenReturnError() { + return Stream.of( + arguments(new CreatePromptVersion(null, factory.manufacturePojo(PromptVersion.class)), + 422, new ErrorMessage(List.of("name must not be blank")), ErrorMessage.class), + arguments(new CreatePromptVersion("", factory.manufacturePojo(PromptVersion.class)), + 422, new ErrorMessage(List.of("name must not be blank")), ErrorMessage.class), + arguments( + new CreatePromptVersion(UUID.randomUUID().toString(), + factory.manufacturePojo(PromptVersion.class) + .toBuilder().commit("").build()), + 422, + new ErrorMessage(List.of( + "version.commit if present, the commit message must be 8 alphanumeric characters long")), + ErrorMessage.class), + arguments( + new CreatePromptVersion(UUID.randomUUID().toString(), + factory.manufacturePojo(PromptVersion.class) + .toBuilder().commit("1234567").build()), + 422, + new ErrorMessage(List.of( + "version.commit if present, the commit message must be 8 alphanumeric characters long")), + ErrorMessage.class), + arguments( + new CreatePromptVersion(UUID.randomUUID().toString(), + factory.manufacturePojo(PromptVersion.class) + .toBuilder().commit("1234-567").build()), + 422, + new ErrorMessage(List.of( + "version.commit if present, the commit message must be 8 alphanumeric characters long")), + ErrorMessage.class), + arguments( + new CreatePromptVersion(UUID.randomUUID().toString(), + factory.manufacturePojo(PromptVersion.class) + .toBuilder().id(UUID.randomUUID()).build()), + 400, new ErrorMessage(List.of("prompt version id must be a version 7 UUID")), + ErrorMessage.class), + arguments( + new CreatePromptVersion(UUID.randomUUID().toString(), + factory.manufacturePojo(PromptVersion.class) + .toBuilder().template("").build()), + 422, new ErrorMessage(List.of("version.template must not be blank")), + ErrorMessage.class), + arguments( + new CreatePromptVersion(UUID.randomUUID().toString(), + factory.manufacturePojo(PromptVersion.class) + .toBuilder().template(null).build()), + 422, new ErrorMessage(List.of("version.template must not be blank")), + ErrorMessage.class)); + } + } + + private List getPrompts(String nameSearch, String apiKey, String workspaceName) { + WebTarget target = client.target(RESOURCE_PATH.formatted(baseURI)); + + if (nameSearch != null) { + target = target.queryParam("name", nameSearch); + } + + try (var response = target.request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(RequestContext.WORKSPACE_HEADER, workspaceName) + .get()) { + + assertThat(response.getStatus()).isEqualTo(200); + + return response.readEntity(Prompt.PromptPage.class).content(); + } + } + + private void assertPromptVersion(PromptVersion createdPromptVersion, PromptVersion promptVersion, UUID promptId) { + assertThat(createdPromptVersion).isNotNull(); + + if (promptVersion.commit() == null) { + assertThat(createdPromptVersion.commit()).isNotNull(); + } else { + assertThat(createdPromptVersion.commit()).isEqualTo(promptVersion.commit()); + } + + UUID id = createdPromptVersion.id(); + + if (promptVersion.id() == null) { + assertThat(id).isNotNull(); + } else { + assertThat(id).isEqualTo(promptVersion.id()); + } + + assertThat(id.toString().substring(id.toString().length() - 8)) + .isEqualTo(createdPromptVersion.commit()); + + assertThat(createdPromptVersion.promptId()).isEqualTo(promptId); + assertThat(createdPromptVersion.template()).isEqualTo(promptVersion.template()); + assertThat(createdPromptVersion.variables()).isEqualTo(promptVersion.variables()); + assertThat(createdPromptVersion.createdAt()).isBetween(promptVersion.createdAt(), Instant.now()); + assertThat(createdPromptVersion.createdBy()).isEqualTo(USER); + } + + private PromptVersion createPromptVersion(CreatePromptVersion promptVersion, String apiKey, String workspaceName) { + try (var response = client.target(RESOURCE_PATH.formatted(baseURI) + "/versions") + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(RequestContext.WORKSPACE_HEADER, workspaceName) + .post(Entity.json(promptVersion))) { + + assertThat(response.getStatus()).isEqualTo(200); + + return response.readEntity(PromptVersion.class); + } } private void findPromptsAndAssertPage(List expectedPrompts, String apiKey, String workspaceName, diff --git a/apps/opik-backend/src/test/java/com/comet/opik/podam/PodamFactoryUtils.java b/apps/opik-backend/src/test/java/com/comet/opik/podam/PodamFactoryUtils.java index 333da413e..6e632739c 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/podam/PodamFactoryUtils.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/podam/PodamFactoryUtils.java @@ -1,11 +1,13 @@ package com.comet.opik.podam; import com.comet.opik.api.DatasetItem; +import com.comet.opik.api.PromptVersion; import com.comet.opik.podam.manufacturer.BigDecimalTypeManufacturer; import com.comet.opik.podam.manufacturer.CategoricalFeedbackDetailTypeManufacturer; import com.comet.opik.podam.manufacturer.DatasetItemTypeManufacturer; import com.comet.opik.podam.manufacturer.JsonNodeTypeManufacturer; import com.comet.opik.podam.manufacturer.NumericalFeedbackDetailTypeManufacturer; +import com.comet.opik.podam.manufacturer.PromptVersionManufacturer; import com.comet.opik.podam.manufacturer.UUIDTypeManufacturer; import com.fasterxml.jackson.databind.JsonNode; import jakarta.validation.constraints.DecimalMax; @@ -41,6 +43,7 @@ public static PodamFactory newPodamFactory() { new CategoricalFeedbackDetailTypeManufacturer()); strategy.addOrReplaceTypeManufacturer(JsonNode.class, JsonNodeTypeManufacturer.INSTANCE); strategy.addOrReplaceTypeManufacturer(DatasetItem.class, DatasetItemTypeManufacturer.INSTANCE); + strategy.addOrReplaceTypeManufacturer(PromptVersion.class, PromptVersionManufacturer.INSTANCE); return podamFactory; } diff --git a/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/PromptVersionManufacturer.java b/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/PromptVersionManufacturer.java new file mode 100644 index 000000000..d6c05ed35 --- /dev/null +++ b/apps/opik-backend/src/test/java/com/comet/opik/podam/manufacturer/PromptVersionManufacturer.java @@ -0,0 +1,47 @@ +package com.comet.opik.podam.manufacturer; + +import com.comet.opik.api.PromptVersion; +import org.apache.commons.lang3.RandomStringUtils; +import uk.co.jemos.podam.api.AttributeMetadata; +import uk.co.jemos.podam.api.DataProviderStrategy; +import uk.co.jemos.podam.common.ManufacturingContext; +import uk.co.jemos.podam.typeManufacturers.AbstractTypeManufacturer; + +import java.time.Instant; +import java.util.Set; +import java.util.UUID; + +public class PromptVersionManufacturer extends AbstractTypeManufacturer { + + public static final PromptVersionManufacturer INSTANCE = new PromptVersionManufacturer(); + + public static final String TEMPLATE = """ + Hi {{%s}}, + + This is a test prompt. The current time is {{%s}}. + + Regards, + {{%s}} + """; + + @Override + public PromptVersion getType(DataProviderStrategy strategy, AttributeMetadata metadata, + ManufacturingContext context) { + + UUID id = strategy.getTypeValue(metadata, context, UUID.class); + + String variable1 = RandomStringUtils.randomAlphanumeric(5); + String variable2 = RandomStringUtils.randomAlphanumeric(5); + String variable3 = RandomStringUtils.randomAlphanumeric(5); + + return PromptVersion.builder() + .id(id) + .commit(id.toString().substring(id.toString().length() - 8)) + .template(TEMPLATE.format(variable1, variable2, variable3)) + .variables(Set.of(variable1, variable2, variable3)) + .promptId(strategy.getTypeValue(metadata, context, UUID.class)) + .createdBy(strategy.getTypeValue(metadata, context, String.class)) + .createdAt(Instant.now()) + .build(); + } +} From dca7f1fadeb8997a2afb13b75e70eace630c1c6c Mon Sep 17 00:00:00 2001 From: Thiago Hora Date: Mon, 4 Nov 2024 21:10:22 +0100 Subject: [PATCH 9/9] Fix error --- .../com/comet/opik/domain/PromptService.java | 14 ++-- .../resources/v1/priv/PromptResourceTest.java | 84 ++++++++++++++++++- 2 files changed, 89 insertions(+), 9 deletions(-) diff --git a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java index 86bd311c8..f3400a9a8 100644 --- a/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java +++ b/apps/opik-backend/src/main/java/com/comet/opik/domain/PromptService.java @@ -38,8 +38,8 @@ public interface PromptService { @RequiredArgsConstructor(onConstructor_ = @Inject) class PromptServiceImpl implements PromptService { - public static final String ALREADY_EXISTS = "Prompt id or name already exists"; - public static final String VERSION_ALREADY_EXISTS = "Prompt version already exists"; + private static final String ALREADY_EXISTS = "Prompt id or name already exists"; + private static final String VERSION_ALREADY_EXISTS = "Prompt version already exists"; private final @NonNull Provider requestContext; private final @NonNull IdGenerator idGenerator; private final @NonNull TransactionTemplate transactionTemplate; @@ -116,12 +116,12 @@ private Prompt savePrompt(String workspaceId, Prompt prompt) { private EntityAlreadyExistsException newConflict() { log.info(ALREADY_EXISTS); - return new EntityAlreadyExistsException(new ErrorMessage(ALREADY_EXISTS)); + return new EntityAlreadyExistsException(new ErrorMessage(409, ALREADY_EXISTS)); } private EntityAlreadyExistsException newVersionConflict() { log.info(VERSION_ALREADY_EXISTS); - return new EntityAlreadyExistsException(new ErrorMessage(VERSION_ALREADY_EXISTS)); + return new EntityAlreadyExistsException(new ErrorMessage(409, VERSION_ALREADY_EXISTS)); } @Override @@ -181,8 +181,6 @@ public PromptVersion createPromptVersion(@NonNull CreatePromptVersion createProm String workspaceId = requestContext.get().getWorkspaceId(); String userName = requestContext.get().getUserName(); - Prompt prompt = getOrCreatePrompt(workspaceId, createPromptVersion.name(), userName); - UUID id = createPromptVersion.version().id() == null ? idGenerator.generateId() : createPromptVersion.version().id(); @@ -190,6 +188,10 @@ public PromptVersion createPromptVersion(@NonNull CreatePromptVersion createProm ? CommitGenerator.generateCommit(id) : createPromptVersion.version().commit(); + IdGenerator.validateVersion(id, "prompt version"); + + Prompt prompt = getOrCreatePrompt(workspaceId, createPromptVersion.name(), userName); + EntityConstraintHandler handler = EntityConstraintHandler.handle(() -> { PromptVersion promptVersion = createPromptVersion.version().toBuilder() .promptId(prompt.id()) diff --git a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java index f0556a269..1ff8962a3 100644 --- a/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java +++ b/apps/opik-backend/src/test/java/com/comet/opik/api/resources/v1/priv/PromptResourceTest.java @@ -389,10 +389,10 @@ Stream when__promptIsInvalid__thenReturnError() { new ErrorMessage(List.of("prompt id must be a version 7 UUID")), ErrorMessage.class), Arguments.of(duplicatedPrompt.toBuilder().name(UUID.randomUUID().toString()).build(), 409, - new io.dropwizard.jersey.errors.ErrorMessage("Prompt id or name already exists"), + new io.dropwizard.jersey.errors.ErrorMessage(409, "Prompt id or name already exists"), io.dropwizard.jersey.errors.ErrorMessage.class), Arguments.of(duplicatedPrompt.toBuilder().id(factory.manufacturePojo(UUID.class)).build(), 409, - new io.dropwizard.jersey.errors.ErrorMessage("Prompt id or name already exists"), + new io.dropwizard.jersey.errors.ErrorMessage(409, "Prompt id or name already exists"), io.dropwizard.jersey.errors.ErrorMessage.class), Arguments.of(factory.manufacturePojo(Prompt.class).toBuilder().description("").build(), 422, new ErrorMessage(List.of("description must not be blank")), @@ -469,7 +469,7 @@ void when__searchByPartialName__thenReturnPromptMatchingName(String promptName, .build(); Prompt updatedPrompt = prompt.toBuilder() - .name(prompt.name().replace(partialSearch, "")) + .name(prompt.name().replaceAll(partialSearch, "")) .build(); createPrompt(updatedPrompt, apiKey, workspaceName); @@ -649,6 +649,65 @@ void when__promptDoesNotExist__thenReturnCreatedPromptVersion() { assertPromptVersion(actualPromptVersion, expectedPromptVersion, prompts.getFirst().id()); } + @Test + @DisplayName("when prompt version id already exists, then return error") + void when__promptVersionIdAlreadyExists__thenReturnError() { + + var prompt = factory.manufacturePojo(Prompt.class).toBuilder() + .lastUpdatedBy(USER) + .createdBy(USER) + .template(null) + .build(); + + var versionId = factory.manufacturePojo(UUID.class); + + var promptVersion = factory.manufacturePojo(PromptVersion.class).toBuilder() + .createdBy(USER) + .id(versionId) + .build(); + + var request = new CreatePromptVersion(prompt.name(), promptVersion); + + createPromptVersion(request, API_KEY, TEST_WORKSPACE); + + var promptVersion2 = factory.manufacturePojo(PromptVersion.class).toBuilder() + .createdBy(USER) + .id(versionId) + .build(); + + assertPromptVersionConflict( + new CreatePromptVersion(UUID.randomUUID().toString(), promptVersion2), + API_KEY, TEST_WORKSPACE, "Prompt version already exists"); + } + + @Test + @DisplayName("when prompt version commit already exists, then return error") + void when__promptVersionCommitAlreadyExists__thenReturnError() { + + var prompt = factory.manufacturePojo(Prompt.class).toBuilder() + .lastUpdatedBy(USER) + .createdBy(USER) + .template(null) + .build(); + + var promptVersion = factory.manufacturePojo(PromptVersion.class).toBuilder() + .createdBy(USER) + .build(); + + var request = new CreatePromptVersion(prompt.name(), promptVersion); + + createPromptVersion(request, API_KEY, TEST_WORKSPACE); + + var promptVersion2 = factory.manufacturePojo(PromptVersion.class).toBuilder() + .createdBy(USER) + .commit(promptVersion.commit()) + .build(); + + assertPromptVersionConflict( + new CreatePromptVersion(prompt.name(), promptVersion2), + API_KEY, TEST_WORKSPACE, "Prompt version already exists"); + } + @ParameterizedTest @MethodSource @DisplayName("when prompt version is invalid, then return error") @@ -720,6 +779,25 @@ Stream when__promptVersionIsInvalid__thenReturnError() { } } + private void assertPromptVersionConflict(CreatePromptVersion request, String apiKey, String workspaceName, + String message) { + try (var response = client.target(RESOURCE_PATH.formatted(baseURI) + "/versions") + .request() + .header(HttpHeaders.AUTHORIZATION, apiKey) + .header(RequestContext.WORKSPACE_HEADER, workspaceName) + .post(Entity.json(request))) { + + assertThat(response.getStatus()).isEqualTo(409); + + var errorMessage = response.readEntity(io.dropwizard.jersey.errors.ErrorMessage.class); + + io.dropwizard.jersey.errors.ErrorMessage expectedError = new io.dropwizard.jersey.errors.ErrorMessage(409, + message); + + assertThat(errorMessage).isEqualTo(expectedError); + } + } + private List getPrompts(String nameSearch, String apiKey, String workspaceName) { WebTarget target = client.target(RESOURCE_PATH.formatted(baseURI));