diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java index 4f9dfdc033..f2d8bdc2c5 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImpl.java @@ -51,15 +51,16 @@ public CreateAsyncQueryResponse createAsyncQuery( sparkExecutionEngineConfig.getSparkSubmitParameters(), createAsyncQueryRequest.getSessionId())); asyncQueryJobMetadataStorageService.storeJobMetadata( - new AsyncQueryJobMetadata( - dispatchQueryResponse.getQueryId(), - sparkExecutionEngineConfig.getApplicationId(), - dispatchQueryResponse.getJobId(), - dispatchQueryResponse.getResultIndex(), - dispatchQueryResponse.getSessionId(), - dispatchQueryResponse.getDatasourceName(), - dispatchQueryResponse.getJobType(), - dispatchQueryResponse.getIndexName())); + AsyncQueryJobMetadata.builder() + .queryId(dispatchQueryResponse.getQueryId()) + .applicationId(sparkExecutionEngineConfig.getApplicationId()) + .jobId(dispatchQueryResponse.getJobId()) + .resultIndex(dispatchQueryResponse.getResultIndex()) + .sessionId(dispatchQueryResponse.getSessionId()) + .datasourceName(dispatchQueryResponse.getDatasourceName()) + .jobType(dispatchQueryResponse.getJobType()) + .indexName(dispatchQueryResponse.getIndexName()) + .build()); return new CreateAsyncQueryResponse( dispatchQueryResponse.getQueryId().getId(), dispatchQueryResponse.getSessionId()); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java index cef3b6ede2..2ac67b96ba 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java @@ -7,8 +7,6 @@ package org.opensearch.sql.spark.asyncquery; -import static org.opensearch.sql.spark.execution.statestore.StateStore.createJobMetaData; - import java.util.Optional; import lombok.RequiredArgsConstructor; import org.apache.logging.log4j.LogManager; @@ -16,7 +14,9 @@ import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.execution.statestore.OpenSearchStateStoreUtil; import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.execution.xcontent.AsyncQueryJobMetadataXContentSerializer; /** Opensearch implementation of {@link AsyncQueryJobMetadataStorageService} */ @RequiredArgsConstructor @@ -24,6 +24,7 @@ public class OpensearchAsyncQueryJobMetadataStorageService implements AsyncQueryJobMetadataStorageService { private final StateStore stateStore; + private final AsyncQueryJobMetadataXContentSerializer asyncQueryJobMetadataXContentSerializer; private static final Logger LOGGER = LogManager.getLogger(OpensearchAsyncQueryJobMetadataStorageService.class); @@ -31,15 +32,20 @@ public class OpensearchAsyncQueryJobMetadataStorageService @Override public void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata) { AsyncQueryId queryId = asyncQueryJobMetadata.getQueryId(); - createJobMetaData(stateStore, queryId.getDataSourceName()).apply(asyncQueryJobMetadata); + stateStore.create( + asyncQueryJobMetadata, + AsyncQueryJobMetadata::copy, + OpenSearchStateStoreUtil.getIndexName(queryId.getDataSourceName())); } @Override public Optional getJobMetadata(String qid) { try { AsyncQueryId queryId = new AsyncQueryId(qid); - return StateStore.getJobMetaData(stateStore, queryId.getDataSourceName()) - .apply(queryId.docId()); + return stateStore.get( + queryId.docId(), + asyncQueryJobMetadataXContentSerializer::fromXContent, + OpenSearchStateStoreUtil.getIndexName(queryId.getDataSourceName())); } catch (Exception e) { LOGGER.error("Error while fetching the job metadata.", e); throw new AsyncQueryNotFoundException(String.format("Invalid QueryId: %s", qid)); diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java index bef8218b15..08770c7588 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/model/AsyncQueryJobMetadata.java @@ -7,15 +7,18 @@ package org.opensearch.sql.spark.asyncquery.model; +import com.google.common.collect.ImmutableMap; import com.google.gson.Gson; +import lombok.Builder.Default; import lombok.Data; import lombok.EqualsAndHashCode; -import org.opensearch.index.seqno.SequenceNumbers; +import lombok.experimental.SuperBuilder; import org.opensearch.sql.spark.dispatcher.model.JobType; import org.opensearch.sql.spark.execution.statestore.StateModel; /** This class models all the metadata required for a job. */ @Data +@SuperBuilder @EqualsAndHashCode(callSuper = false) public class AsyncQueryJobMetadata extends StateModel { private final AsyncQueryId queryId; @@ -27,94 +30,12 @@ public class AsyncQueryJobMetadata extends StateModel { // since 2.13 // jobType could be null before OpenSearch 2.12. SparkQueryDispatcher use jobType to choose // cancel query handler. if jobType is null, it will invoke BatchQueryHandler.cancel(). - private final JobType jobType; + @Default private final JobType jobType = JobType.INTERACTIVE; // null if JobType is null private final String datasourceName; // null if JobType is INTERACTIVE or null private final String indexName; - @EqualsAndHashCode.Exclude private final long seqNo; - @EqualsAndHashCode.Exclude private final long primaryTerm; - - public AsyncQueryJobMetadata( - AsyncQueryId queryId, String applicationId, String jobId, String resultIndex) { - this( - queryId, - applicationId, - jobId, - resultIndex, - null, - null, - JobType.INTERACTIVE, - null, - SequenceNumbers.UNASSIGNED_SEQ_NO, - SequenceNumbers.UNASSIGNED_PRIMARY_TERM); - } - - public AsyncQueryJobMetadata( - AsyncQueryId queryId, - String applicationId, - String jobId, - String resultIndex, - String sessionId) { - this( - queryId, - applicationId, - jobId, - resultIndex, - sessionId, - null, - JobType.INTERACTIVE, - null, - SequenceNumbers.UNASSIGNED_SEQ_NO, - SequenceNumbers.UNASSIGNED_PRIMARY_TERM); - } - - public AsyncQueryJobMetadata( - AsyncQueryId queryId, - String applicationId, - String jobId, - String resultIndex, - String sessionId, - String datasourceName, - JobType jobType, - String indexName) { - this( - queryId, - applicationId, - jobId, - resultIndex, - sessionId, - datasourceName, - jobType, - indexName, - SequenceNumbers.UNASSIGNED_SEQ_NO, - SequenceNumbers.UNASSIGNED_PRIMARY_TERM); - } - - public AsyncQueryJobMetadata( - AsyncQueryId queryId, - String applicationId, - String jobId, - String resultIndex, - String sessionId, - String datasourceName, - JobType jobType, - String indexName, - long seqNo, - long primaryTerm) { - this.queryId = queryId; - this.applicationId = applicationId; - this.jobId = jobId; - this.resultIndex = resultIndex; - this.sessionId = sessionId; - this.datasourceName = datasourceName; - this.jobType = jobType; - this.indexName = indexName; - this.seqNo = seqNo; - this.primaryTerm = primaryTerm; - } - @Override public String toString() { return new Gson().toJson(this); @@ -122,18 +43,18 @@ public String toString() { /** copy builder. update seqNo and primaryTerm */ public static AsyncQueryJobMetadata copy( - AsyncQueryJobMetadata copy, long seqNo, long primaryTerm) { - return new AsyncQueryJobMetadata( - copy.getQueryId(), - copy.getApplicationId(), - copy.getJobId(), - copy.getResultIndex(), - copy.getSessionId(), - copy.datasourceName, - copy.jobType, - copy.indexName, - seqNo, - primaryTerm); + AsyncQueryJobMetadata copy, ImmutableMap metadata) { + return builder() + .queryId(copy.queryId) + .applicationId(copy.getApplicationId()) + .jobId(copy.getJobId()) + .resultIndex(copy.getResultIndex()) + .sessionId(copy.getSessionId()) + .datasourceName(copy.datasourceName) + .jobType(copy.jobType) + .indexName(copy.indexName) + .metadata(metadata) + .build(); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java index 9bfead67b6..72980dcb1f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/IndexDMLHandler.java @@ -96,13 +96,14 @@ private AsyncQueryId storeIndexDMLResult( long queryRunTime) { AsyncQueryId asyncQueryId = AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()); IndexDMLResult indexDMLResult = - new IndexDMLResult( - asyncQueryId.getId(), - status, - error, - dispatchQueryRequest.getDatasource(), - queryRunTime, - System.currentTimeMillis()); + IndexDMLResult.builder() + .queryId(asyncQueryId.getId()) + .status(status) + .error(error) + .datasourceName(dispatchQueryRequest.getDatasource()) + .queryRunTime(queryRunTime) + .updateTime(System.currentTimeMillis()) + .build(); indexDMLResultStorageService.createIndexDMLResult(indexDMLResult); return asyncQueryId; } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDMLResult.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDMLResult.java index d0b99e883e..42bddf6c15 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDMLResult.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/IndexDMLResult.java @@ -5,13 +5,15 @@ package org.opensearch.sql.spark.dispatcher.model; +import com.google.common.collect.ImmutableMap; import lombok.Data; import lombok.EqualsAndHashCode; -import org.opensearch.index.seqno.SequenceNumbers; +import lombok.experimental.SuperBuilder; import org.opensearch.sql.spark.execution.statestore.StateModel; /** Plugin create Index DML result. */ @Data +@SuperBuilder @EqualsAndHashCode(callSuper = false) public class IndexDMLResult extends StateModel { public static final String DOC_ID_PREFIX = "index"; @@ -23,28 +25,20 @@ public class IndexDMLResult extends StateModel { private final Long queryRunTime; private final Long updateTime; - public static IndexDMLResult copy(IndexDMLResult copy, long seqNo, long primaryTerm) { - return new IndexDMLResult( - copy.queryId, - copy.status, - copy.error, - copy.datasourceName, - copy.queryRunTime, - copy.updateTime); + public static IndexDMLResult copy(IndexDMLResult copy, ImmutableMap metadata) { + return builder() + .queryId(copy.queryId) + .status(copy.status) + .error(copy.error) + .datasourceName(copy.datasourceName) + .queryRunTime(copy.queryRunTime) + .updateTime(copy.updateTime) + .metadata(metadata) + .build(); } @Override public String getId() { return DOC_ID_PREFIX + queryId; } - - @Override - public long getSeqNo() { - return SequenceNumbers.UNASSIGNED_SEQ_NO; - } - - @Override - public long getPrimaryTerm() { - return SequenceNumbers.UNASSIGNED_PRIMARY_TERM; - } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java index 09e83ea41c..b79bef7b27 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java @@ -8,14 +8,14 @@ import static org.opensearch.sql.spark.execution.session.SessionState.NOT_STARTED; import static org.opensearch.sql.spark.execution.session.SessionType.INTERACTIVE; -import lombok.Builder; +import com.google.common.collect.ImmutableMap; import lombok.Data; -import org.opensearch.index.seqno.SequenceNumbers; +import lombok.experimental.SuperBuilder; import org.opensearch.sql.spark.execution.statestore.StateModel; /** Session data in flint.ql.sessions index. */ @Data -@Builder +@SuperBuilder public class SessionModel extends StateModel { public static final String UNKNOWN = "unknown"; @@ -30,10 +30,7 @@ public class SessionModel extends StateModel { private final String error; private final long lastUpdateTime; - private final long seqNo; - private final long primaryTerm; - - public static SessionModel of(SessionModel copy, long seqNo, long primaryTerm) { + public static SessionModel of(SessionModel copy, ImmutableMap metadata) { return builder() .version(copy.version) .sessionType(copy.sessionType) @@ -44,13 +41,12 @@ public static SessionModel of(SessionModel copy, long seqNo, long primaryTerm) { .jobId(copy.jobId) .error(UNKNOWN) .lastUpdateTime(copy.getLastUpdateTime()) - .seqNo(seqNo) - .primaryTerm(primaryTerm) + .metadata(metadata) .build(); } public static SessionModel copyWithState( - SessionModel copy, SessionState state, long seqNo, long primaryTerm) { + SessionModel copy, SessionState state, ImmutableMap metadata) { return builder() .version(copy.version) .sessionType(copy.sessionType) @@ -61,8 +57,7 @@ public static SessionModel copyWithState( .jobId(copy.jobId) .error(UNKNOWN) .lastUpdateTime(copy.getLastUpdateTime()) - .seqNo(seqNo) - .primaryTerm(primaryTerm) + .metadata(metadata) .build(); } @@ -78,8 +73,6 @@ public static SessionModel initInteractiveSession( .jobId(jobId) .error(UNKNOWN) .lastUpdateTime(System.currentTimeMillis()) - .seqNo(SequenceNumbers.UNASSIGNED_SEQ_NO) - .primaryTerm(SequenceNumbers.UNASSIGNED_PRIMARY_TERM) .build(); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java index f58e3a4f1c..86e8d6e156 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java @@ -7,16 +7,16 @@ import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; -import lombok.Builder; +import com.google.common.collect.ImmutableMap; import lombok.Data; -import org.opensearch.index.seqno.SequenceNumbers; +import lombok.experimental.SuperBuilder; import org.opensearch.sql.spark.execution.session.SessionId; import org.opensearch.sql.spark.execution.statestore.StateModel; import org.opensearch.sql.spark.rest.model.LangType; /** Statement data in flint.ql.sessions index. */ @Data -@Builder +@SuperBuilder public class StatementModel extends StateModel { public static final String UNKNOWN = ""; @@ -33,10 +33,7 @@ public class StatementModel extends StateModel { private final long submitTime; private final String error; - private final long seqNo; - private final long primaryTerm; - - public static StatementModel copy(StatementModel copy, long seqNo, long primaryTerm) { + public static StatementModel copy(StatementModel copy, ImmutableMap metadata) { return builder() .version("1.0") .statementState(copy.statementState) @@ -50,13 +47,12 @@ public static StatementModel copy(StatementModel copy, long seqNo, long primaryT .queryId(copy.queryId) .submitTime(copy.submitTime) .error(copy.error) - .seqNo(seqNo) - .primaryTerm(primaryTerm) + .metadata(metadata) .build(); } public static StatementModel copyWithState( - StatementModel copy, StatementState state, long seqNo, long primaryTerm) { + StatementModel copy, StatementState state, ImmutableMap metadata) { return builder() .version("1.0") .statementState(state) @@ -70,8 +66,7 @@ public static StatementModel copyWithState( .queryId(copy.queryId) .submitTime(copy.submitTime) .error(copy.error) - .seqNo(seqNo) - .primaryTerm(primaryTerm) + .metadata(metadata) .build(); } @@ -97,8 +92,6 @@ public static StatementModel submitStatement( .queryId(queryId) .submitTime(System.currentTimeMillis()) .error(UNKNOWN) - .seqNo(SequenceNumbers.UNASSIGNED_SEQ_NO) - .primaryTerm(SequenceNumbers.UNASSIGNED_PRIMARY_TERM) .build(); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/CopyBuilder.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/CopyBuilder.java index 3ab2c9eb47..e9de7064d5 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/CopyBuilder.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/CopyBuilder.java @@ -5,7 +5,9 @@ package org.opensearch.sql.spark.execution.statestore; +import com.google.common.collect.ImmutableMap; + /** Interface for copying StateModel object. Refer {@link StateStore} */ public interface CopyBuilder { - T of(T copy, long seqNo, long primaryTerm); + T of(T copy, ImmutableMap metadata); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateCopyBuilder.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateCopyBuilder.java index 7bc14f5a2e..1f38e5a1c5 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateCopyBuilder.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateCopyBuilder.java @@ -5,6 +5,8 @@ package org.opensearch.sql.spark.execution.statestore; +import com.google.common.collect.ImmutableMap; + public interface StateCopyBuilder { - T of(T copy, S state, long seqNo, long primaryTerm); + T of(T copy, S state, ImmutableMap metadata); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateModel.java index cc1b9d56d4..9d29299818 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateModel.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateModel.java @@ -5,10 +5,33 @@ package org.opensearch.sql.spark.execution.statestore; +import com.google.common.collect.ImmutableMap; +import java.util.Optional; +import lombok.Builder.Default; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.experimental.SuperBuilder; + +@SuperBuilder public abstract class StateModel { - public abstract String getId(); + @Getter @EqualsAndHashCode.Exclude @Default + private final ImmutableMap metadata = ImmutableMap.of(); - public abstract long getSeqNo(); + public abstract String getId(); - public abstract long getPrimaryTerm(); + public Optional getMetadataItem(String name, Class type) { + if (metadata.containsKey(name)) { + Object value = metadata.get(name); + if (type.isInstance(value)) { + return Optional.of(type.cast(value)); + } else { + throw new RuntimeException( + String.format( + "The metadata field %s is an instance of %s instead of %s", + name, value.getClass(), type)); + } + } else { + return Optional.empty(); + } + } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java index 56d2a0f179..d4141c54d2 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java @@ -42,6 +42,7 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.seqno.SequenceNumbers; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.dispatcher.model.IndexDMLResult; @@ -57,6 +58,7 @@ import org.opensearch.sql.spark.execution.xcontent.StatementModelXContentSerializer; import org.opensearch.sql.spark.execution.xcontent.XContentCommonAttributes; import org.opensearch.sql.spark.execution.xcontent.XContentSerializer; +import org.opensearch.sql.spark.execution.xcontent.XContentSerializerUtil; import org.opensearch.sql.spark.flint.FlintIndexState; import org.opensearch.sql.spark.flint.FlintIndexStateModel; @@ -86,8 +88,8 @@ public T create(T st, CopyBuilder builder, String inde new IndexRequest(indexName) .id(st.getId()) .source(serializer.toXContent(st, ToXContent.EMPTY_PARAMS)) - .setIfSeqNo(st.getSeqNo()) - .setIfPrimaryTerm(st.getPrimaryTerm()) + .setIfSeqNo(getSeqNo(st)) + .setIfPrimaryTerm(getPrimaryTerm(st)) .create(true) .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); try (ThreadContext.StoredContext ignored = @@ -95,7 +97,10 @@ public T create(T st, CopyBuilder builder, String inde IndexResponse indexResponse = client.index(indexRequest).actionGet(); if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { LOG.debug("Successfully created doc. id: {}", st.getId()); - return builder.of(st, indexResponse.getSeqNo(), indexResponse.getPrimaryTerm()); + return builder.of( + st, + XContentSerializerUtil.buildMetadata( + indexResponse.getSeqNo(), indexResponse.getPrimaryTerm())); } else { throw new RuntimeException( String.format( @@ -146,14 +151,14 @@ public Optional get( public T updateState( T st, S state, StateCopyBuilder builder, String indexName) { try { - T model = builder.of(st, state, st.getSeqNo(), st.getPrimaryTerm()); + T model = builder.of(st, state, st.getMetadata()); XContentSerializer serializer = getXContentSerializer(st); UpdateRequest updateRequest = new UpdateRequest() .index(indexName) .id(model.getId()) - .setIfSeqNo(model.getSeqNo()) - .setIfPrimaryTerm(model.getPrimaryTerm()) + .setIfSeqNo(getSeqNo(model)) + .setIfPrimaryTerm(getPrimaryTerm(model)) .doc(serializer.toXContent(model, ToXContent.EMPTY_PARAMS)) .fetchSource(true) .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); @@ -161,13 +166,27 @@ public T updateState( client.threadPool().getThreadContext().stashContext()) { UpdateResponse updateResponse = client.update(updateRequest).actionGet(); LOG.debug("Successfully update doc. id: {}", st.getId()); - return builder.of(model, state, updateResponse.getSeqNo(), updateResponse.getPrimaryTerm()); + return builder.of( + model, + state, + XContentSerializerUtil.buildMetadata( + updateResponse.getSeqNo(), updateResponse.getPrimaryTerm())); } } catch (IOException e) { throw new RuntimeException(e); } } + private long getSeqNo(StateModel model) { + return model.getMetadataItem("seqNo", Long.class).orElse(SequenceNumbers.UNASSIGNED_SEQ_NO); + } + + private long getPrimaryTerm(StateModel model) { + return model + .getMetadataItem("primaryTerm", Long.class) + .orElse(SequenceNumbers.UNASSIGNED_PRIMARY_TERM); + } + /** * Delete the index state document with the given ID. * diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializer.java b/spark/src/main/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializer.java index bf61818b9f..a4209a0ce7 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializer.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializer.java @@ -52,42 +52,37 @@ public XContentBuilder toXContent(AsyncQueryJobMetadata jobMetadata, ToXContent. @Override @SneakyThrows public AsyncQueryJobMetadata fromXContent(XContentParser parser, long seqNo, long primaryTerm) { - AsyncQueryId queryId = null; - String jobId = null; - String applicationId = null; - String resultIndex = null; - String sessionId = null; - String datasourceName = null; - String jobTypeStr = null; - String indexName = null; + AsyncQueryJobMetadata.AsyncQueryJobMetadataBuilder builder = AsyncQueryJobMetadata.builder(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (!XContentParser.Token.END_OBJECT.equals(parser.nextToken())) { String fieldName = parser.currentName(); parser.nextToken(); switch (fieldName) { case QUERY_ID: - queryId = new AsyncQueryId(parser.textOrNull()); + builder.queryId(new AsyncQueryId(parser.textOrNull())); break; case JOB_ID: - jobId = parser.textOrNull(); + builder.jobId(parser.textOrNull()); break; case APPLICATION_ID: - applicationId = parser.textOrNull(); + builder.applicationId(parser.textOrNull()); break; case RESULT_INDEX: - resultIndex = parser.textOrNull(); + builder.resultIndex(parser.textOrNull()); break; case SESSION_ID: - sessionId = parser.textOrNull(); + builder.sessionId(parser.textOrNull()); break; case DATASOURCE_NAME: - datasourceName = parser.textOrNull(); + builder.datasourceName(parser.textOrNull()); break; case JOB_TYPE: - jobTypeStr = parser.textOrNull(); + String jobTypeStr = parser.textOrNull(); + builder.jobType( + Strings.isNullOrEmpty(jobTypeStr) ? null : JobType.fromString(jobTypeStr)); break; case INDEX_NAME: - indexName = parser.textOrNull(); + builder.indexName(parser.textOrNull()); break; case TYPE: break; @@ -95,19 +90,11 @@ public AsyncQueryJobMetadata fromXContent(XContentParser parser, long seqNo, lon throw new IllegalArgumentException("Unknown field: " + fieldName); } } - if (jobId == null || applicationId == null) { + builder.metadata(XContentSerializerUtil.buildMetadata(seqNo, primaryTerm)); + AsyncQueryJobMetadata result = builder.build(); + if (result.getJobId() == null || result.getApplicationId() == null) { throw new IllegalArgumentException("jobId and applicationId are required fields."); } - return new AsyncQueryJobMetadata( - queryId, - applicationId, - jobId, - resultIndex, - sessionId, - datasourceName, - Strings.isNullOrEmpty(jobTypeStr) ? null : JobType.fromString(jobTypeStr), - indexName, - seqNo, - primaryTerm); + return builder.build(); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/xcontent/FlintIndexStateModelXContentSerializer.java b/spark/src/main/java/org/opensearch/sql/spark/execution/xcontent/FlintIndexStateModelXContentSerializer.java index 87ddc6f719..5e47fa2462 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/xcontent/FlintIndexStateModelXContentSerializer.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/xcontent/FlintIndexStateModelXContentSerializer.java @@ -50,7 +50,6 @@ public XContentBuilder toXContent( @Override @SneakyThrows public FlintIndexStateModel fromXContent(XContentParser parser, long seqNo, long primaryTerm) { - // Implement the fromXContent logic here FlintIndexStateModel.FlintIndexStateModelBuilder builder = FlintIndexStateModel.builder(); XContentParserUtils.ensureExpectedToken( XContentParser.Token.START_OBJECT, parser.currentToken(), parser); @@ -81,8 +80,7 @@ public FlintIndexStateModel fromXContent(XContentParser parser, long seqNo, long break; } } - builder.seqNo(seqNo); - builder.primaryTerm(primaryTerm); + builder.metadata(XContentSerializerUtil.buildMetadata(seqNo, primaryTerm)); return builder.build(); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/xcontent/SessionModelXContentSerializer.java b/spark/src/main/java/org/opensearch/sql/spark/execution/xcontent/SessionModelXContentSerializer.java index d453b6ffa9..3ce20ca8b2 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/xcontent/SessionModelXContentSerializer.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/xcontent/SessionModelXContentSerializer.java @@ -52,7 +52,6 @@ public XContentBuilder toXContent(SessionModel sessionModel, ToXContent.Params p @Override @SneakyThrows public SessionModel fromXContent(XContentParser parser, long seqNo, long primaryTerm) { - // Implement the fromXContent logic here SessionModel.SessionModelBuilder builder = SessionModel.builder(); XContentParserUtils.ensureExpectedToken( XContentParser.Token.START_OBJECT, parser.currentToken(), parser); @@ -92,8 +91,7 @@ public SessionModel fromXContent(XContentParser parser, long seqNo, long primary break; } } - builder.seqNo(seqNo); - builder.primaryTerm(primaryTerm); + builder.metadata(XContentSerializerUtil.buildMetadata(seqNo, primaryTerm)); return builder.build(); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/xcontent/StatementModelXContentSerializer.java b/spark/src/main/java/org/opensearch/sql/spark/execution/xcontent/StatementModelXContentSerializer.java index 2323df998d..39fbbd6279 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/xcontent/StatementModelXContentSerializer.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/xcontent/StatementModelXContentSerializer.java @@ -110,8 +110,7 @@ public StatementModel fromXContent(XContentParser parser, long seqNo, long prima throw new IllegalArgumentException("Unexpected field: " + fieldName); } } - builder.seqNo(seqNo); - builder.primaryTerm(primaryTerm); + builder.metadata(XContentSerializerUtil.buildMetadata(seqNo, primaryTerm)); return builder.build(); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/xcontent/XContentSerializerUtil.java b/spark/src/main/java/org/opensearch/sql/spark/execution/xcontent/XContentSerializerUtil.java new file mode 100644 index 0000000000..2f8558d723 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/xcontent/XContentSerializerUtil.java @@ -0,0 +1,14 @@ +package org.opensearch.sql.spark.execution.xcontent; + +import com.google.common.collect.ImmutableMap; +import lombok.experimental.UtilityClass; + +@UtilityClass +public class XContentSerializerUtil { + public static final String SEQ_NO = "seqNo"; + public static final String PRIMARY_TERM = "primaryTerm"; + + public static ImmutableMap buildMetadata(long seqNo, long primaryTerm) { + return ImmutableMap.of(SEQ_NO, seqNo, PRIMARY_TERM, primaryTerm); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexStateModel.java b/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexStateModel.java index 9c03b084db..2b071a1516 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexStateModel.java +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/FlintIndexStateModel.java @@ -5,14 +5,15 @@ package org.opensearch.sql.spark.flint; -import lombok.Builder; +import com.google.common.collect.ImmutableMap; import lombok.EqualsAndHashCode; import lombok.Getter; +import lombok.experimental.SuperBuilder; import org.opensearch.sql.spark.execution.statestore.StateModel; /** Flint Index Model maintain the index state. */ @Getter -@Builder +@SuperBuilder @EqualsAndHashCode(callSuper = false) public class FlintIndexStateModel extends StateModel { private final FlintIndexState indexState; @@ -23,55 +24,32 @@ public class FlintIndexStateModel extends StateModel { private final long lastUpdateTime; private final String error; - @EqualsAndHashCode.Exclude private final long seqNo; - @EqualsAndHashCode.Exclude private final long primaryTerm; - - public FlintIndexStateModel( - FlintIndexState indexState, - String applicationId, - String jobId, - String latestId, - String datasourceName, - long lastUpdateTime, - String error, - long seqNo, - long primaryTerm) { - this.indexState = indexState; - this.applicationId = applicationId; - this.jobId = jobId; - this.latestId = latestId; - this.datasourceName = datasourceName; - this.lastUpdateTime = lastUpdateTime; - this.error = error; - this.seqNo = seqNo; - this.primaryTerm = primaryTerm; - } - - public static FlintIndexStateModel copy(FlintIndexStateModel copy, long seqNo, long primaryTerm) { - return new FlintIndexStateModel( - copy.indexState, - copy.applicationId, - copy.jobId, - copy.latestId, - copy.datasourceName, - copy.lastUpdateTime, - copy.error, - seqNo, - primaryTerm); + public static FlintIndexStateModel copy( + FlintIndexStateModel copy, ImmutableMap metadata) { + return builder() + .indexState(copy.indexState) + .applicationId(copy.applicationId) + .jobId(copy.jobId) + .latestId(copy.latestId) + .datasourceName(copy.datasourceName) + .lastUpdateTime(copy.lastUpdateTime) + .error(copy.error) + .metadata(metadata) + .build(); } public static FlintIndexStateModel copyWithState( - FlintIndexStateModel copy, FlintIndexState state, long seqNo, long primaryTerm) { - return new FlintIndexStateModel( - state, - copy.applicationId, - copy.jobId, - copy.latestId, - copy.datasourceName, - copy.lastUpdateTime, - copy.error, - seqNo, - primaryTerm); + FlintIndexStateModel copy, FlintIndexState state, ImmutableMap metadata) { + return builder() + .indexState(state) + .applicationId(copy.applicationId) + .jobId(copy.jobId) + .latestId(copy.latestId) + .datasourceName(copy.datasourceName) + .lastUpdateTime(copy.lastUpdateTime) + .error(copy.error) + .metadata(metadata) + .build(); } @Override diff --git a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java index 0b1ccc988e..97ddccaf8f 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java +++ b/spark/src/main/java/org/opensearch/sql/spark/flint/operation/FlintIndexOp.java @@ -16,7 +16,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.jetbrains.annotations.NotNull; -import org.opensearch.index.seqno.SequenceNumbers; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; import org.opensearch.sql.spark.flint.FlintIndexMetadata; @@ -81,16 +80,15 @@ private FlintIndexStateModel getFlintIndexStateModel(String latestId) { private void takeActionWithoutOCC(FlintIndexMetadata metadata) { // take action without occ. FlintIndexStateModel fakeModel = - new FlintIndexStateModel( - FlintIndexState.REFRESHING, - metadata.getAppId(), - metadata.getJobId(), - "", - datasourceName, - System.currentTimeMillis(), - "", - SequenceNumbers.UNASSIGNED_SEQ_NO, - SequenceNumbers.UNASSIGNED_PRIMARY_TERM); + FlintIndexStateModel.builder() + .indexState(FlintIndexState.REFRESHING) + .applicationId(metadata.getAppId()) + .jobId(metadata.getJobId()) + .latestId("") + .datasourceName(datasourceName) + .lastUpdateTime(System.currentTimeMillis()) + .error("") + .build(); runOp(metadata, fakeModel); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java index 25f31dcc69..5007cff64e 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java +++ b/spark/src/main/java/org/opensearch/sql/spark/transport/config/AsyncExecutorServiceModule.java @@ -33,6 +33,7 @@ import org.opensearch.sql.spark.execution.statestore.SessionStorageService; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.execution.statestore.StatementStorageService; +import org.opensearch.sql.spark.execution.xcontent.AsyncQueryJobMetadataXContentSerializer; import org.opensearch.sql.spark.execution.xcontent.FlintIndexStateModelXContentSerializer; import org.opensearch.sql.spark.execution.xcontent.SessionModelXContentSerializer; import org.opensearch.sql.spark.execution.xcontent.StatementModelXContentSerializer; @@ -64,8 +65,8 @@ public AsyncQueryExecutorService asyncQueryExecutorService( @Provides public AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService( - StateStore stateStore) { - return new OpensearchAsyncQueryJobMetadataStorageService(stateStore); + StateStore stateStore, AsyncQueryJobMetadataXContentSerializer serializer) { + return new OpensearchAsyncQueryJobMetadataStorageService(stateStore, serializer); } @Provides @@ -137,14 +138,14 @@ public SessionManager sessionManager( @Provides public SessionStorageService sessionStorageService( - StateStore stateStore, SessionModelXContentSerializer sessionModelXContentSerializer) { - return new OpenSearchSessionStorageService(stateStore, sessionModelXContentSerializer); + StateStore stateStore, SessionModelXContentSerializer serializer) { + return new OpenSearchSessionStorageService(stateStore, serializer); } @Provides public StatementStorageService statementStorageService( - StateStore stateStore, StatementModelXContentSerializer statementModelXContentSerializer) { - return new OpenSearchStatementStorageService(stateStore, statementModelXContentSerializer); + StateStore stateStore, StatementModelXContentSerializer serializer) { + return new OpenSearchStatementStorageService(stateStore, serializer); } @Provides diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index f3c17914d2..74b18d0332 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -149,6 +149,7 @@ public void withSessionCreateAsyncQueryThenGetResultThenCancel() { // 2. fetch async query result. AsyncQueryExecutionResponse asyncQueryResults = asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId()); + assertEquals("", asyncQueryResults.getError()); assertTrue(Strings.isEmpty(asyncQueryResults.getError())); assertEquals(StatementState.WAITING.getState(), asyncQueryResults.getStatus()); @@ -314,8 +315,7 @@ public void withSessionCreateAsyncQueryFailed() { .queryId(submitted.getQueryId()) .submitTime(submitted.getSubmitTime()) .error("mock error") - .seqNo(submitted.getSeqNo()) - .primaryTerm(submitted.getPrimaryTerm()) + .metadata(submitted.getMetadata()) .build(); statementStorageService.updateStatementState(mocked, StatementState.FAILED); diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java index 634df6670d..a5dee8f4e8 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplTest.java @@ -13,6 +13,7 @@ import static org.mockito.Mockito.when; import static org.opensearch.sql.spark.asyncquery.OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.DS_NAME; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; +import static org.opensearch.sql.spark.constants.TestConstants.EMRS_EXECUTION_ROLE; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; import static org.opensearch.sql.spark.constants.TestConstants.TEST_CLUSTER_NAME; import static org.opensearch.sql.spark.utils.TestUtils.getJson; @@ -68,35 +69,25 @@ void testCreateAsyncQuery() { when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig()) .thenReturn( new SparkExecutionEngineConfig( - "00fd775baqpu4g0p", - "eu-west-1", - "arn:aws:iam::270824043731:role/emr-job-execution-role", - null, - TEST_CLUSTER_NAME)); - when(sparkQueryDispatcher.dispatch( - new DispatchQueryRequest( - "00fd775baqpu4g0p", - "select * from my_glue.default.http_logs", - "my_glue", - LangType.SQL, - "arn:aws:iam::270824043731:role/emr-job-execution-role", - TEST_CLUSTER_NAME))) + EMRS_APPLICATION_ID, "eu-west-1", EMRS_EXECUTION_ROLE, null, TEST_CLUSTER_NAME)); + DispatchQueryRequest expectedDispatchQueryRequest = + new DispatchQueryRequest( + EMRS_APPLICATION_ID, + "select * from my_glue.default.http_logs", + "my_glue", + LangType.SQL, + EMRS_EXECUTION_ROLE, + TEST_CLUSTER_NAME); + when(sparkQueryDispatcher.dispatch(expectedDispatchQueryRequest)) .thenReturn(new DispatchQueryResponse(QUERY_ID, EMR_JOB_ID, null, null)); + CreateAsyncQueryResponse createAsyncQueryResponse = jobExecutorService.createAsyncQuery(createAsyncQueryRequest); + verify(asyncQueryJobMetadataStorageService, times(1)) - .storeJobMetadata( - new AsyncQueryJobMetadata(QUERY_ID, "00fd775baqpu4g0p", EMR_JOB_ID, null)); + .storeJobMetadata(getAsyncQueryJobMetadata()); verify(sparkExecutionEngineConfigSupplier, times(1)).getSparkExecutionEngineConfig(); - verify(sparkQueryDispatcher, times(1)) - .dispatch( - new DispatchQueryRequest( - "00fd775baqpu4g0p", - "select * from my_glue.default.http_logs", - "my_glue", - LangType.SQL, - "arn:aws:iam::270824043731:role/emr-job-execution-role", - TEST_CLUSTER_NAME)); + verify(sparkQueryDispatcher, times(1)).dispatch(expectedDispatchQueryRequest); Assertions.assertEquals(QUERY_ID.getId(), createAsyncQueryResponse.getQueryId()); } @@ -105,9 +96,9 @@ void testCreateAsyncQueryWithExtraSparkSubmitParameter() { when(sparkExecutionEngineConfigSupplier.getSparkExecutionEngineConfig()) .thenReturn( new SparkExecutionEngineConfig( - "00fd775baqpu4g0p", + EMRS_APPLICATION_ID, "eu-west-1", - "arn:aws:iam::270824043731:role/emr-job-execution-role", + EMRS_APPLICATION_ID, "--conf spark.dynamicAllocation.enabled=false", TEST_CLUSTER_NAME)); when(sparkQueryDispatcher.dispatch(any())) @@ -143,14 +134,10 @@ void testGetAsyncQueryResultsWithJobNotFoundException() { @Test void testGetAsyncQueryResultsWithInProgressJob() { when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) - .thenReturn( - Optional.of( - new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))); + .thenReturn(Optional.of(getAsyncQueryJobMetadata())); JSONObject jobResult = new JSONObject(); jobResult.put("status", JobRunState.PENDING.toString()); - when(sparkQueryDispatcher.getQueryResponse( - new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))) - .thenReturn(jobResult); + when(sparkQueryDispatcher.getQueryResponse(getAsyncQueryJobMetadata())).thenReturn(jobResult); AsyncQueryExecutionResponse asyncQueryExecutionResponse = jobExecutorService.getAsyncQueryResults(EMR_JOB_ID); @@ -163,14 +150,10 @@ void testGetAsyncQueryResultsWithInProgressJob() { @Test void testGetAsyncQueryResultsWithSuccessJob() throws IOException { when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) - .thenReturn( - Optional.of( - new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))); + .thenReturn(Optional.of(getAsyncQueryJobMetadata())); JSONObject jobResult = new JSONObject(getJson("select_query_response.json")); jobResult.put("status", JobRunState.SUCCESS.toString()); - when(sparkQueryDispatcher.getQueryResponse( - new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))) - .thenReturn(jobResult); + when(sparkQueryDispatcher.getQueryResponse(getAsyncQueryJobMetadata())).thenReturn(jobResult); AsyncQueryExecutionResponse asyncQueryExecutionResponse = jobExecutorService.getAsyncQueryResults(EMR_JOB_ID); @@ -202,14 +185,18 @@ void testCancelJobWithJobNotFound() { @Test void testCancelJob() { when(asyncQueryJobMetadataStorageService.getJobMetadata(EMR_JOB_ID)) - .thenReturn( - Optional.of( - new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))); - when(sparkQueryDispatcher.cancelJob( - new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null))) - .thenReturn(EMR_JOB_ID); + .thenReturn(Optional.of(getAsyncQueryJobMetadata())); + when(sparkQueryDispatcher.cancelJob(getAsyncQueryJobMetadata())).thenReturn(EMR_JOB_ID); String jobId = jobExecutorService.cancelQuery(EMR_JOB_ID); Assertions.assertEquals(EMR_JOB_ID, jobId); verifyNoInteractions(sparkExecutionEngineConfigSupplier); } + + private AsyncQueryJobMetadata getAsyncQueryJobMetadata() { + return AsyncQueryJobMetadata.builder() + .queryId(QUERY_ID) + .applicationId(EMRS_APPLICATION_ID) + .jobId(EMR_JOB_ID) + .build(); + } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java index ba75da5dda..85bb92bba2 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java @@ -66,6 +66,7 @@ import org.opensearch.sql.spark.execution.statestore.SessionStorageService; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.execution.statestore.StatementStorageService; +import org.opensearch.sql.spark.execution.xcontent.AsyncQueryJobMetadataXContentSerializer; import org.opensearch.sql.spark.execution.xcontent.FlintIndexStateModelXContentSerializer; import org.opensearch.sql.spark.execution.xcontent.SessionModelXContentSerializer; import org.opensearch.sql.spark.execution.xcontent.StatementModelXContentSerializer; @@ -230,7 +231,8 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService( JobExecutionResponseReader jobExecutionResponseReader) { StateStore stateStore = new StateStore(client, clusterService); AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = - new OpensearchAsyncQueryJobMetadataStorageService(stateStore); + new OpensearchAsyncQueryJobMetadataStorageService( + stateStore, new AsyncQueryJobMetadataXContentSerializer()); QueryHandlerFactory queryHandlerFactory = new QueryHandlerFactory( jobExecutionResponseReader, diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java index 20c944fd0a..431f5b2b15 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java @@ -16,6 +16,7 @@ import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.execution.xcontent.AsyncQueryJobMetadataXContentSerializer; import org.opensearch.test.OpenSearchIntegTestCase; public class OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest @@ -31,17 +32,19 @@ public class OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest public void setup() { opensearchJobMetadataStorageService = new OpensearchAsyncQueryJobMetadataStorageService( - new StateStore(client(), clusterService())); + new StateStore(client(), clusterService()), + new AsyncQueryJobMetadataXContentSerializer()); } @Test public void testStoreJobMetadata() { AsyncQueryJobMetadata expected = - new AsyncQueryJobMetadata( - AsyncQueryId.newAsyncQueryId(DS_NAME), - EMR_JOB_ID, - EMRS_APPLICATION_ID, - MOCK_RESULT_INDEX); + AsyncQueryJobMetadata.builder() + .queryId(AsyncQueryId.newAsyncQueryId(DS_NAME)) + .jobId(EMR_JOB_ID) + .applicationId(EMRS_APPLICATION_ID) + .resultIndex(MOCK_RESULT_INDEX) + .build(); opensearchJobMetadataStorageService.storeJobMetadata(expected); Optional actual = @@ -56,12 +59,13 @@ public void testStoreJobMetadata() { @Test public void testStoreJobMetadataWithResultExtraData() { AsyncQueryJobMetadata expected = - new AsyncQueryJobMetadata( - AsyncQueryId.newAsyncQueryId(DS_NAME), - EMR_JOB_ID, - EMRS_APPLICATION_ID, - MOCK_RESULT_INDEX, - MOCK_SESSION_ID); + AsyncQueryJobMetadata.builder() + .queryId(AsyncQueryId.newAsyncQueryId(DS_NAME)) + .jobId(EMR_JOB_ID) + .applicationId(EMRS_APPLICATION_ID) + .resultIndex(MOCK_RESULT_INDEX) + .sessionId(MOCK_SESSION_ID) + .build(); opensearchJobMetadataStorageService.storeJobMetadata(expected); Optional actual = @@ -69,7 +73,7 @@ public void testStoreJobMetadataWithResultExtraData() { assertTrue(actual.isPresent()); assertEquals(expected, actual.get()); - assertEquals("resultIndex", actual.get().getResultIndex()); + assertEquals(MOCK_RESULT_INDEX, actual.get().getResultIndex()); assertEquals(MOCK_SESSION_ID, actual.get().getSessionId()); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/MockFlintSparkJob.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/MockFlintSparkJob.java index 87cc765071..6c82188ee6 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/MockFlintSparkJob.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/model/MockFlintSparkJob.java @@ -10,7 +10,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.Optional; -import org.opensearch.index.seqno.SequenceNumbers; import org.opensearch.sql.spark.flint.FlintIndexState; import org.opensearch.sql.spark.flint.FlintIndexStateModel; import org.opensearch.sql.spark.flint.FlintIndexStateModelService; @@ -26,16 +25,15 @@ public MockFlintSparkJob( this.flintIndexStateModelService = flintIndexStateModelService; this.datasource = datasource; stateModel = - new FlintIndexStateModel( - FlintIndexState.EMPTY, - "mockAppId", - "mockJobId", - latestId, - datasource, - System.currentTimeMillis(), - "", - SequenceNumbers.UNASSIGNED_SEQ_NO, - SequenceNumbers.UNASSIGNED_PRIMARY_TERM); + FlintIndexStateModel.builder() + .indexState(FlintIndexState.EMPTY) + .applicationId("mockAppId") + .jobId("mockJobId") + .latestId(latestId) + .datasourceName(datasource) + .lastUpdateTime(System.currentTimeMillis()) + .error("") + .build(); stateModel = flintIndexStateModelService.createFlintIndexStateModel(stateModel); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index 19be7fd9fb..08aa0e4d0e 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -1199,12 +1199,20 @@ private DispatchQueryRequest dispatchQueryRequestWithSessionId(String query, Str } private AsyncQueryJobMetadata asyncQueryJobMetadata() { - return new AsyncQueryJobMetadata(QUERY_ID, EMRS_APPLICATION_ID, EMR_JOB_ID, null); + return AsyncQueryJobMetadata.builder() + .queryId(QUERY_ID) + .applicationId(EMRS_APPLICATION_ID) + .jobId(EMR_JOB_ID) + .build(); } private AsyncQueryJobMetadata asyncQueryJobMetadataWithSessionId( String statementId, String sessionId) { - return new AsyncQueryJobMetadata( - new AsyncQueryId(statementId), EMRS_APPLICATION_ID, EMR_JOB_ID, null, sessionId); + return AsyncQueryJobMetadata.builder() + .queryId(new AsyncQueryId(statementId)) + .applicationId(EMRS_APPLICATION_ID) + .jobId(EMR_JOB_ID) + .sessionId(sessionId) + .build(); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java index 010c8b7c6a..e3f610000c 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -158,10 +158,7 @@ public void cancelSuccessStatementFailed() { StatementModel model = st.getStatementModel(); st.setStatementModel( StatementModel.copyWithState( - st.getStatementModel(), - StatementState.SUCCESS, - model.getSeqNo(), - model.getPrimaryTerm())); + st.getStatementModel(), StatementState.SUCCESS, model.getMetadata())); // cancel conflict IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); @@ -179,10 +176,7 @@ public void cancelFailedStatementFailed() { StatementModel model = st.getStatementModel(); st.setStatementModel( StatementModel.copyWithState( - st.getStatementModel(), - StatementState.FAILED, - model.getSeqNo(), - model.getPrimaryTerm())); + st.getStatementModel(), StatementState.FAILED, model.getMetadata())); // cancel conflict IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); @@ -199,8 +193,7 @@ public void cancelCancelledStatementFailed() { // update to running state StatementModel model = st.getStatementModel(); st.setStatementModel( - StatementModel.copyWithState( - st.getStatementModel(), CANCELLED, model.getSeqNo(), model.getPrimaryTerm())); + StatementModel.copyWithState(st.getStatementModel(), CANCELLED, model.getMetadata())); // cancel conflict IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/StateModelTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/StateModelTest.java new file mode 100644 index 0000000000..15d1ec2ecc --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/StateModelTest.java @@ -0,0 +1,49 @@ +package org.opensearch.sql.spark.execution.statestore; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.google.common.collect.ImmutableMap; +import java.util.Optional; +import lombok.Data; +import lombok.experimental.SuperBuilder; +import org.junit.jupiter.api.Test; + +class StateModelTest { + + public static final String METADATA_KEY = "KEY"; + public static final String METADATA_VALUE = "VALUE"; + public static final String UNKNOWN_KEY = "UNKNOWN_KEY"; + + @Data + @SuperBuilder + static class ConcreteStateModel extends StateModel { + @Override + public String getId() { + return null; + } + } + + ConcreteStateModel model = + ConcreteStateModel.builder().metadata(ImmutableMap.of(METADATA_KEY, METADATA_VALUE)).build(); + + @Test + public void whenMetadataExist() { + Optional result = model.getMetadataItem(METADATA_KEY, String.class); + + assertEquals(METADATA_VALUE, result.get()); + } + + @Test + public void whenMetadataNotExist() { + Optional result = model.getMetadataItem(UNKNOWN_KEY, String.class); + + assertFalse(result.isPresent()); + } + + @Test + public void whenTypeDoNotMatch() { + assertThrows(RuntimeException.class, () -> model.getMetadataItem(METADATA_KEY, Long.class)); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializerTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializerTest.java index d393c383c6..cf658ea017 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/AsyncQueryJobMetadataXContentSerializerTest.java @@ -28,17 +28,17 @@ class AsyncQueryJobMetadataXContentSerializerTest { @Test void toXContentShouldSerializeAsyncQueryJobMetadata() throws Exception { AsyncQueryJobMetadata jobMetadata = - new AsyncQueryJobMetadata( - new AsyncQueryId("query1"), - "app1", - "job1", - "result1", - "session1", - "datasource1", - JobType.INTERACTIVE, - "index1", - 1L, - 1L); + AsyncQueryJobMetadata.builder() + .queryId(new AsyncQueryId("query1")) + .applicationId("app1") + .jobId("job1") + .resultIndex("result1") + .sessionId("session1") + .datasourceName("datasource1") + .jobType(JobType.INTERACTIVE) + .indexName("index1") + .metadata(XContentSerializerUtil.buildMetadata(1L, 1L)) + .build(); XContentBuilder xContentBuilder = serializer.toXContent(jobMetadata, ToXContent.EMPTY_PARAMS); String json = xContentBuilder.toString(); @@ -56,23 +56,19 @@ void toXContentShouldSerializeAsyncQueryJobMetadata() throws Exception { @Test void fromXContentShouldDeserializeAsyncQueryJobMetadata() throws Exception { - String json = - "{\n" - + " \"queryId\": \"query1\",\n" - + " \"type\": \"jobmeta\",\n" - + " \"jobId\": \"job1\",\n" - + " \"applicationId\": \"app1\",\n" - + " \"resultIndex\": \"result1\",\n" - + " \"sessionId\": \"session1\",\n" - + " \"dataSourceName\": \"datasource1\",\n" - + " \"jobType\": \"interactive\",\n" - + " \"indexName\": \"index1\"\n" - + "}"; XContentParser parser = - XContentType.JSON - .xContent() - .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, json); - parser.nextToken(); + prepareParserForJson( + "{\n" + + " \"queryId\": \"query1\",\n" + + " \"type\": \"jobmeta\",\n" + + " \"jobId\": \"job1\",\n" + + " \"applicationId\": \"app1\",\n" + + " \"resultIndex\": \"result1\",\n" + + " \"sessionId\": \"session1\",\n" + + " \"dataSourceName\": \"datasource1\",\n" + + " \"jobType\": \"interactive\",\n" + + " \"indexName\": \"index1\"\n" + + "}"); AsyncQueryJobMetadata jobMetadata = serializer.fromXContent(parser, 1L, 1L); @@ -88,87 +84,61 @@ void fromXContentShouldDeserializeAsyncQueryJobMetadata() throws Exception { @Test void fromXContentShouldThrowExceptionWhenMissingRequiredFields() throws Exception { - String json = - "{\n" - + " \"queryId\": \"query1\",\n" - + " \"type\": \"asyncqueryjobmeta\",\n" - + " \"resultIndex\": \"result1\",\n" - + " \"sessionId\": \"session1\",\n" - + " \"dataSourceName\": \"datasource1\",\n" - + " \"jobType\": \"async_query\",\n" - + " \"indexName\": \"index1\"\n" - + "}"; XContentParser parser = - XContentType.JSON - .xContent() - .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, json); - parser.nextToken(); + prepareParserForJson( + "{\n" + + " \"queryId\": \"query1\",\n" + + " \"type\": \"asyncqueryjobmeta\",\n" + + " \"resultIndex\": \"result1\",\n" + + " \"sessionId\": \"session1\",\n" + + " \"dataSourceName\": \"datasource1\",\n" + + " \"jobType\": \"async_query\",\n" + + " \"indexName\": \"index1\"\n" + + "}"); assertThrows(IllegalArgumentException.class, () -> serializer.fromXContent(parser, 1L, 1L)); } @Test void fromXContentShouldDeserializeWithMissingApplicationId() throws Exception { - String json = - "{\n" - + " \"queryId\": \"query1\",\n" - + " \"type\": \"jobmeta\",\n" - + " \"jobId\": \"job1\",\n" - + " \"resultIndex\": \"result1\",\n" - + " \"sessionId\": \"session1\",\n" - + " \"dataSourceName\": \"datasource1\",\n" - + " \"jobType\": \"interactive\",\n" - + " \"indexName\": \"index1\"\n" - + "}"; XContentParser parser = - XContentType.JSON - .xContent() - .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, json); - parser.nextToken(); + prepareParserForJson( + "{\n" + + " \"queryId\": \"query1\",\n" + + " \"type\": \"jobmeta\",\n" + + " \"jobId\": \"job1\",\n" + + " \"resultIndex\": \"result1\",\n" + + " \"sessionId\": \"session1\",\n" + + " \"dataSourceName\": \"datasource1\",\n" + + " \"jobType\": \"interactive\",\n" + + " \"indexName\": \"index1\"\n" + + "}"); assertThrows(IllegalArgumentException.class, () -> serializer.fromXContent(parser, 1L, 1L)); } @Test void fromXContentShouldThrowExceptionWhenUnknownFields() throws Exception { - String json = - "{\n" - + " \"queryId\": \"query1\",\n" - + " \"type\": \"asyncqueryjobmeta\",\n" - + " \"resultIndex\": \"result1\",\n" - + " \"sessionId\": \"session1\",\n" - + " \"dataSourceName\": \"datasource1\",\n" - + " \"jobType\": \"async_query\",\n" - + " \"indexame\": \"index1\"\n" - + "}"; - XContentParser parser = - XContentType.JSON - .xContent() - .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, json); - parser.nextToken(); + XContentParser parser = prepareParserForJson("{\"unknownAttr\": \"index1\"}"); assertThrows(IllegalArgumentException.class, () -> serializer.fromXContent(parser, 1L, 1L)); } @Test void fromXContentShouldDeserializeAsyncQueryWithJobTypeNUll() throws Exception { - String json = - "{\n" - + " \"queryId\": \"query1\",\n" - + " \"type\": \"jobmeta\",\n" - + " \"jobId\": \"job1\",\n" - + " \"applicationId\": \"app1\",\n" - + " \"resultIndex\": \"result1\",\n" - + " \"sessionId\": \"session1\",\n" - + " \"dataSourceName\": \"datasource1\",\n" - + " \"jobType\": \"\",\n" - + " \"indexName\": \"index1\"\n" - + "}"; XContentParser parser = - XContentType.JSON - .xContent() - .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, json); - parser.nextToken(); + prepareParserForJson( + "{\n" + + " \"queryId\": \"query1\",\n" + + " \"type\": \"jobmeta\",\n" + + " \"jobId\": \"job1\",\n" + + " \"applicationId\": \"app1\",\n" + + " \"resultIndex\": \"result1\",\n" + + " \"sessionId\": \"session1\",\n" + + " \"dataSourceName\": \"datasource1\",\n" + + " \"jobType\": \"\",\n" + + " \"indexName\": \"index1\"\n" + + "}"); AsyncQueryJobMetadata jobMetadata = serializer.fromXContent(parser, 1L, 1L); @@ -181,4 +151,28 @@ void fromXContentShouldDeserializeAsyncQueryWithJobTypeNUll() throws Exception { assertNull(jobMetadata.getJobType()); assertEquals("index1", jobMetadata.getIndexName()); } + + @Test + void fromXContentShouldDeserializeAsyncQueryWithoutJobId() throws Exception { + XContentParser parser = + prepareParserForJson("{\"queryId\": \"query1\", \"applicationId\": \"app1\"}"); + + assertThrows(IllegalArgumentException.class, () -> serializer.fromXContent(parser, 1L, 1L)); + } + + @Test + void fromXContentShouldDeserializeAsyncQueryWithoutApplicationId() throws Exception { + XContentParser parser = prepareParserForJson("{\"queryId\": \"query1\", \"jobId\": \"job1\"}"); + + assertThrows(IllegalArgumentException.class, () -> serializer.fromXContent(parser, 1L, 1L)); + } + + private XContentParser prepareParserForJson(String json) throws Exception { + XContentParser parser = + XContentType.JSON + .xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, json); + parser.nextToken(); + return parser; + } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/IndexDMLResultXContentSerializerTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/IndexDMLResultXContentSerializerTest.java index de614235f6..edf88bad42 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/IndexDMLResultXContentSerializerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/IndexDMLResultXContentSerializerTest.java @@ -21,7 +21,14 @@ class IndexDMLResultXContentSerializerTest { @Test void toXContentShouldSerializeIndexDMLResult() throws IOException { IndexDMLResult dmlResult = - new IndexDMLResult("query1", "SUCCESS", null, "datasource1", 1000L, 2000L); + IndexDMLResult.builder() + .queryId("query1") + .status("SUCCESS") + .error(null) + .datasourceName("datasource1") + .queryRunTime(1000L) + .updateTime(2000L) + .build(); XContentBuilder xContentBuilder = serializer.toXContent(dmlResult, ToXContent.EMPTY_PARAMS); String json = xContentBuilder.toString(); @@ -39,7 +46,14 @@ void toXContentShouldSerializeIndexDMLResult() throws IOException { @Test void toXContentShouldHandleErrorInIndexDMLResult() throws IOException { IndexDMLResult dmlResult = - new IndexDMLResult("query1", "FAILURE", "An error occurred", "datasource1", 1000L, 2000L); + IndexDMLResult.builder() + .queryId("query1") + .status("FAILURE") + .error("An error occurred") + .datasourceName("datasource1") + .queryRunTime(1000L) + .updateTime(2000L) + .build(); XContentBuilder xContentBuilder = serializer.toXContent(dmlResult, ToXContent.EMPTY_PARAMS); diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/XContentSerializerUtilTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/XContentSerializerUtilTest.java new file mode 100644 index 0000000000..5bd8795663 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/xcontent/XContentSerializerUtilTest.java @@ -0,0 +1,17 @@ +package org.opensearch.sql.spark.execution.xcontent; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.google.common.collect.ImmutableMap; +import org.junit.jupiter.api.Test; + +class XContentSerializerUtilTest { + @Test + public void testBuildMetadata() { + ImmutableMap result = XContentSerializerUtil.buildMetadata(1, 2); + + assertEquals(2, result.size()); + assertEquals(1L, result.get(XContentSerializerUtil.SEQ_NO)); + assertEquals(2L, result.get(XContentSerializerUtil.PRIMARY_TERM)); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java b/spark/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java index 6c2a3a81a4..0c82733ae6 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/flint/operation/FlintIndexOpTest.java @@ -16,8 +16,8 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.index.seqno.SequenceNumbers; import org.opensearch.sql.spark.client.EMRServerlessClientFactory; +import org.opensearch.sql.spark.execution.xcontent.XContentSerializerUtil; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexState; import org.opensearch.sql.spark.flint.FlintIndexStateModel; @@ -33,25 +33,17 @@ public class FlintIndexOpTest { public void testApplyWithTransitioningStateFailure() { FlintIndexMetadata metadata = mock(FlintIndexMetadata.class); when(metadata.getLatestId()).thenReturn(Optional.of("latestId")); - FlintIndexStateModel fakeModel = - new FlintIndexStateModel( - FlintIndexState.ACTIVE, - metadata.getAppId(), - metadata.getJobId(), - "latestId", - "myS3", - System.currentTimeMillis(), - "", - SequenceNumbers.UNASSIGNED_SEQ_NO, - SequenceNumbers.UNASSIGNED_PRIMARY_TERM); + FlintIndexStateModel fakeModel = getFlintIndexStateModel(metadata); when(flintIndexStateModelService.getFlintIndexStateModel(eq("latestId"), any())) .thenReturn(Optional.of(fakeModel)); when(flintIndexStateModelService.updateFlintIndexState(any(), any(), any())) .thenThrow(new RuntimeException("Transitioning state failed")); FlintIndexOp flintIndexOp = new TestFlintIndexOp(flintIndexStateModelService, "myS3", mockEmrServerlessClientFactory); + IllegalStateException illegalStateException = Assertions.assertThrows(IllegalStateException.class, () -> flintIndexOp.apply(metadata)); + Assertions.assertEquals( "Moving to transition state:DELETING failed.", illegalStateException.getMessage()); } @@ -60,27 +52,21 @@ public void testApplyWithTransitioningStateFailure() { public void testApplyWithCommitFailure() { FlintIndexMetadata metadata = mock(FlintIndexMetadata.class); when(metadata.getLatestId()).thenReturn(Optional.of("latestId")); - FlintIndexStateModel fakeModel = - new FlintIndexStateModel( - FlintIndexState.ACTIVE, - metadata.getAppId(), - metadata.getJobId(), - "latestId", - "myS3", - System.currentTimeMillis(), - "", - SequenceNumbers.UNASSIGNED_SEQ_NO, - SequenceNumbers.UNASSIGNED_PRIMARY_TERM); + FlintIndexStateModel fakeModel = getFlintIndexStateModel(metadata); when(flintIndexStateModelService.getFlintIndexStateModel(eq("latestId"), any())) .thenReturn(Optional.of(fakeModel)); when(flintIndexStateModelService.updateFlintIndexState(any(), any(), any())) - .thenReturn(FlintIndexStateModel.copy(fakeModel, 1, 2)) + .thenReturn( + FlintIndexStateModel.copy(fakeModel, XContentSerializerUtil.buildMetadata(1, 2))) .thenThrow(new RuntimeException("Commit state failed")) - .thenReturn(FlintIndexStateModel.copy(fakeModel, 1, 3)); + .thenReturn( + FlintIndexStateModel.copy(fakeModel, XContentSerializerUtil.buildMetadata(1, 3))); FlintIndexOp flintIndexOp = new TestFlintIndexOp(flintIndexStateModelService, "myS3", mockEmrServerlessClientFactory); + IllegalStateException illegalStateException = Assertions.assertThrows(IllegalStateException.class, () -> flintIndexOp.apply(metadata)); + Assertions.assertEquals( "commit failed. target stable state: [DELETED]", illegalStateException.getMessage()); } @@ -89,31 +75,36 @@ public void testApplyWithCommitFailure() { public void testApplyWithRollBackFailure() { FlintIndexMetadata metadata = mock(FlintIndexMetadata.class); when(metadata.getLatestId()).thenReturn(Optional.of("latestId")); - FlintIndexStateModel fakeModel = - new FlintIndexStateModel( - FlintIndexState.ACTIVE, - metadata.getAppId(), - metadata.getJobId(), - "latestId", - "myS3", - System.currentTimeMillis(), - "", - SequenceNumbers.UNASSIGNED_SEQ_NO, - SequenceNumbers.UNASSIGNED_PRIMARY_TERM); + FlintIndexStateModel fakeModel = getFlintIndexStateModel(metadata); when(flintIndexStateModelService.getFlintIndexStateModel(eq("latestId"), any())) .thenReturn(Optional.of(fakeModel)); when(flintIndexStateModelService.updateFlintIndexState(any(), any(), any())) - .thenReturn(FlintIndexStateModel.copy(fakeModel, 1, 2)) + .thenReturn( + FlintIndexStateModel.copy(fakeModel, XContentSerializerUtil.buildMetadata(1, 2))) .thenThrow(new RuntimeException("Commit state failed")) .thenThrow(new RuntimeException("Rollback failure")); FlintIndexOp flintIndexOp = new TestFlintIndexOp(flintIndexStateModelService, "myS3", mockEmrServerlessClientFactory); + IllegalStateException illegalStateException = Assertions.assertThrows(IllegalStateException.class, () -> flintIndexOp.apply(metadata)); + Assertions.assertEquals( "commit failed. target stable state: [DELETED]", illegalStateException.getMessage()); } + private FlintIndexStateModel getFlintIndexStateModel(FlintIndexMetadata metadata) { + return FlintIndexStateModel.builder() + .indexState(FlintIndexState.ACTIVE) + .applicationId(metadata.getAppId()) + .jobId(metadata.getJobId()) + .latestId("latestId") + .datasourceName("myS3") + .lastUpdateTime(System.currentTimeMillis()) + .error("") + .build(); + } + static class TestFlintIndexOp extends FlintIndexOp { public TestFlintIndexOp(