From 7cd8c1fabd15b28f91c02dea4aaa9dd80c5d489e Mon Sep 17 00:00:00 2001 From: Edgar Asatryan Date: Sat, 9 Apr 2022 01:32:14 +0400 Subject: [PATCH] chore: Optimize encrypted file header length reading --- .../http/ext/EncryptedStreamFactory.java | 17 +++++++++- .../http/ext/EncryptedStreamFactoryTest.kt | 33 ++++++++++++++++--- 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/src/main/java/io/github/nstdio/http/ext/EncryptedStreamFactory.java b/src/main/java/io/github/nstdio/http/ext/EncryptedStreamFactory.java index 932e8e6..01af92b 100644 --- a/src/main/java/io/github/nstdio/http/ext/EncryptedStreamFactory.java +++ b/src/main/java/io/github/nstdio/http/ext/EncryptedStreamFactory.java @@ -22,6 +22,7 @@ import javax.crypto.NoSuchPaddingException; import java.io.DataInputStream; import java.io.DataOutputStream; +import java.io.EOFException; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -155,7 +156,7 @@ private void writeParams(Cipher cipher, OutputStream out) throws IOException { private AlgorithmParameters readParams(InputStream is, Cipher c) throws Exception { if (c.getParameters() != null) { try { - int len = new DataInputStream(is).readInt(); + int len = readInt(is); byte[] encodedParams = is.readNBytes(len); var parameters = algorithmParameters(); parameters.init(encodedParams); @@ -170,6 +171,20 @@ private AlgorithmParameters readParams(InputStream is, Cipher c) throws Exceptio return null; } + /** + * It's not worth to instantiate {@link DataInputStream}. + */ + private int readInt(InputStream is) throws IOException { + int ch1 = is.read(); + int ch2 = is.read(); + int ch3 = is.read(); + int ch4 = is.read(); + if ((ch1 | ch2 | ch3 | ch4) < 0) + throw new EOFException(); + + return ((ch1 << 24) + (ch2 << 16) + (ch3 << 8) + (ch4)); + } + private IOException asIOException(Exception e) { if (e instanceof IOException) return ((IOException) e); diff --git a/src/test/kotlin/io/github/nstdio/http/ext/EncryptedStreamFactoryTest.kt b/src/test/kotlin/io/github/nstdio/http/ext/EncryptedStreamFactoryTest.kt index 55c8f52..3b532fa 100644 --- a/src/test/kotlin/io/github/nstdio/http/ext/EncryptedStreamFactoryTest.kt +++ b/src/test/kotlin/io/github/nstdio/http/ext/EncryptedStreamFactoryTest.kt @@ -17,6 +17,7 @@ package io.github.nstdio.http.ext import io.github.nstdio.http.ext.IOUtils.bufferedWriter +import io.kotest.assertions.assertSoftly import io.kotest.assertions.throwables.shouldThrowExactly import io.kotest.matchers.shouldBe import io.kotest.matchers.shouldNotBe @@ -34,6 +35,7 @@ import org.mockito.BDDMockito.times import org.mockito.BDDMockito.verify import org.mockito.Mockito.any import org.mockito.Mockito.mock +import java.io.EOFException import java.io.IOException import java.io.InputStream import java.io.OutputStream @@ -42,6 +44,7 @@ import java.nio.file.Path import java.security.InvalidKeyException import java.security.Key import java.security.NoSuchProviderException +import java.security.ProviderException import kotlin.io.path.readText class EncryptedStreamFactoryTest { @@ -115,16 +118,36 @@ class EncryptedStreamFactoryTest { @Test fun `Should throw when key is invalid`() { + //given + val delegate = mock(StreamFactory::class.java).also { + given(it.input(any())).willReturn(InputStream.nullInputStream()) + } + + val factory = Crypto.rsaKeyPair().let { + EncryptedStreamFactory(delegate, it.public, it.private, "AES/CBC/PKCS5Padding", null) + } + + //when + then + assertSoftly { + shouldThrowExactly { factory.output(Path.of("any")) } + .shouldHaveCauseInstanceOf() + + shouldThrowExactly { factory.input(Path.of("any")) } + .shouldHaveCauseInstanceOf() + } + } + + @Test + fun `Should throw eof when cannot read header length`() { //given val delegate = mock(StreamFactory::class.java) - val key = Crypto.rsaKeyPair() - val transform = "AES/CBC/PKCS5Padding" + given(delegate.input(any())).willReturn(InputStream.nullInputStream()) - val factory = EncryptedStreamFactory(delegate, key.public, key.private, transform, null) + val key = Crypto.pbe() + val factory = EncryptedStreamFactory(delegate, key, key, "AES/CBC/PKCS5Padding", null) //when + then - shouldThrowExactly { factory.output(Path.of("any")) } - .shouldHaveCauseInstanceOf() + shouldThrowExactly { factory.input(Path.of("any")) } } private fun InputStream.readText(): String {