From 8f5b880afd5f3b6fef9533145ff654a6717a7ea2 Mon Sep 17 00:00:00 2001 From: Yaron Schneider Date: Fri, 1 Nov 2024 15:42:58 -0700 Subject: [PATCH 01/12] Update sarama dependency (#3587) Signed-off-by: yaron2 --- go.mod | 18 ++++++++-------- go.sum | 36 +++++++++++++++---------------- tests/certification/go.mod | 18 ++++++++-------- tests/certification/go.sum | 36 +++++++++++++++---------------- tests/e2e/pubsub/jetstream/go.mod | 6 +++--- tests/e2e/pubsub/jetstream/go.sum | 12 +++++------ 6 files changed, 63 insertions(+), 63 deletions(-) diff --git a/go.mod b/go.mod index 2f194602f4..13a6af3ab6 100644 --- a/go.mod +++ b/go.mod @@ -24,7 +24,7 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/storage/azqueue v1.0.0 github.com/Azure/go-amqp v1.0.5 github.com/DATA-DOG/go-sqlmock v1.5.0 - github.com/IBM/sarama v1.42.2 + github.com/IBM/sarama v1.43.3 github.com/aerospike/aerospike-client-go/v6 v6.12.0 github.com/alibaba/sentinel-golang v1.0.4 github.com/alibabacloud-go/darabonba-openapi v0.2.1 @@ -121,10 +121,10 @@ require ( go.uber.org/goleak v1.2.1 go.uber.org/multierr v1.11.0 go.uber.org/ratelimit v0.3.0 - golang.org/x/crypto v0.24.0 + golang.org/x/crypto v0.26.0 golang.org/x/exp v0.0.0-20240119083558-1b970713d09a golang.org/x/mod v0.17.0 - golang.org/x/net v0.26.0 + golang.org/x/net v0.28.0 golang.org/x/oauth2 v0.20.0 google.golang.org/api v0.180.0 google.golang.org/grpc v1.64.0 @@ -217,7 +217,7 @@ require ( github.com/dubbogo/triple v1.1.8 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/dvsekhvalnov/jose2go v1.6.0 // indirect - github.com/eapache/go-resiliency v1.5.0 // indirect + github.com/eapache/go-resiliency v1.7.0 // indirect github.com/eapache/go-xerial-snappy v0.0.0-20230731223053-c322873962e3 // indirect github.com/eapache/queue v1.1.0 // indirect github.com/emicklei/go-restful/v3 v3.10.1 // indirect @@ -297,7 +297,7 @@ require ( github.com/kataras/go-errors v0.0.3 // indirect github.com/kataras/go-serializer v0.0.4 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect - github.com/klauspost/compress v1.17.7 // indirect + github.com/klauspost/compress v1.17.9 // indirect github.com/knadh/koanf v1.4.1 // indirect github.com/kr/fs v0.1.0 // indirect github.com/kubemq-io/protobuf v1.3.1 // indirect @@ -393,10 +393,10 @@ require ( go.uber.org/atomic v1.10.0 // indirect go.uber.org/zap v1.24.0 // indirect golang.org/x/arch v0.10.0 // indirect - golang.org/x/sync v0.7.0 // indirect - golang.org/x/sys v0.21.0 // indirect - golang.org/x/term v0.21.0 // indirect - golang.org/x/text v0.16.0 // indirect + golang.org/x/sync v0.8.0 // indirect + golang.org/x/sys v0.23.0 // indirect + golang.org/x/term v0.23.0 // indirect + golang.org/x/text v0.17.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect google.golang.org/genproto v0.0.0-20240401170217-c3f982113cda // indirect diff --git a/go.sum b/go.sum index bc2344e5a3..54c28416d8 100644 --- a/go.sum +++ b/go.sum @@ -121,8 +121,8 @@ github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3 github.com/DataDog/zstd v1.5.2 h1:vUG4lAyuPCXO0TLbXvPv7EB7cNK1QV/luu55UHLrrn8= github.com/DataDog/zstd v1.5.2/go.mod h1:g4AWEaM3yOg3HYfnJ3YIawPnVdXJh9QME85blwSAmyw= github.com/HdrHistogram/hdrhistogram-go v1.1.2/go.mod h1:yDgFjdqOqDEKOvasDdhWNXYg9BVp4O+o5f6V/ehm6Oo= -github.com/IBM/sarama v1.42.2 h1:VoY4hVIZ+WQJ8G9KNY/SQlWguBQXQ9uvFPOnrcu8hEw= -github.com/IBM/sarama v1.42.2/go.mod h1:FLPGUGwYqEs62hq2bVG6Io2+5n+pS6s/WOXVKWSLFtE= +github.com/IBM/sarama v1.43.3 h1:Yj6L2IaNvb2mRBop39N7mmJAHBVY3dTPncr3qGVkxPA= +github.com/IBM/sarama v1.43.3/go.mod h1:FVIRaLrhK3Cla/9FfRF5X9Zua2KpS3SYIXxhac1H+FQ= github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= github.com/Netflix/go-env v0.0.0-20220526054621-78278af1949d h1:wvStE9wLpws31NiWUx+38wny1msZ/tm+eL5xmm4Y7So= github.com/Netflix/go-env v0.0.0-20220526054621-78278af1949d/go.mod h1:9XMFaCeRyW7fC9XJOWQ+NdAv8VLG7ys7l3x4ozEGLUQ= @@ -518,8 +518,8 @@ github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+m github.com/dvsekhvalnov/jose2go v1.6.0 h1:Y9gnSnP4qEI0+/uQkHvFXeD2PLPJeXEL+ySMEA2EjTY= github.com/dvsekhvalnov/jose2go v1.6.0/go.mod h1:QsHjhyTlD/lAVqn/NSbVZmSCGeDehTB/mPZadG+mhXU= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= -github.com/eapache/go-resiliency v1.5.0 h1:dRsaR00whmQD+SgVKlq/vCRFNgtEb5yppyeVos3Yce0= -github.com/eapache/go-resiliency v1.5.0/go.mod h1:5yPzW0MIvSe0JDsv0v+DvcjEv2FyD6iZYSs1ZI+iQho= +github.com/eapache/go-resiliency v1.7.0 h1:n3NRTnBn5N0Cbi/IeOHuQn9s2UwVUH7Ga0ZWcP+9JTA= +github.com/eapache/go-resiliency v1.7.0/go.mod h1:5yPzW0MIvSe0JDsv0v+DvcjEv2FyD6iZYSs1ZI+iQho= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= github.com/eapache/go-xerial-snappy v0.0.0-20230731223053-c322873962e3 h1:Oy0F4ALJ04o5Qqpdz8XLIpNA3WM/iSIXqxtqo7UGVws= github.com/eapache/go-xerial-snappy v0.0.0-20230731223053-c322873962e3/go.mod h1:YvSRo5mw33fLEx1+DlK6L2VV43tJt5Eyel9n9XBcR+0= @@ -1056,8 +1056,8 @@ github.com/kitex-contrib/obs-opentelemetry/logging/logrus v0.0.0-20220601144657- github.com/kitex-contrib/tracer-opentracing v0.0.2/go.mod h1:mprt5pxqywFQxlHb7ugfiMdKbABTLI9YrBYs9WmlK5Q= github.com/klauspost/compress v1.10.7/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= -github.com/klauspost/compress v1.17.7 h1:ehO88t2UGzQK66LMdE8tibEd1ErmzZjNEqWkjLAKQQg= -github.com/klauspost/compress v1.17.7/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= +github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.1.0/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= @@ -1796,8 +1796,8 @@ golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0 golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= -golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= -golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= +golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= +golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -1930,8 +1930,8 @@ golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.11.0/go.mod h1:2L/ixqYpgIVXmeoSA/4Lu7BzTG4KIyPIryS4IsOd1oQ= -golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= -golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= +golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= +golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -1957,8 +1957,8 @@ golang.org/x/sync v0.0.0-20220513210516-0976fa681c29/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180828065106-d99a578cf41b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -2076,8 +2076,8 @@ golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= -golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM= +golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -2088,8 +2088,8 @@ golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.9.0/go.mod h1:M6DEAAIenWoTxdKrOltXcmDY3rSplQUkrvaDU5FcQyo= -golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= -golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= +golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU= +golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -2107,8 +2107,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= +golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/tests/certification/go.mod b/tests/certification/go.mod index cda79b8d90..1dc9c0ad44 100644 --- a/tests/certification/go.mod +++ b/tests/certification/go.mod @@ -8,7 +8,7 @@ require ( cloud.google.com/go/pubsub v1.37.0 dubbo.apache.org/dubbo-go/v3 v3.0.3-0.20230118042253-4f159a2b38f3 github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.3.2 - github.com/IBM/sarama v1.42.2 + github.com/IBM/sarama v1.43.3 github.com/a8m/documentdb v1.3.0 github.com/apache/dubbo-go-hessian2 v1.11.5 github.com/apache/pulsar-client-go v0.11.0 @@ -128,7 +128,7 @@ require ( github.com/dubbogo/triple v1.1.8 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/dvsekhvalnov/jose2go v1.6.0 // indirect - github.com/eapache/go-resiliency v1.5.0 // indirect + github.com/eapache/go-resiliency v1.7.0 // indirect github.com/eapache/go-xerial-snappy v0.0.0-20230731223053-c322873962e3 // indirect github.com/eapache/queue v1.1.0 // indirect github.com/emicklei/go-restful/v3 v3.11.0 // indirect @@ -213,7 +213,7 @@ require ( github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/k0kubun/pp v3.0.1+incompatible // indirect - github.com/klauspost/compress v1.17.7 // indirect + github.com/klauspost/compress v1.17.9 // indirect github.com/knadh/koanf v1.4.1 // indirect github.com/kylelemons/godebug v1.1.0 // indirect github.com/leodido/go-urn v1.2.1 // indirect @@ -314,15 +314,15 @@ require ( go.uber.org/atomic v1.10.0 // indirect go.uber.org/zap v1.27.0 // indirect golang.org/x/arch v0.10.0 // indirect - golang.org/x/crypto v0.24.0 // indirect + golang.org/x/crypto v0.26.0 // indirect golang.org/x/exp v0.0.0-20240613232115-7f521ea00fb8 // indirect golang.org/x/mod v0.18.0 // indirect - golang.org/x/net v0.26.0 // indirect + golang.org/x/net v0.28.0 // indirect golang.org/x/oauth2 v0.20.0 // indirect - golang.org/x/sync v0.7.0 // indirect - golang.org/x/sys v0.21.0 // indirect - golang.org/x/term v0.21.0 // indirect - golang.org/x/text v0.16.0 // indirect + golang.org/x/sync v0.8.0 // indirect + golang.org/x/sys v0.23.0 // indirect + golang.org/x/term v0.23.0 // indirect + golang.org/x/text v0.17.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.22.0 // indirect gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect diff --git a/tests/certification/go.sum b/tests/certification/go.sum index 83a47faa47..145f05a305 100644 --- a/tests/certification/go.sum +++ b/tests/certification/go.sum @@ -110,8 +110,8 @@ github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3 github.com/DataDog/zstd v1.5.2 h1:vUG4lAyuPCXO0TLbXvPv7EB7cNK1QV/luu55UHLrrn8= github.com/DataDog/zstd v1.5.2/go.mod h1:g4AWEaM3yOg3HYfnJ3YIawPnVdXJh9QME85blwSAmyw= github.com/HdrHistogram/hdrhistogram-go v1.1.2/go.mod h1:yDgFjdqOqDEKOvasDdhWNXYg9BVp4O+o5f6V/ehm6Oo= -github.com/IBM/sarama v1.42.2 h1:VoY4hVIZ+WQJ8G9KNY/SQlWguBQXQ9uvFPOnrcu8hEw= -github.com/IBM/sarama v1.42.2/go.mod h1:FLPGUGwYqEs62hq2bVG6Io2+5n+pS6s/WOXVKWSLFtE= +github.com/IBM/sarama v1.43.3 h1:Yj6L2IaNvb2mRBop39N7mmJAHBVY3dTPncr3qGVkxPA= +github.com/IBM/sarama v1.43.3/go.mod h1:FVIRaLrhK3Cla/9FfRF5X9Zua2KpS3SYIXxhac1H+FQ= github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= @@ -443,8 +443,8 @@ github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+m github.com/dvsekhvalnov/jose2go v1.6.0 h1:Y9gnSnP4qEI0+/uQkHvFXeD2PLPJeXEL+ySMEA2EjTY= github.com/dvsekhvalnov/jose2go v1.6.0/go.mod h1:QsHjhyTlD/lAVqn/NSbVZmSCGeDehTB/mPZadG+mhXU= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= -github.com/eapache/go-resiliency v1.5.0 h1:dRsaR00whmQD+SgVKlq/vCRFNgtEb5yppyeVos3Yce0= -github.com/eapache/go-resiliency v1.5.0/go.mod h1:5yPzW0MIvSe0JDsv0v+DvcjEv2FyD6iZYSs1ZI+iQho= +github.com/eapache/go-resiliency v1.7.0 h1:n3NRTnBn5N0Cbi/IeOHuQn9s2UwVUH7Ga0ZWcP+9JTA= +github.com/eapache/go-resiliency v1.7.0/go.mod h1:5yPzW0MIvSe0JDsv0v+DvcjEv2FyD6iZYSs1ZI+iQho= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= github.com/eapache/go-xerial-snappy v0.0.0-20230731223053-c322873962e3 h1:Oy0F4ALJ04o5Qqpdz8XLIpNA3WM/iSIXqxtqo7UGVws= github.com/eapache/go-xerial-snappy v0.0.0-20230731223053-c322873962e3/go.mod h1:YvSRo5mw33fLEx1+DlK6L2VV43tJt5Eyel9n9XBcR+0= @@ -931,8 +931,8 @@ github.com/kitex-contrib/monitor-prometheus v0.0.0-20210817080809-024dd7bd51e1/g github.com/kitex-contrib/obs-opentelemetry v0.0.0-20220601144657-c60210e3c928/go.mod h1:VvMzPMfgL7iUG92eVZGuRybGVMKzuSrsfMvHHpL7/Ac= github.com/kitex-contrib/obs-opentelemetry/logging/logrus v0.0.0-20220601144657-c60210e3c928/go.mod h1:Eml/0Z+CqgGIPf9JXzLGu+N9NJoy2x5pqypN+hmKArE= github.com/kitex-contrib/tracer-opentracing v0.0.2/go.mod h1:mprt5pxqywFQxlHb7ugfiMdKbABTLI9YrBYs9WmlK5Q= -github.com/klauspost/compress v1.17.7 h1:ehO88t2UGzQK66LMdE8tibEd1ErmzZjNEqWkjLAKQQg= -github.com/klauspost/compress v1.17.7/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= +github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.1.0/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= @@ -1540,8 +1540,8 @@ golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= -golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= -golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= +golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= +golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -1656,8 +1656,8 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug golang.org/x/net v0.0.0-20221014081412-f15817d10f9b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= -golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= +golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= +golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -1683,8 +1683,8 @@ golang.org/x/sync v0.0.0-20220513210516-0976fa681c29/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= +golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -1788,13 +1788,13 @@ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= -golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM= +golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= -golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= +golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU= +golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -1807,8 +1807,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= +golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/tests/e2e/pubsub/jetstream/go.mod b/tests/e2e/pubsub/jetstream/go.mod index c9383d0c2a..801e078b2e 100644 --- a/tests/e2e/pubsub/jetstream/go.mod +++ b/tests/e2e/pubsub/jetstream/go.mod @@ -17,7 +17,7 @@ require ( github.com/golang/protobuf v1.5.4 // indirect github.com/google/uuid v1.6.0 // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/klauspost/compress v1.17.7 // indirect + github.com/klauspost/compress v1.17.9 // indirect github.com/mitchellh/mapstructure v1.5.1-0.20220423185008-bf980b35cac4 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect @@ -26,8 +26,8 @@ require ( github.com/nats-io/nuid v1.0.1 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/cast v1.5.1 // indirect - golang.org/x/crypto v0.24.0 // indirect - golang.org/x/sys v0.21.0 // indirect + golang.org/x/crypto v0.26.0 // indirect + golang.org/x/sys v0.23.0 // indirect google.golang.org/protobuf v1.34.1 // indirect gopkg.in/inf.v0 v0.9.1 // indirect k8s.io/apimachinery v0.26.10 // indirect diff --git a/tests/e2e/pubsub/jetstream/go.sum b/tests/e2e/pubsub/jetstream/go.sum index 281c3f7ec6..3118370999 100644 --- a/tests/e2e/pubsub/jetstream/go.sum +++ b/tests/e2e/pubsub/jetstream/go.sum @@ -27,8 +27,8 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/klauspost/compress v1.17.7 h1:ehO88t2UGzQK66LMdE8tibEd1ErmzZjNEqWkjLAKQQg= -github.com/klauspost/compress v1.17.7/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= +github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -79,8 +79,8 @@ go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= -golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= +golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= +golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRjZPVIx+MXInMzdvQqcA= golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= @@ -96,8 +96,8 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= -golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM= +golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= From b969bbfe88679f8800083850d696fa44cc6ea7b5 Mon Sep 17 00:00:00 2001 From: Yaron Schneider Date: Sun, 3 Nov 2024 21:25:09 -0800 Subject: [PATCH 02/12] Add receiverQueueSize to pulsar (#3589) Signed-off-by: yaron2 --- pubsub/pulsar/metadata.go | 1 + pubsub/pulsar/metadata.yaml | 9 ++++++++- pubsub/pulsar/pulsar.go | 4 ++++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/pubsub/pulsar/metadata.go b/pubsub/pulsar/metadata.go index 3533c6de37..62b3b06bbc 100644 --- a/pubsub/pulsar/metadata.go +++ b/pubsub/pulsar/metadata.go @@ -36,6 +36,7 @@ type pulsarMetadata struct { PrivateKey string `mapstructure:"privateKey"` Keys string `mapstructure:"keys"` MaxConcurrentHandlers uint `mapstructure:"maxConcurrentHandlers"` + ReceiverQueueSize int `mapstructure:"receiverQueueSize"` Token string `mapstructure:"token"` oauth2.ClientCredentialsMetadata `mapstructure:",squash"` diff --git a/pubsub/pulsar/metadata.yaml b/pubsub/pulsar/metadata.yaml index e81d6f75f8..7cc216cf12 100644 --- a/pubsub/pulsar/metadata.yaml +++ b/pubsub/pulsar/metadata.yaml @@ -176,4 +176,11 @@ metadata: description: | Sets the maximum number of concurrent messages sent to the application. Default is 100. default: '"100"' - example: '"100"' \ No newline at end of file + example: '"100"' + - name: receiverQueueSize + type: number + description: | + Sets the size of the consumer receive queue. + Controls how many messages can be accumulated by the consumer before it is explicitly called to read messages by Dapr. + default: '"1000"' + example: '"1000"' \ No newline at end of file diff --git a/pubsub/pulsar/pulsar.go b/pubsub/pulsar/pulsar.go index 6e4fb94d87..7822d63f5e 100644 --- a/pubsub/pulsar/pulsar.go +++ b/pubsub/pulsar/pulsar.go @@ -80,6 +80,8 @@ const ( defaultRedeliveryDelay = 30 * time.Second // defaultConcurrency controls the number of concurrent messages sent to the app. defaultConcurrency = 100 + // defaultReceiverQueueSize controls the number of messages the pulsar sdk pulls before dapr explicitly consumes the messages. + defaultReceiverQueueSize = 1000 subscribeTypeKey = "subscribeType" @@ -125,6 +127,7 @@ func parsePulsarMetadata(meta pubsub.Metadata) (*pulsarMetadata, error) { BatchingMaxSize: defaultMaxBatchSize, RedeliveryDelay: defaultRedeliveryDelay, MaxConcurrentHandlers: defaultConcurrency, + ReceiverQueueSize: defaultReceiverQueueSize, } if err := kitmd.DecodeMetadata(meta.Properties, &m); err != nil { @@ -403,6 +406,7 @@ func (p *Pulsar) Subscribe(ctx context.Context, req pubsub.SubscribeRequest, han Type: getSubscribeType(req.Metadata), MessageChannel: channel, NackRedeliveryDelay: p.metadata.RedeliveryDelay, + ReceiverQueueSize: p.metadata.ReceiverQueueSize, } if p.useConsumerEncryption() { From f0a99c114c5ab9c59e7d481461825861e738a898 Mon Sep 17 00:00:00 2001 From: bhagya <43932219+bhagya05@users.noreply.github.com> Date: Wed, 6 Nov 2024 03:58:38 +0530 Subject: [PATCH 03/12] Fix metadata header value sanitization (#3581) Signed-off-by: Bhagya Singh Purba Co-authored-by: bhagyapurba Co-authored-by: Yaron Schneider --- common/component/azure/blobstorage/request.go | 9 +++- .../azure/blobstorage/request_test.go | 48 +++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/common/component/azure/blobstorage/request.go b/common/component/azure/blobstorage/request.go index 1b969f7273..cc6922e225 100644 --- a/common/component/azure/blobstorage/request.go +++ b/common/component/azure/blobstorage/request.go @@ -107,7 +107,7 @@ func SanitizeMetadata(log logger.Logger, metadata map[string]string) map[string] n = 0 newVal := make([]byte, len(val)) for i := range len(val) { - if val[i] > 127 || val[i] == 0 { + if val[i] > 127 || (isCTL(val[i]) && !isLWS(val[i])) { continue } newVal[n] = val[i] @@ -118,3 +118,10 @@ func SanitizeMetadata(log logger.Logger, metadata map[string]string) map[string] return res } + +func isLWS(b byte) bool { return b == ' ' || b == '\t' } + +func isCTL(b byte) bool { + const del = 0x7f // a CTL + return b < ' ' || b == del +} diff --git a/common/component/azure/blobstorage/request_test.go b/common/component/azure/blobstorage/request_test.go index 222cbe4576..aed6d0e570 100644 --- a/common/component/azure/blobstorage/request_test.go +++ b/common/component/azure/blobstorage/request_test.go @@ -60,6 +60,7 @@ func TestSanitizeRequestMetadata(t *testing.T) { "somecustomfield": "some-custom-value", "specialfield": "special:valueÜ", "not-allowed:": "not-allowed", + "ctr-characters": string([]byte{72, 20, 1, 0, 101, 120}), } meta := SanitizeMetadata(log, m) _ = assert.NotNil(t, meta["somecustomfield"]) && @@ -68,5 +69,52 @@ func TestSanitizeRequestMetadata(t *testing.T) { assert.Equal(t, "special:value", *meta["specialfield"]) _ = assert.NotNil(t, meta["notallowed"]) && assert.Equal(t, "not-allowed", *meta["notallowed"]) + _ = assert.NotNil(t, meta["ctrcharacters"]) && + assert.Equal(t, string([]byte{72, 101, 120}), *meta["ctrcharacters"]) }) } + +func TestIsLWS(t *testing.T) { + // Test cases for isLWS + tests := []struct { + input byte + expected bool + }{ + {' ', true}, // Space character, should return true + {'\t', true}, // Tab character, should return true + {'A', false}, // Non-LWS character, should return false + {'1', false}, // Non-LWS character, should return false + {'\n', false}, // Newline, a CTL but not LWS, should return false + {0x7F, false}, // DEL character, a CTL but not LWS, should return false + } + + for _, tt := range tests { + t.Run("Testing for LWS", func(t *testing.T) { + result := isLWS(tt.input) + assert.Equal(t, tt.expected, result, "input: %v", tt.input) + }) + } +} + +func TestIsCTL(t *testing.T) { + // Test cases for isCTL + tests := []struct { + input byte + expected bool + }{ + {0x00, true}, // NUL, a control character + {0x1F, true}, // US (Unit Separator), a control character + {'\n', true}, // Newline, a control character + {0x7F, true}, // DEL, a control character + {'A', false}, // Non-CTL character + {'1', false}, // Non-CTL character + {' ', false}, // Space is not a CTL (although LWS) + } + + for _, tt := range tests { + t.Run("Testing for CTL characters", func(t *testing.T) { + result := isCTL(tt.input) + assert.Equal(t, tt.expected, result, "input: %v", tt.input) + }) + } +} From 9833e560208139a60db2df38ba09da586a3328d7 Mon Sep 17 00:00:00 2001 From: Elena Kolevska Date: Wed, 6 Nov 2024 01:27:24 +0000 Subject: [PATCH 04/12] Add elena-kolevska to codeowners for dapr bot (#3592) Signed-off-by: Elena Kolevska --- .github/scripts/dapr_bot.js | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/scripts/dapr_bot.js b/.github/scripts/dapr_bot.js index f5179db75f..279799958a 100644 --- a/.github/scripts/dapr_bot.js +++ b/.github/scripts/dapr_bot.js @@ -7,6 +7,7 @@ const owners = [ 'berndverst', 'daixiang0', 'DeepanshuA', + 'elena-kolevska', 'halspang', 'ItalyPaleAle', 'jjcollinge', From b05e19a431da2a3eb190e9261b517c663e7eb485 Mon Sep 17 00:00:00 2001 From: Elena Kolevska Date: Fri, 8 Nov 2024 05:57:57 +0000 Subject: [PATCH 05/12] Adds conformance tests for AWS Secrets store component (#3588) Signed-off-by: Elena Kolevska Co-authored-by: Yaron Schneider --- .../docker-compose-secrets-manager.yml | 15 ++++++ .../aws/secretsmanager/secretsmanager.tf | 54 +++++++++++++++++++ ...s.secretsmanager.secretsmanager-destroy.sh | 9 ++++ ...aws.secretsmanager.secretsmanager-setup.sh | 15 ++++++ ...t-conformance-state-aws-secrets-manager.sh | 9 ++++ .github/scripts/test-info.mjs | 11 ++++ secretstores/aws/secretmanager/metadata.yaml | 19 +++++++ .../aws/secretmanager/secretmanager.go | 11 ++-- .../secretsmanager/docker/secretsmanager.yml | 16 ++++++ .../terraform/secretsmanager.yml | 15 ++++++ tests/config/secretstores/tests.yml | 4 ++ .../conformance/secretstores/secretstores.go | 2 +- tests/conformance/secretstores_test.go | 5 ++ 13 files changed, 179 insertions(+), 6 deletions(-) create mode 100644 .github/infrastructure/docker-compose-secrets-manager.yml create mode 100644 .github/infrastructure/terraform/conformance/secretstores/aws/secretsmanager/secretsmanager.tf create mode 100755 .github/scripts/components-scripts/conformance-secretstores.aws.secretsmanager.secretsmanager-destroy.sh create mode 100755 .github/scripts/components-scripts/conformance-secretstores.aws.secretsmanager.secretsmanager-setup.sh create mode 100755 .github/scripts/docker-compose-init/init-conformance-state-aws-secrets-manager.sh create mode 100644 secretstores/aws/secretmanager/metadata.yaml create mode 100644 tests/config/secretstores/aws/secretsmanager/docker/secretsmanager.yml create mode 100644 tests/config/secretstores/aws/secretsmanager/terraform/secretsmanager.yml diff --git a/.github/infrastructure/docker-compose-secrets-manager.yml b/.github/infrastructure/docker-compose-secrets-manager.yml new file mode 100644 index 0000000000..4d9911e409 --- /dev/null +++ b/.github/infrastructure/docker-compose-secrets-manager.yml @@ -0,0 +1,15 @@ +version: "3.8" + +services: + localstack: + container_name: "conformance-aws-secrets-manager" + image: localstack/localstack + ports: + - "127.0.0.1:4566:4566" + environment: + - DEBUG=1 + - DOCKER_HOST=unix:///var/run/docker.sock + volumes: + - "${PWD}/.github/scripts/docker-compose-init/init-conformance-state-aws-secrets-manager.sh:/etc/localstack/init/ready.d/init-aws.sh" # ready hook + - "${LOCALSTACK_VOLUME_DIR:-./volume}:/var/lib/localstack" + - "/var/run/docker.sock:/var/run/docker.sock" \ No newline at end of file diff --git a/.github/infrastructure/terraform/conformance/secretstores/aws/secretsmanager/secretsmanager.tf b/.github/infrastructure/terraform/conformance/secretstores/aws/secretsmanager/secretsmanager.tf new file mode 100644 index 0000000000..7c1f3b2921 --- /dev/null +++ b/.github/infrastructure/terraform/conformance/secretstores/aws/secretsmanager/secretsmanager.tf @@ -0,0 +1,54 @@ +terraform { + required_version = ">=0.13" + + required_providers { + aws = { + source = "hashicorp/aws" + version = "~> 4.0" + } + } +} + +variable "TIMESTAMP" { + type = string + description = "Timestamp of the GitHub workflow run." +} + +variable "UNIQUE_ID" { + type = string + description = "Unique ID of the GitHub workflow run." +} + +provider "aws" { + region = "us-east-1" + default_tags { + tags = { + Purpose = "AutomatedConformanceTesting" + Timestamp = "${var.TIMESTAMP}" + } + } +} + +# Create the first secret in AWS Secrets Manager +resource "aws_secretsmanager_secret" "conftestsecret" { + name = "conftestsecret" + description = "Secret for conformance test" + recovery_window_in_days = 0 +} + +resource "aws_secretsmanager_secret_version" "conftestsecret_value" { + secret_id = aws_secretsmanager_secret.conftestsecret.id + secret_string = "abcd" +} + +# Create the second secret in AWS Secrets Manager +resource "aws_secretsmanager_secret" "secondsecret" { + name = "secondsecret" + description = "Another secret for conformance test" + recovery_window_in_days = 0 +} + +resource "aws_secretsmanager_secret_version" "secondsecret_value" { + secret_id = aws_secretsmanager_secret.secondsecret.id + secret_string = "efgh" +} diff --git a/.github/scripts/components-scripts/conformance-secretstores.aws.secretsmanager.secretsmanager-destroy.sh b/.github/scripts/components-scripts/conformance-secretstores.aws.secretsmanager.secretsmanager-destroy.sh new file mode 100755 index 0000000000..fa23ac24a3 --- /dev/null +++ b/.github/scripts/components-scripts/conformance-secretstores.aws.secretsmanager.secretsmanager-destroy.sh @@ -0,0 +1,9 @@ +#!/bin/sh + +set +e + +# Navigate to the Terraform directory +cd ".github/infrastructure/terraform/conformance/secretstores/aws/secretsmanager" + +# Run Terraform +terraform destroy -auto-approve -var="UNIQUE_ID=$UNIQUE_ID" -var="TIMESTAMP=$CURRENT_TIME" diff --git a/.github/scripts/components-scripts/conformance-secretstores.aws.secretsmanager.secretsmanager-setup.sh b/.github/scripts/components-scripts/conformance-secretstores.aws.secretsmanager.secretsmanager-setup.sh new file mode 100755 index 0000000000..759024a7dc --- /dev/null +++ b/.github/scripts/components-scripts/conformance-secretstores.aws.secretsmanager.secretsmanager-setup.sh @@ -0,0 +1,15 @@ +#!/bin/sh + +set -e + +# Set variables for GitHub Actions +echo "AWS_REGION=us-east-1" >> $GITHUB_ENV + +# Navigate to the Terraform directory +cd ".github/infrastructure/terraform/conformance/secretstores/aws/secretsmanager" + +# Run Terraform +terraform init +terraform validate -no-color +terraform plan -no-color -var="UNIQUE_ID=$UNIQUE_ID" -var="TIMESTAMP=$CURRENT_TIME" +terraform apply -auto-approve -var="UNIQUE_ID=$UNIQUE_ID" -var="TIMESTAMP=$CURRENT_TIME" diff --git a/.github/scripts/docker-compose-init/init-conformance-state-aws-secrets-manager.sh b/.github/scripts/docker-compose-init/init-conformance-state-aws-secrets-manager.sh new file mode 100755 index 0000000000..6cb9f25917 --- /dev/null +++ b/.github/scripts/docker-compose-init/init-conformance-state-aws-secrets-manager.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +awslocal secretsmanager create-secret \ + --name conftestsecret \ + --secret-string "abcd" + +awslocal secretsmanager create-secret \ + --name secondsecret \ + --secret-string "efgh" \ No newline at end of file diff --git a/.github/scripts/test-info.mjs b/.github/scripts/test-info.mjs index 67e7df5a9b..4d16848e4d 100644 --- a/.github/scripts/test-info.mjs +++ b/.github/scripts/test-info.mjs @@ -492,6 +492,17 @@ const components = { conformance: true, certification: true, }, + 'secretstores.aws.secretsmanager.terraform': { + conformance: true, + requireAWSCredentials: true, + requireTerraform: true, + conformanceSetup: 'conformance-secretstores.aws.secretsmanager.secretsmanager-setup.sh', + conformanceDestroy: 'conformance-secretstores.aws.secretsmanager.secretsmanager-destroy.sh', + }, + 'secretstores.aws.secretsmanager.docker': { + conformance: true, + conformanceSetup: 'docker-compose.sh secrets-manager', + }, 'state.aws.dynamodb': { certification: true, requireAWSCredentials: true, diff --git a/secretstores/aws/secretmanager/metadata.yaml b/secretstores/aws/secretmanager/metadata.yaml new file mode 100644 index 0000000000..21bfbd5b2c --- /dev/null +++ b/secretstores/aws/secretmanager/metadata.yaml @@ -0,0 +1,19 @@ +# yaml-language-server: $schema=../../../component-metadata-schema.json +schemaVersion: v1 +type: secretstores +name: aws.secretsmanager +version: v1 +status: beta +title: "AWS Secrets manager" +urls: + - title: Reference + url: https://docs.dapr.io/reference/components-reference/supported-secret-stores/aws-secret-manager/ +builtinAuthenticationProfiles: + - name: "aws" +metadata: + - name: endpoint + required: false + description: | + The Secrets manager endpoint. The AWS SDK will generate a default endpoint if not specified. Useful for local testing with AWS LocalStack + example: '"http://localhost:4566"' + type: string \ No newline at end of file diff --git a/secretstores/aws/secretmanager/secretmanager.go b/secretstores/aws/secretmanager/secretmanager.go index 33f5a54c9d..54ed329d35 100644 --- a/secretstores/aws/secretmanager/secretmanager.go +++ b/secretstores/aws/secretmanager/secretmanager.go @@ -41,10 +41,11 @@ func NewSecretManager(logger logger.Logger) secretstores.SecretStore { } type SecretManagerMetaData struct { - Region string `json:"region"` - AccessKey string `json:"accessKey"` - SecretKey string `json:"secretKey"` - SessionToken string `json:"sessionToken"` + Region string `json:"region" mapstructure:"region" mdignore:"true"` + AccessKey string `json:"accessKey" mapstructure:"accessKey" mdignore:"true"` + SecretKey string `json:"secretKey" mapstructure:"secretKey" mdignore:"true"` + SessionToken string `json:"sessionToken" mapstructure:"sessionToken" mdignore:"true"` + Endpoint string `json:"endpoint" mapstructure:"endpoint"` } type smSecretStore struct { @@ -136,7 +137,7 @@ func (s *smSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bulk } func (s *smSecretStore) getClient(metadata *SecretManagerMetaData) (*secretsmanager.SecretsManager, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, "") + sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) if err != nil { return nil, err } diff --git a/tests/config/secretstores/aws/secretsmanager/docker/secretsmanager.yml b/tests/config/secretstores/aws/secretsmanager/docker/secretsmanager.yml new file mode 100644 index 0000000000..10b8f61787 --- /dev/null +++ b/tests/config/secretstores/aws/secretsmanager/docker/secretsmanager.yml @@ -0,0 +1,16 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: awssecretmanager +spec: + type: secretstores.aws.secretmanager + version: v1 + metadata: + - name: endpoint + value: "http://localhost:4566" # AWS LocalStack address + - name: accessKey + value: "test" # AWS LocalStack placeholder + - name: secretKey + value: "test" # AWS LocalStack placeholder + - name: region + value: "us-east-1" # AWS LocalStack placeholder diff --git a/tests/config/secretstores/aws/secretsmanager/terraform/secretsmanager.yml b/tests/config/secretstores/aws/secretsmanager/terraform/secretsmanager.yml new file mode 100644 index 0000000000..0824f93bf2 --- /dev/null +++ b/tests/config/secretstores/aws/secretsmanager/terraform/secretsmanager.yml @@ -0,0 +1,15 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: awssecretmanager + namespace: default +spec: + type: secretstores.aws.secretmanager + version: v1 + metadata: + - name: accessKey + value: ${{AWS_ACCESS_KEY_ID}} + - name: secretKey + value: ${{AWS_SECRET_ACCESS_KEY}} + - name: region + value: ${{AWS_REGION}} diff --git a/tests/config/secretstores/tests.yml b/tests/config/secretstores/tests.yml index 134abffa08..40ed714c2f 100644 --- a/tests/config/secretstores/tests.yml +++ b/tests/config/secretstores/tests.yml @@ -5,6 +5,10 @@ components: operations: [] - component: local.file operations: [] + - component: aws.secretsmanager.docker + operations: [] + - component: aws.secretsmanager.terraform + operations: [] - component: azure.keyvault.certificate operations: [] - component: azure.keyvault.serviceprincipal diff --git a/tests/conformance/secretstores/secretstores.go b/tests/conformance/secretstores/secretstores.go index b086f1aa93..e680e7500d 100644 --- a/tests/conformance/secretstores/secretstores.go +++ b/tests/conformance/secretstores/secretstores.go @@ -58,7 +58,7 @@ func ConformanceTests(t *testing.T, props map[string]string, store secretstores. t.Run("ping", func(t *testing.T) { err := secretstores.Ping(context.Background(), store) - // TODO: Ideally, all stable components should implenment ping function, + // TODO: Ideally, all stable components should implement a ping function, // so will only assert require.NoError(t, err) finally, i.e. when current implementation // implements ping in existing stable components if err != nil { diff --git a/tests/conformance/secretstores_test.go b/tests/conformance/secretstores_test.go index ee524e13f4..2b9e7c9cd5 100644 --- a/tests/conformance/secretstores_test.go +++ b/tests/conformance/secretstores_test.go @@ -23,6 +23,7 @@ import ( "github.com/stretchr/testify/require" "github.com/dapr/components-contrib/secretstores" + ss_aws "github.com/dapr/components-contrib/secretstores/aws/secretmanager" ss_azure "github.com/dapr/components-contrib/secretstores/azure/keyvault" ss_hashicorp_vault "github.com/dapr/components-contrib/secretstores/hashicorp/vault" ss_kubernetes "github.com/dapr/components-contrib/secretstores/kubernetes" @@ -71,6 +72,10 @@ func loadSecretStore(name string) secretstores.SecretStore { return ss_local_file.NewLocalSecretStore(testLogger) case "hashicorp.vault": return ss_hashicorp_vault.NewHashiCorpVaultSecretStore(testLogger) + case "aws.secretsmanager.docker": + return ss_aws.NewSecretManager(testLogger) + case "aws.secretsmanager.terraform": + return ss_aws.NewSecretManager(testLogger) default: return nil } From 2b924c46c772bcd5088e44f9dc8349182f116f11 Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 13 Nov 2024 14:47:41 -0600 Subject: [PATCH 06/12] feat: add me to bot owners to run tests (#3600) Signed-off-by: Samantha Coyle --- .github/scripts/dapr_bot.js | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/scripts/dapr_bot.js b/.github/scripts/dapr_bot.js index 279799958a..328972d27e 100644 --- a/.github/scripts/dapr_bot.js +++ b/.github/scripts/dapr_bot.js @@ -21,6 +21,7 @@ const owners = [ 'RyanLettieri', 'shivamkm07', 'shubham1172', + 'sicoyle', 'skyao', 'Taction', 'tmacam', From a00a85355663d36a38b4029e6e538de24dfc340b Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 14 Nov 2024 13:04:56 -0600 Subject: [PATCH 07/12] feat(iam auth): allow iam roles anywhere auth profile (#3591) Signed-off-by: Samantha Coyle Signed-off-by: Sam --- .../builtin-authentication-profiles.yaml | 20 +- bindings/aws/dynamodb/dynamodb.go | 37 +- bindings/aws/kinesis/kinesis.go | 89 +-- bindings/aws/s3/s3.go | 87 ++- bindings/aws/ses/ses.go | 42 +- bindings/aws/sns/sns.go | 35 +- bindings/aws/sqs/sqs.go | 71 +- common/authentication/aws/aws.go | 125 ++-- common/authentication/aws/aws_test.go | 44 ++ common/authentication/aws/client.go | 209 ++++++ common/authentication/aws/client_fake.go | 79 +++ common/authentication/aws/client_test.go | 265 +++++++ common/authentication/aws/static.go | 272 +++++++ common/authentication/aws/static_test.go | 66 ++ common/authentication/aws/x509.go | 449 ++++++++++++ common/authentication/aws/x509_test.go | 125 ++++ common/authentication/postgresql/metadata.go | 2 +- go.mod | 4 + go.sum | 12 + pubsub/aws/snssqs/metadata.go | 2 +- pubsub/aws/snssqs/snssqs.go | 75 +- pubsub/aws/snssqs/snssqs_test.go | 22 +- .../aws/parameterstore/parameterstore.go | 40 +- .../aws/parameterstore/parameterstore_test.go | 332 +++++---- .../aws/secretmanager/secretmanager.go | 36 +- .../aws/secretmanager/secretmanager_test.go | 138 ++-- state/aws/dynamodb/dynamodb.go | 52 +- state/aws/dynamodb/dynamodb_test.go | 661 ++++++++++++------ .../certification/bindings/aws/s3/s3_test.go | 2 +- tests/certification/go.mod | 3 + tests/certification/go.sum | 4 + 31 files changed, 2647 insertions(+), 753 deletions(-) create mode 100644 common/authentication/aws/aws_test.go create mode 100644 common/authentication/aws/client.go create mode 100644 common/authentication/aws/client_fake.go create mode 100644 common/authentication/aws/client_test.go create mode 100644 common/authentication/aws/static.go create mode 100644 common/authentication/aws/static_test.go create mode 100644 common/authentication/aws/x509.go create mode 100644 common/authentication/aws/x509_test.go diff --git a/.build-tools/builtin-authentication-profiles.yaml b/.build-tools/builtin-authentication-profiles.yaml index 9113cf286b..cccb195a44 100644 --- a/.build-tools/builtin-authentication-profiles.yaml +++ b/.build-tools/builtin-authentication-profiles.yaml @@ -29,7 +29,25 @@ aws: type: string - title: "AWS: Credentials from Environment Variables" description: Use AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY from the environment - + - title: "AWS: IAM Roles Anywhere" + description: Use X.509 certificates to establish trust between AWS and your AWS account and the Dapr cluster using AWS IAM Roles Anywhere. + metadata: + - name: trustAnchorArn + description: | + ARN of the AWS Trust Anchor in the AWS account granting trust to the Dapr Certificate Authority. + example: arn:aws:rolesanywhere:us-west-1:012345678910:trust-anchor/01234568-0123-0123-0123-012345678901 + required: true + - name: trustProfileArn + description: | + ARN of the AWS IAM Profile in the trusting AWS account. + example: arn:aws:rolesanywhere:us-west-1:012345678910:profile/01234568-0123-0123-0123-012345678901 + required: true + - name: assumeRoleArn + description: | + ARN of the AWS IAM role to assume in the trusting AWS account. + example: arn:aws:iam:012345678910:role/exampleIAMRoleName + required: true + azuread: - title: "Azure AD: Managed identity" description: Authenticate using Azure AD and a managed identity. diff --git a/bindings/aws/dynamodb/dynamodb.go b/bindings/aws/dynamodb/dynamodb.go index bd882e7b55..755b3158d3 100644 --- a/bindings/aws/dynamodb/dynamodb.go +++ b/bindings/aws/dynamodb/dynamodb.go @@ -31,9 +31,9 @@ import ( // DynamoDB allows performing stateful operations on AWS DynamoDB. type DynamoDB struct { - client *dynamodb.DynamoDB - table string - logger logger.Logger + authProvider awsAuth.Provider + table string + logger logger.Logger } type dynamoDBMetadata struct { @@ -51,18 +51,27 @@ func NewDynamoDB(logger logger.Logger) bindings.OutputBinding { } // Init performs connection parsing for DynamoDB. -func (d *DynamoDB) Init(_ context.Context, metadata bindings.Metadata) error { +func (d *DynamoDB) Init(ctx context.Context, metadata bindings.Metadata) error { meta, err := d.getDynamoDBMetadata(metadata) if err != nil { return err } - client, err := d.getClient(meta) + opts := awsAuth.Options{ + Logger: d.logger, + Properties: metadata.Properties, + Region: meta.Region, + Endpoint: meta.Endpoint, + AccessKey: meta.AccessKey, + SecretKey: meta.SecretKey, + SessionToken: meta.SessionToken, + } + + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } - - d.client = client + d.authProvider = provider d.table = meta.Table return nil @@ -84,7 +93,7 @@ func (d *DynamoDB) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bi return nil, err } - _, err = d.client.PutItemWithContext(ctx, &dynamodb.PutItemInput{ + _, err = d.authProvider.DynamoDB().DynamoDB.PutItemWithContext(ctx, &dynamodb.PutItemInput{ Item: item, TableName: aws.String(d.table), }) @@ -105,16 +114,6 @@ func (d *DynamoDB) getDynamoDBMetadata(spec bindings.Metadata) (*dynamoDBMetadat return &meta, nil } -func (d *DynamoDB) getClient(metadata *dynamoDBMetadata) (*dynamodb.DynamoDB, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) - if err != nil { - return nil, err - } - c := dynamodb.New(sess) - - return c, nil -} - // GetComponentMetadata returns the metadata of the component. func (d *DynamoDB) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { metadataStruct := dynamoDBMetadata{} @@ -123,5 +122,5 @@ func (d *DynamoDB) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { } func (d *DynamoDB) Close() error { - return nil + return d.authProvider.Close() } diff --git a/bindings/aws/kinesis/kinesis.go b/bindings/aws/kinesis/kinesis.go index dbe0ceb918..7ede7ba245 100644 --- a/bindings/aws/kinesis/kinesis.go +++ b/bindings/aws/kinesis/kinesis.go @@ -27,7 +27,6 @@ import ( "github.com/aws/aws-sdk-go/service/kinesis" "github.com/cenkalti/backoff/v4" "github.com/google/uuid" - "github.com/vmware/vmware-go-kcl/clientlibrary/config" "github.com/vmware/vmware-go-kcl/clientlibrary/interfaces" "github.com/vmware/vmware-go-kcl/clientlibrary/worker" @@ -40,15 +39,16 @@ import ( // AWSKinesis allows receiving and sending data to/from AWS Kinesis stream. type AWSKinesis struct { - client *kinesis.Kinesis - metadata *kinesisMetadata + authProvider awsAuth.Provider + metadata *kinesisMetadata - worker *worker.Worker - workerConfig *config.KinesisClientLibConfiguration + worker *worker.Worker - streamARN *string - consumerARN *string - logger logger.Logger + streamName string + consumerName string + consumerARN *string + logger logger.Logger + consumerMode string closed atomic.Bool closeCh chan struct{} @@ -112,30 +112,25 @@ func (a *AWSKinesis) Init(ctx context.Context, metadata bindings.Metadata) error return fmt.Errorf("%s invalid \"mode\" field %s", "aws.kinesis", m.KinesisConsumerMode) } - client, err := a.getClient(m) - if err != nil { - return err - } + a.consumerMode = m.KinesisConsumerMode + a.streamName = m.StreamName + a.consumerName = m.ConsumerName + a.metadata = m - streamName := aws.String(m.StreamName) - stream, err := client.DescribeStreamWithContext(ctx, &kinesis.DescribeStreamInput{ - StreamName: streamName, - }) + opts := awsAuth.Options{ + Logger: a.logger, + Properties: metadata.Properties, + Region: m.Region, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: "", + } + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } - - if m.KinesisConsumerMode == SharedThroughput { - kclConfig := config.NewKinesisClientLibConfigWithCredential(m.ConsumerName, - m.StreamName, m.Region, m.ConsumerName, - client.Config.Credentials) - a.workerConfig = kclConfig - } - - a.streamARN = stream.StreamDescription.StreamARN - a.metadata = m - a.client = client - + a.authProvider = provider return nil } @@ -148,7 +143,7 @@ func (a *AWSKinesis) Invoke(ctx context.Context, req *bindings.InvokeRequest) (* if partitionKey == "" { partitionKey = uuid.New().String() } - _, err := a.client.PutRecordWithContext(ctx, &kinesis.PutRecordInput{ + _, err := a.authProvider.Kinesis().Kinesis.PutRecordWithContext(ctx, &kinesis.PutRecordInput{ StreamName: &a.metadata.StreamName, Data: req.Data, PartitionKey: &partitionKey, @@ -161,16 +156,15 @@ func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err er if a.closed.Load() { return errors.New("binding is closed") } - if a.metadata.KinesisConsumerMode == SharedThroughput { - a.worker = worker.NewWorker(a.recordProcessorFactory(ctx, handler), a.workerConfig) + a.worker = worker.NewWorker(a.recordProcessorFactory(ctx, handler), a.authProvider.Kinesis().WorkerCfg(ctx, a.streamName, a.consumerName, a.consumerMode)) err = a.worker.Start() if err != nil { return err } } else if a.metadata.KinesisConsumerMode == ExtendedFanout { var stream *kinesis.DescribeStreamOutput - stream, err = a.client.DescribeStream(&kinesis.DescribeStreamInput{StreamName: &a.metadata.StreamName}) + stream, err = a.authProvider.Kinesis().Kinesis.DescribeStream(&kinesis.DescribeStreamInput{StreamName: &a.metadata.StreamName}) if err != nil { return err } @@ -180,6 +174,10 @@ func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err er } } + stream, err := a.authProvider.Kinesis().Stream(ctx, a.streamName) + if err != nil { + return fmt.Errorf("failed to get kinesis stream arn: %v", err) + } // Wait for context cancelation then stop a.wg.Add(1) go func() { @@ -191,7 +189,7 @@ func (a *AWSKinesis) Read(ctx context.Context, handler bindings.Handler) (err er if a.metadata.KinesisConsumerMode == SharedThroughput { a.worker.Shutdown() } else if a.metadata.KinesisConsumerMode == ExtendedFanout { - a.deregisterConsumer(a.streamARN, a.consumerARN) + a.deregisterConsumer(ctx, stream, a.consumerARN) } }() @@ -226,8 +224,7 @@ func (a *AWSKinesis) Subscribe(ctx context.Context, streamDesc kinesis.StreamDes return default: } - - sub, err := a.client.SubscribeToShardWithContext(ctx, &kinesis.SubscribeToShardInput{ + sub, err := a.authProvider.Kinesis().Kinesis.SubscribeToShardWithContext(ctx, &kinesis.SubscribeToShardInput{ ConsumerARN: consumerARN, ShardId: s.ShardId, StartingPosition: &kinesis.StartingPosition{Type: aws.String(kinesis.ShardIteratorTypeLatest)}, @@ -269,14 +266,14 @@ func (a *AWSKinesis) Close() error { close(a.closeCh) } a.wg.Wait() - return nil + return a.authProvider.Close() } func (a *AWSKinesis) ensureConsumer(ctx context.Context, streamARN *string) (*string, error) { // Only set timeout on consumer call. conCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - consumer, err := a.client.DescribeStreamConsumerWithContext(conCtx, &kinesis.DescribeStreamConsumerInput{ + consumer, err := a.authProvider.Kinesis().Kinesis.DescribeStreamConsumerWithContext(conCtx, &kinesis.DescribeStreamConsumerInput{ ConsumerName: &a.metadata.ConsumerName, StreamARN: streamARN, }) @@ -288,7 +285,7 @@ func (a *AWSKinesis) ensureConsumer(ctx context.Context, streamARN *string) (*st } func (a *AWSKinesis) registerConsumer(ctx context.Context, streamARN *string) (*string, error) { - consumer, err := a.client.RegisterStreamConsumerWithContext(ctx, &kinesis.RegisterStreamConsumerInput{ + consumer, err := a.authProvider.Kinesis().Kinesis.RegisterStreamConsumerWithContext(ctx, &kinesis.RegisterStreamConsumerInput{ ConsumerName: &a.metadata.ConsumerName, StreamARN: streamARN, }) @@ -307,11 +304,11 @@ func (a *AWSKinesis) registerConsumer(ctx context.Context, streamARN *string) (* return consumer.Consumer.ConsumerARN, nil } -func (a *AWSKinesis) deregisterConsumer(streamARN *string, consumerARN *string) error { +func (a *AWSKinesis) deregisterConsumer(ctx context.Context, streamARN *string, consumerARN *string) error { if a.consumerARN != nil { // Use a background context because the running context may have been canceled already ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - _, err := a.client.DeregisterStreamConsumerWithContext(ctx, &kinesis.DeregisterStreamConsumerInput{ + _, err := a.authProvider.Kinesis().Kinesis.DeregisterStreamConsumerWithContext(ctx, &kinesis.DeregisterStreamConsumerInput{ ConsumerARN: consumerARN, StreamARN: streamARN, ConsumerName: &a.metadata.ConsumerName, @@ -342,7 +339,7 @@ func (a *AWSKinesis) waitUntilConsumerExists(ctx aws.Context, input *kinesis.Des tmp := *input inCpy = &tmp } - req, _ := a.client.DescribeStreamConsumerRequest(inCpy) + req, _ := a.authProvider.Kinesis().Kinesis.DescribeStreamConsumerRequest(inCpy) req.SetContext(ctx) req.ApplyOptions(opts...) @@ -354,16 +351,6 @@ func (a *AWSKinesis) waitUntilConsumerExists(ctx aws.Context, input *kinesis.Des return w.WaitWithContext(ctx) } -func (a *AWSKinesis) getClient(metadata *kinesisMetadata) (*kinesis.Kinesis, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) - if err != nil { - return nil, err - } - k := kinesis.New(sess) - - return k, nil -} - func (a *AWSKinesis) parseMetadata(meta bindings.Metadata) (*kinesisMetadata, error) { var m kinesisMetadata err := kitmd.DecodeMetadata(meta.Properties, &m) diff --git a/bindings/aws/s3/s3.go b/bindings/aws/s3/s3.go index cc67cec94f..13f8730e78 100644 --- a/bindings/aws/s3/s3.go +++ b/bindings/aws/s3/s3.go @@ -29,9 +29,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3" - "github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/google/uuid" @@ -61,11 +59,9 @@ const ( // AWSS3 is a binding for an AWS S3 storage bucket. type AWSS3 struct { - metadata *s3Metadata - s3Client *s3.S3 - uploader *s3manager.Uploader - downloader *s3manager.Downloader - logger logger.Logger + metadata *s3Metadata + authProvider awsAuth.Provider + logger logger.Logger } type s3Metadata struct { @@ -109,23 +105,11 @@ func NewAWSS3(logger logger.Logger) bindings.OutputBinding { return &AWSS3{logger: logger} } -// Init does metadata parsing and connection creation. -func (s *AWSS3) Init(_ context.Context, metadata bindings.Metadata) error { - m, err := s.parseMetadata(metadata) - if err != nil { - return err - } - session, err := s.getSession(m) - if err != nil { - return err - } - - cfg := aws.NewConfig(). - WithS3ForcePathStyle(m.ForcePathStyle). - WithDisableSSL(m.DisableSSL) +func (s *AWSS3) getAWSConfig(opts awsAuth.Options) *aws.Config { + cfg := awsAuth.GetConfig(opts).WithS3ForcePathStyle(s.metadata.ForcePathStyle).WithDisableSSL(s.metadata.DisableSSL) // Use a custom HTTP client to allow self-signed certs - if m.InsecureSSL { + if s.metadata.InsecureSSL { customTransport := http.DefaultTransport.(*http.Transport).Clone() customTransport.TLSClientConfig = &tls.Config{ //nolint:gosec @@ -138,17 +122,38 @@ func (s *AWSS3) Init(_ context.Context, metadata bindings.Metadata) error { s.logger.Infof("aws s3: you are using 'insecureSSL' to skip server config verify which is unsafe!") } + return cfg +} +// Init does metadata parsing and connection creation. +func (s *AWSS3) Init(ctx context.Context, metadata bindings.Metadata) error { + m, err := s.parseMetadata(metadata) + if err != nil { + return err + } s.metadata = m - s.s3Client = s3.New(session, cfg) - s.downloader = s3manager.NewDownloaderWithClient(s.s3Client) - s.uploader = s3manager.NewUploaderWithClient(s.s3Client) + + opts := awsAuth.Options{ + Logger: s.logger, + Properties: metadata.Properties, + Region: m.Region, + Endpoint: m.Endpoint, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: m.SessionToken, + } + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, s.getAWSConfig(opts)) + if err != nil { + return err + } + s.authProvider = provider return nil } func (s *AWSS3) Close() error { - return nil + return s.authProvider.Close() } func (s *AWSS3) Operations() []bindings.OperationKind { @@ -201,8 +206,7 @@ func (s *AWSS3) create(ctx context.Context, req *bindings.InvokeRequest) (*bindi if metadata.StorageClass != "" { storageClass = aws.String(metadata.StorageClass) } - - resultUpload, err := s.uploader.UploadWithContext(ctx, &s3manager.UploadInput{ + resultUpload, err := s.authProvider.S3().Uploader.UploadWithContext(ctx, &s3manager.UploadInput{ Bucket: ptr.Of(metadata.Bucket), Key: ptr.Of(key), Body: r, @@ -215,7 +219,7 @@ func (s *AWSS3) create(ctx context.Context, req *bindings.InvokeRequest) (*bindi var presignURL string if metadata.PresignTTL != "" { - url, presignErr := s.presignObject(metadata.Bucket, key, metadata.PresignTTL) + url, presignErr := s.presignObject(ctx, metadata.Bucket, key, metadata.PresignTTL) if presignErr != nil { return nil, fmt.Errorf("s3 binding error: %s", presignErr) } @@ -255,7 +259,7 @@ func (s *AWSS3) presign(ctx context.Context, req *bindings.InvokeRequest) (*bind return nil, fmt.Errorf("s3 binding error: required metadata '%s' missing", metadataPresignTTL) } - url, err := s.presignObject(metadata.Bucket, key, metadata.PresignTTL) + url, err := s.presignObject(ctx, metadata.Bucket, key, metadata.PresignTTL) if err != nil { return nil, fmt.Errorf("s3 binding error: %w", err) } @@ -272,13 +276,12 @@ func (s *AWSS3) presign(ctx context.Context, req *bindings.InvokeRequest) (*bind }, nil } -func (s *AWSS3) presignObject(bucket, key, ttl string) (string, error) { +func (s *AWSS3) presignObject(ctx context.Context, bucket, key, ttl string) (string, error) { d, err := time.ParseDuration(ttl) if err != nil { return "", fmt.Errorf("s3 binding error: cannot parse duration %s: %w", ttl, err) } - - objReq, _ := s.s3Client.GetObjectRequest(&s3.GetObjectInput{ + objReq, _ := s.authProvider.S3().S3.GetObjectRequest(&s3.GetObjectInput{ Bucket: ptr.Of(bucket), Key: ptr.Of(key), }) @@ -302,8 +305,7 @@ func (s *AWSS3) get(ctx context.Context, req *bindings.InvokeRequest) (*bindings } buff := &aws.WriteAtBuffer{} - - _, err = s.downloader.DownloadWithContext(ctx, + _, err = s.authProvider.S3().Downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{ Bucket: ptr.Of(s.metadata.Bucket), @@ -337,8 +339,7 @@ func (s *AWSS3) delete(ctx context.Context, req *bindings.InvokeRequest) (*bindi if key == "" { return nil, fmt.Errorf("s3 binding error: required metadata '%s' missing", metadataKey) } - - _, err := s.s3Client.DeleteObjectWithContext( + _, err := s.authProvider.S3().S3.DeleteObjectWithContext( ctx, &s3.DeleteObjectInput{ Bucket: ptr.Of(s.metadata.Bucket), @@ -367,8 +368,7 @@ func (s *AWSS3) list(ctx context.Context, req *bindings.InvokeRequest) (*binding if payload.MaxResults < 1 { payload.MaxResults = defaultMaxResults } - - result, err := s.s3Client.ListObjectsWithContext(ctx, &s3.ListObjectsInput{ + result, err := s.authProvider.S3().S3.ListObjectsWithContext(ctx, &s3.ListObjectsInput{ Bucket: ptr.Of(s.metadata.Bucket), MaxKeys: ptr.Of(int64(payload.MaxResults)), Marker: ptr.Of(payload.Marker), @@ -415,15 +415,6 @@ func (s *AWSS3) parseMetadata(md bindings.Metadata) (*s3Metadata, error) { return &m, nil } -func (s *AWSS3) getSession(metadata *s3Metadata) (*session.Session, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) - if err != nil { - return nil, err - } - - return sess, nil -} - // Helper to merge config and request metadata. func (metadata s3Metadata) mergeWithRequestMetadata(req *bindings.InvokeRequest) (s3Metadata, error) { merged := metadata diff --git a/bindings/aws/ses/ses.go b/bindings/aws/ses/ses.go index 483fde8c64..4cd752bac5 100644 --- a/bindings/aws/ses/ses.go +++ b/bindings/aws/ses/ses.go @@ -38,9 +38,9 @@ const ( // AWSSES is an AWS SNS binding. type AWSSES struct { - metadata *sesMetadata - logger logger.Logger - svc *ses.SES + authProvider awsAuth.Provider + metadata *sesMetadata + logger logger.Logger } type sesMetadata struct { @@ -61,19 +61,29 @@ func NewAWSSES(logger logger.Logger) bindings.OutputBinding { } // Init does metadata parsing. -func (a *AWSSES) Init(_ context.Context, metadata bindings.Metadata) error { +func (a *AWSSES) Init(ctx context.Context, metadata bindings.Metadata) error { // Parse input metadata - meta, err := a.parseMetadata(metadata) + m, err := a.parseMetadata(metadata) if err != nil { return err } - svc, err := a.getClient(meta) + a.metadata = m + + opts := awsAuth.Options{ + Logger: a.logger, + Properties: metadata.Properties, + Region: m.Region, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: "", + } + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } - a.metadata = meta - a.svc = svc + a.authProvider = provider return nil } @@ -141,7 +151,7 @@ func (a *AWSSES) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bind } // Attempt to send the email. - result, err := a.svc.SendEmail(input) + result, err := a.authProvider.Ses().Ses.SendEmail(input) if err != nil { return nil, fmt.Errorf("SES binding error. Sending email failed: %w", err) } @@ -158,18 +168,6 @@ func (metadata sesMetadata) mergeWithRequestMetadata(req *bindings.InvokeRequest return merged } -func (a *AWSSES) getClient(metadata *sesMetadata) (*ses.SES, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, "") - if err != nil { - return nil, fmt.Errorf("SES binding error: error creating AWS session %w", err) - } - - // Create an SES instance - svc := ses.New(sess) - - return svc, nil -} - // GetComponentMetadata returns the metadata of the component. func (a *AWSSES) GetComponentMetadata() (metadataInfo contribMetadata.MetadataMap) { metadataStruct := sesMetadata{} @@ -178,5 +176,5 @@ func (a *AWSSES) GetComponentMetadata() (metadataInfo contribMetadata.MetadataMa } func (a *AWSSES) Close() error { - return nil + return a.authProvider.Close() } diff --git a/bindings/aws/sns/sns.go b/bindings/aws/sns/sns.go index 43b63cd2b1..55e3ccefa5 100644 --- a/bindings/aws/sns/sns.go +++ b/bindings/aws/sns/sns.go @@ -30,8 +30,8 @@ import ( // AWSSNS is an AWS SNS binding. type AWSSNS struct { - client *sns.SNS - topicARN string + authProvider awsAuth.Provider + topicARN string logger logger.Logger } @@ -58,16 +58,27 @@ func NewAWSSNS(logger logger.Logger) bindings.OutputBinding { } // Init does metadata parsing. -func (a *AWSSNS) Init(_ context.Context, metadata bindings.Metadata) error { +func (a *AWSSNS) Init(ctx context.Context, metadata bindings.Metadata) error { m, err := a.parseMetadata(metadata) if err != nil { return err } - client, err := a.getClient(m) + + opts := awsAuth.Options{ + Logger: a.logger, + Properties: metadata.Properties, + Region: m.Region, + Endpoint: m.Endpoint, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: m.SessionToken, + } + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } - a.client = client + a.authProvider = provider a.topicARN = m.TopicArn return nil @@ -83,16 +94,6 @@ func (a *AWSSNS) parseMetadata(meta bindings.Metadata) (*snsMetadata, error) { return &m, nil } -func (a *AWSSNS) getClient(metadata *snsMetadata) (*sns.SNS, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) - if err != nil { - return nil, err - } - c := sns.New(sess) - - return c, nil -} - func (a *AWSSNS) Operations() []bindings.OperationKind { return []bindings.OperationKind{bindings.CreateOperation} } @@ -107,7 +108,7 @@ func (a *AWSSNS) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bind msg := fmt.Sprintf("%v", payload.Message) subject := fmt.Sprintf("%v", payload.Subject) - _, err = a.client.PublishWithContext(ctx, &sns.PublishInput{ + _, err = a.authProvider.Sns().Sns.PublishWithContext(ctx, &sns.PublishInput{ Message: &msg, Subject: &subject, TopicArn: &a.topicARN, @@ -127,5 +128,5 @@ func (a *AWSSNS) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { } func (a *AWSSNS) Close() error { - return nil + return a.authProvider.Close() } diff --git a/bindings/aws/sqs/sqs.go b/bindings/aws/sqs/sqs.go index 465e061b61..d803bafc5a 100644 --- a/bindings/aws/sqs/sqs.go +++ b/bindings/aws/sqs/sqs.go @@ -33,13 +33,12 @@ import ( // AWSSQS allows receiving and sending data to/from AWS SQS. type AWSSQS struct { - Client *sqs.SQS - QueueURL *string - - logger logger.Logger - wg sync.WaitGroup - closeCh chan struct{} - closed atomic.Bool + authProvider awsAuth.Provider + queueName string + logger logger.Logger + wg sync.WaitGroup + closeCh chan struct{} + closed atomic.Bool } type sqsMetadata struct { @@ -66,21 +65,22 @@ func (a *AWSSQS) Init(ctx context.Context, metadata bindings.Metadata) error { return err } - client, err := a.getClient(m) - if err != nil { - return err + opts := awsAuth.Options{ + Logger: a.logger, + Properties: metadata.Properties, + Region: m.Region, + Endpoint: m.Endpoint, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: m.SessionToken, } - - queueName := m.QueueName - resultURL, err := client.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{ - QueueName: aws.String(queueName), - }) + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } - - a.QueueURL = resultURL.QueueUrl - a.Client = client + a.authProvider = provider + a.queueName = m.QueueName return nil } @@ -91,9 +91,14 @@ func (a *AWSSQS) Operations() []bindings.OperationKind { func (a *AWSSQS) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) { msgBody := string(req.Data) - _, err := a.Client.SendMessageWithContext(ctx, &sqs.SendMessageInput{ + url, err := a.authProvider.Sqs().QueueURL(ctx, a.queueName) + if err != nil { + a.logger.Errorf("failed to get queue url: %v", err) + } + + _, err = a.authProvider.Sqs().Sqs.SendMessageWithContext(ctx, &sqs.SendMessageInput{ MessageBody: &msgBody, - QueueUrl: a.QueueURL, + QueueUrl: url, }) return nil, err @@ -113,9 +118,13 @@ func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error { if ctx.Err() != nil || a.closed.Load() { return } + url, err := a.authProvider.Sqs().QueueURL(ctx, a.queueName) + if err != nil { + a.logger.Errorf("failed to get queue url: %v", err) + } - result, err := a.Client.ReceiveMessageWithContext(ctx, &sqs.ReceiveMessageInput{ - QueueUrl: a.QueueURL, + result, err := a.authProvider.Sqs().Sqs.ReceiveMessageWithContext(ctx, &sqs.ReceiveMessageInput{ + QueueUrl: url, AttributeNames: aws.StringSlice([]string{ "SentTimestamp", }), @@ -126,7 +135,7 @@ func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error { WaitTimeSeconds: aws.Int64(20), }) if err != nil { - a.logger.Errorf("Unable to receive message from queue %q, %v.", *a.QueueURL, err) + a.logger.Errorf("Unable to receive message from queue %q, %v.", url, err) } if len(result.Messages) > 0 { @@ -140,8 +149,8 @@ func (a *AWSSQS) Read(ctx context.Context, handler bindings.Handler) error { msgHandle := m.ReceiptHandle // Use a background context here because ctx may be canceled already - a.Client.DeleteMessageWithContext(context.Background(), &sqs.DeleteMessageInput{ - QueueUrl: a.QueueURL, + a.authProvider.Sqs().Sqs.DeleteMessageWithContext(context.Background(), &sqs.DeleteMessageInput{ + QueueUrl: url, ReceiptHandle: msgHandle, }) } @@ -164,7 +173,7 @@ func (a *AWSSQS) Close() error { close(a.closeCh) } a.wg.Wait() - return nil + return a.authProvider.Close() } func (a *AWSSQS) parseSQSMetadata(meta bindings.Metadata) (*sqsMetadata, error) { @@ -177,16 +186,6 @@ func (a *AWSSQS) parseSQSMetadata(meta bindings.Metadata) (*sqsMetadata, error) return &m, nil } -func (a *AWSSQS) getClient(metadata *sqsMetadata) (*sqs.SQS, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) - if err != nil { - return nil, err - } - c := sqs.New(sess) - - return c, nil -} - // GetComponentMetadata returns the metadata of the component. func (a *AWSSQS) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { metadataStruct := sqsMetadata{} diff --git a/common/authentication/aws/aws.go b/common/authentication/aws/aws.go index 48c8b209a4..a45eb48277 100644 --- a/common/authentication/aws/aws.go +++ b/common/authentication/aws/aws.go @@ -20,14 +20,10 @@ import ( "strconv" "time" - awsv2 "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" v2creds "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/feature/rds/auth" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/aws/session" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" @@ -38,59 +34,78 @@ type EnvironmentSettings struct { Metadata map[string]string } -func GetConfigV2(accessKey string, secretKey string, sessionToken string, region string, endpoint string) (awsv2.Config, error) { - optFns := []func(*config.LoadOptions) error{} - if region != "" { - optFns = append(optFns, config.WithRegion(region)) - } +type AWSIAM struct { + // Ignored by metadata parser because included in built-in authentication profile + // Access key to use for accessing PostgreSQL. + AWSAccessKey string `json:"awsAccessKey" mapstructure:"awsAccessKey"` + // Secret key to use for accessing PostgreSQL. + AWSSecretKey string `json:"awsSecretKey" mapstructure:"awsSecretKey"` + // AWS region in which PostgreSQL is deployed. + AWSRegion string `json:"awsRegion" mapstructure:"awsRegion"` +} - if accessKey != "" && secretKey != "" { - provider := v2creds.NewStaticCredentialsProvider(accessKey, secretKey, sessionToken) - optFns = append(optFns, config.WithCredentialsProvider(provider)) - } +type AWSIAMAuthOptions struct { + PoolConfig *pgxpool.Config `json:"poolConfig" mapstructure:"poolConfig"` + ConnectionString string `json:"connectionString" mapstructure:"connectionString"` + Region string `json:"region" mapstructure:"region"` + AccessKey string `json:"accessKey" mapstructure:"accessKey"` + SecretKey string `json:"secretKey" mapstructure:"secretKey"` +} - awsCfg, err := config.LoadDefaultConfig(context.Background(), optFns...) - if err != nil { - return awsv2.Config{}, err - } +type Options struct { + Logger logger.Logger + Properties map[string]string - if endpoint != "" { - awsCfg.BaseEndpoint = &endpoint - } + PoolConfig *pgxpool.Config `json:"poolConfig" mapstructure:"poolConfig"` + ConnectionString string `json:"connectionString" mapstructure:"connectionString"` - return awsCfg, nil + Region string `json:"region" mapstructure:"region"` + AccessKey string `json:"accessKey" mapstructure:"accessKey"` + SecretKey string `json:"secretKey" mapstructure:"secretKey"` + + Endpoint string + SessionToken string } -func GetClient(accessKey string, secretKey string, sessionToken string, region string, endpoint string) (*session.Session, error) { - awsConfig := aws.NewConfig() +func GetConfig(opts Options) *aws.Config { + cfg := aws.NewConfig() - if region != "" { - awsConfig = awsConfig.WithRegion(region) + switch { + case opts.Region != "": + cfg.WithRegion(opts.Region) + case opts.Endpoint != "": + cfg.WithEndpoint(opts.Endpoint) } - if accessKey != "" && secretKey != "" { - awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(accessKey, secretKey, sessionToken)) - } + return cfg +} - if endpoint != "" { - awsConfig = awsConfig.WithEndpoint(endpoint) - } +type Provider interface { + S3() *S3Clients + DynamoDB() *DynamoDBClients + Sqs() *SqsClients + Sns() *SnsClients + SnsSqs() *SnsSqsClients + SecretManager() *SecretManagerClients + ParameterStore() *ParameterStoreClients + Kinesis() *KinesisClients + Ses() *SesClients + + Close() error +} - awsSession, err := session.NewSessionWithOptions(session.Options{ - Config: *awsConfig, - SharedConfigState: session.SharedConfigEnable, - }) - if err != nil { - return nil, err - } +func isX509Auth(m map[string]string) bool { + tp, _ := m["trustProfileArn"] + ta, _ := m["trustAnchorArn"] + ar, _ := m["assumeRoleArn"] + return tp != "" && ta != "" && ar != "" +} - userAgentHandler := request.NamedHandler{ - Name: "UserAgentHandler", - Fn: request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion), +func NewProvider(ctx context.Context, opts Options, cfg *aws.Config) (Provider, error) { + if isX509Auth(opts.Properties) { + return newX509(ctx, opts, cfg) } - awsSession.Handlers.Build.PushBackNamed(userAgentHandler) - - return awsSession, nil + return newStaticIAM(ctx, opts, cfg) } // NewEnvironmentSettings returns a new EnvironmentSettings configured for a given AWS resource. @@ -102,25 +117,7 @@ func NewEnvironmentSettings(md map[string]string) (EnvironmentSettings, error) { return es, nil } -type AWSIAM struct { - // Ignored by metadata parser because included in built-in authentication profile - // Access key to use for accessing PostgreSQL. - AWSAccessKey string `json:"awsAccessKey" mapstructure:"awsAccessKey"` - // Secret key to use for accessing PostgreSQL. - AWSSecretKey string `json:"awsSecretKey" mapstructure:"awsSecretKey"` - // AWS region in which PostgreSQL is deployed. - AWSRegion string `json:"awsRegion" mapstructure:"awsRegion"` -} - -type AWSIAMAuthOptions struct { - PoolConfig *pgxpool.Config `json:"poolConfig" mapstructure:"poolConfig"` - ConnectionString string `json:"connectionString" mapstructure:"connectionString"` - Region string `json:"region" mapstructure:"region"` - AccessKey string `json:"accessKey" mapstructure:"accessKey"` - SecretKey string `json:"secretKey" mapstructure:"secretKey"` -} - -func (opts *AWSIAMAuthOptions) GetAccessToken(ctx context.Context) (string, error) { +func (opts *Options) GetAccessToken(ctx context.Context) (string, error) { dbEndpoint := opts.PoolConfig.ConnConfig.Host + ":" + strconv.Itoa(int(opts.PoolConfig.ConnConfig.Port)) var authenticationToken string @@ -160,7 +157,7 @@ func (opts *AWSIAMAuthOptions) GetAccessToken(ctx context.Context) (string, erro return authenticationToken, nil } -func (opts *AWSIAMAuthOptions) InitiateAWSIAMAuth() error { +func (opts *Options) InitiateAWSIAMAuth() error { // Set max connection lifetime to 8 minutes in postgres connection pool configuration. // Note: this will refresh connections before the 15 min expiration on the IAM AWS auth token, // while leveraging the BeforeConnect hook to recreate the token in time dynamically. diff --git a/common/authentication/aws/aws_test.go b/common/authentication/aws/aws_test.go new file mode 100644 index 0000000000..15aac78ad7 --- /dev/null +++ b/common/authentication/aws/aws_test.go @@ -0,0 +1,44 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewEnvironmentSettings(t *testing.T) { + tests := []struct { + name string + metadata map[string]string + }{ + { + name: "valid metadata", + metadata: map[string]string{ + "key1": "value1", + "key2": "value2", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := NewEnvironmentSettings(tt.metadata) + require.NoError(t, err) + assert.NotNil(t, result) + }) + } +} diff --git a/common/authentication/aws/client.go b/common/authentication/aws/client.go new file mode 100644 index 0000000000..8d0e9de20b --- /dev/null +++ b/common/authentication/aws/client.go @@ -0,0 +1,209 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + "errors" + "sync" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" + "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/aws/aws-sdk-go/service/secretsmanager" + "github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface" + "github.com/aws/aws-sdk-go/service/ses" + "github.com/aws/aws-sdk-go/service/sns" + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" + "github.com/aws/aws-sdk-go/service/ssm" + "github.com/aws/aws-sdk-go/service/ssm/ssmiface" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/vmware/vmware-go-kcl/clientlibrary/config" +) + +type Clients struct { + mu sync.RWMutex + + s3 *S3Clients + Dynamo *DynamoDBClients + sns *SnsClients + sqs *SqsClients + snssqs *SnsSqsClients + Secret *SecretManagerClients + ParameterStore *ParameterStoreClients + kinesis *KinesisClients + ses *SesClients +} + +func newClients() *Clients { + return new(Clients) +} + +func (c *Clients) refresh(session *session.Session) { + c.mu.Lock() + defer c.mu.Unlock() + switch { + case c.s3 != nil: + c.s3.New(session) + case c.Dynamo != nil: + c.Dynamo.New(session) + case c.sns != nil: + c.sns.New(session) + case c.sqs != nil: + c.sqs.New(session) + case c.snssqs != nil: + c.snssqs.New(session) + case c.Secret != nil: + c.Secret.New(session) + case c.ParameterStore != nil: + c.ParameterStore.New(session) + case c.kinesis != nil: + c.kinesis.New(session) + case c.ses != nil: + c.ses.New(session) + } +} + +type S3Clients struct { + S3 *s3.S3 + Uploader *s3manager.Uploader + Downloader *s3manager.Downloader +} + +type DynamoDBClients struct { + DynamoDB dynamodbiface.DynamoDBAPI +} + +type SnsSqsClients struct { + Sns *sns.SNS + Sqs *sqs.SQS + Sts *sts.STS +} + +type SnsClients struct { + Sns *sns.SNS +} + +type SqsClients struct { + Sqs sqsiface.SQSAPI +} + +type SecretManagerClients struct { + Manager secretsmanageriface.SecretsManagerAPI +} + +type ParameterStoreClients struct { + Store ssmiface.SSMAPI +} + +type KinesisClients struct { + Kinesis kinesisiface.KinesisAPI + Region string + Credentials *credentials.Credentials +} + +type SesClients struct { + Ses *ses.SES +} + +func (c *S3Clients) New(session *session.Session) { + refreshedS3 := s3.New(session, session.Config) + c.S3 = refreshedS3 + c.Uploader = s3manager.NewUploaderWithClient(refreshedS3) + c.Downloader = s3manager.NewDownloaderWithClient(refreshedS3) +} + +func (c *DynamoDBClients) New(session *session.Session) { + c.DynamoDB = dynamodb.New(session, session.Config) +} + +func (c *SnsClients) New(session *session.Session) { + c.Sns = sns.New(session, session.Config) +} + +func (c *SnsSqsClients) New(session *session.Session) { + c.Sns = sns.New(session, session.Config) + c.Sqs = sqs.New(session, session.Config) + c.Sts = sts.New(session, session.Config) +} + +func (c *SqsClients) New(session *session.Session) { + c.Sqs = sqs.New(session, session.Config) +} + +func (c *SqsClients) QueueURL(ctx context.Context, queueName string) (*string, error) { + if c.Sqs != nil { + resultURL, err := c.Sqs.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{ + QueueName: aws.String(queueName), + }) + if resultURL != nil { + return resultURL.QueueUrl, err + } + } + return nil, errors.New("unable to get queue url due to empty client") +} + +func (c *SecretManagerClients) New(session *session.Session) { + c.Manager = secretsmanager.New(session, session.Config) +} + +func (c *ParameterStoreClients) New(session *session.Session) { + c.Store = ssm.New(session, session.Config) +} + +func (c *KinesisClients) New(session *session.Session) { + c.Kinesis = kinesis.New(session, session.Config) + c.Region = *session.Config.Region + c.Credentials = session.Config.Credentials +} + +func (c *KinesisClients) Stream(ctx context.Context, streamName string) (*string, error) { + if c.Kinesis != nil { + stream, err := c.Kinesis.DescribeStreamWithContext(ctx, &kinesis.DescribeStreamInput{ + StreamName: aws.String(streamName), + }) + if stream != nil { + return stream.StreamDescription.StreamARN, err + } + } + + return nil, errors.New("unable to get stream arn due to empty client") +} + +func (c *KinesisClients) WorkerCfg(ctx context.Context, stream, consumer, mode string) *config.KinesisClientLibConfiguration { + const sharedMode = "shared" + if c.Kinesis != nil { + if mode == sharedMode { + if c.Credentials != nil { + kclConfig := config.NewKinesisClientLibConfigWithCredential(consumer, + stream, c.Region, consumer, + c.Credentials) + return kclConfig + } + } + } + + return nil +} + +func (c *SesClients) New(session *session.Session) { + c.Ses = ses.New(session, session.Config) +} diff --git a/common/authentication/aws/client_fake.go b/common/authentication/aws/client_fake.go new file mode 100644 index 0000000000..c9e23641ba --- /dev/null +++ b/common/authentication/aws/client_fake.go @@ -0,0 +1,79 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" + "github.com/aws/aws-sdk-go/service/secretsmanager" + "github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface" + "github.com/aws/aws-sdk-go/service/ssm" + "github.com/aws/aws-sdk-go/service/ssm/ssmiface" +) + +type MockParameterStore struct { + GetParameterFn func(context.Context, *ssm.GetParameterInput, ...request.Option) (*ssm.GetParameterOutput, error) + DescribeParametersFn func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) + ssmiface.SSMAPI +} + +func (m *MockParameterStore) GetParameterWithContext(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + return m.GetParameterFn(ctx, input, option...) +} + +func (m *MockParameterStore) DescribeParametersWithContext(ctx context.Context, input *ssm.DescribeParametersInput, option ...request.Option) (*ssm.DescribeParametersOutput, error) { + return m.DescribeParametersFn(ctx, input, option...) +} + +type MockSecretManager struct { + GetSecretValueFn func(context.Context, *secretsmanager.GetSecretValueInput, ...request.Option) (*secretsmanager.GetSecretValueOutput, error) + secretsmanageriface.SecretsManagerAPI +} + +func (m *MockSecretManager) GetSecretValueWithContext(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { + return m.GetSecretValueFn(ctx, input, option...) +} + +type MockDynamoDB struct { + GetItemWithContextFn func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) + PutItemWithContextFn func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (*dynamodb.PutItemOutput, error) + DeleteItemWithContextFn func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (*dynamodb.DeleteItemOutput, error) + BatchWriteItemWithContextFn func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (*dynamodb.BatchWriteItemOutput, error) + TransactWriteItemsWithContextFn func(aws.Context, *dynamodb.TransactWriteItemsInput, ...request.Option) (*dynamodb.TransactWriteItemsOutput, error) + dynamodbiface.DynamoDBAPI +} + +func (m *MockDynamoDB) GetItemWithContext(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) { + return m.GetItemWithContextFn(ctx, input, op...) +} + +func (m *MockDynamoDB) PutItemWithContext(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (*dynamodb.PutItemOutput, error) { + return m.PutItemWithContextFn(ctx, input, op...) +} + +func (m *MockDynamoDB) DeleteItemWithContext(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (*dynamodb.DeleteItemOutput, error) { + return m.DeleteItemWithContextFn(ctx, input, op...) +} + +func (m *MockDynamoDB) BatchWriteItemWithContext(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (*dynamodb.BatchWriteItemOutput, error) { + return m.BatchWriteItemWithContextFn(ctx, input, op...) +} + +func (m *MockDynamoDB) TransactWriteItemsWithContext(ctx context.Context, input *dynamodb.TransactWriteItemsInput, op ...request.Option) (*dynamodb.TransactWriteItemsOutput, error) { + return m.TransactWriteItemsWithContextFn(ctx, input, op...) +} diff --git a/common/authentication/aws/client_test.go b/common/authentication/aws/client_test.go new file mode 100644 index 0000000000..67d2ac88f3 --- /dev/null +++ b/common/authentication/aws/client_test.go @@ -0,0 +1,265 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + "errors" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/kinesis" + "github.com/aws/aws-sdk-go/service/kinesis/kinesisiface" + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/vmware/vmware-go-kcl/clientlibrary/config" +) + +type mockedSQS struct { + sqsiface.SQSAPI + GetQueueURLFn func(ctx context.Context, input *sqs.GetQueueUrlInput) (*sqs.GetQueueUrlOutput, error) +} + +func (m *mockedSQS) GetQueueUrlWithContext(ctx context.Context, input *sqs.GetQueueUrlInput, opts ...request.Option) (*sqs.GetQueueUrlOutput, error) { //nolint:stylecheck + return m.GetQueueURLFn(ctx, input) +} + +type mockedKinesis struct { + kinesisiface.KinesisAPI + DescribeStreamFn func(ctx context.Context, input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) +} + +func (m *mockedKinesis) DescribeStreamWithContext(ctx context.Context, input *kinesis.DescribeStreamInput, opts ...request.Option) (*kinesis.DescribeStreamOutput, error) { + return m.DescribeStreamFn(ctx, input) +} + +func TestS3Clients_New(t *testing.T) { + tests := []struct { + name string + s3Client *S3Clients + session *session.Session + }{ + {"initializes S3 client", &S3Clients{}, session.Must(session.NewSession())}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.s3Client.New(tt.session) + require.NotNil(t, tt.s3Client.S3) + require.NotNil(t, tt.s3Client.Uploader) + require.NotNil(t, tt.s3Client.Downloader) + }) + } +} + +func TestSqsClients_QueueURL(t *testing.T) { + tests := []struct { + name string + mockFn func() *mockedSQS + queueName string + expectedURL *string + expectError bool + }{ + { + name: "returns queue URL successfully", + mockFn: func() *mockedSQS { + return &mockedSQS{ + GetQueueURLFn: func(ctx context.Context, input *sqs.GetQueueUrlInput) (*sqs.GetQueueUrlOutput, error) { + return &sqs.GetQueueUrlOutput{ + QueueUrl: aws.String("https://sqs.aws.com/123456789012/queue"), + }, nil + }, + } + }, + queueName: "valid-queue", + expectedURL: aws.String("https://sqs.aws.com/123456789012/queue"), + expectError: false, + }, + { + name: "returns error when queue URL not found", + mockFn: func() *mockedSQS { + return &mockedSQS{ + GetQueueURLFn: func(ctx context.Context, input *sqs.GetQueueUrlInput) (*sqs.GetQueueUrlOutput, error) { + return nil, errors.New("unable to get stream arn due to empty client") + }, + } + }, + queueName: "missing-queue", + expectedURL: nil, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockSQS := tt.mockFn() + + // Initialize SqsClients with the mocked SQS client + client := &SqsClients{ + Sqs: mockSQS, + } + + url, err := client.QueueURL(context.Background(), tt.queueName) + + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedURL, url) + } + }) + } +} + +func TestKinesisClients_Stream(t *testing.T) { + tests := []struct { + name string + kinesisClient *KinesisClients + streamName string + mockStreamARN *string + mockError error + expectedStream *string + expectedErr error + }{ + { + name: "successfully retrieves stream ARN", + kinesisClient: &KinesisClients{ + Kinesis: &mockedKinesis{DescribeStreamFn: func(ctx context.Context, input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { + return &kinesis.DescribeStreamOutput{ + StreamDescription: &kinesis.StreamDescription{ + StreamARN: aws.String("arn:aws:kinesis:some-region:123456789012:stream/some-stream"), + }, + }, nil + }}, + Region: "us-west-1", + Credentials: credentials.NewStaticCredentials("accessKey", "secretKey", ""), + }, + streamName: "some-stream", + expectedStream: aws.String("arn:aws:kinesis:some-region:123456789012:stream/some-stream"), + expectedErr: nil, + }, + { + name: "returns error when stream not found", + kinesisClient: &KinesisClients{ + Kinesis: &mockedKinesis{DescribeStreamFn: func(ctx context.Context, input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { + return nil, errors.New("stream not found") + }}, + Region: "us-west-1", + Credentials: credentials.NewStaticCredentials("accessKey", "secretKey", ""), + }, + streamName: "nonexistent-stream", + expectedStream: nil, + expectedErr: errors.New("unable to get stream arn due to empty client"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.kinesisClient.Stream(context.Background(), tt.streamName) + if tt.expectedErr != nil { + require.Error(t, err) + assert.Equal(t, tt.expectedErr.Error(), err.Error()) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedStream, got) + } + }) + } +} + +func TestKinesisClients_WorkerCfg(t *testing.T) { + testCreds := credentials.NewStaticCredentials("accessKey", "secretKey", "") + tests := []struct { + name string + kinesisClient *KinesisClients + streamName string + consumer string + mode string + expectedConfig *config.KinesisClientLibConfiguration + }{ + { + name: "successfully creates shared mode worker config", + kinesisClient: &KinesisClients{ + Kinesis: &mockedKinesis{ + DescribeStreamFn: func(ctx context.Context, input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { + return &kinesis.DescribeStreamOutput{ + StreamDescription: &kinesis.StreamDescription{ + StreamARN: aws.String("arn:aws:kinesis:us-east-1:123456789012:stream/existing-stream"), + }, + }, nil + }, + }, + Region: "us-west-1", + Credentials: testCreds, + }, + streamName: "existing-stream", + consumer: "consumer1", + mode: "shared", + expectedConfig: config.NewKinesisClientLibConfigWithCredential( + "consumer1", "existing-stream", "us-west-1", "consumer1", testCreds, + ), + }, + { + name: "returns nil when mode is not shared", + kinesisClient: &KinesisClients{ + Kinesis: &mockedKinesis{ + DescribeStreamFn: func(ctx context.Context, input *kinesis.DescribeStreamInput) (*kinesis.DescribeStreamOutput, error) { + return &kinesis.DescribeStreamOutput{ + StreamDescription: &kinesis.StreamDescription{ + StreamARN: aws.String("arn:aws:kinesis:us-east-1:123456789012:stream/existing-stream"), + }, + }, nil + }, + }, + Region: "us-west-1", + Credentials: testCreds, + }, + streamName: "existing-stream", + consumer: "consumer1", + mode: "exclusive", + expectedConfig: nil, + }, + { + name: "returns nil when client is nil", + kinesisClient: &KinesisClients{ + Kinesis: nil, + Region: "us-west-1", + Credentials: credentials.NewStaticCredentials("accessKey", "secretKey", ""), + }, + streamName: "existing-stream", + consumer: "consumer1", + mode: "shared", + expectedConfig: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := tt.kinesisClient.WorkerCfg(context.Background(), tt.streamName, tt.consumer, tt.mode) + if tt.expectedConfig == nil { + assert.Equal(t, tt.expectedConfig, cfg) + return + } + assert.Equal(t, tt.expectedConfig.StreamName, cfg.StreamName) + assert.Equal(t, tt.expectedConfig.EnhancedFanOutConsumerName, cfg.EnhancedFanOutConsumerName) + assert.Equal(t, tt.expectedConfig.EnableEnhancedFanOutConsumer, cfg.EnableEnhancedFanOutConsumer) + assert.Equal(t, tt.expectedConfig.RegionName, cfg.RegionName) + }) + } +} diff --git a/common/authentication/aws/static.go b/common/authentication/aws/static.go new file mode 100644 index 0000000000..a66ef86e1e --- /dev/null +++ b/common/authentication/aws/static.go @@ -0,0 +1,272 @@ +/* +Copyright 2021 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + "fmt" + "sync" + + awsv2 "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + v2creds "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" + + "github.com/dapr/kit/logger" +) + +type StaticAuth struct { + mu sync.RWMutex + logger logger.Logger + + region *string + endpoint *string + accessKey *string + secretKey *string + sessionToken *string + + session *session.Session + cfg *aws.Config + clients *Clients +} + +func newStaticIAM(_ context.Context, opts Options, cfg *aws.Config) (*StaticAuth, error) { + auth := &StaticAuth{ + logger: opts.Logger, + region: &opts.Region, + endpoint: &opts.Endpoint, + accessKey: &opts.AccessKey, + secretKey: &opts.SecretKey, + sessionToken: &opts.SessionToken, + cfg: func() *aws.Config { + // if nil is passed or it's just a default cfg, + // then we use the options to build the aws cfg. + if cfg != nil && cfg != aws.NewConfig() { + return cfg + } + return GetConfig(opts) + }(), + clients: newClients(), + } + + initialSession, err := auth.getTokenClient() + if err != nil { + return nil, fmt.Errorf("failed to get token client: %v", err) + } + + auth.session = initialSession + + return auth, nil +} + +// This is to be used only for test purposes to inject mocked clients +func (a *StaticAuth) WithMockClients(clients *Clients) { + a.clients = clients +} + +func (a *StaticAuth) S3() *S3Clients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.s3 != nil { + return a.clients.s3 + } + + s3Clients := S3Clients{} + a.clients.s3 = &s3Clients + a.clients.s3.New(a.session) + return a.clients.s3 +} + +func (a *StaticAuth) DynamoDB() *DynamoDBClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.Dynamo != nil { + return a.clients.Dynamo + } + + clients := DynamoDBClients{} + a.clients.Dynamo = &clients + a.clients.Dynamo.New(a.session) + + return a.clients.Dynamo +} + +func (a *StaticAuth) Sqs() *SqsClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.sqs != nil { + return a.clients.sqs + } + + clients := SqsClients{} + a.clients.sqs = &clients + a.clients.sqs.New(a.session) + + return a.clients.sqs +} + +func (a *StaticAuth) Sns() *SnsClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.sns != nil { + return a.clients.sns + } + + clients := SnsClients{} + a.clients.sns = &clients + a.clients.sns.New(a.session) + return a.clients.sns +} + +func (a *StaticAuth) SnsSqs() *SnsSqsClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.snssqs != nil { + return a.clients.snssqs + } + + clients := SnsSqsClients{} + a.clients.snssqs = &clients + a.clients.snssqs.New(a.session) + return a.clients.snssqs +} + +func (a *StaticAuth) SecretManager() *SecretManagerClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.Secret != nil { + return a.clients.Secret + } + + clients := SecretManagerClients{} + a.clients.Secret = &clients + a.clients.Secret.New(a.session) + return a.clients.Secret +} + +func (a *StaticAuth) ParameterStore() *ParameterStoreClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.ParameterStore != nil { + return a.clients.ParameterStore + } + + clients := ParameterStoreClients{} + a.clients.ParameterStore = &clients + a.clients.ParameterStore.New(a.session) + return a.clients.ParameterStore +} + +func (a *StaticAuth) Kinesis() *KinesisClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.kinesis != nil { + return a.clients.kinesis + } + + clients := KinesisClients{} + a.clients.kinesis = &clients + a.clients.kinesis.New(a.session) + return a.clients.kinesis +} + +func (a *StaticAuth) Ses() *SesClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.ses != nil { + return a.clients.ses + } + + clients := SesClients{} + a.clients.ses = &clients + a.clients.ses.New(a.session) + return a.clients.ses +} + +func (a *StaticAuth) getTokenClient() (*session.Session, error) { + var awsConfig *aws.Config + if a.cfg == nil { + awsConfig = aws.NewConfig() + } else { + awsConfig = a.cfg + } + + if a.region != nil { + awsConfig = awsConfig.WithRegion(*a.region) + } + + if a.accessKey != nil && a.secretKey != nil { + // session token is an option field + awsConfig = awsConfig.WithCredentials(credentials.NewStaticCredentials(*a.accessKey, *a.secretKey, *a.sessionToken)) + } + + if a.endpoint != nil { + awsConfig = awsConfig.WithEndpoint(*a.endpoint) + } + + awsSession, err := session.NewSessionWithOptions(session.Options{ + Config: *awsConfig, + SharedConfigState: session.SharedConfigEnable, + }) + if err != nil { + return nil, err + } + + userAgentHandler := request.NamedHandler{ + Name: "UserAgentHandler", + Fn: request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion), + } + awsSession.Handlers.Build.PushBackNamed(userAgentHandler) + + return awsSession, nil +} + +func (a *StaticAuth) Close() error { + return nil +} + +func GetConfigV2(accessKey string, secretKey string, sessionToken string, region string, endpoint string) (awsv2.Config, error) { + optFns := []func(*config.LoadOptions) error{} + if region != "" { + optFns = append(optFns, config.WithRegion(region)) + } + + if accessKey != "" && secretKey != "" { + provider := v2creds.NewStaticCredentialsProvider(accessKey, secretKey, sessionToken) + optFns = append(optFns, config.WithCredentialsProvider(provider)) + } + + awsCfg, err := config.LoadDefaultConfig(context.Background(), optFns...) + if err != nil { + return awsv2.Config{}, err + } + + if endpoint != "" { + awsCfg.BaseEndpoint = &endpoint + } + + return awsCfg, nil +} diff --git a/common/authentication/aws/static_test.go b/common/authentication/aws/static_test.go new file mode 100644 index 0000000000..a1a17a093c --- /dev/null +++ b/common/authentication/aws/static_test.go @@ -0,0 +1,66 @@ +package aws + +import ( + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetConfigV2(t *testing.T) { + tests := []struct { + name string + accessKey string + secretKey string + sessionToken string + region string + endpoint string + }{ + { + name: "valid config", + accessKey: "testAccessKey", + secretKey: "testSecretKey", + sessionToken: "testSessionToken", + region: "us-west-2", + endpoint: "https://test.endpoint.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + awsCfg, err := GetConfigV2(tt.accessKey, tt.secretKey, tt.sessionToken, tt.region, tt.endpoint) + require.NoError(t, err) + assert.NotNil(t, awsCfg) + assert.Equal(t, tt.region, awsCfg.Region) + assert.Equal(t, tt.endpoint, *awsCfg.BaseEndpoint) + }) + } +} + +func TestGetTokenClient(t *testing.T) { + tests := []struct { + name string + awsInstance *StaticAuth + }{ + { + name: "valid token client", + awsInstance: &StaticAuth{ + accessKey: aws.String("testAccessKey"), + secretKey: aws.String("testSecretKey"), + sessionToken: aws.String("testSessionToken"), + region: aws.String("us-west-2"), + endpoint: aws.String("https://test.endpoint.com"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + session, err := tt.awsInstance.getTokenClient() + require.NotNil(t, session) + require.NoError(t, err) + assert.Equal(t, tt.awsInstance.region, session.Config.Region) + }) + } +} diff --git a/common/authentication/aws/x509.go b/common/authentication/aws/x509.go new file mode 100644 index 0000000000..cb1bafdeb3 --- /dev/null +++ b/common/authentication/aws/x509.go @@ -0,0 +1,449 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + "crypto/ecdsa" + "crypto/tls" + cryptoX509 "crypto/x509" + "errors" + "fmt" + "net/http" + "runtime" + "sync" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/arn" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" + awssh "github.com/aws/rolesanywhere-credential-helper/aws_signing_helper" + "github.com/aws/rolesanywhere-credential-helper/rolesanywhere" + "github.com/aws/rolesanywhere-credential-helper/rolesanywhere/rolesanywhereiface" + + cryptopem "github.com/dapr/kit/crypto/pem" + spiffecontext "github.com/dapr/kit/crypto/spiffe/context" + "github.com/dapr/kit/logger" + kitmd "github.com/dapr/kit/metadata" + "github.com/dapr/kit/ptr" +) + +type x509Options struct { + TrustProfileArn *string `json:"trustProfileArn" mapstructure:"trustProfileArn"` + TrustAnchorArn *string `json:"trustAnchorArn" mapstructure:"trustAnchorArn"` + AssumeRoleArn *string `json:"assumeRoleArn" mapstructure:"assumeRoleArn"` +} + +type x509 struct { + mu sync.RWMutex + wg sync.WaitGroup + closeCh chan struct{} + + logger logger.Logger + clients *Clients + rolesAnywhereClient rolesanywhereiface.RolesAnywhereAPI // this is so we can mock it in tests + session *session.Session + cfg *aws.Config + + chainPEM []byte + keyPEM []byte + + region *string + trustProfileArn *string + trustAnchorArn *string + assumeRoleArn *string +} + +func newX509(ctx context.Context, opts Options, cfg *aws.Config) (*x509, error) { + var x509Auth x509Options + if err := kitmd.DecodeMetadata(opts.Properties, &x509Auth); err != nil { + return nil, err + } + + switch { + case x509Auth.TrustProfileArn == nil: + return nil, errors.New("trustProfileArn is required") + case x509Auth.TrustAnchorArn == nil: + return nil, errors.New("trustAnchorArn is required") + case x509Auth.AssumeRoleArn == nil: + return nil, errors.New("assumeRoleArn is required") + } + + auth := &x509{ + logger: opts.Logger, + trustProfileArn: x509Auth.TrustProfileArn, + trustAnchorArn: x509Auth.TrustAnchorArn, + assumeRoleArn: x509Auth.AssumeRoleArn, + cfg: func() *aws.Config { + // if nil is passed or it's just a default cfg, + // then we use the options to build the aws cfg. + if cfg != nil && cfg != aws.NewConfig() { + return cfg + } + return GetConfig(opts) + }(), + clients: newClients(), + } + + if err := auth.getCertPEM(ctx); err != nil { + return nil, fmt.Errorf("failed to get x.509 credentials: %v", err) + } + + // Parse trust anchor and profile ARNs + if err := auth.initializeTrustAnchors(); err != nil { + return nil, err + } + + initialSession, err := auth.createOrRefreshSession(ctx) + if err != nil { + return nil, fmt.Errorf("failed to create the initial session: %v", err) + } + auth.session = initialSession + auth.startSessionRefresher() + + return auth, nil +} + +func (a *x509) Close() error { + close(a.closeCh) + a.wg.Wait() + return nil +} + +func (a *x509) getCertPEM(ctx context.Context) error { + // retrieve svid from spiffe context + svid, ok := spiffecontext.From(ctx) + if !ok { + return errors.New("no SVID found in context") + } + // get x.509 svid + svidx, err := svid.GetX509SVID() + if err != nil { + return err + } + + // marshal x.509 svid to pem format + chainPEM, keyPEM, err := svidx.Marshal() + if err != nil { + return fmt.Errorf("failed to marshal SVID: %w", err) + } + + a.chainPEM = chainPEM + a.keyPEM = keyPEM + return nil +} + +func (a *x509) S3() *S3Clients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.s3 != nil { + return a.clients.s3 + } + + s3Clients := S3Clients{} + a.clients.s3 = &s3Clients + a.clients.s3.New(a.session) + return a.clients.s3 +} + +func (a *x509) DynamoDB() *DynamoDBClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.Dynamo != nil { + return a.clients.Dynamo + } + + clients := DynamoDBClients{} + a.clients.Dynamo = &clients + a.clients.Dynamo.New(a.session) + + return a.clients.Dynamo +} + +func (a *x509) Sqs() *SqsClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.sqs != nil { + return a.clients.sqs + } + + clients := SqsClients{} + a.clients.sqs = &clients + a.clients.sqs.New(a.session) + + return a.clients.sqs +} + +func (a *x509) Sns() *SnsClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.sns != nil { + return a.clients.sns + } + + clients := SnsClients{} + a.clients.sns = &clients + a.clients.sns.New(a.session) + return a.clients.sns +} + +func (a *x509) SnsSqs() *SnsSqsClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.snssqs != nil { + return a.clients.snssqs + } + + clients := SnsSqsClients{} + a.clients.snssqs = &clients + a.clients.snssqs.New(a.session) + return a.clients.snssqs +} + +func (a *x509) SecretManager() *SecretManagerClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.Secret != nil { + return a.clients.Secret + } + + clients := SecretManagerClients{} + a.clients.Secret = &clients + a.clients.Secret.New(a.session) + return a.clients.Secret +} + +func (a *x509) ParameterStore() *ParameterStoreClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.ParameterStore != nil { + return a.clients.ParameterStore + } + + clients := ParameterStoreClients{} + a.clients.ParameterStore = &clients + a.clients.ParameterStore.New(a.session) + return a.clients.ParameterStore +} + +func (a *x509) Kinesis() *KinesisClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.kinesis != nil { + return a.clients.kinesis + } + + clients := KinesisClients{} + a.clients.kinesis = &clients + a.clients.kinesis.New(a.session) + return a.clients.kinesis +} + +func (a *x509) Ses() *SesClients { + a.mu.Lock() + defer a.mu.Unlock() + + if a.clients.ses != nil { + return a.clients.ses + } + + clients := SesClients{} + a.clients.ses = &clients + a.clients.ses.New(a.session) + return a.clients.ses +} + +func (a *x509) initializeTrustAnchors() error { + var ( + trustAnchor arn.ARN + profile arn.ARN + err error + ) + if a.trustAnchorArn != nil { + trustAnchor, err = arn.Parse(*a.trustAnchorArn) + if err != nil { + return err + } + a.region = &trustAnchor.Region + } + + if a.trustProfileArn != nil { + profile, err = arn.Parse(*a.trustProfileArn) + if err != nil { + return err + } + + if profile.Region != "" && trustAnchor.Region != profile.Region { + return fmt.Errorf("trust anchor and profile must be in the same region: trustAnchor=%s, profile=%s", + trustAnchor.Region, profile.Region) + } + } + return nil +} + +func (a *x509) setSigningFunction(rolesAnywhereClient *rolesanywhere.RolesAnywhere) error { + certs, err := cryptopem.DecodePEMCertificatesChain(a.chainPEM) + if err != nil { + return err + } + + ints := make([]cryptoX509.Certificate, 0, len(certs)-1) + for i := range certs[1:] { + ints = append(ints, *certs[i+1]) + } + + key, err := cryptopem.DecodePEMPrivateKey(a.keyPEM) + if err != nil { + return err + } + + keyECDSA := key.(*ecdsa.PrivateKey) + signFunc := awssh.CreateSignFunction(*keyECDSA, *certs[0], ints) + + agentHandlerFunc := request.MakeAddToUserAgentHandler("dapr", logger.DaprVersion, runtime.Version(), runtime.GOOS, runtime.GOARCH) + rolesAnywhereClient.Handlers.Build.RemoveByName("core.SDKVersionUserAgentHandler") + rolesAnywhereClient.Handlers.Build.PushBackNamed(request.NamedHandler{Name: "v4x509.CredHelperUserAgentHandler", Fn: agentHandlerFunc}) + rolesAnywhereClient.Handlers.Sign.Clear() + rolesAnywhereClient.Handlers.Sign.PushBackNamed(request.NamedHandler{Name: "v4x509.SignRequestHandler", Fn: signFunc}) + + return nil +} + +func (a *x509) createOrRefreshSession(ctx context.Context) (*session.Session, error) { + a.mu.Lock() + defer a.mu.Unlock() + + client := &http.Client{Transport: &http.Transport{ + TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12}, + }} + var mySession *session.Session + + var awsConfig *aws.Config + if a.cfg == nil { + awsConfig = aws.NewConfig().WithHTTPClient(client).WithLogLevel(aws.LogOff) + } else { + awsConfig = a.cfg.WithHTTPClient(client).WithLogLevel(aws.LogOff) + } + if a.region != nil { + awsConfig.WithRegion(*a.region) + } + // this is needed for testing purposes to mock the client, + // so code never sets the client, but tests do. + var rolesClient *rolesanywhere.RolesAnywhere + if a.rolesAnywhereClient == nil { + mySession = session.Must(session.NewSession(awsConfig)) + rolesAnywhereClient := rolesanywhere.New(mySession, awsConfig) + // Set up signing function and handlers + if err := a.setSigningFunction(rolesAnywhereClient); err != nil { + return nil, err + } + rolesClient = rolesAnywhereClient + } + + createSessionRequest := rolesanywhere.CreateSessionInput{ + Cert: ptr.Of(string(a.chainPEM)), + ProfileArn: a.trustProfileArn, + TrustAnchorArn: a.trustAnchorArn, + RoleArn: a.assumeRoleArn, + // https://aws.amazon.com/about-aws/whats-new/2024/03/iam-roles-anywhere-credentials-valid-12-hours/#:~:text=The%20duration%20can%20range%20from,and%20applications%2C%20to%20use%20X. + DurationSeconds: aws.Int64(int64(time.Hour.Seconds())), // AWS default is 1hr timeout + InstanceProperties: nil, + SessionName: nil, + } + + var output *rolesanywhere.CreateSessionOutput + if a.rolesAnywhereClient != nil { + var err error + output, err = a.rolesAnywhereClient.CreateSessionWithContext(ctx, &createSessionRequest) + if err != nil { + return nil, fmt.Errorf("failed to create session using dapr app identity: %w", err) + } + } else { + var err error + output, err = rolesClient.CreateSessionWithContext(ctx, &createSessionRequest) + if err != nil { + return nil, fmt.Errorf("failed to create session using dapr app identity: %w", err) + } + } + + if output == nil || len(output.CredentialSet) != 1 { + return nil, fmt.Errorf("expected 1 credential set from X.509 rolesanyway response, got %d", len(output.CredentialSet)) + } + + accessKey := output.CredentialSet[0].Credentials.AccessKeyId + secretKey := output.CredentialSet[0].Credentials.SecretAccessKey + sessionToken := output.CredentialSet[0].Credentials.SessionToken + awsCreds := credentials.NewStaticCredentials(*accessKey, *secretKey, *sessionToken) + sess := session.Must(session.NewSession(&aws.Config{ + Credentials: awsCreds, + }, awsConfig)) + if sess == nil { + return nil, errors.New("session is nil") + } + + return sess, nil +} + +func (a *x509) startSessionRefresher() { + a.logger.Infof("starting session refresher for x509 auth") + + a.wg.Add(1) + go func() { + defer a.wg.Done() + for { + // renew at ~half the lifespan + expiration, err := a.session.Config.Credentials.ExpiresAt() + if err != nil { + a.logger.Errorf("Failed to retrieve session expiration time, using 30 minute interval: %w", err) + expiration = time.Now().Add(time.Hour) + } + timeUntilExpiration := time.Until(expiration) + refreshInterval := timeUntilExpiration / 2 + select { + case <-time.After(refreshInterval): + a.refreshClient() + case <-a.closeCh: + a.logger.Debugf("Session refresher is stopped") + return + } + } + }() +} + +func (a *x509) refreshClient() { + for { + newSession, err := a.createOrRefreshSession(context.Background()) + if err == nil { + a.clients.refresh(newSession) + a.logger.Debugf("AWS IAM Roles Anywhere session credentials refreshed successfully") + return + } + a.logger.Errorf("Failed to refresh session, retrying in 5 seconds: %w", err) + select { + case <-time.After(time.Second * 5): + case <-a.closeCh: + return + } + } +} diff --git a/common/authentication/aws/x509_test.go b/common/authentication/aws/x509_test.go new file mode 100644 index 0000000000..3f7d2189c3 --- /dev/null +++ b/common/authentication/aws/x509_test.go @@ -0,0 +1,125 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "context" + cryptoX509 "crypto/x509" + "sync/atomic" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/rolesanywhere-credential-helper/rolesanywhere" + "github.com/aws/rolesanywhere-credential-helper/rolesanywhere/rolesanywhereiface" + "github.com/spiffe/go-spiffe/v2/spiffeid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/dapr/kit/crypto/spiffe" + spiffecontext "github.com/dapr/kit/crypto/spiffe/context" + "github.com/dapr/kit/crypto/test" + "github.com/dapr/kit/logger" +) + +type mockRolesAnywhereClient struct { + rolesanywhereiface.RolesAnywhereAPI + + CreateSessionOutput *rolesanywhere.CreateSessionOutput + CreateSessionError error +} + +func (m *mockRolesAnywhereClient) CreateSessionWithContext(ctx context.Context, input *rolesanywhere.CreateSessionInput, opts ...request.Option) (*rolesanywhere.CreateSessionOutput, error) { + return m.CreateSessionOutput, m.CreateSessionError +} + +func TestGetX509Client(t *testing.T) { + tests := []struct { + name string + mockOutput *rolesanywhere.CreateSessionOutput + mockError error + }{ + { + name: "valid x509 client", + mockOutput: &rolesanywhere.CreateSessionOutput{ + CredentialSet: []*rolesanywhere.CredentialResponse{ + { + Credentials: &rolesanywhere.Credentials{ + AccessKeyId: aws.String("mockAccessKeyId"), + SecretAccessKey: aws.String("mockSecretAccessKey"), + SessionToken: aws.String("mockSessionToken"), + Expiration: aws.String(time.Now().Add(15 * time.Minute).Format(time.RFC3339)), + }, + }, + }, + }, + mockError: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockSvc := &mockRolesAnywhereClient{ + CreateSessionOutput: tt.mockOutput, + CreateSessionError: tt.mockError, + } + mockAWS := x509{ + logger: logger.NewLogger("testLogger"), + assumeRoleArn: aws.String("arn:aws:iam:012345678910:role/exampleIAMRoleName"), + trustAnchorArn: aws.String("arn:aws:rolesanywhere:us-west-1:012345678910:trust-anchor/01234568-0123-0123-0123-012345678901"), + trustProfileArn: aws.String("arn:aws:rolesanywhere:us-west-1:012345678910:profile/01234568-0123-0123-0123-012345678901"), + rolesAnywhereClient: mockSvc, + } + pki := test.GenPKI(t, test.PKIOptions{ + LeafID: spiffeid.RequireFromString("spiffe://example.com/foo/bar"), + }) + + respCert := []*cryptoX509.Certificate{pki.LeafCert} + var respErr error + + var fetches atomic.Int32 + s := spiffe.New(spiffe.Options{ + Log: logger.NewLogger("test"), + RequestSVIDFn: func(context.Context, []byte) ([]*cryptoX509.Certificate, error) { + fetches.Add(1) + return respCert, respErr + }, + }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + errCh := make(chan error) + go func() { + errCh <- s.Run(ctx) + }() + + select { + case err := <-errCh: + require.NoError(t, err) + default: + } + + err := s.Ready(ctx) + require.NoError(t, err) + + // inject the SVID source into the context + ctx = spiffecontext.With(ctx, s) + session, err := mockAWS.createOrRefreshSession(ctx) + + require.NoError(t, err) + assert.NotNil(t, session) + }) + } +} diff --git a/common/authentication/postgresql/metadata.go b/common/authentication/postgresql/metadata.go index 7cacecfaa4..4b2135ba6a 100644 --- a/common/authentication/postgresql/metadata.go +++ b/common/authentication/postgresql/metadata.go @@ -162,7 +162,7 @@ func (m *PostgresAuthMetadata) GetPgxPoolConfig() (*pgxpool.Config, error) { return nil, err } - awsOpts := aws.AWSIAMAuthOptions{ + awsOpts := aws.Options{ PoolConfig: config, ConnectionString: m.ConnectionString, Region: awsRegion, diff --git a/go.mod b/go.mod index 13a6af3ab6..a8ece053a6 100644 --- a/go.mod +++ b/go.mod @@ -46,6 +46,7 @@ require ( github.com/aws/aws-sdk-go-v2/credentials v1.17.37 github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.3.10 github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.17.3 + github.com/aws/rolesanywhere-credential-helper v1.0.4 github.com/bradfitz/gomemcache v0.0.0-20230905024940-24af94b03874 github.com/camunda/zeebe/clients/go/v8 v8.2.12 github.com/cenkalti/backoff/v4 v4.2.1 @@ -106,6 +107,7 @@ require ( github.com/sendgrid/sendgrid-go v3.13.0+incompatible github.com/sijms/go-ora/v2 v2.7.18 github.com/spf13/cast v1.5.1 + github.com/spiffe/go-spiffe/v2 v2.1.7 github.com/stealthrocket/wasi-go v0.8.1-0.20230912180546-8efbab50fb58 github.com/stretchr/testify v1.9.0 github.com/supplyon/gremcos v0.1.40 @@ -379,6 +381,7 @@ require ( github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 // indirect github.com/yuin/gopher-lua v1.1.0 // indirect github.com/yusufpapurcu/wmi v1.2.3 // indirect + github.com/zeebo/errs v1.3.0 // indirect go.etcd.io/etcd/api/v3 v3.5.9 // indirect go.etcd.io/etcd/client/pkg/v3 v3.5.9 // indirect go.opencensus.io v0.24.0 // indirect @@ -402,6 +405,7 @@ require ( google.golang.org/genproto v0.0.0-20240401170217-c3f982113cda // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240513163218-0867130af1f8 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240509183442-62759503f434 // indirect + google.golang.org/grpc/examples v0.0.0-20230224211313-3775f633ce20 // indirect gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect gopkg.in/couchbase/gocbcore.v7 v7.1.18 // indirect gopkg.in/couchbaselabs/gocbconnstr.v1 v1.0.4 // indirect diff --git a/go.sum b/go.sum index 54c28416d8..9bafb4a502 100644 --- a/go.sum +++ b/go.sum @@ -124,6 +124,8 @@ github.com/HdrHistogram/hdrhistogram-go v1.1.2/go.mod h1:yDgFjdqOqDEKOvasDdhWNXY github.com/IBM/sarama v1.43.3 h1:Yj6L2IaNvb2mRBop39N7mmJAHBVY3dTPncr3qGVkxPA= github.com/IBM/sarama v1.43.3/go.mod h1:FVIRaLrhK3Cla/9FfRF5X9Zua2KpS3SYIXxhac1H+FQ= github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= +github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= +github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= github.com/Netflix/go-env v0.0.0-20220526054621-78278af1949d h1:wvStE9wLpws31NiWUx+38wny1msZ/tm+eL5xmm4Y7So= github.com/Netflix/go-env v0.0.0-20220526054621-78278af1949d/go.mod h1:9XMFaCeRyW7fC9XJOWQ+NdAv8VLG7ys7l3x4ozEGLUQ= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= @@ -289,6 +291,8 @@ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.27.3/go.mod h1:FnvDM4sfa+isJ3kDX github.com/aws/aws-sdk-go-v2/service/sts v1.7.2/go.mod h1:8EzeIqfWt2wWT4rJVu3f21TfrhJ8AEMzVybRNSb/b4g= github.com/aws/aws-sdk-go-v2/service/sts v1.31.3 h1:VzudTFrDCIDakXtemR7l6Qzt2+JYsVqo2MxBPt5k8T8= github.com/aws/aws-sdk-go-v2/service/sts v1.31.3/go.mod h1:yMWe0F+XG0DkRZK5ODZhG7BEFYhLXi2dqGsv6tX0cgI= +github.com/aws/rolesanywhere-credential-helper v1.0.4 h1:kHIVVdyQQiFZoKBP+zywBdFilGCS8It+UvW5LolKbW8= +github.com/aws/rolesanywhere-credential-helper v1.0.4/go.mod h1:QVGNxlDlYhjR0/ZUee7uGl0hNChWidNpe2+GD87Buqk= github.com/aws/smithy-go v1.8.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E= github.com/aws/smithy-go v1.21.0 h1:H7L8dtDRk0P1Qm6y0ji7MCYMQObJ5R9CRpyPhRUkLYA= github.com/aws/smithy-go v1.21.0/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= @@ -604,6 +608,8 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2 github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= +github.com/go-jose/go-jose/v3 v3.0.1 h1:pWmKFVtt+Jl0vBZTIpz/eAKwsm6LkIxDVVbFHKkchhA= +github.com/go-jose/go-jose/v3 v3.0.1/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.10.0 h1:dXFJfIHVvUcpSgDOV+Ne6t7jXri8Tfv2uOLHUZ2XNuo= @@ -1517,6 +1523,8 @@ github.com/spf13/viper v1.7.0/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5q github.com/spf13/viper v1.7.1/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg= github.com/spf13/viper v1.15.0 h1:js3yy885G8xwJa6iOISGFwd+qlUo5AvyXb7CiihdtiU= github.com/spf13/viper v1.15.0/go.mod h1:fFcTBJxvhhzSJiZy8n+PeW6t8l+KeT/uTARa0jHOQLA= +github.com/spiffe/go-spiffe/v2 v2.1.7 h1:VUkM1yIyg/x8X7u1uXqSRVRCdMdfRIEdFBzpqoeASGk= +github.com/spiffe/go-spiffe/v2 v2.1.7/go.mod h1:QJDGdhXllxjxvd5B+2XnhhXB/+rC8gr+lNrtOryiWeE= github.com/stealthrocket/wasi-go v0.8.1-0.20230912180546-8efbab50fb58 h1:mTC4gyv3lcJ1XpzZMAckqkvWUqeT5Bva4RAT1IoHAAA= github.com/stealthrocket/wasi-go v0.8.1-0.20230912180546-8efbab50fb58/go.mod h1:ZAYCOqLJkc9P6fcq14TV4cf+gJ2fHthp9kCGxBViagE= github.com/stealthrocket/wazergo v0.19.1 h1:BPrITETPgSFwiytwmToO0MbUC/+RGC39JScz1JmmG6c= @@ -1652,6 +1660,8 @@ github.com/yuin/gopher-lua v1.1.0/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7 github.com/yusufpapurcu/wmi v1.2.2/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw= github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= +github.com/zeebo/errs v1.3.0 h1:hmiaKqgYZzcVgRL1Vkc1Mn2914BbzB0IBxs+ebeutGs= +github.com/zeebo/errs v1.3.0/go.mod h1:sgbWHsvVuTPHcqJJGQ1WhI5KbWlHYz+2+2C/LSEtCw4= github.com/zouyx/agollo/v3 v3.4.5 h1:7YCxzY9ZYaH9TuVUBvmI6Tk0mwMggikah+cfbYogcHQ= github.com/zouyx/agollo/v3 v3.4.5/go.mod h1:LJr3kDmm23QSW+F1Ol4TMHDa7HvJvscMdVxJ2IpUTVc= go.einride.tech/aip v0.66.0 h1:XfV+NQX6L7EOYK11yoHHFtndeaWh3KbD9/cN/6iWEt8= @@ -2316,6 +2326,8 @@ google.golang.org/grpc v1.47.0/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACu google.golang.org/grpc v1.48.0/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACuMGWk= google.golang.org/grpc v1.64.0 h1:KH3VH9y/MgNQg1dE7b3XfVK0GsPSIzJwdF617gUSbvY= google.golang.org/grpc v1.64.0/go.mod h1:oxjF8E3FBnjp+/gVFYdWacaLDx9na1aqy9oovLpxQYg= +google.golang.org/grpc/examples v0.0.0-20230224211313-3775f633ce20 h1:MLBCGN1O7GzIx+cBiwfYPwtmZ41U3Mn/cotLJciaArI= +google.golang.org/grpc/examples v0.0.0-20230224211313-3775f633ce20/go.mod h1:Nr5H8+MlGWr5+xX/STzdoEqJrO+YteqFbMyCsrb6mH0= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/pubsub/aws/snssqs/metadata.go b/pubsub/aws/snssqs/metadata.go index db45fb8d84..4b469106b7 100644 --- a/pubsub/aws/snssqs/metadata.go +++ b/pubsub/aws/snssqs/metadata.go @@ -67,7 +67,7 @@ func maskLeft(s string) string { return string(rs) } -func (s *snsSqs) getSnsSqsMetatdata(meta pubsub.Metadata) (*snsSqsMetadata, error) { +func (s *snsSqs) getSnsSqsMetadata(meta pubsub.Metadata) (*snsSqsMetadata, error) { md := &snsSqsMetadata{ AssetsManagementTimeoutSeconds: assetsManagementDefaultTimeoutSeconds, MessageVisibilityTimeout: 10, diff --git a/pubsub/aws/snssqs/snssqs.go b/pubsub/aws/snssqs/snssqs.go index 357cfcabb9..93481fb733 100644 --- a/pubsub/aws/snssqs/snssqs.go +++ b/pubsub/aws/snssqs/snssqs.go @@ -49,9 +49,7 @@ type snsSqs struct { queues map[string]*sqsQueueInfo // key is a composite key of queue ARN and topic ARN mapping to subscription ARN. subscriptions map[string]string - snsClient *sns.SNS - sqsClient *sqs.SQS - stsClient *sts.STS + authProvider awsAuth.Provider metadata *snsSqsMetadata logger logger.Logger id string @@ -138,23 +136,33 @@ func nameToAWSSanitizedName(name string, isFifo bool) string { } func (s *snsSqs) Init(ctx context.Context, metadata pubsub.Metadata) error { - md, err := s.getSnsSqsMetatdata(metadata) + m, err := s.getSnsSqsMetadata(metadata) if err != nil { return err } - s.metadata = md + s.metadata = m - sess, err := awsAuth.GetClient(md.AccessKey, md.SecretKey, md.SessionToken, md.Region, md.Endpoint) - if err != nil { - return fmt.Errorf("error creating an AWS client: %w", err) + if s.authProvider == nil { + opts := awsAuth.Options{ + Logger: s.logger, + Properties: metadata.Properties, + Region: m.Region, + Endpoint: m.Endpoint, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: m.SessionToken, + } + // extra configs needed per component type + var provider awsAuth.Provider + provider, err = awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) + if err != nil { + return err + } + s.authProvider = provider } - // AWS sns,sqs,sts client. - s.snsClient = sns.New(sess) - s.sqsClient = sqs.New(sess) - s.stsClient = sts.New(sess) - s.opsTimeout = time.Duration(md.AssetsManagementTimeoutSeconds * float64(time.Second)) + s.opsTimeout = time.Duration(m.AssetsManagementTimeoutSeconds * float64(time.Second)) err = s.setAwsAccountIDIfNotProvided(ctx) if err != nil { @@ -181,9 +189,8 @@ func (s *snsSqs) setAwsAccountIDIfNotProvided(parentCtx context.Context) error { if len(s.metadata.AccountID) == awsAccountIDLength { return nil } - ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) - callerIDOutput, err := s.stsClient.GetCallerIdentityWithContext(ctx, &sts.GetCallerIdentityInput{}) + callerIDOutput, err := s.authProvider.SnsSqs().Sts.GetCallerIdentityWithContext(ctx, &sts.GetCallerIdentityInput{}) cancelFn() if err != nil { return fmt.Errorf("error fetching sts caller ID: %w", err) @@ -208,9 +215,8 @@ func (s *snsSqs) createTopic(parentCtx context.Context, topic string) (string, e attributes := map[string]*string{"FifoTopic": aws.String("true"), "ContentBasedDeduplication": aws.String("true")} snsCreateTopicInput.SetAttributes(attributes) } - ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) - createTopicResponse, err := s.snsClient.CreateTopicWithContext(ctx, snsCreateTopicInput) + createTopicResponse, err := s.authProvider.SnsSqs().Sns.CreateTopicWithContext(ctx, snsCreateTopicInput) cancelFn() if err != nil { return "", fmt.Errorf("error while creating an SNS topic: %w", err) @@ -222,7 +228,7 @@ func (s *snsSqs) createTopic(parentCtx context.Context, topic string) (string, e func (s *snsSqs) getTopicArn(parentCtx context.Context, topic string) (string, error) { ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) arn := s.buildARN("sns", topic) - getTopicOutput, err := s.snsClient.GetTopicAttributesWithContext(ctx, &sns.GetTopicAttributesInput{ + getTopicOutput, err := s.authProvider.SnsSqs().Sns.GetTopicAttributesWithContext(ctx, &sns.GetTopicAttributesInput{ TopicArn: &arn, }) cancelFn() @@ -288,15 +294,16 @@ func (s *snsSqs) createQueue(parentCtx context.Context, queueName string) (*sqsQ attributes := map[string]*string{"FifoQueue": aws.String("true"), "ContentBasedDeduplication": aws.String("true")} sqsCreateQueueInput.SetAttributes(attributes) } + ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout) - createQueueResponse, err := s.sqsClient.CreateQueueWithContext(ctx, sqsCreateQueueInput) + createQueueResponse, err := s.authProvider.SnsSqs().Sqs.CreateQueueWithContext(ctx, sqsCreateQueueInput) cancel() if err != nil { return nil, fmt.Errorf("error creaing an SQS queue: %w", err) } ctx, cancel = context.WithTimeout(parentCtx, s.opsTimeout) - queueAttributesResponse, err := s.sqsClient.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{ + queueAttributesResponse, err := s.authProvider.SnsSqs().Sqs.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{ AttributeNames: []*string{aws.String("QueueArn")}, QueueUrl: createQueueResponse.QueueUrl, }) @@ -313,7 +320,7 @@ func (s *snsSqs) createQueue(parentCtx context.Context, queueName string) (*sqsQ func (s *snsSqs) getQueueArn(parentCtx context.Context, queueName string) (*sqsQueueInfo, error) { ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout) - queueURLOutput, err := s.sqsClient.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{QueueName: aws.String(queueName), QueueOwnerAWSAccountId: aws.String(s.metadata.AccountID)}) + queueURLOutput, err := s.authProvider.SnsSqs().Sqs.GetQueueUrlWithContext(ctx, &sqs.GetQueueUrlInput{QueueName: aws.String(queueName), QueueOwnerAWSAccountId: aws.String(s.metadata.AccountID)}) cancel() if err != nil { return nil, fmt.Errorf("error: %w while getting url of queue: %s", err, queueName) @@ -321,7 +328,7 @@ func (s *snsSqs) getQueueArn(parentCtx context.Context, queueName string) (*sqsQ url := queueURLOutput.QueueUrl ctx, cancel = context.WithTimeout(parentCtx, s.opsTimeout) - getQueueOutput, err := s.sqsClient.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{QueueUrl: url, AttributeNames: []*string{aws.String("QueueArn")}}) + getQueueOutput, err := s.authProvider.SnsSqs().Sqs.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{QueueUrl: url, AttributeNames: []*string{aws.String("QueueArn")}}) cancel() if err != nil { return nil, fmt.Errorf("error: %w while getting information for queue: %s, with url: %s", err, queueName, *url) @@ -382,7 +389,7 @@ func (s *snsSqs) getMessageGroupID(req *pubsub.PublishRequest) *string { func (s *snsSqs) createSnsSqsSubscription(parentCtx context.Context, queueArn, topicArn string) (string, error) { ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout) - subscribeOutput, err := s.snsClient.SubscribeWithContext(ctx, &sns.SubscribeInput{ + subscribeOutput, err := s.authProvider.SnsSqs().Sns.SubscribeWithContext(ctx, &sns.SubscribeInput{ Attributes: nil, Endpoint: aws.String(queueArn), // create SQS queue per subscription. Protocol: aws.String("sqs"), @@ -402,7 +409,7 @@ func (s *snsSqs) createSnsSqsSubscription(parentCtx context.Context, queueArn, t func (s *snsSqs) getSnsSqsSubscriptionArn(parentCtx context.Context, topicArn string) (string, error) { ctx, cancel := context.WithTimeout(parentCtx, s.opsTimeout) - listSubscriptionsOutput, err := s.snsClient.ListSubscriptionsByTopicWithContext(ctx, &sns.ListSubscriptionsByTopicInput{TopicArn: aws.String(topicArn)}) + listSubscriptionsOutput, err := s.authProvider.SnsSqs().Sns.ListSubscriptionsByTopicWithContext(ctx, &sns.ListSubscriptionsByTopicInput{TopicArn: aws.String(topicArn)}) cancel() if err != nil { return "", fmt.Errorf("error listing subsriptions for topic arn: %v: %w", topicArn, err) @@ -451,7 +458,7 @@ func (s *snsSqs) getOrCreateSnsSqsSubscription(ctx context.Context, queueArn, to func (s *snsSqs) acknowledgeMessage(parentCtx context.Context, queueURL string, receiptHandle *string) error { ctx, cancelFn := context.WithCancel(parentCtx) - _, err := s.sqsClient.DeleteMessageWithContext(ctx, &sqs.DeleteMessageInput{ + _, err := s.authProvider.SnsSqs().Sqs.DeleteMessageWithContext(ctx, &sqs.DeleteMessageInput{ QueueUrl: aws.String(queueURL), ReceiptHandle: receiptHandle, }) @@ -466,7 +473,7 @@ func (s *snsSqs) acknowledgeMessage(parentCtx context.Context, queueURL string, func (s *snsSqs) resetMessageVisibilityTimeout(parentCtx context.Context, queueURL string, receiptHandle *string) error { ctx, cancelFn := context.WithCancel(parentCtx) // reset the timeout to its initial value so that the remaining timeout would be overridden by the initial value for other consumer to attempt processing. - _, err := s.sqsClient.ChangeMessageVisibilityWithContext(ctx, &sqs.ChangeMessageVisibilityInput{ + _, err := s.authProvider.SnsSqs().Sqs.ChangeMessageVisibilityWithContext(ctx, &sqs.ChangeMessageVisibilityInput{ QueueUrl: aws.String(queueURL), ReceiptHandle: receiptHandle, VisibilityTimeout: aws.Int64(0), @@ -593,12 +600,11 @@ func (s *snsSqs) consumeSubscription(ctx context.Context, queueInfo, deadLetters if ctx.Err() != nil { break } - - // Internally, by default, aws go sdk performs 3 retires with exponential backoff to contact + // Internally, by default, aws go sdk performs 3 retries with exponential backoff to contact // sqs and try pull messages. Since we are iteratively short polling (based on the defined // s.metadata.messageWaitTimeSeconds) the sdk backoff is not effective as it gets reset per each polling // iteration. Therefore, a global backoff (to the internal backoff) is used (sqsPullExponentialBackoff). - messageResponse, err := s.sqsClient.ReceiveMessageWithContext(ctx, receiveMessageInput) + messageResponse, err := s.authProvider.SnsSqs().Sqs.ReceiveMessageWithContext(ctx, receiveMessageInput) if err != nil { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) || ctx.Err() != nil { s.logger.Warn("context canceled; stopping consuming from queue arn: %v", queueInfo.arn) @@ -690,9 +696,8 @@ func (s *snsSqs) setDeadLettersQueueAttributes(parentCtx context.Context, queueI return wrappedErr } - ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) - _, derr = s.sqsClient.SetQueueAttributesWithContext(ctx, sqsSetQueueAttributesInput) + _, derr = s.authProvider.SnsSqs().Sqs.SetQueueAttributesWithContext(ctx, sqsSetQueueAttributesInput) cancelFn() if derr != nil { wrappedErr := fmt.Errorf("error updating queue attributes with dead-letter queue: %w", derr) @@ -712,7 +717,7 @@ func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(parentCtx context.Context, ctx, cancelFn := context.WithTimeout(parentCtx, s.opsTimeout) // only permit SNS to send messages to SQS using the created subscription. - getQueueAttributesOutput, err := s.sqsClient.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{ + getQueueAttributesOutput, err := s.authProvider.SnsSqs().Sqs.GetQueueAttributesWithContext(ctx, &sqs.GetQueueAttributesInput{ QueueUrl: &sqsQueueInfo.url, AttributeNames: []*string{aws.String(sqs.QueueAttributeNamePolicy)}, }) @@ -739,7 +744,7 @@ func (s *snsSqs) restrictQueuePublishPolicyToOnlySNS(parentCtx context.Context, } ctx, cancelFn = context.WithTimeout(parentCtx, s.opsTimeout) - _, err = s.sqsClient.SetQueueAttributesWithContext(ctx, &(sqs.SetQueueAttributesInput{ + _, err = s.authProvider.SnsSqs().Sqs.SetQueueAttributesWithContext(ctx, &(sqs.SetQueueAttributesInput{ Attributes: map[string]*string{ "Policy": aws.String(string(b)), }, @@ -852,7 +857,7 @@ func (s *snsSqs) Publish(ctx context.Context, req *pubsub.PublishRequest) error } // sns client has internal exponential backoffs. - _, err = s.snsClient.PublishWithContext(ctx, snsPublishInput) + _, err = s.authProvider.SnsSqs().Sns.PublishWithContext(ctx, snsPublishInput) if err != nil { wrappedErr := fmt.Errorf("error publishing to topic: %s with topic ARN %s: %w", req.Topic, topicArn, err) s.logger.Error(wrappedErr) @@ -870,7 +875,7 @@ func (s *snsSqs) Close() error { s.subscriptionManager.Close() } - return nil + return s.authProvider.Close() } func (s *snsSqs) Features() []pubsub.Feature { diff --git a/pubsub/aws/snssqs/snssqs_test.go b/pubsub/aws/snssqs/snssqs_test.go index 1c789b67be..f396ddef47 100644 --- a/pubsub/aws/snssqs/snssqs_test.go +++ b/pubsub/aws/snssqs/snssqs_test.go @@ -38,7 +38,7 @@ func Test_parseTopicArn(t *testing.T) { } // Verify that all metadata ends up in the correct spot. -func Test_getSnsSqsMetatdata_AllConfiguration(t *testing.T) { +func Test_getSnsSqsMetadata_AllConfiguration(t *testing.T) { t.Parallel() r := require.New(t) l := logger.NewLogger("SnsSqs unit test") @@ -47,7 +47,7 @@ func Test_getSnsSqsMetatdata_AllConfiguration(t *testing.T) { logger: l, } - md, err := ps.getSnsSqsMetatdata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ + md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ "consumerID": "consumer", "Endpoint": "endpoint", "concurrencyMode": string(pubsub.Single), @@ -80,7 +80,7 @@ func Test_getSnsSqsMetatdata_AllConfiguration(t *testing.T) { r.Equal(int64(6), md.MessageReceiveLimit) } -func Test_getSnsSqsMetatdata_defaults(t *testing.T) { +func Test_getSnsSqsMetadata_defaults(t *testing.T) { t.Parallel() r := require.New(t) l := logger.NewLogger("SnsSqs unit test") @@ -89,7 +89,7 @@ func Test_getSnsSqsMetatdata_defaults(t *testing.T) { logger: l, } - md, err := ps.getSnsSqsMetatdata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ + md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ "consumerID": "c", "accessKey": "a", "secretKey": "s", @@ -114,7 +114,7 @@ func Test_getSnsSqsMetatdata_defaults(t *testing.T) { r.False(md.DisableDeleteOnRetryLimit) } -func Test_getSnsSqsMetatdata_legacyaliases(t *testing.T) { +func Test_getSnsSqsMetadata_legacyaliases(t *testing.T) { t.Parallel() r := require.New(t) l := logger.NewLogger("SnsSqs unit test") @@ -123,7 +123,7 @@ func Test_getSnsSqsMetatdata_legacyaliases(t *testing.T) { logger: l, } - md, err := ps.getSnsSqsMetatdata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ + md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ "consumerID": "consumer", "awsAccountID": "acctId", "awsSecret": "secret", @@ -151,13 +151,13 @@ func testMetadataParsingShouldFail(t *testing.T, metadata pubsub.Metadata, l log logger: l, } - md, err := ps.getSnsSqsMetatdata(metadata) + md, err := ps.getSnsSqsMetadata(metadata) r.Error(err) r.Nil(md) } -func Test_getSnsSqsMetatdata_invalidMetadataSetup(t *testing.T) { +func Test_getSnsSqsMetadata_invalidMetadataSetup(t *testing.T) { t.Parallel() fixtures := []testUnitFixture{ @@ -432,7 +432,7 @@ func Test_buildARN_DefaultPartition(t *testing.T) { logger: l, } - md, err := ps.getSnsSqsMetatdata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ + md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ "consumerID": "c", "accessKey": "a", "secretKey": "s", @@ -455,7 +455,7 @@ func Test_buildARN_StandardPartition(t *testing.T) { logger: l, } - md, err := ps.getSnsSqsMetatdata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ + md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ "consumerID": "c", "accessKey": "a", "secretKey": "s", @@ -478,7 +478,7 @@ func Test_buildARN_NonStandardPartition(t *testing.T) { logger: l, } - md, err := ps.getSnsSqsMetatdata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ + md, err := ps.getSnsSqsMetadata(pubsub.Metadata{Base: metadata.Base{Properties: map[string]string{ "consumerID": "c", "accessKey": "a", "secretKey": "s", diff --git a/secretstores/aws/parameterstore/parameterstore.go b/secretstores/aws/parameterstore/parameterstore.go index 1f82031ba8..abf9c6c4de 100644 --- a/secretstores/aws/parameterstore/parameterstore.go +++ b/secretstores/aws/parameterstore/parameterstore.go @@ -20,7 +20,6 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ssm" - "github.com/aws/aws-sdk-go/service/ssm/ssmiface" awsAuth "github.com/dapr/components-contrib/common/authentication/aws" "github.com/dapr/components-contrib/metadata" @@ -53,23 +52,33 @@ type ParameterStoreMetaData struct { } type ssmSecretStore struct { - client ssmiface.SSMAPI - prefix string - logger logger.Logger + authProvider awsAuth.Provider + prefix string + logger logger.Logger } // Init creates an AWS secret manager client. func (s *ssmSecretStore) Init(ctx context.Context, metadata secretstores.Metadata) error { - meta, err := s.getSecretManagerMetadata(metadata) + m, err := s.getSecretManagerMetadata(metadata) if err != nil { return err } - s.client, err = s.getClient(meta) + opts := awsAuth.Options{ + Logger: s.logger, + Properties: metadata.Properties, + Region: m.Region, + AccessKey: m.AccessKey, + SecretKey: m.SecretKey, + SessionToken: "", + } + // extra configs needed per component type + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } - s.prefix = meta.Prefix + s.authProvider = provider + s.prefix = m.Prefix return nil } @@ -84,7 +93,7 @@ func (s *ssmSecretStore) GetSecret(ctx context.Context, req secretstores.GetSecr name = fmt.Sprintf("%s:%s", req.Name, versionID) } - output, err := s.client.GetParameterWithContext(ctx, &ssm.GetParameterInput{ + output, err := s.authProvider.ParameterStore().Store.GetParameterWithContext(ctx, &ssm.GetParameterInput{ Name: ptr.Of(s.prefix + name), WithDecryption: ptr.Of(true), }) @@ -124,7 +133,7 @@ func (s *ssmSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bul } for search { - output, err := s.client.DescribeParametersWithContext(ctx, &ssm.DescribeParametersInput{ + output, err := s.authProvider.ParameterStore().Store.DescribeParametersWithContext(ctx, &ssm.DescribeParametersInput{ MaxResults: nil, NextToken: nextToken, ParameterFilters: filters, @@ -134,7 +143,7 @@ func (s *ssmSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bul } for _, entry := range output.Parameters { - params, err := s.client.GetParameterWithContext(ctx, &ssm.GetParameterInput{ + params, err := s.authProvider.ParameterStore().Store.GetParameterWithContext(ctx, &ssm.GetParameterInput{ Name: entry.Name, WithDecryption: aws.Bool(true), }) @@ -155,15 +164,6 @@ func (s *ssmSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bul return resp, nil } -func (s *ssmSecretStore) getClient(metadata *ParameterStoreMetaData) (*ssm.SSM, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, "") - if err != nil { - return nil, err - } - - return ssm.New(sess), nil -} - func (s *ssmSecretStore) getSecretManagerMetadata(spec secretstores.Metadata) (*ParameterStoreMetaData, error) { meta := ParameterStoreMetaData{} err := kitmd.DecodeMetadata(spec.Properties, &meta) @@ -182,5 +182,5 @@ func (s *ssmSecretStore) GetComponentMetadata() (metadataInfo metadata.MetadataM } func (s *ssmSecretStore) Close() error { - return nil + return s.authProvider.Close() } diff --git a/secretstores/aws/parameterstore/parameterstore_test.go b/secretstores/aws/parameterstore/parameterstore_test.go index 8d9bcf6065..04c7a6995e 100644 --- a/secretstores/aws/parameterstore/parameterstore_test.go +++ b/secretstores/aws/parameterstore/parameterstore_test.go @@ -22,9 +22,11 @@ import ( "testing" "github.com/aws/aws-sdk-go/aws" + + awsAuth "github.com/dapr/components-contrib/common/authentication/aws" + "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/ssm" - "github.com/aws/aws-sdk-go/service/ssm/ssmiface" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -34,20 +36,6 @@ import ( const secretValue = "secret" -type mockedSSM struct { - GetParameterFn func(context.Context, *ssm.GetParameterInput, ...request.Option) (*ssm.GetParameterOutput, error) - DescribeParametersFn func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) - ssmiface.SSMAPI -} - -func (m *mockedSSM) GetParameterWithContext(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - return m.GetParameterFn(ctx, input, option...) -} - -func (m *mockedSSM) DescribeParametersWithContext(ctx context.Context, input *ssm.DescribeParametersInput, option ...request.Option) (*ssm.DescribeParametersOutput, error) { - return m.DescribeParametersFn(ctx, input, option...) -} - func TestInit(t *testing.T) { m := secretstores.Metadata{} s := NewParameterStore(logger.NewLogger("test")) @@ -68,21 +56,32 @@ func TestInit(t *testing.T) { func TestGetSecret(t *testing.T) { t.Run("successfully retrieve secret", func(t *testing.T) { t.Run("with valid path", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - secret := secretValue - - return &ssm.GetParameterOutput{ - Parameter: &ssm.Parameter{ - Name: input.Name, - Value: &secret, - }, - }, nil - }, + mockSSM := &awsAuth.MockParameterStore{ + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + secret := secretValue + return &ssm.GetParameterOutput{ + Parameter: &ssm.Parameter{ + Name: input.Name, + Value: &secret, + }, + }, nil }, } + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + + s := ssmSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.GetSecretRequest{ Name: "/aws/dev/secret", Metadata: map[string]string{}, @@ -93,25 +92,36 @@ func TestGetSecret(t *testing.T) { }) t.Run("with version id", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - secret := secretValue - keys := strings.Split(*input.Name, ":") - assert.NotNil(t, keys) - assert.Len(t, keys, 2) - assert.Equalf(t, "1", keys[1], "Version IDs are same") - - return &ssm.GetParameterOutput{ - Parameter: &ssm.Parameter{ - Name: &keys[0], - Value: &secret, - }, - }, nil - }, + mockSSM := &awsAuth.MockParameterStore{ + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + secret := secretValue + keys := strings.Split(*input.Name, ":") + assert.NotNil(t, keys) + assert.Len(t, keys, 2) + assert.Equalf(t, "1", keys[1], "Version IDs are same") + + return &ssm.GetParameterOutput{ + Parameter: &ssm.Parameter{ + Name: &keys[0], + Value: &secret, + }, + }, nil }, } + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := ssmSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.GetSecretRequest{ Name: "/aws/dev/secret", Metadata: map[string]string{ @@ -124,21 +134,33 @@ func TestGetSecret(t *testing.T) { }) t.Run("with prefix", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - assert.Equal(t, "/prefix/aws/dev/secret", *input.Name) - secret := secretValue - - return &ssm.GetParameterOutput{ - Parameter: &ssm.Parameter{ - Name: input.Name, - Value: &secret, - }, - }, nil - }, + mockSSM := &awsAuth.MockParameterStore{ + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + assert.Equal(t, "/prefix/aws/dev/secret", *input.Name) + secret := secretValue + + return &ssm.GetParameterOutput{ + Parameter: &ssm.Parameter{ + Name: input.Name, + Value: &secret, + }, + }, nil }, - prefix: "/prefix", + } + + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + + s := ssmSecretStore{ + authProvider: mockAuthProvider, + prefix: "/prefix", } req := secretstores.GetSecretRequest{ @@ -152,13 +174,27 @@ func TestGetSecret(t *testing.T) { }) t.Run("unsuccessfully retrieve secret", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - return nil, errors.New("failed due to any reason") - }, + mockSSM := &awsAuth.MockParameterStore{ + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + return nil, errors.New("failed due to any reason") }, } + + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + + s := ssmSecretStore{ + authProvider: mockAuthProvider, + prefix: "/prefix", + } + req := secretstores.GetSecretRequest{ Name: "/aws/dev/secret", Metadata: map[string]string{}, @@ -170,31 +206,42 @@ func TestGetSecret(t *testing.T) { func TestGetBulkSecrets(t *testing.T) { t.Run("successfully retrieve bulk secrets", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { - return &ssm.DescribeParametersOutput{NextToken: nil, Parameters: []*ssm.ParameterMetadata{ - { - Name: aws.String("/aws/dev/secret1"), - }, - { - Name: aws.String("/aws/dev/secret2"), - }, - }}, nil - }, - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - secret := fmt.Sprintf("%s-%s", *input.Name, secretValue) + mockSSM := &awsAuth.MockParameterStore{ + DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { + return &ssm.DescribeParametersOutput{NextToken: nil, Parameters: []*ssm.ParameterMetadata{ + { + Name: aws.String("/aws/dev/secret1"), + }, + { + Name: aws.String("/aws/dev/secret2"), + }, + }}, nil + }, + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + secret := fmt.Sprintf("%s-%s", *input.Name, secretValue) - return &ssm.GetParameterOutput{ - Parameter: &ssm.Parameter{ - Name: input.Name, - Value: &secret, - }, - }, nil - }, + return &ssm.GetParameterOutput{ + Parameter: &ssm.Parameter{ + Name: input.Name, + Value: &secret, + }, + }, nil }, } + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := ssmSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.BulkGetSecretRequest{ Metadata: map[string]string{}, } @@ -205,30 +252,41 @@ func TestGetBulkSecrets(t *testing.T) { }) t.Run("successfully retrieve bulk secrets with prefix", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { - return &ssm.DescribeParametersOutput{NextToken: nil, Parameters: []*ssm.ParameterMetadata{ - { - Name: aws.String("/prefix/aws/dev/secret1"), - }, - { - Name: aws.String("/prefix/aws/dev/secret2"), - }, - }}, nil - }, - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - secret := fmt.Sprintf("%s-%s", *input.Name, secretValue) + mockSSM := &awsAuth.MockParameterStore{ + DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { + return &ssm.DescribeParametersOutput{NextToken: nil, Parameters: []*ssm.ParameterMetadata{ + { + Name: aws.String("/prefix/aws/dev/secret1"), + }, + { + Name: aws.String("/prefix/aws/dev/secret2"), + }, + }}, nil + }, + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + secret := fmt.Sprintf("%s-%s", *input.Name, secretValue) - return &ssm.GetParameterOutput{ - Parameter: &ssm.Parameter{ - Name: input.Name, - Value: &secret, - }, - }, nil - }, + return &ssm.GetParameterOutput{ + Parameter: &ssm.Parameter{ + Name: input.Name, + Value: &secret, + }, + }, nil }, - prefix: "/prefix", + } + + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := ssmSecretStore{ + authProvider: mockAuthProvider, + prefix: "/prefix", } req := secretstores.BulkGetSecretRequest{ @@ -241,23 +299,35 @@ func TestGetBulkSecrets(t *testing.T) { }) t.Run("unsuccessfully retrieve bulk secrets on get parameter", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { - return &ssm.DescribeParametersOutput{NextToken: nil, Parameters: []*ssm.ParameterMetadata{ - { - Name: aws.String("/aws/dev/secret1"), - }, - { - Name: aws.String("/aws/dev/secret2"), - }, - }}, nil - }, - GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { - return nil, errors.New("failed due to any reason") - }, + mockSSM := &awsAuth.MockParameterStore{ + DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { + return &ssm.DescribeParametersOutput{NextToken: nil, Parameters: []*ssm.ParameterMetadata{ + { + Name: aws.String("/aws/dev/secret1"), + }, + { + Name: aws.String("/aws/dev/secret2"), + }, + }}, nil }, + GetParameterFn: func(ctx context.Context, input *ssm.GetParameterInput, option ...request.Option) (*ssm.GetParameterOutput, error) { + return nil, errors.New("failed due to any reason") + }, + } + + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := ssmSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.BulkGetSecretRequest{ Metadata: map[string]string{}, } @@ -266,13 +336,25 @@ func TestGetBulkSecrets(t *testing.T) { }) t.Run("unsuccessfully retrieve bulk secrets on describe parameter", func(t *testing.T) { - s := ssmSecretStore{ - client: &mockedSSM{ - DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { - return nil, errors.New("failed due to any reason") - }, + mockSSM := &awsAuth.MockParameterStore{ + DescribeParametersFn: func(context.Context, *ssm.DescribeParametersInput, ...request.Option) (*ssm.DescribeParametersOutput, error) { + return nil, errors.New("failed due to any reason") }, } + + paramStore := awsAuth.ParameterStoreClients{ + Store: mockSSM, + } + + mockedClients := awsAuth.Clients{ + ParameterStore: ¶mStore, + } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := ssmSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.BulkGetSecretRequest{ Metadata: map[string]string{}, } diff --git a/secretstores/aws/secretmanager/secretmanager.go b/secretstores/aws/secretmanager/secretmanager.go index 54ed329d35..6faf1f1eab 100644 --- a/secretstores/aws/secretmanager/secretmanager.go +++ b/secretstores/aws/secretmanager/secretmanager.go @@ -20,7 +20,6 @@ import ( "reflect" "github.com/aws/aws-sdk-go/service/secretsmanager" - "github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface" awsAuth "github.com/dapr/components-contrib/common/authentication/aws" "github.com/dapr/components-contrib/metadata" @@ -49,8 +48,8 @@ type SecretManagerMetaData struct { } type smSecretStore struct { - client secretsmanageriface.SecretsManagerAPI - logger logger.Logger + authProvider awsAuth.Provider + logger logger.Logger } // Init creates an AWS secret manager client. @@ -60,11 +59,20 @@ func (s *smSecretStore) Init(ctx context.Context, metadata secretstores.Metadata return err } - s.client, err = s.getClient(meta) + opts := awsAuth.Options{ + Logger: s.logger, + Region: meta.Region, + AccessKey: meta.AccessKey, + SecretKey: meta.SecretKey, + SessionToken: meta.SessionToken, + Endpoint: meta.Endpoint, + } + + provider, err := awsAuth.NewProvider(ctx, opts, awsAuth.GetConfig(opts)) if err != nil { return err } - + s.authProvider = provider return nil } @@ -78,8 +86,7 @@ func (s *smSecretStore) GetSecret(ctx context.Context, req secretstores.GetSecre if value, ok := req.Metadata[VersionStage]; ok { versionStage = &value } - - output, err := s.client.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ + output, err := s.authProvider.SecretManager().Manager.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ SecretId: &req.Name, VersionId: versionID, VersionStage: versionStage, @@ -108,7 +115,7 @@ func (s *smSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bulk var nextToken *string = nil for search { - output, err := s.client.ListSecretsWithContext(ctx, &secretsmanager.ListSecretsInput{ + output, err := s.authProvider.SecretManager().Manager.ListSecretsWithContext(ctx, &secretsmanager.ListSecretsInput{ MaxResults: nil, NextToken: nextToken, }) @@ -117,7 +124,7 @@ func (s *smSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bulk } for _, entry := range output.SecretList { - secrets, err := s.client.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ + secrets, err := s.authProvider.SecretManager().Manager.GetSecretValueWithContext(ctx, &secretsmanager.GetSecretValueInput{ SecretId: entry.Name, }) if err != nil { @@ -136,15 +143,6 @@ func (s *smSecretStore) BulkGetSecret(ctx context.Context, req secretstores.Bulk return resp, nil } -func (s *smSecretStore) getClient(metadata *SecretManagerMetaData) (*secretsmanager.SecretsManager, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) - if err != nil { - return nil, err - } - - return secretsmanager.New(sess), nil -} - func (s *smSecretStore) getSecretManagerMetadata(spec secretstores.Metadata) (*SecretManagerMetaData, error) { b, err := json.Marshal(spec.Properties) if err != nil { @@ -172,5 +170,5 @@ func (s *smSecretStore) GetComponentMetadata() (metadataInfo metadata.MetadataMa } func (s *smSecretStore) Close() error { - return nil + return s.authProvider.Close() } diff --git a/secretstores/aws/secretmanager/secretmanager_test.go b/secretstores/aws/secretmanager/secretmanager_test.go index 85918237a3..7fbd8493af 100644 --- a/secretstores/aws/secretmanager/secretmanager_test.go +++ b/secretstores/aws/secretmanager/secretmanager_test.go @@ -21,25 +21,17 @@ import ( "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/secretsmanager" - "github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + awsAuth "github.com/dapr/components-contrib/common/authentication/aws" + "github.com/dapr/components-contrib/secretstores" "github.com/dapr/kit/logger" ) const secretValue = "secret" -type mockedSM struct { - GetSecretValueFn func(context.Context, *secretsmanager.GetSecretValueInput, ...request.Option) (*secretsmanager.GetSecretValueOutput, error) - secretsmanageriface.SecretsManagerAPI -} - -func (m *mockedSM) GetSecretValueWithContext(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { - return m.GetSecretValueFn(ctx, input, option...) -} - func TestInit(t *testing.T) { m := secretstores.Metadata{} s := NewSecretManager(logger.NewLogger("test")) @@ -60,21 +52,32 @@ func TestInit(t *testing.T) { func TestGetSecret(t *testing.T) { t.Run("successfully retrieve secret", func(t *testing.T) { t.Run("without version id and version stage", func(t *testing.T) { - s := smSecretStore{ - client: &mockedSM{ - GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { - assert.Nil(t, input.VersionId) - assert.Nil(t, input.VersionStage) - secret := secretValue - - return &secretsmanager.GetSecretValueOutput{ - Name: input.SecretId, - SecretString: &secret, - }, nil - }, + mockSSM := &awsAuth.MockSecretManager{ + GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { + assert.Nil(t, input.VersionId) + assert.Nil(t, input.VersionStage) + secret := secretValue + + return &secretsmanager.GetSecretValueOutput{ + Name: input.SecretId, + SecretString: &secret, + }, nil }, } + secret := awsAuth.SecretManagerClients{ + Manager: mockSSM, + } + + mockedClients := awsAuth.Clients{ + Secret: &secret, + } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := smSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.GetSecretRequest{ Name: "/aws/secret/testing", Metadata: map[string]string{}, @@ -85,20 +88,32 @@ func TestGetSecret(t *testing.T) { }) t.Run("with version id", func(t *testing.T) { - s := smSecretStore{ - client: &mockedSM{ - GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { - assert.NotNil(t, input.VersionId) - secret := secretValue - - return &secretsmanager.GetSecretValueOutput{ - Name: input.SecretId, - SecretString: &secret, - }, nil - }, + mockSSM := &awsAuth.MockSecretManager{ + GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { + assert.NotNil(t, input.VersionId) + secret := secretValue + + return &secretsmanager.GetSecretValueOutput{ + Name: input.SecretId, + SecretString: &secret, + }, nil }, } + secret := awsAuth.SecretManagerClients{ + Manager: mockSSM, + } + + mockedClients := awsAuth.Clients{ + Secret: &secret, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := smSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.GetSecretRequest{ Name: "/aws/secret/testing", Metadata: map[string]string{ @@ -111,20 +126,32 @@ func TestGetSecret(t *testing.T) { }) t.Run("with version stage", func(t *testing.T) { - s := smSecretStore{ - client: &mockedSM{ - GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { - assert.NotNil(t, input.VersionStage) - secret := secretValue - - return &secretsmanager.GetSecretValueOutput{ - Name: input.SecretId, - SecretString: &secret, - }, nil - }, + mockSSM := &awsAuth.MockSecretManager{ + GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { + assert.NotNil(t, input.VersionStage) + secret := secretValue + + return &secretsmanager.GetSecretValueOutput{ + Name: input.SecretId, + SecretString: &secret, + }, nil }, } + secret := awsAuth.SecretManagerClients{ + Manager: mockSSM, + } + + mockedClients := awsAuth.Clients{ + Secret: &secret, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := smSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.GetSecretRequest{ Name: "/aws/secret/testing", Metadata: map[string]string{ @@ -138,13 +165,26 @@ func TestGetSecret(t *testing.T) { }) t.Run("unsuccessfully retrieve secret", func(t *testing.T) { - s := smSecretStore{ - client: &mockedSM{ - GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { - return nil, errors.New("failed due to any reason") - }, + mockSSM := &awsAuth.MockSecretManager{ + GetSecretValueFn: func(ctx context.Context, input *secretsmanager.GetSecretValueInput, option ...request.Option) (*secretsmanager.GetSecretValueOutput, error) { + return nil, errors.New("failed due to any reason") }, } + + secret := awsAuth.SecretManagerClients{ + Manager: mockSSM, + } + + mockedClients := awsAuth.Clients{ + Secret: &secret, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := smSecretStore{ + authProvider: mockAuthProvider, + } + req := secretstores.GetSecretRequest{ Name: "/aws/secret/testing", Metadata: map[string]string{}, diff --git a/state/aws/dynamodb/dynamodb.go b/state/aws/dynamodb/dynamodb.go index 503d7082c7..ae4ba7c5e9 100644 --- a/state/aws/dynamodb/dynamodb.go +++ b/state/aws/dynamodb/dynamodb.go @@ -25,7 +25,6 @@ import ( "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" jsoniterator "github.com/json-iterator/go" awsAuth "github.com/dapr/components-contrib/common/authentication/aws" @@ -41,7 +40,8 @@ import ( type StateStore struct { state.BulkStore - client dynamodbiface.DynamoDBAPI + authProvider awsAuth.Provider + logger logger.Logger table string ttlAttributeName string partitionKey string @@ -66,9 +66,10 @@ const ( ) // NewDynamoDBStateStore returns a new dynamoDB state store. -func NewDynamoDBStateStore(_ logger.Logger) state.Store { +func NewDynamoDBStateStore(logger logger.Logger) state.Store { s := &StateStore{ partitionKey: defaultPartitionKeyName, + logger: logger, } s.BulkStore = state.NewDefaultBulkStore(s) return s @@ -80,14 +81,24 @@ func (d *StateStore) Init(ctx context.Context, metadata state.Metadata) error { if err != nil { return err } - - // This check is needed because d.client is set to a mock in tests - if d.client == nil { - d.client, err = d.getClient(meta) + if d.authProvider == nil { + opts := awsAuth.Options{ + Logger: d.logger, + Properties: metadata.Properties, + Region: meta.Region, + Endpoint: meta.Endpoint, + AccessKey: meta.AccessKey, + SecretKey: meta.SecretKey, + SessionToken: meta.SessionToken, + } + cfg := awsAuth.GetConfig(opts) + provider, err := awsAuth.NewProvider(ctx, opts, cfg) if err != nil { return err } + d.authProvider = provider } + d.table = meta.Table d.ttlAttributeName = meta.TTLAttributeName d.partitionKey = meta.PartitionKey @@ -111,8 +122,7 @@ func (d *StateStore) validateTableAccess(ctx context.Context) error { }, }, } - - _, err := d.client.GetItemWithContext(ctx, input) + _, err := d.authProvider.DynamoDB().DynamoDB.GetItemWithContext(ctx, input) return err } @@ -144,8 +154,7 @@ func (d *StateStore) Get(ctx context.Context, req *state.GetRequest) (*state.Get }, }, } - - result, err := d.client.GetItemWithContext(ctx, input) + result, err := d.authProvider.DynamoDB().DynamoDB.GetItemWithContext(ctx, input) if err != nil { return nil, err } @@ -217,8 +226,7 @@ func (d *StateStore) Set(ctx context.Context, req *state.SetRequest) error { condExpr := "attribute_not_exists(etag)" input.ConditionExpression = &condExpr } - - _, err = d.client.PutItemWithContext(ctx, input) + _, err = d.authProvider.DynamoDB().DynamoDB.PutItemWithContext(ctx, input) if err != nil && req.HasETag() { switch cErr := err.(type) { case *dynamodb.ConditionalCheckFailedException: @@ -249,8 +257,7 @@ func (d *StateStore) Delete(ctx context.Context, req *state.DeleteRequest) error } input.ExpressionAttributeValues = exprAttrValues } - - _, err := d.client.DeleteItemWithContext(ctx, input) + _, err := d.authProvider.DynamoDB().DynamoDB.DeleteItemWithContext(ctx, input) if err != nil { switch cErr := err.(type) { case *dynamodb.ConditionalCheckFailedException: @@ -268,7 +275,7 @@ func (d *StateStore) GetComponentMetadata() (metadataInfo metadata.MetadataMap) } func (d *StateStore) Close() error { - return nil + return d.authProvider.Close() } func (d *StateStore) getDynamoDBMetadata(meta state.Metadata) (*dynamoDBMetadata, error) { @@ -281,16 +288,6 @@ func (d *StateStore) getDynamoDBMetadata(meta state.Metadata) (*dynamoDBMetadata return &m, err } -func (d *StateStore) getClient(metadata *dynamoDBMetadata) (*dynamodb.DynamoDB, error) { - sess, err := awsAuth.GetClient(metadata.AccessKey, metadata.SecretKey, metadata.SessionToken, metadata.Region, metadata.Endpoint) - if err != nil { - return nil, err - } - c := dynamodb.New(sess) - - return c, nil -} - // getItemFromReq converts a dapr state.SetRequest into an dynamodb item func (d *StateStore) getItemFromReq(req *state.SetRequest) (map[string]*dynamodb.AttributeValue, error) { value, err := d.marshalToString(req.Value) @@ -431,8 +428,7 @@ func (d *StateStore) Multi(ctx context.Context, request *state.TransactionalStat } twinput.TransactItems = append(twinput.TransactItems, twi) } - - _, err := d.client.TransactWriteItemsWithContext(ctx, twinput) + _, err := d.authProvider.DynamoDB().DynamoDB.TransactWriteItemsWithContext(ctx, twinput) return err } diff --git a/state/aws/dynamodb/dynamodb_test.go b/state/aws/dynamodb/dynamodb_test.go index d1f98b70ba..7b667b6c78 100644 --- a/state/aws/dynamodb/dynamodb_test.go +++ b/state/aws/dynamodb/dynamodb_test.go @@ -21,26 +21,18 @@ import ( "testing" "time" + awsAuth "github.com/dapr/components-contrib/common/authentication/aws" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute" - "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/dapr/components-contrib/state" ) -type mockedDynamoDB struct { - GetItemWithContextFn func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) - PutItemWithContextFn func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (*dynamodb.PutItemOutput, error) - DeleteItemWithContextFn func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (*dynamodb.DeleteItemOutput, error) - BatchWriteItemWithContextFn func(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (*dynamodb.BatchWriteItemOutput, error) - TransactWriteItemsWithContextFn func(aws.Context, *dynamodb.TransactWriteItemsInput, ...request.Option) (*dynamodb.TransactWriteItemsOutput, error) - dynamodbiface.DynamoDBAPI -} - type DynamoDBItem struct { Key string `json:"key"` Value string `json:"value"` @@ -52,36 +44,28 @@ const ( pkey = "partitionKey" ) -func (m *mockedDynamoDB) GetItemWithContext(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) { - return m.GetItemWithContextFn(ctx, input, op...) -} - -func (m *mockedDynamoDB) PutItemWithContext(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (*dynamodb.PutItemOutput, error) { - return m.PutItemWithContextFn(ctx, input, op...) -} - -func (m *mockedDynamoDB) DeleteItemWithContext(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (*dynamodb.DeleteItemOutput, error) { - return m.DeleteItemWithContextFn(ctx, input, op...) -} +func TestInit(t *testing.T) { + m := state.Metadata{} + mockedDB := &awsAuth.MockDynamoDB{ + // We're adding this so we can pass the connection check on Init + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) { + return nil, nil + }, + } -func (m *mockedDynamoDB) BatchWriteItemWithContext(ctx context.Context, input *dynamodb.BatchWriteItemInput, op ...request.Option) (*dynamodb.BatchWriteItemOutput, error) { - return m.BatchWriteItemWithContextFn(ctx, input, op...) -} + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } -func (m *mockedDynamoDB) TransactWriteItemsWithContext(ctx context.Context, input *dynamodb.TransactWriteItemsInput, op ...request.Option) (*dynamodb.TransactWriteItemsOutput, error) { - return m.TransactWriteItemsWithContextFn(ctx, input, op...) -} + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } -func TestInit(t *testing.T) { - m := state.Metadata{} - s := &StateStore{ + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, partitionKey: defaultPartitionKeyName, - client: &mockedDynamoDB{ - // We're adding this so we can pass the connection check on Init - GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) { - return nil, nil - }, - }, } t.Run("NewDynamoDBStateStore Default Partition Key", func(t *testing.T) { @@ -132,16 +116,29 @@ func TestInit(t *testing.T) { }) t.Run("Init with bad table name or permissions", func(t *testing.T) { + table := "does-not-exist" m.Properties = map[string]string{ - "Table": "does-not-exist", - "Region": "eu-west-1", + "Table": table, } - s.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (*dynamodb.GetItemOutput, error) { return nil, errors.New("Requested resource not found") }, } + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + table: table, + } err := s.Init(context.Background(), m) require.Error(t, err) @@ -151,10 +148,7 @@ func TestInit(t *testing.T) { func TestGet(t *testing.T) { t.Run("Successfully retrieve item", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { return &dynamodb.GetItemOutput{ Item: map[string]*dynamodb.AttributeValue{ @@ -172,6 +166,20 @@ func TestGet(t *testing.T) { }, } + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } req := &state.GetRequest{ Key: "someKey", Metadata: nil, @@ -179,34 +187,46 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(context.Background(), req) + out, err := s.Get(context.Background(), req) require.NoError(t, err) assert.Equal(t, []byte("some value"), out.Data) assert.Equal(t, "1bdead4badc0ffee", *out.ETag) assert.NotContains(t, out.Metadata, "ttlExpireTime") }) t.Run("Successfully retrieve item (with unexpired ttl)", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { - return &dynamodb.GetItemOutput{ - Item: map[string]*dynamodb.AttributeValue{ - "key": { - S: aws.String("someKey"), - }, - "value": { - S: aws.String("some value"), - }, - "testAttributeName": { - N: aws.String("4074862051"), - }, - "etag": { - S: aws.String("1bdead4badc0ffee"), - }, + mockedDB := &awsAuth.MockDynamoDB{ + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { + return &dynamodb.GetItemOutput{ + Item: map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String("someKey"), }, - }, nil - }, + "value": { + S: aws.String("some value"), + }, + "testAttributeName": { + N: aws.String("4074862051"), + }, + "etag": { + S: aws.String("1bdead4badc0ffee"), + }, + }, + }, nil }, + } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, ttlAttributeName: "testAttributeName", } req := &state.GetRequest{ @@ -216,7 +236,7 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(context.Background(), req) + out, err := s.Get(context.Background(), req) require.NoError(t, err) assert.Equal(t, []byte("some value"), out.Data) assert.Equal(t, "1bdead4badc0ffee", *out.ETag) @@ -226,27 +246,39 @@ func TestGet(t *testing.T) { assert.Equal(t, int64(4074862051), expireTime.Unix()) }) t.Run("Successfully retrieve item (with expired ttl)", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { - return &dynamodb.GetItemOutput{ - Item: map[string]*dynamodb.AttributeValue{ - "key": { - S: aws.String("someKey"), - }, - "value": { - S: aws.String("some value"), - }, - "testAttributeName": { - N: aws.String("35489251"), - }, - "etag": { - S: aws.String("1bdead4badc0ffee"), - }, + mockedDB := &awsAuth.MockDynamoDB{ + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { + return &dynamodb.GetItemOutput{ + Item: map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String("someKey"), + }, + "value": { + S: aws.String("some value"), + }, + "testAttributeName": { + N: aws.String("35489251"), }, - }, nil - }, + "etag": { + S: aws.String("1bdead4badc0ffee"), + }, + }, + }, nil }, + } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, ttlAttributeName: "testAttributeName", } req := &state.GetRequest{ @@ -256,20 +288,33 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(context.Background(), req) + out, err := s.Get(context.Background(), req) require.NoError(t, err) assert.Nil(t, out.Data) assert.Nil(t, out.ETag) assert.Nil(t, out.Metadata) }) t.Run("Unsuccessfully get item", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { - return nil, errors.New("failed to retrieve data") - }, + mockedDB := &awsAuth.MockDynamoDB{ + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { + return nil, errors.New("failed to retrieve data") }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + } + req := &state.GetRequest{ Key: "key", Metadata: nil, @@ -277,20 +322,32 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(context.Background(), req) + out, err := s.Get(context.Background(), req) require.Error(t, err) assert.Nil(t, out) }) t.Run("Unsuccessfully with empty response", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { - return &dynamodb.GetItemOutput{ - Item: map[string]*dynamodb.AttributeValue{}, - }, nil - }, + mockedDB := &awsAuth.MockDynamoDB{ + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { + return &dynamodb.GetItemOutput{ + Item: map[string]*dynamodb.AttributeValue{}, + }, nil }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + } req := &state.GetRequest{ Key: "key", Metadata: nil, @@ -298,26 +355,38 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(context.Background(), req) + out, err := s.Get(context.Background(), req) require.NoError(t, err) assert.Nil(t, out.Data) assert.Nil(t, out.ETag) assert.Nil(t, out.Metadata) }) t.Run("Unsuccessfully with no required key", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { - return &dynamodb.GetItemOutput{ - Item: map[string]*dynamodb.AttributeValue{ - "value2": { - S: aws.String("value"), - }, + mockedDB := &awsAuth.MockDynamoDB{ + GetItemWithContextFn: func(ctx context.Context, input *dynamodb.GetItemInput, op ...request.Option) (output *dynamodb.GetItemOutput, err error) { + return &dynamodb.GetItemOutput{ + Item: map[string]*dynamodb.AttributeValue{ + "value2": { + S: aws.String("value"), }, - }, nil - }, + }, + }, nil }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + } req := &state.GetRequest{ Key: "key", Metadata: nil, @@ -325,7 +394,7 @@ func TestGet(t *testing.T) { Consistency: "strong", }, } - out, err := ss.Get(context.Background(), req) + out, err := s.Get(context.Background(), req) require.NoError(t, err) assert.Empty(t, out.Data) assert.Nil(t, out.ETag) @@ -338,10 +407,7 @@ func TestSet(t *testing.T) { } t.Run("Successfully set item", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -360,21 +426,34 @@ func TestSet(t *testing.T) { }, nil }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } + req := &state.SetRequest{ Key: "key", Value: value{ Value: "value", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.NoError(t, err) }) t.Run("Successfully set item with matching etag", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -397,6 +476,21 @@ func TestSet(t *testing.T) { }, nil }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } etag := "1bdead4badc0ffee" req := &state.SetRequest{ ETag: &etag, @@ -405,15 +499,12 @@ func TestSet(t *testing.T) { Value: "value", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.NoError(t, err) }) t.Run("Unsuccessfully set item with mismatched etag", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -431,6 +522,21 @@ func TestSet(t *testing.T) { return nil, &checkErr }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } etag := "bogusetag" req := &state.SetRequest{ ETag: &etag, @@ -440,7 +546,7 @@ func TestSet(t *testing.T) { }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.Error(t, err) switch tagErr := err.(type) { case *state.ETagError: @@ -451,10 +557,7 @@ func TestSet(t *testing.T) { }) t.Run("Successfully set item with first-write-concurrency", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -474,6 +577,21 @@ func TestSet(t *testing.T) { }, nil }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } req := &state.SetRequest{ Key: "key", Value: value{ @@ -483,15 +601,12 @@ func TestSet(t *testing.T) { Concurrency: state.FirstWrite, }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.NoError(t, err) }) t.Run("Unsuccessfully set item with first-write-concurrency", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("key"), @@ -506,6 +621,21 @@ func TestSet(t *testing.T) { return nil, &checkErr }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } req := &state.SetRequest{ Key: "key", Value: value{ @@ -515,7 +645,7 @@ func TestSet(t *testing.T) { Concurrency: state.FirstWrite, }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.Error(t, err) switch err.(type) { case *state.ETagError: @@ -525,10 +655,7 @@ func TestSet(t *testing.T) { }) t.Run("Successfully set item with ttl = -1", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Len(t, input.Item, 4) result := DynamoDBItem{} @@ -547,7 +674,22 @@ func TestSet(t *testing.T) { }, nil }, } - ss.ttlAttributeName = "testAttributeName" + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + ttlAttributeName: "testAttributeName", + partitionKey: defaultPartitionKeyName, + } req := &state.SetRequest{ Key: "someKey", @@ -558,14 +700,11 @@ func TestSet(t *testing.T) { "ttlInSeconds": "-1", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.NoError(t, err) }) t.Run("Successfully set item with 'correct' ttl", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Len(t, input.Item, 4) result := DynamoDBItem{} @@ -584,7 +723,22 @@ func TestSet(t *testing.T) { }, nil }, } - ss.ttlAttributeName = "testAttributeName" + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + ttlAttributeName: "testAttributeName", + } req := &state.SetRequest{ Key: "someKey", @@ -595,33 +749,42 @@ func TestSet(t *testing.T) { "ttlInSeconds": "180", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.NoError(t, err) }) t.Run("Unsuccessfully set item", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { return nil, errors.New("unable to put item") }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } req := &state.SetRequest{ Key: "key", Value: value{ Value: "value", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.Error(t, err) }) t.Run("Successfully set item with correct ttl but without component metadata", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { assert.Equal(t, dynamodb.AttributeValue{ S: aws.String("someKey"), @@ -640,6 +803,21 @@ func TestSet(t *testing.T) { }, nil }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } req := &state.SetRequest{ Key: "someKey", Value: value{ @@ -649,34 +827,46 @@ func TestSet(t *testing.T) { "ttlInSeconds": "180", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.NoError(t, err) }) t.Run("Unsuccessfully set item with ttl (invalid value)", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { - assert.Equal(t, map[string]*dynamodb.AttributeValue{ - "key": { - S: aws.String("somekey"), - }, - "value": { - S: aws.String(`{"Value":"somevalue"}`), - }, - "ttlInSeconds": { - N: aws.String("180"), - }, - }, input.Item) + mockedDB := &awsAuth.MockDynamoDB{ + PutItemWithContextFn: func(ctx context.Context, input *dynamodb.PutItemInput, op ...request.Option) (output *dynamodb.PutItemOutput, err error) { + assert.Equal(t, map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String("somekey"), + }, + "value": { + S: aws.String(`{"Value":"somevalue"}`), + }, + "ttlInSeconds": { + N: aws.String("180"), + }, + }, input.Item) - return &dynamodb.PutItemOutput{ - Attributes: map[string]*dynamodb.AttributeValue{ - "key": { - S: aws.String("value"), - }, + return &dynamodb.PutItemOutput{ + Attributes: map[string]*dynamodb.AttributeValue{ + "key": { + S: aws.String("value"), }, - }, nil - }, + }, + }, nil }, + } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, ttlAttributeName: "testAttributeName", } req := &state.SetRequest{ @@ -688,7 +878,7 @@ func TestSet(t *testing.T) { "ttlInSeconds": "invalidvalue", }, } - err := ss.Set(context.Background(), req) + err := s.Set(context.Background(), req) require.Error(t, err) assert.Equal(t, "dynamodb error: failed to parse ttlInSeconds: strconv.ParseInt: parsing \"invalidvalue\": invalid syntax", err.Error()) }) @@ -700,10 +890,7 @@ func TestDelete(t *testing.T) { Key: "key", } - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { assert.Equal(t, map[string]*dynamodb.AttributeValue{ "key": { @@ -715,7 +902,22 @@ func TestDelete(t *testing.T) { }, } - err := ss.Delete(context.Background(), req) + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } + + err := s.Delete(context.Background(), req) require.NoError(t, err) }) @@ -725,10 +927,8 @@ func TestDelete(t *testing.T) { ETag: &etag, Key: "key", } - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + + mockedDB := &awsAuth.MockDynamoDB{ DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { assert.Equal(t, map[string]*dynamodb.AttributeValue{ "key": { @@ -744,7 +944,22 @@ func TestDelete(t *testing.T) { }, } - err := ss.Delete(context.Background(), req) + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } + + err := s.Delete(context.Background(), req) require.NoError(t, err) }) @@ -755,10 +970,7 @@ func TestDelete(t *testing.T) { Key: "key", } - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { assert.Equal(t, map[string]*dynamodb.AttributeValue{ "key": { @@ -775,7 +987,21 @@ func TestDelete(t *testing.T) { }, } - err := ss.Delete(context.Background(), req) + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + partitionKey: defaultPartitionKeyName, + } + err := s.Delete(context.Background(), req) require.Error(t, err) switch tagErr := err.(type) { case *state.ETagError: @@ -786,26 +1012,36 @@ func TestDelete(t *testing.T) { }) t.Run("Unsuccessfully delete item", func(t *testing.T) { - ss := StateStore{ - client: &mockedDynamoDB{ - DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { - return nil, errors.New("unable to delete item") - }, + mockedDB := &awsAuth.MockDynamoDB{ + DeleteItemWithContextFn: func(ctx context.Context, input *dynamodb.DeleteItemInput, op ...request.Option) (output *dynamodb.DeleteItemOutput, err error) { + return nil, errors.New("unable to delete item") }, } + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + } + req := &state.DeleteRequest{ Key: "key", } - err := ss.Delete(context.Background(), req) + err := s.Delete(context.Background(), req) require.Error(t, err) }) } func TestMultiTx(t *testing.T) { t.Run("Successfully Multiple Transaction Operations", func(t *testing.T) { - ss := &StateStore{ - partitionKey: defaultPartitionKeyName, - } firstKey := "key1" secondKey := "key2" secondValue := "value2" @@ -829,7 +1065,7 @@ func TestMultiTx(t *testing.T) { }, } - ss.client = &mockedDynamoDB{ + mockedDB := &awsAuth.MockDynamoDB{ TransactWriteItemsWithContextFn: func(ctx context.Context, input *dynamodb.TransactWriteItemsInput, op ...request.Option) (*dynamodb.TransactWriteItemsOutput, error) { // ops - duplicates exOps := len(ops) - 1 @@ -853,13 +1089,28 @@ func TestMultiTx(t *testing.T) { return &dynamodb.TransactWriteItemsOutput{}, nil }, } - ss.table = tableName + + dynamo := awsAuth.DynamoDBClients{ + DynamoDB: mockedDB, + } + + mockedClients := awsAuth.Clients{ + Dynamo: &dynamo, + } + + mockAuthProvider := &awsAuth.StaticAuth{} + mockAuthProvider.WithMockClients(&mockedClients) + s := StateStore{ + authProvider: mockAuthProvider, + table: tableName, + partitionKey: defaultPartitionKeyName, + } req := &state.TransactionalStateRequest{ Operations: ops, Metadata: map[string]string{}, } - err := ss.Multi(context.Background(), req) + err := s.Multi(context.Background(), req) require.NoError(t, err) }) } diff --git a/tests/certification/bindings/aws/s3/s3_test.go b/tests/certification/bindings/aws/s3/s3_test.go index 16e41abc49..c815901f8c 100644 --- a/tests/certification/bindings/aws/s3/s3_test.go +++ b/tests/certification/bindings/aws/s3/s3_test.go @@ -279,7 +279,7 @@ func S3SForcePathStyle(t *testing.T) { Step(sidecar.Run(sidecarName, append(componentRuntimeOptions(), embedded.WithoutApp(), - embedded.WithComponentsPath("./components/forcePathStyleTrue"), + embedded.WithResourcesPath("./components/forcePathStyleTrue"), embedded.WithDaprGRPCPort(strconv.Itoa(currentGRPCPort)), embedded.WithDaprHTTPPort(strconv.Itoa(currentHTTPPort)), )..., diff --git a/tests/certification/go.mod b/tests/certification/go.mod index 1dc9c0ad44..3fe60b4224 100644 --- a/tests/certification/go.mod +++ b/tests/certification/go.mod @@ -98,6 +98,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/sso v1.23.3 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.27.3 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.31.3 // indirect + github.com/aws/rolesanywhere-credential-helper v1.0.4 // indirect github.com/aws/smithy-go v1.21.0 // indirect github.com/benbjohnson/clock v1.3.5 // indirect github.com/beorn7/perks v1.0.1 // indirect @@ -290,6 +291,7 @@ require ( github.com/tklauser/numcpus v0.6.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.51.0 // indirect + github.com/vmware/vmware-go-kcl v1.5.1 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/scram v1.1.2 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect @@ -333,6 +335,7 @@ require ( google.golang.org/grpc v1.64.1 // indirect google.golang.org/protobuf v1.34.2 // indirect gopkg.in/inf.v0 v0.9.1 // indirect + gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect k8s.io/api v0.30.2 // indirect diff --git a/tests/certification/go.sum b/tests/certification/go.sum index 145f05a305..d21aec5c3c 100644 --- a/tests/certification/go.sum +++ b/tests/certification/go.sum @@ -234,6 +234,8 @@ github.com/aws/aws-sdk-go-v2/service/ssooidc v1.27.3/go.mod h1:FnvDM4sfa+isJ3kDX github.com/aws/aws-sdk-go-v2/service/sts v1.7.2/go.mod h1:8EzeIqfWt2wWT4rJVu3f21TfrhJ8AEMzVybRNSb/b4g= github.com/aws/aws-sdk-go-v2/service/sts v1.31.3 h1:VzudTFrDCIDakXtemR7l6Qzt2+JYsVqo2MxBPt5k8T8= github.com/aws/aws-sdk-go-v2/service/sts v1.31.3/go.mod h1:yMWe0F+XG0DkRZK5ODZhG7BEFYhLXi2dqGsv6tX0cgI= +github.com/aws/rolesanywhere-credential-helper v1.0.4 h1:kHIVVdyQQiFZoKBP+zywBdFilGCS8It+UvW5LolKbW8= +github.com/aws/rolesanywhere-credential-helper v1.0.4/go.mod h1:QVGNxlDlYhjR0/ZUee7uGl0hNChWidNpe2+GD87Buqk= github.com/aws/smithy-go v1.8.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E= github.com/aws/smithy-go v1.21.0 h1:H7L8dtDRk0P1Qm6y0ji7MCYMQObJ5R9CRpyPhRUkLYA= github.com/aws/smithy-go v1.21.0/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= @@ -1381,6 +1383,8 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA= github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdIA3Xl7cH8g= +github.com/vmware/vmware-go-kcl v1.5.1 h1:1rJLfAX4sDnCyatNoD/WJzVafkwST6u/cgY/Uf2VgHk= +github.com/vmware/vmware-go-kcl v1.5.1/go.mod h1:kXJmQ6h0dRMRrp1uWU9XbIXvwelDpTxSPquvQUBdpbo= github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= From 1a6a75a1ce73de95382511bc0695d0fc038ca9a2 Mon Sep 17 00:00:00 2001 From: Yaron Schneider Date: Mon, 18 Nov 2024 13:18:20 -0800 Subject: [PATCH 08/12] Enable in order processing of eventhubs messages (#3605) Signed-off-by: yaron2 --- bindings/azure/eventhubs/metadata.yaml | 8 ++++++ common/component/azure/eventhubs/eventhubs.go | 6 ++++- .../azure/eventhubs/eventhubs_test.go | 12 +++++++++ common/component/azure/eventhubs/metadata.go | 25 ++++++++++--------- pubsub/azure/eventhubs/metadata.yaml | 7 ++++++ 5 files changed, 45 insertions(+), 13 deletions(-) diff --git a/bindings/azure/eventhubs/metadata.yaml b/bindings/azure/eventhubs/metadata.yaml index 439fc5a9e6..841758b940 100644 --- a/bindings/azure/eventhubs/metadata.yaml +++ b/bindings/azure/eventhubs/metadata.yaml @@ -55,7 +55,15 @@ builtinAuthenticationProfiles: default: "false" example: "false" description: | + Allow management of the Event Hub namespace and storage account. + - name: enableInOrderMessageDelivery + type: bool + required: false + default: "false" + example: "false" + description: | + Enable in order processing of messages within a partition. - name: resourceGroupName type: string required: false diff --git a/common/component/azure/eventhubs/eventhubs.go b/common/component/azure/eventhubs/eventhubs.go index f5724e891f..e5b4cfd6c0 100644 --- a/common/component/azure/eventhubs/eventhubs.go +++ b/common/component/azure/eventhubs/eventhubs.go @@ -393,7 +393,11 @@ func (aeh *AzureEventHubs) processEvents(subscribeCtx context.Context, partition if len(events) != 0 { // Handle received message - go aeh.handleAsync(subscribeCtx, config.Topic, events, config.Handler) + if aeh.metadata.EnableInOrderMessageDelivery { + aeh.handleAsync(subscribeCtx, config.Topic, events, config.Handler) + } else { + go aeh.handleAsync(subscribeCtx, config.Topic, events, config.Handler) + } // Checkpointing disabled for CheckPointFrequencyPerPartition == 0 if config.CheckPointFrequencyPerPartition > 0 { diff --git a/common/component/azure/eventhubs/eventhubs_test.go b/common/component/azure/eventhubs/eventhubs_test.go index bf63c19b6b..8381c07743 100644 --- a/common/component/azure/eventhubs/eventhubs_test.go +++ b/common/component/azure/eventhubs/eventhubs_test.go @@ -72,6 +72,18 @@ func TestParseEventHubsMetadata(t *testing.T) { require.Error(t, err) require.ErrorContains(t, err, "one of connectionString or eventHubNamespace is required") }) + + t.Run("test in order delivery", func(t *testing.T) { + metadata := map[string]string{ + "enableInOrderMessageDelivery": "true", + "connectionString": "fake", + } + + m, err := parseEventHubsMetadata(metadata, false, testLogger) + + require.NoError(t, err) + require.True(t, m.EnableInOrderMessageDelivery) + }) } func TestConstructConnectionStringFromTopic(t *testing.T) { diff --git a/common/component/azure/eventhubs/metadata.go b/common/component/azure/eventhubs/metadata.go index b5e94e114e..00ed07fa7d 100644 --- a/common/component/azure/eventhubs/metadata.go +++ b/common/component/azure/eventhubs/metadata.go @@ -26,18 +26,19 @@ import ( ) type AzureEventHubsMetadata struct { - ConnectionString string `json:"connectionString" mapstructure:"connectionString"` - EventHubNamespace string `json:"eventHubNamespace" mapstructure:"eventHubNamespace"` - ConsumerID string `json:"consumerID" mapstructure:"consumerID"` - StorageConnectionString string `json:"storageConnectionString" mapstructure:"storageConnectionString"` - StorageAccountName string `json:"storageAccountName" mapstructure:"storageAccountName"` - StorageAccountKey string `json:"storageAccountKey" mapstructure:"storageAccountKey"` - StorageContainerName string `json:"storageContainerName" mapstructure:"storageContainerName"` - EnableEntityManagement bool `json:"enableEntityManagement,string" mapstructure:"enableEntityManagement"` - MessageRetentionInDays int32 `json:"messageRetentionInDays,string" mapstructure:"messageRetentionInDays"` - PartitionCount int32 `json:"partitionCount,string" mapstructure:"partitionCount"` - SubscriptionID string `json:"subscriptionID" mapstructure:"subscriptionID"` - ResourceGroupName string `json:"resourceGroupName" mapstructure:"resourceGroupName"` + ConnectionString string `json:"connectionString" mapstructure:"connectionString"` + EventHubNamespace string `json:"eventHubNamespace" mapstructure:"eventHubNamespace"` + ConsumerID string `json:"consumerID" mapstructure:"consumerID"` + StorageConnectionString string `json:"storageConnectionString" mapstructure:"storageConnectionString"` + StorageAccountName string `json:"storageAccountName" mapstructure:"storageAccountName"` + StorageAccountKey string `json:"storageAccountKey" mapstructure:"storageAccountKey"` + StorageContainerName string `json:"storageContainerName" mapstructure:"storageContainerName"` + EnableEntityManagement bool `json:"enableEntityManagement,string" mapstructure:"enableEntityManagement"` + MessageRetentionInDays int32 `json:"messageRetentionInDays,string" mapstructure:"messageRetentionInDays"` + PartitionCount int32 `json:"partitionCount,string" mapstructure:"partitionCount"` + SubscriptionID string `json:"subscriptionID" mapstructure:"subscriptionID"` + ResourceGroupName string `json:"resourceGroupName" mapstructure:"resourceGroupName"` + EnableInOrderMessageDelivery bool `json:"enableInOrderMessageDelivery,string" mapstructure:"enableInOrderMessageDelivery"` // Binding only EventHub string `json:"eventHub" mapstructure:"eventHub" mdonly:"bindings"` diff --git a/pubsub/azure/eventhubs/metadata.yaml b/pubsub/azure/eventhubs/metadata.yaml index 768d472252..57c73c721c 100644 --- a/pubsub/azure/eventhubs/metadata.yaml +++ b/pubsub/azure/eventhubs/metadata.yaml @@ -35,6 +35,13 @@ builtinAuthenticationProfiles: example: "false" description: | Allow management of the Event Hub namespace and storage account. + - name: enableInOrderMessageDelivery + type: bool + required: false + default: "false" + example: "false" + description: | + Enable in order processing of messages within a partition. # The following four properties are needed only if enableEntityManagement is set to true - name: resourceGroupName From 0f09d25bcd4638af758e444cab4c99237645b298 Mon Sep 17 00:00:00 2001 From: Fabian Martinez <46371672+famarting@users.noreply.github.com> Date: Wed, 20 Nov 2024 15:47:40 +0100 Subject: [PATCH 09/12] postgres binding, ping on init (#3595) Signed-off-by: Fabian Martinez <46371672+famarting@users.noreply.github.com> --- bindings/postgres/metadata.go | 5 ++ bindings/postgres/metadata_test.go | 88 ++++++++++++++++++++++++++++++ bindings/postgres/postgres.go | 11 +++- bindings/postgres/postgres_test.go | 45 +++++++++++++++ state/postgresql/v2/postgresql.go | 4 +- 5 files changed, 150 insertions(+), 3 deletions(-) create mode 100644 bindings/postgres/metadata_test.go diff --git a/bindings/postgres/metadata.go b/bindings/postgres/metadata.go index b4747c33ff..33eae83f58 100644 --- a/bindings/postgres/metadata.go +++ b/bindings/postgres/metadata.go @@ -14,6 +14,7 @@ limitations under the License. package postgres import ( + "errors" "time" "github.com/dapr/components-contrib/common/authentication/aws" @@ -53,5 +54,9 @@ func (m *psqlMetadata) InitWithMetadata(meta map[string]string) error { return err } + if m.Timeout < 1*time.Second { + return errors.New("invalid value for 'timeout': must be greater than 1s") + } + return nil } diff --git a/bindings/postgres/metadata_test.go b/bindings/postgres/metadata_test.go new file mode 100644 index 0000000000..ece5e433ed --- /dev/null +++ b/bindings/postgres/metadata_test.go @@ -0,0 +1,88 @@ +/* +Copyright 2023 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package postgres + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMetadata(t *testing.T) { + t.Run("missing connection string", func(t *testing.T) { + m := psqlMetadata{} + props := map[string]string{} + + err := m.InitWithMetadata(props) + require.Error(t, err) + require.ErrorContains(t, err, "connection string") + }) + + t.Run("has connection string", func(t *testing.T) { + m := psqlMetadata{} + props := map[string]string{ + "connectionString": "foo", + } + + err := m.InitWithMetadata(props) + require.NoError(t, err) + }) + + t.Run("default timeout", func(t *testing.T) { + m := psqlMetadata{} + props := map[string]string{ + "connectionString": "foo", + } + + err := m.InitWithMetadata(props) + require.NoError(t, err) + assert.Equal(t, 20*time.Second, m.Timeout) + }) + + t.Run("invalid timeout", func(t *testing.T) { + m := psqlMetadata{} + props := map[string]string{ + "connectionString": "foo", + "timeout": "NaN", + } + + err := m.InitWithMetadata(props) + require.Error(t, err) + }) + + t.Run("positive timeout", func(t *testing.T) { + m := psqlMetadata{} + props := map[string]string{ + "connectionString": "foo", + "timeout": "42", + } + + err := m.InitWithMetadata(props) + require.NoError(t, err) + assert.Equal(t, 42*time.Second, m.Timeout) + }) + + t.Run("zero timeout", func(t *testing.T) { + m := psqlMetadata{} + props := map[string]string{ + "connectionString": "foo", + "timeout": "0", + } + + err := m.InitWithMetadata(props) + require.Error(t, err) + }) +} diff --git a/bindings/postgres/postgres.go b/bindings/postgres/postgres.go index c9dc6bcfbe..e6dce08e0d 100644 --- a/bindings/postgres/postgres.go +++ b/bindings/postgres/postgres.go @@ -73,11 +73,20 @@ func (p *Postgres) Init(ctx context.Context, meta bindings.Metadata) error { // This context doesn't control the lifetime of the connection pool, and is // only scoped to postgres creating resources at init. - p.db, err = pgxpool.NewWithConfig(ctx, poolConfig) + connCtx, connCancel := context.WithTimeout(ctx, m.Timeout) + defer connCancel() + p.db, err = pgxpool.NewWithConfig(connCtx, poolConfig) if err != nil { return fmt.Errorf("unable to connect to the DB: %w", err) } + pingCtx, pingCancel := context.WithTimeout(ctx, m.Timeout) + defer pingCancel() + err = p.db.Ping(pingCtx) + if err != nil { + return fmt.Errorf("failed to ping the DB: %w", err) + } + return nil } diff --git a/bindings/postgres/postgres_test.go b/bindings/postgres/postgres_test.go index c24fc099fb..6a517fcd6a 100644 --- a/bindings/postgres/postgres_test.go +++ b/bindings/postgres/postgres_test.go @@ -15,6 +15,7 @@ package postgres import ( "context" + "errors" "fmt" "os" "testing" @@ -62,6 +63,10 @@ func TestPostgresIntegration(t *testing.T) { t.SkipNow() } + t.Run("Test init configurations", func(t *testing.T) { + testInitConfiguration(t, url) + }) + // live DB test b := NewPostgres(logger.NewLogger("test")).(*Postgres) m := bindings.Metadata{Base: metadata.Base{Properties: map[string]string{"connectionString": url}}} @@ -131,6 +136,46 @@ func TestPostgresIntegration(t *testing.T) { }) } +// testInitConfiguration tests valid and invalid config settings. +func testInitConfiguration(t *testing.T, connectionString string) { + logger := logger.NewLogger("test") + tests := []struct { + name string + props map[string]string + expectedErr error + }{ + { + name: "Empty", + props: map[string]string{}, + expectedErr: errors.New("missing connection string"), + }, + { + name: "Valid connection string", + props: map[string]string{"connectionString": connectionString}, + expectedErr: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := NewPostgres(logger).(*Postgres) + defer p.Close() + + metadata := bindings.Metadata{ + Base: metadata.Base{Properties: tt.props}, + } + + err := p.Init(context.Background(), metadata) + if tt.expectedErr == nil { + require.NoError(t, err) + } else { + require.Error(t, err) + assert.Equal(t, tt.expectedErr, err) + } + }) + } +} + func assertResponse(t *testing.T, res *bindings.InvokeResponse, err error) { require.NoError(t, err) assert.NotNil(t, res) diff --git a/state/postgresql/v2/postgresql.go b/state/postgresql/v2/postgresql.go index d323ca5c90..a0f44ec043 100644 --- a/state/postgresql/v2/postgresql.go +++ b/state/postgresql/v2/postgresql.go @@ -99,16 +99,16 @@ func (p *PostgreSQL) Init(ctx context.Context, meta state.Metadata) (err error) } connCtx, connCancel := context.WithTimeout(ctx, p.metadata.Timeout) + defer connCancel() p.db, err = pgxpool.NewWithConfig(connCtx, config) - connCancel() if err != nil { err = fmt.Errorf("failed to connect to the database: %w", err) return err } pingCtx, pingCancel := context.WithTimeout(ctx, p.metadata.Timeout) + defer pingCancel() err = p.db.Ping(pingCtx) - pingCancel() if err != nil { err = fmt.Errorf("failed to ping the database: %w", err) return err From f3bd794b12033a27b8d9292ca424daa13fc99bc3 Mon Sep 17 00:00:00 2001 From: luigirende Date: Wed, 20 Nov 2024 17:27:46 +0100 Subject: [PATCH 10/12] Mongo State: fix serialization value in the transaction method (#3576) Signed-off-by: Luigi Rende Co-authored-by: Yaron Schneider --- state/mongodb/mongodb.go | 15 +++++++- tests/conformance/state/state.go | 65 ++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/state/mongodb/mongodb.go b/state/mongodb/mongodb.go index 1bd5472ef9..112043d637 100644 --- a/state/mongodb/mongodb.go +++ b/state/mongodb/mongodb.go @@ -25,6 +25,8 @@ import ( "strings" "time" + "github.com/dapr/components-contrib/contenttype" + "github.com/google/uuid" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/bsonrw" @@ -528,7 +530,18 @@ func (m *MongoDB) doTransaction(sessCtx mongo.SessionContext, operations []state var err error switch req := o.(type) { case state.SetRequest: - err = m.setInternal(sessCtx, &req) + { + isJSON := (len(req.Metadata) > 0 && req.Metadata[metadata.ContentType] == contenttype.JSONContentType) + if isJSON { + if bytes, ok := req.Value.([]byte); ok { + err = json.Unmarshal(bytes, &req.Value) + if err != nil { + break + } + } + } + err = m.setInternal(sessCtx, &req) + } case state.DeleteRequest: err = m.deleteInternal(sessCtx, &req) } diff --git a/tests/conformance/state/state.go b/tests/conformance/state/state.go index 25de650cf9..cfb94dbfbb 100644 --- a/tests/conformance/state/state.go +++ b/tests/conformance/state/state.go @@ -15,6 +15,7 @@ package state import ( "context" + "encoding/base64" "encoding/json" "fmt" "slices" @@ -784,6 +785,70 @@ func ConformanceTests(t *testing.T, props map[string]string, statestore state.St assert.Equal(t, v, res.Data) } }) + + t.Run("transaction-serialization-grpc-json", func(t *testing.T) { + features := statestore.Features() + // this check for exclude redis 7 + if state.FeatureQueryAPI.IsPresent(features) { + json := "{\"id\":1223,\"name\":\"test\"}" + keyTest1 := key + "-key-grpc" + valueTest := []byte(json) + keyTest2 := key + "-key-grpc-no-json" + + metadataTest1 := map[string]string{ + "contentType": "application/json", + } + + operations := []state.TransactionalStateOperation{ + state.SetRequest{ + Key: keyTest1, + Value: valueTest, + Metadata: metadataTest1, + }, + state.SetRequest{ + Key: keyTest2, + Value: valueTest, + }, + } + + expected := map[string][]byte{ + keyTest1: []byte(json), + keyTest2: []byte(json), + } + + expectedMetadata := map[string]map[string]string{ + keyTest1: metadataTest1, + } + + // Act + transactionStore, ok := statestore.(state.TransactionalStore) + assert.True(t, ok) + err := transactionStore.Multi(context.Background(), &state.TransactionalStateRequest{ + Operations: operations, + }) + require.NoError(t, err) + + // Assert + for k, v := range expected { + res, err := statestore.Get(context.Background(), &state.GetRequest{ + Key: k, + Metadata: expectedMetadata[k], + }) + expectedValue := res.Data + + // In redisjson when set the value with contentType = application/Json store the value in base64 + if strings.HasPrefix(string(expectedValue), "\"ey") { + valueBase64 := strings.Trim(string(expectedValue), "\"") + expectedValueDecoded, _ := base64.StdEncoding.DecodeString(valueBase64) + require.NoError(t, err) + assert.Equal(t, expectedValueDecoded, v) + } else { + require.NoError(t, err) + assert.Equal(t, expectedValue, v) + } + } + } + }) } else { t.Run("component does not implement TransactionalStore interface", func(t *testing.T) { _, ok := statestore.(state.TransactionalStore) From e2b27d3538e84957abc62112ca8b22fb56987e97 Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 20 Nov 2024 11:24:49 -0600 Subject: [PATCH 11/12] fix(aws): update close if aws auth provider is nil (#3607) Signed-off-by: Samantha Coyle Co-authored-by: Yaron Schneider --- bindings/aws/dynamodb/dynamodb.go | 5 ++++- bindings/aws/kinesis/kinesis.go | 5 ++++- bindings/aws/s3/s3.go | 5 ++++- bindings/aws/ses/ses.go | 5 ++++- bindings/aws/sns/sns.go | 5 ++++- bindings/aws/sqs/sqs.go | 5 ++++- pubsub/aws/snssqs/snssqs.go | 5 ++++- secretstores/aws/parameterstore/parameterstore.go | 5 ++++- secretstores/aws/secretmanager/secretmanager.go | 5 ++++- state/aws/dynamodb/dynamodb.go | 5 ++++- 10 files changed, 40 insertions(+), 10 deletions(-) diff --git a/bindings/aws/dynamodb/dynamodb.go b/bindings/aws/dynamodb/dynamodb.go index 755b3158d3..2096f22433 100644 --- a/bindings/aws/dynamodb/dynamodb.go +++ b/bindings/aws/dynamodb/dynamodb.go @@ -122,5 +122,8 @@ func (d *DynamoDB) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { } func (d *DynamoDB) Close() error { - return d.authProvider.Close() + if d.authProvider != nil { + return d.authProvider.Close() + } + return nil } diff --git a/bindings/aws/kinesis/kinesis.go b/bindings/aws/kinesis/kinesis.go index 7ede7ba245..bf684f8bbb 100644 --- a/bindings/aws/kinesis/kinesis.go +++ b/bindings/aws/kinesis/kinesis.go @@ -266,7 +266,10 @@ func (a *AWSKinesis) Close() error { close(a.closeCh) } a.wg.Wait() - return a.authProvider.Close() + if a.authProvider != nil { + return a.authProvider.Close() + } + return nil } func (a *AWSKinesis) ensureConsumer(ctx context.Context, streamARN *string) (*string, error) { diff --git a/bindings/aws/s3/s3.go b/bindings/aws/s3/s3.go index 13f8730e78..fa20c70a6b 100644 --- a/bindings/aws/s3/s3.go +++ b/bindings/aws/s3/s3.go @@ -153,7 +153,10 @@ func (s *AWSS3) Init(ctx context.Context, metadata bindings.Metadata) error { } func (s *AWSS3) Close() error { - return s.authProvider.Close() + if s.authProvider != nil { + return s.authProvider.Close() + } + return nil } func (s *AWSS3) Operations() []bindings.OperationKind { diff --git a/bindings/aws/ses/ses.go b/bindings/aws/ses/ses.go index 4cd752bac5..b8d2ff3faa 100644 --- a/bindings/aws/ses/ses.go +++ b/bindings/aws/ses/ses.go @@ -176,5 +176,8 @@ func (a *AWSSES) GetComponentMetadata() (metadataInfo contribMetadata.MetadataMa } func (a *AWSSES) Close() error { - return a.authProvider.Close() + if a.authProvider != nil { + return a.authProvider.Close() + } + return nil } diff --git a/bindings/aws/sns/sns.go b/bindings/aws/sns/sns.go index 55e3ccefa5..5464f1f044 100644 --- a/bindings/aws/sns/sns.go +++ b/bindings/aws/sns/sns.go @@ -128,5 +128,8 @@ func (a *AWSSNS) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { } func (a *AWSSNS) Close() error { - return a.authProvider.Close() + if a.authProvider != nil { + return a.authProvider.Close() + } + return nil } diff --git a/bindings/aws/sqs/sqs.go b/bindings/aws/sqs/sqs.go index d803bafc5a..b09fde61f6 100644 --- a/bindings/aws/sqs/sqs.go +++ b/bindings/aws/sqs/sqs.go @@ -173,7 +173,10 @@ func (a *AWSSQS) Close() error { close(a.closeCh) } a.wg.Wait() - return a.authProvider.Close() + if a.authProvider != nil { + return a.authProvider.Close() + } + return nil } func (a *AWSSQS) parseSQSMetadata(meta bindings.Metadata) (*sqsMetadata, error) { diff --git a/pubsub/aws/snssqs/snssqs.go b/pubsub/aws/snssqs/snssqs.go index 93481fb733..4e2371764b 100644 --- a/pubsub/aws/snssqs/snssqs.go +++ b/pubsub/aws/snssqs/snssqs.go @@ -875,7 +875,10 @@ func (s *snsSqs) Close() error { s.subscriptionManager.Close() } - return s.authProvider.Close() + if s.authProvider != nil { + return s.authProvider.Close() + } + return nil } func (s *snsSqs) Features() []pubsub.Feature { diff --git a/secretstores/aws/parameterstore/parameterstore.go b/secretstores/aws/parameterstore/parameterstore.go index abf9c6c4de..038399b30c 100644 --- a/secretstores/aws/parameterstore/parameterstore.go +++ b/secretstores/aws/parameterstore/parameterstore.go @@ -182,5 +182,8 @@ func (s *ssmSecretStore) GetComponentMetadata() (metadataInfo metadata.MetadataM } func (s *ssmSecretStore) Close() error { - return s.authProvider.Close() + if s.authProvider != nil { + return s.authProvider.Close() + } + return nil } diff --git a/secretstores/aws/secretmanager/secretmanager.go b/secretstores/aws/secretmanager/secretmanager.go index 6faf1f1eab..979739be5b 100644 --- a/secretstores/aws/secretmanager/secretmanager.go +++ b/secretstores/aws/secretmanager/secretmanager.go @@ -170,5 +170,8 @@ func (s *smSecretStore) GetComponentMetadata() (metadataInfo metadata.MetadataMa } func (s *smSecretStore) Close() error { - return s.authProvider.Close() + if s.authProvider != nil { + return s.authProvider.Close() + } + return nil } diff --git a/state/aws/dynamodb/dynamodb.go b/state/aws/dynamodb/dynamodb.go index ae4ba7c5e9..d3bbd39a85 100644 --- a/state/aws/dynamodb/dynamodb.go +++ b/state/aws/dynamodb/dynamodb.go @@ -275,7 +275,10 @@ func (d *StateStore) GetComponentMetadata() (metadataInfo metadata.MetadataMap) } func (d *StateStore) Close() error { - return d.authProvider.Close() + if d.authProvider != nil { + return d.authProvider.Close() + } + return nil } func (d *StateStore) getDynamoDBMetadata(meta state.Metadata) (*dynamoDBMetadata, error) { From f521a76f7b070a1eb939eee90db7f3789c025680 Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 20 Nov 2024 17:22:13 -0600 Subject: [PATCH 12/12] fix: initialize the close chan (#3608) Signed-off-by: Samantha Coyle --- common/authentication/aws/x509.go | 1 + 1 file changed, 1 insertion(+) diff --git a/common/authentication/aws/x509.go b/common/authentication/aws/x509.go index cb1bafdeb3..52af56d271 100644 --- a/common/authentication/aws/x509.go +++ b/common/authentication/aws/x509.go @@ -96,6 +96,7 @@ func newX509(ctx context.Context, opts Options, cfg *aws.Config) (*x509, error) return GetConfig(opts) }(), clients: newClients(), + closeCh: make(chan struct{}), } if err := auth.getCertPEM(ctx); err != nil {