diff --git a/hmac.go b/hmac.go index 10e4579..3a00e33 100644 --- a/hmac.go +++ b/hmac.go @@ -72,6 +72,9 @@ type hmacImplementation struct { // Cleanup function cleanup func() + // Count of updates + updates uint64 + // Result, or nil if we don't have the answer yet result []byte } @@ -167,6 +170,7 @@ func (hi *hmacImplementation) initialize() (err error) { hi.cleanup() return } + hi.updates = 0 hi.result = nil return } @@ -181,6 +185,7 @@ func (hi *hmacImplementation) Write(p []byte) (n int, err error) { if err = hi.session.Ctx.SignUpdate(hi.session.Handle, p); err != nil { return } + hi.updates++ n = len(p) return } @@ -188,6 +193,13 @@ func (hi *hmacImplementation) Write(p []byte) (n int, err error) { func (hi *hmacImplementation) Sum(b []byte) []byte { if hi.result == nil { var err error + if hi.updates == 0 { + // http://docs.oasis-open.org/pkcs11/pkcs11-base/v2.40/os/pkcs11-base-v2.40-os.html#_Toc322855304 + // We must ensure that C_SignUpdate is called _at least once_. + if err = hi.session.Ctx.SignUpdate(hi.session.Handle, []byte{}); err != nil { + panic(err) + } + } hi.result, err = hi.session.Ctx.SignFinal(hi.session.Handle) hi.cleanup() if err != nil { diff --git a/hmac_test.go b/hmac_test.go index 14443bf..9baa766 100644 --- a/hmac_test.go +++ b/hmac_test.go @@ -98,6 +98,15 @@ func testHmac(t *testing.T, keytype int, mech int, length int, xlength int, full } }) if full { // Independent of hash, only do these once + t.Run("Empty", func(t *testing.T) { + // Must be able to MAC empty inputs without panicing + var h1 hash.Hash + if h1, err = key.NewHMAC(mech, length); err != nil { + t.Errorf("key.NewHMAC: %v", err) + return + } + h1.Sum([]byte{}) + }) t.Run("MultiSum", func(t *testing.T) { input := []byte("a different short string") var h1 hash.Hash @@ -158,5 +167,19 @@ func testHmac(t *testing.T, keytype int, mech int, length int, xlength int, full return } }) + t.Run("ResetFast", func(t *testing.T) { + // Reset() immediately after creation should be safe + var h1 hash.Hash + if h1, err = key.NewHMAC(mech, length); err != nil { + t.Errorf("key.NewHMAC: %v", err) + return + } + h1.Reset() + if n, err := h1.Write([]byte{2}); err != nil || n != 1 { + t.Errorf("h1.Write: %v/%d", err, n) + return + } + h1.Sum([]byte{}) + }) } }