Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OPIK-314] Create prompt version endpoint #554

Merged
merged 18 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions apps/opik-backend/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@
</dependencyManagement>

<dependencies>
<dependency>
<groupId>com.github.spullara.mustache.java</groupId>
<artifactId>compiler</artifactId>
<version>0.9.10</version>
</dependency>
<dependency>
<groupId>io.opentelemetry.instrumentation</groupId>
<artifactId>opentelemetry-instrumentation-annotations</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.Valid;
import lombok.Builder;

@Builder(toBuilder = true)
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
public record CreatePromptVersion(@JsonView( {
PromptVersion.View.Detail.class}) @NotBlank String name,
@JsonView({PromptVersion.View.Detail.class}) @NotNull PromptVersion version){
@JsonView({PromptVersion.View.Detail.class}) @NotNull @Valid PromptVersion version){
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public static class Public {
public static class Detail {
}
}

@Builder
public record PromptPage(
@JsonView( {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package com.comet.opik.api;

import com.comet.opik.api.validate.CommitValidation;
import com.comet.opik.utils.ValidationUtils;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonView;
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.annotation.Nullable;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.NotBlank;
import lombok.Builder;

import java.time.Instant;
Expand All @@ -24,8 +26,8 @@ public record PromptVersion(
@JsonView({PromptVersion.View.Public.class,
PromptVersion.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) UUID promptId,
@JsonView({PromptVersion.View.Public.class,
PromptVersion.View.Detail.class}) @Schema(description = "version short unique identifier, generated if absent") String commit,
@JsonView({PromptVersion.View.Detail.class}) @NotNull String template,
PromptVersion.View.Detail.class}) @Schema(description = "version short unique identifier, generated if absent. it must be 8 characters long", requiredMode = Schema.RequiredMode.NOT_REQUIRED, pattern = ValidationUtils.COMMIT_PATTERN) @CommitValidation String commit,
@JsonView({PromptVersion.View.Detail.class}) @NotBlank String template,
@JsonView({
PromptVersion.View.Detail.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @Nullable Set<String> variables,
@JsonView({PromptVersion.View.Public.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,17 @@
import jakarta.inject.Inject;
import jakarta.inject.Provider;
import jakarta.validation.Valid;
import jakarta.validation.constraints.Min;
import jakarta.ws.rs.Consumes;
import jakarta.ws.rs.DELETE;
import jakarta.ws.rs.DefaultValue;
import jakarta.ws.rs.GET;
import jakarta.ws.rs.POST;
import jakarta.ws.rs.PUT;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.PathParam;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.QueryParam;
import jakarta.ws.rs.core.Context;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
Expand Down Expand Up @@ -205,27 +212,12 @@ public Response createPromptVersion(
log.info("Creating prompt version commit '{}' on workspace_id '{}'", promptVersion.version().commit(),
workspaceId);

UUID id = idGenerator.generateId();
log.info("Created prompt version commit '{}' with id '{}' on workspace_id '{}'",
promptVersion.version().commit(), id, workspaceId);
var createdVersion = promptService.createPromptVersion(promptVersion);

return Response.status(Response.Status.NOT_IMPLEMENTED)
.entity(generatePromptVersion(promptVersion, id))
.build();
}
log.info("Created prompt version commit '{}' with id '{}' on workspace_id '{}'",
promptVersion.version().commit(), createdVersion.id(), workspaceId);

private PromptVersion generatePromptVersion(CreatePromptVersion promptVersion, UUID id) {
return PromptVersion.builder()
.id(id)
.commit(promptVersion.version().commit() == null
? id.toString().substring(id.toString().length() - 7)
: promptVersion.version().commit())
.template(promptVersion.version().template())
.variables(
Set.of("user_message"))
.createdAt(Instant.now())
.createdBy("User 1")
.build();
return Response.ok(createdVersion).build();
}

@GET
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.comet.opik.api.validate;

import jakarta.validation.Constraint;
import jakarta.validation.Payload;

import java.lang.annotation.Documented;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target({ElementType.FIELD, ElementType.ANNOTATION_TYPE, ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Constraint(validatedBy = {CommitValidator.class})
@Documented
public @interface CommitValidation {

String message() default "if present, the commit message must be 8 alphanumeric characters long";

Class<?>[] groups() default {};

Class<? extends Payload>[] payload() default {};
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.comet.opik.api.validate;

import com.comet.opik.utils.ValidationUtils;
import jakarta.validation.ConstraintValidator;
import jakarta.validation.ConstraintValidatorContext;

import java.util.regex.Pattern;

public class CommitValidator implements ConstraintValidator<CommitValidation, String> {

@Override
public boolean isValid(String commit, ConstraintValidatorContext context) {
return commit == null || Pattern.matches(ValidationUtils.COMMIT_PATTERN, commit);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ default T withError(Supplier<EntityAlreadyExistsException> errorProvider) {
}
}

default T onErrorDo(Supplier<T> errorProvider) {
try {
return wrappedAction().execute();
} catch (UnableToExecuteStatementException e) {
if (e.getCause() instanceof SQLIntegrityConstraintViolationException) {
return errorProvider.get();
} else {
throw e;
}
}
}

default T withRetry(int times, Supplier<EntityAlreadyExistsException> errorProvider) {
Preconditions.checkArgument(times > 0, "Retry times must be greater than 0");

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package com.comet.opik.domain;

import com.github.mustachejava.Code;
import com.github.mustachejava.DefaultMustacheFactory;
import com.github.mustachejava.Mustache;
import com.github.mustachejava.MustacheFactory;
import com.github.mustachejava.codes.ValueCode;
import lombok.experimental.UtilityClass;

import java.io.StringReader;
import java.util.HashSet;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

@UtilityClass
class MustacheVariableExtractor {

public static final MustacheFactory MF = new DefaultMustacheFactory();

public static Set<String> extractVariables(String template) {
Set<String> variables = new HashSet<>();

// Initialize Mustache Factory
Mustache mustache = MF.compile(new StringReader(template), "template");

// Get th e root node of the template
Code[] codes = mustache.getCodes();
collectVariables(codes, variables);

return variables;
}

private static void collectVariables(Code[] codes, Set<String> variables) {
for (Code code : codes) {
if (Objects.requireNonNull(code) instanceof ValueCode valueCode) {
variables.add(valueCode.getName());
} else {
Optional.ofNullable(code)
.map(Code::getCodes)
.map(it -> it.length > 0)
.ifPresent(it -> collectVariables(code.getCodes(), variables));
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,7 @@ List<Prompt> find(@Define("name") @Bind("name") String name, @Bind("workspace_Id
@UseStringTemplateEngine
@AllowUnusedBindings
long count(@Define("name") @Bind("name") String name, @Bind("workspace_Id") String workspaceId);

@SqlQuery("SELECT * FROM prompts WHERE name = :name AND workspace_id = :workspace_id")
Prompt findByName(@Bind("name") String name, @Bind("workspace_id") String workspaceId);
}
Loading