Skip to content

Commit

Permalink
Tests round 1
Browse files Browse the repository at this point in the history
  • Loading branch information
afoucret committed Apr 4, 2024
1 parent 16024b3 commit 2267ee2
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ public class TrainedModelCacheMetadata implements Metadata.Custom {
private static final ConstructingObjectParser<TrainedModelCacheMetadata, Void> PARSER = new ConstructingObjectParser<>(
NAME,
true,
args -> new TrainedModelCacheMetadata((Map<String, TrainedModelCustomMetadataEntry>) args[0])
args -> new TrainedModelCacheMetadata((Map<String, TrainedModelCacheMetadataEntry>) args[0])
);

static {
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> {
Map<String, TrainedModelCustomMetadataEntry> entries = new HashMap<>();
Map<String, TrainedModelCacheMetadataEntry> entries = new HashMap<>();
while (p.nextToken() != XContentParser.Token.END_OBJECT) {
String modelId = p.currentName();
entries.put(modelId, TrainedModelCustomMetadataEntry.fromXContent(p));
entries.put(modelId, TrainedModelCacheMetadataEntry.fromXContent(p));
}
return entries;
}, ENTRIES);
Expand All @@ -80,8 +80,8 @@ public static Set<String> getUpdatedModelIds(ClusterChangedEvent event) {
return Collections.emptySet();
}

Map<String, TrainedModelCustomMetadataEntry> oldCacheMetadataEntries = TrainedModelCacheMetadata.fromState(event.previousState()).entries();
Map<String, TrainedModelCustomMetadataEntry> newCacheMetadataEntries = TrainedModelCacheMetadata.fromState(event.state()).entries();
Map<String, TrainedModelCacheMetadataEntry> oldCacheMetadataEntries = TrainedModelCacheMetadata.fromState(event.previousState()).entries();
Map<String, TrainedModelCacheMetadataEntry> newCacheMetadataEntries = TrainedModelCacheMetadata.fromState(event.state()).entries();

return Sets.union(oldCacheMetadataEntries.keySet(), newCacheMetadataEntries.keySet()).stream()
.filter(modelId -> {
Expand All @@ -94,17 +94,17 @@ public static Set<String> getUpdatedModelIds(ClusterChangedEvent event) {
.collect(Collectors.toSet());
}

private final Map<String, TrainedModelCustomMetadataEntry> entries;
private final Map<String, TrainedModelCacheMetadataEntry> entries;

public TrainedModelCacheMetadata(Map<String, TrainedModelCustomMetadataEntry> entries) {
public TrainedModelCacheMetadata(Map<String, TrainedModelCacheMetadataEntry> entries) {
this.entries = entries;
}

public TrainedModelCacheMetadata(StreamInput in) throws IOException {
this.entries = in.readImmutableMap(TrainedModelCustomMetadataEntry::new);
this.entries = in.readImmutableMap(TrainedModelCacheMetadataEntry::new);
}

public Map<String, TrainedModelCustomMetadataEntry> entries() {
public Map<String, TrainedModelCacheMetadataEntry> entries() {
return entries;
}

Expand Down Expand Up @@ -153,7 +153,7 @@ public int hashCode() {
}

public static class TrainedModelCacheMetadataDiff implements NamedDiff<Metadata.Custom> {
final Diff<Map<String, TrainedModelCustomMetadataEntry>> entriesDiff;
final Diff<Map<String, TrainedModelCacheMetadataEntry>> entriesDiff;

TrainedModelCacheMetadataDiff(TrainedModelCacheMetadata before, TrainedModelCacheMetadata after) {
this.entriesDiff = DiffableUtils.diff(before.entries, after.entries, DiffableUtils.getStringKeySerializer());
Expand All @@ -163,8 +163,8 @@ public static class TrainedModelCacheMetadataDiff implements NamedDiff<Metadata.
this.entriesDiff = DiffableUtils.readJdkMapDiff(
in,
DiffableUtils.getStringKeySerializer(),
TrainedModelCustomMetadataEntry::new,
TrainedModelCustomMetadataEntry::readDiffFrom
TrainedModelCacheMetadataEntry::new,
TrainedModelCacheMetadataEntry::readDiffFrom
);
}

Expand All @@ -189,31 +189,31 @@ public TransportVersion getMinimalSupportedVersion() {
return TransportVersion.current();
}
}
public static class TrainedModelCustomMetadataEntry implements SimpleDiffable<TrainedModelCustomMetadataEntry>, ToXContentObject {
private static final ConstructingObjectParser<TrainedModelCustomMetadataEntry, Void> PARSER = new ConstructingObjectParser<>(
public static class TrainedModelCacheMetadataEntry implements SimpleDiffable<TrainedModelCacheMetadataEntry>, ToXContentObject {
private static final ConstructingObjectParser<TrainedModelCacheMetadataEntry, Void> PARSER = new ConstructingObjectParser<>(
"trained_model_cache_metadata_entry",
true,
args -> new TrainedModelCustomMetadataEntry((String) args[0])
args -> new TrainedModelCacheMetadataEntry((String) args[0])
);
static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), MODEL_ID);
}

private static Diff<TrainedModelCustomMetadataEntry> readDiffFrom(StreamInput in) throws IOException {
return SimpleDiffable.readDiffFrom(TrainedModelCustomMetadataEntry::new, in);
private static Diff<TrainedModelCacheMetadataEntry> readDiffFrom(StreamInput in) throws IOException {
return SimpleDiffable.readDiffFrom(TrainedModelCacheMetadataEntry::new, in);
}

private static TrainedModelCustomMetadataEntry fromXContent(XContentParser parser) {
private static TrainedModelCacheMetadataEntry fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

private final String modelId;

public TrainedModelCustomMetadataEntry(String modelId) {
public TrainedModelCacheMetadataEntry(String modelId) {
this.modelId = modelId;
}

TrainedModelCustomMetadataEntry(StreamInput in) throws IOException {
TrainedModelCacheMetadataEntry(StreamInput in) throws IOException {
this.modelId = in.readString();
}

Expand All @@ -238,7 +238,7 @@ public void writeTo(StreamOutput out) throws IOException {
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TrainedModelCustomMetadataEntry that = (TrainedModelCustomMetadataEntry) o;
TrainedModelCacheMetadataEntry that = (TrainedModelCacheMetadataEntry) o;
return Objects.equals(modelId, that.modelId);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference;

import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractChunkedSerializingTestCase;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelCacheMetadata.TrainedModelCacheMetadataEntry;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class TrainedModelCacheMetadataTests extends AbstractChunkedSerializingTestCase<TrainedModelCacheMetadata> {
public static TrainedModelCacheMetadataEntry randomEntry() {
return new TrainedModelCacheMetadataEntry(randomIdentifier());
}

public static TrainedModelCacheMetadata randomInstance() {
Map<String, TrainedModelCacheMetadataEntry> entries = Stream.generate(TrainedModelCacheMetadataTests::randomEntry)
.limit(randomInt(5))
.collect(Collectors.toMap(TrainedModelCacheMetadataEntry::getModelId, Function.identity(), (k, k1) -> k, HashMap::new));

return new TrainedModelCacheMetadata(entries);
}

@Override
protected TrainedModelCacheMetadata doParseInstance(XContentParser parser) throws IOException {
return TrainedModelCacheMetadata.fromXContent(parser);
}

@Override
protected Writeable.Reader<TrainedModelCacheMetadata> instanceReader() {
return TrainedModelCacheMetadata::new;
}

@Override
protected TrainedModelCacheMetadata createTestInstance() {
return randomInstance();
}

@Override
protected TrainedModelCacheMetadata mutateInstance(TrainedModelCacheMetadata instance) {
return randomValueOtherThan(instance, () -> randomInstance());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class TrainedModelCacheMetadataTests extends AbstractBWCSerializationTestCase<TrainedModelMetadata> {
public class TrainedModelMetadataTests extends AbstractBWCSerializationTestCase<TrainedModelMetadata> {

private boolean lenient;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ private static class PutModelCacheMetadataTask extends ModelCacheMetadataManagem

protected TrainedModelCacheMetadata execute(TrainedModelCacheMetadata currentCacheMetadata, TaskContext<ModelCacheMetadataManagementTask> taskContext) {
var entries = new HashMap<>(currentCacheMetadata.entries());
entries.put(modelId, new TrainedModelCacheMetadata.TrainedModelCustomMetadataEntry(modelId));
entries.put(modelId, new TrainedModelCacheMetadata.TrainedModelCacheMetadataEntry(modelId));
taskContext.success(() -> listener.onResponse(AcknowledgedResponse.TRUE));
return new TrainedModelCacheMetadata(entries);
}
Expand Down
Loading

0 comments on commit 2267ee2

Please sign in to comment.