Skip to content

Commit

Permalink
treewide: allow multiple validators
Browse files Browse the repository at this point in the history
This changes the attestation (as of now, only SEV-SNP) to be passed
multiple validators. The aTLS code already handles multiple validators,
but the code previously passed only one. This way, attestation will now
work by being handed a list of validators, and returning success as soon
as one can successfully validate a report. Furthermore, the
`atls.NoValidator` is now obsolete, and semantically represented by
passing an empty list of validators.
  • Loading branch information
msanft committed Aug 13, 2024
1 parent fdd49a3 commit e72e418
Show file tree
Hide file tree
Showing 16 changed files with 247 additions and 216 deletions.
31 changes: 31 additions & 0 deletions cli/cmd/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,18 @@ package cmd
import (
"context"
_ "embed"
"fmt"
"log/slog"
"os"
"path/filepath"
"time"

"github.com/edgelesssys/contrast/cli/telemetry"
"github.com/edgelesssys/contrast/internal/atls"
"github.com/edgelesssys/contrast/internal/attestation/snp"
"github.com/edgelesssys/contrast/internal/fsstore"
"github.com/edgelesssys/contrast/internal/logger"
"github.com/edgelesssys/contrast/internal/manifest"
"github.com/spf13/cobra"
)

Expand Down Expand Up @@ -72,3 +79,27 @@ func withTelemetry(runFunc func(*cobra.Command, []string) error) func(*cobra.Com
return cmdErr
}
}

// validatorsFromManifest returns a list of validators corresponding to the reference values in the given manifest.
func validatorsFromManifest(m *manifest.Manifest, log *slog.Logger, hostData []byte) ([]atls.Validator, error) {
kdsDir, err := cachedir("kds")
if err != nil {
return nil, fmt.Errorf("getting cache dir: %w", err)
}
log.Debug("Using KDS cache dir", "dir", kdsDir)
kdsCache := fsstore.New(kdsDir, log.WithGroup("kds-cache"))
kdsGetter := snp.NewCachedHTTPSGetter(kdsCache, snp.NeverGCTicker, log.WithGroup("kds-getter"))

optsGens, err := m.SNPValidateOpts()
if err != nil {
return nil, fmt.Errorf("getting SNP validate options: %w", err)
}

var validators []atls.Validator
for _, gen := range optsGens {
validators = append(validators, snp.NewValidator(gen.WithStaticHostData(hostData), kdsGetter,
logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "snp"}),
))
}
return validators, nil
}
19 changes: 3 additions & 16 deletions cli/cmd/recover.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@ import (
"path/filepath"

"github.com/edgelesssys/contrast/internal/atls"
"github.com/edgelesssys/contrast/internal/attestation/snp"
"github.com/edgelesssys/contrast/internal/fsstore"
"github.com/edgelesssys/contrast/internal/grpc/dialer"
"github.com/edgelesssys/contrast/internal/logger"
"github.com/edgelesssys/contrast/internal/manifest"
"github.com/edgelesssys/contrast/internal/userapi"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -76,22 +73,12 @@ func runRecover(cmd *cobra.Command, _ []string) error {
return fmt.Errorf("decrypting seed: %w", err)
}

kdsDir, err := cachedir("kds")
validators, err := validatorsFromManifest(&m, log, flags.policy)
if err != nil {
return fmt.Errorf("getting cache dir: %w", err)
return fmt.Errorf("getting validators: %w", err)
}
log.Debug("Using KDS cache dir", "dir", kdsDir)

validateOptsGen, err := newCoordinatorValidateOptsGen(m, flags.policy)
if err != nil {
return fmt.Errorf("generating validate opts: %w", err)
}
kdsCache := fsstore.New(kdsDir, log.WithGroup("kds-cache"))
kdsGetter := snp.NewCachedHTTPSGetter(kdsCache, snp.NeverGCTicker, log.WithGroup("kds-getter"))
validator := snp.NewValidator(validateOptsGen, kdsGetter,
logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "snp"}),
)
dialer := dialer.NewWithKey(atls.NoIssuer, validator, &net.Dialer{}, workloadOwnerKey)
dialer := dialer.NewWithKey(atls.NoIssuer, validators, &net.Dialer{}, workloadOwnerKey)

log.Debug("Dialing coordinator", "endpoint", flags.coordinator)
conn, err := dialer.Dial(cmd.Context(), flags.coordinator)
Expand Down
20 changes: 3 additions & 17 deletions cli/cmd/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@ import (
"time"

"github.com/edgelesssys/contrast/internal/atls"
"github.com/edgelesssys/contrast/internal/attestation/snp"
"github.com/edgelesssys/contrast/internal/fsstore"
"github.com/edgelesssys/contrast/internal/grpc/dialer"
grpcRetry "github.com/edgelesssys/contrast/internal/grpc/retry"
"github.com/edgelesssys/contrast/internal/logger"
"github.com/edgelesssys/contrast/internal/manifest"
"github.com/edgelesssys/contrast/internal/retry"
"github.com/edgelesssys/contrast/internal/spinner"
Expand Down Expand Up @@ -101,22 +98,11 @@ func runSet(cmd *cobra.Command, args []string) error {
return fmt.Errorf("checking policies match manifest: %w", err)
}

kdsDir, err := cachedir("kds")
validators, err := validatorsFromManifest(&m, log, flags.policy)
if err != nil {
return fmt.Errorf("getting cache dir: %w", err)
return fmt.Errorf("getting validators: %w", err)
}
log.Debug("Using KDS cache dir", "dir", kdsDir)

validateOptsGen, err := newCoordinatorValidateOptsGen(m, flags.policy)
if err != nil {
return fmt.Errorf("generating validate opts: %w", err)
}
kdsCache := fsstore.New(kdsDir, log.WithGroup("kds-cache"))
kdsGetter := snp.NewCachedHTTPSGetter(kdsCache, snp.NeverGCTicker, log.WithGroup("kds-getter"))
validator := snp.NewValidator(validateOptsGen, kdsGetter,
logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "snp"}),
)
dialer := dialer.NewWithKey(atls.NoIssuer, validator, &net.Dialer{}, workloadOwnerKey)
dialer := dialer.NewWithKey(atls.NoIssuer, validators, &net.Dialer{}, workloadOwnerKey)

conn, err := dialer.Dial(cmd.Context(), flags.coordinator)
if err != nil {
Expand Down
31 changes: 3 additions & 28 deletions cli/cmd/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@ import (
"path/filepath"

"github.com/edgelesssys/contrast/internal/atls"
"github.com/edgelesssys/contrast/internal/attestation/snp"
"github.com/edgelesssys/contrast/internal/fsstore"
"github.com/edgelesssys/contrast/internal/grpc/dialer"
"github.com/edgelesssys/contrast/internal/logger"
"github.com/edgelesssys/contrast/internal/manifest"
"github.com/edgelesssys/contrast/internal/userapi"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -71,22 +68,11 @@ func runVerify(cmd *cobra.Command, _ []string) error {
return fmt.Errorf("validating manifest: %w", err)
}

kdsDir, err := cachedir("kds")
validators, err := validatorsFromManifest(&m, log, flags.policy)
if err != nil {
return fmt.Errorf("getting cache dir: %w", err)
return fmt.Errorf("getting validators: %w", err)
}
log.Debug("Using KDS cache dir", "dir", kdsDir)

validateOptsGen, err := newCoordinatorValidateOptsGen(m, flags.policy)
if err != nil {
return fmt.Errorf("generating validate opts: %w", err)
}
kdsCache := fsstore.New(kdsDir, log.WithGroup("kds-cache"))
kdsGetter := snp.NewCachedHTTPSGetter(kdsCache, snp.NeverGCTicker, log.WithGroup("kds-getter"))
validator := snp.NewValidator(validateOptsGen, kdsGetter,
logger.NewWithAttrs(logger.NewNamed(log, "validator"), map[string]string{"tee-type": "snp"}),
)
dialer := dialer.New(atls.NoIssuer, validator, &net.Dialer{})
dialer := dialer.New(atls.NoIssuer, validators, &net.Dialer{})

log.Debug("Dialing coordinator", "endpoint", flags.coordinator)
conn, err := dialer.Dial(cmd.Context(), flags.coordinator)
Expand Down Expand Up @@ -174,17 +160,6 @@ func parseVerifyFlags(cmd *cobra.Command) (*verifyFlags, error) {
}, nil
}

func newCoordinatorValidateOptsGen(mnfst manifest.Manifest, hostData []byte) (*snp.StaticValidateOptsGenerator, error) {
validateOpts, err := mnfst.AKSValidateOpts()
if err != nil {
return nil, err
}
validateOpts.HostData = hostData
return &snp.StaticValidateOptsGenerator{
Opts: validateOpts,
}, nil
}

func writeFilelist(dir string, filelist map[string][]byte) error {
if dir != "" {
if err := os.MkdirAll(dir, 0o755); err != nil {
Expand Down
18 changes: 0 additions & 18 deletions coordinator/internal/authority/authority.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ import (
"github.com/edgelesssys/contrast/internal/ca"
"github.com/edgelesssys/contrast/internal/manifest"
"github.com/edgelesssys/contrast/internal/userapi"
"github.com/google/go-sev-guest/proto/sevsnp"
"github.com/google/go-sev-guest/validate"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
Expand Down Expand Up @@ -171,19 +169,3 @@ type State struct {
latest *history.LatestTransition
generation int
}

// SNPValidateOpts returns SNP validation options from reference values.
//
// It also ensures that the policy hash in the report's HOSTDATA is allowed by the current
// manifest.
// TODO(msanft): make the manifest authoritative and allow other types of reference values.
func (s *State) SNPValidateOpts(report *sevsnp.Report) (*validate.Options, error) {
mnfst := s.Manifest

hostData := manifest.NewHexString(report.HostData)
if _, ok := mnfst.Policies[hostData]; !ok {
return nil, fmt.Errorf("hostdata %s not found in manifest", hostData)
}

return mnfst.AKSValidateOpts()
}
12 changes: 8 additions & 4 deletions coordinator/internal/authority/authority_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,20 @@ func TestSNPValidateOpts(t *testing.T) {
_, err := a.SetManifest(context.Background(), req)
require.NoError(err)

opts, err := a.state.Load().SNPValidateOpts(report)
gens, err := a.state.Load().Manifest.SNPValidateOpts()
require.NoError(err)
require.NotNil(opts)
require.NotNil(gens)

// Change to unknown policy hash in HostData.
report.HostData[0]++

opts, err = a.state.Load().SNPValidateOpts(report)
gens, err = a.state.Load().Manifest.SNPValidateOpts()
require.NoError(err)
require.NotNil(gens)

gen := gens[0].WithReportHostData()
_, err = gen.SNPValidateOpts(report)
require.Error(err)
require.Nil(opts)
}

// TODO(burgerdev): test ValidateCallback and GetCertBundle
Expand Down
16 changes: 12 additions & 4 deletions coordinator/internal/authority/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,19 @@ func (c *Credentials) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.A

authInfo := AuthInfo{State: state}

validator := snp.NewValidatorWithCallbacks(state, c.kdsGetter,
logger.NewWithAttrs(logger.NewNamed(c.logger, "validator"), map[string]string{"tee-type": "snp"}),
c.attestationFailuresCounter, &authInfo)
optsGens, err := state.Manifest.SNPValidateOpts()
if err != nil {
return nil, nil, fmt.Errorf("generating SNP validation options: %w", err)
}

serverCfg, err := atls.CreateAttestationServerTLSConfig(c.issuer, []atls.Validator{validator})
var validators []atls.Validator
for _, gen := range optsGens {
validator := snp.NewValidatorWithCallbacks(gen.WithReportHostData(), c.kdsGetter,
logger.NewWithAttrs(logger.NewNamed(c.logger, "validator"), map[string]string{"tee-type": "snp"}),
c.attestationFailuresCounter, &authInfo)
validators = append(validators, validator)
}
serverCfg, err := atls.CreateAttestationServerTLSConfig(c.issuer, validators)
if err != nil {
return nil, nil, err
}
Expand Down
4 changes: 3 additions & 1 deletion initializer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ func run() (retErr error) {
}

requestCert := func() (*meshapi.NewMeshCertResponse, error) {
dial := dialer.NewWithKey(issuer, atls.NoValidator, &net.Dialer{}, privKey)
// Supply an empty list of validators, as the coordinator does not need to be
// validated by the initializer.
dial := dialer.NewWithKey(issuer, atls.NoValidators, &net.Dialer{}, privKey)
conn, err := dial.Dial(ctx, net.JoinHostPort(coordinatorHostname, meshapi.Port))
if err != nil {
return nil, fmt.Errorf("dialing: %w", err)
Expand Down
4 changes: 2 additions & 2 deletions internal/atls/atls.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ import (
const attestationTimeout = 30 * time.Second

var (
// NoValidator skips validation of the server's attestation document.
NoValidator Validator
// NoValidators skips validation of the server's attestation document.
NoValidators = []Validator{}
// NoIssuer skips embedding the client's attestation document.
NoIssuer Issuer

Expand Down
8 changes: 4 additions & 4 deletions internal/attestation/snp/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ type validateOptsGenerator interface {
SNPValidateOpts(report *sevsnp.Report) (*validate.Options, error)
}

// StaticValidateOptsGenerator returns validate.Options generator that returns
// StaticValidateOptsGenerator is a [validate.Options] generator that returns
// static validation options.
type StaticValidateOptsGenerator struct {
Opts *validate.Options
}

// SNPValidateOpts return the SNP validation options.
// SNPValidateOpts returns the SNP validation options.
func (v *StaticValidateOptsGenerator) SNPValidateOpts(_ *sevsnp.Report) (*validate.Options, error) {
return v.Opts, nil
}
Expand All @@ -65,13 +65,13 @@ func NewValidator(optsGen validateOptsGenerator, kdsGetter trust.HTTPSGetter, lo
}

// NewValidatorWithCallbacks returns a new Validator with callbacks.
func NewValidatorWithCallbacks(optsGen validateOptsGenerator, kdsGetter trust.HTTPSGetter, log *slog.Logger, attestataionFailures prometheus.Counter, callbacks ...validateCallbacker) *Validator {
func NewValidatorWithCallbacks(optsGen validateOptsGenerator, kdsGetter trust.HTTPSGetter, log *slog.Logger, attestationFailures prometheus.Counter, callbacks ...validateCallbacker) *Validator {
return &Validator{
validateOptsGen: optsGen,
callbackers: callbacks,
kdsGetter: kdsGetter,
logger: log,
metrics: metrics{attestationFailures: attestataionFailures},
metrics: metrics{attestationFailures: attestationFailures},
}
}

Expand Down
32 changes: 14 additions & 18 deletions internal/grpc/dialer/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,34 @@ import (

// Dialer can open grpc client connections with different levels of ATLS encryption / verification.
type Dialer struct {
issuer atls.Issuer
validator atls.Validator
netDialer NetDialer
privKey *ecdsa.PrivateKey
issuer atls.Issuer
validators []atls.Validator
netDialer NetDialer
privKey *ecdsa.PrivateKey
}

// New creates a new Dialer.
func New(issuer atls.Issuer, validator atls.Validator, netDialer NetDialer) *Dialer {
func New(issuer atls.Issuer, validators []atls.Validator, netDialer NetDialer) *Dialer {
return &Dialer{
issuer: issuer,
validator: validator,
netDialer: netDialer,
issuer: issuer,
validators: validators,
netDialer: netDialer,
}
}

// NewWithKey creates a new Dialer with the given private key.
func NewWithKey(issuer atls.Issuer, validator atls.Validator, netDialer NetDialer, privKey *ecdsa.PrivateKey) *Dialer {
func NewWithKey(issuer atls.Issuer, validators []atls.Validator, netDialer NetDialer, privKey *ecdsa.PrivateKey) *Dialer {
return &Dialer{
issuer: issuer,
validator: validator,
netDialer: netDialer,
privKey: privKey,
issuer: issuer,
validators: validators,
netDialer: netDialer,
privKey: privKey,
}
}

// Dial creates a new grpc client connection to the given target using the atls validator.
func (d *Dialer) Dial(_ context.Context, target string) (*grpc.ClientConn, error) {
var validators []atls.Validator
if d.validator != nil {
validators = append(validators, d.validator)
}
credentials := atlscredentials.NewWithKey(d.issuer, validators, d.privKey)
credentials := atlscredentials.NewWithKey(d.issuer, d.validators, d.privKey)

return grpc.NewClient(target,
d.grpcWithDialer(),
Expand Down
18 changes: 2 additions & 16 deletions internal/manifest/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,13 @@ import (
// Default returns a default manifest with reference values for the given platform.
func Default(platform platforms.Platform) (*Manifest, error) {
embeddedRefValues := GetEmbeddedReferenceValues()

refValues, err := embeddedRefValues.ForPlatform(platform)
if err != nil {
return nil, fmt.Errorf("get reference values for platform %s: %w", platform, err)
}

mnfst := Manifest{}
switch platform {
case platforms.AKSCloudHypervisorSNP:
return &Manifest{
ReferenceValues: ReferenceValues{
AKS: refValues.AKS,
},
}, nil
case platforms.RKE2QEMUTDX, platforms.K3sQEMUTDX:
return &Manifest{
ReferenceValues: ReferenceValues{
BareMetalTDX: refValues.BareMetalTDX,
},
}, nil
}
return &mnfst, nil
return &Manifest{ReferenceValues: *refValues}, nil
}

// GetEmbeddedReferenceValues returns the reference values embedded in the binary.
Expand Down
Loading

0 comments on commit e72e418

Please sign in to comment.