From fa8776ae648c402a55d4cad500a9b0a86608aa60 Mon Sep 17 00:00:00 2001 From: Cody Littley <56973212+cody-littley@users.noreply.github.com> Date: Fri, 1 Nov 2024 16:27:05 -0500 Subject: [PATCH] Revert "Revert "S3 relay interface"" (#853) Signed-off-by: Cody Littley --- common/aws/cli.go | 102 ++++++++-- common/aws/s3/client.go | 229 ++++++++++++++++++++--- common/aws/s3/fragment.go | 128 +++++++++++++ common/aws/s3/fragment_test.go | 333 +++++++++++++++++++++++++++++++++ common/aws/s3/s3.go | 43 ++++- common/aws/test/client_test.go | 176 +++++++++++++++++ common/mock/s3_client.go | 31 ++- 7 files changed, 1001 insertions(+), 41 deletions(-) create mode 100644 common/aws/s3/fragment.go create mode 100644 common/aws/s3/fragment_test.go create mode 100644 common/aws/test/client_test.go diff --git a/common/aws/cli.go b/common/aws/cli.go index 5a6d11503..e88618d45 100644 --- a/common/aws/cli.go +++ b/common/aws/cli.go @@ -3,20 +3,48 @@ package aws import ( "github.com/Layr-Labs/eigenda/common" "github.com/urfave/cli" + "time" ) var ( - RegionFlagName = "aws.region" - AccessKeyIdFlagName = "aws.access-key-id" - SecretAccessKeyFlagName = "aws.secret-access-key" - EndpointURLFlagName = "aws.endpoint-url" + RegionFlagName = "aws.region" + AccessKeyIdFlagName = "aws.access-key-id" + SecretAccessKeyFlagName = "aws.secret-access-key" + EndpointURLFlagName = "aws.endpoint-url" + FragmentPrefixCharsFlagName = "aws.fragment-prefix-chars" + FragmentParallelismFactorFlagName = "aws.fragment-parallelism-factor" + FragmentParallelismConstantFlagName = "aws.fragment-parallelism-constant" + FragmentReadTimeoutFlagName = "aws.fragment-read-timeout" + FragmentWriteTimeoutFlagName = "aws.fragment-write-timeout" ) type ClientConfig struct { - Region string - AccessKey string + // Region is the region to use when interacting with S3. Default is "us-east-2". + Region string + // AccessKey to use when interacting with S3. + AccessKey string + // SecretAccessKey to use when interacting with S3. SecretAccessKey string - EndpointURL string + // EndpointURL of the S3 endpoint to use. If this is not set then the default AWS S3 endpoint will be used. + EndpointURL string + + // FragmentPrefixChars is the number of characters of the key to use as the prefix for fragmented files. + // A value of "3" for the key "ABCDEFG" will result in the prefix "ABC". Default is 3. + FragmentPrefixChars int + // FragmentParallelismFactor helps determine the size of the pool of workers to help upload/download files. + // A non-zero value for this parameter adds a number of workers equal to the number of cores times this value. + // Default is 8. In general, the number of workers here can be a lot larger than the number of cores because the + // workers will be blocked on I/O most of the time. + FragmentParallelismFactor int + // FragmentParallelismConstant helps determine the size of the pool of workers to help upload/download files. + // A non-zero value for this parameter adds a constant number of workers. Default is 0. + FragmentParallelismConstant int + // FragmentReadTimeout is used to bound the maximum time to wait for a single fragmented read. + // Default is 30 seconds. + FragmentReadTimeout time.Duration + // FragmentWriteTimeout is used to bound the maximum time to wait for a single fragmented write. + // Default is 30 seconds. + FragmentWriteTimeout time.Duration } func ClientFlags(envPrefix string, flagPrefix string) []cli.Flag { @@ -48,14 +76,66 @@ func ClientFlags(envPrefix string, flagPrefix string) []cli.Flag { Value: "", EnvVar: common.PrefixEnvVar(envPrefix, "AWS_ENDPOINT_URL"), }, + cli.IntFlag{ + Name: common.PrefixFlag(flagPrefix, FragmentPrefixCharsFlagName), + Usage: "The number of characters of the key to use as the prefix for fragmented files", + Required: false, + Value: 3, + EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PREFIX_CHARS"), + }, + cli.IntFlag{ + Name: common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName), + Usage: "Add this many threads times the number of cores to the worker pool", + Required: false, + Value: 8, + EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_FACTOR"), + }, + cli.IntFlag{ + Name: common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName), + Usage: "Add this many threads to the worker pool", + Required: false, + Value: 0, + EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_CONSTANT"), + }, + cli.DurationFlag{ + Name: common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName), + Usage: "The maximum time to wait for a single fragmented read", + Required: false, + Value: 30 * time.Second, + EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_READ_TIMEOUT"), + }, + cli.DurationFlag{ + Name: common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName), + Usage: "The maximum time to wait for a single fragmented write", + Required: false, + Value: 30 * time.Second, + EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_WRITE_TIMEOUT"), + }, } } func ReadClientConfig(ctx *cli.Context, flagPrefix string) ClientConfig { return ClientConfig{ - Region: ctx.GlobalString(common.PrefixFlag(flagPrefix, RegionFlagName)), - AccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, AccessKeyIdFlagName)), - SecretAccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, SecretAccessKeyFlagName)), - EndpointURL: ctx.GlobalString(common.PrefixFlag(flagPrefix, EndpointURLFlagName)), + Region: ctx.GlobalString(common.PrefixFlag(flagPrefix, RegionFlagName)), + AccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, AccessKeyIdFlagName)), + SecretAccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, SecretAccessKeyFlagName)), + EndpointURL: ctx.GlobalString(common.PrefixFlag(flagPrefix, EndpointURLFlagName)), + FragmentPrefixChars: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentPrefixCharsFlagName)), + FragmentParallelismFactor: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName)), + FragmentParallelismConstant: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName)), + FragmentReadTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName)), + FragmentWriteTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName)), + } +} + +// DefaultClientConfig returns a new ClientConfig with default values. +func DefaultClientConfig() *ClientConfig { + return &ClientConfig{ + Region: "us-east-2", + FragmentPrefixChars: 3, + FragmentParallelismFactor: 8, + FragmentParallelismConstant: 0, + FragmentReadTimeout: 30 * time.Second, + FragmentWriteTimeout: 30 * time.Second, } } diff --git a/common/aws/s3/client.go b/common/aws/s3/client.go index 231d546ae..ddc3ce4e1 100644 --- a/common/aws/s3/client.go +++ b/common/aws/s3/client.go @@ -4,6 +4,8 @@ import ( "bytes" "context" "errors" + "golang.org/x/sync/errgroup" + "runtime" "sync" commonaws "github.com/Layr-Labs/eigenda/common/aws" @@ -27,7 +29,9 @@ type Object struct { } type client struct { + cfg *commonaws.ClientConfig s3Client *s3.Client + pool *errgroup.Group logger logging.Logger } @@ -36,18 +40,19 @@ var _ Client = (*client)(nil) func NewClient(ctx context.Context, cfg commonaws.ClientConfig, logger logging.Logger) (*client, error) { var err error once.Do(func() { - customResolver := aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) { - if cfg.EndpointURL != "" { - return aws.Endpoint{ - PartitionID: "aws", - URL: cfg.EndpointURL, - SigningRegion: cfg.Region, - }, nil - } - - // returning EndpointNotFoundError will allow the service to fallback to its default resolution - return aws.Endpoint{}, &aws.EndpointNotFoundError{} - }) + customResolver := aws.EndpointResolverWithOptionsFunc( + func(service, region string, options ...interface{}) (aws.Endpoint, error) { + if cfg.EndpointURL != "" { + return aws.Endpoint{ + PartitionID: "aws", + URL: cfg.EndpointURL, + SigningRegion: cfg.Region, + }, nil + } + + // returning EndpointNotFoundError will allow the service to fallback to its default resolution + return aws.Endpoint{}, &aws.EndpointNotFoundError{} + }) options := [](func(*config.LoadOptions) error){ config.WithRegion(cfg.Region), @@ -56,7 +61,9 @@ func NewClient(ctx context.Context, cfg commonaws.ClientConfig, logger logging.L } // If access key and secret access key are not provided, use the default credential provider if len(cfg.AccessKey) > 0 && len(cfg.SecretAccessKey) > 0 { - options = append(options, config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(cfg.AccessKey, cfg.SecretAccessKey, ""))) + options = append(options, + config.WithCredentialsProvider( + credentials.NewStaticCredentialsProvider(cfg.AccessKey, cfg.SecretAccessKey, ""))) } awsConfig, errCfg := config.LoadDefaultConfig(context.Background(), options...) @@ -64,23 +71,34 @@ func NewClient(ctx context.Context, cfg commonaws.ClientConfig, logger logging.L err = errCfg return } + s3Client := s3.NewFromConfig(awsConfig, func(o *s3.Options) { o.UsePathStyle = true }) - ref = &client{s3Client: s3Client, logger: logger.With("component", "S3Client")} - }) - return ref, err -} -func (s *client) CreateBucket(ctx context.Context, bucket string) error { - _, err := s.s3Client.CreateBucket(ctx, &s3.CreateBucketInput{ - Bucket: aws.String(bucket), - }) - if err != nil { - return err - } + workers := 0 + if cfg.FragmentParallelismConstant > 0 { + workers = cfg.FragmentParallelismConstant + } + if cfg.FragmentParallelismFactor > 0 { + workers = cfg.FragmentParallelismFactor * runtime.NumCPU() + } - return nil + if workers == 0 { + workers = 1 + } + + pool, _ := errgroup.WithContext(ctx) + pool.SetLimit(workers) + + ref = &client{ + cfg: &cfg, + s3Client: s3Client, + pool: pool, + logger: logger.With("component", "S3Client"), + } + }) + return ref, err } func (s *client) DownloadObject(ctx context.Context, bucket string, key string) ([]byte, error) { @@ -159,3 +177,164 @@ func (s *client) ListObjects(ctx context.Context, bucket string, prefix string) } return objects, nil } + +func (s *client) CreateBucket(ctx context.Context, bucket string) error { + _, err := s.s3Client.CreateBucket(ctx, &s3.CreateBucketInput{ + Bucket: aws.String(bucket), + }) + if err != nil { + return err + } + + return nil +} + +func (s *client) FragmentedUploadObject( + ctx context.Context, + bucket string, + key string, + data []byte, + fragmentSize int) error { + + fragments, err := breakIntoFragments(key, data, s.cfg.FragmentPrefixChars, fragmentSize) + if err != nil { + return err + } + resultChannel := make(chan error, len(fragments)) + + ctx, cancel := context.WithTimeout(ctx, s.cfg.FragmentWriteTimeout) + defer cancel() + + for _, fragment := range fragments { + fragmentCapture := fragment + s.pool.Go(func() error { + s.fragmentedWriteTask(ctx, resultChannel, fragmentCapture, bucket) + return nil + }) + } + + for range fragments { + err := <-resultChannel + if err != nil { + return err + } + } + return ctx.Err() + +} + +// fragmentedWriteTask writes a single file to S3. +func (s *client) fragmentedWriteTask( + ctx context.Context, + resultChannel chan error, + fragment *Fragment, + bucket string) { + + _, err := s.s3Client.PutObject(ctx, + &s3.PutObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(fragment.FragmentKey), + Body: bytes.NewReader(fragment.Data), + }) + + resultChannel <- err +} + +func (s *client) FragmentedDownloadObject( + ctx context.Context, + bucket string, + key string, + fileSize int, + fragmentSize int) ([]byte, error) { + + if fragmentSize <= 0 { + return nil, errors.New("fragmentSize must be greater than 0") + } + + fragmentKeys, err := getFragmentKeys(key, s.cfg.FragmentPrefixChars, getFragmentCount(fileSize, fragmentSize)) + if err != nil { + return nil, err + } + resultChannel := make(chan *readResult, len(fragmentKeys)) + + ctx, cancel := context.WithTimeout(ctx, s.cfg.FragmentWriteTimeout) + defer cancel() + + for i, fragmentKey := range fragmentKeys { + boundFragmentKey := fragmentKey + boundI := i + s.pool.Go(func() error { + s.readTask(ctx, resultChannel, bucket, boundFragmentKey, boundI) + return nil + }) + } + + fragments := make([]*Fragment, len(fragmentKeys)) + for i := 0; i < len(fragmentKeys); i++ { + result := <-resultChannel + if result.err != nil { + return nil, result.err + } + fragments[result.fragment.Index] = result.fragment + } + + if ctx.Err() != nil { + return nil, ctx.Err() + } + + return recombineFragments(fragments) + +} + +// readResult is the result of a read task. +type readResult struct { + fragment *Fragment + err error +} + +// readTask reads a single file from S3. +func (s *client) readTask( + ctx context.Context, + resultChannel chan *readResult, + bucket string, + key string, + index int) { + + result := &readResult{} + defer func() { + resultChannel <- result + }() + + ret, err := s.s3Client.GetObject(ctx, &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + }) + + if err != nil { + result.err = err + return + } + + data := make([]byte, *ret.ContentLength) + bytesRead := 0 + + for bytesRead < len(data) && ctx.Err() == nil { + count, err := ret.Body.Read(data[bytesRead:]) + if err != nil && err.Error() != "EOF" { + result.err = err + return + } + bytesRead += count + } + + result.fragment = &Fragment{ + FragmentKey: key, + Data: data, + Index: index, + } + + err = ret.Body.Close() + if err != nil { + result.err = err + } +} diff --git a/common/aws/s3/fragment.go b/common/aws/s3/fragment.go new file mode 100644 index 000000000..21da697d9 --- /dev/null +++ b/common/aws/s3/fragment.go @@ -0,0 +1,128 @@ +package s3 + +import ( + "fmt" + "sort" + "strings" +) + +// getFragmentCount returns the number of fragments that a file of the given size will be broken into. +func getFragmentCount(fileSize int, fragmentSize int) int { + if fileSize < fragmentSize { + return 1 + } else if fileSize%fragmentSize == 0 { + return fileSize / fragmentSize + } else { + return fileSize/fragmentSize + 1 + } +} + +// getFragmentKey returns the key for the fragment at the given index. +// +// Fragment keys take the form of "prefix/body-index[f]". The prefix is the first prefixLength characters +// of the file key. The body is the file key. The index is the index of the fragment. The character "f" is appended +// to the key of the last fragment in the series. +// +// Example: fileKey="abc123", prefixLength=2, fragmentCount=3 +// The keys will be "ab/abc123-0", "ab/abc123-1", "ab/abc123-2f" +func getFragmentKey(fileKey string, prefixLength int, fragmentCount int, index int) (string, error) { + var prefix string + if prefixLength > len(fileKey) { + prefix = fileKey + } else { + prefix = fileKey[:prefixLength] + } + + postfix := "" + if fragmentCount-1 == index { + postfix = "f" + } + + if index >= fragmentCount { + return "", fmt.Errorf("index %d is too high for fragment count %d", index, fragmentCount) + } + + return fmt.Sprintf("%s/%s-%d%s", prefix, fileKey, index, postfix), nil +} + +// Fragment is a subset of a file. +type Fragment struct { + FragmentKey string + Data []byte + Index int +} + +// breakIntoFragments breaks a file into fragments of the given size. +func breakIntoFragments(fileKey string, data []byte, prefixLength int, fragmentSize int) ([]*Fragment, error) { + fragmentCount := getFragmentCount(len(data), fragmentSize) + fragments := make([]*Fragment, fragmentCount) + for i := 0; i < fragmentCount; i++ { + start := i * fragmentSize + end := start + fragmentSize + if end > len(data) { + end = len(data) + } + + fragmentKey, err := getFragmentKey(fileKey, prefixLength, fragmentCount, i) + if err != nil { + return nil, err + } + fragments[i] = &Fragment{ + FragmentKey: fragmentKey, + Data: data[start:end], + Index: i, + } + } + return fragments, nil +} + +// getFragmentKeys returns the keys for all fragments of a file. +func getFragmentKeys(fileKey string, prefixLength int, fragmentCount int) ([]string, error) { + keys := make([]string, fragmentCount) + for i := 0; i < fragmentCount; i++ { + fragmentKey, err := getFragmentKey(fileKey, prefixLength, fragmentCount, i) + if err != nil { + return nil, err + } + keys[i] = fragmentKey + } + return keys, nil +} + +// recombineFragments recombines fragments into a single file. +// Returns an error if any fragments are missing. +func recombineFragments(fragments []*Fragment) ([]byte, error) { + + if len(fragments) == 0 { + return nil, fmt.Errorf("no fragments") + } + + // Sort the fragments by index + sort.Slice(fragments, func(i, j int) bool { + return fragments[i].Index < fragments[j].Index + }) + + // Make sure there aren't any gaps in the fragment indices + dataSize := 0 + for i, fragment := range fragments { + if fragment.Index != i { + return nil, fmt.Errorf("missing fragment with index %d", i) + } + dataSize += len(fragment.Data) + } + + // Make sure we have the last fragment + if !strings.HasSuffix(fragments[len(fragments)-1].FragmentKey, "f") { + return nil, fmt.Errorf("missing final fragment") + } + + fragmentSize := len(fragments[0].Data) + + // Concatenate the data + result := make([]byte, dataSize) + for _, fragment := range fragments { + copy(result[fragment.Index*fragmentSize:], fragment.Data) + } + + return result, nil +} diff --git a/common/aws/s3/fragment_test.go b/common/aws/s3/fragment_test.go new file mode 100644 index 000000000..04271ce8e --- /dev/null +++ b/common/aws/s3/fragment_test.go @@ -0,0 +1,333 @@ +package s3 + +import ( + "fmt" + tu "github.com/Layr-Labs/eigenda/common/testutils" + "github.com/stretchr/testify/assert" + "math/rand" + "strings" + "testing" +) + +func TestGetFragmentCount(t *testing.T) { + tu.InitializeRandom() + + // Test a file smaller than a fragment + fileSize := rand.Intn(100) + 100 + fragmentSize := fileSize * 2 + fragmentCount := getFragmentCount(fileSize, fragmentSize) + assert.Equal(t, 1, fragmentCount) + + // Test a file that can fit in a single fragment + fileSize = rand.Intn(100) + 100 + fragmentSize = fileSize + fragmentCount = getFragmentCount(fileSize, fragmentSize) + assert.Equal(t, 1, fragmentCount) + + // Test a file that is one byte larger than a fragment + fileSize = rand.Intn(100) + 100 + fragmentSize = fileSize - 1 + fragmentCount = getFragmentCount(fileSize, fragmentSize) + assert.Equal(t, 2, fragmentCount) + + // Test a file that is one less than a multiple of the fragment size + fragmentSize = rand.Intn(100) + 100 + expectedFragmentCount := rand.Intn(10) + 1 + fileSize = fragmentSize*expectedFragmentCount - 1 + fragmentCount = getFragmentCount(fileSize, fragmentSize) + assert.Equal(t, expectedFragmentCount, fragmentCount) + + // Test a file that is a multiple of the fragment size + fragmentSize = rand.Intn(100) + 100 + expectedFragmentCount = rand.Intn(10) + 1 + fileSize = fragmentSize * expectedFragmentCount + fragmentCount = getFragmentCount(fileSize, fragmentSize) + assert.Equal(t, expectedFragmentCount, fragmentCount) + + // Test a file that is one more than a multiple of the fragment size + fragmentSize = rand.Intn(100) + 100 + expectedFragmentCount = rand.Intn(10) + 2 + fileSize = fragmentSize*(expectedFragmentCount-1) + 1 + fragmentCount = getFragmentCount(fileSize, fragmentSize) + assert.Equal(t, expectedFragmentCount, fragmentCount) +} + +// Fragment keys take the form of "prefix/body-index[f]". Verify the prefix part of the key. +func TestPrefix(t *testing.T) { + tu.InitializeRandom() + + keyLength := rand.Intn(10) + 10 + key := tu.RandomString(keyLength) + + for i := 0; i < keyLength*2; i++ { + fragmentCount := rand.Intn(10) + 10 + fragmentIndex := rand.Intn(fragmentCount) + fragmentKey, err := getFragmentKey(key, i, fragmentCount, fragmentIndex) + assert.NoError(t, err) + + parts := strings.Split(fragmentKey, "/") + assert.Equal(t, 2, len(parts)) + prefix := parts[0] + + if i >= keyLength { + assert.Equal(t, key, prefix) + } else { + assert.Equal(t, key[:i], prefix) + } + } +} + +// Fragment keys take the form of "prefix/body-index[f]". Verify the body part of the key. +func TestKeyBody(t *testing.T) { + tu.InitializeRandom() + + for i := 0; i < 10; i++ { + keyLength := rand.Intn(10) + 10 + key := tu.RandomString(keyLength) + fragmentCount := rand.Intn(10) + 10 + fragmentIndex := rand.Intn(fragmentCount) + fragmentKey, err := getFragmentKey(key, rand.Intn(10), fragmentCount, fragmentIndex) + assert.NoError(t, err) + + parts := strings.Split(fragmentKey, "/") + assert.Equal(t, 2, len(parts)) + parts = strings.Split(parts[1], "-") + assert.Equal(t, 2, len(parts)) + body := parts[0] + + assert.Equal(t, key, body) + } +} + +// Fragment keys take the form of "prefix/body-index[f]". Verify the index part of the key. +func TestKeyIndex(t *testing.T) { + tu.InitializeRandom() + + for i := 0; i < 10; i++ { + fragmentCount := rand.Intn(10) + 10 + index := rand.Intn(fragmentCount) + fragmentKey, err := getFragmentKey(tu.RandomString(10), rand.Intn(10), fragmentCount, index) + assert.NoError(t, err) + + parts := strings.Split(fragmentKey, "/") + assert.Equal(t, 2, len(parts)) + parts = strings.Split(parts[1], "-") + assert.Equal(t, 2, len(parts)) + indexStr := parts[1] + assert.True(t, strings.HasPrefix(indexStr, fmt.Sprintf("%d", index))) + } +} + +// Fragment keys take the form of "prefix/body-index[f]". +// Verify the postfix part of the key, which should be "f" for the last fragment. +func TestKeyPostfix(t *testing.T) { + tu.InitializeRandom() + + segmentCount := rand.Intn(10) + 10 + + for i := 0; i < segmentCount; i++ { + fragmentKey, err := getFragmentKey(tu.RandomString(10), rand.Intn(10), segmentCount, i) + assert.NoError(t, err) + + if i == segmentCount-1 { + assert.True(t, strings.HasSuffix(fragmentKey, "f")) + } else { + assert.False(t, strings.HasSuffix(fragmentKey, "f")) + } + } +} + +// TestExampleInGodoc tests the example provided in the documentation for getFragmentKey(). +// +// Example: fileKey="abc123", prefixLength=2, fragmentCount=3 +// The keys will be "ab/abc123-0", "ab/abc123-1", "ab/abc123-2f" +func TestExampleInGodoc(t *testing.T) { + fileKey := "abc123" + prefixLength := 2 + fragmentCount := 3 + fragmentKeys, err := getFragmentKeys(fileKey, prefixLength, fragmentCount) + assert.NoError(t, err) + assert.Equal(t, 3, len(fragmentKeys)) + assert.Equal(t, "ab/abc123-0", fragmentKeys[0]) + assert.Equal(t, "ab/abc123-1", fragmentKeys[1]) + assert.Equal(t, "ab/abc123-2f", fragmentKeys[2]) +} + +func TestGetFragmentKeys(t *testing.T) { + tu.InitializeRandom() + + fileKey := tu.RandomString(10) + prefixLength := rand.Intn(3) + 1 + fragmentCount := rand.Intn(10) + 10 + + fragmentKeys, err := getFragmentKeys(fileKey, prefixLength, fragmentCount) + assert.NoError(t, err) + assert.Equal(t, fragmentCount, len(fragmentKeys)) + + for i := 0; i < fragmentCount; i++ { + expectedKey, err := getFragmentKey(fileKey, prefixLength, fragmentCount, i) + assert.NoError(t, err) + assert.Equal(t, expectedKey, fragmentKeys[i]) + + parts := strings.Split(fragmentKeys[i], "/") + assert.Equal(t, 2, len(parts)) + parsedPrefix := parts[0] + assert.Equal(t, fileKey[:prefixLength], parsedPrefix) + parts = strings.Split(parts[1], "-") + assert.Equal(t, 2, len(parts)) + parsedKey := parts[0] + assert.Equal(t, fileKey, parsedKey) + index := parts[1] + + if i == fragmentCount-1 { + assert.Equal(t, fmt.Sprintf("%d", i)+"f", index) + } else { + assert.Equal(t, fmt.Sprintf("%d", i), index) + } + } +} + +func TestGetFragments(t *testing.T) { + tu.InitializeRandom() + + fileKey := tu.RandomString(10) + data := tu.RandomBytes(1000) + prefixLength := rand.Intn(3) + 1 + fragmentSize := rand.Intn(100) + 100 + + fragments, err := breakIntoFragments(fileKey, data, prefixLength, fragmentSize) + assert.NoError(t, err) + assert.Equal(t, getFragmentCount(len(data), fragmentSize), len(fragments)) + + totalSize := 0 + + for i, fragment := range fragments { + fragmentKey, err := getFragmentKey(fileKey, prefixLength, len(fragments), i) + assert.NoError(t, err) + assert.Equal(t, fragmentKey, fragment.FragmentKey) + + start := i * fragmentSize + end := start + fragmentSize + if end > len(data) { + end = len(data) + } + assert.Equal(t, data[start:end], fragment.Data) + assert.Equal(t, i, fragment.Index) + totalSize += len(fragment.Data) + } + + assert.Equal(t, len(data), totalSize) +} + +func TestGetFragmentsSmallFile(t *testing.T) { + tu.InitializeRandom() + + fileKey := tu.RandomString(10) + data := tu.RandomBytes(10) + prefixLength := rand.Intn(3) + 1 + fragmentSize := rand.Intn(100) + 100 + + fragments, err := breakIntoFragments(fileKey, data, prefixLength, fragmentSize) + assert.NoError(t, err) + assert.Equal(t, 1, len(fragments)) + + fragmentKey, err := getFragmentKey(fileKey, prefixLength, 1, 0) + assert.NoError(t, err) + assert.Equal(t, fragmentKey, fragments[0].FragmentKey) + assert.Equal(t, data, fragments[0].Data) + assert.Equal(t, 0, fragments[0].Index) +} + +func TestGetFragmentsExactlyOnePerfectlySizedFile(t *testing.T) { + tu.InitializeRandom() + + fileKey := tu.RandomString(10) + fragmentSize := rand.Intn(100) + 100 + data := tu.RandomBytes(fragmentSize) + prefixLength := rand.Intn(3) + 1 + + fragments, err := breakIntoFragments(fileKey, data, prefixLength, fragmentSize) + assert.NoError(t, err) + assert.Equal(t, 1, len(fragments)) + + fragmentKey, err := getFragmentKey(fileKey, prefixLength, 1, 0) + assert.NoError(t, err) + assert.Equal(t, fragmentKey, fragments[0].FragmentKey) + assert.Equal(t, data, fragments[0].Data) + assert.Equal(t, 0, fragments[0].Index) +} + +func TestRecombineFragments(t *testing.T) { + tu.InitializeRandom() + + fileKey := tu.RandomString(10) + data := tu.RandomBytes(1000) + prefixLength := rand.Intn(3) + 1 + fragmentSize := rand.Intn(100) + 100 + + fragments, err := breakIntoFragments(fileKey, data, prefixLength, fragmentSize) + assert.NoError(t, err) + recombinedData, err := recombineFragments(fragments) + assert.NoError(t, err) + assert.Equal(t, data, recombinedData) + + // Shuffle the fragments + for i := range fragments { + j := rand.Intn(i + 1) + fragments[i], fragments[j] = fragments[j], fragments[i] + } + + recombinedData, err = recombineFragments(fragments) + assert.NoError(t, err) + assert.Equal(t, data, recombinedData) +} + +func TestRecombineFragmentsSmallFile(t *testing.T) { + tu.InitializeRandom() + + fileKey := tu.RandomString(10) + data := tu.RandomBytes(10) + prefixLength := rand.Intn(3) + 1 + fragmentSize := rand.Intn(100) + 100 + + fragments, err := breakIntoFragments(fileKey, data, prefixLength, fragmentSize) + assert.NoError(t, err) + assert.Equal(t, 1, len(fragments)) + recombinedData, err := recombineFragments(fragments) + assert.NoError(t, err) + assert.Equal(t, data, recombinedData) +} + +func TestMissingFragment(t *testing.T) { + tu.InitializeRandom() + + fileKey := tu.RandomString(10) + data := tu.RandomBytes(1000) + prefixLength := rand.Intn(3) + 1 + fragmentSize := rand.Intn(100) + 100 + + fragments, err := breakIntoFragments(fileKey, data, prefixLength, fragmentSize) + assert.NoError(t, err) + + fragmentIndexToSkip := rand.Intn(len(fragments)) + fragments = append(fragments[:fragmentIndexToSkip], fragments[fragmentIndexToSkip+1:]...) + + _, err = recombineFragments(fragments[:len(fragments)-1]) + assert.Error(t, err) +} + +func TestMissingFinalFragment(t *testing.T) { + tu.InitializeRandom() + + fileKey := tu.RandomString(10) + data := tu.RandomBytes(1000) + prefixLength := rand.Intn(3) + 1 + fragmentSize := rand.Intn(100) + 100 + + fragments, err := breakIntoFragments(fileKey, data, prefixLength, fragmentSize) + assert.NoError(t, err) + fragments = fragments[:len(fragments)-1] + + _, err = recombineFragments(fragments) + assert.Error(t, err) +} diff --git a/common/aws/s3/s3.go b/common/aws/s3/s3.go index 475f68c94..d96172dbc 100644 --- a/common/aws/s3/s3.go +++ b/common/aws/s3/s3.go @@ -2,10 +2,51 @@ package s3 import "context" +// Client encapsulates the functionality of an S3 client. type Client interface { - CreateBucket(ctx context.Context, bucket string) error + + // DownloadObject downloads an object from S3. DownloadObject(ctx context.Context, bucket string, key string) ([]byte, error) + + // UploadObject uploads an object to S3. UploadObject(ctx context.Context, bucket string, key string, data []byte) error + + // DeleteObject deletes an object from S3. DeleteObject(ctx context.Context, bucket string, key string) error + + // ListObjects lists all objects in a bucket with the given prefix. Note that this method may return + // file fragments if the bucket contains files uploaded via FragmentedUploadObject. ListObjects(ctx context.Context, bucket string, prefix string) ([]Object, error) + + // CreateBucket creates a bucket in S3. + CreateBucket(ctx context.Context, bucket string) error + + // FragmentedUploadObject uploads a file to S3. The fragmentSize parameter specifies the maximum size of each + // file uploaded to S3. If the file is larger than fragmentSize then it will be broken into + // smaller parts and uploaded in parallel. The file will be reassembled on download. + // + // Note: if a file is uploaded with this method, only the FragmentedDownloadObject method should be used to + // download the file. It is not advised to use DeleteObject on files uploaded with this method (if such + // functionality is required, a new method to do so should be added to this interface). + // + // Note: if this operation fails partway through, some file fragments may have made it to S3 and others may not. + // In order to prevent long term accumulation of fragments, it is suggested to use this method in conjunction with + // a bucket configured to have a TTL. + FragmentedUploadObject( + ctx context.Context, + bucket string, + key string, + data []byte, + fragmentSize int) error + + // FragmentedDownloadObject downloads a file from S3, as written by Upload. The fileSize (in bytes) and fragmentSize + // must be the same as the values used in the FragmentedUploadObject call. + // + // Note: this method can only be used to download files that were uploaded with the FragmentedUploadObject method. + FragmentedDownloadObject( + ctx context.Context, + bucket string, + key string, + fileSize int, + fragmentSize int) ([]byte, error) } diff --git a/common/aws/test/client_test.go b/common/aws/test/client_test.go new file mode 100644 index 000000000..0f5bc4087 --- /dev/null +++ b/common/aws/test/client_test.go @@ -0,0 +1,176 @@ +package test + +import ( + "context" + "github.com/Layr-Labs/eigenda/common" + "github.com/Layr-Labs/eigenda/common/aws" + "github.com/Layr-Labs/eigenda/common/aws/s3" + "github.com/Layr-Labs/eigenda/common/mock" + tu "github.com/Layr-Labs/eigenda/common/testutils" + "github.com/Layr-Labs/eigenda/inabox/deploy" + "github.com/ory/dockertest/v3" + "github.com/stretchr/testify/assert" + "math/rand" + "os" + "testing" +) + +var ( + dockertestPool *dockertest.Pool + dockertestResource *dockertest.Resource +) + +const ( + localstackPort = "4570" + localstackHost = "http://0.0.0.0:4570" + bucket = "eigen-test" +) + +type clientBuilder struct { + // This method is called at the beginning of the test. + start func() error + // This method is called to build a new client. + build func() (s3.Client, error) + // This method is called at the end of the test when all operations are done. + finish func() error +} + +var clientBuilders = []*clientBuilder{ + { + start: func() error { + return nil + }, + build: func() (s3.Client, error) { + return mock.NewS3Client(), nil + }, + finish: func() error { + return nil + }, + }, + { + start: func() error { + return setupLocalstack() + }, + build: func() (s3.Client, error) { + + logger, err := common.NewLogger(common.DefaultLoggerConfig()) + if err != nil { + return nil, err + } + + config := aws.DefaultClientConfig() + config.EndpointURL = localstackHost + config.Region = "us-east-1" + + err = os.Setenv("AWS_ACCESS_KEY_ID", "localstack") + if err != nil { + return nil, err + } + err = os.Setenv("AWS_SECRET_ACCESS_KEY", "localstack") + if err != nil { + return nil, err + } + + client, err := s3.NewClient(context.Background(), *config, logger) + if err != nil { + return nil, err + } + + err = client.CreateBucket(context.Background(), bucket) + if err != nil { + return nil, err + } + + return client, nil + }, + finish: func() error { + teardownLocalstack() + return nil + }, + }, +} + +func setupLocalstack() error { + deployLocalStack := !(os.Getenv("DEPLOY_LOCALSTACK") == "false") + + if deployLocalStack { + var err error + dockertestPool, dockertestResource, err = deploy.StartDockertestWithLocalstackContainer(localstackPort) + if err != nil && err.Error() == "container already exists" { + teardownLocalstack() + return err + } + } + return nil +} + +func teardownLocalstack() { + deployLocalStack := !(os.Getenv("DEPLOY_LOCALSTACK") == "false") + + if deployLocalStack { + deploy.PurgeDockertestResources(dockertestPool, dockertestResource) + } +} + +func RandomOperationsTest(t *testing.T, client s3.Client) { + numberToWrite := 100 + expectedData := make(map[string][]byte) + + fragmentSize := rand.Intn(1000) + 1000 + + for i := 0; i < numberToWrite; i++ { + key := tu.RandomString(10) + fragmentMultiple := rand.Float64() * 10 + dataSize := int(fragmentMultiple*float64(fragmentSize)) + 1 + data := tu.RandomBytes(dataSize) + expectedData[key] = data + + err := client.FragmentedUploadObject(context.Background(), bucket, key, data, fragmentSize) + assert.NoError(t, err) + } + + // Read back the data + for key, expected := range expectedData { + data, err := client.FragmentedDownloadObject(context.Background(), bucket, key, len(expected), fragmentSize) + assert.NoError(t, err) + assert.Equal(t, expected, data) + } +} + +func TestRandomOperations(t *testing.T) { + tu.InitializeRandom() + for _, builder := range clientBuilders { + err := builder.start() + assert.NoError(t, err) + + client, err := builder.build() + assert.NoError(t, err) + RandomOperationsTest(t, client) + + err = builder.finish() + assert.NoError(t, err) + } +} + +func ReadNonExistentValueTest(t *testing.T, client s3.Client) { + _, err := client.FragmentedDownloadObject(context.Background(), bucket, "nonexistent", 1000, 1000) + assert.Error(t, err) + randomKey := tu.RandomString(10) + _, err = client.FragmentedDownloadObject(context.Background(), bucket, randomKey, 0, 0) + assert.Error(t, err) +} + +func TestReadNonExistentValue(t *testing.T) { + tu.InitializeRandom() + for _, builder := range clientBuilders { + err := builder.start() + assert.NoError(t, err) + + client, err := builder.build() + assert.NoError(t, err) + ReadNonExistentValueTest(t, client) + + err = builder.finish() + assert.NoError(t, err) + } +} diff --git a/common/mock/s3_client.go b/common/mock/s3_client.go index d4e79645b..7f505d56a 100644 --- a/common/mock/s3_client.go +++ b/common/mock/s3_client.go @@ -17,10 +17,6 @@ func NewS3Client() *S3Client { return &S3Client{bucket: make(map[string][]byte)} } -func (s *S3Client) CreateBucket(ctx context.Context, bucket string) error { - return nil -} - func (s *S3Client) DownloadObject(ctx context.Context, bucket string, key string) ([]byte, error) { data, ok := s.bucket[key] if !ok { @@ -48,3 +44,30 @@ func (s *S3Client) ListObjects(ctx context.Context, bucket string, prefix string } return objects, nil } + +func (s *S3Client) CreateBucket(ctx context.Context, bucket string) error { + return nil +} + +func (s *S3Client) FragmentedUploadObject( + ctx context.Context, + bucket string, + key string, + data []byte, + fragmentSize int) error { + s.bucket[key] = data + return nil +} + +func (s *S3Client) FragmentedDownloadObject( + ctx context.Context, + bucket string, + key string, + fileSize int, + fragmentSize int) ([]byte, error) { + data, ok := s.bucket[key] + if !ok { + return []byte{}, s3.ErrObjectNotFound + } + return data, nil +}