diff --git a/core/crypto/rsa_common.go b/core/crypto/rsa_common.go index c7e305439a..2b05eb6a35 100644 --- a/core/crypto/rsa_common.go +++ b/core/crypto/rsa_common.go @@ -12,9 +12,12 @@ const WeakRsaKeyEnv = "LIBP2P_ALLOW_WEAK_RSA_KEYS" var MinRsaKeyBits = 2048 +var maxRsaKeyBits = 8192 + // ErrRsaKeyTooSmall is returned when trying to generate or parse an RSA key // that's smaller than MinRsaKeyBits bits. In test var ErrRsaKeyTooSmall error +var ErrRsaKeyTooBig error = fmt.Errorf("rsa keys must be <= %d bits", maxRsaKeyBits) func init() { if _, ok := os.LookupEnv(WeakRsaKeyEnv); ok { diff --git a/core/crypto/rsa_go.go b/core/crypto/rsa_go.go index 7927d17d18..f15393094a 100644 --- a/core/crypto/rsa_go.go +++ b/core/crypto/rsa_go.go @@ -31,6 +31,9 @@ func GenerateRSAKeyPair(bits int, src io.Reader) (PrivKey, PubKey, error) { if bits < MinRsaKeyBits { return nil, nil, ErrRsaKeyTooSmall } + if bits > maxRsaKeyBits { + return nil, nil, ErrRsaKeyTooBig + } priv, err := rsa.GenerateKey(src, bits) if err != nil { return nil, nil, err @@ -124,6 +127,9 @@ func UnmarshalRsaPrivateKey(b []byte) (key PrivKey, err error) { if sk.N.BitLen() < MinRsaKeyBits { return nil, ErrRsaKeyTooSmall } + if sk.N.BitLen() > maxRsaKeyBits { + return nil, ErrRsaKeyTooBig + } return &RsaPrivateKey{sk: *sk}, nil } @@ -141,6 +147,9 @@ func UnmarshalRsaPublicKey(b []byte) (key PubKey, err error) { if pk.N.BitLen() < MinRsaKeyBits { return nil, ErrRsaKeyTooSmall } + if pk.N.BitLen() > maxRsaKeyBits { + return nil, ErrRsaKeyTooBig + } return &RsaPublicKey{k: *pk}, nil } diff --git a/core/crypto/rsa_test.go b/core/crypto/rsa_test.go index 69151b86c9..f4a7971a59 100644 --- a/core/crypto/rsa_test.go +++ b/core/crypto/rsa_test.go @@ -68,6 +68,44 @@ func TestRSASmallKey(t *testing.T) { } } +func TestRSABigKeyFailsToGenerate(t *testing.T) { + _, _, err := GenerateRSAKeyPair(maxRsaKeyBits*2, rand.Reader) + if err != ErrRsaKeyTooBig { + t.Fatal("should have refused to create too big RSA key") + } +} + +func TestRSABigKey(t *testing.T) { + // Make the global limit smaller for this test to run faster. + // Note we also change the limit below, but this is different + origSize := maxRsaKeyBits + maxRsaKeyBits = 2048 + defer func() { maxRsaKeyBits = origSize }() // + + maxRsaKeyBits *= 2 + badPriv, badPub, err := GenerateRSAKeyPair(maxRsaKeyBits, rand.Reader) + if err != nil { + t.Fatalf("should have succeeded, got: %s", err) + } + pubBytes, err := MarshalPublicKey(badPub) + if err != nil { + t.Fatal(err) + } + privBytes, err := MarshalPrivateKey(badPriv) + if err != nil { + t.Fatal(err) + } + maxRsaKeyBits /= 2 + _, err = UnmarshalPublicKey(pubBytes) + if err != ErrRsaKeyTooBig { + t.Fatal("should have refused to unmarshal a too big key") + } + _, err = UnmarshalPrivateKey(privBytes) + if err != ErrRsaKeyTooBig { + t.Fatal("should have refused to unmarshal a too big key") + } +} + func TestRSASignZero(t *testing.T) { priv, pub, err := GenerateRSAKeyPair(2048, rand.Reader) if err != nil {