Skip to content

Commit

Permalink
Merge pull request #69 from nstdio/opt
Browse files Browse the repository at this point in the history
chore: Optimize encrypted file header length reading
  • Loading branch information
nstdio authored Apr 8, 2022
2 parents 7c3b024 + 7cd8c1f commit 65e2707
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import javax.crypto.NoSuchPaddingException;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
Expand Down Expand Up @@ -155,7 +156,7 @@ private void writeParams(Cipher cipher, OutputStream out) throws IOException {
private AlgorithmParameters readParams(InputStream is, Cipher c) throws Exception {
if (c.getParameters() != null) {
try {
int len = new DataInputStream(is).readInt();
int len = readInt(is);
byte[] encodedParams = is.readNBytes(len);
var parameters = algorithmParameters();
parameters.init(encodedParams);
Expand All @@ -170,6 +171,20 @@ private AlgorithmParameters readParams(InputStream is, Cipher c) throws Exceptio
return null;
}

/**
* It's not worth to instantiate {@link DataInputStream}.
*/
private int readInt(InputStream is) throws IOException {
int ch1 = is.read();
int ch2 = is.read();
int ch3 = is.read();
int ch4 = is.read();
if ((ch1 | ch2 | ch3 | ch4) < 0)
throw new EOFException();

return ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4));
}

private IOException asIOException(Exception e) {
if (e instanceof IOException)
return ((IOException) e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package io.github.nstdio.http.ext

import io.github.nstdio.http.ext.IOUtils.bufferedWriter
import io.kotest.assertions.assertSoftly
import io.kotest.assertions.throwables.shouldThrowExactly
import io.kotest.matchers.shouldBe
import io.kotest.matchers.shouldNotBe
Expand All @@ -34,6 +35,7 @@ import org.mockito.BDDMockito.times
import org.mockito.BDDMockito.verify
import org.mockito.Mockito.any
import org.mockito.Mockito.mock
import java.io.EOFException
import java.io.IOException
import java.io.InputStream
import java.io.OutputStream
Expand All @@ -42,6 +44,7 @@ import java.nio.file.Path
import java.security.InvalidKeyException
import java.security.Key
import java.security.NoSuchProviderException
import java.security.ProviderException
import kotlin.io.path.readText

class EncryptedStreamFactoryTest {
Expand Down Expand Up @@ -115,16 +118,36 @@ class EncryptedStreamFactoryTest {

@Test
fun `Should throw when key is invalid`() {
//given
val delegate = mock(StreamFactory::class.java).also {
given(it.input(any())).willReturn(InputStream.nullInputStream())
}

val factory = Crypto.rsaKeyPair().let {
EncryptedStreamFactory(delegate, it.public, it.private, "AES/CBC/PKCS5Padding", null)
}

//when + then
assertSoftly {
shouldThrowExactly<IOException> { factory.output(Path.of("any")) }
.shouldHaveCauseInstanceOf<InvalidKeyException>()

shouldThrowExactly<IOException> { factory.input(Path.of("any")) }
.shouldHaveCauseInstanceOf<ProviderException>()
}
}

@Test
fun `Should throw eof when cannot read header length`() {
//given
val delegate = mock(StreamFactory::class.java)
val key = Crypto.rsaKeyPair()
val transform = "AES/CBC/PKCS5Padding"
given(delegate.input(any())).willReturn(InputStream.nullInputStream())

val factory = EncryptedStreamFactory(delegate, key.public, key.private, transform, null)
val key = Crypto.pbe()
val factory = EncryptedStreamFactory(delegate, key, key, "AES/CBC/PKCS5Padding", null)

//when + then
shouldThrowExactly<IOException> { factory.output(Path.of("any")) }
.shouldHaveCauseInstanceOf<InvalidKeyException>()
shouldThrowExactly<EOFException> { factory.input(Path.of("any")) }
}

private fun InputStream.readText(): String {
Expand Down

0 comments on commit 65e2707

Please sign in to comment.