Skip to content

Commit

Permalink
atls: explicitly pass platform to issuer
Browse files Browse the repository at this point in the history
The aTLS issuer is currently configured according to the CPU vendor. To
allow for multiple attestation variants on a single CPU, forward the
platform enum to the issuer instantiation.
  • Loading branch information
burgerdev committed Oct 30, 2024
1 parent bae04c4 commit f0e3b59
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 43 deletions.
21 changes: 11 additions & 10 deletions cli/cmd/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,7 @@ func runGenerate(cmd *cobra.Command, args []string) error {
}
}

runtimeHandler, err := manifest.RuntimeHandler(flags.referenceValuesPlatform)
if err != nil {
return fmt.Errorf("get runtime handler: %w", err)
}

if err := patchTargets(paths, flags.imageReplacementsFile, runtimeHandler, flags.skipInitializer, log); err != nil {
if err := patchTargets(paths, flags.imageReplacementsFile, flags.referenceValuesPlatform, flags.skipInitializer, log); err != nil {
return fmt.Errorf("patch targets: %w", err)
}
fmt.Fprintln(cmd.OutOrStdout(), "✔️ Patched targets")
Expand Down Expand Up @@ -270,7 +265,7 @@ func generatePolicies(ctx context.Context, flags *generateFlags, yamlPaths []str
return nil
}

func patchTargets(paths []string, imageReplacementsFile, runtimeHandler string, skipInitializer bool, logger *slog.Logger) error {
func patchTargets(paths []string, imageReplacementsFile string, platform platforms.Platform, skipInitializer bool, logger *slog.Logger) error {
var replacements map[string]string
var err error
if imageReplacementsFile != "" {
Expand All @@ -290,6 +285,12 @@ func patchTargets(paths []string, imageReplacementsFile, runtimeHandler string,
return fmt.Errorf("parsing release image definitions %s: %w", ReleaseImageReplacements, err)
}
}

runtimeHandler, err := manifest.RuntimeHandler(platform)
if err != nil {
return fmt.Errorf("getting runtime handler: %w", err)
}

for _, path := range paths {
data, err := os.ReadFile(path)
if err != nil {
Expand All @@ -301,7 +302,7 @@ func patchTargets(paths []string, imageReplacementsFile, runtimeHandler string,
}

if !skipInitializer {
if err := injectInitializer(kubeObjs); err != nil {
if err := injectInitializer(kubeObjs, platform); err != nil {
return fmt.Errorf("injecting Initializer: %w", err)
}
}
Expand All @@ -328,7 +329,7 @@ func patchTargets(paths []string, imageReplacementsFile, runtimeHandler string,
return nil
}

func injectInitializer(resources []any) error {
func injectInitializer(resources []any, platform platforms.Platform) error {
for _, resource := range resources {
switch r := resource.(type) {
case *applyappsv1.StatefulSetApplyConfiguration:
Expand All @@ -337,7 +338,7 @@ func injectInitializer(resources []any) error {
continue
}
}
_, err := kuberesource.AddInitializer(resource, kuberesource.Initializer())
_, err := kuberesource.AddInitializer(resource, kuberesource.Initializer(platform))
if err != nil {
return err
}
Expand Down
14 changes: 9 additions & 5 deletions coordinator/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package main
import (
"context"
"errors"
"flag"
"fmt"
"log/slog"
"net"
Expand All @@ -19,6 +20,7 @@ import (
"github.com/edgelesssys/contrast/internal/grpc/atlscredentials"
"github.com/edgelesssys/contrast/internal/logger"
"github.com/edgelesssys/contrast/internal/meshapi"
"github.com/edgelesssys/contrast/internal/platforms"
"github.com/edgelesssys/contrast/internal/userapi"
grpcprometheus "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus"
"github.com/prometheus/client_golang/prometheus"
Expand All @@ -33,12 +35,14 @@ const (
)

func main() {
if err := run(); err != nil {
platformFlag := platforms.RegisterFlag()
flag.Parse()
if err := run(platformFlag.Platform); err != nil {
os.Exit(1)
}
}

func run() (retErr error) {
func run(platform platforms.Platform) (retErr error) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

Expand Down Expand Up @@ -69,7 +73,7 @@ func run() (retErr error) {
}

meshAuth := authority.New(hist, promRegistry, logger)
grpcServer, err := newGRPCServer(serverMetrics, logger)
grpcServer, err := newGRPCServer(serverMetrics, logger, platform)
if err != nil {
return fmt.Errorf("creating gRPC server: %w", err)
}
Expand Down Expand Up @@ -153,8 +157,8 @@ func newServerMetrics(reg *prometheus.Registry) *grpcprometheus.ServerMetrics {
return serverMetrics
}

func newGRPCServer(serverMetrics *grpcprometheus.ServerMetrics, log *slog.Logger) (*grpc.Server, error) {
issuer, err := atls.PlatformIssuer(log)
func newGRPCServer(serverMetrics *grpcprometheus.ServerMetrics, log *slog.Logger, platform platforms.Platform) (*grpc.Server, error) {
issuer, err := atls.PlatformIssuer(log, platform)
if err != nil {
return nil, fmt.Errorf("creating issuer: %w", err)
}
Expand Down
1 change: 0 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ require (
github.com/google/go-tdx-guest v0.3.1
github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus v1.0.1
github.com/katexochen/sync v0.0.0-20240617152407-6a8003240db0
github.com/klauspost/cpuid/v2 v2.2.8
github.com/pelletier/go-toml/v2 v2.2.3
github.com/prometheus/client_golang v1.20.4
github.com/prometheus/common v0.60.0
Expand Down
3 changes: 0 additions & 3 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,6 @@ github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/klauspost/cpuid/v2 v2.2.8 h1:+StwCXwm9PdpiEkPyzBXIy+M9KUb4ODm0Zarf1kS5BM=
github.com/klauspost/cpuid/v2 v2.2.8/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
Expand Down Expand Up @@ -175,7 +173,6 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210426230700-d19ff857e887/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24=
Expand Down
10 changes: 7 additions & 3 deletions initializer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"encoding/hex"
"encoding/pem"
"errors"
"flag"
"fmt"
"net"
"os"
Expand All @@ -21,15 +22,18 @@ import (
"github.com/edgelesssys/contrast/internal/grpc/dialer"
"github.com/edgelesssys/contrast/internal/logger"
"github.com/edgelesssys/contrast/internal/meshapi"
"github.com/edgelesssys/contrast/internal/platforms"
)

func main() {
if err := run(); err != nil {
platformFlag := platforms.RegisterFlag()
flag.Parse()
if err := run(platformFlag.Platform); err != nil {
os.Exit(1)
}
}

func run() (retErr error) {
func run(platform platforms.Platform) (retErr error) {
log, err := logger.Default()
if err != nil {
fmt.Fprintf(os.Stderr, "Error: creating logger: %v\n", err)
Expand All @@ -55,7 +59,7 @@ func run() (retErr error) {
return fmt.Errorf("generating key: %w", err)
}

issuer, err := atls.PlatformIssuer(log)
issuer, err := atls.PlatformIssuer(log, platform)
if err != nil {
return fmt.Errorf("creating issuer: %w", err)
}
Expand Down
15 changes: 7 additions & 8 deletions internal/atls/issuer.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,21 @@ import (
"github.com/edgelesssys/contrast/internal/attestation/snp"
"github.com/edgelesssys/contrast/internal/attestation/tdx"
"github.com/edgelesssys/contrast/internal/logger"
"github.com/klauspost/cpuid/v2"
"github.com/edgelesssys/contrast/internal/platforms"
)

// PlatformIssuer creates an attestation issuer for the current platform.
func PlatformIssuer(log *slog.Logger) (Issuer, error) {
cpuid.Detect()
switch {
case cpuid.CPU.Supports(cpuid.SEV_SNP):
// PlatformIssuer creates an attestation issuer for the target platform.
func PlatformIssuer(log *slog.Logger, platform platforms.Platform) (Issuer, error) {
switch platform {
case platforms.AKSCloudHypervisorSNP, platforms.K3sQEMUSNP:
return snp.NewIssuer(
logger.NewWithAttrs(logger.NewNamed(log, "issuer"), map[string]string{"tee-type": "snp"}),
), nil
case cpuid.CPU.Supports(cpuid.TDX_GUEST):
case platforms.K3sQEMUTDX, platforms.RKE2QEMUTDX:
return tdx.NewIssuer(
logger.NewWithAttrs(logger.NewNamed(log, "issuer"), map[string]string{"tee-type": "tdx"}),
), nil
default:
return nil, fmt.Errorf("unsupported platform: %T", cpuid.CPU)
return nil, fmt.Errorf("unsupported platform: %q", platform)
}
}
23 changes: 12 additions & 11 deletions internal/kuberesource/mutators_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package kuberesource
import (
"testing"

"github.com/edgelesssys/contrast/internal/platforms"
"github.com/stretchr/testify/require"
applyappsv1 "k8s.io/client-go/applyconfigurations/apps/v1"
applycorev1 "k8s.io/client-go/applyconfigurations/core/v1"
Expand Down Expand Up @@ -90,7 +91,7 @@ func TestAddInitializer(t *testing.T) {
WithTemplate(applycorev1.PodTemplateSpec().
WithSpec(applycorev1.PodSpec().
WithContainers(applycorev1.Container()).
WithInitContainers(Initializer()).
WithInitContainers(Initializer(platforms.Unknown)).
WithRuntimeClassName("contrast-cc"),
))),
wantError: false,
Expand All @@ -104,7 +105,7 @@ func TestAddInitializer(t *testing.T) {
WithContainers(applycorev1.Container()).
WithRuntimeClassName("contrast-cc").
WithVolumes(Volume().
WithName(*Initializer().VolumeMounts[0].Name).
WithName(*Initializer(platforms.Unknown).VolumeMounts[0].Name).
WithEmptyDir(EmptyDirVolumeSource().Inner()),
),
))),
Expand All @@ -119,7 +120,7 @@ func TestAddInitializer(t *testing.T) {
WithContainers(applycorev1.Container()).
WithRuntimeClassName("contrast-cc").
WithVolumes(Volume().
WithName(*Initializer().VolumeMounts[0].Name).
WithName(*Initializer(platforms.Unknown).VolumeMounts[0].Name).
WithConfigMap(Volume().ConfigMap),
),
))),
Expand All @@ -135,8 +136,8 @@ func TestAddInitializer(t *testing.T) {
applycorev1.Container().
WithVolumeMounts(
VolumeMount().
WithName(*Initializer().VolumeMounts[0].Name).
WithMountPath(*Initializer().VolumeMounts[0].Name),
WithName(*Initializer(platforms.Unknown).VolumeMounts[0].Name).
WithMountPath(*Initializer(platforms.Unknown).VolumeMounts[0].Name),
),
).
WithRuntimeClassName("contrast-cc"),
Expand All @@ -147,21 +148,21 @@ func TestAddInitializer(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
require := require.New(t)

_, err := AddInitializer(tc.d, Initializer())
_, err := AddInitializer(tc.d, Initializer(platforms.Unknown))
if tc.wantError {
require.Error(err)
return
}
require.NoError(err)

require.NotEmpty(tc.d.Spec.Template.Spec.InitContainers)
require.Equal(*tc.d.Spec.Template.Spec.InitContainers[0].Name, *Initializer().Name)
require.Equal(*tc.d.Spec.Template.Spec.InitContainers[0].Name, *Initializer(platforms.Unknown).Name)
require.NotEmpty(tc.d.Spec.Template.Spec.InitContainers[0].VolumeMounts)
require.Equal(*tc.d.Spec.Template.Spec.InitContainers[0].VolumeMounts[0].Name, *Initializer().VolumeMounts[0].Name)
require.Equal(*tc.d.Spec.Template.Spec.InitContainers[0].VolumeMounts[0].Name, *Initializer(platforms.Unknown).VolumeMounts[0].Name)

initializerCount := 0
for _, c := range tc.d.Spec.Template.Spec.InitContainers {
if *c.Name == *Initializer().Name {
if *c.Name == *Initializer(platforms.Unknown).Name {
initializerCount++
}
}
Expand All @@ -171,7 +172,7 @@ func TestAddInitializer(t *testing.T) {
for _, c := range tc.d.Spec.Template.Spec.Containers {
initializerVolumeMountCount := 0
for _, v := range c.VolumeMounts {
if *v.Name == *Initializer().VolumeMounts[0].Name {
if *v.Name == *Initializer(platforms.Unknown).VolumeMounts[0].Name {
initializerVolumeMountCount++
}
}
Expand All @@ -181,7 +182,7 @@ func TestAddInitializer(t *testing.T) {
require.NotEmpty(tc.d.Spec.Template.Spec.Volumes)
initializerVolumeCount := 0
for _, v := range tc.d.Spec.Template.Spec.Volumes {
if *v.Name == *Initializer().VolumeMounts[0].Name {
if *v.Name == *Initializer(platforms.Unknown).VolumeMounts[0].Name {
initializerVolumeCount++
}
}
Expand Down
3 changes: 2 additions & 1 deletion internal/kuberesource/parts.go
Original file line number Diff line number Diff line change
Expand Up @@ -426,13 +426,14 @@ func PortForwarderForService(svc *applycorev1.ServiceApplyConfiguration) *applyc
}

// Initializer creates a new InitializerConfig.
func Initializer() *applycorev1.ContainerApplyConfiguration {
func Initializer(platform platforms.Platform) *applycorev1.ContainerApplyConfiguration {
return applycorev1.Container().
WithName("contrast-initializer").
WithImage("ghcr.io/edgelesssys/contrast/initializer:latest").
WithResources(ResourceRequirements().
WithMemoryRequest(50),
).
WithArgs("--platform", platform.String()).
WithEnv(NewEnvVar("COORDINATOR_HOST", "coordinator")).
WithVolumeMounts(VolumeMount().
WithName("contrast-secrets").
Expand Down
31 changes: 31 additions & 0 deletions internal/platforms/platforms.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package platforms

import (
"flag"
"fmt"
"strings"
)
Expand Down Expand Up @@ -71,3 +72,33 @@ func FromString(s string) (Platform, error) {
return Unknown, fmt.Errorf("unknown platform: %s", s)
}
}

// Flag implements flag.Value for Platform.
type Flag struct {
Platform
}

// RegisterFlag registers a --platform flag with the flag package.
func RegisterFlag() *Flag {
var f Flag
flag.Var(&f, "platform", fmt.Sprintf("target platform (one of %s)", strings.Join(AllStrings(), ", ")))
return &f
}

// String returns the string representation of the current flag value.
func (f *Flag) String() string {
if f == nil {
return Unknown.String()
}
return f.Platform.String()
}

// Set configures the flag from the given string.
func (f *Flag) Set(s string) error {
p, err := FromString(s)
if err != nil {
return err
}
f.Platform = p
return nil
}
2 changes: 1 addition & 1 deletion packages/by-name/contrast/package.nix
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ buildGoModule rec {
};

proxyVendor = true;
vendorHash = "sha256-mJz8VgNtw8hwME7mFSc4mLsotreU5Ql1eou0fGIQH7w=";
vendorHash = "sha256-61bgKwUNzmeedY2pg08f6UlZFbRMRGVPc8dLN+EMh4g=";

nativeBuildInputs = [ installShellFiles ];

Expand Down

0 comments on commit f0e3b59

Please sign in to comment.