From 30e6e8ccfe0b62e875cdc861e16e8ee7748bf56f Mon Sep 17 00:00:00 2001 From: Philipp Winter Date: Mon, 14 Oct 2024 07:51:03 -0500 Subject: [PATCH] Set the certificate's hash upon init. --- cmd/main_test.go | 37 +++++++++++++++---------------------- internal/service/service.go | 2 ++ 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/cmd/main_test.go b/cmd/main_test.go index 104f6e1..88315e2 100644 --- a/cmd/main_test.go +++ b/cmd/main_test.go @@ -14,7 +14,6 @@ import ( "net/url" "os" "slices" - "strings" "sync" "syscall" "testing" @@ -277,7 +276,7 @@ func TestHashes(t *testing.T) { body, ) } - doGet = func() (*http.Response, error) { + doGet = func(_ io.Reader) (*http.Response, error) { return testutil.Client.Get(intSrv("/enclave/hashes")) } ) @@ -285,26 +284,26 @@ func TestHashes(t *testing.T) { cases := []struct { name string - method string + reqFunc func(io.Reader) (*http.Response, error) toMarshal any wantCode int wantHashes *attestation.Hashes }{ { name: "get empty hashes", - method: http.MethodGet, + reqFunc: doGet, wantCode: http.StatusOK, wantHashes: new(attestation.Hashes), }, { name: "post application hash", - method: http.MethodPost, + reqFunc: doPost, toMarshal: hashes, wantCode: http.StatusOK, }, { name: "get populated hashes", - method: http.MethodGet, + reqFunc: doGet, wantCode: http.StatusOK, wantHashes: hashes, }, @@ -312,16 +311,10 @@ func TestHashes(t *testing.T) { for _, c := range cases { t.Run(c.name, func(t *testing.T) { - var b []byte - var resp *http.Response - var err error - if c.method == http.MethodGet { - resp, err = doGet() - } else { - b, err = json.Marshal(c.toMarshal) - require.NoError(t, err) - resp, err = doPost(bytes.NewReader(b)) - } + // Either POST or GET the hashes. + reqBody, err := json.Marshal(c.toMarshal) + require.NoError(t, err) + resp, err := c.reqFunc(bytes.NewReader(reqBody)) require.NoError(t, err) require.Equal(t, c.wantCode, resp.StatusCode) @@ -334,13 +327,13 @@ func TestHashes(t *testing.T) { gotBody, err := io.ReadAll(resp.Body) require.NoError(t, err) defer resp.Body.Close() - wantBody, err := json.Marshal(c.wantHashes) - require.NoError(t, err) + var gotHashes attestation.Hashes + require.NoError(t, json.Unmarshal(gotBody, &gotHashes)) - require.Equal(t, - strings.TrimSpace(string(wantBody)), - strings.TrimSpace(string(gotBody)), - ) + // Make sure that the application hashes match. + require.Equal(t, c.wantHashes.AppKeyHash, gotHashes.AppKeyHash) + // Make sure that the TLS certificate hash is set. + require.NotEmpty(t, gotHashes.TlsKeyHash) }) } } diff --git a/internal/service/service.go b/internal/service/service.go index d085674..6304780 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -2,6 +2,7 @@ package service import ( "context" + "crypto/sha256" "crypto/tls" "errors" "log" @@ -41,6 +42,7 @@ func Run( // Initialize hashes for the attestation document. hashes := new(attestation.Hashes) + hashes.SetTLSHash(util.AddrOf(sha256.Sum256(cert))) // Initialize Web servers. intSrv := newIntSrv(config, keys, hashes, appReady)