Skip to content

Commit

Permalink
Add support for encrypted async blob read
Browse files Browse the repository at this point in the history
Signed-off-by: Kunal Kotwani <[email protected]>
  • Loading branch information
kotwanikunal committed Sep 12, 2023
1 parent b90a888 commit 5c5b8c0
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
import org.opensearch.common.blobstore.stream.read.ReadContext;
import org.opensearch.common.blobstore.stream.write.WriteContext;
import org.opensearch.common.crypto.CryptoHandler;
import org.opensearch.common.crypto.DecryptedRangedStreamProvider;
import org.opensearch.common.io.InputStreamContainer;
import org.opensearch.core.action.ActionListener;
import org.opensearch.threadpool.ThreadPool;

import java.io.IOException;
import java.nio.file.Path;
import java.io.InputStream;
import java.util.List;
import java.util.stream.Collectors;

/**
* EncryptedBlobContainer is an encrypted BlobContainer that is backed by a
Expand All @@ -44,12 +46,8 @@ public void asyncBlobUpload(WriteContext writeContext, ActionListener<Void> comp

@Override
public void readBlobAsync(String blobName, ActionListener<ReadContext> listener) {
throw new UnsupportedOperationException();
}

@Override
public void asyncBlobDownload(String blobName, Path fileLocation, ThreadPool threadPool, ActionListener<String> completionListener) {
throw new UnsupportedOperationException();
DecryptingActionListener decryptingActionListener = new DecryptingActionListener(listener, cryptoHandler);
blobContainer.readBlobAsync(blobName, decryptingActionListener);
}

@Override
Expand Down Expand Up @@ -108,4 +106,69 @@ public InputStreamContainer provideStream(int partNumber) throws IOException {
}

}

static class DecryptedReadContext<T, U> extends ReadContext {

private final U cryptoContext;
private final CryptoHandler<T, U> cryptoHandler;
private final long fileSize;

public DecryptedReadContext(ReadContext readContext, CryptoHandler<T, U> cryptoHandler) {
super(readContext);
this.cryptoHandler = cryptoHandler;
try {
this.cryptoContext = this.cryptoHandler.loadEncryptionMetadata(null);
} catch (IOException e) {
throw new RuntimeException(e);
}
this.fileSize = this.cryptoHandler.estimateDecryptedLength(cryptoContext, readContext.getBlobSize());
}

@Override
public long getBlobSize() {
return fileSize;
}

@Override
public List<InputStreamContainer> getPartStreams() {
return super.getPartStreams().stream().map(this::decrpytInputStreamContainer).collect(Collectors.toList());
}

private InputStreamContainer decrpytInputStreamContainer(InputStreamContainer inputStreamContainer) {
long startOfStream = inputStreamContainer.getOffset();
long endOfStream = startOfStream + inputStreamContainer.getContentLength();
DecryptedRangedStreamProvider rangedStreamProvider = cryptoHandler.createDecryptingStreamOfRange(
cryptoContext,
startOfStream,
endOfStream
);

InputStream decryptedStream = cryptoHandler.createDecryptingStream(inputStreamContainer.getInputStream());
long offset = rangedStreamProvider.getAdjustedRange()[0];
long contentLength = rangedStreamProvider.getAdjustedRange()[1];
return new InputStreamContainer(decryptedStream, contentLength, offset);
}
}

static class DecryptingActionListener<T, U> implements ActionListener<ReadContext> {

private final ActionListener<ReadContext> completionListener;
private final CryptoHandler<T, U> cryptoHandler;

public DecryptingActionListener(ActionListener<ReadContext> completionListener, CryptoHandler<T, U> cryptoHandler) {
this.completionListener = completionListener;
this.cryptoHandler = cryptoHandler;
}

@Override
public void onResponse(ReadContext readContext) {
DecryptedReadContext decryptedReadContext = new DecryptedReadContext(readContext, cryptoHandler);
completionListener.onResponse(decryptedReadContext);
}

@Override
public void onFailure(Exception e) {
completionListener.onFailure(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ public ReadContext(long blobSize, List<InputStreamContainer> partStreams, String
this.blobChecksum = blobChecksum;
}

public ReadContext(ReadContext readContext) {
this.blobSize = readContext.blobSize;
this.partStreams = readContext.partStreams;
this.blobChecksum = readContext.blobChecksum;
}

public String getBlobChecksum() {
return blobChecksum;
}
Expand Down

0 comments on commit 5c5b8c0

Please sign in to comment.