Skip to content
This repository has been archived by the owner on Dec 12, 2024. It is now read-only.

Commit

Permalink
ctx for db write (#246)
Browse files Browse the repository at this point in the history
* ctx for db write

* builds

* fixed build and tests
  • Loading branch information
nitro-neal authored Jan 10, 2023
1 parent d0dde58 commit 82a6233
Show file tree
Hide file tree
Showing 47 changed files with 576 additions and 530 deletions.
13 changes: 7 additions & 6 deletions internal/credential/verification.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package credential

import (
"context"
"crypto"
"fmt"

Expand Down Expand Up @@ -49,7 +50,7 @@ func NewCredentialVerifier(didResolver *didsdk.Resolver, schemaResolver schema.R

// VerifyJWTCredential first parses and checks the signature on the given JWT credential. Next, it runs
// a set of static verification checks on the credential as per the credential service's configuration.
func (v Verifier) VerifyJWTCredential(token keyaccess.JWT) error {
func (v Verifier) VerifyJWTCredential(ctx context.Context, token keyaccess.JWT) error {
// first, parse the token to see if it contains a valid verifiable credential
cred, err := signing.ParseVerifiableCredentialFromJWT(token.String())
if err != nil {
Expand Down Expand Up @@ -84,12 +85,12 @@ func (v Verifier) VerifyJWTCredential(token keyaccess.JWT) error {
return util.LoggingErrorMsg(err, "could not verify credential's signature")
}

return v.staticVerificationChecks(*cred)
return v.staticVerificationChecks(ctx, *cred)
}

// VerifyDataIntegrityCredential first checks the signature on the given data integrity credential. Next, it runs
// a set of static verification checks on the credential as per the credential service's configuration.
func (v Verifier) VerifyDataIntegrityCredential(credential credsdk.VerifiableCredential) error {
func (v Verifier) VerifyDataIntegrityCredential(ctx context.Context, credential credsdk.VerifiableCredential) error {
// resolve the issuer's key material
kid, pubKey, err := v.resolveCredentialIssuerKey(credential)
if err != nil {
Expand All @@ -108,7 +109,7 @@ func (v Verifier) VerifyDataIntegrityCredential(credential credsdk.VerifiableCre
return util.LoggingErrorMsg(err, "could not verify the credential's signature")
}

return v.staticVerificationChecks(credential)
return v.staticVerificationChecks(ctx, credential)
}

func (v Verifier) VerifyJWT(did string, token keyaccess.JWT) error {
Expand Down Expand Up @@ -144,12 +145,12 @@ func (v Verifier) resolveCredentialIssuerKey(credential credsdk.VerifiableCreden

// staticVerificationChecks runs a set of static verification checks on the credential as per the credential
// service's configuration, such as checking the credential's schema, expiration, and object validity.
func (v Verifier) staticVerificationChecks(credential credsdk.VerifiableCredential) error {
func (v Verifier) staticVerificationChecks(ctx context.Context, credential credsdk.VerifiableCredential) error {
// if the credential has a schema, resolve it before it is to be used in verification
var verificationOpts []verification.VerificationOption
if credential.CredentialSchema != nil {
schemaID := credential.CredentialSchema.ID
resolvedSchema, err := v.schemaResolver.Resolve(schemaID)
resolvedSchema, err := v.schemaResolver.Resolve(ctx, schemaID)
if err != nil {
return errors.Wrapf(err, "for credential<%s> failed to resolve schemas: %s", credential.ID, schemaID)
}
Expand Down
4 changes: 3 additions & 1 deletion internal/schema/resolver.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package schema

import (
"context"

"github.com/TBD54566975/ssi-sdk/credential/schema"
)

// Resolution is an interface that defines a generic method of resolving a schema
type Resolution interface {
Resolve(id string) (*schema.VCJSONSchema, error)
Resolve(ctx context.Context, id string) (*schema.VCJSONSchema, error)
}
20 changes: 10 additions & 10 deletions pkg/server/router/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func (cr CredentialRouter) CreateCredential(ctx context.Context, w http.Response
}

req := request.ToServiceRequest()
createCredentialResponse, err := cr.service.CreateCredential(req)
createCredentialResponse, err := cr.service.CreateCredential(ctx, req)
if err != nil {
errMsg := "could not create credential"
logrus.WithError(err).Error(errMsg)
Expand Down Expand Up @@ -132,7 +132,7 @@ func (cr CredentialRouter) GetCredential(ctx context.Context, w http.ResponseWri
return framework.NewRequestErrorMsg(errMsg, http.StatusBadRequest)
}

gotCredential, err := cr.service.GetCredential(credential.GetCredentialRequest{ID: *id})
gotCredential, err := cr.service.GetCredential(ctx, credential.GetCredentialRequest{ID: *id})
if err != nil {
errMsg := fmt.Sprintf("could not get credential with id: %s", *id)
logrus.WithError(err).Error(errMsg)
Expand Down Expand Up @@ -169,7 +169,7 @@ func (cr CredentialRouter) GetCredentialStatus(ctx context.Context, w http.Respo
return framework.NewRequestErrorMsg(errMsg, http.StatusBadRequest)
}

getCredentialStatusResponse, err := cr.service.GetCredentialStatus(credential.GetCredentialStatusRequest{ID: *id})
getCredentialStatusResponse, err := cr.service.GetCredentialStatus(ctx, credential.GetCredentialStatusRequest{ID: *id})
if err != nil {
errMsg := fmt.Sprintf("could not get credential with id: %s", *id)
logrus.WithError(err).Error(errMsg)
Expand Down Expand Up @@ -207,7 +207,7 @@ func (cr CredentialRouter) GetCredentialStatusList(ctx context.Context, w http.R
return framework.NewRequestErrorMsg(errMsg, http.StatusBadRequest)
}

gotCredential, err := cr.service.GetCredentialStatusList(credential.GetCredentialStatusListRequest{ID: *id})
gotCredential, err := cr.service.GetCredentialStatusList(ctx, credential.GetCredentialStatusListRequest{ID: *id})
if err != nil {
errMsg := fmt.Sprintf("could not get credential status list with id: %s", *id)
logrus.WithError(err).Error(errMsg)
Expand Down Expand Up @@ -272,7 +272,7 @@ func (cr CredentialRouter) UpdateCredentialStatus(ctx context.Context, w http.Re
}

req := request.ToServiceRequest(*id)
gotCredential, err := cr.service.UpdateCredentialStatus(req)
gotCredential, err := cr.service.UpdateCredentialStatus(ctx, req)

if err != nil {
errMsg := fmt.Sprintf("could not update credential with id: %s", req.ID)
Expand Down Expand Up @@ -326,7 +326,7 @@ func (cr CredentialRouter) VerifyCredential(ctx context.Context, w http.Response
return framework.NewRequestError(err, http.StatusBadRequest)
}

verificationResult, err := cr.service.VerifyCredential(credential.VerifyCredentialRequest{
verificationResult, err := cr.service.VerifyCredential(ctx, credential.VerifyCredentialRequest{
DataIntegrityCredential: request.DataIntegrityCredential,
CredentialJWT: request.CredentialJWT,
})
Expand Down Expand Up @@ -382,7 +382,7 @@ func (cr CredentialRouter) GetCredentials(ctx context.Context, w http.ResponseWr
}

func (cr CredentialRouter) getCredentialsByIssuer(ctx context.Context, issuer string, w http.ResponseWriter, _ *http.Request) error {
gotCredentials, err := cr.service.GetCredentialsByIssuer(credential.GetCredentialByIssuerRequest{Issuer: issuer})
gotCredentials, err := cr.service.GetCredentialsByIssuer(ctx, credential.GetCredentialByIssuerRequest{Issuer: issuer})
if err != nil {
errMsg := fmt.Sprintf("could not get credentials for issuer: %s", util.SanitizeLog(issuer))
logrus.WithError(err).Error(errMsg)
Expand All @@ -394,7 +394,7 @@ func (cr CredentialRouter) getCredentialsByIssuer(ctx context.Context, issuer st
}

func (cr CredentialRouter) getCredentialsBySubject(ctx context.Context, subject string, w http.ResponseWriter, _ *http.Request) error {
gotCredentials, err := cr.service.GetCredentialsBySubject(credential.GetCredentialBySubjectRequest{Subject: subject})
gotCredentials, err := cr.service.GetCredentialsBySubject(ctx, credential.GetCredentialBySubjectRequest{Subject: subject})
if err != nil {
errMsg := fmt.Sprintf("could not get credentials for subject: %s", util.SanitizeLog(subject))
logrus.WithError(err).Error(errMsg)
Expand All @@ -406,7 +406,7 @@ func (cr CredentialRouter) getCredentialsBySubject(ctx context.Context, subject
}

func (cr CredentialRouter) getCredentialsBySchema(ctx context.Context, schema string, w http.ResponseWriter, _ *http.Request) error {
gotCredentials, err := cr.service.GetCredentialsBySchema(credential.GetCredentialBySchemaRequest{Schema: schema})
gotCredentials, err := cr.service.GetCredentialsBySchema(ctx, credential.GetCredentialBySchemaRequest{Schema: schema})
if err != nil {
errMsg := fmt.Sprintf("could not get credentials for schema: %s", util.SanitizeLog(schema))
logrus.WithError(err).Error(errMsg)
Expand Down Expand Up @@ -436,7 +436,7 @@ func (cr CredentialRouter) DeleteCredential(ctx context.Context, w http.Response
return framework.NewRequestErrorMsg(errMsg, http.StatusBadRequest)
}

if err := cr.service.DeleteCredential(credential.DeleteCredentialRequest{ID: *id}); err != nil {
if err := cr.service.DeleteCredential(ctx, credential.DeleteCredentialRequest{ID: *id}); err != nil {
errMsg := fmt.Sprintf("could not delete credential with id: %s", *id)
logrus.WithError(err).Error(errMsg)
return framework.NewRequestError(errors.Wrap(err, errMsg), http.StatusInternalServerError)
Expand Down
41 changes: 21 additions & 20 deletions pkg/server/router/credential_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package router

import (
"context"
"fmt"
"testing"
"time"
Expand Down Expand Up @@ -52,13 +53,13 @@ func TestCredentialRouter(t *testing.T) {

// create a credential

issuerDID, err := didService.CreateDIDByMethod(did.CreateDIDRequest{Method: didsdk.KeyMethod, KeyType: crypto.Ed25519})
issuerDID, err := didService.CreateDIDByMethod(context.Background(), did.CreateDIDRequest{Method: didsdk.KeyMethod, KeyType: crypto.Ed25519})
assert.NoError(tt, err)
assert.NotEmpty(tt, issuerDID)

issuer := issuerDID.DID.ID
subject := "did:test:345"
createdCred, err := credService.CreateCredential(credential.CreateCredentialRequest{
createdCred, err := credService.CreateCredential(context.Background(), credential.CreateCredentialRequest{
Issuer: issuer,
Subject: subject,
Data: map[string]any{
Expand All @@ -80,42 +81,42 @@ func TestCredentialRouter(t *testing.T) {
assert.Equal(tt, "Nakamoto", cred.CredentialSubject["lastName"])

// get it back
gotCred, err := credService.GetCredential(credential.GetCredentialRequest{ID: cred.ID})
gotCred, err := credService.GetCredential(context.Background(), credential.GetCredentialRequest{ID: cred.ID})
assert.NoError(tt, err)
assert.NotEmpty(tt, gotCred)

// compare for object equality
assert.Equal(tt, createdCred.CredentialJWT, gotCred.CredentialJWT)

// get a cred that doesn't exist
_, err = credService.GetCredential(credential.GetCredentialRequest{ID: "bad"})
_, err = credService.GetCredential(context.Background(), credential.GetCredentialRequest{ID: "bad"})
assert.Error(tt, err)
assert.Contains(tt, err.Error(), "credential not found with id: bad")

// get by schema - no schema
bySchema, err := credService.GetCredentialsBySchema(credential.GetCredentialBySchemaRequest{Schema: ""})
bySchema, err := credService.GetCredentialsBySchema(context.Background(), credential.GetCredentialBySchemaRequest{Schema: ""})
assert.NoError(tt, err)
assert.Len(tt, bySchema.Credentials, 1)
assert.EqualValues(tt, cred.CredentialSchema, bySchema.Credentials[0].Credential.CredentialSchema)

// get by subject
bySubject, err := credService.GetCredentialsBySubject(credential.GetCredentialBySubjectRequest{Subject: subject})
bySubject, err := credService.GetCredentialsBySubject(context.Background(), credential.GetCredentialBySubjectRequest{Subject: subject})
assert.NoError(tt, err)
assert.Len(tt, bySubject.Credentials, 1)

assert.Equal(tt, cred.ID, bySubject.Credentials[0].ID)
assert.Equal(tt, cred.CredentialSubject[credsdk.VerifiableCredentialIDProperty], bySubject.Credentials[0].Credential.CredentialSubject[credsdk.VerifiableCredentialIDProperty])

// get by issuer
byIssuer, err := credService.GetCredentialsByIssuer(credential.GetCredentialByIssuerRequest{Issuer: issuer})
byIssuer, err := credService.GetCredentialsByIssuer(context.Background(), credential.GetCredentialByIssuerRequest{Issuer: issuer})
assert.NoError(tt, err)
assert.Len(tt, byIssuer.Credentials, 1)

assert.Equal(tt, cred.ID, byIssuer.Credentials[0].Credential.ID)
assert.Equal(tt, cred.Issuer, byIssuer.Credentials[0].Credential.Issuer)

// create another cred with the same issuer, different subject, different schema that doesn't exist
_, err = credService.CreateCredential(credential.CreateCredentialRequest{
_, err = credService.CreateCredential(context.Background(), credential.CreateCredentialRequest{
Issuer: issuer,
Subject: "did:abcd:efghi",
JSONSchema: "https://test-schema.com",
Expand All @@ -138,12 +139,12 @@ func TestCredentialRouter(t *testing.T) {
"required": []any{"email"},
"additionalProperties": false,
}
createdSchema, err := schemaService.CreateSchema(schema.CreateSchemaRequest{Author: "me", Name: "simple schema", Schema: emailSchema})
createdSchema, err := schemaService.CreateSchema(context.Background(), schema.CreateSchemaRequest{Author: "me", Name: "simple schema", Schema: emailSchema})
assert.NoError(tt, err)
assert.NotEmpty(tt, createdSchema)

// create another cred with the same issuer, different subject, different schema that does exist
createdCredWithSchema, err := credService.CreateCredential(credential.CreateCredentialRequest{
createdCredWithSchema, err := credService.CreateCredential(context.Background(), credential.CreateCredentialRequest{
Issuer: issuer,
Subject: "did:abcd:efghi",
JSONSchema: createdSchema.ID,
Expand All @@ -156,35 +157,35 @@ func TestCredentialRouter(t *testing.T) {
assert.NotEmpty(tt, createdCredWithSchema)

// get by issuer
byIssuer, err = credService.GetCredentialsByIssuer(credential.GetCredentialByIssuerRequest{Issuer: issuer})
byIssuer, err = credService.GetCredentialsByIssuer(context.Background(), credential.GetCredentialByIssuerRequest{Issuer: issuer})
assert.NoError(tt, err)
assert.Len(tt, byIssuer.Credentials, 2)

// make sure the schema and subject queries are consistent
bySchema, err = credService.GetCredentialsBySchema(credential.GetCredentialBySchemaRequest{Schema: ""})
bySchema, err = credService.GetCredentialsBySchema(context.Background(), credential.GetCredentialBySchemaRequest{Schema: ""})
assert.NoError(tt, err)
assert.Len(tt, bySchema.Credentials, 1)

assert.Equal(tt, cred.ID, bySchema.Credentials[0].ID)
assert.EqualValues(tt, cred.CredentialSchema, bySchema.Credentials[0].Credential.CredentialSchema)

bySubject, err = credService.GetCredentialsBySubject(credential.GetCredentialBySubjectRequest{Subject: subject})
bySubject, err = credService.GetCredentialsBySubject(context.Background(), credential.GetCredentialBySubjectRequest{Subject: subject})
assert.NoError(tt, err)
assert.Len(tt, bySubject.Credentials, 1)

assert.Equal(tt, cred.ID, bySubject.Credentials[0].ID)
assert.Equal(tt, cred.CredentialSubject[credsdk.VerifiableCredentialIDProperty], bySubject.Credentials[0].Credential.CredentialSubject[credsdk.VerifiableCredentialIDProperty])

// delete a cred that doesn't exist (no error since idempotent)
err = credService.DeleteCredential(credential.DeleteCredentialRequest{ID: "bad"})
err = credService.DeleteCredential(context.Background(), credential.DeleteCredentialRequest{ID: "bad"})
assert.NoError(tt, err)

// delete a credential that does exist
err = credService.DeleteCredential(credential.DeleteCredentialRequest{ID: cred.ID})
err = credService.DeleteCredential(context.Background(), credential.DeleteCredentialRequest{ID: cred.ID})
assert.NoError(tt, err)

// get it back
_, err = credService.GetCredential(credential.GetCredentialRequest{ID: cred.ID})
_, err = credService.GetCredential(context.Background(), credential.GetCredentialRequest{ID: cred.ID})
assert.Error(tt, err)
assert.Contains(tt, err.Error(), fmt.Sprintf("credential not found with id: %s", cred.ID))
})
Expand All @@ -206,7 +207,7 @@ func TestCredentialRouter(t *testing.T) {
assert.Equal(tt, framework.StatusReady, credService.Status().Status)

// create a did
issuerDID, err := didService.CreateDIDByMethod(did.CreateDIDRequest{Method: didsdk.KeyMethod, KeyType: crypto.Ed25519})
issuerDID, err := didService.CreateDIDByMethod(context.Background(), did.CreateDIDRequest{Method: didsdk.KeyMethod, KeyType: crypto.Ed25519})
assert.NoError(tt, err)
assert.NotEmpty(tt, issuerDID)

Expand All @@ -222,14 +223,14 @@ func TestCredentialRouter(t *testing.T) {
"additionalProperties": false,
}

createdSchema, err := schemaService.CreateSchema(schema.CreateSchemaRequest{Author: "me", Name: "simple schema", Schema: emailSchema})
createdSchema, err := schemaService.CreateSchema(context.Background(), schema.CreateSchemaRequest{Author: "me", Name: "simple schema", Schema: emailSchema})
assert.NoError(tt, err)
assert.NotEmpty(tt, createdSchema)

issuer := issuerDID.DID.ID
subject := "did:test:345"

createdCred, err := credService.CreateCredential(credential.CreateCredentialRequest{
createdCred, err := credService.CreateCredential(context.Background(), credential.CreateCredentialRequest{
Issuer: issuer,
Subject: subject,
JSONSchema: createdSchema.ID,
Expand All @@ -251,7 +252,7 @@ func TestCredentialRouter(t *testing.T) {
assert.Contains(tt, credStatusMap["statusListCredential"], "v1/credentials/status")
assert.NotEmpty(tt, credStatusMap["statusListIndex"])

createdCredTwo, err := credService.CreateCredential(credential.CreateCredentialRequest{
createdCredTwo, err := credService.CreateCredential(context.Background(), credential.CreateCredentialRequest{
Issuer: issuer,
Subject: subject,
JSONSchema: createdSchema.ID,
Expand Down
6 changes: 3 additions & 3 deletions pkg/server/router/did.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func (dr DIDRouter) CreateDIDByMethod(ctx context.Context, w http.ResponseWriter

// TODO(gabe) check if the key type is supported for the method, to tell whether this is a bad req or internal error
createDIDRequest := did.CreateDIDRequest{Method: didsdk.Method(*method), KeyType: request.KeyType, DIDWebID: request.DIDWebID}
createDIDResponse, err := dr.service.CreateDIDByMethod(createDIDRequest)
createDIDResponse, err := dr.service.CreateDIDByMethod(ctx, createDIDRequest)
if err != nil {
errMsg := fmt.Sprintf("could not create DID for method<%s> with key type: %s", *method, request.KeyType)
logrus.WithError(err).Error(errMsg)
Expand Down Expand Up @@ -152,7 +152,7 @@ func (dr DIDRouter) GetDIDByMethod(ctx context.Context, w http.ResponseWriter, _
// TODO(gabe) check if the method is supported, to tell whether this is a bad req or internal error
// TODO(gabe) differentiate between internal errors and not found DIDs
getDIDRequest := did.GetDIDRequest{Method: didsdk.Method(*method), ID: *id}
gotDID, err := dr.service.GetDIDByMethod(getDIDRequest)
gotDID, err := dr.service.GetDIDByMethod(ctx, getDIDRequest)
if err != nil {
errMsg := fmt.Sprintf("could not get DID for method<%s> with id: %s", *method, *id)
logrus.WithError(err).Error(errMsg)
Expand Down Expand Up @@ -188,7 +188,7 @@ func (dr DIDRouter) GetDIDsByMethod(ctx context.Context, w http.ResponseWriter,
// TODO(gabe) check if the method is supported, to tell whether this is a bad req or internal error
// TODO(gabe) differentiate between internal errors and not found DIDs
getDIDsRequest := did.GetDIDsRequest{Method: didsdk.Method(*method)}
gotDIDs, err := dr.service.GetDIDsByMethod(getDIDsRequest)
gotDIDs, err := dr.service.GetDIDsByMethod(ctx, getDIDsRequest)
if err != nil {
errMsg := fmt.Sprintf("could not get DIDs for method: %s", *method)
logrus.WithError(err).Error(errMsg)
Expand Down
Loading

0 comments on commit 82a6233

Please sign in to comment.