diff --git a/CHANGES.md b/CHANGES.md index 6e26722f96b3..824b7982c4d1 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -59,6 +59,7 @@ ## I/Os * Support for Bigtable sink (Write and WriteBatch) added (Go) ([#23324](https://github.com/apache/beam/issues/23324)). +* S3 implementation of the Beam filesystem (Go) ([#23991](https://github.com/apache/beam/issues/23991)). ## New Features / Improvements diff --git a/sdks/go.mod b/sdks/go.mod index 98358d9b2af7..998f4be4801f 100644 --- a/sdks/go.mod +++ b/sdks/go.mod @@ -28,6 +28,10 @@ require ( cloud.google.com/go/profiler v0.3.0 cloud.google.com/go/pubsub v1.26.0 cloud.google.com/go/storage v1.28.0 + github.com/aws/aws-sdk-go-v2 v1.7.1 + github.com/aws/aws-sdk-go-v2/config v1.5.0 + github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.3.2 + github.com/aws/aws-sdk-go-v2/service/s3 v1.11.1 github.com/docker/go-connections v0.4.0 github.com/dustin/go-humanize v1.0.0 github.com/go-sql-driver/mysql v1.6.0 @@ -67,6 +71,15 @@ require ( github.com/Microsoft/hcsshim v0.9.4 // indirect github.com/apache/arrow/go/arrow v0.0.0-20200730104253-651201b0f516 // indirect github.com/apache/thrift v0.14.2 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.3.1 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.3.0 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.1.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.2.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.2.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.5.1 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.3.1 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.6.0 // indirect + github.com/aws/smithy-go v1.6.0 // indirect github.com/cenkalti/backoff/v4 v4.1.3 // indirect github.com/census-instrumentation/opencensus-proto v0.2.1 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect @@ -86,6 +99,7 @@ require ( github.com/googleapis/enterprise-certificate-proxy v0.2.0 // indirect github.com/googleapis/gax-go/v2 v2.6.0 // indirect github.com/inconshreveable/mousetrap v1.0.1 // indirect + github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/klauspost/compress v1.13.1 // indirect github.com/magiconair/properties v1.8.6 // indirect github.com/moby/sys/mount v0.3.3 // indirect diff --git a/sdks/go.sum b/sdks/go.sum index 683669e8f392..35ba044dfd6c 100644 --- a/sdks/go.sum +++ b/sdks/go.sum @@ -154,18 +154,31 @@ github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= github.com/aws/aws-sdk-go v1.15.11/go.mod h1:mFuSZ37Z9YOHbQEwBWztmVzqXrEkub65tZoCYDt7FT0= github.com/aws/aws-sdk-go v1.30.19/go.mod h1:5zCpMtNQVjRREroY7sYe8lOMRSxkhG6MZveU8YkpAk0= +github.com/aws/aws-sdk-go-v2 v1.7.1 h1:TswSc7KNqZ/K1Ijt3IkpXk/2+62vi3Q82Yrr5wSbRBQ= github.com/aws/aws-sdk-go-v2 v1.7.1/go.mod h1:L5LuPC1ZgDr2xQS7AmIec/Jlc7O/Y1u2KxJyNVab250= +github.com/aws/aws-sdk-go-v2/config v1.5.0 h1:tRQcWXVmO7wC+ApwYc2LiYKfIBoIrdzcJ+7HIh6AlR0= github.com/aws/aws-sdk-go-v2/config v1.5.0/go.mod h1:RWlPOAW3E3tbtNAqTwvSW54Of/yP3oiZXMI0xfUdjyA= +github.com/aws/aws-sdk-go-v2/credentials v1.3.1 h1:fFeqL5+9kwFKsCb2oci5yAIDsWYqn/Nga8oQ5bIasI8= github.com/aws/aws-sdk-go-v2/credentials v1.3.1/go.mod h1:r0n73xwsIVagq8RsxmZbGSRQFj9As3je72C2WzUIToc= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.3.0 h1:s4vtv3Mv1CisI3qm2HGHi1Ls9ZtbCOEqeQn6oz7fTyU= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.3.0/go.mod h1:2LAuqPx1I6jNfaGDucWfA2zqQCYCOMCDHiCOciALyNw= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.3.2 h1:fzEMxnHQWh+bUV0ZzfhMbgUG8zjIPnAgApjtdHtC9Yg= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.3.2/go.mod h1:qaqQiHSrOUVOfKe6fhgQ6UzhxjwqVW8aHNegd6Ws4w4= +github.com/aws/aws-sdk-go-v2/internal/ini v1.1.1 h1:SDLwr1NKyowP7uqxuLNdvFZhjnoVWxNv456zAp+ZFjU= github.com/aws/aws-sdk-go-v2/internal/ini v1.1.1/go.mod h1:Zy8smImhTdOETZqfyn01iNOe0CNggVbPjCajyaz6Gvg= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.2.1 h1:s/uV8UyMB4UcO0ERHxG9BJhYJAD9MiY0QeYvJmlC7PE= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.2.1/go.mod h1:v33JQ57i2nekYTA70Mb+O18KeH4KqhdqxTJZNK1zdRE= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.2.1 h1:VJe/XEhrfyfBLupcGg1BfUSK2VMZNdbDcZQ49jnp+h0= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.2.1/go.mod h1:zceowr5Z1Nh2WVP8bf/3ikB41IZW59E4yIYbg+pC6mw= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.5.1 h1:1ds3HkMQEBx9XvOkqsPuqBmNFn0w8XEDuB4LOi6KepU= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.5.1/go.mod h1:6EQZIwNNvHpq/2/QSJnp4+ECvqIy55w95Ofs0ze+nGQ= +github.com/aws/aws-sdk-go-v2/service/s3 v1.11.1 h1:HiXhafnqG0AkVJIZA/BHhFvuc/8xFdUO1uaeqF2Artc= github.com/aws/aws-sdk-go-v2/service/s3 v1.11.1/go.mod h1:XLAGFrEjbvMCLvAtWLLP32yTv8GpBquCApZEycDLunI= +github.com/aws/aws-sdk-go-v2/service/sso v1.3.1 h1:H2ZLWHUbbeYtghuqCY5s/7tbBM99PAwCioRJF8QvV/U= github.com/aws/aws-sdk-go-v2/service/sso v1.3.1/go.mod h1:J3A3RGUvuCZjvSuZEcOpHDnzZP/sKbhDWV2T1EOzFIM= +github.com/aws/aws-sdk-go-v2/service/sts v1.6.0 h1:Y9r6mrzOyAYz4qKaluSH19zqH1236il/nGbsPKOUT0s= github.com/aws/aws-sdk-go-v2/service/sts v1.6.0/go.mod h1:q7o0j7d7HrJk/vr9uUt3BVRASvcU7gYZB9PUgPiByXg= +github.com/aws/smithy-go v1.6.0 h1:T6puApfBcYiTIsaI+SYWqanjMt5pc3aoyyDrI+0YH54= github.com/aws/smithy-go v1.6.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E= github.com/benbjohnson/clock v1.0.3/go.mod h1:bGMdMPoPVvcYyt1gHDf4J2KE153Yf9BuiUKYMaxlTDM= github.com/beorn7/perks v0.0.0-20160804104726-4c0e84591b9a/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= @@ -661,7 +674,9 @@ github.com/jcmturner/gofork v0.0.0-20180107083740-2aebee971930/go.mod h1:MK8+TM0 github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jmespath/go-jmespath v0.0.0-20160803190731-bd40a432e4c7/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= +github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/joefitzgerald/rainbow-reporter v0.1.0/go.mod h1:481CNgqmVHQZzdIbN52CupLJyoVwB10FQ/IQlF1pdL8= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= diff --git a/sdks/go/pkg/beam/io/filesystem/filesystem.go b/sdks/go/pkg/beam/io/filesystem/filesystem.go index 43624ad45f08..43b4fd068b5d 100644 --- a/sdks/go/pkg/beam/io/filesystem/filesystem.go +++ b/sdks/go/pkg/beam/io/filesystem/filesystem.go @@ -34,12 +34,13 @@ import ( var registry = make(map[string]func(context.Context) Interface) -// wellKnownSchemeImportPaths is used for deliverng useful error messages when a +// wellKnownSchemeImportPaths is used for delivering useful error messages when a // scheme is not found. var wellKnownSchemeImportPaths = map[string]string{ "memfs": "github.com/apache/beam/sdks/v2/go/pkg/beam/io/filesystem/memfs", "default": "github.com/apache/beam/sdks/v2/go/pkg/beam/io/filesystem/local", "gs": "github.com/apache/beam/sdks/v2/go/pkg/beam/io/filesystem/gcs", + "s3": "github.com/apache/beam/sdks/v2/go/pkg/beam/io/filesystem/s3", } // Register registers a file system backend under the given scheme. For diff --git a/sdks/go/pkg/beam/io/filesystem/s3/s3.go b/sdks/go/pkg/beam/io/filesystem/s3/s3.go new file mode 100644 index 000000000000..08dd5cd99d7d --- /dev/null +++ b/sdks/go/pkg/beam/io/filesystem/s3/s3.go @@ -0,0 +1,208 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package s3 contains an AWS S3 implementation of the Beam file system. +package s3 + +import ( + "context" + "fmt" + "io" + "path/filepath" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/io/filesystem" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +func init() { + filesystem.Register("s3", New) +} + +type fs struct { + client *s3.Client +} + +// New creates a new S3 filesystem using AWS default configuration sources. +func New(ctx context.Context) filesystem.Interface { + cfg, err := config.LoadDefaultConfig(ctx) + if err != nil { + panic(fmt.Sprintf("error loading AWS config: %v", err)) + } + + client := s3.NewFromConfig(cfg) + return &fs{client: client} +} + +// Close closes the filesystem. +func (f *fs) Close() error { + return nil +} + +// List returns a slice of the files in the filesystem that match the glob pattern. +func (f *fs) List(ctx context.Context, glob string) ([]string, error) { + bucket, keyPattern, err := parseUri(glob) + if err != nil { + return nil, fmt.Errorf("error parsing S3 uri: %v", err) + } + + keys, err := f.listObjectKeys(ctx, bucket, keyPattern) + if err != nil { + return nil, fmt.Errorf("error listing object keys: %v", err) + } + + uris := make([]string, len(keys)) + for i, key := range keys { + uris[i] = makeUri(bucket, key) + } + + return uris, nil +} + +// listObjectKeys returns a slice of the keys in the bucket that match the key pattern. +func (f *fs) listObjectKeys( + ctx context.Context, + bucket string, + keyPattern string, +) ([]string, error) { + prefix := getPrefix(keyPattern) + params := &s3.ListObjectsV2Input{ + Bucket: aws.String(bucket), + Prefix: aws.String(prefix), + } + paginator := s3.NewListObjectsV2Paginator(f.client, params) + + var objects []string + for paginator.HasMorePages() { + output, err := paginator.NextPage(ctx) + if err != nil { + return nil, fmt.Errorf("error retrieving page: %v", err) + } + + for _, object := range output.Contents { + key := aws.ToString(object.Key) + match, err := filepath.Match(keyPattern, key) + if err != nil { + return nil, fmt.Errorf("invalid key pattern: %s", keyPattern) + } + + if match { + objects = append(objects, key) + } + } + } + + return objects, nil +} + +// OpenRead returns a new io.ReadCloser to read contents from the file. The caller must call Close +// on the returned io.ReadCloser when done reading. +func (f *fs) OpenRead(ctx context.Context, filename string) (io.ReadCloser, error) { + bucket, key, err := parseUri(filename) + if err != nil { + return nil, fmt.Errorf("error parsing S3 uri %s: %v", filename, err) + } + + params := &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + } + output, err := f.client.GetObject(ctx, params) + if err != nil { + return nil, fmt.Errorf("error getting object %s: %v", filename, err) + } + + return output.Body, nil +} + +// OpenWrite returns a new io.WriteCloser to write contents to the file. The caller must call Close +// on the returned io.WriteCloser when done writing. +func (f *fs) OpenWrite(ctx context.Context, filename string) (io.WriteCloser, error) { + bucket, key, err := parseUri(filename) + if err != nil { + return nil, fmt.Errorf("error parsing S3 uri %s: %v", filename, err) + } + + return newWriter(ctx, f.client, bucket, key), nil +} + +// Size returns the size of the file. +func (f *fs) Size(ctx context.Context, filename string) (int64, error) { + bucket, key, err := parseUri(filename) + if err != nil { + return -1, fmt.Errorf("error parsing S3 uri %s: %v", filename, err) + } + + params := &s3.HeadObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + } + output, err := f.client.HeadObject(ctx, params) + if err != nil { + return -1, fmt.Errorf("error getting metadata for object %s: %v", filename, err) + } + + return output.ContentLength, err +} + +// Remove removes the file from the filesystem. +func (f *fs) Remove(ctx context.Context, filename string) error { + bucket, key, err := parseUri(filename) + if err != nil { + return fmt.Errorf("error parsing S3 uri %s: %v", filename, err) + } + + params := &s3.DeleteObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + } + if _, err = f.client.DeleteObject(ctx, params); err != nil { + return fmt.Errorf("error deleting object %s: %v", filename, err) + } + + return nil +} + +// Copy copies the file from the old path to the new path. +func (f *fs) Copy(ctx context.Context, oldpath, newpath string) error { + sourceBucket, sourceKey, err := parseUri(oldpath) + if err != nil { + return fmt.Errorf("error parsing S3 source uri %s: %v", oldpath, err) + } + + copySource := fmt.Sprintf("%s/%s", sourceBucket, sourceKey) + destBucket, destKey, err := parseUri(newpath) + if err != nil { + return fmt.Errorf("error parsing S3 destination uri %s: %v", newpath, err) + } + + params := &s3.CopyObjectInput{ + Bucket: aws.String(destBucket), + CopySource: aws.String(copySource), + Key: aws.String(destKey), + } + if _, err = f.client.CopyObject(ctx, params); err != nil { + return fmt.Errorf("error copying object %s: %v", oldpath, err) + } + + return nil +} + +// Compile time check for interface implementations. +var ( + _ filesystem.Remover = (*fs)(nil) + _ filesystem.Copier = (*fs)(nil) +) diff --git a/sdks/go/pkg/beam/io/filesystem/s3/util.go b/sdks/go/pkg/beam/io/filesystem/s3/util.go new file mode 100644 index 000000000000..db076594dac5 --- /dev/null +++ b/sdks/go/pkg/beam/io/filesystem/s3/util.go @@ -0,0 +1,60 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package s3 + +import ( + "errors" + "fmt" + "net/url" + "strings" +) + +// parseUri deconstructs the S3 uri in the format 's3://bucket/key' to (bucket, key) +func parseUri(uri string) (string, string, error) { + parsed, err := url.Parse(uri) + if err != nil { + return "", "", err + } + + if parsed.Scheme != "s3" { + return "", "", errors.New("scheme must be 's3'") + } + + bucket := parsed.Host + if bucket == "" { + return "", "", errors.New("bucket must not be empty") + } + + var key string + if parsed.Path != "" { + key = parsed.Path[1:] + } + + return bucket, key, nil +} + +// makeUri constructs an S3 uri from the bucket and key to the format 's3://bucket/key' +func makeUri(bucket string, key string) string { + return fmt.Sprintf("s3://%s/%s", bucket, key) +} + +// getPrefix returns the prefix of the key pattern before the first wildcard, if any +func getPrefix(keyPattern string) string { + if index := strings.Index(keyPattern, "*"); index >= 0 { + return keyPattern[:index] + } + return keyPattern +} diff --git a/sdks/go/pkg/beam/io/filesystem/s3/util_test.go b/sdks/go/pkg/beam/io/filesystem/s3/util_test.go new file mode 100644 index 000000000000..2c98e7dc07e0 --- /dev/null +++ b/sdks/go/pkg/beam/io/filesystem/s3/util_test.go @@ -0,0 +1,116 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package s3 + +import "testing" + +func Test_parseUri(t *testing.T) { + tests := []struct { + name string + uri string + wantBucket string + wantKey string + wantErr bool + }{ + { + name: "Valid uri with non-empty key", + uri: "s3://bucket/path/to/key", + wantBucket: "bucket", + wantKey: "path/to/key", + wantErr: false, + }, + { + name: "Valid uri with empty key", + uri: "s3://bucket", + wantBucket: "bucket", + wantKey: "", + wantErr: false, + }, + { + name: "Invalid uri: missing scheme", + uri: "bucket/path/to/key", + wantBucket: "", + wantKey: "", + wantErr: true, + }, + { + name: "Invalid uri: wrong scheme", + uri: "file://bucket/path/to/key", + wantBucket: "", + wantKey: "", + wantErr: true, + }, + { + name: "Invalid uri: missing bucket", + uri: "s3://", + wantBucket: "", + wantKey: "", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotBucket, gotKey, err := parseUri(tt.uri) + if (err != nil) != tt.wantErr { + t.Errorf("parseUri() err = %v, want %v", err, tt.wantErr) + } + if gotBucket != tt.wantBucket { + t.Errorf("parseUri() bucket = %v, want %v", gotBucket, tt.wantBucket) + } + if gotKey != tt.wantKey { + t.Errorf("parseUri() key = %v, want %v", gotKey, tt.wantKey) + } + }) + } +} + +func Test_makeUri(t *testing.T) { + bucket := "bucket" + key := "path/to/key" + want := "s3://bucket/path/to/key" + + if got := makeUri(bucket, key); got != want { + t.Errorf("makeUri() = %v, want %v", got, want) + } +} + +func Test_getPrefix(t *testing.T) { + tests := []struct { + name string + keyPattern string + want string + }{ + { + name: "Key pattern with wildcards", + keyPattern: "path/**/*.json", + want: "path/", + }, + { + name: "Key pattern without wildcards", + keyPattern: "path/file.json", + want: "path/file.json", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getPrefix(tt.keyPattern); got != tt.want { + t.Errorf("getPrefix() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/sdks/go/pkg/beam/io/filesystem/s3/writer.go b/sdks/go/pkg/beam/io/filesystem/s3/writer.go new file mode 100644 index 000000000000..3f16e770d1d5 --- /dev/null +++ b/sdks/go/pkg/beam/io/filesystem/s3/writer.go @@ -0,0 +1,100 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package s3 + +import ( + "context" + "io" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +type writer struct { + ctx context.Context + client *s3.Client + bucket string + key string + done chan struct{} + isOpened bool + pw *io.PipeWriter + err error +} + +// newWriter returns a writer that creates and writes to an S3 object. If an object with the same +// bucket and key already exists, it will be overwritten. The caller must call Close on the writer +// when done writing for the object to become available. +func newWriter( + ctx context.Context, + client *s3.Client, + bucket string, + key string, +) *writer { + return &writer{ + ctx: ctx, + client: client, + bucket: bucket, + key: key, + done: make(chan struct{}), + } +} + +// Write writes data to a pipe. +func (w *writer) Write(p []byte) (int, error) { + if !w.isOpened { + w.open() + } + + return w.pw.Write(p) +} + +// Close completes the write operation. +func (w *writer) Close() error { + if !w.isOpened { + w.open() + } + + if err := w.pw.Close(); err != nil { + return err + } + + <-w.done + return w.err +} + +// open creates a pipe for writing to the S3 object. +func (w *writer) open() { + pr, pw := io.Pipe() + w.pw = pw + + go func() { + defer close(w.done) + + params := &s3.PutObjectInput{ + Bucket: aws.String(w.bucket), + Key: aws.String(w.key), + Body: io.Reader(pr), + } + uploader := manager.NewUploader(w.client) + if _, err := uploader.Upload(w.ctx, params); err != nil { + w.err = err + pr.CloseWithError(err) + } + }() + + w.isOpened = true +}