Skip to content

Commit

Permalink
generate master key if not exists
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Ramaram <[email protected]>
  • Loading branch information
rbhavna committed Jun 13, 2024
1 parent 512b8da commit 8abc11a
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,32 @@
package org.opensearch.ml.engine.encryptor;

import static java.util.concurrent.TimeUnit.SECONDS;
import static org.opensearch.ml.common.CommonValue.CREATE_TIME_FIELD;
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX;

import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.time.Instant;
import java.util.Base64;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;

import javax.crypto.spec.SecretKeySpec;

import com.google.common.collect.ImmutableMap;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.DocWriteRequest;
import org.opensearch.action.LatchedActionListener;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.GetResponse;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.engine.VersionConflictEngineException;
import org.opensearch.ml.common.exception.MLException;

import com.amazonaws.encryptionsdk.AwsCrypto;
Expand All @@ -34,6 +41,7 @@

import lombok.extern.log4j.Log4j2;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.engine.indices.MLIndicesHandler;

@Log4j2
public class EncryptorImpl implements Encryptor {
Expand All @@ -43,11 +51,13 @@ public class EncryptorImpl implements Encryptor {
private ClusterService clusterService;
private Client client;
private volatile String masterKey;
private MLIndicesHandler mlIndicesHandler;

public EncryptorImpl(ClusterService clusterService, Client client) {
public EncryptorImpl(ClusterService clusterService, Client client, MLIndicesHandler mlIndicesHandler) {
this.masterKey = null;
this.clusterService = clusterService;
this.client = client;
this.mlIndicesHandler = mlIndicesHandler;
}

public EncryptorImpl(String masterKey) {
Expand All @@ -68,7 +78,7 @@ public String getMasterKey() {
public void encrypt(String plainText, ActionListener<String> listener) {
initMasterKey(new ActionListener<>() {
@Override
public void onResponse(Void unused) {
public void onResponse(Boolean isMasterKeyInitialized) {
try {
final AwsCrypto crypto = AwsCrypto.builder()
.withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt)
Expand Down Expand Up @@ -105,7 +115,7 @@ public void onFailure(Exception e) {
public void decrypt(String encryptedText, ActionListener<String> listener) {
initMasterKey(new ActionListener<>() {
@Override
public void onResponse(Void unused) {
public void onResponse(Boolean isMasterKeyInitialized) {
try {
final AwsCrypto crypto = AwsCrypto.builder().withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt).build();

Expand Down Expand Up @@ -145,54 +155,61 @@ public String generateMasterKey() {
return base64Key;
}

private void initMasterKey(ActionListener<Void> listener) {
private void initMasterKey(ActionListener<Boolean> listener) {
if (masterKey != null) {
listener.onResponse(null);
return;
}
// AtomicReference<Exception> exceptionRef = new AtomicReference<>();
Boolean mlConfig = clusterService.state().metadata().hasIndex(ML_CONFIG_INDEX);

// CountDownLatch latch = new CountDownLatch(1);
if (clusterService.state().metadata().hasIndex(ML_CONFIG_INDEX)) {
mlIndicesHandler.initMLConfigIndex(ActionListener.wrap(r -> {
GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
client.get(getRequest, ActionListener.wrap(response -> {
if (response.isExists()) {
String retrievedMasterKey = (String) response.getSourceAsMap().get(MASTER_KEY);
this.masterKey = retrievedMasterKey;
listener.onResponse(null);
client.get(getRequest, ActionListener.wrap(getResponse -> {
if (!getResponse.isExists()) {
IndexRequest indexRequest = new IndexRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
final String generatedMasterKey = generateMasterKey();
indexRequest.source(ImmutableMap.of(MASTER_KEY, generatedMasterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli()));
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
indexRequest.opType(DocWriteRequest.OpType.CREATE);
client.index(indexRequest, ActionListener.wrap(indexResponse -> {
log.info("ML encryption master key indexed successfully");
this.masterKey = generatedMasterKey;
log.info("ML encryption master key initialized successfully");
listener.onResponse(Boolean.TRUE);
}, e -> {
if (e instanceof VersionConflictEngineException) {
GetRequest getMasterKeyRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
client.get(getMasterKeyRequest, ActionListener.wrap(getMasterKey -> {
if (getMasterKey.isExists()) {
final String masterKey = (String) getResponse.getSourceAsMap().get(MASTER_KEY);
this.masterKey = masterKey;
log.info("ML encryption master key already initialized, no action needed");
listener.onResponse(Boolean.TRUE);
}
}, error -> {
log.debug("Failed to get ML encryption master key", e);
listener.onFailure(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR));
}));
}
} else {
log.debug("Failed to save ML encryption master key", e);
listener.onFailure(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR));
}
}));
} else {
listener.onFailure(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR));
final String masterKey = (String) getResponse.getSourceAsMap().get(MASTER_KEY);
this.masterKey = masterKey;
log.info("ML encryption master key already initialized, no action needed");
listener.onResponse(Boolean.TRUE);
}
}, e -> {
log.error("Failed to get ML encryption master key", e);
listener.onFailure(e);
log.debug("Failed to get ML encryption master key from config index", e);
listener.onFailure(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR));
}));
} catch (Exception e) {
log.error("Failed to get encryption master key", e);
listener.onFailure(e);
}
} else {
listener.onFailure(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR));
}

// try {
// latch.await(5, SECONDS);
// } catch (InterruptedException e) {
// throw new IllegalStateException(e);
// }

// if (exceptionRef.get() != null) {
// log.debug("Failed to init master key", exceptionRef.get());
// if (exceptionRef.get() instanceof RuntimeException) {
// throw (RuntimeException) exceptionRef.get();
// } else {
// throw new MLException(exceptionRef.get());
// }
// }
// if (masterKey == null) {
// throw new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR);
// }
}, e -> {
log.debug("Failed to init ML config index", e);
listener.onFailure(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR));}));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.concurrent.Semaphore;
import java.util.stream.Collectors;

import org.opensearch.action.DocWriteRequest;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.index.IndexRequest;
Expand Down Expand Up @@ -231,6 +232,7 @@ void initMLConfig() {
final String masterKey = encryptor.generateMasterKey();
indexRequest.source(ImmutableMap.of(MASTER_KEY, masterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli()));
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
indexRequest.opType(DocWriteRequest.OpType.CREATE);
client.index(indexRequest, ActionListener.wrap(indexResponse -> {
log.info("ML configuration initialized successfully");
// encryptor.setMasterKey(masterKey);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,9 @@ public Collection<Object> createComponents(
Path dataPath = environment.dataFiles()[0];
Path configFile = environment.configFile();

encryptor = new EncryptorImpl(clusterService, client);
mlIndicesHandler = new MLIndicesHandler(clusterService, client);

encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler);

mlEngine = new MLEngine(dataPath, encryptor);
nodeHelper = new DiscoveryNodeHelper(clusterService, settings);
Expand Down Expand Up @@ -493,7 +495,6 @@ public Collection<Object> createComponents(
stats.put(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier()));
this.mlStats = new MLStats(stats);

mlIndicesHandler = new MLIndicesHandler(clusterService, client);
mlTaskManager = new MLTaskManager(client, threadPool, mlIndicesHandler);
modelHelper = new ModelHelper(mlEngine);

Expand Down

0 comments on commit 8abc11a

Please sign in to comment.