Skip to content

Commit

Permalink
Merge pull request #580 from smallstep/josh/tpm-capalgs
Browse files Browse the repository at this point in the history
Add method to obtain TPM capabilities
  • Loading branch information
joshdrake authored Sep 10, 2024
2 parents c7de661 + 6463150 commit 1412681
Show file tree
Hide file tree
Showing 9 changed files with 485 additions and 22 deletions.
87 changes: 70 additions & 17 deletions kms/tpmkms/tpmkms.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"go.step.sm/crypto/kms/apiv1"
"go.step.sm/crypto/kms/uri"
"go.step.sm/crypto/tpm"
"go.step.sm/crypto/tpm/algorithm"
"go.step.sm/crypto/tpm/attestation"
"go.step.sm/crypto/tpm/storage"
"go.step.sm/crypto/tpm/tss2"
Expand All @@ -39,6 +40,32 @@ func init() {
})
}

// PreferredSignatureAlgorithms indicates the preferred selection of signature
// algorithms when an explicit value is omitted in CreateKeyRequest
var preferredSignatureAlgorithms []apiv1.SignatureAlgorithm

// SetPreferredSignatureAlgorithms sets the preferred signature algorithms
// to select from when explicit values are omitted in CreateKeyRequest
//
// # Experimental
//
// Notice: This method is EXPERIMENTAL and may be changed or removed in a later
// release.
func SetPreferredSignatureAlgorithms(algs []apiv1.SignatureAlgorithm) {
preferredSignatureAlgorithms = algs
}

// PreferredSignatureAlgorithms returns the preferred signature algorithms
// to select from when explicit values are omitted in CreateKeyRequest
//
// # Experimental
//
// Notice: This method is EXPERIMENTAL and may be changed or removed in a later
// release.
func PreferredSignatureAlgorithms() []apiv1.SignatureAlgorithm {
return preferredSignatureAlgorithms
}

// Scheme is the scheme used in TPM KMS URIs, the string "tpmkms".
const Scheme = string(apiv1.TPMKMS)

Expand Down Expand Up @@ -73,21 +100,22 @@ type TPMKMS struct {
}

type algorithmAttributes struct {
Type string
Curve int
Type string
Curve int
Requires []algorithm.Algorithm
}

var signatureAlgorithmMapping = map[apiv1.SignatureAlgorithm]algorithmAttributes{
apiv1.UnspecifiedSignAlgorithm: {"RSA", -1},
apiv1.SHA256WithRSA: {"RSA", -1},
apiv1.SHA384WithRSA: {"RSA", -1},
apiv1.SHA512WithRSA: {"RSA", -1},
apiv1.SHA256WithRSAPSS: {"RSA", -1},
apiv1.SHA384WithRSAPSS: {"RSA", -1},
apiv1.SHA512WithRSAPSS: {"RSA", -1},
apiv1.ECDSAWithSHA256: {"ECDSA", 256},
apiv1.ECDSAWithSHA384: {"ECDSA", 384},
apiv1.ECDSAWithSHA512: {"ECDSA", 521},
apiv1.UnspecifiedSignAlgorithm: {"RSA", -1, []algorithm.Algorithm{algorithm.AlgorithmRSA}},
apiv1.SHA256WithRSA: {"RSA", -1, []algorithm.Algorithm{algorithm.AlgorithmRSA, algorithm.AlgorithmSHA256}},
apiv1.SHA384WithRSA: {"RSA", -1, []algorithm.Algorithm{algorithm.AlgorithmRSA, algorithm.AlgorithmSHA384}},
apiv1.SHA512WithRSA: {"RSA", -1, []algorithm.Algorithm{algorithm.AlgorithmRSA, algorithm.AlgorithmSHA512}},
apiv1.SHA256WithRSAPSS: {"RSA", -1, []algorithm.Algorithm{algorithm.AlgorithmRSAPSS, algorithm.AlgorithmSHA256}},
apiv1.SHA384WithRSAPSS: {"RSA", -1, []algorithm.Algorithm{algorithm.AlgorithmRSAPSS, algorithm.AlgorithmSHA384}},
apiv1.SHA512WithRSAPSS: {"RSA", -1, []algorithm.Algorithm{algorithm.AlgorithmRSAPSS, algorithm.AlgorithmSHA512}},
apiv1.ECDSAWithSHA256: {"ECDSA", 256, []algorithm.Algorithm{algorithm.AlgorithmECDSA, algorithm.AlgorithmSHA256}},
apiv1.ECDSAWithSHA384: {"ECDSA", 384, []algorithm.Algorithm{algorithm.AlgorithmECDSA, algorithm.AlgorithmSHA384}},
apiv1.ECDSAWithSHA512: {"ECDSA", 521, []algorithm.Algorithm{algorithm.AlgorithmECDSA, algorithm.AlgorithmSHA512}},
}

const (
Expand Down Expand Up @@ -326,9 +354,36 @@ func (k *TPMKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespons
return nil, fmt.Errorf("failed parsing %q: %w", req.Name, err)
}

v, ok := signatureAlgorithmMapping[req.SignatureAlgorithm]
if !ok {
return nil, fmt.Errorf("TPMKMS does not support signature algorithm %q", req.SignatureAlgorithm)
ctx := context.Background()
caps, err := k.tpm.GetCapabilities(ctx)
if err != nil {
return nil, fmt.Errorf("could not get TPM capabilities: %w", err)
}

var (
v algorithmAttributes
ok bool
)
if !properties.ak && req.SignatureAlgorithm == apiv1.UnspecifiedSignAlgorithm && len(preferredSignatureAlgorithms) > 0 {
for _, alg := range preferredSignatureAlgorithms {
v, ok = signatureAlgorithmMapping[alg]
if !ok {
return nil, fmt.Errorf("TPMKMS does not support signature algorithm %q", alg)
}

if caps.SupportsAlgorithms(v.Requires) {
break
}
}
} else {
v, ok = signatureAlgorithmMapping[req.SignatureAlgorithm]
if !ok {
return nil, fmt.Errorf("TPMKMS does not support signature algorithm %q", req.SignatureAlgorithm)
}

if !caps.SupportsAlgorithms(v.Requires) {
return nil, fmt.Errorf("signature algorithm %q not supported by the TPM device", req.SignatureAlgorithm)
}
}

if properties.ak && v.Type == "ECDSA" {
Expand All @@ -348,8 +403,6 @@ func (k *TPMKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespons
size = v.Curve
}

ctx := context.Background()

var privateKey any
if properties.ak {
ak, err := k.tpm.CreateAK(ctx, properties.name) // NOTE: size is never passed for AKs; it's hardcoded to 2048 in lower levels.
Expand Down
90 changes: 85 additions & 5 deletions kms/tpmkms/tpmkms_simulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,19 @@ import (
"go.step.sm/crypto/tpm/tss2"
)

type newSimulatedTPMOption func(t *testing.T, tpm *tpmp.TPM)
type newSimulatedTPMOption any

func withAK(name string) newSimulatedTPMOption {
type newSimulatedTPMPreparerOption func(t *testing.T, tpm *tpmp.TPM)

func withAK(name string) newSimulatedTPMPreparerOption {
return func(t *testing.T, tpm *tpmp.TPM) {
t.Helper()
_, err := tpm.CreateAK(context.Background(), name)
require.NoError(t, err)
}
}

func withKey(name string) newSimulatedTPMOption {
func withKey(name string) newSimulatedTPMPreparerOption {
return func(t *testing.T, tpm *tpmp.TPM) {
t.Helper()
config := tpmp.CreateKeyConfig{
Expand All @@ -59,14 +61,38 @@ func withKey(name string) newSimulatedTPMOption {
}
}

func withCapabilities(caps *tpmp.Capabilities) tpmp.NewTPMOption {
return tpmp.WithCapabilities(caps)
}

func newSimulatedTPM(t *testing.T, opts ...newSimulatedTPMOption) *tpmp.TPM {
t.Helper()

tmpDir := t.TempDir()
tpm, err := tpmp.New(withSimulator(t), tpmp.WithStore(storage.NewDirstore(tmpDir)))
tpmOpts := []tpmp.NewTPMOption{
withSimulator(t),
tpmp.WithStore(storage.NewDirstore(tmpDir)),
}

var preparers []newSimulatedTPMPreparerOption
for _, opt := range opts {
switch o := opt.(type) {
case tpmp.NewTPMOption:
tpmOpts = append(tpmOpts, o)
case newSimulatedTPMPreparerOption:
preparers = append(preparers, o)
default:
require.Fail(t, "invalid TPM option type provided", `TPM option type "%T"`, o)
}
}

tpm, err := tpmp.New(tpmOpts...)
require.NoError(t, err)
for _, applyTo := range opts {

for _, applyTo := range preparers {
applyTo(t, tpm)
}

return tpm
}

Expand All @@ -87,6 +113,60 @@ func withSimulator(t *testing.T) tpmp.NewTPMOption {
return tpmp.WithSimulator(sim)
}

func TestTPMKMS_CreateKey_Capabilities(t *testing.T) {
tpmWithNoCaps := newSimulatedTPM(t, withCapabilities(&tpmp.Capabilities{}))
type fields struct {
tpm *tpmp.TPM
}
type args struct {
req *apiv1.CreateKeyRequest
}
tests := []struct {
name string
fields fields
args args
assertFunc assert.ValueAssertionFunc
expErr error
}{
{
name: "fail/unsupported-algorithm",
fields: fields{
tpm: tpmWithNoCaps,
},
args: args{
req: &apiv1.CreateKeyRequest{
Name: "tpmkms:name=key1",
SignatureAlgorithm: apiv1.SHA256WithRSA,
Bits: 2048,
},
},
assertFunc: func(tt assert.TestingT, i1 interface{}, i2 ...interface{}) bool {
if assert.IsType(t, &apiv1.CreateKeyResponse{}, i1) {
r, _ := i1.(*apiv1.CreateKeyResponse)
return assert.Nil(t, r)
}
return false
},
expErr: errors.New(`signature algorithm "SHA256-RSA" not supported by the TPM device`),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
k := &TPMKMS{
tpm: tt.fields.tpm,
}
got, err := k.CreateKey(tt.args.req)
if tt.expErr != nil {
assert.EqualError(t, err, tt.expErr.Error())
return
}

assert.NoError(t, err)
assert.True(t, tt.assertFunc(t, got))
})
}
}

func TestTPMKMS_CreateKey(t *testing.T) {
tpmWithAK := newSimulatedTPM(t, withAK("ak1"))
type fields struct {
Expand Down
14 changes: 14 additions & 0 deletions kms/tpmkms/tpmkms_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,17 @@ func Test_notFoundError(t *testing.T) {
})
}
}

func Test_SetPreferredSignatureAlgorithms(t *testing.T) {
old := preferredSignatureAlgorithms
want := []apiv1.SignatureAlgorithm{
apiv1.ECDSAWithSHA256,
}
SetPreferredSignatureAlgorithms(want)
assert.Equal(t, preferredSignatureAlgorithms, want)
SetPreferredSignatureAlgorithms(old)
}

func Test_PreferredSignatureAlgorithms(t *testing.T) {
assert.Equal(t, PreferredSignatureAlgorithms(), preferredSignatureAlgorithms)
}
113 changes: 113 additions & 0 deletions tpm/algorithm/algorithm.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package algorithm

import (
"encoding/json"
)

// Supported Algorithms.
const (
AlgorithmUnknown Algorithm = 0x0000
AlgorithmRSA Algorithm = 0x0001
Algorithm3DES Algorithm = 0x0003
AlgorithmSHA1 Algorithm = 0x0004
AlgorithmHMAC Algorithm = 0x0005
AlgorithmAES Algorithm = 0x0006
AlgorithmMGF1 Algorithm = 0x0007
AlgorithmKeyedHash Algorithm = 0x0008
AlgorithmXOR Algorithm = 0x000A
AlgorithmSHA256 Algorithm = 0x000B
AlgorithmSHA384 Algorithm = 0x000C
AlgorithmSHA512 Algorithm = 0x000D
AlgorithmNull Algorithm = 0x0010
AlgorithmSM3256 Algorithm = 0x0012
AlgorithmSM4 Algorithm = 0x0013
AlgorithmRSASSA Algorithm = 0x0014
AlgorithmRSAES Algorithm = 0x0015
AlgorithmRSAPSS Algorithm = 0x0016
AlgorithmOAEP Algorithm = 0x0017
AlgorithmECDSA Algorithm = 0x0018
AlgorithmECDH Algorithm = 0x0019
AlgorithmECDAA Algorithm = 0x001A
AlgorithmECSchnorr Algorithm = 0x001C
AlgorithmKDF1_56A Algorithm = 0x0020
AlgorithmKDF2 Algorithm = 0x0021
AlgorithmKDF1_108 Algorithm = 0x0022
AlgorithmECC Algorithm = 0x0023
AlgorithmSymCipher Algorithm = 0x0025
AlgorithmCamellia Algorithm = 0x0026
AlgorithmSHA3_256 Algorithm = 0x0027
AlgorithmSHA3_384 Algorithm = 0x0028
AlgorithmSHA3_512 Algorithm = 0x0029
AlgorithmCMAC Algorithm = 0x003F
AlgorithmCTR Algorithm = 0x0040
AlgorithmOFB Algorithm = 0x0041
AlgorithmCBC Algorithm = 0x0042
AlgorithmCFB Algorithm = 0x0043
AlgorithmECB Algorithm = 0x0044
)

// https://trustedcomputinggroup.org/wp-content/uploads/TCG_TPM2_r1p59_Part2_Structures_pub.pdf
var algs = map[Algorithm]string{
// object types
AlgorithmRSA: "RSA",
AlgorithmECC: "ECC",

// encryption algs
AlgorithmRSAES: "RSAES",

// block ciphers
Algorithm3DES: "3DES",
AlgorithmAES: "AES",
AlgorithmCamellia: "Camellia",
AlgorithmECB: "ECB",
AlgorithmCFB: "CFB",
AlgorithmOFB: "OFB",
AlgorithmCBC: "CBC",
AlgorithmCTR: "CTR",
AlgorithmSymCipher: "Symmetric Cipher",
AlgorithmCMAC: "CMAC",

// other ciphers
AlgorithmXOR: "XOR",
AlgorithmNull: "Null Cipher",

// hash algs
AlgorithmSHA1: "SHA-1",
AlgorithmHMAC: "HMAC",
AlgorithmMGF1: "MGF1",
AlgorithmKeyedHash: "Keyed Hash",
AlgorithmSM3256: "SM3-256",
AlgorithmSHA256: "SHA-256",
AlgorithmSHA384: "SHA-384",
AlgorithmSHA512: "SHA-512",
AlgorithmSHA3_256: "SHA3-256",
AlgorithmSHA3_384: "SHA3-384",
AlgorithmSHA3_512: "SHA3-512",

// signature algs
AlgorithmSM4: "SM4",
AlgorithmRSASSA: "RSA-SSA",
AlgorithmRSAPSS: "RSA-PSS",
AlgorithmECDSA: "ECDSA",
AlgorithmECDAA: "ECDAA",
AlgorithmECSchnorr: "EC-Schnorr",

// encryption schemes
AlgorithmOAEP: "OAEP",
AlgorithmECDH: "ECDH",

// key derivation
AlgorithmKDF1_56A: "KDF1-SP800-56A",
AlgorithmKDF1_108: "KDF1-SP800-108",
AlgorithmKDF2: "KDF2",
}

type Algorithm uint16

func (a Algorithm) String() string {
return algs[Algorithm(int(a))]
}

func (a Algorithm) MarshalJSON() ([]byte, error) {
return json.Marshal(a.String())
}
Loading

0 comments on commit 1412681

Please sign in to comment.