From 8f572b8652ae48cb5669d4f56e57673aed211033 Mon Sep 17 00:00:00 2001 From: Georgy Moiseev Date: Tue, 1 Aug 2023 11:28:28 +0300 Subject: [PATCH] key: extract error on read fail Before this patch, read fail had generated a Go error to return it to a user, even though read errors (for example, caused by bad input or wrong password) were placed in error queue. Since goroutines are not binded to a thread and error queue is per thread, sometimes it had resulted in valid connections failing with error. This patch fixes the issue. --- key.go | 22 +++-- ssl_test.go | 245 +++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 260 insertions(+), 7 deletions(-) diff --git a/key.go b/key.go index 1d2abcb1..de80bfb8 100644 --- a/key.go +++ b/key.go @@ -19,6 +19,7 @@ import "C" import ( "errors" + "fmt" "io/ioutil" "runtime" "unsafe" @@ -302,9 +303,11 @@ func LoadPrivateKeyFromPEM(pem_block []byte) (PrivateKey, error) { } defer C.BIO_free(bio) + runtime.LockOSThread() + defer runtime.UnlockOSThread() key := C.PEM_read_bio_PrivateKey(bio, nil, nil, nil) if key == nil { - return nil, errors.New("failed reading private key") + return nil, fmt.Errorf("failed reading private key: %w", errorFromErrorQueue()) } p := &pKey{key: key} @@ -328,9 +331,12 @@ func LoadPrivateKeyFromPEMWithPassword(pem_block []byte, password string) ( defer C.BIO_free(bio) cs := C.CString(password) defer C.free(unsafe.Pointer(cs)) + + runtime.LockOSThread() + defer runtime.UnlockOSThread() key := C.PEM_read_bio_PrivateKey(bio, nil, nil, unsafe.Pointer(cs)) if key == nil { - return nil, errors.New("failed reading private key") + return nil, fmt.Errorf("failed reading private key: %w", errorFromErrorQueue()) } p := &pKey{key: key} @@ -352,9 +358,11 @@ func LoadPrivateKeyFromDER(der_block []byte) (PrivateKey, error) { } defer C.BIO_free(bio) + runtime.LockOSThread() + defer runtime.UnlockOSThread() key := C.d2i_PrivateKey_bio(bio, nil) if key == nil { - return nil, errors.New("failed reading private key der") + return nil, fmt.Errorf("failed reading private key der: %w", errorFromErrorQueue()) } p := &pKey{key: key} @@ -383,9 +391,11 @@ func LoadPublicKeyFromPEM(pem_block []byte) (PublicKey, error) { } defer C.BIO_free(bio) + runtime.LockOSThread() + defer runtime.UnlockOSThread() key := C.PEM_read_bio_PUBKEY(bio, nil, nil, nil) if key == nil { - return nil, errors.New("failed reading public key") + return nil, fmt.Errorf("failed reading public key: %w", errorFromErrorQueue()) } p := &pKey{key: key} @@ -407,9 +417,11 @@ func LoadPublicKeyFromDER(der_block []byte) (PublicKey, error) { } defer C.BIO_free(bio) + runtime.LockOSThread() + defer runtime.UnlockOSThread() key := C.d2i_PUBKEY_bio(bio, nil) if key == nil { - return nil, errors.New("failed reading public key der") + return nil, fmt.Errorf("failed reading public key der: %w", errorFromErrorQueue()) } p := &pKey{key: key} diff --git a/ssl_test.go b/ssl_test.go index 66b733a6..42e1aa7f 100644 --- a/ssl_test.go +++ b/ssl_test.go @@ -18,6 +18,8 @@ import ( "bytes" "crypto/rand" "crypto/tls" + "encoding/base64" + "fmt" "io" "io/ioutil" "net" @@ -79,6 +81,77 @@ ucCCa4lOGgPtXJ0Qf1c8yq5vh4yqkQjrgUTkr+CFDGR6y4CxmNDQxEMYIajaIiSY qmgvgyRayemfO2zR0CPgC6wSoGBth+xW6g+WA8y0z76ZSaWpFi8lVM4= -----END RSA PRIVATE KEY----- `) + keyDERBase64 = `MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDdf3icNvFsrlrnNLi8SocscqlS +bFq+pEvmhcSoqgDLqebnqu8Ld73HJJ74MGXEgRX8xZT5FinOML31CR6t9E/j3dqV6p+GfdlFLe3I +qtC0/bPVnCDBirBygBI4uCrMq+1VhAxPWclrDo7l9QRYbsExH9lfn+RyvxeNMZiOASasvVZNncY8 +E9usBGRdH17EfDL/TPwXqWOLyxSN5o54GTztjjy9w9CGQP7jcCueKYyQJQCtEmnwc6P/q6/EPv5R +6drBkX6loAPtmCUAkHqxkWOJrRq/v7PwzRYhfY+ZpVHGc7WEkDnLzRiUypr1C9oxvLKS10etZEIw +EdKyOkSg2fdPAgMBAAECggEBALCCnpjOaAIVx7csGnNiaOoQzcIzOvVldF7WBuvp3gxu7uV7IFfh +KkkCc/SQjOjVfbIbuiXtdY8s5JPamqpBYVDTQRfrCwlgTL6GZVFeXkd9TcxSSQAzB32Xde3hRaoo +8E8Plce+Y3Z++X1jjfzy9d2x5cYAY0rV4WzRMyMm460PEPKuYlyfGknvL/vvL8dvmtRr0MK+Py2Q +oeAeZ4jBATQeEgWh+P6oJhup5XKNs8BphkyshJdvYeZzJQsoQ4ZFswfbgjMicoDQ3FWR0/2qZXxj +ZsnrT0V+3/LxWmPfWIR2r84VfKht1W+dcydpXIB62hpbLsA5MhDYKfDAZsn8XwECgYEA3+kKctCt +ZER5Uy4WN/4yBiAK1FFpDwQ5W+NhxdeLQ+G04wVMBvt2OqGWCYE4WyZ5dxvxhFmwX6k227XGTlEk +6q7Ip8ryPWtR0Al+jEYeJpQUjb/WL0TO5df8CjTE1GRYzEwq7BgGaa3V8y5Wj7UzxQTAUx0dG2DF +77/545XUnY8CgYEA/T3sYs5hJgAjU47cE39sRU1q7q0DE/lhKEaZfZDHLT6ijEdnVXcPzKjoLSFK +D4f589OaMyAZy9pbor5DIJ0cOOBuTaF91n4F6r3YkCk+sf2AxXsw+HAfODCPMMOKj04XFhje+MFD +bJf72O/P2WS+rY0zd+oWpuwboS37mYJzqkECgYEAjrZvFW0SBuVp2u119exLoAHORTM6XfrYQEv2 +Jm5Sckqqy0O2CIFAAvC4u4gkDlzAcH1b+3pa4y3sLC94nLQ1bmtGs0O0EBeWBp32jZunXfll/E74 +Shp2MKLwHuUxSxpGSriFZwONGtBUnHG9dE0PGRUFLDRTN/7/Sec3c6os4NsCgYAIJj7+KwALVgPN +A5LneblFPamMRrsLoIHU5vi3hroyJYrbksyrfmpevqzCDwkwGMMdapjSvly2J6+9O/wzB3tKBUbn +bqP7DBEqrbNTaFBhL/Q95qn7xLfsefuRqSlDVVL+3gwG20lNLFLpd0YsC8brFNksKbdS5dQ5yp4H +IaCRQQKBgQCainjmRGL3viDgw+/IumVwTizJvX2J+sePvqA9EGL3gKnyNmixjVO5wIJriU4aA+1c +nRB/VzzKrm+HjKqRCOuBROSv4IUMZHrLgLGY0NDEQxghqNoiJJiqaC+DJFrJ6Z87bNHQI+ALrBKg +YG2H7FbqD5YDzLTPvplJpakWLyVUzg== +` + keyEncryptedBytes = []byte(`-----BEGIN ENCRYPTED PRIVATE KEY----- +MIIFLTBXBgkqhkiG9w0BBQ0wSjApBgkqhkiG9w0BBQwwHAQIEyjG5ZrEc7ACAggA +MAwGCCqGSIb3DQIJBQAwHQYJYIZIAWUDBAEqBBDOuZqPxXCetvSxecEZgZyFBIIE +0PQTvLPEM6mh/yqURhnfqg/sQvKnm9AoaVfeucK9E25wpuAr24mR3/QBmL+cIGQx +oNPohmf0MU8CHBgzg4dNL6cRFohHdzrEemV02hk3NRv4z+UhQQelT/ZwF3PbcegI +Zbj6POzZjoK5NXDDuEqxG2SN59+oEEmF/fJkPuK0iqgVsAEvOFYFb117/IbRMjNJ +vse4ZYmNEfknww4OKPL8D8gBYbtPbsKaQTVcoQuJMiaUCypU5uBucYFIjzY2otkM +mVNL4YaS0YdZpzp6JfNLQF80IILRtW+JpYzBALQJH5pTjXH6mA1/RDJwCGwwZrMu +18UtNre2bMPDdCL+GX8uPG0HuTpBjojELaVZz1aimJN9vad/Q+X6QxiXYRJseTML +IO8nHuEu36HAg7OzOU3umCGdlQ7z3GJ9eP6npE1p44h89zbHOMYcGp/doG2f1fO5 +2lAqpfG/fAtefW7yUmSrGXVe0g8L62qoGyv4DJdSPaa0Nc+N/FeKc8e/V+kWOoDm +LY0XIy8TATuiqS9NwaKFSGC/kUoDt0UTPqUGeAjObfabiLOOCsuUJmohF+BxxpO/ +xNIcylDUuYDbDFVNSWeDToloVH8i5RZeLy2vskLM4uHrOraaRH9HUnqMQ9jQ7SXh +1/lCmDJgStrjkYL9IhVzXfrmtOZqASwwUiiFiQoJLsnN3ic/6PHx8gu890wpL0Se +jUgLxX21m42tZ1ismGcmzL8U00RAEth+fO+0dLQx1c6yfsSywlb2Fb7kuMW3HU15 +tbpA7AfZviqarXLcECFsbzOMt/pfUbMUG3OOJ6q/4gMiAEPi+TIrECFCkjHP0Tgw +aeCC2I3yfboaSNeI6dH422JJwPvfRc2I3MHOHlpXnRCgF3btDKW8vw96b7X/P5uV +9/KpXirP/O3JYWYg/co1KaT6LCtuCfUf8Z9gZYbcwn6Kxh9g5LQPysxMVQVx9R2H +ktjWUWwNUVOPA4GtbiNbQXjAgyyTPxv2wJSJav0yJrJUkqkvz1nrnIyocGk+xiJ8 +BAUl/GOGeiS5gskxumJzG6iIv8LRTFKQ85Lp5oD9EwAbxloASjDVwzMSVCqcZQfb +q4VIpbcBUpvyH6tchxQUujmI2ZQ/54C9u100Z4gAVYsLuRaDKcJ1c3kuffLIz/fI +Cfa/kzt+o9YyeCxz2w0aVHsbOk+0P1dL8usBc5b1MB7RaH7So8/7j8mluzyiUNn9 +64VUiSNWEEVlDN40Ar+BBBdRUUJJkWgpPuLQpCj6dSPxevAeHjbytjByUSPt+v0d +oGW4XCDw/72IBa/S5kSgE+n5FV3lrxq4DAgEeVVmZkO35cis7mqcgSgxwBzDclku +UZ56N5FOw5WwELz7+zC1fvdJGpKAwr9uOu6mKArvIshCGhRLeOSjBe3biERcrCJt +WrJra8Zt/E7fOFi0KedYQtdu/7oV31NrTFj5+eL+j/D6+tKGA1+goDDLt5xPFbu+ +l8yjbswmGrOTCHMrd4SJqmGUMWz1dWMdjIeQrwGok25mIE9BtUGzyQN84oWMe75Z +eD+ArpnO4zJcny47LG9PwtyBVn3GDinB/RGi2qcJPYg7xmV7PDNlFuvExctBEQJD +7onbx6HP5kKOZHTZvkm2viGuZbG2Pgz2kk32CxsxWTUL +-----END ENCRYPTED PRIVATE KEY----- +`) + keyEncryptedPassword = "mysslpassword" + keyPublicBytes = []byte(`-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA3X94nDbxbK5a5zS4vEqH +LHKpUmxavqRL5oXEqKoAy6nm56rvC3e9xySe+DBlxIEV/MWU+RYpzjC99QkerfRP +493aleqfhn3ZRS3tyKrQtP2z1ZwgwYqwcoASOLgqzKvtVYQMT1nJaw6O5fUEWG7B +MR/ZX5/kcr8XjTGYjgEmrL1WTZ3GPBPbrARkXR9exHwy/0z8F6lji8sUjeaOeBk8 +7Y48vcPQhkD+43ArnimMkCUArRJp8HOj/6uvxD7+UenawZF+paAD7ZglAJB6sZFj +ia0av7+z8M0WIX2PmaVRxnO1hJA5y80YlMqa9QvaMbyyktdHrWRCMBHSsjpEoNn3 +TwIDAQAB +-----END PUBLIC KEY----- +`) + keyPublicDERBase64 = `MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA3X94nDbxbK5a5zS4vEqHLHKpUmxavqRL +5oXEqKoAy6nm56rvC3e9xySe+DBlxIEV/MWU+RYpzjC99QkerfRP493aleqfhn3ZRS3tyKrQtP2z +1ZwgwYqwcoASOLgqzKvtVYQMT1nJaw6O5fUEWG7BMR/ZX5/kcr8XjTGYjgEmrL1WTZ3GPBPbrARk +XR9exHwy/0z8F6lji8sUjeaOeBk87Y48vcPQhkD+43ArnimMkCUArRJp8HOj/6uvxD7+UenawZF+ +paAD7ZglAJB6sZFjia0av7+z8M0WIX2PmaVRxnO1hJA5y80YlMqa9QvaMbyyktdHrWRCMBHSsjpE +oNn3TwIDAQAB` prime256v1KeyBytes = []byte(`-----BEGIN EC PRIVATE KEY----- MHcCAQEEIB/XL0zZSsAu+IQF1AI/nRneabb2S126WFlvvhzmYr1KoAoGCCqGSM49 AwEHoUQDQgAESSFGWwF6W1hoatKGPPorh4+ipyk0FqpiWdiH+4jIiU39qtOeZGSh @@ -665,7 +738,7 @@ func TestStdlibLotsOfConns(t *testing.T) { }) } -func TestOpenSSLLotsOfConns(t *testing.T) { +func getCtx(t *testing.T) *Ctx { ctx, err := NewCtx() if err != nil { t.Fatal(err) @@ -684,7 +757,12 @@ func TestOpenSSLLotsOfConns(t *testing.T) { if err = ctx.UseCertificate(cert); err != nil { t.Fatal(err) } - if err = ctx.SetCipherList("AES128-SHA"); err != nil { + return ctx +} + +func TestOpenSSLLotsOfConns(t *testing.T) { + ctx := getCtx(t) + if err := ctx.SetCipherList("AES128-SHA"); err != nil { t.Fatal(err) } LotsOfConns(t, 1024*64, 10, 100, 0*time.Second, @@ -694,3 +772,166 @@ func TestOpenSSLLotsOfConns(t *testing.T) { return Client(c, ctx) }) } + +func getCtxWithPrivateKeyAfterFail(t *testing.T, + getPrivateKeyAfterFail func(t *testing.T) PrivateKey) *Ctx { + ctx, err := NewCtx() + if err != nil { + t.Fatal(err) + } + + key := getPrivateKeyAfterFail(t) + + if err = ctx.UsePrivateKey(key); err != nil { + t.Fatal(err) + } + + cert, err := LoadCertificateFromPEM(certBytes) + if err != nil { + t.Fatal(err) + } + + if err = ctx.UseCertificate(cert); err != nil { + t.Fatal(err) + } + + return ctx + +} + +func getPrivatePEMKeyAfterFail(t *testing.T) PrivateKey { + key, err := LoadPrivateKeyFromPEM([]byte("badbadkey")) + if err == nil { + t.Fatal("Expected error, got none") + } + + key, err = LoadPrivateKeyFromPEM(keyBytes) + if err != nil { + t.Fatal(err) + } + + return key +} + +func getPrivateEncryptedPEMKeyAfterFail(t *testing.T) PrivateKey { + badPassword := fmt.Sprintf("wrong_%s", keyEncryptedPassword) + key, err := LoadPrivateKeyFromPEMWithPassword(keyEncryptedBytes, badPassword) + if err == nil { + t.Fatal("Expected error, got none") + } + + key, err = LoadPrivateKeyFromPEMWithPassword(keyEncryptedBytes, keyEncryptedPassword) + if err != nil { + t.Fatal(err) + } + + return key +} + +func getPrivateDERKeyAfterFail(t *testing.T) PrivateKey { + keyDERBytes, err := base64.StdEncoding.DecodeString(keyDERBase64) + if err != nil { + t.Fatal(err) + } + + key, err := LoadPrivateKeyFromDER([]byte("badbadkey")) + if err == nil { + t.Fatal("Expected error, got none") + } + + key, err = LoadPrivateKeyFromDER(keyDERBytes) + if err != nil { + t.Fatal(err) + } + + return key +} + +func getCtxWithPublicKeyAfterFail(t *testing.T, + getPublicKeyAfterFail func(t *testing.T) PublicKey) *Ctx { + ctx, err := NewCtx() + if err != nil { + t.Fatal(err) + } + + cert, err := LoadCertificateFromPEM(certBytes) + if err != nil { + t.Fatal(err) + } + + key := getPublicKeyAfterFail(t) + + if err = cert.SetPubKey(key); err != nil { + t.Fatal(err) + } + + if err = ctx.UseCertificate(cert); err != nil { + t.Fatal(err) + } + + return ctx +} + +func getPublicPEMKeyAfterFail(t *testing.T) PublicKey { + key, err := LoadPublicKeyFromPEM([]byte("badbadkey")) + if err == nil { + t.Fatal("Expected error, got none") + } + + key, err = LoadPublicKeyFromPEM(keyPublicBytes) + if err != nil { + t.Fatal(err) + } + + return key +} + +func getPublicDERKeyAfterFail(t *testing.T) PublicKey { + keyPublicDERBytes, err := base64.StdEncoding.DecodeString(keyPublicDERBase64) + if err != nil { + t.Fatal(err) + } + + key, err := LoadPublicKeyFromDER([]byte("badbadkey")) + if err == nil { + t.Fatal("Expected error, got none") + } + + key, err = LoadPublicKeyFromDER(keyPublicDERBytes) + if err != nil { + t.Fatal(err) + } + + return key +} + +var lotsOfConnsWithFailCases = map[string]func(t *testing.T) *Ctx{ + "PrivatePEM": func(t *testing.T) *Ctx { + return getCtxWithPrivateKeyAfterFail(t, getPrivatePEMKeyAfterFail) + }, + "PrivateEncryptedPEM": func(t *testing.T) *Ctx { + return getCtxWithPrivateKeyAfterFail(t, getPrivateEncryptedPEMKeyAfterFail) + }, + "PrivateDER": func(t *testing.T) *Ctx { + return getCtxWithPrivateKeyAfterFail(t, getPrivateDERKeyAfterFail) + }, + "PublicPEM": func(t *testing.T) *Ctx { + return getCtxWithPublicKeyAfterFail(t, getPublicPEMKeyAfterFail) + }, + "PublicDER": func(t *testing.T) *Ctx { + return getCtxWithPublicKeyAfterFail(t, getPublicDERKeyAfterFail) + }, +} + +func TestOpenSSLLotsOfConnsWithFail(t *testing.T) { + for name, getClientCtx := range lotsOfConnsWithFailCases { + t.Run(name, func(t *testing.T) { + LotsOfConns(t, 1024*64, 10, 100, 0*time.Second, + func(l net.Listener) net.Listener { + return NewListener(l, getCtx(t)) + }, func(c net.Conn) (net.Conn, error) { + return Client(c, getClientCtx(t)) + }) + }) + } +}