diff --git a/piv/key.go b/piv/key.go index 1cec128..08a9824 100644 --- a/piv/key.go +++ b/piv/key.go @@ -511,7 +511,12 @@ func (yk *YubiKey) AttestationCertificate() (*x509.Certificate, error) { // // If the slot doesn't have a key, the returned error wraps ErrNotFound. func (yk *YubiKey) Attest(slot Slot) (*x509.Certificate, error) { - cert, err := ykAttest(yk.tx, slot) + var cert *x509.Certificate + err := yk.withTx(func(tx *scTx) error { + var err error + cert, err = ykAttest(tx, slot) + return err + }) if err == nil { return cert, nil } @@ -562,10 +567,17 @@ func (yk *YubiKey) Certificate(slot Slot) (*x509.Certificate, error) { byte(slot.Object), }, } - resp, err := yk.tx.Transmit(cmd) + + var resp []byte + err := yk.withTx(func(tx *scTx) error { + var err error + resp, err = tx.Transmit(cmd) + return err + }) if err != nil { return nil, fmt.Errorf("command failed: %w", err) } + // https://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-73-4.pdf#page=85 obj, _, err := unmarshalASN1(resp, 1, 0x13) // tag 0x53 if err != nil { @@ -609,10 +621,13 @@ func marshalASN1(tag byte, data []byte) []byte { // certificate isn't required to use the associated key for signing or // decryption. func (yk *YubiKey) SetCertificate(key [24]byte, slot Slot, cert *x509.Certificate) error { - if err := ykAuthenticate(yk.tx, key, yk.rand); err != nil { - return fmt.Errorf("authenticating with management key: %w", err) - } - return ykStoreCertificate(yk.tx, slot, cert) + return yk.withTx(func(tx *scTx) error { + err := ykAuthenticate(tx, key, yk.rand) + if err != nil { + return fmt.Errorf("authenticating with management key: %w", err) + } + return ykStoreCertificate(tx, slot, cert) + }) } func ykStoreCertificate(tx *scTx, slot Slot, cert *x509.Certificate) error { @@ -663,10 +678,16 @@ type Key struct { // GenerateKey generates an asymmetric key on the card, returning the key's // public key. func (yk *YubiKey) GenerateKey(key [24]byte, slot Slot, opts Key) (crypto.PublicKey, error) { - if err := ykAuthenticate(yk.tx, key, yk.rand); err != nil { - return nil, fmt.Errorf("authenticating with management key: %w", err) - } - return ykGenerateKey(yk.tx, slot, opts) + var pk crypto.PublicKey + err := yk.withTx(func(tx *scTx) error { + err := ykAuthenticate(tx, key, yk.rand) + if err != nil { + return fmt.Errorf("authenticating with management key: %w", err) + } + pk, err = ykGenerateKey(tx, slot, opts) + return err + }) + return pk, err } func ykGenerateKey(tx *scTx, slot Slot, o Key) (crypto.PublicKey, error) { @@ -755,7 +776,7 @@ func isAuthErr(err error) bool { return e.sw1 == 0x69 && e.sw2 == 0x82 // "security status not satisfied" } -func (k KeyAuth) authTx(yk *YubiKey, pp PINPolicy) error { +func (k KeyAuth) authTx(yk *YubiKey, pp PINPolicy, tx *scTx) error { // PINPolicyNever shouldn't require a PIN. if pp == PINPolicyNever { return nil @@ -764,7 +785,11 @@ func (k KeyAuth) authTx(yk *YubiKey, pp PINPolicy) error { // PINPolicyAlways should always prompt a PIN even if the key says that // login isn't needed. // https://github.com/go-piv/piv-go/issues/49 - if pp != PINPolicyAlways && !ykLoginNeeded(yk.tx) { + + var flag bool + + flag = !ykLoginNeeded(tx) + if pp != PINPolicyAlways && flag { return nil } @@ -779,14 +804,16 @@ func (k KeyAuth) authTx(yk *YubiKey, pp PINPolicy) error { if pin == "" { return fmt.Errorf("pin required but wasn't provided") } - return ykLogin(yk.tx, pin) + return yk.withTx(func(tx *scTx) error { + return ykLogin(tx, pin) + }) } -func (k KeyAuth) do(yk *YubiKey, pp PINPolicy, f func(tx *scTx) ([]byte, error)) ([]byte, error) { - if err := k.authTx(yk, pp); err != nil { +func (k KeyAuth) do(yk *YubiKey, pp PINPolicy, tx *scTx, f func() ([]byte, error)) ([]byte, error) { + if err := k.authTx(yk, pp, tx); err != nil { return nil, err } - return f(yk.tx) + return f() } func pinPolicy(yk *YubiKey, slot Slot) (PINPolicy, error) { @@ -931,11 +958,12 @@ func (yk *YubiKey) SetPrivateKeyInsecure(key [24]byte, slot Slot, private crypto tags = append(tags, param...) } - if err := ykAuthenticate(yk.tx, key, yk.rand); err != nil { - return fmt.Errorf("authenticating with management key: %w", err) - } - - return ykImportKey(yk.tx, tags, slot, policy) + return yk.withTx(func(tx *scTx) error { + if err := ykAuthenticate(tx, key, yk.rand); err != nil { + return fmt.Errorf("authenticating with management key: %w", err) + } + return ykImportKey(tx, tags, slot, policy) + }) } func ykImportKey(tx *scTx, tags []byte, slot Slot, o Key) error { @@ -995,9 +1023,15 @@ var _ crypto.Signer = (*ECDSAPrivateKey)(nil) // Sign implements crypto.Signer. func (k *ECDSAPrivateKey) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { - return k.auth.do(k.yk, k.pp, func(tx *scTx) ([]byte, error) { - return ykSignECDSA(tx, k.slot, k.pub, digest) + var res []byte + err := k.yk.withTx(func(tx *scTx) error { + var err error + res, err = k.auth.do(k.yk, k.pp, tx, func() ([]byte, error) { + return ykSignECDSA(tx, k.slot, k.pub, digest) + }) + return err }) + return res, err } // SharedKey performs a Diffie-Hellman key agreement with the peer @@ -1014,42 +1048,49 @@ func (k *ECDSAPrivateKey) SharedKey(peer *ecdsa.PublicKey) ([]byte, error) { return nil, errMismatchingAlgorithms } msg := elliptic.Marshal(peer.Curve, peer.X, peer.Y) - return k.auth.do(k.yk, k.pp, func(tx *scTx) ([]byte, error) { - var alg byte - size := k.pub.Params().BitSize - switch size { - case 256: - alg = algECCP256 - case 384: - alg = algECCP384 - default: - return nil, unsupportedCurveError{curve: size} - } - // https://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-73-4.pdf#page=118 - // https://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-73-4.pdf#page=93 - cmd := apdu{ - instruction: insAuthenticate, - param1: alg, - param2: byte(k.slot.Key), - data: marshalASN1(0x7c, - append([]byte{0x82, 0x00}, - marshalASN1(0x85, msg)...)), - } - resp, err := tx.Transmit(cmd) - if err != nil { - return nil, fmt.Errorf("command failed: %w", err) - } - sig, _, err := unmarshalASN1(resp, 1, 0x1c) // 0x7c - if err != nil { - return nil, fmt.Errorf("unmarshal response: %v", err) - } - rs, _, err := unmarshalASN1(sig, 2, 0x02) // 0x82 - if err != nil { - return nil, fmt.Errorf("unmarshal response signature: %v", err) - } - return rs, nil + var res []byte + err := k.yk.withTx(func(tx *scTx) error { + var err error + res, err = k.auth.do(k.yk, k.pp, tx, func() ([]byte, error) { + var alg byte + size := k.pub.Params().BitSize + switch size { + case 256: + alg = algECCP256 + case 384: + alg = algECCP384 + default: + return nil, unsupportedCurveError{curve: size} + } + + // https://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-73-4.pdf#page=118 + // https://nvlpubs.nist.gov/nistpubs/SpecialPublications/NIST.SP.800-73-4.pdf#page=93 + cmd := apdu{ + instruction: insAuthenticate, + param1: alg, + param2: byte(k.slot.Key), + data: marshalASN1(0x7c, + append([]byte{0x82, 0x00}, + marshalASN1(0x85, msg)...)), + } + resp, err := tx.Transmit(cmd) + if err != nil { + return nil, fmt.Errorf("command failed: %w", err) + } + sig, _, err := unmarshalASN1(resp, 1, 0x1c) // 0x7c + if err != nil { + return nil, fmt.Errorf("unmarshal response: %v", err) + } + rs, _, err := unmarshalASN1(sig, 2, 0x02) // 0x82 + if err != nil { + return nil, fmt.Errorf("unmarshal response signature: %v", err) + } + return rs, nil + }) + return err }) + return res, err } type keyEd25519 struct { @@ -1065,9 +1106,15 @@ func (k *keyEd25519) Public() crypto.PublicKey { } func (k *keyEd25519) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { - return k.auth.do(k.yk, k.pp, func(tx *scTx) ([]byte, error) { - return skSignEd25519(tx, k.slot, k.pub, digest) + var res []byte + err := k.yk.withTx(func(tx *scTx) error { + var err error + res, err = k.auth.do(k.yk, k.pp, tx, func() ([]byte, error) { + return skSignEd25519(tx, k.slot, k.pub, digest) + }) + return err }) + return res, err } type keyRSA struct { @@ -1083,15 +1130,27 @@ func (k *keyRSA) Public() crypto.PublicKey { } func (k *keyRSA) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { - return k.auth.do(k.yk, k.pp, func(tx *scTx) ([]byte, error) { - return ykSignRSA(tx, k.slot, k.pub, digest, opts) + var res []byte + err := k.yk.withTx(func(tx *scTx) error { + var err error + res, err = k.auth.do(k.yk, k.pp, tx, func() ([]byte, error) { + return ykSignRSA(tx, k.slot, k.pub, digest, opts) + }) + return err }) + return res, err } func (k *keyRSA) Decrypt(rand io.Reader, msg []byte, opts crypto.DecrypterOpts) ([]byte, error) { - return k.auth.do(k.yk, k.pp, func(tx *scTx) ([]byte, error) { - return ykDecryptRSA(tx, k.slot, k.pub, msg) + var res []byte + err := k.yk.withTx(func(tx *scTx) error { + var err error + res, err = k.auth.do(k.yk, k.pp, tx, func() ([]byte, error) { + return ykDecryptRSA(tx, k.slot, k.pub, msg) + }) + return err }) + return res, err } func ykSignECDSA(tx *scTx, slot Slot, pub *ecdsa.PublicKey, digest []byte) ([]byte, error) { diff --git a/piv/key_test.go b/piv/key_test.go index e4727f0..05d5956 100644 --- a/piv/key_test.go +++ b/piv/key_test.go @@ -250,7 +250,7 @@ func TestSlots(t *testing.T) { } tmpl := &x509.Certificate{ - Subject: pkix.Name{CommonName: "my-client"}, + Subject: pkix.Name{CommonName: "my-Client"}, SerialNumber: big.NewInt(1), NotBefore: time.Now(), NotAfter: time.Now().Add(time.Hour), @@ -483,7 +483,7 @@ func TestYubiKeyStoreCertificate(t *testing.T) { } cliTmpl := &x509.Certificate{ - Subject: pkix.Name{CommonName: "my-client"}, + Subject: pkix.Name{CommonName: "my-Client"}, SerialNumber: big.NewInt(101), NotBefore: time.Now(), NotAfter: time.Now().Add(time.Hour), @@ -492,18 +492,18 @@ func TestYubiKeyStoreCertificate(t *testing.T) { } cliCertDER, err := x509.CreateCertificate(rand.Reader, cliTmpl, caCert, pub, caPriv) if err != nil { - t.Fatalf("creating client cert: %v", err) + t.Fatalf("creating Client cert: %v", err) } cliCert, err := x509.ParseCertificate(cliCertDER) if err != nil { t.Fatalf("parsing cli cert: %v", err) } if err := yk.SetCertificate(DefaultManagementKey, slot, cliCert); err != nil { - t.Fatalf("storing client cert: %v", err) + t.Fatalf("storing Client cert: %v", err) } gotCert, err := yk.Certificate(slot) if err != nil { - t.Fatalf("getting client cert: %v", err) + t.Fatalf("getting Client cert: %v", err) } if !bytes.Equal(gotCert.Raw, cliCert.Raw) { t.Errorf("stored cert didn't match cert retrieved") diff --git a/piv/pcsc_linux.go b/piv/pcsc_linux.go index 6a5c078..a8d5433 100644 --- a/piv/pcsc_linux.go +++ b/piv/pcsc_linux.go @@ -27,4 +27,4 @@ func scCheck(rc C.long) error { func isRCNoReaders(rc C.long) bool { return C.ulong(rc) == 0x8010002E -} +} \ No newline at end of file diff --git a/piv/pcsc_test.go b/piv/pcsc_test.go index 5902564..56814ef 100644 --- a/piv/pcsc_test.go +++ b/piv/pcsc_test.go @@ -21,7 +21,7 @@ import ( ) func runContextTest(t *testing.T, f func(t *testing.T, c *scContext)) { - ctx, err := newSCContext() + ctx, err := newSCContext(true) if err != nil { t.Fatalf("creating context: %v", err) } diff --git a/piv/pcsc_unix.go b/piv/pcsc_unix.go index 66d590e..110d66f 100644 --- a/piv/pcsc_unix.go +++ b/piv/pcsc_unix.go @@ -38,16 +38,17 @@ import ( const rcSuccess = C.SCARD_S_SUCCESS type scContext struct { - ctx C.SCARDCONTEXT + ctx C.SCARDCONTEXT + shared bool } -func newSCContext() (*scContext, error) { +func newSCContext(shared bool) (*scContext, error) { var ctx C.SCARDCONTEXT rc := C.SCardEstablishContext(C.SCARD_SCOPE_SYSTEM, nil, nil, &ctx) if err := scCheck(rc); err != nil { return nil, err } - return &scContext{ctx: ctx}, nil + return &scContext{ctx: ctx, shared: shared}, nil } func (c *scContext) Close() error { @@ -93,12 +94,23 @@ func (c *scContext) Connect(reader string) (*scHandle, error) { handle C.SCARDHANDLE activeProtocol C.DWORD ) - rc := C.SCardConnect(c.ctx, C.CString(reader), - C.SCARD_SHARE_EXCLUSIVE, C.SCARD_PROTOCOL_T1, - &handle, &activeProtocol) - if err := scCheck(rc); err != nil { - return nil, err + + if c.shared { + rc := C.SCardConnect(c.ctx, C.CString(reader), + C.SCARD_SHARE_SHARED, C.SCARD_PROTOCOL_T1, + &handle, &activeProtocol) + if err := scCheck(rc); err != nil { + return nil, err + } + } else { + rc := C.SCardConnect(c.ctx, C.CString(reader), + C.SCARD_SHARE_EXCLUSIVE, C.SCARD_PROTOCOL_T1, + &handle, &activeProtocol) + if err := scCheck(rc); err != nil { + return nil, err + } } + return &scHandle{handle}, nil } diff --git a/piv/pcsc_windows.go b/piv/pcsc_windows.go index 930ef38..e7d4399 100644 --- a/piv/pcsc_windows.go +++ b/piv/pcsc_windows.go @@ -35,6 +35,7 @@ var ( const ( scardScopeSystem = 2 scardShareExclusive = 1 + scardShareShared = 2 scardLeaveCard = 0 scardProtocolT1 = 2 scardPCIT1 = 0 @@ -54,10 +55,11 @@ func isRCNoReaders(rc uintptr) bool { } type scContext struct { - ctx syscall.Handle + ctx syscall.Handle + shared bool } -func newSCContext() (*scContext, error) { +func newSCContext(shared bool) (*scContext, error) { var ctx syscall.Handle r0, _, _ := procSCardEstablishContext.Call( @@ -69,7 +71,7 @@ func newSCContext() (*scContext, error) { if err := scCheck(r0); err != nil { return nil, err } - return &scContext{ctx: ctx}, nil + return &scContext{ctx: ctx, shared: shared}, nil } func (c *scContext) Close() error { @@ -127,14 +129,26 @@ func (c *scContext) Connect(reader string) (*scHandle, error) { handle syscall.Handle activeProtocol uint16 ) - r0, _, _ := procSCardConnectW.Call( - uintptr(c.ctx), - uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(reader))), - scardShareExclusive, - scardProtocolT1, - uintptr(unsafe.Pointer(&handle)), - uintptr(activeProtocol), - ) + var r0 uintptr + if c.shared { + r0, _, _ = procSCardConnectW.Call( + uintptr(c.ctx), + uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(reader))), + scardShareShared, + scardProtocolT1, + uintptr(unsafe.Pointer(&handle)), + uintptr(activeProtocol), + ) + } else { + r0, _, _ = procSCardConnectW.Call( + uintptr(c.ctx), + uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(reader))), + scardShareExclusive, + scardProtocolT1, + uintptr(unsafe.Pointer(&handle)), + uintptr(activeProtocol), + ) + } if err := scCheck(r0); err != nil { return nil, err } diff --git a/piv/piv.go b/piv/piv.go index 001d5c3..fbb8916 100644 --- a/piv/piv.go +++ b/piv/piv.go @@ -51,8 +51,8 @@ var ( // // See: https://ludovicrousseau.blogspot.com/2010/05/what-is-in-pcsc-reader-name.html func Cards() ([]string, error) { - var c client - return c.Cards() + var c Client + return c.cards() } const ( @@ -102,7 +102,6 @@ const ( type YubiKey struct { ctx *scContext h *scHandle - tx *scTx rand io.Reader @@ -124,23 +123,31 @@ func (yk *YubiKey) Close() error { return err1 } -// Open connects to a YubiKey smart card. -func Open(card string) (*YubiKey, error) { - var c client - return c.Open(card) -} - -// client is a smart card client and may be exported in the future to allow +// Client is a smart card Client and may be exported in the future to allow // configuration for the top level Open() and Cards() APIs. -type client struct { +type Client struct { // Rand is a cryptographic source of randomness used for card challenges. // // If nil, defaults to crypto.Rand. Rand io.Reader + + // Shared enables a non-exclusive connection to the PIV application, allowing + // other applications that also request a non-exclusive connection to connect + // at the same time. + // + // Certain features, such as cached PINs, don't work when this feature is enabled. + // It's also common for other applications to require exclusive connections. + Shared bool } -func (c *client) Cards() ([]string, error) { - ctx, err := newSCContext() +// Open connects to a YubiKey smart card. +func Open(card string) (*YubiKey, error) { + var c Client + return c.Open(card) +} + +func (c *Client) cards() ([]string, error) { + ctx, err := newSCContext(c.Shared) if err != nil { return nil, fmt.Errorf("connecting to pscs: %w", err) } @@ -148,8 +155,18 @@ func (c *client) Cards() ([]string, error) { return ctx.ListReaders() } -func (c *client) Open(card string) (*YubiKey, error) { - ctx, err := newSCContext() +func (yk *YubiKey) withTx(fn func(tx *scTx) error) error { + tx, err := yk.h.Begin() + if err != nil { + return fmt.Errorf("beginning smart card transaction: %w", err) + } + defer tx.Close() + return fn(tx) +} + +// Open connects to a YubiKey smart card. +func (c *Client) Open(card string) (*YubiKey, error) { + ctx, err := newSCContext(c.Shared) if err != nil { return nil, fmt.Errorf("connecting to smart card daemon: %w", err) } @@ -159,27 +176,30 @@ func (c *client) Open(card string) (*YubiKey, error) { ctx.Close() return nil, fmt.Errorf("connecting to smart card: %w", err) } - tx, err := h.Begin() - if err != nil { - return nil, fmt.Errorf("beginning smart card transaction: %w", err) - } - if err := ykSelectApplication(tx, aidPIV[:]); err != nil { - tx.Close() - return nil, fmt.Errorf("selecting piv applet: %w", err) - } - yk := &YubiKey{ctx: ctx, h: h, tx: tx} - v, err := ykVersion(yk.tx) + yk := &YubiKey{ctx: ctx, h: h} + + err = yk.withTx(func(tx *scTx) error { + if err := ykSelectApplication(tx, aidPIV[:]); err != nil { + return fmt.Errorf("selecting piv applet: %w", err) + } + v, err := ykVersion(tx) + if err != nil { + return fmt.Errorf("getting yubikey version: %w", err) + } + yk.version = v + if c.Rand != nil { + yk.rand = c.Rand + } else { + yk.rand = rand.Reader + } + return nil + }) if err != nil { yk.Close() - return nil, fmt.Errorf("getting yubikey version: %w", err) - } - yk.version = v - if c.Rand != nil { - yk.rand = c.Rand - } else { - yk.rand = rand.Reader + return nil, err } + return yk, nil } @@ -198,7 +218,13 @@ func (yk *YubiKey) Version() Version { // Serial returns the YubiKey's serial number. func (yk *YubiKey) Serial() (uint32, error) { - return ykSerial(yk.tx, yk.version) + var serial uint32 + err := yk.withTx(func(tx *scTx) error { + var err error + serial, err = ykSerial(tx, yk.version) + return err + }) + return serial, err } func encodePIN(pin string) ([]byte, error) { @@ -225,7 +251,9 @@ func encodePIN(pin string) ([]byte, error) { // // Use DefaultPIN if the PIN hasn't been set. func (yk *YubiKey) authPIN(pin string) error { - return ykLogin(yk.tx, pin) + return yk.withTx(func(tx *scTx) error { + return ykLogin(tx, pin) + }) } func ykLogin(tx *scTx, pin string) error { @@ -250,7 +278,13 @@ func ykLoginNeeded(tx *scTx) bool { // Retries returns the number of attempts remaining to enter the correct PIN. func (yk *YubiKey) Retries() (int, error) { - return ykPINRetries(yk.tx) + var retry int + err := yk.withTx(func(tx *scTx) error { + var err error + retry, err = ykPINRetries(tx) + return err + }) + return retry, err } func ykPINRetries(tx *scTx) (int, error) { @@ -270,7 +304,9 @@ func ykPINRetries(tx *scTx) (int, error) { // and resetting the PIN, PUK, and Management Key to their default values. This // does NOT affect data on other applets, such as GPG or U2F. func (yk *YubiKey) Reset() error { - return ykReset(yk.tx, yk.rand) + return yk.withTx(func(tx *scTx) error { + return ykReset(tx, yk.rand) + }) } func ykReset(tx *scTx, r io.Reader) error { @@ -338,8 +374,8 @@ type version struct { // certificates to slots. // // Use DefaultManagementKey if the management key hasn't been set. -func (yk *YubiKey) authManagementKey(key [24]byte) error { - return ykAuthenticate(yk.tx, key, yk.rand) +func (yk *YubiKey) authManagementKey(key [24]byte, tx *scTx) error { + return ykAuthenticate(tx, key, yk.rand) } var ( @@ -457,13 +493,14 @@ func ykAuthenticate(tx *scTx, key [24]byte, rand io.Reader) error { // // func (yk *YubiKey) SetManagementKey(oldKey, newKey [24]byte) error { - if err := ykAuthenticate(yk.tx, oldKey, yk.rand); err != nil { - return fmt.Errorf("authenticating with old key: %w", err) - } - if err := ykSetManagementKey(yk.tx, newKey, false); err != nil { - return err - } - return nil + return yk.withTx(func(tx *scTx) error { + err := ykAuthenticate(tx, oldKey, yk.rand) + if err != nil { + return fmt.Errorf("authenticating with old key: %w", err) + } + return ykSetManagementKey(tx, newKey, false) + + }) } // ykSetManagementKey updates the management key to a new key. This requires @@ -503,7 +540,9 @@ func ykSetManagementKey(tx *scTx, key [24]byte, touch bool) error { // } // func (yk *YubiKey) SetPIN(oldPIN, newPIN string) error { - return ykChangePIN(yk.tx, oldPIN, newPIN) + return yk.withTx(func(tx *scTx) error { + return ykChangePIN(tx, oldPIN, newPIN) + }) } func ykChangePIN(tx *scTx, oldPIN, newPIN string) error { @@ -526,7 +565,9 @@ func ykChangePIN(tx *scTx, oldPIN, newPIN string) error { // Unblock unblocks the PIN, setting it to a new value. func (yk *YubiKey) Unblock(puk, newPIN string) error { - return ykUnblockPIN(yk.tx, puk, newPIN) + return yk.withTx(func(tx *scTx) error { + return ykUnblockPIN(tx, puk, newPIN) + }) } func ykUnblockPIN(tx *scTx, puk, newPIN string) error { @@ -564,7 +605,9 @@ func ykUnblockPIN(tx *scTx, puk, newPIN string) error { // } // func (yk *YubiKey) SetPUK(oldPUK, newPUK string) error { - return ykChangePUK(yk.tx, oldPUK, newPUK) + return yk.withTx(func(tx *scTx) error { + return ykChangePUK(tx, oldPUK, newPUK) + }) } func ykChangePUK(tx *scTx, oldPUK, newPUK string) error { @@ -676,7 +719,12 @@ func unmarshalDERField(b []byte, tag uint64) (obj []byte, err error) { // Metadata returns protected data stored on the card. This can be used to // retrieve PIN protected management keys. func (yk *YubiKey) Metadata(pin string) (*Metadata, error) { - m, err := ykGetProtectedMetadata(yk.tx, pin) + var m *Metadata + err := yk.withTx(func(tx *scTx) error { + var err error + m, err = ykGetProtectedMetadata(tx, pin) + return err + }) if err != nil { if errors.Is(err, ErrNotFound) { return &Metadata{}, nil @@ -690,7 +738,9 @@ func (yk *YubiKey) Metadata(pin string) (*Metadata, error) { // store the management key on the smart card instead of managing the PIN and // management key seperately. func (yk *YubiKey) SetMetadata(key [24]byte, m *Metadata) error { - return ykSetProtectedMetadata(yk.tx, key, m) + return yk.withTx(func(tx *scTx) error { + return ykSetProtectedMetadata(tx, key, m) + }) } // Metadata holds protected metadata. This is primarily used by YubiKey manager diff --git a/piv/piv_test.go b/piv/piv_test.go index 534b259..46c31c0 100644 --- a/piv/piv_test.go +++ b/piv/piv_test.go @@ -158,13 +158,30 @@ func TestYubiKeyLoginNeeded(t *testing.T) { testRequiresVersion(t, yk, 4, 3, 0) - if !ykLoginNeeded(yk.tx) { + var flag bool + + yk.withTx(func(tx *scTx) error { + flag = !ykLoginNeeded(tx) + return nil + }) + + if flag { t.Errorf("expected login needed") } - if err := ykLogin(yk.tx, DefaultPIN); err != nil { + + err := yk.withTx(func(tx *scTx) error { + return ykLogin(tx, DefaultPIN) + }) + if err != nil { t.Fatalf("login: %v", err) } - if ykLoginNeeded(yk.tx) { + + yk.withTx(func(tx *scTx) error { + flag = ykLoginNeeded(tx) + return nil + }) + + if flag { t.Errorf("expected no login needed") } } @@ -208,7 +225,11 @@ func TestYubiKeyAuthenticate(t *testing.T) { yk, close := newTestYubiKey(t) defer close() - if err := yk.authManagementKey(DefaultManagementKey); err != nil { + err := yk.withTx(func(tx *scTx) error { + return yk.authManagementKey(DefaultManagementKey, tx) + }) + + if err != nil { t.Errorf("authenticating: %v", err) } } @@ -225,9 +246,15 @@ func TestYubiKeySetManagementKey(t *testing.T) { if err := yk.SetManagementKey(DefaultManagementKey, mgmtKey); err != nil { t.Fatalf("setting management key: %v", err) } - if err := yk.authManagementKey(mgmtKey); err != nil { + + err := yk.withTx(func(tx *scTx) error { + return yk.authManagementKey(mgmtKey, tx) + }) + + if err != nil { t.Errorf("authenticating with new management key: %v", err) } + if err := yk.SetManagementKey(mgmtKey, DefaultManagementKey); err != nil { t.Fatalf("resetting management key: %v", err) } @@ -239,7 +266,11 @@ func TestYubiKeyUnblockPIN(t *testing.T) { badPIN := "0" for { - err := ykLogin(yk.tx, badPIN) + + err := yk.withTx(func(tx *scTx) error { + return ykLogin(tx, badPIN) + }) + if err == nil { t.Fatalf("login with bad pin succeeded") } @@ -255,7 +286,11 @@ func TestYubiKeyUnblockPIN(t *testing.T) { if err := yk.Unblock(DefaultPUK, DefaultPIN); err != nil { t.Fatalf("unblocking pin: %v", err) } - if err := ykLogin(yk.tx, DefaultPIN); err != nil { + + err := yk.withTx(func(tx *scTx) error { + return ykLogin(tx, DefaultPIN) + }) + if err != nil { t.Errorf("failed to login with pin after unblock: %v", err) } }