diff --git a/.github/linters/comet-checkstyle.xml b/.github/linters/comet-checkstyle.xml index c7d17059..7e6b7a56 100644 --- a/.github/linters/comet-checkstyle.xml +++ b/.github/linters/comet-checkstyle.xml @@ -29,11 +29,11 @@ - - - - + + + + + @@ -47,7 +47,15 @@ + + + + + + + + diff --git a/comet-examples/src/main/java/ml/comet/examples/OnlineExperimentExample.java b/comet-examples/src/main/java/ml/comet/examples/OnlineExperimentExample.java index 58ee22c1..18707c63 100644 --- a/comet-examples/src/main/java/ml/comet/examples/OnlineExperimentExample.java +++ b/comet-examples/src/main/java/ml/comet/examples/OnlineExperimentExample.java @@ -2,8 +2,13 @@ import ml.comet.experiment.ExperimentBuilder; import ml.comet.experiment.OnlineExperiment; +import ml.comet.experiment.context.ExperimentContext; +import org.apache.commons.io.file.PathUtils; import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Objects; import static ml.comet.examples.Utils.getResourceFile; import static ml.comet.examples.Utils.readResourceToString; @@ -22,21 +27,19 @@ */ public class OnlineExperimentExample { + private static final String CHART_IMAGE_FILE = "chart.png"; + private static final String MODEL_FILE = "model.hd5"; + private static final String HTML_REPORT_FILE = "report.html"; + private static final String GRAPH_JSON_FILE = "graph.json"; + private static final String CODE_FILE = "code_sample.py"; + /** * The main entry point to the example. * * @param args the command line arguments if any. */ public static void main(String[] args) { - OnlineExperimentExample main = new OnlineExperimentExample(); - try { - main.run(); - } catch (IOException e) { - e.printStackTrace(); - } - } - private void run() throws IOException { //this will take configs from /comet-java-sdk/comet-examples/src/main/resources/application.conf //be sure you have set up apiKey, project, workspace in defaults.conf before you start! @@ -48,6 +51,16 @@ private void run() throws IOException { //you can use a default builder or just inject params //OnlineExperiment experiment = ExperimentBuilder.OnlineExperiment().builder(); + try { + OnlineExperimentExample.run(experiment); + } catch (IOException e) { + e.printStackTrace(); + } finally { + experiment.end(); + } + } + + private static void run(OnlineExperiment experiment) throws IOException { experiment.setExperimentName("Java-SDK 2.0.2"); experiment.nextStep(); @@ -68,20 +81,33 @@ private void run() throws IOException { experiment.logParameter("batch_size", "500"); experiment.logParameter("learning_rate", 12); - experiment.uploadAsset(getResourceFile("chart.png"), "amazing chart.png", false); - experiment.uploadAsset(getResourceFile("model.hd5"), false); + experiment.uploadAsset(getResourceFile(CHART_IMAGE_FILE), "amazing chart.png", false); + experiment.uploadAsset(getResourceFile(MODEL_FILE), false, + ExperimentContext.builder().withContext("train").build()); + + experiment.nextStep(); + + // upload assets from folder + Path assetDir = copyResourcesToTmpDir(); + experiment.logAssetFolder(assetDir.toFile(), true, true); experiment.logOther("Parameter", 4); System.out.println("Epoch 1/20"); System.out.println("- loss: 0.7858 - acc: 0.7759 - val_loss: 0.3416 - val_acc: 0.9026"); - experiment.logGraph(readResourceToString("graph.json")); + experiment.logGraph(readResourceToString(GRAPH_JSON_FILE)); + + experiment.logCode(getResourceFile(CODE_FILE), + ExperimentContext.builder().withContext("test").build()); System.out.println("===== Experiment completed ===="); // will close connection, if not called connection will close on jvm exit experiment.end(); + + // remove tmp directory + PathUtils.deleteDirectory(assetDir); } private static void generateCharts(OnlineExperiment experiment) { @@ -108,4 +134,20 @@ private static long getUpdatedEpochValue(OnlineExperiment experiment) { return experiment.getEpoch() + experiment.getStep() / 5; } + private static Path copyResourcesToTmpDir() throws IOException { + Path root = Files.createTempDirectory("onlineExperimentExample"); + PathUtils.copyFileToDirectory( + Objects.requireNonNull(getResourceFile(CHART_IMAGE_FILE)).toPath(), root); + PathUtils.copyFileToDirectory( + Objects.requireNonNull(getResourceFile(MODEL_FILE)).toPath(), root); + Files.createTempFile(root, "empty_file", ".txt"); + + Path subDir = Files.createTempDirectory(root, "subDir"); + PathUtils.copyFileToDirectory( + Objects.requireNonNull(getResourceFile(HTML_REPORT_FILE)).toPath(), subDir); + PathUtils.copyFileToDirectory( + Objects.requireNonNull(getResourceFile(GRAPH_JSON_FILE)).toPath(), subDir); + + return root; + } } \ No newline at end of file diff --git a/comet-examples/src/main/resources/code_sample.py b/comet-examples/src/main/resources/code_sample.py new file mode 100644 index 00000000..fe01c726 --- /dev/null +++ b/comet-examples/src/main/resources/code_sample.py @@ -0,0 +1 @@ +# some code goes here \ No newline at end of file diff --git a/comet-java-client/pom.xml b/comet-java-client/pom.xml index 43e2d0aa..9d699c01 100644 --- a/comet-java-client/pom.xml +++ b/comet-java-client/pom.xml @@ -49,6 +49,11 @@ commons-lang3 3.12.0 + + commons-io + commons-io + 2.11.0 + com.typesafe config diff --git a/comet-java-client/src/main/java/ml/comet/experiment/ApiExperiment.java b/comet-java-client/src/main/java/ml/comet/experiment/ApiExperiment.java index 73788734..66772a6d 100644 --- a/comet-java-client/src/main/java/ml/comet/experiment/ApiExperiment.java +++ b/comet-java-client/src/main/java/ml/comet/experiment/ApiExperiment.java @@ -1,24 +1,7 @@ package ml.comet.experiment; -import lombok.NonNull; -import lombok.experimental.UtilityClass; -import ml.comet.experiment.builder.ApiExperimentBuilder; -import ml.comet.experiment.impl.ApiExperimentImpl; - /** - * This is stub to support backward compatibility. - * - * @deprecated It would be replaced in the future with new experiment creation API. + * The {@code ApiExperiment} can be used to synchronously read/update data of your Comet.ml experiment. */ -@UtilityClass -public final class ApiExperiment { - /** - * Returns builder to create {@link Experiment} instance. - * - * @param experimentKey the unique identifier of the existing experiment. - * @return the initialized ApiExperiment instance. - */ - public static ApiExperimentBuilder builder(@NonNull final String experimentKey) { - return ApiExperimentImpl.builder(experimentKey); - } +public interface ApiExperiment extends Experiment { } diff --git a/comet-java-client/src/main/java/ml/comet/experiment/ApiExperimentImpl.java b/comet-java-client/src/main/java/ml/comet/experiment/ApiExperimentImpl.java new file mode 100644 index 00000000..f24bc5ed --- /dev/null +++ b/comet-java-client/src/main/java/ml/comet/experiment/ApiExperimentImpl.java @@ -0,0 +1,23 @@ +package ml.comet.experiment; + +import lombok.NonNull; +import lombok.experimental.UtilityClass; +import ml.comet.experiment.builder.ApiExperimentBuilder; + +/** + * This is stub to support backward compatibility. + * + * @deprecated It would be replaced in the future with new experiment creation API. + */ +@UtilityClass +public final class ApiExperimentImpl { + /** + * Returns builder to create {@link Experiment} instance. + * + * @param experimentKey the unique identifier of the existing experiment. + * @return the initialized ApiExperiment instance. + */ + public static ApiExperimentBuilder builder(@NonNull final String experimentKey) { + return ml.comet.experiment.impl.ApiExperimentImpl.builder(experimentKey); + } +} diff --git a/comet-java-client/src/main/java/ml/comet/experiment/Experiment.java b/comet-java-client/src/main/java/ml/comet/experiment/Experiment.java index e2235d92..7446e32b 100644 --- a/comet-java-client/src/main/java/ml/comet/experiment/Experiment.java +++ b/comet-java-client/src/main/java/ml/comet/experiment/Experiment.java @@ -1,6 +1,7 @@ package ml.comet.experiment; -import ml.comet.experiment.impl.constants.AssetType; +import ml.comet.experiment.context.ExperimentContext; +import ml.comet.experiment.impl.asset.AssetType; import ml.comet.experiment.model.ExperimentAssetLink; import ml.comet.experiment.model.ExperimentMetadataRest; import ml.comet.experiment.model.GitMetadata; @@ -140,11 +141,22 @@ public interface Experiment { /** * Send logs to Comet. * - * @param line Text to be logged - * @param offset Offset describes the place for current text to be inserted - * @param stderr the flag to indicate if this is StdErr message. + * @param line Text to be logged + * @param offset Offset describes the place for current text to be inserted + * @param stderr the flag to indicate if this is StdErr message. + * @param context the context to be associated with the parameter. */ - void logLine(String line, long offset, boolean stderr); + void logLine(String line, long offset, boolean stderr, String context); + + /** + * Logs a metric with Comet. For running experiment updates current step to one from param! + * Metrics are generally values that change from step to step. + * + * @param metricName The name for the metric to be logged + * @param metricValue The new value for the metric. If the values for a metric are plottable we will plot them. + * @param context the context to be associated with the metric. + */ + void logMetric(String metricName, Object metricValue, ExperimentContext context); /** * Logs a metric with Comet. For running experiment updates current step to one from param! @@ -153,10 +165,19 @@ public interface Experiment { * @param metricName The name for the metric to be logged * @param metricValue The new value for the metric. If the values for a metric are plottable we will plot them * @param step The current step for this metric, this will set the given step for this experiment - * @param epoch The current epoch for this metric, this will set the given epoch for this experiment + * @param epoch The current epoch for this metric, this will set the given epoch for this experiment. */ void logMetric(String metricName, Object metricValue, long step, long epoch); + /** + * Logs a param with Comet. For running experiment updates current step to one from param! + * Params should be set at the start of the experiment. + * + * @param parameterName The name of the param being logged + * @param paramValue The value for the param being logged + * @param context the context to be associated with the parameter. + */ + void logParameter(String parameterName, Object paramValue, ExperimentContext context); /** * Logs a param with Comet. For running experiment updates current step to one from param! @@ -164,7 +185,7 @@ public interface Experiment { * * @param parameterName The name of the param being logged * @param paramValue The value for the param being logged - * @param step The current step for this metric, this will set the given step for this experiment + * @param step The current step for this metric, this will set the given step for this experiment. */ void logParameter(String parameterName, Object paramValue, long step); @@ -214,6 +235,15 @@ public interface Experiment { */ void logEndTime(long endTimeMillis); + /** + * Allows you to report code for the experiment. + * + * @param code Code to be sent to Comet + * @param fileName Name of source file to be displayed on UI 'code' tab + * @param context the context to be associated with the asset. + */ + void logCode(String code, String fileName, ExperimentContext context); + /** * Allows you to report code for the experiment. * @@ -222,6 +252,14 @@ public interface Experiment { */ void logCode(String code, String fileName); + /** + * Allows you to report code for the experiment. + * + * @param file Asset with source code to be sent + * @param context the context to be associated with the asset. + */ + void logCode(File file, ExperimentContext context); + /** * Allows you to report code for the experiment. * @@ -236,11 +274,41 @@ public interface Experiment { * @param asset The asset to be stored * @param fileName The file name under which the asset should be stored in Comet. E.g. "someFile.txt" * @param overwrite Whether to overwrite files of the same name in Comet - * @param step the step to be associated with asset - * @param epoch the epoch to be associated with asset + * @param context the context to be associated with the asset. + */ + void uploadAsset(File asset, String fileName, boolean overwrite, ExperimentContext context); + + /** + * Upload an asset to be associated with the experiment, for example the trained weights of a neural net. + * For running experiment updates current step to one from param! + * + * @param asset The asset to be stored + * @param fileName The file name under which the asset should be stored in Comet. E.g. "someFile.txt" + * @param overwrite Whether to overwrite files of the same name in Comet + * @param step the step to be associated with the asset + * @param epoch the epoch to be associated with the asset */ void uploadAsset(File asset, String fileName, boolean overwrite, long step, long epoch); + /** + * Upload an asset to be associated with the experiment, for example the trained weights of a neural net. + * For running experiment updates current step to one from param! + * + * @param asset The file asset to be stored. The name of the file will be used as assets identifier on Comet. + * @param overwrite Whether to overwrite files of the same name in Comet + * @param context the context to be associated with the asset. + */ + void uploadAsset(File asset, boolean overwrite, ExperimentContext context); + + /** + * Upload an asset to be associated with the experiment, for example the trained weights of a neural net. + * For running experiment updates current step to one from param! + * + * @param asset The file asset to be stored. The name of the file will be used as assets identifier on Comet. + * @param overwrite Whether to overwrite files of the same name in Comet + * @param step the step to be associated with the asset + * @param epoch the epoch to be associated with the asset + */ void uploadAsset(File asset, boolean overwrite, long step, long epoch); /** diff --git a/comet-java-client/src/main/java/ml/comet/experiment/ExperimentBuilder.java b/comet-java-client/src/main/java/ml/comet/experiment/ExperimentBuilder.java index 0c8e000f..922dc4c6 100644 --- a/comet-java-client/src/main/java/ml/comet/experiment/ExperimentBuilder.java +++ b/comet-java-client/src/main/java/ml/comet/experiment/ExperimentBuilder.java @@ -1,7 +1,9 @@ package ml.comet.experiment; import lombok.experimental.UtilityClass; +import ml.comet.experiment.builder.ApiExperimentBuilder; import ml.comet.experiment.builder.OnlineExperimentBuilder; +import ml.comet.experiment.impl.ApiExperimentImpl; import ml.comet.experiment.impl.OnlineExperimentImpl; /** @@ -23,7 +25,19 @@ public class ExperimentBuilder { * * @return the instance of the {@link OnlineExperimentBuilder}. */ + @SuppressWarnings({"MethodName"}) public static OnlineExperimentBuilder OnlineExperiment() { return OnlineExperimentImpl.builder(); } + + /** + * The factory to create instance of the {@link ApiExperimentBuilder} which can be used + * to configure and create fully initialized instance of the {@link ApiExperiment}. + * + * @return the initialized instance of the {@link ApiExperimentBuilder}. + */ + @SuppressWarnings({"MethodName"}) + public static ApiExperimentBuilder ApiExperiment() { + return ApiExperimentImpl.builder(); + } } diff --git a/comet-java-client/src/main/java/ml/comet/experiment/OnlineExperiment.java b/comet-java-client/src/main/java/ml/comet/experiment/OnlineExperiment.java index cce3b968..9ea5003b 100644 --- a/comet-java-client/src/main/java/ml/comet/experiment/OnlineExperiment.java +++ b/comet-java-client/src/main/java/ml/comet/experiment/OnlineExperiment.java @@ -1,13 +1,21 @@ package ml.comet.experiment; +import ml.comet.experiment.context.ExperimentContext; + import java.io.File; import java.io.IOException; /** - * Describes the public contract of the online experiment which extends functionality of the Experiment by providing - * additional methods to log various parameters in real time. + * The {@code OnlineExperiment} should be used to asynchronously update data of your Comet.ml experiment. + * + *

This experiment type allows you to automatically intercept {@code StdOut} and {@code StdErr} streams and send + * them to the Comet.ml. Use the {@link #setInterceptStdout()} to start automatic interception of {@code StdOut} and + * the {@link #stopInterceptStdout()} to stop. + * + *

Also, it is possible to use {@link #setStep(long)}, {@link #setEpoch(long)}, + * and {@link #setContext(String)} which will bbe automatically associated with related logged data records. */ -public interface OnlineExperiment extends Experiment { +public interface OnlineExperiment extends Experiment, AutoCloseable { /** * Turn on intercept of stdout and stderr and the logging of both in Comet. @@ -62,32 +70,31 @@ public interface OnlineExperiment extends Experiment { long getEpoch(); /** - * Sets the context for any logs and uploaded files. + * Sets the context identifier for any logs and uploaded files. * - * @param context the context to be associated with any log records, files, and assets. + * @param context the context identifier to be associated with any log records, files, and assets. */ void setContext(String context); /** - * Gets the current context as recorded in the Experiment object locally. + * Gets the current context identifier as recorded in the {@link OnlineExperiment} object locally. * * @return the current context which associated with log records of this experiment. */ String getContext(); /** - * Logs a metric with Comet under the current experiment step. + * Logs a metric with Comet. For running experiment updates current step to one from param! * Metrics are generally values that change from step to step. * * @param metricName The name for the metric to be logged - * @param metricValue The new value for the metric. If the values for a metric are plottable we will plot them. - * @param step The step to be associated with this metric + * @param metricValue The new value for the metric. If the values for a metric are plottable we will plot them + * @param step The current step for this metric, this will set the given step for this experiment */ void logMetric(String metricName, Object metricValue, long step); void logMetric(String metricName, Object metricValue); - /** * Logs a param with Comet under the current experiment step. * Params should be set at the start of the experiment. @@ -97,6 +104,15 @@ public interface OnlineExperiment extends Experiment { */ void logParameter(String parameterName, Object paramValue); + /** + * Send output logs to Comet. + * + * @param line Text to be logged + * @param offset Offset describes the place for current text to be inserted + * @param stderr the flag to indicate if this is StdErr message. + */ + void logLine(String line, long offset, boolean stderr); + /** * Upload an asset under the current experiment step to be associated with the experiment, * for example the trained weights of a neural net. @@ -111,4 +127,18 @@ public interface OnlineExperiment extends Experiment { void uploadAsset(File asset, String fileName, boolean overwrite); void uploadAsset(File asset, boolean overwrite); + + /** + * Logs all the files located in the given folder as assets. + * + * @param folder the folder you want to log. + * @param logFilePath if {@code true}, log the file path with each file. + * @param recursive if {@code true}, recurse the folder. + * @param context the experiment context to be associated with the logged assets. + */ + void logAssetFolder(File folder, boolean logFilePath, boolean recursive, ExperimentContext context); + + void logAssetFolder(File folder, boolean logFilePath, boolean recursive); + + void logAssetFolder(File folder, boolean logFilePath); } diff --git a/comet-java-client/src/main/java/ml/comet/experiment/builder/ApiExperimentBuilder.java b/comet-java-client/src/main/java/ml/comet/experiment/builder/ApiExperimentBuilder.java index 882c5a22..a7b8053b 100644 --- a/comet-java-client/src/main/java/ml/comet/experiment/builder/ApiExperimentBuilder.java +++ b/comet-java-client/src/main/java/ml/comet/experiment/builder/ApiExperimentBuilder.java @@ -1,12 +1,12 @@ package ml.comet.experiment.builder; -import ml.comet.experiment.Experiment; +import ml.comet.experiment.ApiExperiment; /** - * Defines the public contract of the factory to create initialized instances of the {@link Experiment} allowing + * Defines the public contract of the factory to create initialized instances of the {@link ApiExperiment} allowing * to work with Comet API synchronously. */ -public interface ApiExperimentBuilder extends BaseCometBuilder { +public interface ApiExperimentBuilder extends BaseCometBuilder { /** * Allows to continue a previous experiment by providing the key of the existing experiment. * diff --git a/comet-java-client/src/main/java/ml/comet/experiment/context/ExperimentContext.java b/comet-java-client/src/main/java/ml/comet/experiment/context/ExperimentContext.java new file mode 100644 index 00000000..ee1020b9 --- /dev/null +++ b/comet-java-client/src/main/java/ml/comet/experiment/context/ExperimentContext.java @@ -0,0 +1,142 @@ +package ml.comet.experiment.context; + +import lombok.Data; +import lombok.NonNull; +import org.apache.commons.lang3.StringUtils; + +/** + * Describes context of the {@link ml.comet.experiment.Experiment}. + */ +@Data +public final class ExperimentContext { + private Long step; + private Long epoch; + private String context; + + ExperimentContext() { + } + + /** + * Creates new instance with specified parameters. + * + * @param step the current step of the experiment. + * @param epoch the current epoch of the experiment. + * @param context the current context identifier of the data log operation. + */ + public ExperimentContext(long step, long epoch, String context) { + this.step = step; + this.epoch = epoch; + this.context = context; + } + + /** + * Creates new instance with specified parameters. + * + * @param step the current step of the experiment. + * @param epoch the current epoch of the experiment. + */ + public ExperimentContext(long step, long epoch) { + this(step, epoch, StringUtils.EMPTY); + } + + /** + * Creates new instance with specified parameters. + * + * @param step the current step of the experiment. + */ + public ExperimentContext(long step) { + this(step, 0); + } + + /** + * Merges not empty values from other context into this one. + * + * @param other the context to be merged into this. + */ + public void mergeFrom(@NonNull ExperimentContext other) { + if (this == other) { + return; + } + + if (other.step != null) { + this.step = other.step; + } + if (other.epoch != null) { + this.epoch = other.epoch; + } + if (StringUtils.isNotBlank(other.context)) { + this.context = other.context; + } + } + + /** + * The factory to return empty {@link ExperimentContext} instance. + * + * @return the empty {@link ExperimentContext} + */ + public static ExperimentContext empty() { + return new ExperimentContext(); + } + + /** + * Returns builder to create populated instance of the {@link ExperimentContext}. + * + * @return the builder to create populated instance of the {@link ExperimentContext}. + */ + public static ExperimentContextBuilder builder() { + return new ExperimentContextBuilder(); + } + + /** + * Builder to create populated instance of the {@link ExperimentContext}. + */ + public static final class ExperimentContextBuilder { + private final ExperimentContext context; + + ExperimentContextBuilder() { + this.context = new ExperimentContext(); + } + + /** + * Populates context with specified step of the experiment. + * + * @param step the experiment's step. + * @return the instance of this builder. + */ + public ExperimentContextBuilder withStep(long step) { + this.context.step = step; + return this; + } + + /** + * Populates context with specified epoch of the experiment. + * + * @param epoch the epoch of the experiment. + * @return the instance of this builder. + */ + public ExperimentContextBuilder withEpoch(long epoch) { + this.context.epoch = epoch; + return this; + } + + /** + * Populates context with specified context identifier string. + * + * @param context the context identifier string. + * @return the instance of this builder. + */ + public ExperimentContextBuilder withContext(String context) { + this.context.context = context; + return this; + } + + /** + * Creates fully initialized {@link ExperimentContext} instance. + * + * @return the fully initialized {@link ExperimentContext} instance. + */ + public ExperimentContext build() { + return this.context; + } + } +} diff --git a/comet-java-client/src/main/java/ml/comet/experiment/impl/ApiExperimentImpl.java b/comet-java-client/src/main/java/ml/comet/experiment/impl/ApiExperimentImpl.java index c15a6b04..982c90fb 100644 --- a/comet-java-client/src/main/java/ml/comet/experiment/impl/ApiExperimentImpl.java +++ b/comet-java-client/src/main/java/ml/comet/experiment/impl/ApiExperimentImpl.java @@ -2,6 +2,7 @@ import lombok.Getter; import lombok.NonNull; +import ml.comet.experiment.ApiExperiment; import ml.comet.experiment.Experiment; import ml.comet.experiment.builder.ApiExperimentBuilder; import ml.comet.experiment.impl.config.CometConfig; @@ -22,7 +23,7 @@ /** * Implementation of the {@link Experiment} that allows to read/update existing experiment synchronously. */ -public final class ApiExperimentImpl extends BaseExperiment { +public final class ApiExperimentImpl extends BaseExperiment implements ApiExperiment { @Getter private Logger logger = LoggerFactory.getLogger(ApiExperimentImpl.class); @@ -64,7 +65,7 @@ public String getExperimentName() { @Override public Optional getExperimentLink() { - if (StringUtils.isEmpty(experimentKey)) { + if (StringUtils.isBlank(experimentKey)) { return Optional.empty(); } try { @@ -131,8 +132,8 @@ public ApiExperimentImpl.ApiExperimentBuilderImpl withConfigOverride(@NonNull fi } @Override - public ApiExperimentImpl build() { - if (StringUtils.isEmpty(this.apiKey)) { + public ApiExperiment build() { + if (StringUtils.isBlank(this.apiKey)) { this.apiKey = COMET_API_KEY.getString(); } return new ApiExperimentImpl( diff --git a/comet-java-client/src/main/java/ml/comet/experiment/impl/BaseExperiment.java b/comet-java-client/src/main/java/ml/comet/experiment/impl/BaseExperiment.java index 2c89908d..de5207a1 100644 --- a/comet-java-client/src/main/java/ml/comet/experiment/impl/BaseExperiment.java +++ b/comet-java-client/src/main/java/ml/comet/experiment/impl/BaseExperiment.java @@ -1,39 +1,27 @@ package ml.comet.experiment.impl; import io.reactivex.rxjava3.core.Single; -import io.reactivex.rxjava3.disposables.CompositeDisposable; -import io.reactivex.rxjava3.functions.Action; import io.reactivex.rxjava3.functions.BiFunction; import io.reactivex.rxjava3.functions.Function; -import io.reactivex.rxjava3.schedulers.Schedulers; import lombok.Getter; import lombok.NonNull; -import lombok.Setter; import ml.comet.experiment.Experiment; +import ml.comet.experiment.context.ExperimentContext; import ml.comet.experiment.exception.CometApiException; import ml.comet.experiment.exception.CometGeneralException; -import ml.comet.experiment.impl.constants.AssetType; -import ml.comet.experiment.impl.constants.QueryParamName; +import ml.comet.experiment.impl.asset.Asset; +import ml.comet.experiment.impl.asset.AssetType; import ml.comet.experiment.impl.http.Connection; import ml.comet.experiment.impl.http.ConnectionInitializer; import ml.comet.experiment.impl.utils.CometUtils; -import ml.comet.experiment.model.AddExperimentTagsRest; -import ml.comet.experiment.model.AddGraphRest; import ml.comet.experiment.model.CreateExperimentRequest; import ml.comet.experiment.model.CreateExperimentResponse; import ml.comet.experiment.model.ExperimentAssetLink; import ml.comet.experiment.model.ExperimentMetadataRest; import ml.comet.experiment.model.ExperimentStatusResponse; -import ml.comet.experiment.model.ExperimentTimeRequest; import ml.comet.experiment.model.GitMetadata; import ml.comet.experiment.model.GitMetadataRest; -import ml.comet.experiment.model.HtmlRest; import ml.comet.experiment.model.LogDataResponse; -import ml.comet.experiment.model.LogOtherRest; -import ml.comet.experiment.model.MetricRest; -import ml.comet.experiment.model.OutputLine; -import ml.comet.experiment.model.OutputUpdate; -import ml.comet.experiment.model.ParameterRest; import ml.comet.experiment.model.ValueMinMaxDto; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; @@ -41,26 +29,30 @@ import java.io.File; import java.nio.charset.StandardCharsets; import java.time.Duration; -import java.util.Collections; -import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.Optional; -import static ml.comet.experiment.impl.constants.ApiEndpoints.ADD_ASSET; -import static ml.comet.experiment.impl.constants.AssetType.ASSET_TYPE_SOURCE_CODE; -import static ml.comet.experiment.impl.constants.QueryParamName.CONTEXT; -import static ml.comet.experiment.impl.constants.QueryParamName.EPOCH; -import static ml.comet.experiment.impl.constants.QueryParamName.EXPERIMENT_KEY; -import static ml.comet.experiment.impl.constants.QueryParamName.FILE_NAME; -import static ml.comet.experiment.impl.constants.QueryParamName.OVERWRITE; -import static ml.comet.experiment.impl.constants.QueryParamName.STEP; -import static ml.comet.experiment.impl.constants.QueryParamName.TYPE; +import static ml.comet.experiment.impl.asset.AssetType.ASSET_TYPE_ASSET; +import static ml.comet.experiment.impl.asset.AssetType.ASSET_TYPE_SOURCE_CODE; +import static ml.comet.experiment.impl.resources.LogMessages.EXPERIMENT_CLEANUP_PROMPT; +import static ml.comet.experiment.impl.resources.LogMessages.EXPERIMENT_LIVE; +import static ml.comet.experiment.impl.resources.LogMessages.FAILED_READ_DATA_FOR_EXPERIMENT; +import static ml.comet.experiment.impl.resources.LogMessages.getString; +import static ml.comet.experiment.impl.utils.DataUtils.createGraphRequest; +import static ml.comet.experiment.impl.utils.DataUtils.createLogEndTimeRequest; +import static ml.comet.experiment.impl.utils.DataUtils.createLogHtmlRequest; +import static ml.comet.experiment.impl.utils.DataUtils.createLogLineRequest; +import static ml.comet.experiment.impl.utils.DataUtils.createLogMetricRequest; +import static ml.comet.experiment.impl.utils.DataUtils.createLogOtherRequest; +import static ml.comet.experiment.impl.utils.DataUtils.createLogParamRequest; +import static ml.comet.experiment.impl.utils.DataUtils.createLogStartTimeRequest; +import static ml.comet.experiment.impl.utils.DataUtils.createTagRequest; /** - * The base class for all experiment implementations providing implementation of common routines. + * The base class for all synchronous experiment implementations providing implementation of common routines + * using synchronous networking. */ -public abstract class BaseExperiment implements Experiment { +abstract class BaseExperiment implements Experiment { final String apiKey; final String baseUrl; final int maxAuthRetries; @@ -73,19 +65,10 @@ public abstract class BaseExperiment implements Experiment { String experimentName; boolean alive; - @Setter @Getter - long step; - @Setter - @Getter - long epoch; - @Getter - @Setter - private String context = StringUtils.EMPTY; - private RestApiClient restApiClient; + @Getter private Connection connection; - private final CompositeDisposable disposables = new CompositeDisposable(); /** * Returns logger instance associated with particular experiment. The subclasses should override this method to @@ -138,10 +121,10 @@ void init() { * @throws IllegalArgumentException if validation failed. */ private void validateInitialParams() throws IllegalArgumentException { - if (StringUtils.isEmpty(this.apiKey)) { + if (StringUtils.isBlank(this.apiKey)) { throw new IllegalArgumentException("API key is not specified!"); } - if (StringUtils.isEmpty(this.baseUrl)) { + if (StringUtils.isBlank(this.baseUrl)) { throw new IllegalArgumentException("The Comet base URL is not specified!"); } } @@ -152,21 +135,21 @@ private void validateInitialParams() throws IllegalArgumentException { * @throws CometGeneralException if failed to register experiment. */ void registerExperiment() throws CometGeneralException { - if (!StringUtils.isEmpty(this.experimentKey)) { + if (StringUtils.isNotBlank(this.experimentKey)) { getLogger().debug("Not registering a new experiment. Using previous experiment key {}", this.experimentKey); return; } // do synchronous call to register experiment CreateExperimentResponse result = this.restApiClient.registerExperiment( - new CreateExperimentRequest(this.workspaceName, this.projectName, this.experimentName)) + new CreateExperimentRequest(this.workspaceName, this.projectName, this.experimentName)) .blockingGet(); this.experimentKey = result.getExperimentKey(); this.experimentLink = result.getLink(); - getLogger().info("Experiment is live on comet.ml " + this.experimentLink); + getLogger().info(getString(EXPERIMENT_LIVE, this.experimentLink)); - if (StringUtils.isEmpty(this.experimentKey)) { + if (StringUtils.isBlank(this.experimentKey)) { throw new CometGeneralException("Failed to register onlineExperiment with Comet ML"); } } @@ -202,38 +185,23 @@ public void setExperimentName(@NonNull String experimentName) { * * @param metricName The name for the metric to be logged * @param metricValue The new value for the metric. If the values for a metric are plottable we will plot them - * @param step The current step for this metric, this will set the given step for this experiment - * @param epoch The current epoch for this metric, this will set the given epoch for this experiment + * @param context the context to be associated with the parameter. * @throws CometApiException if received response with failure code. */ @Override - public void logMetric(@NonNull String metricName, @NonNull Object metricValue, long step, long epoch) { + public void logMetric(@NonNull String metricName, @NonNull Object metricValue, + @NonNull ExperimentContext context) { if (getLogger().isDebugEnabled()) { - getLogger().debug("logMetric {} = {}, step: {}, epoch: {}", metricName, metricValue, step, epoch); + getLogger().debug("logMetric {} = {}, context: {}", metricName, metricValue, context); } sendSynchronously(restApiClient::logMetric, - createLogMetricRequest(metricName, metricValue, step, epoch, this.context)); + createLogMetricRequest(metricName, metricValue, context)); } - /** - * Asynchronous version that only logs any received exceptions or failures. - * - * @param metricName The name for the metric to be logged - * @param metricValue The new value for the metric. If the values for a metric are plottable we will plot them - * @param step The current step for this metric, this will set the given step for this experiment - * @param epoch The current epoch for this metric, this will set the given epoch for this experiment - * @param onComplete The optional action to be invoked when this operation asynchronously completes. - * Can be {@code null} if not interested in completion signal. - */ - void logMetricAsync(@NonNull String metricName, @NonNull Object metricValue, - long step, long epoch, Action onComplete) { - if (getLogger().isDebugEnabled()) { - getLogger().debug("logMetricAsync {} = {}, step: {}, epoch: {}", metricName, metricValue, step, epoch); - } - - MetricRest metricRequest = createLogMetricRequest(metricName, metricValue, step, epoch, this.context); - this.sendAsynchronously(restApiClient::logMetric, metricRequest, onComplete); + @Override + public void logMetric(String metricName, Object metricValue, long step, long epoch) { + this.logMetric(metricName, metricValue, new ExperimentContext(step, epoch)); } /** @@ -241,49 +209,37 @@ void logMetricAsync(@NonNull String metricName, @NonNull Object metricValue, * * @param parameterName The name of the param being logged * @param paramValue The value for the param being logged - * @param step The current step for this metric, this will set the given step for this experiment + * @param context the context to be associated with the parameter. */ @Override - public void logParameter(@NonNull String parameterName, @NonNull Object paramValue, long step) { + public void logParameter(String parameterName, Object paramValue, ExperimentContext context) { if (getLogger().isDebugEnabled()) { - getLogger().debug("logParameter {} = {}, step: {}", parameterName, paramValue, step); + getLogger().debug("logParameter {} = {}, context: {}", parameterName, paramValue, context); } sendSynchronously(restApiClient::logParameter, - createLogParamRequest(parameterName, paramValue, step, this.context)); + createLogParamRequest(parameterName, paramValue, context)); } - /** - * Asynchronous version that only logs any received exceptions or failures. - * - * @param parameterName The name of the param being logged - * @param paramValue The value for the param being logged - * @param step The current step for this metric, this will set the given step for this experiment - * @param onComplete The optional action to be invoked when this operation asynchronously completes. - * Can be {@code null} if not interested in completion signal. - */ - void logParameterAsync(@NonNull String parameterName, @NonNull Object paramValue, long step, Action onComplete) { - if (getLogger().isDebugEnabled()) { - getLogger().debug("logParameterAsync {} = {}, step: {}", parameterName, paramValue, step); - } - - ParameterRest paramRequest = createLogParamRequest(parameterName, paramValue, step, this.context); - this.sendAsynchronously(restApiClient::logParameter, paramRequest, onComplete); + @Override + public void logParameter(String parameterName, Object paramValue, long step) { + this.logParameter(parameterName, paramValue, new ExperimentContext(step)); } /** * Synchronous version that waits for result or exception. Also, it checks the response status for failure. * - * @param line Text to be logged - * @param offset Offset describes the place for current text to be inserted - * @param stderr the flag to indicate if this is StdErr message. + * @param line Text to be logged + * @param offset Offset describes the place for current text to be inserted + * @param stderr the flag to indicate if this is StdErr message. + * @param context the context to be associated with the parameter. */ @Override - public void logLine(String line, long offset, boolean stderr) { + public void logLine(String line, long offset, boolean stderr, String context) { validate(); sendSynchronously(restApiClient::logOutputLine, - createLogLineRequest(line, offset, stderr, this.context)); + createLogLineRequest(line, offset, stderr, context)); } /** @@ -299,26 +255,7 @@ public void logHtml(@NonNull String html, boolean override) { getLogger().debug("logHtml {}, override: {}", html, override); } - sendSynchronously(restApiClient::logHtml, - createLogHtmlRequest(html, override)); - } - - /** - * Asynchronous version that only logs any received exceptions or failures. - * - * @param html A block of html to be sent to Comet - * @param override Whether previous html sent should be deleted. - * If true the old html will be deleted. - * @param onComplete The optional action to be invoked when this operation asynchronously completes. - * Can be {@code null} if not interested in completion signal. - */ - void logHtmlAsync(@NonNull String html, boolean override, Action onComplete) { - if (getLogger().isDebugEnabled()) { - getLogger().debug("logHtmlAsync {}, override: {}", html, override); - } - - HtmlRest htmlRequest = createLogHtmlRequest(html, override); - this.sendAsynchronously(restApiClient::logHtml, htmlRequest, onComplete); + sendSynchronously(restApiClient::logHtml, createLogHtmlRequest(html, override)); } /** @@ -336,23 +273,6 @@ public void logOther(@NonNull String key, @NonNull Object value) { sendSynchronously(restApiClient::logOther, createLogOtherRequest(key, value)); } - /** - * Asynchronous version that only logs any received exceptions or failures. - * - * @param key The key for the data to be stored - * @param value The value for said key - * @param onComplete The optional action to be invoked when this operation asynchronously completes. - * Can be {@code null} if not interested in completion signal. - */ - void logOtherAsync(@NonNull String key, @NonNull Object value, Action onComplete) { - if (getLogger().isDebugEnabled()) { - getLogger().debug("logOtherAsync {} {}", key, value); - } - - LogOtherRest request = createLogOtherRequest(key, value); - sendAsynchronously(restApiClient::logOther, request, onComplete); - } - /** * Synchronous version that waits for result or exception. Also, it checks the response status for failure. * @@ -367,21 +287,6 @@ public void addTag(@NonNull String tag) { sendSynchronously(restApiClient::addTag, createTagRequest(tag)); } - /** - * Asynchronous version that only logs any received exceptions or failures. - * - * @param tag The tag to be added - * @param onComplete The optional action to be invoked when this operation asynchronously completes. - * Can be {@code null} if not interested in completion signal. - */ - public void addTagAsync(@NonNull String tag, Action onComplete) { - if (getLogger().isDebugEnabled()) { - getLogger().debug("addTagAsync {}", tag); - } - - sendAsynchronously(restApiClient::addTag, createTagRequest(tag), onComplete); - } - /** * Synchronous version that waits for result or exception. Also, it checks the response status for failure. * @@ -396,21 +301,6 @@ public void logGraph(@NonNull String graph) { sendSynchronously(restApiClient::logGraph, createGraphRequest(graph)); } - /** - * Asynchronous version that only logs any received exceptions or failures. - * - * @param graph The graph to be logged - * @param onComplete The optional action to be invoked when this operation asynchronously completes. - * Can be {@code null} if not interested in completion signal. - */ - void logGraphAsync(@NonNull String graph, Action onComplete) { - if (getLogger().isDebugEnabled()) { - getLogger().debug("logGraphAsync {}", graph); - } - - sendAsynchronously(restApiClient::logGraph, createGraphRequest(graph), onComplete); - } - /** * Synchronous version that waits for result or exception. Also, it checks the response status for failure. * @@ -425,21 +315,6 @@ public void logStartTime(long startTimeMillis) { sendSynchronously(restApiClient::logStartEndTime, createLogStartTimeRequest(startTimeMillis)); } - /** - * Asynchronous version that only logs any received exceptions or failures. - * - * @param startTimeMillis When you want to say that the experiment started - * @param onComplete The optional action to be invoked when this operation asynchronously completes. - * Can be {@code null} if not interested in completion signal. - */ - void logStartTimeAsync(long startTimeMillis, Action onComplete) { - if (getLogger().isDebugEnabled()) { - getLogger().debug("logStartTimeAsync {}", startTimeMillis); - } - - sendAsynchronously(restApiClient::logStartEndTime, createLogStartTimeRequest(startTimeMillis), onComplete); - } - /** * Synchronous version that waits for result or exception. Also, it checks the response status for failure. * @@ -454,21 +329,6 @@ public void logEndTime(long endTimeMillis) { sendSynchronously(restApiClient::logStartEndTime, createLogEndTimeRequest(endTimeMillis)); } - /** - * Asynchronous version that only logs any received exceptions or failures. - * - * @param endTimeMillis When you want to say that the experiment ended - * @param onComplete The optional action to be invoked when this operation asynchronously completes. - * Can be {@code null} if not interested in completion signal. - */ - void logEndTimeAsync(long endTimeMillis, Action onComplete) { - if (getLogger().isDebugEnabled()) { - getLogger().debug("logEndTimeAsync {}", endTimeMillis); - } - - sendAsynchronously(restApiClient::logStartEndTime, createLogEndTimeRequest(endTimeMillis), onComplete); - } - /** * Synchronous version that waits for result or exception. Also, it checks the response status for failure. * @@ -483,95 +343,75 @@ public void logGitMetadata(GitMetadata gitMetadata) { sendSynchronously(restApiClient::logGitMetadata, gitMetadata); } - /** - * Asynchronous version that only logs any received exceptions or failures. - * - * @param gitMetadata The Git Metadata for the experiment. - * @param onComplete The optional action to be invoked when this operation asynchronously completes. - * Can be {@code null} if not interested in completion signal. - */ - void logGitMetadataAsync(GitMetadata gitMetadata, Action onComplete) { - if (getLogger().isDebugEnabled()) { - getLogger().debug("logGitMetadata {}", gitMetadata); - } - - sendAsynchronously(restApiClient::logGitMetadata, gitMetadata, onComplete); - } - @Override - public void logCode(@NonNull String code, @NonNull String fileName) { + public void logCode(@NonNull String code, @NonNull String fileName, @NonNull ExperimentContext context) { if (getLogger().isDebugEnabled()) { getLogger().debug("log raw source code, file name: {}", fileName); } - validate(); + Asset asset = new Asset(); + asset.setFileLikeData(code.getBytes(StandardCharsets.UTF_8)); + asset.setFileName(fileName); + asset.setExperimentContext(context); + asset.setType(ASSET_TYPE_SOURCE_CODE); - Map params = new HashMap() {{ - put(EXPERIMENT_KEY, getExperimentKey()); - put(FILE_NAME, fileName); - put(CONTEXT, getContext()); - put(TYPE, ASSET_TYPE_SOURCE_CODE.type()); - put(OVERWRITE, Boolean.toString(false)); - }}; + sendSynchronously(restApiClient::logAsset, asset); + } - this.connection.sendPostAsync(code.getBytes(StandardCharsets.UTF_8), ADD_ASSET, params) - .toCompletableFuture() - .exceptionally(t -> { - getLogger().error("failed to log raw source code with file name {}", fileName, t); - return null; - }); + @Override + public void logCode(String code, String fileName) { + this.logCode(code, fileName, ExperimentContext.empty()); } @Override - public void logCode(@NonNull File asset) { + public void logCode(@NonNull File file, @NonNull ExperimentContext context) { if (getLogger().isDebugEnabled()) { - getLogger().debug("log source code from file {}", asset.getName()); + getLogger().debug("log source code from file {}", file.getName()); } - validate(); + Asset asset = new Asset(); + asset.setFile(file); + asset.setFileName(file.getName()); + asset.setExperimentContext(context); + asset.setType(ASSET_TYPE_SOURCE_CODE); - Map params = new HashMap() {{ - put(EXPERIMENT_KEY, getExperimentKey()); - put(FILE_NAME, asset.getName()); - put(CONTEXT, getContext()); - put(TYPE, ASSET_TYPE_SOURCE_CODE.type()); - put(OVERWRITE, Boolean.toString(false)); - }}; + sendSynchronously(restApiClient::logAsset, asset); + } - this.connection.sendPostAsync(asset, ADD_ASSET, params) - .toCompletableFuture() - .exceptionally(t -> { - getLogger().error("failed to log source code from file {}", asset, t); - return null; - }); + @Override + public void logCode(File file) { + this.logCode(file, ExperimentContext.empty()); } @Override - public void uploadAsset(@NonNull File asset, @NonNull String fileName, boolean overwrite, long step, long epoch) { + public void uploadAsset(@NonNull File file, @NonNull String fileName, + boolean overwrite, @NonNull ExperimentContext context) { if (getLogger().isDebugEnabled()) { - getLogger().debug("uploadAsset from file {}, name {}, override {}, step {}, epoch {}", - asset.getName(), fileName, overwrite, step, epoch); + getLogger().debug("uploadAsset from file {}, name {}, override {}, context {}", + file.getName(), fileName, overwrite, context); } - validate(); + Asset asset = new Asset(); + asset.setFile(file); + asset.setFileName(fileName); + asset.setExperimentContext(context); + asset.setOverwrite(overwrite); + asset.setType(ASSET_TYPE_ASSET); + + sendSynchronously(restApiClient::logAsset, asset); + } - this.connection - .sendPostAsync(asset, ADD_ASSET, new HashMap() {{ - put(EXPERIMENT_KEY, getExperimentKey()); - put(FILE_NAME, fileName); - put(STEP, Long.toString(step)); - put(EPOCH, Long.toString(epoch)); - put(CONTEXT, getContext()); - put(OVERWRITE, Boolean.toString(overwrite)); - }}) - .toCompletableFuture() - .exceptionally(t -> { - getLogger().error("failed to upload asset from file {} with name {}", asset, fileName, t); - return null; - }); + @Override + public void uploadAsset(File asset, String fileName, boolean overwrite, long step, long epoch) { + this.uploadAsset(asset, fileName, overwrite, new ExperimentContext(step, epoch)); + } + + @Override + public void uploadAsset(File asset, boolean overwrite, ExperimentContext context) { + this.uploadAsset(asset, asset.getName(), overwrite, context); } @Override public void uploadAsset(@NonNull File asset, boolean overwrite, long step, long epoch) { - uploadAsset(asset, asset.getName(), overwrite, step, epoch); + this.uploadAsset(asset, overwrite, new ExperimentContext(step, epoch)); } @Override @@ -674,14 +514,15 @@ public void end() { if (!this.alive) { return; } - getLogger().info("Waiting for all scheduled uploads to complete. It can take up to {} seconds.", - cleaningTimeout.getSeconds()); + getLogger().info(getString(EXPERIMENT_CLEANUP_PROMPT, cleaningTimeout.getSeconds())); // mark as not alive this.alive = false; // close REST API - this.restApiClient.dispose(); + if (this.restApiClient != null) { + this.restApiClient.dispose(); + } // close connection if (this.connection != null) { @@ -692,12 +533,6 @@ public void end() { getLogger().error("failed to close connection", e); } } - - // dispose all pending calls - if (disposables.size() > 0) { - getLogger().warn("{} calls still has not been processed, disposing", disposables.size()); - } - this.disposables.dispose(); } /** @@ -723,41 +558,11 @@ Optional sendExperimentStatus() { private T loadRemote(final Function> loadFunc, String alias) { return validateAndGetExperimentKey() .concatMap(loadFunc) - .doOnError(ex -> getLogger().error("Failed to read {} for the experiment, experiment key: {}", - alias, this.experimentKey, ex)) + .doOnError(ex -> getLogger().error( + getString(FAILED_READ_DATA_FOR_EXPERIMENT, alias, this.experimentKey), ex)) .blockingGet(); } - /** - * Uses provided function to send request data asynchronously and log received output. Optionally, can use - * provided {@link Action} handler to notify about completion of the operation. - * - * @param func the function to be invoked to send request data. - * @param request the request data object. - * @param onComplete the optional {@link Action} to be notified the operation completes either - * successfully or erroneously. - * @param the type of the request data object. - */ - private void sendAsynchronously(final BiFunction> func, - final T request, final Action onComplete) { - Single single = validateAndGetExperimentKey() - .subscribeOn(Schedulers.io()) - .concatMap(experimentKey -> func.apply(request, experimentKey)); - - // register notification action if provided - if (onComplete != null) { - single = single.doFinally(onComplete); - } - - // subscribe to receive operation results - single - .observeOn(Schedulers.single()) - .subscribe( - (logDataResponse) -> DataResponseLogger.checkAndLog(logDataResponse, getLogger(), request), - (throwable) -> getLogger().error("failed to log {}", request, throwable), - disposables); - } - /** * Uses provided function to send request data synchronously. If response indicating the remote error * received the {@link CometApiException} will be thrown. @@ -785,7 +590,7 @@ private void sendSynchronously(final BiFunction validateAndGetExperimentKey() { - if (StringUtils.isEmpty(this.experimentKey)) { + Single validateAndGetExperimentKey() { + if (StringUtils.isBlank(this.experimentKey)) { return Single.error(new IllegalStateException("Experiment key must be present!")); } if (!this.alive) { @@ -807,93 +612,4 @@ private Single validateAndGetExperimentKey() { } return Single.just(getExperimentKey()); } - - static MetricRest createLogMetricRequest( - @NonNull String metricName, @NonNull Object metricValue, long step, long epoch, String context) { - MetricRest request = new MetricRest(); - request.setMetricName(metricName); - request.setMetricValue(metricValue.toString()); - request.setStep(step); - request.setEpoch(epoch); - request.setTimestamp(System.currentTimeMillis()); - request.setContext(context); - return request; - } - - static ParameterRest createLogParamRequest( - @NonNull String parameterName, @NonNull Object paramValue, long step, String context) { - ParameterRest request = new ParameterRest(); - request.setParameterName(parameterName); - request.setParameterValue(paramValue.toString()); - request.setStep(step); - request.setTimestamp(System.currentTimeMillis()); - request.setContext(context); - return request; - } - - private HtmlRest createLogHtmlRequest(@NonNull String html, boolean override) { - HtmlRest request = new HtmlRest(); - request.setHtml(html); - request.setOverride(override); - request.setTimestamp(System.currentTimeMillis()); - return request; - } - - private LogOtherRest createLogOtherRequest(@NonNull String key, @NonNull Object value) { - LogOtherRest request = new LogOtherRest(); - request.setKey(key); - request.setValue(value.toString()); - request.setTimestamp(System.currentTimeMillis()); - return request; - } - - private AddExperimentTagsRest createTagRequest(@NonNull String tag) { - AddExperimentTagsRest request = new AddExperimentTagsRest(); - request.setAddedTags(Collections.singletonList(tag)); - return request; - } - - private AddGraphRest createGraphRequest(@NonNull String graph) { - AddGraphRest request = new AddGraphRest(); - request.setGraph(graph); - return request; - } - - private ExperimentTimeRequest createLogStartTimeRequest(long startTimeMillis) { - ExperimentTimeRequest request = new ExperimentTimeRequest(); - request.setStartTimeMillis(startTimeMillis); - return request; - } - - private ExperimentTimeRequest createLogEndTimeRequest(long endTimeMillis) { - ExperimentTimeRequest request = new ExperimentTimeRequest(); - request.setEndTimeMillis(endTimeMillis); - return request; - } - - static OutputUpdate createLogLineRequest(@NonNull String line, long offset, boolean stderr, String context) { - OutputLine outputLine = new OutputLine(); - outputLine.setOutput(line); - outputLine.setStderr(stderr); - outputLine.setLocalTimestamp(System.currentTimeMillis()); - outputLine.setOffset(offset); - - OutputUpdate outputUpdate = new OutputUpdate(); - outputUpdate.setRunContext(context); - outputUpdate.setOutputLines(Collections.singletonList(outputLine)); - return outputUpdate; - } - - /** - * Utility class to log asynchronously received data responses. - */ - static final class DataResponseLogger { - static void checkAndLog(LogDataResponse logDataResponse, Logger logger, Object request) { - if (logDataResponse.hasFailed()) { - logger.error("failed to log {}, reason: {}", request, logDataResponse.getMsg()); - } else if (logger.isDebugEnabled()) { - logger.debug("success {}", logDataResponse); - } - } - } } diff --git a/comet-java-client/src/main/java/ml/comet/experiment/impl/BaseExperimentAsync.java b/comet-java-client/src/main/java/ml/comet/experiment/impl/BaseExperimentAsync.java new file mode 100644 index 00000000..d469ed8e --- /dev/null +++ b/comet-java-client/src/main/java/ml/comet/experiment/impl/BaseExperimentAsync.java @@ -0,0 +1,462 @@ +package ml.comet.experiment.impl; + +import io.reactivex.rxjava3.core.Observable; +import io.reactivex.rxjava3.core.Single; +import io.reactivex.rxjava3.disposables.CompositeDisposable; +import io.reactivex.rxjava3.functions.Action; +import io.reactivex.rxjava3.functions.BiFunction; +import io.reactivex.rxjava3.schedulers.Schedulers; +import lombok.NonNull; +import ml.comet.experiment.context.ExperimentContext; +import ml.comet.experiment.impl.asset.Asset; +import ml.comet.experiment.impl.utils.AssetUtils; +import ml.comet.experiment.model.GitMetadata; +import ml.comet.experiment.model.HtmlRest; +import ml.comet.experiment.model.LogDataResponse; +import ml.comet.experiment.model.LogOtherRest; +import ml.comet.experiment.model.MetricRest; +import ml.comet.experiment.model.OutputUpdate; +import ml.comet.experiment.model.ParameterRest; +import org.slf4j.Logger; + +import java.io.File; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Stream; + +import static ml.comet.experiment.impl.asset.AssetType.ASSET_TYPE_ASSET; +import static ml.comet.experiment.impl.asset.AssetType.ASSET_TYPE_SOURCE_CODE; +import static ml.comet.experiment.impl.resources.LogMessages.ASSETS_FOLDER_UPLOAD_COMPLETED; +import static ml.comet.experiment.impl.resources.LogMessages.FAILED_TO_LOG_ASSET_FOLDER; +import static ml.comet.experiment.impl.resources.LogMessages.FAILED_TO_LOG_SOME_ASSET_FROM_FOLDER; +import static ml.comet.experiment.impl.resources.LogMessages.FAILED_TO_SEND_LOG_REQUEST; +import static ml.comet.experiment.impl.resources.LogMessages.LOG_ASSET_FOLDER_EMPTY; +import static ml.comet.experiment.impl.resources.LogMessages.getString; +import static ml.comet.experiment.impl.utils.DataUtils.createGraphRequest; +import static ml.comet.experiment.impl.utils.DataUtils.createLogEndTimeRequest; +import static ml.comet.experiment.impl.utils.DataUtils.createLogHtmlRequest; +import static ml.comet.experiment.impl.utils.DataUtils.createLogLineRequest; +import static ml.comet.experiment.impl.utils.DataUtils.createLogMetricRequest; +import static ml.comet.experiment.impl.utils.DataUtils.createLogOtherRequest; +import static ml.comet.experiment.impl.utils.DataUtils.createLogParamRequest; +import static ml.comet.experiment.impl.utils.DataUtils.createLogStartTimeRequest; +import static ml.comet.experiment.impl.utils.DataUtils.createTagRequest; + +/** + * The base class for all asynchronous experiment implementations providing implementation of common routines + * using asynchronous networking. + */ +abstract class BaseExperimentAsync extends BaseExperiment { + ExperimentContext baseContext; + + final CompositeDisposable disposables = new CompositeDisposable(); + + BaseExperimentAsync(@NonNull final String apiKey, + @NonNull final String baseUrl, + int maxAuthRetries, + final String experimentKey, + @NonNull final Duration cleaningTimeout, + final String projectName, + final String workspaceName) { + super(apiKey, baseUrl, maxAuthRetries, experimentKey, cleaningTimeout, projectName, workspaceName); + this.baseContext = ExperimentContext.empty(); + } + + @Override + public void end() { + if (!this.alive) { + return; + } + super.end(); + + // dispose all pending asynchronous calls + if (disposables.size() > 0) { + getLogger().warn("{} calls still has not been processed, disposing", disposables.size()); + } + this.disposables.dispose(); + } + + void updateContext(ExperimentContext context) { + this.baseContext.mergeFrom(context); + } + + /** + * Asynchronous version that only logs any received exceptions or failures. + * + * @param metricName The name for the metric to be logged + * @param metricValue The new value for the metric. If the values for a metric are plottable we will plot them + * @param context the context to be associated with the parameter. + * @param onComplete The optional action to be invoked when this operation asynchronously completes. + * Can be {@code null} if not interested in completion signal. + */ + void logMetric(@NonNull String metricName, @NonNull Object metricValue, + @NonNull ExperimentContext context, Action onComplete) { + this.updateContext(context); + + if (getLogger().isDebugEnabled()) { + getLogger().debug("logMetricAsync {} = {}, context: {}", metricName, metricValue, context); + } + + MetricRest metricRequest = createLogMetricRequest(metricName, metricValue, this.baseContext); + this.sendAsynchronously(getRestApiClient()::logMetric, metricRequest, onComplete); + } + + /** + * Asynchronous version that only logs any received exceptions or failures. + * + * @param parameterName The name of the param being logged + * @param paramValue The value for the param being logged + * @param context the context to be associated with the parameter. + * @param onComplete The optional action to be invoked when this operation asynchronously completes. + * Can be {@code null} if not interested in completion signal. + */ + void logParameter(@NonNull String parameterName, @NonNull Object paramValue, + @NonNull ExperimentContext context, Action onComplete) { + this.updateContext(context); + + if (getLogger().isDebugEnabled()) { + getLogger().debug("logParameterAsync {} = {}, context: {}", parameterName, paramValue, context); + } + + ParameterRest paramRequest = createLogParamRequest(parameterName, paramValue, this.baseContext); + this.sendAsynchronously(getRestApiClient()::logParameter, paramRequest, onComplete); + } + + /** + * Asynchronous version that only logs any received exceptions or failures. + * + * @param html A block of html to be sent to Comet + * @param override Whether previous html sent should be deleted. + * If true the old html will be deleted. + * @param onComplete The optional action to be invoked when this operation asynchronously completes. + * Can be {@code null} if not interested in completion signal. + */ + void logHtml(@NonNull String html, boolean override, Action onComplete) { + if (getLogger().isDebugEnabled()) { + getLogger().debug("logHtmlAsync {}, override: {}", html, override); + } + + HtmlRest htmlRequest = createLogHtmlRequest(html, override); + this.sendAsynchronously(getRestApiClient()::logHtml, htmlRequest, onComplete); + } + + /** + * Asynchronous version that only logs any received exceptions or failures. + * + * @param key The key for the data to be stored + * @param value The value for said key + * @param onComplete The optional action to be invoked when this operation asynchronously completes. + * Can be {@code null} if not interested in completion signal. + */ + void logOther(@NonNull String key, @NonNull Object value, Action onComplete) { + if (getLogger().isDebugEnabled()) { + getLogger().debug("logOtherAsync {} {}", key, value); + } + + LogOtherRest request = createLogOtherRequest(key, value); + sendAsynchronously(getRestApiClient()::logOther, request, onComplete); + } + + /** + * Asynchronous version that only logs any received exceptions or failures. + * + * @param tag The tag to be added + * @param onComplete The optional action to be invoked when this operation asynchronously completes. + * Can be {@code null} if not interested in completion signal. + */ + public void addTag(@NonNull String tag, Action onComplete) { + if (getLogger().isDebugEnabled()) { + getLogger().debug("addTagAsync {}", tag); + } + + sendAsynchronously(getRestApiClient()::addTag, createTagRequest(tag), onComplete); + } + + /** + * Asynchronous version that only logs any received exceptions or failures. + * + * @param graph The graph to be logged + * @param onComplete The optional action to be invoked when this operation asynchronously completes. + * Can be {@code null} if not interested in completion signal. + */ + void logGraph(@NonNull String graph, Action onComplete) { + if (getLogger().isDebugEnabled()) { + getLogger().debug("logGraphAsync {}", graph); + } + + sendAsynchronously(getRestApiClient()::logGraph, createGraphRequest(graph), onComplete); + } + + /** + * Asynchronous version that only logs any received exceptions or failures. + * + * @param startTimeMillis When you want to say that the experiment started + * @param onComplete The optional action to be invoked when this operation asynchronously completes. + * Can be {@code null} if not interested in completion signal. + */ + void logStartTime(long startTimeMillis, Action onComplete) { + if (getLogger().isDebugEnabled()) { + getLogger().debug("logStartTimeAsync {}", startTimeMillis); + } + + sendAsynchronously(getRestApiClient()::logStartEndTime, createLogStartTimeRequest(startTimeMillis), onComplete); + } + + /** + * Asynchronous version that only logs any received exceptions or failures. + * + * @param endTimeMillis When you want to say that the experiment ended + * @param onComplete The optional action to be invoked when this operation asynchronously completes. + * Can be {@code null} if not interested in completion signal. + */ + void logEndTime(long endTimeMillis, Action onComplete) { + if (getLogger().isDebugEnabled()) { + getLogger().debug("logEndTimeAsync {}", endTimeMillis); + } + + sendAsynchronously(getRestApiClient()::logStartEndTime, createLogEndTimeRequest(endTimeMillis), onComplete); + } + + /** + * Asynchronous version that only logs any received exceptions or failures. + * + * @param gitMetadata The Git Metadata for the experiment. + * @param onComplete The optional action to be invoked when this operation asynchronously completes. + * Can be {@code null} if not interested in completion signal. + */ + void logGitMetadataAsync(@NonNull GitMetadata gitMetadata, Action onComplete) { + if (getLogger().isDebugEnabled()) { + getLogger().debug("logGitMetadata {}", gitMetadata); + } + + sendAsynchronously(getRestApiClient()::logGitMetadata, gitMetadata, onComplete); + } + + /** + * Asynchronous version that only logs any received exceptions or failures. + * + * @param line Text to be logged + * @param offset Offset describes the place for current text to be inserted + * @param stderr the flag to indicate if this is StdErr message. + * @param context the context to be associated with the parameter. + * @param onComplete The optional action to be invoked when this operation asynchronously completes. + * Can be {@code null} if not interested in completion signal. + */ + void logLine(String line, long offset, boolean stderr, String context, Action onComplete) { + OutputUpdate request = createLogLineRequest(line, offset, stderr, context); + Single single = validateAndGetExperimentKey() + .subscribeOn(Schedulers.io()) + .concatMap(experimentKey -> getRestApiClient().logOutputLine(request, experimentKey)); + + // register notification action if provided + if (onComplete != null) { + single = single.doFinally(onComplete); + } + + // subscribe to receive operation results but do not log anything + single.subscribe(); + } + + /** + * Asynchronous version that only logs any received exceptions or failures. + * + * @param folder the folder you want to log. + * @param logFilePath if {@code true}, log the file path with each file. + * @param recursive if {@code true}, recurse the folder. + * @param prefixWithFolderName if {@code true} then path of each asset file will be prefixed with folder name + * in case if {@code logFilePath} is {@code true}. + * @param context the context to be associated with logged assets. + * @param onComplete onComplete The optional action to be invoked when this operation + * asynchronously completes. Can be {@code null} if not interested in completion signal. + */ + void logAssetFolder(@NonNull File folder, boolean logFilePath, boolean recursive, boolean prefixWithFolderName, + @NonNull ExperimentContext context, Action onComplete) { + if (!folder.isDirectory()) { + getLogger().error(getString(LOG_ASSET_FOLDER_EMPTY, folder)); + return; + } + this.updateContext(context); + + AtomicInteger count = new AtomicInteger(); + try { + Stream assets = AssetUtils.walkFolderAssets(folder, logFilePath, recursive, prefixWithFolderName) + .peek(asset -> { + asset.setExperimentContext(this.baseContext); + asset.setType(ASSET_TYPE_ASSET); + count.incrementAndGet(); + }); + + // create parallel execution flow with errors delaying + // allowing processing of items even if some of them failed + Observable observable = + Observable.fromStream(assets) + .flatMap(asset -> Observable.fromSingle(sendAssetAsync(asset)), true); + + // register on completion action + if (onComplete != null) { + observable = observable.doFinally(onComplete); + } + + // subscribe for processing results + observable + .ignoreElements() // ignore items which already processed, see: logAsset + .subscribe( + () -> getLogger().info( + getString(ASSETS_FOLDER_UPLOAD_COMPLETED, folder, count.get())), + (throwable) -> getLogger().error( + getString(FAILED_TO_LOG_SOME_ASSET_FROM_FOLDER, folder), throwable), + disposables); + } catch (Throwable t) { + getLogger().error(getString(FAILED_TO_LOG_ASSET_FOLDER, folder), t); + } + } + + /** + * Asynchronous version that only logs any received exceptions or failures. + * + * @param file The file asset to be stored + * @param fileName The file name under which the asset should be stored in Comet. E.g. "someFile.txt" + * @param overwrite Whether to overwrite files of the same name in Comet + * @param onComplete onComplete The optional action to be invoked when this operation asynchronously completes. + * Can be {@code null} if not interested in completion signal. + */ + void uploadAsset(@NonNull File file, @NonNull String fileName, + boolean overwrite, @NonNull ExperimentContext context, Action onComplete) { + this.updateContext(context); + + Asset asset = new Asset(); + asset.setFile(file); + asset.setFileName(fileName); + asset.setOverwrite(overwrite); + asset.setType(ASSET_TYPE_ASSET); + + this.logAsset(asset, onComplete); + } + + /** + * Asynchronous version that only logs any received exceptions or failures. + * + * @param code Code to be sent to Comet + * @param fileName Name of source file to be displayed on UI 'code' tab + * @param context the context to be associated with the asset. + * @param onComplete onComplete The optional action to be invoked when this operation asynchronously completes. + * Can be {@code null} if not interested in completion signal. + */ + void logCode(@NonNull String code, @NonNull String fileName, + @NonNull ExperimentContext context, Action onComplete) { + this.updateContext(context); + + Asset asset = new Asset(); + asset.setFileLikeData(code.getBytes(StandardCharsets.UTF_8)); + asset.setFileName(fileName); + asset.setType(ASSET_TYPE_SOURCE_CODE); + + this.logAsset(asset, onComplete); + } + + /** + * Asynchronous version that only logs any received exceptions or failures. + * + * @param file Asset with source code to be sent + * @param context the context to be associated with the asset. + * @param onComplete onComplete The optional action to be invoked when this operation asynchronously completes. + * Can be {@code null} if not interested in completion signal. + */ + void logCode(@NonNull File file, @NonNull ExperimentContext context, Action onComplete) { + this.updateContext(context); + + Asset asset = new Asset(); + asset.setFile(file); + asset.setFileName(file.getName()); + asset.setType(ASSET_TYPE_SOURCE_CODE); + + this.logAsset(asset, onComplete); + } + + /** + * Asynchronously logs provided asset and signals upload completion if {@code onComplete} action provided. + * + * @param asset the {@link Asset} to be uploaded. + * @param onComplete the optional {@link Action} to be called upon operation completed, + * either successful or failure. + */ + void logAsset(@NonNull final Asset asset, Action onComplete) { + asset.setExperimentContext(this.baseContext); + + Single single = this.sendAssetAsync(asset); + if (onComplete != null) { + single = single.doFinally(onComplete); + } + + // subscribe to get operation completed + single.subscribe( + (logDataResponse) -> { + // ignore - already logged, see: sendAssetAsync + }, + (throwable) -> { + // ignore - already logged, see: sendAssetAsync + }, + disposables); + } + + /** + * Attempts to send given {@link Asset} asynchronously. + * This method will wrap send operation into {@link Single} and transparently log any errors that may happen. + * + * @param asset the {@link Asset} to be sent. + * @return the {@link Single} which can be used to subscribe for operation results. + */ + private Single sendAssetAsync(@NonNull final Asset asset) { + return validateAndGetExperimentKey() + .subscribeOn(Schedulers.io()) + .concatMap(experimentKey -> getRestApiClient().logAsset(asset, experimentKey)) + .doOnSuccess(logDataResponse -> + AsyncDataResponseLogger.checkAndLog(logDataResponse, getLogger(), asset)) + .doOnError(throwable -> + getLogger().error(getString(FAILED_TO_SEND_LOG_REQUEST, asset), throwable)); + } + + /** + * Uses provided function to send request data asynchronously and log received output. Optionally, can use + * provided {@link Action} handler to notify about completion of the operation. + * + * @param func the function to be invoked to send request data. + * @param request the request data object. + * @param onComplete the optional {@link Action} to be notified the operation completes either + * successfully or erroneously. + * @param the type of the request data object. + */ + private void sendAsynchronously(final BiFunction> func, + final T request, final Action onComplete) { + Single single = validateAndGetExperimentKey() + .subscribeOn(Schedulers.io()) + .concatMap(experimentKey -> func.apply(request, experimentKey)); + + // register notification action if provided + if (onComplete != null) { + single = single.doFinally(onComplete); + } + + // subscribe to receive operation results + single + .observeOn(Schedulers.single()) + .subscribe( + (logDataResponse) -> AsyncDataResponseLogger.checkAndLog(logDataResponse, getLogger(), request), + (throwable) -> getLogger().error(getString(FAILED_TO_SEND_LOG_REQUEST, request), throwable), + disposables); + } + + /** + * Utility class to log asynchronously received data responses. + */ + static final class AsyncDataResponseLogger { + static void checkAndLog(LogDataResponse logDataResponse, Logger logger, Object request) { + if (logDataResponse.hasFailed()) { + logger.error("failed to log {}, reason: {}", request, logDataResponse.getMsg()); + } else if (logger.isDebugEnabled()) { + logger.debug("success {}", logDataResponse); + } + } + } +} diff --git a/comet-java-client/src/main/java/ml/comet/experiment/impl/CometApiImpl.java b/comet-java-client/src/main/java/ml/comet/experiment/impl/CometApiImpl.java index bd8fca99..2d8a42d2 100644 --- a/comet-java-client/src/main/java/ml/comet/experiment/impl/CometApiImpl.java +++ b/comet-java-client/src/main/java/ml/comet/experiment/impl/CometApiImpl.java @@ -103,6 +103,7 @@ static final class CometApiBuilderImpl implements CometApiBuilder { private String apiKey; private Logger logger; + @Override public CometApiBuilder withConfigOverride(@NonNull File overrideConfig) { CometConfig.applyConfigOverride(overrideConfig); return this; @@ -114,6 +115,7 @@ public BaseCometBuilder withLogger(@NonNull Logger logger) { return this; } + @Override public CometApiBuilder withApiKey(@NonNull String apiKey) { this.apiKey = apiKey; return this; @@ -125,7 +127,7 @@ public CometApiBuilder withApiKey(@NonNull String apiKey) { * @return the fully initialized instance of the CometApiImpl. */ public CometApi build() { - if (StringUtils.isEmpty(this.apiKey)) { + if (StringUtils.isBlank(this.apiKey)) { this.apiKey = COMET_API_KEY.getString(); } return new CometApiImpl( diff --git a/comet-java-client/src/main/java/ml/comet/experiment/impl/OnlineExperimentBuilderImpl.java b/comet-java-client/src/main/java/ml/comet/experiment/impl/OnlineExperimentBuilderImpl.java new file mode 100644 index 00000000..ba069f72 --- /dev/null +++ b/comet-java-client/src/main/java/ml/comet/experiment/impl/OnlineExperimentBuilderImpl.java @@ -0,0 +1,132 @@ +package ml.comet.experiment.impl; + +import lombok.NonNull; +import ml.comet.experiment.builder.OnlineExperimentBuilder; +import ml.comet.experiment.impl.config.CometConfig; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; + +import java.io.File; +import java.time.Duration; + +import static ml.comet.experiment.impl.config.CometConfig.COMET_API_KEY; +import static ml.comet.experiment.impl.config.CometConfig.COMET_BASE_URL; +import static ml.comet.experiment.impl.config.CometConfig.COMET_MAX_AUTH_RETRIES; +import static ml.comet.experiment.impl.config.CometConfig.COMET_PROJECT_NAME; +import static ml.comet.experiment.impl.config.CometConfig.COMET_TIMEOUT_CLEANING_SECONDS; +import static ml.comet.experiment.impl.config.CometConfig.COMET_WORKSPACE_NAME; + +/** + * The builder to create properly configured instance of the OnlineExperimentImpl. + */ +final class OnlineExperimentBuilderImpl implements OnlineExperimentBuilder { + private String projectName; + private String workspace; + private String apiKey; + private String baseUrl; + private int maxAuthRetries = -1; + private String experimentName; + private String experimentKey; + private Logger logger; + private boolean interceptStdout = false; + + /** + * Default constructor to avoid direct initialization from the outside. + */ + OnlineExperimentBuilderImpl() { + } + + @Override + public OnlineExperimentBuilderImpl withProjectName(@NonNull String projectName) { + this.projectName = projectName; + return this; + } + + @Override + public OnlineExperimentBuilderImpl withWorkspace(@NonNull String workspace) { + this.workspace = workspace; + return this; + } + + @Override + public OnlineExperimentBuilderImpl withApiKey(@NonNull String apiKey) { + this.apiKey = apiKey; + return this; + } + + @Override + public OnlineExperimentBuilderImpl withMaxAuthRetries(int maxAuthRetries) { + this.maxAuthRetries = maxAuthRetries; + return this; + } + + @Override + public OnlineExperimentBuilderImpl withUrlOverride(@NonNull String urlOverride) { + this.baseUrl = urlOverride; + return this; + } + + @Override + public OnlineExperimentBuilderImpl withExperimentName(@NonNull String experimentName) { + this.experimentName = experimentName; + return this; + } + + @Override + public OnlineExperimentBuilderImpl withExistingExperimentKey(@NonNull String experimentKey) { + this.experimentKey = experimentKey; + return this; + } + + @Override + public OnlineExperimentBuilderImpl withLogger(@NonNull Logger logger) { + this.logger = logger; + return this; + } + + @Override + public OnlineExperimentBuilderImpl withConfigOverride(@NonNull File overrideConfig) { + CometConfig.applyConfigOverride(overrideConfig); + return this; + } + + @Override + public OnlineExperimentBuilderImpl interceptStdout() { + this.interceptStdout = true; + return this; + } + + @Override + public OnlineExperimentImpl build() { + + if (StringUtils.isBlank(this.apiKey)) { + this.apiKey = COMET_API_KEY.getString(); + } + if (StringUtils.isBlank(this.projectName)) { + this.projectName = COMET_PROJECT_NAME.getOptionalString().orElse(null); + } + if (StringUtils.isBlank(this.workspace)) { + this.workspace = COMET_WORKSPACE_NAME.getOptionalString().orElse(null); + } + if (StringUtils.isBlank(this.baseUrl)) { + this.baseUrl = COMET_BASE_URL.getString(); + } + if (this.maxAuthRetries == -1) { + this.maxAuthRetries = COMET_MAX_AUTH_RETRIES.getInt(); + } + Duration cleaningTimeout = COMET_TIMEOUT_CLEANING_SECONDS.getDuration(); + + OnlineExperimentImpl experiment = new OnlineExperimentImpl( + this.apiKey, this.projectName, this.workspace, this.experimentName, this.experimentKey, + this.logger, this.interceptStdout, this.baseUrl, this.maxAuthRetries, cleaningTimeout); + try { + // initialize experiment + experiment.init(); + } catch (Throwable ex) { + // release hold resources and signal to user about failure + experiment.end(); + throw ex; + } + return experiment; + } +} diff --git a/comet-java-client/src/main/java/ml/comet/experiment/impl/OnlineExperimentImpl.java b/comet-java-client/src/main/java/ml/comet/experiment/impl/OnlineExperimentImpl.java index a5710616..276d0dfd 100644 --- a/comet-java-client/src/main/java/ml/comet/experiment/impl/OnlineExperimentImpl.java +++ b/comet-java-client/src/main/java/ml/comet/experiment/impl/OnlineExperimentImpl.java @@ -3,12 +3,10 @@ import lombok.Getter; import lombok.NonNull; import ml.comet.experiment.OnlineExperiment; -import ml.comet.experiment.builder.OnlineExperimentBuilder; -import ml.comet.experiment.impl.config.CometConfig; +import ml.comet.experiment.context.ExperimentContext; import ml.comet.experiment.impl.log.StdOutLogger; import ml.comet.experiment.model.ExperimentStatusResponse; import ml.comet.experiment.model.GitMetadata; -import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -22,17 +20,13 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import static ml.comet.experiment.impl.config.CometConfig.COMET_API_KEY; -import static ml.comet.experiment.impl.config.CometConfig.COMET_BASE_URL; -import static ml.comet.experiment.impl.config.CometConfig.COMET_MAX_AUTH_RETRIES; -import static ml.comet.experiment.impl.config.CometConfig.COMET_PROJECT_NAME; -import static ml.comet.experiment.impl.config.CometConfig.COMET_TIMEOUT_CLEANING_SECONDS; -import static ml.comet.experiment.impl.config.CometConfig.COMET_WORKSPACE_NAME; +import static ml.comet.experiment.impl.resources.LogMessages.EXPERIMENT_HEARTBEAT_STOPPED_PROMPT; +import static ml.comet.experiment.impl.resources.LogMessages.getString; /** * The implementation of the {@link OnlineExperiment} to work with Comet API asynchronously. */ -public final class OnlineExperimentImpl extends BaseExperiment implements OnlineExperiment { +public final class OnlineExperimentImpl extends BaseExperimentAsync implements OnlineExperiment { private static final int SCHEDULED_EXECUTOR_TERMINATION_WAIT_SEC = 60; private static final int STD_OUT_LOGGER_FLUSH_WAIT_DELAY_MS = 2000; @@ -83,7 +77,6 @@ public final class OnlineExperimentImpl extends BaseExperiment implements Online if (logger != null) { this.logger = logger; } - this.init(); } @Override @@ -96,7 +89,7 @@ public void end() { if (!heartbeatSendFuture.cancel(true)) { this.logger.error("failed to stop experiment's heartbeat sender"); } else { - this.logger.info("Experiment's heartbeat sender stopped"); + this.logger.info(getString(EXPERIMENT_HEARTBEAT_STOPPED_PROMPT)); } heartbeatSendFuture = null; } @@ -124,6 +117,16 @@ public void end() { super.end(); } + /** + * Allows using {@link OnlineExperiment} with try-with-resources statement with automatic closing after usage. + * + * @throws Exception if an exception occurs. + */ + @Override + public void close() throws Exception { + this.end(); + } + @Override public void setInterceptStdout() throws IOException { if (!interceptStdout) { @@ -147,12 +150,50 @@ public void stopInterceptStdout() throws IOException { @Override public void nextStep() { - this.step++; + this.setStep(this.getStep() + 1); + } + + @Override + public long getStep() { + if (this.baseContext.getStep() != null) { + return this.baseContext.getStep(); + } else { + return 0; + } + } + + @Override + public void setStep(long step) { + this.baseContext.setStep(step); } @Override public void nextEpoch() { - this.epoch++; + this.setEpoch(this.getEpoch() + 1); + } + + @Override + public long getEpoch() { + if (this.baseContext.getEpoch() != null) { + return this.baseContext.getEpoch(); + } else { + return 0; + } + } + + @Override + public void setEpoch(long epoch) { + this.baseContext.setEpoch(epoch); + } + + @Override + public void setContext(String context) { + this.baseContext.setContext(context); + } + + @Override + public String getContext() { + return this.baseContext.getContext(); } @Override @@ -161,61 +202,71 @@ public Optional getExperimentLink() { } @Override - public void logMetric(@NonNull String metricName, @NonNull Object metricValue) { - this.logMetric(metricName, metricValue, this.step, this.epoch); + public void logMetric(@NonNull String metricName, @NonNull Object metricValue, @NonNull ExperimentContext context) { + this.logMetric(metricName, metricValue, context, null); + } + + @Override + public void logMetric(String metricName, Object metricValue, long step, long epoch) { + this.logMetric(metricName, metricValue, + new ExperimentContext(step, epoch, this.getContext())); } @Override - public void logMetric(@NonNull String metricName, @NonNull Object metricValue, long step) { - this.logMetric(metricName, metricValue, step, this.epoch); + public void logMetric(String metricName, Object metricValue, long step) { + this.logMetric(metricName, metricValue, + new ExperimentContext(step, this.getEpoch(), this.getContext())); } @Override - public void logMetric(@NonNull String metricName, @NonNull Object metricValue, long step, long epoch) { - this.setStep(step); - this.setEpoch(epoch); - this.logMetricAsync(metricName, metricValue, step, epoch, null); + public void logMetric(String metricName, Object metricValue) { + this.logMetric(metricName, metricValue, this.baseContext); } @Override public void logParameter(@NonNull String parameterName, @NonNull Object paramValue) { - this.logParameter(parameterName, paramValue, this.step); + this.logParameter(parameterName, paramValue, this.baseContext); } @Override public void logParameter(@NonNull String parameterName, @NonNull Object paramValue, long step) { - this.setStep(step); - this.logParameterAsync(parameterName, paramValue, step, null); + this.logParameter(parameterName, paramValue, + new ExperimentContext(step, this.getEpoch(), this.getContext())); + } + + @Override + public void logParameter(String parameterName, Object paramValue, @NonNull ExperimentContext context) { + this.logParameter(parameterName, paramValue, context, null); } @Override public void logHtml(@NonNull String html, boolean override) { - this.logHtmlAsync(html, override, null); + this.logHtml(html, override, null); } @Override public void logOther(@NonNull String key, @NonNull Object value) { - this.logOtherAsync(key, value, null); + this.logOther(key, value, null); } @Override public void addTag(@NonNull String tag) { - this.addTagAsync(tag, null); + this.addTag(tag, null); } @Override public void logGraph(@NonNull String graph) { - this.logGraphAsync(graph, null); + this.logGraph(graph, null); } @Override public void logStartTime(long startTimeMillis) { - this.logStartTimeAsync(startTimeMillis, null); + this.logStartTime(startTimeMillis, null); } @Override public void logEndTime(long endTimeMillis) { - this.logEndTimeAsync(endTimeMillis, null); + this.logEndTime(endTimeMillis, null); } @Override @@ -224,18 +275,81 @@ public void logGitMetadata(GitMetadata gitMetadata) { } @Override + public void logLine(String line, long offset, boolean stderr) { + this.logLine(line, offset, stderr, this.getContext()); + } + + @Override + public void logLine(String line, long offset, boolean stderr, String context) { + this.setContext(context); + this.logLine(line, offset, stderr, context, null); + } + + @Override + public void logAssetFolder(File folder, boolean logFilePath, boolean recursive, ExperimentContext context) { + this.logAssetFolder(folder, logFilePath, recursive, true, context, null); + } + + @Override + public void logAssetFolder(File folder, boolean logFilePath, boolean recursive) { + this.logAssetFolder(folder, logFilePath, recursive, this.baseContext); + } + + @Override + public void logAssetFolder(File folder, boolean logFilePath) { + this.logAssetFolder(folder, logFilePath, false); + } + + @Override + public void uploadAsset(@NonNull File asset, @NonNull String fileName, + boolean overwrite, @NonNull ExperimentContext context) { + this.uploadAsset(asset, fileName, overwrite, context, null); + } + + @Override + public void uploadAsset(File asset, boolean overwrite, ExperimentContext context) { + this.uploadAsset(asset, asset.getName(), overwrite, context); + } + + @Override + public void uploadAsset(@NonNull File asset, boolean overwrite, long step, long epoch) { + this.uploadAsset(asset, asset.getName(), overwrite, + new ExperimentContext(step, epoch, getContext())); + } + public void uploadAsset(@NonNull File asset, @NonNull String fileName, boolean overwrite, long step) { - super.uploadAsset(asset, fileName, overwrite, step, this.epoch); + this.uploadAsset(asset, fileName, overwrite, + new ExperimentContext(step, this.getEpoch(), this.getContext())); + } + + @Override + public void uploadAsset(@NonNull File asset, @NonNull String fileName, boolean overwrite) { + this.uploadAsset(asset, fileName, overwrite, this.baseContext); } @Override public void uploadAsset(@NonNull File asset, boolean overwrite) { - uploadAsset(asset, asset.getName(), overwrite); + this.uploadAsset(asset, asset.getName(), overwrite, this.baseContext); } @Override - public void uploadAsset(@NonNull File asset, @NonNull String fileName, boolean overwrite) { - super.uploadAsset(asset, fileName, overwrite, this.step, this.epoch); + public void logCode(@NonNull String code, @NonNull String fileName, @NonNull ExperimentContext context) { + this.logCode(code, fileName, context, null); + } + + @Override + public void logCode(@NonNull File file, @NonNull ExperimentContext context) { + this.logCode(file, context, null); + } + + @Override + public void logCode(@NonNull String code, @NonNull String fileName) { + this.logCode(code, fileName, this.baseContext); + } + + @Override + public void logCode(@NonNull File file) { + this.logCode(file, this.baseContext); } @Override @@ -319,110 +433,4 @@ public void run() { public static OnlineExperimentBuilderImpl builder() { return new OnlineExperimentBuilderImpl(); } - - /** - * The builder to create properly configured instance of the OnlineExperimentImpl. - */ - public static final class OnlineExperimentBuilderImpl implements OnlineExperimentBuilder { - private String projectName; - private String workspace; - private String apiKey; - private String baseUrl; - private int maxAuthRetries = -1; - private String experimentName; - private String experimentKey; - private Logger logger; - private boolean interceptStdout = false; - - /** - * Default constructor to avoid direct initialization from the outside. - */ - private OnlineExperimentBuilderImpl() { - } - - @Override - public OnlineExperimentBuilderImpl withProjectName(@NonNull String projectName) { - this.projectName = projectName; - return this; - } - - @Override - public OnlineExperimentBuilderImpl withWorkspace(@NonNull String workspace) { - this.workspace = workspace; - return this; - } - - @Override - public OnlineExperimentBuilderImpl withApiKey(@NonNull String apiKey) { - this.apiKey = apiKey; - return this; - } - - @Override - public OnlineExperimentBuilderImpl withMaxAuthRetries(int maxAuthRetries) { - this.maxAuthRetries = maxAuthRetries; - return this; - } - - @Override - public OnlineExperimentBuilderImpl withUrlOverride(@NonNull String urlOverride) { - this.baseUrl = urlOverride; - return this; - } - - @Override - public OnlineExperimentBuilderImpl withExperimentName(@NonNull String experimentName) { - this.experimentName = experimentName; - return this; - } - - @Override - public OnlineExperimentBuilderImpl withExistingExperimentKey(@NonNull String experimentKey) { - this.experimentKey = experimentKey; - return this; - } - - @Override - public OnlineExperimentBuilderImpl withLogger(@NonNull Logger logger) { - this.logger = logger; - return this; - } - - @Override - public OnlineExperimentBuilderImpl withConfigOverride(@NonNull File overrideConfig) { - CometConfig.applyConfigOverride(overrideConfig); - return this; - } - - @Override - public OnlineExperimentBuilderImpl interceptStdout() { - this.interceptStdout = true; - return this; - } - - @Override - public OnlineExperimentImpl build() { - - if (StringUtils.isEmpty(this.apiKey)) { - this.apiKey = COMET_API_KEY.getString(); - } - if (StringUtils.isEmpty(this.projectName)) { - this.projectName = COMET_PROJECT_NAME.getOptionalString().orElse(null); - } - if (StringUtils.isEmpty(this.workspace)) { - this.workspace = COMET_WORKSPACE_NAME.getOptionalString().orElse(null); - } - if (StringUtils.isEmpty(this.baseUrl)) { - this.baseUrl = COMET_BASE_URL.getString(); - } - if (this.maxAuthRetries == -1) { - this.maxAuthRetries = COMET_MAX_AUTH_RETRIES.getInt(); - } - Duration cleaningTimeout = COMET_TIMEOUT_CLEANING_SECONDS.getDuration(); - - return new OnlineExperimentImpl( - this.apiKey, this.projectName, this.workspace, this.experimentName, this.experimentKey, - this.logger, this.interceptStdout, this.baseUrl, this.maxAuthRetries, cleaningTimeout); - } - } } diff --git a/comet-java-client/src/main/java/ml/comet/experiment/impl/RestApiClient.java b/comet-java-client/src/main/java/ml/comet/experiment/impl/RestApiClient.java index a67be113..de56bfeb 100644 --- a/comet-java-client/src/main/java/ml/comet/experiment/impl/RestApiClient.java +++ b/comet-java-client/src/main/java/ml/comet/experiment/impl/RestApiClient.java @@ -4,7 +4,9 @@ import io.reactivex.rxjava3.disposables.Disposable; import lombok.NonNull; import ml.comet.experiment.exception.CometApiException; -import ml.comet.experiment.impl.constants.AssetType; +import ml.comet.experiment.impl.asset.Asset; +import ml.comet.experiment.impl.asset.AssetType; +import ml.comet.experiment.impl.constants.FormParamName; import ml.comet.experiment.impl.constants.QueryParamName; import ml.comet.experiment.impl.http.Connection; import ml.comet.experiment.impl.utils.JsonUtils; @@ -33,10 +35,12 @@ import ml.comet.experiment.model.ParameterRest; import ml.comet.experiment.model.TagsResponse; +import java.io.File; import java.util.Collections; import java.util.HashMap; import java.util.Map; +import static ml.comet.experiment.impl.constants.ApiEndpoints.ADD_ASSET; import static ml.comet.experiment.impl.constants.ApiEndpoints.ADD_GIT_METADATA; import static ml.comet.experiment.impl.constants.ApiEndpoints.ADD_GRAPH; import static ml.comet.experiment.impl.constants.ApiEndpoints.ADD_HTML; @@ -61,10 +65,18 @@ import static ml.comet.experiment.impl.constants.ApiEndpoints.PROJECTS; import static ml.comet.experiment.impl.constants.ApiEndpoints.SET_EXPERIMENT_STATUS; import static ml.comet.experiment.impl.constants.ApiEndpoints.WORKSPACES; +import static ml.comet.experiment.impl.constants.QueryParamName.CONTEXT; +import static ml.comet.experiment.impl.constants.QueryParamName.EPOCH; import static ml.comet.experiment.impl.constants.QueryParamName.EXPERIMENT_KEY; +import static ml.comet.experiment.impl.constants.QueryParamName.EXTENSION; +import static ml.comet.experiment.impl.constants.QueryParamName.FILE_NAME; +import static ml.comet.experiment.impl.constants.QueryParamName.IS_REMOTE; +import static ml.comet.experiment.impl.constants.QueryParamName.OVERWRITE; import static ml.comet.experiment.impl.constants.QueryParamName.PROJECT_ID; +import static ml.comet.experiment.impl.constants.QueryParamName.STEP; import static ml.comet.experiment.impl.constants.QueryParamName.TYPE; import static ml.comet.experiment.impl.constants.QueryParamName.WORKSPACE_NAME; +import static ml.comet.experiment.impl.utils.CometUtils.putNotNull; /** * Represents Comet REST API client providing access to all exposed REST endpoints. @@ -190,6 +202,71 @@ Single registerExperiment(final CreateExperimentReques return singleFromSyncPost(request, NEW_EXPERIMENT, true, CreateExperimentResponse.class); } + Single logAsset(final Asset asset, String experimentKey) { + // populate query parameters + HashMap queryParams = new HashMap() {{ + put(EXPERIMENT_KEY, experimentKey); + put(TYPE, asset.getType().type()); + }}; + putNotNull(queryParams, OVERWRITE, asset.getOverwrite()); + putNotNull(queryParams, IS_REMOTE, asset.getRemote()); + putNotNull(queryParams, FILE_NAME, asset.getFileName()); + putNotNull(queryParams, EXTENSION, asset.getFileExtension()); + putNotNull(queryParams, CONTEXT, asset.getContext()); + putNotNull(queryParams, STEP, asset.getStep()); + putNotNull(queryParams, EPOCH, asset.getEpoch()); + + // populate form parameters + HashMap formParams = null; + if (asset.getMetadata() != null) { + // encode metadata to JSON + formParams = new HashMap() {{ + put(FormParamName.METADATA, JsonUtils.toJson(asset.getMetadata())); + }}; + } + + // call appropriate send method + if (asset.getFile() != null) { + return singleFromAsyncPost(asset.getFile(), ADD_ASSET, queryParams, + formParams, LogDataResponse.class); + } else if (asset.getFileLikeData() != null) { + return singleFromAsyncPost(asset.getFileLikeData(), ADD_ASSET, queryParams, + formParams, LogDataResponse.class); + } + + // no data response + LogDataResponse response = new LogDataResponse(); + response.setMsg("asset has no data"); + response.setCode(-1); + return Single.just(response); + } + + private Single singleFromAsyncPost( + byte[] fileLikeData, @NonNull String endpoint, + @NonNull HashMap queryParams, + HashMap formParams, @NonNull Class clazz) { + if (isDisposed()) { + return Single.error(ALREADY_DISPOSED); + } + + return Single.fromFuture(this.connection.sendPostAsync(fileLikeData, endpoint, queryParams, formParams)) + .onTerminateDetach() + .map(response -> JsonUtils.fromJson(response.getResponseBody(), clazz)); + } + + private Single singleFromAsyncPost( + @NonNull File file, @NonNull String endpoint, + @NonNull HashMap queryParams, + HashMap formParams, @NonNull Class clazz) { + if (isDisposed()) { + return Single.error(ALREADY_DISPOSED); + } + + return Single.fromFuture(this.connection.sendPostAsync(file, endpoint, queryParams, formParams)) + .onTerminateDetach() + .map(response -> JsonUtils.fromJson(response.getResponseBody(), clazz)); + } + private Single singleFromAsyncPost( @NonNull Object payload, @NonNull String endpoint, @NonNull Class clazz) { if (isDisposed()) { diff --git a/comet-java-client/src/main/java/ml/comet/experiment/impl/asset/Asset.java b/comet-java-client/src/main/java/ml/comet/experiment/impl/asset/Asset.java new file mode 100644 index 00000000..6f0ccdee --- /dev/null +++ b/comet-java-client/src/main/java/ml/comet/experiment/impl/asset/Asset.java @@ -0,0 +1,42 @@ +package ml.comet.experiment.impl.asset; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import ml.comet.experiment.context.ExperimentContext; + +import java.io.File; +import java.util.Map; + +/** + * Describes asset data. + */ +@Data +@NoArgsConstructor +@AllArgsConstructor +public class Asset { + private File file; + private byte[] fileLikeData; + private String fileName; + private AssetType type; + private Boolean overwrite; + private Long step; + private Long epoch; + private String groupingName; + private Map metadata; + private String assetId; + private String fileExtension; + private String context; + private Boolean remote; + + /** + * Updates this asset with values from provided {@link ExperimentContext}. + * + * @param context the {@link ExperimentContext} with context values. + */ + public void setExperimentContext(ExperimentContext context) { + this.step = context.getStep(); + this.epoch = context.getEpoch(); + this.context = context.getContext(); + } +} diff --git a/comet-java-client/src/main/java/ml/comet/experiment/impl/constants/AssetType.java b/comet-java-client/src/main/java/ml/comet/experiment/impl/asset/AssetType.java similarity index 97% rename from comet-java-client/src/main/java/ml/comet/experiment/impl/constants/AssetType.java rename to comet-java-client/src/main/java/ml/comet/experiment/impl/asset/AssetType.java index d29125f6..b8bd6665 100644 --- a/comet-java-client/src/main/java/ml/comet/experiment/impl/constants/AssetType.java +++ b/comet-java-client/src/main/java/ml/comet/experiment/impl/asset/AssetType.java @@ -1,4 +1,4 @@ -package ml.comet.experiment.impl.constants; +package ml.comet.experiment.impl.asset; /** * Represents known types of the assets. diff --git a/comet-java-client/src/main/java/ml/comet/experiment/impl/config/EnvironmentConfig.java b/comet-java-client/src/main/java/ml/comet/experiment/impl/config/EnvironmentConfig.java index eb495fd6..9e03b0eb 100644 --- a/comet-java-client/src/main/java/ml/comet/experiment/impl/config/EnvironmentConfig.java +++ b/comet-java-client/src/main/java/ml/comet/experiment/impl/config/EnvironmentConfig.java @@ -19,11 +19,11 @@ final class EnvironmentConfig { * @throws IllegalStateException if optionName is */ public Optional getEnvVariable(String optionName) throws IllegalStateException { - if (StringUtils.isEmpty(optionName)) { + if (StringUtils.isBlank(optionName)) { throw new IllegalStateException("optionName is empty"); } String res = System.getenv(optionName); - if (StringUtils.isEmpty(res)) { + if (StringUtils.isBlank(res)) { return Optional.empty(); } return Optional.of(res); diff --git a/comet-java-client/src/main/java/ml/comet/experiment/impl/constants/ApiEndpoints.java b/comet-java-client/src/main/java/ml/comet/experiment/impl/constants/ApiEndpoints.java index d7e8d259..a0085624 100644 --- a/comet-java-client/src/main/java/ml/comet/experiment/impl/constants/ApiEndpoints.java +++ b/comet-java-client/src/main/java/ml/comet/experiment/impl/constants/ApiEndpoints.java @@ -1,8 +1,11 @@ package ml.comet.experiment.impl.constants; +import lombok.experimental.UtilityClass; + /** * Definitions of the Comet API endpoints. */ +@UtilityClass public final class ApiEndpoints { public static final String UPDATE_API_URL = "/api/rest/v2/write"; public static final String ADD_OUTPUT = UPDATE_API_URL + "/experiment/output"; diff --git a/comet-java-client/src/main/java/ml/comet/experiment/impl/constants/FormParamName.java b/comet-java-client/src/main/java/ml/comet/experiment/impl/constants/FormParamName.java new file mode 100644 index 00000000..a621dd5b --- /dev/null +++ b/comet-java-client/src/main/java/ml/comet/experiment/impl/constants/FormParamName.java @@ -0,0 +1,27 @@ +package ml.comet.experiment.impl.constants; + +/** + * Enumeration of all known form parameter names of the REST endpoints. + */ +public enum FormParamName { + + TAGS("tags"), // string list + LINK("link"), // string + METADATA("metadata"), // json string + FILE("file"); // InputStream or FormDataContentDisposition + + private final String paramName; + + FormParamName(String paramName) { + this.paramName = paramName; + } + + public String paramName() { + return this.paramName; + } + + @Override + public String toString() { + return this.paramName; + } +} diff --git a/comet-java-client/src/main/java/ml/comet/experiment/impl/constants/QueryParamName.java b/comet-java-client/src/main/java/ml/comet/experiment/impl/constants/QueryParamName.java index 373095e4..44aeeae7 100644 --- a/comet-java-client/src/main/java/ml/comet/experiment/impl/constants/QueryParamName.java +++ b/comet-java-client/src/main/java/ml/comet/experiment/impl/constants/QueryParamName.java @@ -1,19 +1,25 @@ package ml.comet.experiment.impl.constants; /** - * The enumeration of all known query parameter names. + * The enumeration of all known query parameter names of the REST endpoints. */ public enum QueryParamName { - EXPERIMENT_KEY("experimentKey"), - FILE_NAME("fileName"), - CONTEXT("context"), - TYPE("type"), - OVERWRITE("overwrite"), - STEP("step"), - EPOCH("epoch"), - PROJECT_ID("projectId"), - WORKSPACE_NAME("workspaceName"); + EXPERIMENT_KEY("experimentKey"), // string + EXTENSION("extension"), // string + EPOCH("epoch"), // integer + STEP("step"), // integer + SOURCE("source"), // string + CONTEXT("context"), // string + TYPE("type"), // string + METADATA("metadata"), // json string + FILE_NAME("fileName"), // string + GROUPING_NAME("groupingName"), // string + ARTIFACT_VERSION_ID("artifactVersionId"), // string + IS_REMOTE("isRemote"), // boolean + OVERWRITE("overwrite"), // boolean + PROJECT_ID("projectId"), // string + WORKSPACE_NAME("workspaceName"); // string private final String paramName; diff --git a/comet-java-client/src/main/java/ml/comet/experiment/impl/http/Connection.java b/comet-java-client/src/main/java/ml/comet/experiment/impl/http/Connection.java index 86a0b555..d86d0fcc 100644 --- a/comet-java-client/src/main/java/ml/comet/experiment/impl/http/Connection.java +++ b/comet-java-client/src/main/java/ml/comet/experiment/impl/http/Connection.java @@ -4,6 +4,7 @@ import lombok.Value; import ml.comet.experiment.exception.CometApiException; import ml.comet.experiment.exception.CometGeneralException; +import ml.comet.experiment.impl.constants.FormParamName; import ml.comet.experiment.impl.constants.QueryParamName; import org.asynchttpclient.AsyncCompletionHandler; import org.asynchttpclient.AsyncHttpClient; @@ -126,35 +127,39 @@ public ListenableFuture sendPostAsync(@NonNull String json, @NonNull S /** * Allows asynchronous posting the content of the file as multipart form data to the specified endpoint. * - * @param file the file to be included. - * @param endpoint the relative path to the endpoint. - * @param params the request parameters + * @param file the file to be included. + * @param endpoint the relative path to the endpoint. + * @param queryParams the request query parameters + * @param formParams the form parameters * @return the ListenableFuture<Response> which can be used to monitor status of * the request execution. */ public ListenableFuture sendPostAsync(@NonNull File file, @NonNull String endpoint, - @NonNull Map params) { + @NonNull Map queryParams, + Map formParams) { return executeRequestAsync( - ConnectionUtils.createPostFileRequest(file, this.buildCometUrl(endpoint), params)); + ConnectionUtils.createPostFileRequest(file, this.buildCometUrl(endpoint), queryParams, formParams)); } /** * Allows asynchronous sending of provided byte array as POST request to the specified endpoint. * - * @param bytes the data array - * @param endpoint the relative path to the endpoint. - * @param params the request parameters map. + * @param bytes the data array + * @param endpoint the relative path to the endpoint. + * @param params the request parameters map. + * @param formParams the form parameters * @return the ListenableFuture<Response> which can be used to monitor status of * the request execution. */ public ListenableFuture sendPostAsync(byte[] bytes, @NonNull String endpoint, - @NonNull Map params) { + @NonNull Map params, + Map formParams) { String url = this.buildCometUrl(endpoint); if (logger.isDebugEnabled()) { logger.debug("sending POST bytearray with length {} to {}", bytes.length, url); } - return executeRequestAsync(ConnectionUtils.createPostByteArrayRequest(bytes, url, params)); + return executeRequestAsync(ConnectionUtils.createPostByteArrayRequest(bytes, url, params, formParams)); } /** @@ -263,7 +268,8 @@ Optional executeRequestSyncWithRetries( // attempt failed - check if to retry if (i < this.maxAuthRetries - 1) { // sleep for a while and repeat - this.logger.debug("for endpoint {} response {}, retrying\n", endpoint, response.getStatusText()); + this.logger.debug("for endpoint {} response {}, retrying\n", + endpoint, response.getStatusText()); Thread.sleep((2 ^ i) * 1000L); } else { // maximal number of attempts exceeded - throw or return diff --git a/comet-java-client/src/main/java/ml/comet/experiment/impl/http/ConnectionInitializer.java b/comet-java-client/src/main/java/ml/comet/experiment/impl/http/ConnectionInitializer.java index 2f85c7f8..cf80690b 100644 --- a/comet-java-client/src/main/java/ml/comet/experiment/impl/http/ConnectionInitializer.java +++ b/comet-java-client/src/main/java/ml/comet/experiment/impl/http/ConnectionInitializer.java @@ -20,7 +20,7 @@ public class ConnectionInitializer { * @return the properly initialized Connection instance. */ public Connection initConnection(String apiKey, String cometBaseUrl, int maxAuthRetries, Logger logger) { - if (StringUtils.isEmpty(apiKey)) { + if (StringUtils.isBlank(apiKey)) { throw new IllegalArgumentException("Api key required!"); } return new Connection(cometBaseUrl, apiKey, maxAuthRetries, logger); diff --git a/comet-java-client/src/main/java/ml/comet/experiment/impl/http/ConnectionUtils.java b/comet-java-client/src/main/java/ml/comet/experiment/impl/http/ConnectionUtils.java index d685efcb..be136e74 100644 --- a/comet-java-client/src/main/java/ml/comet/experiment/impl/http/ConnectionUtils.java +++ b/comet-java-client/src/main/java/ml/comet/experiment/impl/http/ConnectionUtils.java @@ -1,28 +1,35 @@ package ml.comet.experiment.impl.http; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; import lombok.NonNull; import lombok.Value; +import ml.comet.experiment.impl.constants.FormParamName; import ml.comet.experiment.impl.constants.QueryParamName; import org.asynchttpclient.Request; import org.asynchttpclient.RequestBuilder; import org.asynchttpclient.Response; import org.asynchttpclient.request.body.generator.ByteArrayBodyGenerator; import org.asynchttpclient.request.body.multipart.ByteArrayPart; +import org.asynchttpclient.request.body.multipart.FileLikePart; import org.asynchttpclient.request.body.multipart.FilePart; +import org.asynchttpclient.request.body.multipart.Part; +import org.asynchttpclient.request.body.multipart.StringPart; import org.asynchttpclient.util.HttpConstants; import org.slf4j.Logger; import java.io.File; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.function.Function; +import static ml.comet.experiment.impl.constants.FormParamName.FILE; + /** * Collection of the utilities used by Connection. */ public class ConnectionUtils { - static final String FILE = "file"; - static final String FORM_MIME_TYPE = "multipart/form-data"; - static final String JSON_MIME_TYPE = "application/json"; /** * Creates GET request to the given endpoint with specified query parameters. @@ -40,32 +47,37 @@ static Request createGetRequest(@NonNull String url, Map /** * Creates POST request from given file to the specified endpoint. * - * @param file the file to be included in the body parts. - * @param url the URL of the endpoint. - * @param params the query parameters of the request. + * @param file the file to be included in the body parts. + * @param url the URL of the endpoint. + * @param params the query parameters of the request. + * @param formParams the form parameters to be added. * @return the POST request with specified file. */ - static Request createPostFileRequest(@NonNull File file, @NonNull String url, Map params) { - return createRequestBuilder(HttpConstants.Methods.POST, params) + static Request createPostFileRequest(@NonNull File file, @NonNull String url, + Map params, + Map formParams) { + return createMultipartRequestBuilder( + new FilePart(FILE.paramName(), file), params, formParams) .setUrl(url) - .setHeader("Content-Type", FORM_MIME_TYPE) - .addBodyPart(new FilePart(FILE, file, FORM_MIME_TYPE)) .build(); } /** * Creates POST request from given byte array to the specified endpoint. * - * @param bytes the bytes array to include into request. - * @param url the URL of the endpoint. - * @param params the query parameters of the request. + * @param bytes the bytes array to include into request. + * @param url the URL of the endpoint. + * @param params the query parameters of the request. + * @param formParams the form parameters to be added * @return the POST request with specified byte array as body part. */ - static Request createPostByteArrayRequest(byte[] bytes, @NonNull String url, Map params) { - return createRequestBuilder(HttpConstants.Methods.POST, params) + static Request createPostByteArrayRequest(byte[] bytes, @NonNull String url, + Map params, + Map formParams) { + return createMultipartRequestBuilder( + new ByteArrayPart(FILE.paramName(), bytes, HttpHeaderValues.APPLICATION_OCTET_STREAM.toString()), + params, formParams) .setUrl(url) - .setHeader("Content-Type", FORM_MIME_TYPE) - .addBodyPart(new ByteArrayPart(FILE, bytes, FORM_MIME_TYPE)) .build(); } @@ -75,7 +87,7 @@ static Request createPostByteArrayRequest(byte[] bytes, @NonNull String url, Map static Request createPostJsonRequest(@NonNull String body, @NonNull String url) { return new RequestBuilder() .setUrl(url) - .setHeader("Content-Type", JSON_MIME_TYPE) + .setHeader(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON) .setBody(new ByteArrayBodyGenerator(body.getBytes())) .setMethod(HttpConstants.Methods.POST) .build(); @@ -99,11 +111,46 @@ static boolean isResponseSuccessful(int statusCode) { static RequestBuilder createRequestBuilder(@NonNull String httpMethod, Map params) { RequestBuilder builder = new RequestBuilder(httpMethod); if (params != null) { - params.forEach((k, v) -> builder.addQueryParam(k.paramName(), v)); + params.forEach((k, v) -> { + if (v != null) { + builder.addQueryParam(k.paramName(), v); + } + }); } return builder; } + /** + * Creates multipart request builder using provided parameters. + * + * @param fileLikePart the file like part to be added + * @param params the query parameters to be added + * @param formParams the form parameters + * @return the pre-configured request builder. + */ + static RequestBuilder createMultipartRequestBuilder( + @NonNull FileLikePart fileLikePart, Map params, + Map formParams) { + RequestBuilder builder = createRequestBuilder(HttpConstants.Methods.POST, params); + List parts = new ArrayList<>(); + parts.add(fileLikePart); + if (formParams != null) { + formParams.forEach((k, v) -> { + if (v != null) { + parts.add(createPart(k.paramName(), v)); + } + }); + } + builder + .setHeader(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.MULTIPART_FORM_DATA) + .setBodyParts(parts); + return builder; + } + + private static Part createPart(String name, @NonNull Object value) { + return new StringPart(name, value.toString()); + } + /** * The function allowing to debug response. */ @@ -116,9 +163,9 @@ public static class DebugLogResponse implements Function { public Response apply(Response response) { // log response for debug purposes if (ConnectionUtils.isResponseSuccessful(response.getStatusCode())) { - logger.debug("for endpoint {} response {}\n", endpoint, response.getResponseBody()); + logger.debug("for endpoint {} got response {}\n", endpoint, response.getResponseBody()); } else { - logger.error("for endpoint {} response {}\n", endpoint, response.getStatusText()); + logger.error("for endpoint {} got response {}\n", endpoint, response.getStatusText()); } return response; } diff --git a/comet-java-client/src/main/java/ml/comet/experiment/impl/resources/LogMessages.java b/comet-java-client/src/main/java/ml/comet/experiment/impl/resources/LogMessages.java new file mode 100644 index 00000000..80128adf --- /dev/null +++ b/comet-java-client/src/main/java/ml/comet/experiment/impl/resources/LogMessages.java @@ -0,0 +1,86 @@ +package ml.comet.experiment.impl.resources; + +import lombok.experimental.UtilityClass; +import org.apache.commons.lang3.StringUtils; + +import java.io.IOException; +import java.io.InputStream; +import java.util.PropertyResourceBundle; + +/** + * Provides access to the log messages to be presented to the user. + */ +@UtilityClass +public class LogMessages { + + public static final String EXPERIMENT_LIVE = "EXPERIMENT_LIVE"; + public static final String EXPERIMENT_CLEANUP_PROMPT = "EXPERIMENT_CLEANUP_PROMPT"; + public static final String EXPERIMENT_HEARTBEAT_STOPPED_PROMPT = "EXPERIMENT_HEARTBEAT_STOPPED_PROMPT"; + public static final String LOG_ASSET_FOLDER_EMPTY = "LOG_ASSET_FOLDER_EMPTY"; + public static final String ASSETS_FOLDER_UPLOAD_COMPLETED = "ASSETS_FOLDER_UPLOAD_COMPLETED"; + + public static final String FAILED_READ_DATA_FOR_EXPERIMENT = "FAILED_READ_DATA_FOR_EXPERIMENT"; + public static final String FAILED_TO_SEND_LOG_REQUEST = "FAILED_TO_SEND_LOG_REQUEST"; + public static final String FAILED_TO_LOG_ASSET_FOLDER = "FAILED_TO_LOG_ASSET_FOLDER"; + public static final String FAILED_TO_LOG_SOME_ASSET_FROM_FOLDER = "FAILED_TO_LOG_SOME_ASSET_FROM_FOLDER"; + + + /** + * Gets a formatted string for the given key from this resource bundle. + * + * @param key the key for the desired string + * @param args the formatting arguments. See {@link String#format(String, Object...)} + * @return the formatted string for the given key + */ + public static String getString(String key, Object... args) { + String format = getString(key); + if (StringUtils.isNotBlank(format)) { + if (args != null) { + try { + return String.format(format, args); + } catch (Throwable t) { + System.err.println("Failed to format message for key: " + key); + t.printStackTrace(); + } + } else { + return format; + } + } + return StringUtils.EMPTY; + } + + /** + * Gets a string for the given key from this resource bundle. + * + * @param key the key for the desired string + * @return the string for the given key + */ + public static String getString(String key) { + if (res != null) { + try { + return res.getString(key); + } catch (Throwable t) { + System.err.println("Failed to get message for key: " + key); + t.printStackTrace(); + } + } + return StringUtils.EMPTY; + } + + static PropertyResourceBundle res; + + static { + InputStream is = Thread.currentThread() + .getContextClassLoader().getResourceAsStream("messages.properties"); + if (is != null) { + try { + res = new PropertyResourceBundle(is); + } catch (IOException e) { + System.err.println("Failed to initialize messages bundle"); + e.printStackTrace(); + } + } else { + System.err.println("Failed to find messages bundle"); + } + } +} diff --git a/comet-java-client/src/main/java/ml/comet/experiment/impl/utils/AssetUtils.java b/comet-java-client/src/main/java/ml/comet/experiment/impl/utils/AssetUtils.java new file mode 100644 index 00000000..8cbc9e6e --- /dev/null +++ b/comet-java-client/src/main/java/ml/comet/experiment/impl/utils/AssetUtils.java @@ -0,0 +1,49 @@ +package ml.comet.experiment.impl.utils; + +import lombok.experimental.UtilityClass; +import ml.comet.experiment.impl.asset.Asset; +import org.apache.commons.io.FilenameUtils; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.stream.Stream; + +/** + * Utilities to work with assets. + */ +@UtilityClass +public class AssetUtils { + + /** + * Walks through the asset files in the given folder and produce stream of {@link Asset} objects holding information + * about file assets found in the folder. + * + * @param folder the folder where to look for asset files + * @param logFilePath if {@code true} the file path relative to the folder will be used. + * Otherwise, basename of the asset file will be used. + * @param recursive if {@code true} then subfolder files will be included recursively. + * @param prefixWithFolderName if {@code true} then path of each asset file will be prefixed with folder name + * in case if {@code logFilePath} is {@code true}. + * @return the stream of {@link Asset} objects. + * @throws IOException if an I/O exception occurred. + */ + public static Stream walkFolderAssets(File folder, boolean logFilePath, + boolean recursive, boolean prefixWithFolderName) + throws IOException { + // list files in the directory and process each file as an asset + return FileUtils.listFiles(folder, recursive) + .map(path -> mapToFileAsset(folder, path, logFilePath, prefixWithFolderName)); + } + + static Asset mapToFileAsset(File folder, Path assetPath, + boolean logFilePath, boolean prefixWithFolderName) { + Asset asset = new Asset(); + asset.setFile(assetPath.toFile()); + String fileName = FileUtils.resolveAssetFileName(folder, assetPath, logFilePath, prefixWithFolderName); + asset.setFileName(fileName); + asset.setFileExtension(FilenameUtils.getExtension(fileName)); + return asset; + } + +} diff --git a/comet-java-client/src/main/java/ml/comet/experiment/impl/utils/CometUtils.java b/comet-java-client/src/main/java/ml/comet/experiment/impl/utils/CometUtils.java index d6f711e2..0822960b 100644 --- a/comet-java-client/src/main/java/ml/comet/experiment/impl/utils/CometUtils.java +++ b/comet-java-client/src/main/java/ml/comet/experiment/impl/utils/CometUtils.java @@ -1,8 +1,12 @@ package ml.comet.experiment.impl.utils; +import lombok.NonNull; import lombok.experimental.UtilityClass; +import org.apache.commons.lang3.StringUtils; import java.net.URI; +import java.util.Map; +import java.util.UUID; import static ml.comet.experiment.impl.utils.ResourceUtils.readCometSdkVersion; @@ -41,4 +45,28 @@ public String createExperimentLink(String baseUrl, String workspaceName, String URI uri = URI.create(url); return uri.toString(); } + + /** + * Generates global unique identifier in format supported by Comet.ml + * + * @return the global unique identifier in format supported by Comet.ml. + */ + @SuppressWarnings("checkstyle:AbbreviationAsWordInName") + public static String generateGUID() { + String guid = UUID.randomUUID().toString(); + return StringUtils.remove(guid, '-'); + } + + /** + * Puts provided value into the map as string if it is not {@code null}. + * + * @param map the container map. + * @param key the key to use. + * @param value the optional value. + */ + public static void putNotNull(@NonNull Map map, @NonNull T key, Object value) { + if (value != null) { + map.put(key, value.toString()); + } + } } diff --git a/comet-java-client/src/main/java/ml/comet/experiment/impl/utils/DataUtils.java b/comet-java-client/src/main/java/ml/comet/experiment/impl/utils/DataUtils.java new file mode 100644 index 00000000..d2bf7c8a --- /dev/null +++ b/comet-java-client/src/main/java/ml/comet/experiment/impl/utils/DataUtils.java @@ -0,0 +1,161 @@ +package ml.comet.experiment.impl.utils; + +import lombok.NonNull; +import lombok.experimental.UtilityClass; +import ml.comet.experiment.context.ExperimentContext; +import ml.comet.experiment.model.AddExperimentTagsRest; +import ml.comet.experiment.model.AddGraphRest; +import ml.comet.experiment.model.ExperimentTimeRequest; +import ml.comet.experiment.model.HtmlRest; +import ml.comet.experiment.model.LogOtherRest; +import ml.comet.experiment.model.MetricRest; +import ml.comet.experiment.model.OutputLine; +import ml.comet.experiment.model.OutputUpdate; +import ml.comet.experiment.model.ParameterRest; + +import java.util.Collections; + +/** + * The common factory methods to create initialized DTO instances. + */ +@UtilityClass +public class DataUtils { + /** + * The factory to create {@link MetricRest} instance. + * + * @param metricName the metric name + * @param metricValue the metric value + * @param context the current context + * @return the initialized {@link MetricRest} instance. + */ + public static MetricRest createLogMetricRequest( + @NonNull String metricName, @NonNull Object metricValue, @NonNull ExperimentContext context) { + MetricRest request = new MetricRest(); + request.setMetricName(metricName); + request.setMetricValue(metricValue.toString()); + request.setStep(context.getStep()); + request.setEpoch(context.getEpoch()); + request.setTimestamp(System.currentTimeMillis()); + request.setContext(context.getContext()); + return request; + } + + /** + * The factory to create {@link ParameterRest} instance. + * + * @param parameterName the name of the parameter + * @param paramValue the value of the parameter + * @param context the current context + * @return the initialized {@link ParameterRest} instance. + */ + public static ParameterRest createLogParamRequest( + @NonNull String parameterName, @NonNull Object paramValue, @NonNull ExperimentContext context) { + ParameterRest request = new ParameterRest(); + request.setParameterName(parameterName); + request.setParameterValue(paramValue.toString()); + request.setStep(context.getStep()); + request.setTimestamp(System.currentTimeMillis()); + request.setContext(context.getContext()); + return request; + } + + /** + * The factory to create {@link OutputUpdate} instance. + * + * @param line the log line + * @param offset the log line offset + * @param stderr the flag to indicate if it's from StdErr + * @param context the current context + * @return the initialized {@link OutputUpdate} instance. + */ + public static OutputUpdate createLogLineRequest(@NonNull String line, long offset, boolean stderr, String context) { + OutputLine outputLine = new OutputLine(); + outputLine.setOutput(line); + outputLine.setStderr(stderr); + outputLine.setLocalTimestamp(System.currentTimeMillis()); + outputLine.setOffset(offset); + + OutputUpdate outputUpdate = new OutputUpdate(); + outputUpdate.setRunContext(context); + outputUpdate.setOutputLines(Collections.singletonList(outputLine)); + return outputUpdate; + } + + /** + * The factory to create {@link HtmlRest} instance. + * + * @param html the HTML code to be logged. + * @param override the flag to indicate whether it should override already saved version. + * @return the initialized {@link HtmlRest} instance. + */ + public static HtmlRest createLogHtmlRequest(@NonNull String html, boolean override) { + HtmlRest request = new HtmlRest(); + request.setHtml(html); + request.setOverride(override); + request.setTimestamp(System.currentTimeMillis()); + return request; + } + + /** + * The factory to create {@link LogOtherRest} instance. + * + * @param key the parameter name/key. + * @param value the parameter value. + * @return the initialized {@link LogOtherRest} instance. + */ + public static LogOtherRest createLogOtherRequest(@NonNull String key, @NonNull Object value) { + LogOtherRest request = new LogOtherRest(); + request.setKey(key); + request.setValue(value.toString()); + request.setTimestamp(System.currentTimeMillis()); + return request; + } + + /** + * The factory to create {@link AddExperimentTagsRest} instance. + * + * @param tag the tag value + * @return the initialized {@link AddExperimentTagsRest} instance + */ + public static AddExperimentTagsRest createTagRequest(@NonNull String tag) { + AddExperimentTagsRest request = new AddExperimentTagsRest(); + request.setAddedTags(Collections.singletonList(tag)); + return request; + } + + /** + * The factory to create {@link AddGraphRest} instance. + * + * @param graph the NN graph representation. + * @return the initialized {@link AddGraphRest} instance. + */ + public static AddGraphRest createGraphRequest(@NonNull String graph) { + AddGraphRest request = new AddGraphRest(); + request.setGraph(graph); + return request; + } + + /** + * The factory to create {@link ExperimentTimeRequest} instance. + * + * @param startTimeMillis the experiment's start time in milliseconds. + * @return the initialized {@link ExperimentTimeRequest} instance. + */ + public static ExperimentTimeRequest createLogStartTimeRequest(long startTimeMillis) { + ExperimentTimeRequest request = new ExperimentTimeRequest(); + request.setStartTimeMillis(startTimeMillis); + return request; + } + + /** + * The factory to create {@link ExperimentTimeRequest} instance. + * + * @param endTimeMillis the experiment's end time in milliseconds. + * @return the initialized {@link ExperimentTimeRequest} instance. + */ + public static ExperimentTimeRequest createLogEndTimeRequest(long endTimeMillis) { + ExperimentTimeRequest request = new ExperimentTimeRequest(); + request.setEndTimeMillis(endTimeMillis); + return request; + } +} diff --git a/comet-java-client/src/main/java/ml/comet/experiment/impl/utils/FileUtils.java b/comet-java-client/src/main/java/ml/comet/experiment/impl/utils/FileUtils.java new file mode 100644 index 00000000..97334c26 --- /dev/null +++ b/comet-java-client/src/main/java/ml/comet/experiment/impl/utils/FileUtils.java @@ -0,0 +1,66 @@ +package ml.comet.experiment.impl.utils; + +import lombok.experimental.UtilityClass; + +import java.io.File; +import java.io.IOException; +import java.nio.file.DirectoryStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.stream.Stream; + +/** + * Provides common file system utilities. + */ +@UtilityClass +public class FileUtils { + + /** + * Lists files under given folder. + * + * @param folder the folder to list files in. + * @param recursive if {@code true} then subfolder files will be included recursively. + * @return the list of files under given directory. + * @throws IOException if an I/O exception occurs. + */ + public static Stream listFiles(File folder, boolean recursive) throws IOException { + ArrayList res; + if (recursive) { + try (Stream files = Files.walk(folder.toPath())) { + res = files.collect(ArrayList::new, (paths, path) -> { + if (!path.toFile().isDirectory()) { + paths.add(path); + } + }, ArrayList::addAll); + } + } else { + res = new ArrayList<>(); + try (DirectoryStream files = Files.newDirectoryStream(folder.toPath())) { + files.forEach(path -> { + if (!path.toFile().isDirectory()) { + res.add(path); + } + }); + } + } + return res.stream().sorted(Comparator.naturalOrder()); + } + + static String resolveAssetFileName(File folder, Path path, boolean logFilePath, + boolean prefixWithFolderName) { + if (logFilePath) { + // the path relative to the assets' folder root + Path filePath = folder.toPath().relativize(path); + + if (prefixWithFolderName) { + filePath = folder.toPath().getFileName().resolve(filePath); + } + return filePath.toString(); + } else { + // the asset's file name + return path.getFileName().toString(); + } + } +} diff --git a/comet-java-client/src/main/java/ml/comet/experiment/model/ExperimentAssetLink.java b/comet-java-client/src/main/java/ml/comet/experiment/model/ExperimentAssetLink.java index ced3650b..c8f9ba14 100644 --- a/comet-java-client/src/main/java/ml/comet/experiment/model/ExperimentAssetLink.java +++ b/comet-java-client/src/main/java/ml/comet/experiment/model/ExperimentAssetLink.java @@ -16,9 +16,9 @@ @JsonIgnoreProperties(ignoreUnknown = true) public class ExperimentAssetLink { private String fileName; - private long fileSize; + private Long fileSize; private String runContext; - private Integer step; + private Long step; private boolean remote = false; private String link; private String compressedAssetLink; diff --git a/comet-java-client/src/main/java/ml/comet/experiment/model/LogDataResponse.java b/comet-java-client/src/main/java/ml/comet/experiment/model/LogDataResponse.java index a91ce7ec..3c378184 100644 --- a/comet-java-client/src/main/java/ml/comet/experiment/model/LogDataResponse.java +++ b/comet-java-client/src/main/java/ml/comet/experiment/model/LogDataResponse.java @@ -18,6 +18,7 @@ public class LogDataResponse { private String msg; private int code; private int sdkErrorCode; + private String data; public boolean hasFailed() { return code != 200 || sdkErrorCode != 0; diff --git a/comet-java-client/src/main/java/ml/comet/experiment/model/package-info.java b/comet-java-client/src/main/java/ml/comet/experiment/model/package-info.java new file mode 100644 index 00000000..c20bafb9 --- /dev/null +++ b/comet-java-client/src/main/java/ml/comet/experiment/model/package-info.java @@ -0,0 +1,4 @@ +/** + * Contains all classes used to model Comet REST API data transfer objects. + */ +package ml.comet.experiment.model; \ No newline at end of file diff --git a/comet-java-client/src/main/resources/messages.properties b/comet-java-client/src/main/resources/messages.properties new file mode 100644 index 00000000..e3433515 --- /dev/null +++ b/comet-java-client/src/main/resources/messages.properties @@ -0,0 +1,11 @@ +EXPERIMENT_LIVE=Experiment is live on comet.ml %s +EXPERIMENT_CLEANUP_PROMPT=Waiting for all scheduled uploads to complete. It can take up to %d seconds. +EXPERIMENT_HEARTBEAT_STOPPED_PROMPT=Experiment's heartbeat sender stopped +ASSETS_FOLDER_UPLOAD_COMPLETED=The asset folder '%s' has been uploaded. Processed %d asset files. + +LOG_ASSET_FOLDER_EMPTY=Directory %s is empty; no files were uploaded.\nPlease double-check the directory path and the recursive parameter. + +FAILED_READ_DATA_FOR_EXPERIMENT=Failed to read %s for the experiment, experiment key: %s +FAILED_TO_SEND_LOG_REQUEST=Failed to send log request: %s +FAILED_TO_LOG_ASSET_FOLDER=We failed to read directory '%s' for uploading.\nPlease double-check the file path, permissions, and that it is a directory. +FAILED_TO_LOG_SOME_ASSET_FROM_FOLDER=We failed to upload some asset from directory '%s'.\nPlease check previous logs for details about failed assets. \ No newline at end of file diff --git a/comet-java-client/src/test/java/ml/comet/experiment/impl/ApiExperimentImplTest.java b/comet-java-client/src/test/java/ml/comet/experiment/impl/ApiExperimentImplTest.java index fd663d01..6c17bd87 100644 --- a/comet-java-client/src/test/java/ml/comet/experiment/impl/ApiExperimentImplTest.java +++ b/comet-java-client/src/test/java/ml/comet/experiment/impl/ApiExperimentImplTest.java @@ -1,5 +1,6 @@ package ml.comet.experiment.impl; +import ml.comet.experiment.ApiExperiment; import ml.comet.experiment.OnlineExperiment; import ml.comet.experiment.exception.CometGeneralException; import org.junit.jupiter.api.Test; @@ -26,7 +27,7 @@ public void testApiExperimentInitialized() { String experimentKey = experiment.getExperimentKey(); experiment.end(); - ApiExperimentImpl apiExperiment = ApiExperimentImpl.builder(experimentKey) + ApiExperiment apiExperiment = ApiExperimentImpl.builder(experimentKey) .withApiKey(API_KEY) .build(); diff --git a/comet-java-client/src/test/java/ml/comet/experiment/impl/OnlineExperimentTest.java b/comet-java-client/src/test/java/ml/comet/experiment/impl/OnlineExperimentTest.java index 3b313b80..1980ea33 100644 --- a/comet-java-client/src/test/java/ml/comet/experiment/impl/OnlineExperimentTest.java +++ b/comet-java-client/src/test/java/ml/comet/experiment/impl/OnlineExperimentTest.java @@ -1,22 +1,31 @@ package ml.comet.experiment.impl; import io.reactivex.rxjava3.functions.Action; +import ml.comet.experiment.ApiExperiment; import ml.comet.experiment.Experiment; import ml.comet.experiment.OnlineExperiment; +import ml.comet.experiment.context.ExperimentContext; import ml.comet.experiment.impl.utils.TestUtils; -import ml.comet.experiment.model.GitMetadata; import ml.comet.experiment.model.ExperimentAssetLink; import ml.comet.experiment.model.ExperimentMetadataRest; +import ml.comet.experiment.model.GitMetadata; import ml.comet.experiment.model.GitMetadataRest; import ml.comet.experiment.model.ValueMinMaxDto; +import org.apache.commons.io.file.PathUtils; import org.apache.commons.lang3.StringUtils; import org.awaitility.Awaitility; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.function.BiConsumer; import java.util.function.BooleanSupplier; @@ -24,9 +33,8 @@ import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.SECONDS; -import static ml.comet.experiment.impl.constants.AssetType.ASSET_TYPE_ALL; -import static ml.comet.experiment.impl.constants.AssetType.ASSET_TYPE_SOURCE_CODE; -import static ml.comet.experiment.impl.constants.AssetType.ASSET_TYPE_UNKNOWN; +import static ml.comet.experiment.impl.asset.AssetType.ASSET_TYPE_ALL; +import static ml.comet.experiment.impl.asset.AssetType.ASSET_TYPE_SOURCE_CODE; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotEquals; @@ -58,6 +66,71 @@ public class OnlineExperimentTest extends BaseApiTest { private static final String LOGGED_ERROR_LINE = "This error should also get to Comet ML."; private static final String NON_LOGGED_LINE = "This should not end up in Comet ML."; + private static Path root; + private static Path emptyFile; + private static List assetFolderFiles; + + @BeforeAll + static void setup() throws IOException { + assetFolderFiles = new ArrayList<>(); + // create temporary directory tree + root = Files.createTempDirectory("testFileUtils"); + assetFolderFiles.add( + PathUtils.copyFileToDirectory( + Objects.requireNonNull(TestUtils.getFile(SOME_TEXT_FILE_NAME)).toPath(), root)); + assetFolderFiles.add( + PathUtils.copyFileToDirectory( + Objects.requireNonNull(TestUtils.getFile(ANOTHER_TEXT_FILE_NAME)).toPath(), root)); + emptyFile = Files.createTempFile(root, "c_file", ".txt"); + assetFolderFiles.add(emptyFile); + + Path subDir = Files.createTempDirectory(root, "subDir"); + assetFolderFiles.add( + PathUtils.copyFileToDirectory( + Objects.requireNonNull(TestUtils.getFile(IMAGE_FILE_NAME)).toPath(), subDir)); + assetFolderFiles.add( + PathUtils.copyFileToDirectory( + Objects.requireNonNull(TestUtils.getFile(CODE_FILE_NAME)).toPath(), subDir)); + } + + @AfterAll + static void tearDown() throws IOException { + PathUtils.delete(root); + assertFalse(Files.exists(root), "Directory still exists"); + } + + @Test + public void testLogAndGetAssetsFolder() { + OnlineExperimentImpl experiment = (OnlineExperimentImpl) createOnlineExperiment(); + + // Make sure experiment has no assets + // + assertTrue(experiment.getAssetList(ASSET_TYPE_ALL).isEmpty()); + + // Log assets folder nd wait for completion + // + ExperimentContext context = new ExperimentContext(123, 1042, "train"); + OnCompleteAction onComplete = new OnCompleteAction(); + experiment.logAssetFolder(root.toFile(), false, true, false, context, onComplete); + + awaitForCondition(onComplete, "log assets' folder timeout", 60); + + // wait for assets become available and validate results + // + awaitForCondition(() -> + experiment.getAssetList(ASSET_TYPE_ALL).size() == assetFolderFiles.size(), "Assets was uploaded"); + + List assets = experiment.getAssetList(ASSET_TYPE_ALL); + + validateAsset(assets, SOME_TEXT_FILE_NAME, SOME_TEXT_FILE_SIZE, context); + validateAsset(assets, ANOTHER_TEXT_FILE_NAME, ANOTHER_TEXT_FILE_SIZE, context); + validateAsset(assets, emptyFile.getFileName().toString(), 0, context); + validateAsset(assets, IMAGE_FILE_NAME, IMAGE_FILE_SIZE, context); + validateAsset(assets, CODE_FILE_NAME, CODE_FILE_SIZE, context); + + experiment.end(); + } + @Test public void testExperimentCreatedAndShutDown() { OnlineExperiment experiment = createOnlineExperiment(); @@ -78,7 +151,7 @@ public void testExperimentCreatedAndShutDown() { experiment.end(); // use REST API to check experiment status - ApiExperimentImpl apiExperiment = ApiExperimentImpl.builder(experimentKey).build(); + ApiExperiment apiExperiment = ApiExperimentImpl.builder(experimentKey).build(); awaitForCondition(() -> !apiExperiment.getMetadata().isRunning(), "Experiment running status updated", 60); assertFalse(apiExperiment.getMetadata().isRunning(), "Experiment must have status not running"); @@ -96,7 +169,7 @@ public void testInitAndUpdateExistingExperiment() { // get previous experiment by key and check that update is working String experimentKey = experiment.getExperimentKey(); - OnlineExperiment updatedExperiment = fetchExperiment(experimentKey); + OnlineExperiment updatedExperiment = onlineExperiment(experimentKey); updatedExperiment.setExperimentName(SOME_NAME); awaitForCondition( @@ -131,7 +204,8 @@ public void testLogAndGetMetric() { testLogParameters(experiment, Experiment::getMetrics, (key, value) -> { OnCompleteAction onCompleteAction = new OnCompleteAction(); - ((BaseExperiment) experiment).logMetricAsync(key, value, 1, 1, onCompleteAction); + ((OnlineExperimentImpl) experiment).logMetric(key, value, + new ExperimentContext(1, 1), onCompleteAction); awaitForCondition(onCompleteAction, "logMetricAsync onComplete timeout"); }); @@ -144,7 +218,7 @@ public void testLogAndGetParameter() { testLogParameters(experiment, Experiment::getParameters, (key, value) -> { OnCompleteAction onCompleteAction = new OnCompleteAction(); - ((BaseExperiment) experiment).logParameterAsync(key, value, 1, onCompleteAction); + ((OnlineExperimentImpl) experiment).logParameter(key, value, new ExperimentContext(1), onCompleteAction); awaitForCondition(onCompleteAction, "logParameterAsync onComplete timeout"); }); @@ -169,7 +243,7 @@ public void testLogAndGetOther() { params.forEach((key, value) -> { OnCompleteAction onCompleteAction = new OnCompleteAction(); - ((BaseExperiment) experiment).logOtherAsync(key, value, onCompleteAction); + ((OnlineExperimentImpl) experiment).logOther(key, value, onCompleteAction); awaitForCondition(onCompleteAction, "logOtherAsync onComplete timeout"); }); @@ -185,14 +259,14 @@ public void testLogAndGetOther() { @Test public void testLogAndGetHtml() { - BaseExperiment experiment = (BaseExperiment) createOnlineExperiment(); + OnlineExperimentImpl experiment = (OnlineExperimentImpl) createOnlineExperiment(); assertFalse(experiment.getHtml().isPresent()); // Create first HTML record // OnCompleteAction onComplete = new OnCompleteAction(); - experiment.logHtmlAsync(SOME_HTML, true, onComplete); + experiment.logHtml(SOME_HTML, true, onComplete); // sleep to make sure the request was sent awaitForCondition(onComplete, "onComplete timeout"); @@ -205,7 +279,7 @@ public void testLogAndGetHtml() { // Override first HTML record // onComplete = new OnCompleteAction(); - experiment.logHtmlAsync(ANOTHER_HTML, true, onComplete); + experiment.logHtml(ANOTHER_HTML, true, onComplete); // sleep to make sure the request was sent awaitForCondition(onComplete, "onComplete timeout"); @@ -218,7 +292,7 @@ public void testLogAndGetHtml() { // Check that HTML record was not overridden but appended // onComplete = new OnCompleteAction(); - experiment.logHtmlAsync(SOME_HTML, false, onComplete); + experiment.logHtml(SOME_HTML, false, onComplete); // sleep to make sure the request was sent awaitForCondition(onComplete, "onComplete timeout"); @@ -233,7 +307,7 @@ public void testLogAndGetHtml() { @Test public void testAddAndGetTag() { - BaseExperiment experiment = (BaseExperiment) createOnlineExperiment(); + OnlineExperimentImpl experiment = (OnlineExperimentImpl) createOnlineExperiment(); // Check that experiment has no TAGs assertTrue(experiment.getTags().isEmpty()); @@ -241,11 +315,11 @@ public void testAddAndGetTag() { // Add TAGs and wait for response // OnCompleteAction onComplete = new OnCompleteAction(); - experiment.addTagAsync(SOME_TEXT, onComplete); + experiment.addTag(SOME_TEXT, onComplete); awaitForCondition(onComplete, "onComplete timeout"); onComplete = new OnCompleteAction(); - experiment.addTagAsync(ANOTHER_TAG, onComplete); + experiment.addTag(ANOTHER_TAG, onComplete); awaitForCondition(onComplete, "onComplete timeout"); // Get new TAGs and check @@ -261,7 +335,7 @@ public void testAddAndGetTag() { @Test public void testLogAndGetGraph() { - BaseExperiment experiment = (BaseExperiment) createOnlineExperiment(); + OnlineExperimentImpl experiment = (OnlineExperimentImpl) createOnlineExperiment(); // Check that experiment has no Graph // @@ -271,7 +345,7 @@ public void testLogAndGetGraph() { // Log Graph and wait for response // OnCompleteAction onComplete = new OnCompleteAction(); - experiment.logGraphAsync(SOME_GRAPH, onComplete); + experiment.logGraph(SOME_GRAPH, onComplete); awaitForCondition(onComplete, "onComplete timeout"); // Get graph and check result @@ -298,15 +372,15 @@ public void testLogAndGetExperimentTime() { // fetch existing experiment and update time // - BaseExperiment existingExperiment = (BaseExperiment) fetchExperiment(experimentKey); + OnlineExperimentImpl existingExperiment = (OnlineExperimentImpl) onlineExperiment(experimentKey); long now = System.currentTimeMillis(); OnCompleteAction onComplete = new OnCompleteAction(); - existingExperiment.logStartTimeAsync(now, onComplete); + existingExperiment.logStartTime(now, onComplete); awaitForCondition(onComplete, "logStartTime onComplete timeout", 120); onComplete = new OnCompleteAction(); - existingExperiment.logEndTimeAsync(now, onComplete); + existingExperiment.logEndTime(now, onComplete); awaitForCondition(onComplete, "logEndTime onComplete timeout", 120); // Get updated experiment metadata and check results @@ -323,25 +397,48 @@ public void testLogAndGetExperimentTime() { @Test public void testUploadAndGetAssets() { - OnlineExperiment experiment = createOnlineExperiment(); + OnlineExperimentImpl experiment = (OnlineExperimentImpl) createOnlineExperiment(); + // Make sure experiment has no assets + // assertTrue(experiment.getAssetList(ASSET_TYPE_ALL).isEmpty()); - experiment.uploadAsset(TestUtils.getFile(IMAGE_FILE_NAME), false); - experiment.uploadAsset(TestUtils.getFile(SOME_TEXT_FILE_NAME), false); + // Upload few assets and wait for completion + // + ExperimentContext context = new ExperimentContext(10, 101, "train"); + OnCompleteAction onComplete = new OnCompleteAction(); + experiment.uploadAsset(Objects.requireNonNull(TestUtils.getFile(IMAGE_FILE_NAME)), IMAGE_FILE_NAME, + false, context, onComplete); + awaitForCondition(onComplete, "image file onComplete timeout", 30); + + onComplete = new OnCompleteAction(); + experiment.uploadAsset(Objects.requireNonNull(TestUtils.getFile(SOME_TEXT_FILE_NAME)), SOME_TEXT_FILE_NAME, + false, context, onComplete); + awaitForCondition(onComplete, "text file onComplete timeout", 30); - awaitForCondition(() -> experiment.getAssetList(ASSET_TYPE_ALL).size() == 2, "Assets uploaded"); + // wait for assets become available and validate results + // + awaitForCondition(() -> experiment.getAssetList(ASSET_TYPE_ALL).size() == 2, "Assets was uploaded"); List assets = experiment.getAssetList(ASSET_TYPE_ALL); - validateAsset(assets, IMAGE_FILE_NAME, IMAGE_FILE_SIZE); - validateAsset(assets, SOME_TEXT_FILE_NAME, SOME_TEXT_FILE_SIZE); + validateAsset(assets, IMAGE_FILE_NAME, IMAGE_FILE_SIZE, context); + validateAsset(assets, SOME_TEXT_FILE_NAME, SOME_TEXT_FILE_SIZE, context); - experiment.uploadAsset(TestUtils.getFile(ANOTHER_TEXT_FILE_NAME), SOME_TEXT_FILE_NAME, true); + // update one of the assets and validate + // + onComplete = new OnCompleteAction(); + experiment.uploadAsset(Objects.requireNonNull(TestUtils.getFile(ANOTHER_TEXT_FILE_NAME)), + SOME_TEXT_FILE_NAME, true, context, onComplete); + awaitForCondition(onComplete, "update text file onComplete timeout", 30); awaitForCondition(() -> { - List textFiles = experiment.getAssetList(ASSET_TYPE_UNKNOWN); - ExperimentAssetLink file = textFiles.get(0); - return ANOTHER_TEXT_FILE_SIZE == file.getFileSize(); + List assetList = experiment.getAssetList(ASSET_TYPE_ALL); + return assetList.stream() + .filter(asset -> SOME_TEXT_FILE_NAME.equals(asset.getFileName())) + .anyMatch(asset -> + ANOTHER_TEXT_FILE_SIZE == asset.getFileSize() + && Objects.equals(asset.getStep(), context.getStep()) + && asset.getRunContext().equals(context.getContext())); }, "Asset was updated"); experiment.end(); @@ -382,7 +479,7 @@ public void testLogAndGetGitMetadata() { OnCompleteAction onComplete = new OnCompleteAction(); GitMetadata request = new GitMetadata(experiment.getExperimentKey(), "user", "root", "branch", "parent", "origin"); - ((BaseExperiment)experiment).logGitMetadataAsync(request, onComplete); + ((OnlineExperimentImpl) experiment).logGitMetadataAsync(request, onComplete); awaitForCondition(onComplete, "onComplete timeout"); // Get GIT metadata and check results @@ -436,23 +533,43 @@ public void testCopyStdout() throws IOException { @Test public void testLogAndGetFileCode() { - OnlineExperiment experiment = createOnlineExperiment(); + OnlineExperimentImpl experiment = (OnlineExperimentImpl) createOnlineExperiment(); + + // check that no code was logged + // assertTrue(experiment.getAssetList(ASSET_TYPE_ALL).isEmpty()); - experiment.logCode(TestUtils.getFile(CODE_FILE_NAME)); - awaitForCondition(() -> !experiment.getAssetList(ASSET_TYPE_SOURCE_CODE).isEmpty(), "Experiment code from file added"); + + // log code and check results + // + ExperimentContext context = new ExperimentContext(10, 101, "test"); + experiment.logCode(Objects.requireNonNull(TestUtils.getFile(CODE_FILE_NAME)), context); + + awaitForCondition(() -> !experiment.getAssetList(ASSET_TYPE_SOURCE_CODE).isEmpty(), + "Experiment code from file added"); List assets = experiment.getAssetList(ASSET_TYPE_SOURCE_CODE); - validateAsset(assets, CODE_FILE_NAME, CODE_FILE_SIZE); + validateAsset(assets, CODE_FILE_NAME, CODE_FILE_SIZE, context); + experiment.end(); } @Test public void testLogAndGetRawCode() { - OnlineExperiment experiment = createOnlineExperiment(); + OnlineExperimentImpl experiment = (OnlineExperimentImpl) createOnlineExperiment(); + + // check that no code was logged + // assertTrue(experiment.getAssetList(ASSET_TYPE_ALL).isEmpty()); - experiment.logCode(SOME_TEXT, CODE_FILE_NAME); - awaitForCondition(() -> !experiment.getAssetList(ASSET_TYPE_SOURCE_CODE).isEmpty(), "Experiment raw code added"); + + // log code and check results + // + ExperimentContext context = new ExperimentContext(10, 101, "test"); + experiment.logCode(SOME_TEXT, CODE_FILE_NAME, context); + + awaitForCondition(() -> !experiment.getAssetList(ASSET_TYPE_SOURCE_CODE).isEmpty(), + "Experiment raw code added"); List assets = experiment.getAssetList(ASSET_TYPE_SOURCE_CODE); - validateAsset(assets, CODE_FILE_NAME, SOME_TEXT_FILE_SIZE); + validateAsset(assets, CODE_FILE_NAME, SOME_TEXT_FILE_SIZE, context); + experiment.end(); } @@ -470,17 +587,20 @@ public boolean getAsBoolean() { } } - static OnlineExperiment fetchExperiment(String experimentKey) { + static OnlineExperiment onlineExperiment(String experimentKey) { return OnlineExperimentImpl.builder() .withApiKey(API_KEY) .withExistingExperimentKey(experimentKey) .build(); } - static void validateAsset(List assets, String expectedAssetName, long expectedSize) { + static void validateAsset(List assets, String expectedAssetName, + long expectedSize, ExperimentContext context) { assertTrue(assets.stream() .filter(asset -> expectedAssetName.equals(asset.getFileName())) - .anyMatch(asset -> expectedSize == asset.getFileSize())); + .anyMatch(asset -> expectedSize == asset.getFileSize() + && Objects.equals(context.getStep(), asset.getStep()) + && context.getContext().equals(asset.getRunContext()))); } static void testLogParameters(OnlineExperiment experiment, diff --git a/comet-java-client/src/test/java/ml/comet/experiment/impl/http/ConnectionTest.java b/comet-java-client/src/test/java/ml/comet/experiment/impl/http/ConnectionTest.java index 738b4a19..d031f006 100644 --- a/comet-java-client/src/test/java/ml/comet/experiment/impl/http/ConnectionTest.java +++ b/comet-java-client/src/test/java/ml/comet/experiment/impl/http/ConnectionTest.java @@ -30,6 +30,8 @@ import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; import static com.github.tomakehurst.wiremock.client.WireMock.urlPathEqualTo; import static com.github.tomakehurst.wiremock.client.WireMock.verify; +import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_TYPE; +import static io.netty.handler.codec.http.HttpHeaderValues.APPLICATION_JSON; import static ml.comet.experiment.impl.constants.QueryParamName.EXPERIMENT_KEY; import static ml.comet.experiment.impl.constants.QueryParamName.OVERWRITE; import static ml.comet.experiment.impl.http.Connection.COMET_SDK_API_HEADER; @@ -64,7 +66,7 @@ public void testSendGetWithRetries(@NonNull WireMockRuntimeInfo wmRuntimeInfo) { // stubFor(get(urlPathEqualTo(endpoint)) .withQueryParams(queryParams) - .willReturn(ok(responseStr).withHeader("Content-Type", ConnectionUtils.JSON_MIME_TYPE))); + .willReturn(ok(responseStr).withHeader(CONTENT_TYPE.toString(), APPLICATION_JSON.toString()))); // execute request and check results // @@ -97,7 +99,7 @@ public void testSendPostWithRetries(@NonNull WireMockRuntimeInfo wmRuntimeInfo) // create test HTTP stub // stubFor(post(urlPathEqualTo(endpoint)) - .willReturn(ok(responseStr).withHeader("Content-Type", ConnectionUtils.JSON_MIME_TYPE))); + .willReturn(ok(responseStr).withHeader(CONTENT_TYPE.toString(), APPLICATION_JSON.toString()))); // execute request and check results // @@ -128,7 +130,7 @@ public void testSendPostWithRetriesException(@NonNull WireMockRuntimeInfo wmRunt // create test HTTP stub // stubFor(post(urlPathEqualTo(endpoint)) - .willReturn(badRequest().withHeader("Content-Type", ConnectionUtils.JSON_MIME_TYPE))); + .willReturn(badRequest().withHeader(CONTENT_TYPE.toString(), APPLICATION_JSON.toString()))); // execute request and check results // @@ -152,7 +154,7 @@ public void testSendPostWithRetriesEmptyOptional(@NonNull WireMockRuntimeInfo wm // create test HTTP stub // stubFor(post(urlPathEqualTo(endpoint)) - .willReturn(badRequest().withHeader("Content-Type", ConnectionUtils.JSON_MIME_TYPE))); + .willReturn(badRequest().withHeader(CONTENT_TYPE.toString(), APPLICATION_JSON.toString()))); // execute request and check results // @@ -176,7 +178,7 @@ public void testSendPostAsync(@NonNull WireMockRuntimeInfo wmRuntimeInfo) { // create test HTTP stub // stubFor(post(urlPathEqualTo(endpoint)) - .willReturn(ok(responseStr).withHeader("Content-Type", ConnectionUtils.JSON_MIME_TYPE))); + .willReturn(ok(responseStr).withHeader(CONTENT_TYPE.toString(), APPLICATION_JSON.toString()))); // execute request and check results // @@ -209,7 +211,7 @@ public void testSendPostAsyncErrorStatus(@NonNull WireMockRuntimeInfo wmRuntimeI // create test HTTP stub // stubFor(post(urlPathEqualTo(endpoint)) - .willReturn(badRequest().withHeader("Content-Type", ConnectionUtils.JSON_MIME_TYPE))); + .willReturn(badRequest().withHeader(CONTENT_TYPE.toString(), APPLICATION_JSON.toString()))); // execute request and check results // diff --git a/comet-java-client/src/test/java/ml/comet/experiment/impl/http/ConnectionUtilsTest.java b/comet-java-client/src/test/java/ml/comet/experiment/impl/http/ConnectionUtilsTest.java index cd31d68f..2902f577 100644 --- a/comet-java-client/src/test/java/ml/comet/experiment/impl/http/ConnectionUtilsTest.java +++ b/comet-java-client/src/test/java/ml/comet/experiment/impl/http/ConnectionUtilsTest.java @@ -1,6 +1,7 @@ package ml.comet.experiment.impl.http; import ml.comet.experiment.impl.constants.ApiEndpoints; +import ml.comet.experiment.impl.constants.FormParamName; import ml.comet.experiment.impl.constants.QueryParamName; import ml.comet.experiment.impl.utils.JsonUtils; import ml.comet.experiment.impl.utils.TestUtils; @@ -8,6 +9,7 @@ import org.asynchttpclient.Request; import org.asynchttpclient.request.body.multipart.ByteArrayPart; import org.asynchttpclient.request.body.multipart.FilePart; +import org.asynchttpclient.request.body.multipart.StringPart; import org.asynchttpclient.util.HttpConstants; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -17,8 +19,16 @@ import java.net.URI; import java.util.HashMap; +import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_TYPE; +import static io.netty.handler.codec.http.HttpHeaderValues.APPLICATION_JSON; +import static io.netty.handler.codec.http.HttpHeaderValues.APPLICATION_OCTET_STREAM; +import static io.netty.handler.codec.http.HttpHeaderValues.MULTIPART_FORM_DATA; +import static io.netty.handler.codec.http.HttpHeaderValues.TEXT_PLAIN; +import static ml.comet.experiment.impl.constants.FormParamName.FILE; +import static ml.comet.experiment.impl.constants.FormParamName.METADATA; import static ml.comet.experiment.impl.constants.QueryParamName.EXPERIMENT_KEY; import static ml.comet.experiment.impl.constants.QueryParamName.OVERWRITE; +import static org.asynchttpclient.util.HttpConstants.Methods.POST; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -42,43 +52,69 @@ public void testCreateGetRequest() { @Test public void testCreatePostFileRequest() { + // Create test data + // String url = "http://test.com" + ApiEndpoints.ADD_ASSET; - HashMap params = new HashMap() {{ + HashMap queryParams = new HashMap() {{ put(EXPERIMENT_KEY, "someValue"); put(OVERWRITE, Boolean.toString(true)); }}; + HashMap formParams = new HashMap() {{ + put(METADATA, "some string"); + }}; File file = TestUtils.getFile(SOME_TEXT_FILE_NAME); assertNotNull(file, "test file not found"); - Request r = ConnectionUtils.createPostFileRequest(file, url, params); - this.validateRequest(r, url, params, HttpConstants.Methods.POST, ConnectionUtils.FORM_MIME_TYPE); - - // check body parts - assertEquals(1, r.getBodyParts().size(), "wrong number of body parts"); - FilePart part = (FilePart) r.getBodyParts().get(0); - assertEquals(ConnectionUtils.FILE, part.getName(), "wrong name"); - assertEquals(ConnectionUtils.FORM_MIME_TYPE, part.getContentType(), "wrong content type"); - assertEquals(file, part.getFile(), "wrong file"); + // Create request + // + Request r = ConnectionUtils.createPostFileRequest(file, url, queryParams, formParams); + this.validateRequest(r, url, queryParams, POST, MULTIPART_FORM_DATA.toString()); + + // Check body parts + // + assertEquals(2, r.getBodyParts().size(), "wrong number of body parts"); + // file part + FilePart filePart = (FilePart) r.getBodyParts().get(0); + assertEquals(FILE.paramName(), filePart.getName(), "wrong name"); + assertEquals(TEXT_PLAIN.toString(), filePart.getContentType(), "wrong content type"); + assertEquals(file, filePart.getFile(), "wrong file"); + // metadata part + StringPart stringPart = (StringPart) r.getBodyParts().get(1); + assertEquals(METADATA.paramName(), stringPart.getName(), "wrong name"); + assertEquals(formParams.get(METADATA), stringPart.getValue(), "wrong value"); } @Test public void testCreatePostByteArrayRequest() { + // Create test data + // String url = "http://test.com" + ApiEndpoints.ADD_ASSET; HashMap params = new HashMap() {{ put(EXPERIMENT_KEY, "someValue"); put(OVERWRITE, Boolean.toString(true)); }}; + HashMap formParams = new HashMap() {{ + put(METADATA, "some string"); + }}; byte[] data = "The test byte data".getBytes(); - Request r = ConnectionUtils.createPostByteArrayRequest(data, url, params); - this.validateRequest(r, url, params, HttpConstants.Methods.POST, ConnectionUtils.FORM_MIME_TYPE); + // Create request + // + Request r = ConnectionUtils.createPostByteArrayRequest(data, url, params, formParams); + this.validateRequest(r, url, params, POST, MULTIPART_FORM_DATA.toString()); - // check body parts - assertEquals(1, r.getBodyParts().size(), "wrong number of body parts"); + // Check body parts + // + assertEquals(2, r.getBodyParts().size(), "wrong number of body parts"); + // data part ByteArrayPart part = (ByteArrayPart) r.getBodyParts().get(0); - assertEquals(ConnectionUtils.FILE, part.getName(), "wrong name"); - assertEquals(ConnectionUtils.FORM_MIME_TYPE, part.getContentType(), "wrong content type"); + assertEquals(FILE.paramName(), part.getName(), "wrong name"); + assertEquals(APPLICATION_OCTET_STREAM.toString(), part.getContentType(), "wrong content type"); assertEquals(data, part.getBytes(), "wrong data array"); + // metadata part + StringPart stringPart = (StringPart) r.getBodyParts().get(1); + assertEquals(METADATA.paramName(), stringPart.getName(), "wrong name"); + assertEquals(formParams.get(METADATA), stringPart.getValue(), "wrong value"); } @Test @@ -89,7 +125,7 @@ public void testCreatePostJsonRequest() { String json = JsonUtils.toJson(html); Request r = ConnectionUtils.createPostJsonRequest(json, url); - this.validateRequest(r, url, null, HttpConstants.Methods.POST, ConnectionUtils.JSON_MIME_TYPE); + this.validateRequest(r, url, null, POST, APPLICATION_JSON.toString()); assertEquals(json.length(), r.getBodyGenerator().createBody().getContentLength(), "wrong body"); } @@ -119,7 +155,7 @@ private void validateRequest(Request r, String url, HashMap allFolderFiles; + private static List subFolderFiles; + + + private static final String someFileExtension = "txt"; + + @BeforeAll + static void setup() throws IOException { + allFolderFiles = new ArrayList<>(); + subFolderFiles = new ArrayList<>(); + // create temporary directory tree + root = Files.createTempDirectory("testAssetUtils"); + allFolderFiles.add( + Files.createTempFile(root, "a_file", "." + someFileExtension)); + allFolderFiles.add( + Files.createTempFile(root, "b_file", "." + someFileExtension)); + allFolderFiles.add( + Files.createTempFile(root, "c_file", "." + someFileExtension)); + + subDir = Files.createTempDirectory(root, "subDir"); + subFolderFiles.add( + Files.createTempFile(subDir, "d_file", "." + someFileExtension)); + subFolderFiles.add( + Files.createTempFile(subDir, "e_file", "." + someFileExtension)); + allFolderFiles.addAll(subFolderFiles); + } + + @AfterAll + static void teardown() throws IOException { + PathUtils.delete(root); + assertFalse(Files.exists(root), "Directory still exists"); + } + + @Test + public void testMapToFileAsset() { + Path file = subFolderFiles.get(0); + Asset asset = AssetUtils.mapToFileAsset( + root.toFile(), file, false, false); + assertNotNull(asset, "asset expected"); + assertEquals(asset.getFile(), file.toFile(), "wrong asset file"); + assertEquals(asset.getFileName(), file.getFileName().toString(), "wrong asset file name"); + assertEquals(someFileExtension, asset.getFileExtension(), "wrong file extension"); + } + + @ParameterizedTest(name = "[{index}] logFilePath: {0}, recursive: {1}, prefixWithFolderName: {2}") + @MethodSource("walkFolderAssetsModifiers") + void testWalkFolderAssets(boolean logFilePath, boolean recursive, boolean prefixWithFolderName) throws IOException { + // tests that correct number of assets returned + // + int expected = allFolderFiles.size(); + if (!recursive) { + expected -= subFolderFiles.size(); + } + Stream assets = AssetUtils.walkFolderAssets( + root.toFile(), logFilePath, recursive, prefixWithFolderName); + assertEquals(expected, assets.count(), "wrong assets count"); + + // tests that assets has been properly populated + // + assets = AssetUtils.walkFolderAssets(root.toFile(), logFilePath, recursive, prefixWithFolderName); + assertTrue( + assets.allMatch( + asset -> Objects.equals(asset.getFileExtension(), someFileExtension) + && StringUtils.isNotBlank(asset.getFileName()) + && asset.getFile() != null), + "wrong asset data"); + + // tests that correct file names recorded + // + assets = AssetUtils.walkFolderAssets(root.toFile(), logFilePath, recursive, prefixWithFolderName); + assets.forEach(asset -> checkAssetFilename(asset, logFilePath, recursive, prefixWithFolderName)); + } + + void checkAssetFilename(Asset asset, boolean logFilePath, boolean recursive, boolean prefixWithFolderName) { + String name = asset.getFileName(); + if (logFilePath && prefixWithFolderName) { + assertTrue(name.startsWith(root.getFileName().toString()), "must have folder name prefix"); + } + if (subFolderFiles.contains(asset.getFile().toPath()) && logFilePath) { + assertTrue(name.contains(subDir.getFileName().toString()), "must include relative file path"); + } + + assertTrue(allFolderFiles.contains(asset.getFile().toPath()), "must be in all files list"); + if (!recursive) { + assertFalse(subFolderFiles.contains(asset.getFile().toPath()), "must be only in top folder files list"); + } + } + + static Stream walkFolderAssetsModifiers() { + // create matrix of all possible combinations: 2^3 = 8 + // the order: logFilePath, recursive, prefixWithFolderName + return Stream.of( + arguments(false, false, false), + arguments(true, false, false), + arguments(true, true, false), + arguments(true, true, true), + arguments(false, false, true), + arguments(false, true, true), + arguments(true, false, true), + arguments(false, true, false) + ); + } +} diff --git a/comet-java-client/src/test/java/ml/comet/experiment/impl/utils/CometUtilsTest.java b/comet-java-client/src/test/java/ml/comet/experiment/impl/utils/CometUtilsTest.java new file mode 100644 index 00000000..fa709f2d --- /dev/null +++ b/comet-java-client/src/test/java/ml/comet/experiment/impl/utils/CometUtilsTest.java @@ -0,0 +1,26 @@ +package ml.comet.experiment.impl.utils; + +import org.apache.commons.lang3.StringUtils; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class CometUtilsTest { + + @Test + public void testJavaSdkVersionParsing() { + // Tests that Comet Java SDK version was set + assertNotNull(CometUtils.COMET_JAVA_SDK_VERSION); + assertTrue(StringUtils.isNotBlank(CometUtils.COMET_JAVA_SDK_VERSION)); + } + + @Test + public void testGenerateGUID() { + String guid = CometUtils.generateGUID(); + assertTrue(StringUtils.isNotBlank(guid), "GUID expected"); + assertEquals(32, guid.length(), "wrong length"); + } +} diff --git a/comet-java-client/src/test/java/ml/comet/experiment/impl/utils/FileUtilsTest.java b/comet-java-client/src/test/java/ml/comet/experiment/impl/utils/FileUtilsTest.java new file mode 100644 index 00000000..bf01a49b --- /dev/null +++ b/comet-java-client/src/test/java/ml/comet/experiment/impl/utils/FileUtilsTest.java @@ -0,0 +1,104 @@ +package ml.comet.experiment.impl.utils; + +import org.apache.commons.io.file.PathUtils; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class FileUtilsTest { + private static Path root; + private static Path subFolderFile; + private static Path subDir; + private static List allFiles; + private static List topFiles; + + @BeforeAll + static void setup() throws IOException { + allFiles = new ArrayList<>(); + topFiles = new ArrayList<>(); + // create temporary directory tree + root = Files.createTempDirectory("testFileUtils"); + topFiles.add( + Files.createTempFile(root, "a_file", ".txt")); + topFiles.add( + Files.createTempFile(root, "b_file", ".txt")); + topFiles.add( + Files.createTempFile(root, "c_file", ".txt")); + + allFiles.addAll(topFiles); + + subDir = Files.createTempDirectory(root, "subDir"); + subFolderFile = Files.createTempFile(subDir, "d_file", ".txt"); + allFiles.add(subFolderFile); + allFiles.add( + Files.createTempFile(subDir, "e_file", ".txt")); + } + + @AfterAll + static void teardown() throws IOException { + PathUtils.delete(root); + assertFalse(Files.exists(root), "Directory still exists"); + } + + @Test + public void testListFilesPlain() throws IOException { + Stream files = FileUtils.listFiles(root.toFile(), false); + assertEquals(topFiles.size(), files.count()); + + files = FileUtils.listFiles(root.toFile(), false); + assertTrue(files.peek(System.out::println) + .allMatch(path -> topFiles.contains(path))); + } + + @Test + public void testListFilesRecursive() throws IOException { + Stream files = FileUtils.listFiles(root.toFile(), true); + assertEquals(allFiles.size(), files.count()); + + files = FileUtils.listFiles(root.toFile(), true); + assertTrue(files.peek(System.out::println) + .allMatch(path -> allFiles.contains(path))); + } + + @Test + public void testResolveAssetFileNameSimple() { + // test only file name + String expected = subFolderFile.getFileName().toString(); + String name = FileUtils.resolveAssetFileName(root.toFile(), subFolderFile, false, false); + System.out.println(name); + assertEquals(expected, name, "wrong simple file name"); + } + + @Test + public void testResolveAssetFileNameRelative() { + // test relative path + String expected = subDir.getFileName().resolve( + subFolderFile.getFileName()).toString(); + String name = FileUtils.resolveAssetFileName(root.toFile(), subFolderFile, true, false); + System.out.println(name); + assertEquals(expected, name, "wrong relative file name"); + } + + @Test + public void testResolveAssetFileNameWithPrefix() { + // test absolute path + String expected = root.getFileName().resolve( + subDir.getFileName()).resolve( + subFolderFile.getFileName()) + .toString(); + String name = FileUtils.resolveAssetFileName(root.toFile(), subFolderFile, true, true); + System.out.println(name); + assertEquals(expected, name, "wrong absolute file name"); + } +} diff --git a/comet-java-client/src/test/java/ml/comet/experiment/impl/utils/TestCometUtils.java b/comet-java-client/src/test/java/ml/comet/experiment/impl/utils/TestCometUtils.java deleted file mode 100644 index a80c23ad..00000000 --- a/comet-java-client/src/test/java/ml/comet/experiment/impl/utils/TestCometUtils.java +++ /dev/null @@ -1,17 +0,0 @@ -package ml.comet.experiment.impl.utils; - -import org.apache.commons.lang3.StringUtils; -import org.junit.jupiter.api.Test; - -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNotNull; - -public class TestCometUtils { - - @Test - public void testJavaSdkVersionParsing() { - // Tests that Comet Java SDK version was set - assertNotNull(CometUtils.COMET_JAVA_SDK_VERSION); - assertFalse(StringUtils.isEmpty(CometUtils.COMET_JAVA_SDK_VERSION)); - } -}