Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Bogan <[email protected]>
  • Loading branch information
ryanbogan committed Feb 27, 2024
1 parent d0d82f0 commit ec35d3d
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ public void onFailure(Exception e) {
return modelIds;
}

private void updateModelStateAsFailed(String modelId, ModelMetadata modelMetadata, String msg) throws IOException, ExecutionException, InterruptedException {
private void updateModelStateAsFailed(String modelId, ModelMetadata modelMetadata, String msg) throws IOException, ExecutionException,
InterruptedException {
modelMetadata.setState(ModelState.FAILED);
modelMetadata.setError(msg);
Model model = new Model(modelMetadata, null, modelId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,11 @@ private void train(TrainingJob trainingJob) {
private void serializeModel(TrainingJob trainingJob, ActionListener<IndexResponse> listener, boolean update) throws IOException,
ExecutionException, InterruptedException {
if (update) {
ModelMetadata modelMetadata = modelDao.getMetadata(trainingJob.getModelId());
if (modelMetadata.getState().equals(ModelState.TRAINING)) {
Model model = modelDao.get(trainingJob.getModelId());
if (model.getModelMetadata().getState().equals(ModelState.TRAINING)) {
modelDao.update(trainingJob.getModel(), listener);
} else {
logger.info("Model state is {}. Skipping serialization of trained data", modelMetadata.getState());
logger.info("Model state is {}. Skipping serialization of trained data", model.getModelMetadata().getState());
}
} else {
modelDao.put(trainingJob.getModel(), listener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ public void testUpdateModelsNewCluster() throws IOException, InterruptedExceptio
ModelDao modelDao = mock(ModelDao.class);
when(modelDao.isCreated()).thenReturn(true);
when(modelDao.get(modelId)).thenReturn(model);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);
doAnswer(invocationOnMock -> {
SearchResponse searchResponse = mock(SearchResponse.class);
SearchHits searchHits = mock(SearchHits.class);
Expand Down Expand Up @@ -144,6 +145,7 @@ public void testUpdateModelsNodesRemoved() throws IOException, InterruptedExcept
ModelDao modelDao = mock(ModelDao.class);
when(modelDao.isCreated()).thenReturn(true);
when(modelDao.get(modelId)).thenReturn(model);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);
DiscoveryNode node1 = mock(DiscoveryNode.class);
when(node1.getEphemeralId()).thenReturn("test-node-model-match");
DiscoveryNode node2 = mock(DiscoveryNode.class);
Expand Down

0 comments on commit ec35d3d

Please sign in to comment.