From 42d2750c903e81e04be47b09331914a02b9963ac Mon Sep 17 00:00:00 2001 From: Tim Holm Date: Thu, 6 Jul 2023 11:26:47 +1000 Subject: [PATCH] AWS websocket implementation --- cloud/aws/deploy/policy/iam.go | 9 ++ cloud/aws/deploy/up.go | 19 +++ cloud/aws/deploy/websocket/apigateway.go | 184 ++++++++++++++++++++++ cloud/aws/go.mod | 7 +- cloud/aws/go.sum | 9 ++ cloud/aws/mocks/provider/aws.go | 16 ++ cloud/aws/runtime/cmd/membrane.go | 2 + cloud/aws/runtime/core/provider.go | 24 ++- cloud/aws/runtime/gateway/event.go | 22 ++- cloud/aws/runtime/gateway/lambda.go | 64 ++++++++ cloud/aws/runtime/websocket/apigateway.go | 99 ++++++++++++ cloud/azure/deploy/up.go | 8 + cloud/gcp/deploy/up.go | 8 + 13 files changed, 460 insertions(+), 11 deletions(-) create mode 100644 cloud/aws/deploy/websocket/apigateway.go create mode 100644 cloud/aws/runtime/websocket/apigateway.go diff --git a/cloud/aws/deploy/policy/iam.go b/cloud/aws/deploy/policy/iam.go index 8ba055a7e..eef11d8d3 100644 --- a/cloud/aws/deploy/policy/iam.go +++ b/cloud/aws/deploy/policy/iam.go @@ -30,6 +30,7 @@ import ( "github.com/nitrictech/nitric/cloud/aws/deploy/queue" "github.com/nitrictech/nitric/cloud/aws/deploy/secret" "github.com/nitrictech/nitric/cloud/aws/deploy/topic" + "github.com/nitrictech/nitric/cloud/aws/deploy/websocket" deploy "github.com/nitrictech/nitric/core/pkg/api/nitric/deploy/v1" v1 "github.com/nitrictech/nitric/core/pkg/api/nitric/v1" ) @@ -54,6 +55,7 @@ type StackResources struct { Buckets map[string]*bucket.S3Bucket Collections map[string]*collection.DynamodbCollection Secrets map[string]*secret.SecretsManagerSecret + Websockets map[string]*websocket.AwsWebsocketApiGateway } type PrincipalMap = map[v1.ResourceType]map[string]*iam.Role @@ -132,6 +134,9 @@ var awsActionsMap map[v1.Action][]string = map[v1.Action][]string{ v1.Action_SecretPut: { "secretsmanager:PutSecretValue", }, + v1.Action_WebsocketManage: { + "execute-api:ManageConnections", + }, } func actionsToAwsActions(actions []v1.Action) []string { @@ -167,6 +172,10 @@ func arnForResource(resource *deploy.Resource, resources *StackResources) ([]int if s, ok := resources.Secrets[resource.Name]; ok { return []interface{}{s.SecretsManager.Arn}, nil } + case v1.ResourceType_Websocket: + if w, ok := resources.Websockets[resource.Name]; ok { + return []interface{}{pulumi.Sprintf("%s/*", w.Api.ExecutionArn)}, nil + } default: return nil, fmt.Errorf( "invalid resource type: %s. Did you mean to define it as a principal?", resource.Type) diff --git a/cloud/aws/deploy/up.go b/cloud/aws/deploy/up.go index 230eb9a45..6d30f969e 100644 --- a/cloud/aws/deploy/up.go +++ b/cloud/aws/deploy/up.go @@ -37,6 +37,7 @@ import ( "github.com/nitrictech/nitric/cloud/aws/deploy/secret" "github.com/nitrictech/nitric/cloud/aws/deploy/stack" "github.com/nitrictech/nitric/cloud/aws/deploy/topic" + "github.com/nitrictech/nitric/cloud/aws/deploy/websocket" commonDeploy "github.com/nitrictech/nitric/cloud/common/deploy" "github.com/nitrictech/nitric/cloud/common/deploy/image" pulumiutils "github.com/nitrictech/nitric/cloud/common/deploy/pulumi" @@ -309,6 +310,23 @@ func (d *DeployServer) Up(request *deploy.DeployUpRequest, stream deploy.DeployS } } + // deploy websockets + websockets := map[string]*websocket.AwsWebsocketApiGateway{} + for _, res := range request.Spec.Resources { + switch ws := res.Config.(type) { + case *deploy.Resource_Websocket: + websockets[res.Name], err = websocket.NewAwsWebsocketApiGateway(ctx, res.Name, &websocket.AwsWebsocketApiGatewayArgs{ + DefaultTarget: execs[ws.Websocket.MessageTarget.GetExecutionUnit()], + ConnectTarget: execs[ws.Websocket.ConnectTarget.GetExecutionUnit()], + DisconnectTarget: execs[ws.Websocket.DisconnectTarget.GetExecutionUnit()], + StackID: stackID, + }) + if err != nil { + return err + } + } + } + // Deploy all schedules schedules := map[string]*schedule.AwsEventbridgeSchedule{} for _, res := range request.Spec.Resources { @@ -381,6 +399,7 @@ func (d *DeployServer) Up(request *deploy.DeployUpRequest, stream deploy.DeployS Queues: queues, Collections: collections, Secrets: secrets, + Websockets: websockets, }, Principals: principals, }) diff --git a/cloud/aws/deploy/websocket/apigateway.go b/cloud/aws/deploy/websocket/apigateway.go new file mode 100644 index 000000000..eeea255d4 --- /dev/null +++ b/cloud/aws/deploy/websocket/apigateway.go @@ -0,0 +1,184 @@ +// Copyright Nitric Pty Ltd. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package websocket + +import ( + "fmt" + + "github.com/pulumi/pulumi-aws/sdk/v5/go/aws/apigatewayv2" + awslambda "github.com/pulumi/pulumi-aws/sdk/v5/go/aws/lambda" + "github.com/pulumi/pulumi/sdk/v3/go/pulumi" + + "github.com/nitrictech/nitric/cloud/aws/deploy/exec" + common "github.com/nitrictech/nitric/cloud/common/deploy/tags" +) + +type AwsWebsocketApiGatewayArgs struct { + DefaultTarget *exec.LambdaExecUnit + ConnectTarget *exec.LambdaExecUnit + DisconnectTarget *exec.LambdaExecUnit + + StackID pulumi.StringInput +} + +type AwsWebsocketApiGateway struct { + pulumi.ResourceState + + Name string + Api *apigatewayv2.Api +} + +func NewAwsWebsocketApiGateway(ctx *pulumi.Context, name string, args *AwsWebsocketApiGatewayArgs, opts ...pulumi.ResourceOption) (*AwsWebsocketApiGateway, error) { + res := &AwsWebsocketApiGateway{Name: name} + + err := ctx.RegisterComponentResource("nitric:websocket:AwsApiGateway", name, res, opts...) + if err != nil { + return nil, err + } + + opts = append(opts, pulumi.Parent(res)) + + res.Api, err = apigatewayv2.NewApi(ctx, name, &apigatewayv2.ApiArgs{ + ProtocolType: pulumi.String("WEBSOCKET"), + Tags: common.Tags(ctx, args.StackID, name), + // TODO: We won't actually be using this, but it is required. + // Instead we'll be using the $default route + RouteSelectionExpression: pulumi.String("$request.body.action"), + }, opts...) + if err != nil { + return nil, err + } + + // Create the API integrations + integrationDefault, err := apigatewayv2.NewIntegration(ctx, fmt.Sprintf("%s-default-integration", name), &apigatewayv2.IntegrationArgs{ + ApiId: res.Api.ID(), + IntegrationType: pulumi.String("AWS_PROXY"), + IntegrationUri: args.DefaultTarget.Function.Arn, + }) + if err != nil { + return nil, err + } + + _, err = awslambda.NewPermission(ctx, fmt.Sprintf("%s-default-permission", name), &awslambda.PermissionArgs{ + Function: args.DefaultTarget.Function.Name, + Action: pulumi.String("lambda:InvokeFunction"), + Principal: pulumi.String("apigateway.amazonaws.com"), + SourceArn: pulumi.Sprintf("%s/*/*", res.Api.ExecutionArn), + }, opts...) + if err != nil { + return nil, err + } + + // check if the function name is different if not assign to default + integrationConnect := integrationDefault + if args.ConnectTarget != args.DefaultTarget { + integrationConnect, err = apigatewayv2.NewIntegration(ctx, fmt.Sprintf("%s-connect-integration", name), &apigatewayv2.IntegrationArgs{ + ApiId: res.Api.ID(), + IntegrationType: pulumi.String("AWS_PROXY"), + IntegrationUri: args.ConnectTarget.Function.Arn, + }) + if err != nil { + return nil, err + } + + _, err = awslambda.NewPermission(ctx, fmt.Sprintf("%s-connect-permission", name), &awslambda.PermissionArgs{ + Function: args.DefaultTarget.Function.Name, + Action: pulumi.String("lambda:InvokeFunction"), + Principal: pulumi.String("apigateway.amazonaws.com"), + SourceArn: pulumi.Sprintf("%s/*/*", res.Api.ExecutionArn), + }, opts...) + if err != nil { + return nil, err + } + } + + // check if the function name is different if not assign to default + integrationDisconnect := integrationDefault + if args.DisconnectTarget != args.DefaultTarget { + integrationDisconnect, err = apigatewayv2.NewIntegration(ctx, fmt.Sprintf("%s-disconnect-integration", name), &apigatewayv2.IntegrationArgs{ + ApiId: res.Api.ID(), + IntegrationType: pulumi.String("AWS_PROXY"), + IntegrationUri: args.DisconnectTarget.Function.Arn, + }) + if err != nil { + return nil, err + } + + _, err = awslambda.NewPermission(ctx, fmt.Sprintf("%s-disconnect-permission", name), &awslambda.PermissionArgs{ + Function: args.DefaultTarget.Function.Name, + Action: pulumi.String("lambda:InvokeFunction"), + Principal: pulumi.String("apigateway.amazonaws.com"), + SourceArn: pulumi.Sprintf("%s/*/*", res.Api.ExecutionArn), + }, opts...) + if err != nil { + return nil, err + } + } + + // Create the routes for the websocket handler + // The default message route + _, err = apigatewayv2.NewRoute(ctx, fmt.Sprintf("%s-default-route", name), &apigatewayv2.RouteArgs{ + ApiId: res.Api.ID(), + RouteKey: pulumi.String("$default"), + Target: pulumi.Sprintf("integrations/%s", integrationDefault.ID()), + }) + if err != nil { + return nil, err + } + + // The client connection route + _, err = apigatewayv2.NewRoute(ctx, fmt.Sprintf("%s-connect-route", name), &apigatewayv2.RouteArgs{ + ApiId: res.Api.ID(), + RouteKey: pulumi.String("$connect"), + Target: pulumi.Sprintf("integrations/%s", integrationConnect.ID()), + }) + if err != nil { + return nil, err + } + + // the client disconnection route + _, err = apigatewayv2.NewRoute(ctx, fmt.Sprintf("%s-disconnect-route", name), &apigatewayv2.RouteArgs{ + ApiId: res.Api.ID(), + RouteKey: pulumi.String("$disconnect"), + Target: pulumi.Sprintf("integrations/%s", integrationDisconnect.ID()), + }) + if err != nil { + return nil, err + } + + _, err = apigatewayv2.NewStage(ctx, name+"DefaultStage", &apigatewayv2.StageArgs{ + AutoDeploy: pulumi.BoolPtr(true), + Name: pulumi.String("$default"), + ApiId: res.Api.ID(), + Tags: common.Tags(ctx, args.StackID, name+"DefaultStage"), + }, opts...) + if err != nil { + return nil, err + } + + if err != nil { + return nil, err + } + + endPoint := res.Api.ApiEndpoint.ApplyT(func(ep string) string { + return ep + }).(pulumi.StringInput) + + ctx.Export("api:"+name, endPoint) + + return res, nil +} diff --git a/cloud/aws/go.mod b/cloud/aws/go.mod index 6caf5ae40..3eafc2586 100644 --- a/cloud/aws/go.mod +++ b/cloud/aws/go.mod @@ -6,7 +6,7 @@ require ( github.com/avast/retry-go v3.0.0+incompatible github.com/aws/aws-lambda-go v1.34.1 github.com/aws/aws-sdk-go v1.44.146 - github.com/aws/aws-sdk-go-v2 v1.17.2 + github.com/aws/aws-sdk-go-v2 v1.18.0 github.com/aws/aws-sdk-go-v2/config v1.18.4 github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.10.6 github.com/aws/aws-sdk-go-v2/service/apigatewayv2 v1.12.20 @@ -70,10 +70,11 @@ require ( github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.9 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.13.4 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.20 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.26 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.20 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.33 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.27 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.3.27 // indirect github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.16 // indirect + github.com/aws/aws-sdk-go-v2/service/apigatewaymanagementapi v1.11.10 // indirect github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.13.26 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.11 // indirect github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.20 // indirect diff --git a/cloud/aws/go.sum b/cloud/aws/go.sum index 3f98c40ab..5f04f0581 100644 --- a/cloud/aws/go.sum +++ b/cloud/aws/go.sum @@ -117,6 +117,8 @@ github.com/aws/aws-sdk-go v1.44.146/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8 github.com/aws/aws-sdk-go-v2 v1.17.1/go.mod h1:JLnGeGONAyi2lWXI1p0PCIOIy333JMVK1U7Hf0aRFLw= github.com/aws/aws-sdk-go-v2 v1.17.2 h1:r0yRZInwiPBNpQ4aDy/Ssh3ROWsGtKDwar2JS8Lm+N8= github.com/aws/aws-sdk-go-v2 v1.17.2/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw= +github.com/aws/aws-sdk-go-v2 v1.18.0 h1:882kkTpSFhdgYRKVZ/VCgf7sd0ru57p2JCxz4/oN5RY= +github.com/aws/aws-sdk-go-v2 v1.18.0/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.9 h1:RKci2D7tMwpvGpDNZnGQw9wk6v7o/xSwFcUAuNPoB8k= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.9/go.mod h1:vCmV1q1VK8eoQJ5+aYE7PkK1K6v41qJ5pJdK3ggCDvg= github.com/aws/aws-sdk-go-v2/config v1.18.4 h1:VZKhr3uAADXHStS/Gf9xSYVmmaluTUfkc0dcbPiDsKE= @@ -130,13 +132,19 @@ github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.20/go.mod h1:d9xFpWd3qYwdIXM github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.25/go.mod h1:Zb29PYkf42vVYQY6pvSyJCJcFHlPIiY+YKdPtwnvMkY= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.26 h1:5WU31cY7m0tG+AiaXuXGoMzo2GBQ1IixtWa8Yywsgco= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.26/go.mod h1:2E0LdbJW6lbeU4uxjum99GZzI0ZjDpAb0CoSCM0oeEY= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.33 h1:kG5eQilShqmJbv11XL1VpyDbaEJzWxd4zRiCG30GSn4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.33/go.mod h1:7i0PF1ME/2eUPFcjkVIwq+DOygHEoK92t5cDqNgYbIw= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.19/go.mod h1:6Q0546uHDp421okhmmGfbxzq2hBqbXFNpi4k+Q1JnQA= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.20 h1:WW0qSzDWoiWU2FS5DbKpxGilFVlCEJPwx4YtjdfI0Jw= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.20/go.mod h1:/+6lSiby8TBFpTVXZgKiN/rCfkYXEGvhlM4zCgPpt7w= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.27 h1:vFQlirhuM8lLlpI7imKOMsjdQLuN9CPi+k44F/OFVsk= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.27/go.mod h1:UrHnn3QV/d0pBZ6QBAEQcqFLf8FAzLmoUfPVIueOvoM= github.com/aws/aws-sdk-go-v2/internal/ini v1.3.27 h1:N2eKFw2S+JWRCtTt0IhIX7uoGGQciD4p6ba+SJv4WEU= github.com/aws/aws-sdk-go-v2/internal/ini v1.3.27/go.mod h1:RdwFVc7PBYWY33fa2+8T1mSqQ7ZEK4ILpM0wfioDC3w= github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.16 h1:2EXB7dtGwRYIN3XQ9qwIW504DVbKIw3r89xQnonGdsQ= github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.16/go.mod h1:XH+3h395e3WVdd6T2Z3mPxuI+x/HVtdqVOREkTiyubs= +github.com/aws/aws-sdk-go-v2/service/apigatewaymanagementapi v1.11.10 h1:os9Aix72xeiZ9+wQ2LZJSoHOzGUqKYLLS9S7Y4BaRmI= +github.com/aws/aws-sdk-go-v2/service/apigatewaymanagementapi v1.11.10/go.mod h1:rFWa3WA43LkZ9pAJkGuO90kU+N0Ru2dCJwjRfZ8kKZ8= github.com/aws/aws-sdk-go-v2/service/apigatewayv2 v1.12.20 h1:7N4o3yLag3c3c22POkmCAfrr/OQG5807a9NRh9lUUKw= github.com/aws/aws-sdk-go-v2/service/apigatewayv2 v1.12.20/go.mod h1:BEIWaGqO27qq9JeFeY746S4+SFmBajpV+yhGne2qbMo= github.com/aws/aws-sdk-go-v2/service/dynamodb v1.17.7/go.mod h1:BiglbKCG56L8tmMnUEyEQo422BO9xnNR8vVHnOsByf8= @@ -1076,6 +1084,7 @@ golang.org/x/oauth2 v0.0.0-20210514164344-f6687ab2804c/go.mod h1:KelEdhl1UZF7XfJ golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20220223155221-ee480838109b/go.mod h1:DAh4E804XQdzx2j+YRIaUnCqCV2RuMz24cGBJ5QYIrc= golang.org/x/oauth2 v0.6.0 h1:Lh8GPgSKBfWSwFvtuWOfeI3aAAnbXTSutYxJiOJFgIw= +golang.org/x/oauth2 v0.6.0/go.mod h1:ycmewcwgD4Rpr3eZJLSB4Kyyljb3qDh40vJ8STE5HKw= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/cloud/aws/mocks/provider/aws.go b/cloud/aws/mocks/provider/aws.go index 2d4be8032..c1828b9b5 100644 --- a/cloud/aws/mocks/provider/aws.go +++ b/cloud/aws/mocks/provider/aws.go @@ -8,6 +8,7 @@ import ( context "context" reflect "reflect" + apigatewayv2 "github.com/aws/aws-sdk-go-v2/service/apigatewayv2" gomock "github.com/golang/mock/gomock" v1 "github.com/nitrictech/nitric/core/pkg/api/nitric/v1" resource "github.com/nitrictech/nitric/core/pkg/plugins/resource" @@ -65,6 +66,21 @@ func (mr *MockAwsProviderMockRecorder) Details(arg0, arg1, arg2 interface{}) *go return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Details", reflect.TypeOf((*MockAwsProvider)(nil).Details), arg0, arg1, arg2) } +// GetApiGatewayById mocks base method. +func (m *MockAwsProvider) GetApiGatewayById(arg0 context.Context, arg1 string) (*apigatewayv2.GetApiOutput, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetApiGatewayById", arg0, arg1) + ret0, _ := ret[0].(*apigatewayv2.GetApiOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetApiGatewayById indicates an expected call of GetApiGatewayById. +func (mr *MockAwsProviderMockRecorder) GetApiGatewayById(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetApiGatewayById", reflect.TypeOf((*MockAwsProvider)(nil).GetApiGatewayById), arg0, arg1) +} + // GetResources mocks base method. func (m *MockAwsProvider) GetResources(arg0 context.Context, arg1 string) (map[string]string, error) { m.ctrl.T.Helper() diff --git a/cloud/aws/runtime/cmd/membrane.go b/cloud/aws/runtime/cmd/membrane.go index 511f956d4..ebd8a9822 100644 --- a/cloud/aws/runtime/cmd/membrane.go +++ b/cloud/aws/runtime/cmd/membrane.go @@ -27,6 +27,7 @@ import ( sqs_service "github.com/nitrictech/nitric/cloud/aws/runtime/queue" secrets_manager_secret_service "github.com/nitrictech/nitric/cloud/aws/runtime/secret" s3_service "github.com/nitrictech/nitric/cloud/aws/runtime/storage" + "github.com/nitrictech/nitric/cloud/aws/runtime/websocket" base_http "github.com/nitrictech/nitric/cloud/common/runtime/gateway" "github.com/nitrictech/nitric/core/pkg/membrane" "github.com/nitrictech/nitric/core/pkg/utils" @@ -62,6 +63,7 @@ func main() { membraneOpts.StoragePlugin, _ = s3_service.New(provider) membraneOpts.ResourcesPlugin = provider membraneOpts.CreateTracerProvider = newTracerProvider + membraneOpts.WebsocketPlugin, _ = websocket.NewAwsApiGatewayWebsocket(provider) m, err := membrane.New(membraneOpts) if err != nil { diff --git a/cloud/aws/runtime/core/provider.go b/cloud/aws/runtime/core/provider.go index cf6399e67..b646794da 100644 --- a/cloud/aws/runtime/core/provider.go +++ b/cloud/aws/runtime/core/provider.go @@ -48,7 +48,8 @@ const ( ) var resourceTypeMap = map[resource.ResourceType]AwsResource{ - resource.ResourceType_Api: AwsResource_Api, + resource.ResourceType_Api: AwsResource_Api, + resource.ResourceType_Websocket: AwsResource_Api, } type AwsProvider interface { @@ -57,6 +58,7 @@ type AwsProvider interface { // GetResources API operation for AWS Provider. // Returns requested aws resources for the given resource type GetResources(context.Context, AwsResource) (map[string]string, error) + GetApiGatewayById(context.Context, string) (*apigatewayv2.GetApiOutput, error) } // Aws core utility provider @@ -102,16 +104,20 @@ func (a *awsProviderImpl) Details(ctx context.Context, typ resource.ResourceType arnParts := strings.Split(arn, "/") apiId := arnParts[len(arnParts)-1] // Get api detail - api, err := a.apiClient.GetApi(context.TODO(), &apigatewayv2.GetApiInput{ - ApiId: aws.String(apiId), - }) + api, err := a.GetApiGatewayById(ctx, apiId) if err != nil { return nil, err } details.Service = "ApiGateway" - details.Detail = resource.ApiDetails{ - URL: *api.ApiEndpoint, + if typ == resource.ResourceType_Api { + details.Detail = resource.ApiDetails{ + URL: *api.ApiEndpoint, + } + } else { + details.Detail = resource.WebsocketDetails{ + URL: fmt.Sprintf("%s/$default", *api.ApiEndpoint), + } } return details, nil @@ -120,6 +126,12 @@ func (a *awsProviderImpl) Details(ctx context.Context, typ resource.ResourceType } } +func (a *awsProviderImpl) GetApiGatewayById(ctx context.Context, apiId string) (*apigatewayv2.GetApiOutput, error) { + return a.apiClient.GetApi(context.TODO(), &apigatewayv2.GetApiInput{ + ApiId: aws.String(apiId), + }) +} + func resourceTypeFromArn(arn string) (resource.ResourceType, error) { if !awsArn.IsARN(arn) { return "", fmt.Errorf("invalid ARN provided") diff --git a/cloud/aws/runtime/gateway/event.go b/cloud/aws/runtime/gateway/event.go index b42f9520c..e4c690fa4 100644 --- a/cloud/aws/runtime/gateway/event.go +++ b/cloud/aws/runtime/gateway/event.go @@ -28,6 +28,7 @@ const ( sns s3 httpEvent + websocketEvent healthcheck // cloudwatch schedule @@ -55,6 +56,7 @@ type nitricScheduleEvent struct { // An event struct that embeds the AWS event types that we handle type Event struct { events.APIGatewayV2HTTPRequest + events.APIGatewayWebsocketProxyRequest healthCheckEvent Records []Record nitricScheduleEvent @@ -62,7 +64,9 @@ type Event struct { func (e *Event) Type() eventType { // check if this event type contains valid data - if e.APIGatewayV2HTTPRequest.RouteKey != "" { + if e.APIGatewayWebsocketProxyRequest.RequestContext.ConnectionID != "" { + return websocketEvent + } else if e.APIGatewayV2HTTPRequest.RouteKey != "" { return httpEvent } else if e.Check { return healthcheck @@ -132,6 +136,15 @@ func (e *Event) UnmarshalJSON(data []byte) error { } e.APIGatewayV2HTTPRequest = apiEvent + case websocketEvent: + websocketEvent := events.APIGatewayWebsocketProxyRequest{} + err = json.Unmarshal(data, &websocketEvent) + + if err != nil { + return err + } + + e.APIGatewayWebsocketProxyRequest = websocketEvent case schedule: nitricSchedule := nitricScheduleEvent{} err = json.Unmarshal(data, &nitricSchedule) @@ -164,8 +177,13 @@ func (e *Event) getEventType(data []byte) eventType { return unknown } + requestContext, isRequest := temp["requestContext"].(map[string]interface{}) + // Handle non-record events - if _, ok := temp["routeKey"]; ok { + if isRequest { + if _, ok := requestContext["connectionId"]; ok { + return websocketEvent + } return httpEvent } else if _, ok := temp["x-nitric-healthcheck"]; ok { return healthcheck diff --git a/cloud/aws/runtime/gateway/lambda.go b/cloud/aws/runtime/gateway/lambda.go index d5218cda5..fc2a614cc 100644 --- a/cloud/aws/runtime/gateway/lambda.go +++ b/cloud/aws/runtime/gateway/lambda.go @@ -76,6 +76,68 @@ type LambdaGateway struct { finished chan int } +// Handle websocket events +func (s *LambdaGateway) handleWebsocketEvent(ctx context.Context, evt events.APIGatewayWebsocketProxyRequest) (interface{}, error) { + // Use the routekey to get the event type + + var wsEvent = v1.WebsocketEvent_Message + switch evt.RequestContext.RouteKey { + case "$connect": + wsEvent = v1.WebsocketEvent_Connect + case "$disconnect": + wsEvent = v1.WebsocketEvent_Disconnect + } + + api, err := s.provider.GetApiGatewayById(ctx, evt.RequestContext.APIID) + if err != nil { + return nil, err + } + + nitricName, ok := api.Tags["x-nitric-name"] + if !ok { + return nil, fmt.Errorf("recieved websocket trigger from non-nitric API gateway") + } + + req := &v1.TriggerRequest{ + Data: []byte(evt.Body), + Context: &v1.TriggerRequest_Websocket{ + Websocket: &v1.WebsocketTriggerContext{ + ConnectionId: evt.RequestContext.ConnectionID, + Event: wsEvent, + // Get the API gateways nitric name + Socket: nitricName, + }, + }, + } + + wrk, err := s.pool.GetWorker(&pool.GetWorkerOptions{ + Trigger: req, + }) + if err != nil { + return nil, err + } + + _, err = wrk.HandleTrigger(ctx, req) + if err != nil { + return events.APIGatewayProxyResponse{ + StatusCode: 500, + Body: "Error processing lambda request", + // TODO: Need to determine best case when to use this... + IsBase64Encoded: false, + }, nil + } + + // if response.GetWebsocket() == nil || !response.GetWebsocket().Success { + // return events.APIGatewayProxyResponse{ + // StatusCode: 500, + // }, nil + // } + + return events.APIGatewayProxyResponse{ + StatusCode: 200, + }, nil +} + // Handle API events func (s *LambdaGateway) handleApiEvent(ctx context.Context, evt events.APIGatewayV2HTTPRequest) (interface{}, error) { // Copy the headers and re-write for the proxy @@ -324,6 +386,8 @@ func (s *LambdaGateway) handleS3Event(ctx context.Context, records []Record) (in func (s *LambdaGateway) routeEvent(ctx context.Context, evt Event) (interface{}, error) { switch evt.Type() { + case websocketEvent: + return s.handleWebsocketEvent(ctx, evt.APIGatewayWebsocketProxyRequest) case httpEvent: return s.handleApiEvent(ctx, evt.APIGatewayV2HTTPRequest) case healthcheck: diff --git a/cloud/aws/runtime/websocket/apigateway.go b/cloud/aws/runtime/websocket/apigateway.go new file mode 100644 index 000000000..8476aa9ee --- /dev/null +++ b/cloud/aws/runtime/websocket/apigateway.go @@ -0,0 +1,99 @@ +package websocket + +import ( + "context" + "fmt" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/apigatewaymanagementapi" + "github.com/nitrictech/nitric/cloud/aws/runtime/core" + "github.com/nitrictech/nitric/core/pkg/plugins/resource" + "github.com/nitrictech/nitric/core/pkg/plugins/websocket" + "github.com/nitrictech/nitric/core/pkg/utils" + "go.opentelemetry.io/contrib/instrumentation/github.com/aws/aws-sdk-go-v2/otelaws" +) + +type ApiGatewayWebsocketService struct { + websocket.UnimplementedWebsocketService + provider core.AwsProvider + clients map[string]*apigatewaymanagementapi.Client +} + +var _ websocket.WebsocketService = &ApiGatewayWebsocketService{} + +func (a *ApiGatewayWebsocketService) getClientForSocket(socket string) (*apigatewaymanagementapi.Client, error) { + awsRegion := utils.GetEnv("AWS_REGION", "us-east-1") + + if client, ok := a.clients[socket]; ok { + return client, nil + } + + details, err := a.provider.Details(context.TODO(), resource.ResourceType_Api, socket) + if err != nil { + return nil, err + } + + apiDetails, ok := details.Detail.(resource.ApiDetails) + if !ok { + return nil, fmt.Errorf("an error occurred resolving API Gateway details") + } + + cfg, sessionError := config.LoadDefaultConfig(context.TODO(), config.WithRegion(awsRegion)) + if sessionError != nil { + return nil, fmt.Errorf("error creating new AWS session %w", sessionError) + } + + callbackUrl := strings.Replace(apiDetails.URL, "wss", "https", 1) + callbackUrl = callbackUrl + "/$default" + + otelaws.AppendMiddlewares(&cfg.APIOptions) + + a.clients[socket] = apigatewaymanagementapi.NewFromConfig(cfg, apigatewaymanagementapi.WithEndpointResolver(apigatewaymanagementapi.EndpointResolverFromURL( + callbackUrl, + ))) + + return a.clients[socket], nil +} + +func (a *ApiGatewayWebsocketService) Send(ctx context.Context, socket string, connectionId string, message []byte) error { + client, err := a.getClientForSocket(socket) + if err != nil { + return err + } + + _, err = client.PostToConnection(ctx, &apigatewaymanagementapi.PostToConnectionInput{ + ConnectionId: aws.String(connectionId), + Data: message, + }) + + if err != nil { + return err + } + + return nil +} + +func (a *ApiGatewayWebsocketService) Close(ctx context.Context, socket string, connectionId string) error { + client, err := a.getClientForSocket(socket) + if err != nil { + return err + } + + _, err = client.DeleteConnection(ctx, &apigatewaymanagementapi.DeleteConnectionInput{ + ConnectionId: aws.String(connectionId), + }) + if err != nil { + return err + } + + return nil +} + +func NewAwsApiGatewayWebsocket(provider core.AwsProvider) (*ApiGatewayWebsocketService, error) { + return &ApiGatewayWebsocketService{ + provider: provider, + clients: make(map[string]*apigatewaymanagementapi.Client), + }, nil +} diff --git a/cloud/azure/deploy/up.go b/cloud/azure/deploy/up.go index 9421fa8d5..3d2accf3d 100644 --- a/cloud/azure/deploy/up.go +++ b/cloud/azure/deploy/up.go @@ -70,6 +70,14 @@ func (d *DeployServer) Up(request *deploy.DeployUpRequest, stream deploy.DeployS } }() + // Get Websockets + websockets := lo.Filter[*deploy.Resource](request.Spec.Resources, func(item *deploy.Resource, index int) bool { + return item.GetWebsocket() != nil + }) + if len(websockets) > 0 { + return fmt.Errorf("websockets currently in preview not supported in the Azure provider.") + } + // Get Execution units executionUnits := lo.Filter[*deploy.Resource](request.Spec.Resources, func(item *deploy.Resource, index int) bool { return item.GetExecutionUnit() != nil diff --git a/cloud/gcp/deploy/up.go b/cloud/gcp/deploy/up.go index 47f107343..889b7976d 100644 --- a/cloud/gcp/deploy/up.go +++ b/cloud/gcp/deploy/up.go @@ -81,6 +81,14 @@ func (d *DeployServer) Up(request *deploy.DeployUpRequest, stream deploy.DeployS } }() + // Get Websockets + websockets := lo.Filter[*deploy.Resource](request.Spec.Resources, func(item *deploy.Resource, index int) bool { + return item.GetWebsocket() != nil + }) + if len(websockets) > 0 { + return fmt.Errorf("websockets currently in preview not supported in the GCP provider.") + } + project, err := organizations.LookupProject(ctx, &organizations.LookupProjectArgs{ ProjectId: &details.ProjectId, }, nil)