diff --git a/src/test/java/org/opensearch/ad/transport/ADTaskProfileResponseTests.java b/src/test/java/org/opensearch/ad/transport/ADTaskProfileResponseTests.java index 2f6d555d1..a88bbc445 100644 --- a/src/test/java/org/opensearch/ad/transport/ADTaskProfileResponseTests.java +++ b/src/test/java/org/opensearch/ad/transport/ADTaskProfileResponseTests.java @@ -21,6 +21,7 @@ import org.opensearch.ad.model.ADTaskProfile; import org.opensearch.cluster.ClusterName; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; import com.google.common.collect.ImmutableList; @@ -53,4 +54,19 @@ public void testSerializeResponse() throws IOException { assertEquals(1, response2.getNodes().size()); assertEquals(taskId, response2.getNodes().get(0).getAdTaskProfile().getTaskId()); } + + public void testFromActionResponse() throws IOException { + String taskId = randomAlphaOfLength(5); + ADTaskProfile adTaskProfile = new ADTaskProfile(); + adTaskProfile.setTaskId(taskId); + Version remoteAdVersion = Version.CURRENT; + ADTaskProfileNodeResponse nodeResponse = new ADTaskProfileNodeResponse(randomDiscoveryNode(), adTaskProfile, remoteAdVersion); + + List nodeResponses = ImmutableList.of(nodeResponse); + ADTaskProfileResponse response = new ADTaskProfileResponse(new ClusterName("test"), nodeResponses, ImmutableList.of()); + + ADTaskProfileResponse reserializedResponse = ADTaskProfileResponse.fromActionResponse((ActionResponse) response); + assertEquals(1, reserializedResponse.getNodes().size()); + assertEquals(taskId, reserializedResponse.getNodes().get(0).getAdTaskProfile().getTaskId()); + } }