Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CM-2344]: Implement equivalent of API.add_registry_model_version_stage #77

Merged
merged 3 commits into from
Oct 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public class RegistryModelExample {
static final String SOME_MODEL_VERSION = "1.0.0";
static final String SOME_MODEL_VERSION_UP = "1.0.1";
static final String STAGE_PRODUCTION = "production";
static final String STAGE_STAGING = "staging";
static final String SOME_NOTES = "some model notes";

/**
Expand Down Expand Up @@ -147,6 +148,13 @@ record = api.registerModel(updatedModel, experiment.getExperimentKey());
System.out.printf("Overview of the model '%s' not found\n", registryName);
}

// add stage to the model
//
System.out.printf("Adding stage `%s' to the registered model version '%s:%s'\n",
STAGE_STAGING, registryName, SOME_MODEL_VERSION_UP);
api.addRegistryModelVersionStage(registryName, experiment.getWorkspaceName(),
SOME_MODEL_VERSION_UP, STAGE_STAGING);

// get details about model version
//
System.out.printf("Retrieving details of the model version '%s:%s'\n", registryName, SOME_MODEL_VERSION_UP);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public final class MnistExperimentExample {
* The number of epochs to perform.
*/
@Parameter(names = {"--epochs", "-e"}, description = "number of epochs to perform")
final
int numEpochs = 2;

/**
Expand Down
10 changes: 10 additions & 0 deletions comet-java-client/src/main/java/ml/comet/experiment/CometApi.java
Original file line number Diff line number Diff line change
Expand Up @@ -206,4 +206,14 @@ void updateRegistryModelVersion(String registryName, String workspace, String ve
* @param version the version of the registered model to be deleted.
*/
void deleteRegistryModelVersion(String registryName, String workspace, String version);

/**
* Adds a stage to a registered model version.
*
* @param registryName the name of the model.
* @param workspace the name of the model's workspace.
* @param version the version of the registered model to be updated.
* @param stage the name of the stage to be added.
*/
void addRegistryModelVersionStage(String registryName, String workspace, String version, String stage);
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public interface OnlineExperimentBuilder extends BaseCometBuilder<OnlineExperime
/**
* Set the URL of your comet installation.
*
* @param urlOverride full url of comet installation. Default is https://www.comet.ml
* @param urlOverride full url of comet installation. Default is https://www.comet.com
* @return the builder configured with specified URL of the Comet installation.
*/
OnlineExperimentBuilder withUrlOverride(String urlOverride);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
* using asynchronous networking.
*/
abstract class BaseExperimentAsync extends BaseExperiment {
ExperimentContext baseContext;
final ExperimentContext baseContext;

BaseExperimentAsync(@NonNull final String apiKey,
@NonNull final String baseUrl,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
import ml.comet.experiment.impl.rest.ExperimentModelResponse;
import ml.comet.experiment.impl.rest.RegistryModelCountResponse;
import ml.comet.experiment.impl.rest.RegistryModelCreateRequest;
import ml.comet.experiment.impl.rest.RegistryModelDeleteRequest;
import ml.comet.experiment.impl.rest.RegistryModelDetailsResponse;
import ml.comet.experiment.impl.rest.RegistryModelItemCreateRequest;
import ml.comet.experiment.impl.rest.RegistryModelNotesResponse;
import ml.comet.experiment.impl.rest.RegistryModelNotesUpdateRequest;
import ml.comet.experiment.impl.rest.RegistryModelOverviewListResponse;
import ml.comet.experiment.impl.rest.RegistryModelUpdateItemRequest;
import ml.comet.experiment.impl.rest.RegistryModelUpdateRequest;
import ml.comet.experiment.impl.rest.RegistryModelVersionStageAddRequest;
import ml.comet.experiment.impl.rest.RestApiResponse;
import ml.comet.experiment.impl.utils.CometUtils;
import ml.comet.experiment.impl.utils.ExceptionUtils;
Expand Down Expand Up @@ -71,6 +73,7 @@
import static ml.comet.experiment.impl.resources.LogMessages.EXPERIMENT_HAS_NO_MODELS;
import static ml.comet.experiment.impl.resources.LogMessages.EXPERIMENT_WITH_KEY_NOT_FOUND;
import static ml.comet.experiment.impl.resources.LogMessages.EXTRACTED_N_REGISTRY_MODEL_FILES;
import static ml.comet.experiment.impl.resources.LogMessages.FAILED_TO_ADD_REGISTRY_MODEL_VERSION_STAGE;
import static ml.comet.experiment.impl.resources.LogMessages.FAILED_TO_DELETE_REGISTRY_MODEL;
import static ml.comet.experiment.impl.resources.LogMessages.FAILED_TO_DELETE_REGISTRY_MODEL_VERSION;
import static ml.comet.experiment.impl.resources.LogMessages.FAILED_TO_DOWNLOAD_REGISTRY_MODEL;
Expand Down Expand Up @@ -505,11 +508,29 @@ public void updateRegistryModelVersion(String registryName, String workspace, St
this.updateRegistryModelVersion(registryName, workspace, version, null, null);
}

@Override
public void addRegistryModelVersionStage(@NonNull String registryName, @NonNull String workspace,
@NonNull String version, @NonNull String stage) {
// get version details
Optional<ModelVersionOverview> versionOverviewOptional = this.getRegistryModelVersion(
registryName, workspace, version);
if (!versionOverviewOptional.isPresent()) {
throw new ModelVersionNotFoundException(
getString(REGISTRY_MODEL_VERSION_NOT_FOUND, version, workspace, registryName));
}

String errorMessage = getString(
FAILED_TO_ADD_REGISTRY_MODEL_VERSION_STAGE, stage, workspace, registryName, version);
this.executeSyncRequest(this.restApiClient::addRegistryModelVersionStage,
new RegistryModelVersionStageAddRequest(versionOverviewOptional.get().getRegistryModelItemId(), stage),
errorMessage);
}

@Override
public void deleteRegistryModel(@NonNull String registryName, @NonNull String workspace) {
RestApiResponse response = this.restApiClient.deleteRegistryModel(registryName, workspace)
.blockingGet();
this.checkRestApiResponse(response, getString(FAILED_TO_DELETE_REGISTRY_MODEL, registryName, workspace));
String errorMsg = getString(FAILED_TO_DELETE_REGISTRY_MODEL, registryName, workspace);
this.executeSyncRequest(this.restApiClient::deleteRegistryModel,
new RegistryModelDeleteRequest(registryName, workspace), errorMsg);
}

@Override
Expand All @@ -523,9 +544,8 @@ public void deleteRegistryModelVersion(@NonNull String registryName, @NonNull St
getString(REGISTRY_MODEL_VERSION_NOT_FOUND, version, workspace, registryName));
}
String errorMsg = getString(FAILED_TO_DELETE_REGISTRY_MODEL_VERSION, workspace, registryName, version);
RestApiResponse response = this.executeSyncRequest(this.restApiClient::deleteRegistryModelVersion,
this.executeSyncRequest(this.restApiClient::deleteRegistryModelVersion,
versionOverviewOptional.get().getRegistryModelItemId(), errorMsg);
this.checkRestApiResponse(response, errorMsg);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,7 @@ Optional<Action> getLogAssetOnCompleteAction() {
* The runnable to be invoked to send periodic heartbeat ping to mark this experiment as still running.
*/
static class HeartbeatPing implements Runnable {
OnlineExperimentImpl onlineExperiment;
final OnlineExperimentImpl onlineExperiment;

HeartbeatPing(OnlineExperimentImpl onlineExperiment) {
this.onlineExperiment = onlineExperiment;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import ml.comet.experiment.impl.rest.RegistryModelCountResponse;
import ml.comet.experiment.impl.rest.RegistryModelCreateRequest;
import ml.comet.experiment.impl.rest.RegistryModelCreateResponse;
import ml.comet.experiment.impl.rest.RegistryModelDeleteRequest;
import ml.comet.experiment.impl.rest.RegistryModelDetailsResponse;
import ml.comet.experiment.impl.rest.RegistryModelItemCreateRequest;
import ml.comet.experiment.impl.rest.RegistryModelItemCreateResponse;
Expand All @@ -52,6 +53,7 @@
import ml.comet.experiment.impl.rest.RegistryModelOverviewListResponse;
import ml.comet.experiment.impl.rest.RegistryModelUpdateItemRequest;
import ml.comet.experiment.impl.rest.RegistryModelUpdateRequest;
import ml.comet.experiment.impl.rest.RegistryModelVersionStageAddRequest;
import ml.comet.experiment.impl.rest.RestApiResponse;
import ml.comet.experiment.impl.rest.SetSystemDetailsRequest;
import ml.comet.experiment.impl.rest.TagsResponse;
Expand All @@ -74,6 +76,7 @@
import static ml.comet.experiment.impl.constants.ApiEndpoints.ADD_METRIC;
import static ml.comet.experiment.impl.constants.ApiEndpoints.ADD_OUTPUT;
import static ml.comet.experiment.impl.constants.ApiEndpoints.ADD_PARAMETER;
import static ml.comet.experiment.impl.constants.ApiEndpoints.ADD_REGISTRY_MODEL_VERSION_STAGE;
import static ml.comet.experiment.impl.constants.ApiEndpoints.ADD_START_END_TIME;
import static ml.comet.experiment.impl.constants.ApiEndpoints.ADD_TAG;
import static ml.comet.experiment.impl.constants.ApiEndpoints.CREATE_REGISTRY_MODEL;
Expand Down Expand Up @@ -118,6 +121,7 @@
import static ml.comet.experiment.impl.constants.QueryParamName.MODEL_NAME;
import static ml.comet.experiment.impl.constants.QueryParamName.PROJECT_ID;
import static ml.comet.experiment.impl.constants.QueryParamName.PROJECT_NAME;
import static ml.comet.experiment.impl.constants.QueryParamName.STAGE;
import static ml.comet.experiment.impl.constants.QueryParamName.TYPE;
import static ml.comet.experiment.impl.constants.QueryParamName.WORKSPACE_NAME;
import static ml.comet.experiment.impl.http.ConnectionUtils.checkResponseStatus;
Expand Down Expand Up @@ -383,6 +387,13 @@ Single<RegistryModelCountResponse> getRegistryModelsCount(String workspaceName)
RegistryModelCountResponse.class);
}

Single<RestApiResponse> addRegistryModelVersionStage(RegistryModelVersionStageAddRequest request) {
Map<QueryParamName, String> queryParams = new HashMap<>();
queryParams.put(MODEL_ITEM_ID, request.getRegistryModelItemId());
queryParams.put(STAGE, request.getStage());
return this.singleFromSyncGetWithRetries(ADD_REGISTRY_MODEL_VERSION_STAGE, queryParams);
}

Single<RestApiResponse> downloadRegistryModel(
final OutputStream output, String workspace, String registryName, final DownloadModelOptions options) {
Map<QueryParamName, String> queryParams = downloadModelParams(workspace, registryName, options);
Expand All @@ -409,10 +420,10 @@ Single<RestApiResponse> updateRegistryModelVersion(RegistryModelUpdateItemReques
return singleFromAsyncPost(request, UPDATE_REGISTRY_MODEL_VERSION);
}

Single<RestApiResponse> deleteRegistryModel(String modelName, String workspaceName) {
Single<RestApiResponse> deleteRegistryModel(RegistryModelDeleteRequest request) {
Map<QueryParamName, String> queryParams = new HashMap<>();
queryParams.put(WORKSPACE_NAME, workspaceName);
queryParams.put(MODEL_NAME, modelName);
queryParams.put(WORKSPACE_NAME, request.getWorkspace());
queryParams.put(MODEL_NAME, request.getRegistryModelName());
return singleFromSyncGetWithRetries(DELETE_REGISTRY_MODEL, queryParams);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package ml.comet.experiment.impl.asset;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
import ml.comet.experiment.asset.RemoteAsset;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public final class ApiEndpoints {
public static final String UPDATE_REGISTRY_MODEL_VERSION = UPDATE_API_URL + "/registry-model/item/update";
public static final String DELETE_REGISTRY_MODEL = UPDATE_API_URL + "/registry-model/delete";
public static final String DELETE_REGISTRY_MODEL_ITEM = UPDATE_API_URL + "/registry-model/item/delete";
public static final String ADD_REGISTRY_MODEL_VERSION_STAGE = UPDATE_API_URL + "/registry-model/item/stage";

public static final String READ_API_URL = "/api/rest/v2";
public static final String GET_ASSETS_LIST = READ_API_URL + "/experiment/asset/list";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
public class StdOutLogger implements Runnable, Closeable {
final AtomicLong offset = new AtomicLong();

OutputStream outputStream;
InputStream inputStream;
PrintStream original;
OnlineExperiment experiment;
boolean stdOut;
final OutputStream outputStream;
final InputStream inputStream;
final PrintStream original;
final OnlineExperiment experiment;
final boolean stdOut;

/**
* Creates logger instance that captures StdOut stream for a given OnlineExperiment.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ public class LogMessages {
public static final String FAILED_TO_UPDATE_REGISTRY_MODEL_NOTES = "FAILED_TO_UPDATE_REGISTRY_MODEL_NOTES";
public static final String FAILED_TO_UPDATE_REGISTRY_MODEL = "FAILED_TO_UPDATE_REGISTRY_MODEL";
public static final String FAILED_TO_UPDATE_REGISTRY_MODEL_VERSION = "FAILED_TO_UPDATE_REGISTRY_MODEL_VERSION";
public static final String FAILED_TO_ADD_REGISTRY_MODEL_VERSION_STAGE =
"FAILED_TO_ADD_REGISTRY_MODEL_VERSION_STAGE";
public static final String REGISTRY_MODEL_VERSION_NOT_FOUND = "REGISTRY_MODEL_VERSION_NOT_FOUND";
public static final String FAILED_TO_DELETE_REGISTRY_MODEL = "FAILED_TO_DELETE_REGISTRY_MODEL";
public static final String NO_RESPONSE_RETURNED_BY_REMOTE_ENDPOINT = "NO_RESPONSE_RETURNED_BY_REMOTE_ENDPOINT";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package ml.comet.experiment.impl.rest;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
@Data
@NoArgsConstructor
@AllArgsConstructor
public class RegistryModelDeleteRequest {
private String registryModelName;
String workspace;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package ml.comet.experiment.impl.rest;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
@Data
@NoArgsConstructor
@AllArgsConstructor
public class RegistryModelVersionStageAddRequest {
private String registryModelItemId;
private String stage;
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package ml.comet.experiment.impl.utils;

import ml.comet.experiment.exception.CometApiException;

import java.util.Objects;

/**
Expand Down
1 change: 1 addition & 0 deletions comet-java-client/src/main/resources/messages.properties
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ FAILED_TO_GET_REGISTRY_MODELS_COUNT=Failed to get number of registry models unde
FAILED_TO_UPDATE_REGISTRY_MODEL_NOTES=Failed to update notes of the registry model '%s/%s'.
FAILED_TO_UPDATE_REGISTRY_MODEL=Failed to update registry model '%s/%s' with data: '%s'.
FAILED_TO_UPDATE_REGISTRY_MODEL_VERSION=Failed to update registry model's version '%s/%s:%s' with data: '%s'.
FAILED_TO_ADD_REGISTRY_MODEL_VERSION_STAGE=Failed to add stage `%s` to the registry model version '%s/%s:%s'.
REGISTRY_MODEL_VERSION_NOT_FOUND=Version '%s' of the registry model '%s/%s' is not found.
FAILED_TO_DELETE_REGISTRY_MODEL=Failed to delete registry model '%s/%s'.
NO_RESPONSE_RETURNED_BY_REMOTE_ENDPOINT=No response was returned by endpoint '%s'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@
@DisplayName("ApiExperimentTest INTEGRATION")
@Tag("integration")
public class ApiExperimentTest {
static Map<String, Object> SOME_METADATA = new HashMap<String, Object>() {{
static final Map<String, Object> SOME_METADATA = new HashMap<String, Object>() {{
put("someString", "string");
put("someInt", 10);
}};
static String SOME_TEXT = "this is some text to be used";
static final String SOME_TEXT = "this is some text to be used";

@Test
public void testApiExperimentInitializedWithInvalidValues() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,20 @@
@DisplayName("Artifact")
public class ArtifactImplTest extends AssetsBaseTest {

static String SOME_ARTIFACT_NAME = "artifactName";
static String SOME_ARTIFACT_TYPE = "artifactType";
static List<String> SOME_ALIASES = Arrays.asList("one", "two", "three", "three");
static Set<String> UNIQUE_ALIASES = new HashSet<>(SOME_ALIASES);
static List<String> SOME_TAGS = Arrays.asList("tag_1", "tag_2", "tag_3", "tag_3");
static Set<String> UNIQUE_TAGS = new HashSet<>(SOME_TAGS);
static String SOME_VERSION = "1.2.3-beta.4+sha899d8g79f87";
static String INVALID_VERSION = "1.2";
static Map<String, Object> SOME_METADATA = new HashMap<String, Object>() {{
static final String SOME_ARTIFACT_NAME = "artifactName";
static final String SOME_ARTIFACT_TYPE = "artifactType";
static final List<String> SOME_ALIASES = Arrays.asList("one", "two", "three", "three");
static final Set<String> UNIQUE_ALIASES = new HashSet<>(SOME_ALIASES);
static final List<String> SOME_TAGS = Arrays.asList("tag_1", "tag_2", "tag_3", "tag_3");
static final Set<String> UNIQUE_TAGS = new HashSet<>(SOME_TAGS);
static final String SOME_VERSION = "1.2.3-beta.4+sha899d8g79f87";
static final String INVALID_VERSION = "1.2";
static final Map<String, Object> SOME_METADATA = new HashMap<String, Object>() {{
put("someString", "string");
put("someInt", 10);
}};
static String SOME_REMOTE_ASSET_LINK = "s3://bucket/folder/someFile";
static String SOME_REMOTE_ASSET_NAME = "someRemoteAsset";
static final String SOME_REMOTE_ASSET_LINK = "s3://bucket/folder/someFile";
static final String SOME_REMOTE_ASSET_NAME = "someRemoteAsset";

@Test
@DisplayName("is created with newArtifact()")
Expand Down Expand Up @@ -143,7 +143,7 @@ void hasNoAssets() {
class AfterAddingFileAssetTest {
File assetFile;
String assetFileName;
boolean overwrite = true;
final boolean overwrite = true;

@BeforeEach
void addFileAsset() {
Expand Down Expand Up @@ -175,7 +175,7 @@ void throwsExceptionWhenAddingSameName() {
class AfterAddingFileLikeAssetTest {
byte[] data;
String assetFileName;
boolean overwrite = true;
final boolean overwrite = true;

@BeforeEach
void addFileLikeAsset() {
Expand Down Expand Up @@ -207,7 +207,7 @@ void throwsExceptionWhenAddingSameName() {
class AfterAddingRemoteAssetTest {
URI uri;
String assetFileName;
boolean overwrite = true;
final boolean overwrite = true;

@BeforeEach
void addRemoteAsset() throws URISyntaxException {
Expand Down Expand Up @@ -238,8 +238,8 @@ void throwsExceptionWhenAddingSameName() {
@Nested
@DisplayName("after adding assets folder")
class AfterAddingAssetsFolderTest {
boolean logFilePath = true;
boolean recursive = true;
final boolean logFilePath = true;
final boolean recursive = true;

@BeforeEach
void addAssetsFolder() throws IOException {
Expand Down
Loading