From ad288fb4ded41218fca403a0ff45ccfe1e10d0c6 Mon Sep 17 00:00:00 2001 From: Tamer Sherif <69483382+tasherif-msft@users.noreply.github.com> Date: Thu, 27 Jul 2023 13:46:57 -0700 Subject: [PATCH] [AzDatalake] File Client Upload/Download Support (#21261) * Enable gocritic during linting (#20715) Enabled gocritic's evalOrder to catch dependencies on undefined behavior on return statements. Updated to latest version of golangci-lint. Fixed issue in azblob flagged by latest linter. * Cosmos DB: Enable merge support (#20716) * Adding header and value * Wiring and tests * format * Fixing value * change log * [azservicebus, azeventhubs] Stress test and logging improvement (#20710) Logging improvements: * Updating the logging to print more tracing information (per-link) in prep for the bigger release coming up. * Trimming out some of the verbose logging, seeing if I can get it a bit more reasonable. Stress tests: * Add a timestamp to the log name we generate and also default to append, not overwrite. * Use 0.5 cores, 0.5GB as our baseline. Some pods use more and I'll tune them more later. * update proxy version (#20712) Co-authored-by: Scott Beddall * Return an error when you try to send a message that's too large. (#20721) This now works just like the message batch - you'll get an ErrMessageTooLarge if you attempt to send a message that's too large for the link's configured size. NOTE: there's a patch to `internal/go-amqp/Sender.go` to match what's in go-amqp's main so it returns a programmatically useful error when the message is too large. Fixes #20647 * Changes in test that is failing in pipeline (#20693) * [azservicebus, azeventhubs] Treat 'entity full' as a fatal error (#20722) When the remote entity is full we get a resource-limit-exceeded condition. This isn't something we should keep retrying on and it's best to just abort and let the user know immediately, rather than hoping it might eventually clear out. This affected both Event Hubs and Service Bus. Fixes #20647 * [azservicebus/azeventhubs] Redirect stderr and stdout to tee (#20726) * Update changelog with latest features (#20730) * Update changelog with latest features Prepare for upcoming release. * bump minor version * pass along the artifact name so we can override it later (#20732) Co-authored-by: scbedd <45376673+scbedd@users.noreply.github.com> * [azeventhubs] Fixing checkpoint store race condition (#20727) The checkpoint store wasn't guarding against multiple owners claiming for the first time - fixing this by using IfNoneMatch Fixes #20717 * Fix azidentity troubleshooting guide link (#20736) * [Release] sdk/resourcemanager/paloaltonetworksngfw/armpanngfw/0.1.0 (#20437) * [Release] sdk/resourcemanager/paloaltonetworksngfw/armpanngfw/0.1.0 generation from spec commit: 85fb4ac6f8bfefd179e6c2632976a154b5c9ff04 * client factory * fix * fix * update * add sdk/resourcemanager/postgresql/armpostgresql live test (#20685) * add sdk/resourcemanager/postgresql/armpostgresql live test * update assets.json * set subscriptionId default value * format * add sdk/resourcemanager/eventhub/armeventhub live test (#20686) * add sdk/resourcemanager/eventhub/armeventhub live test * update assets * add sdk/resourcemanager/compute/armcompute live test (#20048) * add sdk/resourcemanager/compute/armcompute live test * skus filter * fix subscriptionId default value * fix * gofmt * update recording * sdk/resourcemanager/network/armnetwork live test (#20331) * sdk/resourcemanager/network/armnetwork live test * update subscriptionId default value * update recording * add sdk/resourcemanager/cosmos/armcosmos live test (#20705) * add sdk/resourcemanager/cosmos/armcosmos live test * update assets.json * update assets.json * update assets.json * update assets.json * Increment package version after release of azcore (#20740) * [azeventhubs] Improperly resetting etag in the checkpoint store (#20737) We shouldn't be resetting the etag to nil - it's what we use to enforce a "single winner" when doing ownership claims. The bug here was two-fold: I had bad logic in my previous claim ownership, which I fixed in a previous PR, but we need to reflect that same constraint properly in our in-memory checkpoint store for these tests. * Eng workflows sync and branch cleanup additions (#20743) Co-authored-by: James Suplizio * [azeventhubs] Latest start position can also be inclusive (ie, get the latest message) (#20744) * Update GitHubEventProcessor version and remove pull_request_review procesing (#20751) Co-authored-by: James Suplizio * Rename DisableAuthorityValidationAndInstanceDiscovery (#20746) * fix (#20707) * AzFile (#20739) * azfile: Fixing connection string parsing logic (#20798) * Fixing connection string parse logic * Update README * [azadmin] fix flaky test (#20758) * fix flaky test * charles suggestion * Prepare azidentity v1.3.0 for release (#20756) * Fix broken podman link (#20801) Co-authored-by: Wes Haggard * [azquery] update doc comments (#20755) * update doc comments * update statistics and visualization generation * prep-for-release * Fixed contribution section (#20752) Co-authored-by: Bob Tabor * [azeventhubs,azservicebus] Some API cleanup, renames (#20754) * Adding options to UpdateCheckpoint(), just for future potential expansion * Make Offset an int64, not a *int64 (it's not optional, it'll always come back with ReceivedEvents) * Adding more logging into the checkpoint store. * Point all imports at the production go-amqp * Add supporting features to enable distributed tracing (#20301) (#20708) * Add supporting features to enable distributed tracing This includes new internal pipeline policies and other supporting types. See the changelog for a full description. Added some missing doc comments. * fix linter issue * add net.peer.name trace attribute sequence custom HTTP header policy before logging policy. sequence logging policy after HTTP trace policy. keep body download policy at the end. * add span for iterating over pages * Restore ARM CAE support for azcore beta (#20657) This reverts commit 902097226ff3fe2fc6c3e7fc50d3478350253614. * Upgrade to stable azcore (#20808) * Increment package version after release of data/azcosmos (#20807) * Updating changelog (#20810) * Add fake package to azcore (#20711) * Add fake package to azcore This is the supporting infrastructure for the generated SDK fakes. * fix doc comment * Updating CHANGELOG.md (#20809) * changelog (#20811) * Increment package version after release of storage/azfile (#20813) * Update changelog (azblob) (#20815) * Updating CHANGELOG.md * Update the changelog with correct version * [azquery] migration guide (#20742) * migration guide * Charles feedback * Richard feedback --------- Co-authored-by: Charles Lowell <10964656+chlowell@users.noreply.github.com> * Increment package version after release of monitor/azquery (#20820) * [keyvault] prep for release (#20819) * prep for release * perf tests * update date * added all upload methods * added more tests for upload stream * added more tests * added downloaders * added more tests * cleanup * feedback --------- Co-authored-by: Joel Hendrix Co-authored-by: Matias Quaranta Co-authored-by: Richard Park <51494936+richardpark-msft@users.noreply.github.com> Co-authored-by: Azure SDK Bot <53356347+azure-sdk@users.noreply.github.com> Co-authored-by: Scott Beddall Co-authored-by: siminsavani-msft <77068571+siminsavani-msft@users.noreply.github.com> Co-authored-by: scbedd <45376673+scbedd@users.noreply.github.com> Co-authored-by: Charles Lowell <10964656+chlowell@users.noreply.github.com> Co-authored-by: Peng Jiahui <46921893+Alancere@users.noreply.github.com> Co-authored-by: James Suplizio Co-authored-by: Sourav Gupta <98318303+souravgupta-msft@users.noreply.github.com> Co-authored-by: gracewilcox <43627800+gracewilcox@users.noreply.github.com> Co-authored-by: Wes Haggard Co-authored-by: Bob Tabor Co-authored-by: Bob Tabor --- sdk/storage/azdatalake/file/chunkwriting.go | 193 +++++ sdk/storage/azdatalake/file/client.go | 184 +++- sdk/storage/azdatalake/file/client_test.go | 792 ++++++++++++++++++ sdk/storage/azdatalake/file/constants.go | 12 + sdk/storage/azdatalake/file/mmf_unix.go | 38 + sdk/storage/azdatalake/file/mmf_windows.go | 56 ++ sdk/storage/azdatalake/file/models.go | 413 +++++++++ sdk/storage/azdatalake/file/responses.go | 223 +++++ sdk/storage/azdatalake/file/retry_reader.go | 191 +++++ sdk/storage/azdatalake/filesystem/client.go | 4 +- .../exported/shared_key_credential.go | 2 +- .../exported/transfer_validation_option.go | 56 ++ .../azdatalake/internal/generated/models.go | 15 + .../azdatalake/internal/path/constants.go | 12 +- .../azdatalake/internal/path/models.go | 29 +- .../internal/shared/batch_transfer.go | 77 ++ .../azdatalake/internal/testcommon/common.go | 20 + sdk/storage/azdatalake/sas/service.go | 6 +- 18 files changed, 2277 insertions(+), 46 deletions(-) create mode 100644 sdk/storage/azdatalake/file/chunkwriting.go create mode 100644 sdk/storage/azdatalake/file/mmf_unix.go create mode 100644 sdk/storage/azdatalake/file/mmf_windows.go create mode 100644 sdk/storage/azdatalake/file/retry_reader.go create mode 100644 sdk/storage/azdatalake/internal/exported/transfer_validation_option.go create mode 100644 sdk/storage/azdatalake/internal/generated/models.go create mode 100644 sdk/storage/azdatalake/internal/shared/batch_transfer.go diff --git a/sdk/storage/azdatalake/file/chunkwriting.go b/sdk/storage/azdatalake/file/chunkwriting.go new file mode 100644 index 000000000000..289b042d1646 --- /dev/null +++ b/sdk/storage/azdatalake/file/chunkwriting.go @@ -0,0 +1,193 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package file + +import ( + "bytes" + "context" + "errors" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming" + "io" + "sync" +) + +// chunkWriter provides methods to upload chunks that represent a file to a server. +// This allows us to provide a local implementation that fakes the server for hermetic testing. +type chunkWriter interface { + AppendData(context.Context, int64, io.ReadSeekCloser, *AppendDataOptions) (AppendDataResponse, error) + FlushData(context.Context, int64, *FlushDataOptions) (FlushDataResponse, error) +} + +// bufferManager provides an abstraction for the management of buffers. +// this is mostly for testing purposes, but does allow for different implementations without changing the algorithm. +type bufferManager[T ~[]byte] interface { + // Acquire returns the channel that contains the pool of buffers. + Acquire() <-chan T + + // Release releases the buffer back to the pool for reuse/cleanup. + Release(T) + + // Grow grows the number of buffers, up to the predefined max. + // It returns the total number of buffers or an error. + // No error is returned if the number of buffers has reached max. + // This is called only from the reading goroutine. + Grow() (int, error) + + // Free cleans up all buffers. + Free() +} + +// copyFromReader copies a source io.Reader to file storage using concurrent uploads. +func copyFromReader[T ~[]byte](ctx context.Context, src io.Reader, dst chunkWriter, options UploadStreamOptions, getBufferManager func(maxBuffers int, bufferSize int64) bufferManager[T]) error { + options.setDefaults() + actualSize := int64(0) + wg := sync.WaitGroup{} // Used to know when all outgoing chunks have finished processing + errCh := make(chan error, 1) // contains the first error encountered during processing + var err error + + buffers := getBufferManager(int(options.Concurrency), options.ChunkSize) + defer buffers.Free() + + // this controls the lifetime of the uploading goroutines. + // if an error is encountered, cancel() is called which will terminate all uploads. + // NOTE: the ordering is important here. cancel MUST execute before + // cleaning up the buffers so that any uploading goroutines exit first, + // releasing their buffers back to the pool for cleanup. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // This goroutine grabs a buffer, reads from the stream into the buffer, + // then creates a goroutine to upload/stage the chunk. + for chunkNum := uint32(0); true; chunkNum++ { + var buffer T + select { + case buffer = <-buffers.Acquire(): + // got a buffer + default: + // no buffer available; allocate a new buffer if possible + if _, err := buffers.Grow(); err != nil { + return err + } + + // either grab the newly allocated buffer or wait for one to become available + buffer = <-buffers.Acquire() + } + + var n int + n, err = io.ReadFull(src, buffer) + + if n > 0 { + // some data was read, upload it + wg.Add(1) // We're posting a buffer to be sent + + // NOTE: we must pass chunkNum as an arg to our goroutine else + // it's captured by reference and can change underneath us! + go func(chunkNum uint32) { + // Upload the outgoing chunk, matching the number of bytes read + offset := int64(chunkNum) * options.ChunkSize + appendDataOpts := options.getAppendDataOptions() + actualSize += int64(len(buffer[:n])) + _, err := dst.AppendData(ctx, offset, streaming.NopCloser(bytes.NewReader(buffer[:n])), appendDataOpts) + if err != nil { + select { + case errCh <- err: + // error was set + default: + // some other error is already set + } + cancel() + } + buffers.Release(buffer) // The goroutine reading from the stream can reuse this buffer now + + // signal that the chunk has been staged. + // we MUST do this after attempting to write to errCh + // to avoid it racing with the reading goroutine. + wg.Done() + }(chunkNum) + } else { + // nothing was read so the buffer is empty, send it back for reuse/clean-up. + buffers.Release(buffer) + } + + if err != nil { // The reader is done, no more outgoing buffers + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + // these are expected errors, we don't surface those + err = nil + } else { + // some other error happened, terminate any outstanding uploads + cancel() + } + break + } + } + + wg.Wait() // Wait for all outgoing chunks to complete + + if err != nil { + // there was an error reading from src, favor this error over any error during staging + return err + } + + select { + case err = <-errCh: + // there was an error during staging + return err + default: + // no error was encountered + } + + // All chunks uploaded, return nil error + flushOpts := options.getFlushDataOptions() + _, err = dst.FlushData(ctx, actualSize, flushOpts) + return err +} + +// mmbPool implements the bufferManager interface. +// it uses anonymous memory mapped files for buffers. +// don't use this type directly, use newMMBPool() instead. +type mmbPool struct { + buffers chan mmb + count int + max int + size int64 +} + +func newMMBPool(maxBuffers int, bufferSize int64) bufferManager[mmb] { + return &mmbPool{ + buffers: make(chan mmb, maxBuffers), + max: maxBuffers, + size: bufferSize, + } +} + +func (pool *mmbPool) Acquire() <-chan mmb { + return pool.buffers +} + +func (pool *mmbPool) Grow() (int, error) { + if pool.count < pool.max { + buffer, err := newMMB(pool.size) + if err != nil { + return 0, err + } + pool.buffers <- buffer + pool.count++ + } + return pool.count, nil +} + +func (pool *mmbPool) Release(buffer mmb) { + pool.buffers <- buffer +} + +func (pool *mmbPool) Free() { + for i := 0; i < pool.count; i++ { + buffer := <-pool.buffers + buffer.delete() + } + pool.count = 0 +} diff --git a/sdk/storage/azdatalake/file/client.go b/sdk/storage/azdatalake/file/client.go index 5909aa475aea..0fae799fb965 100644 --- a/sdk/storage/azdatalake/file/client.go +++ b/sdk/storage/azdatalake/file/client.go @@ -7,11 +7,16 @@ package file import ( + "bytes" "context" + "errors" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blockblob" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/datalakeerror" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/base" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/exported" @@ -19,9 +24,12 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/path" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/sas" + "io" "net/http" "net/url" + "os" "strings" + "sync" "time" ) @@ -254,26 +262,6 @@ func (f *Client) SetExpiry(ctx context.Context, expiryType SetExpiryType, o *Set return resp, err } -//// Upload uploads data to a file. -//func (f *Client) Upload(ctx context.Context) { -// -//} -// -//// Append appends data to a file. -//func (f *Client) Append(ctx context.Context) { -// -//} -// -//// Flush flushes previous uploaded data to a file. -//func (f *Client) Flush(ctx context.Context) { -// -//} -// -//// Download downloads data from a file. -//func (f *Client) Download(ctx context.Context) { -// -//} - // SetAccessControl sets the owner, owning group, and permissions for a file or directory (dfs1). func (f *Client) SetAccessControl(ctx context.Context, options *SetAccessControlOptions) (SetAccessControlResponse, error) { opts, lac, mac, err := path.FormatSetAccessControlOptions(options) @@ -360,3 +348,159 @@ func (f *Client) GetSASURL(permissions sas.FilePermissions, expiry time.Time, o return endpoint, nil } + +func (f *Client) AppendData(ctx context.Context, offset int64, body io.ReadSeekCloser, options *AppendDataOptions) (AppendDataResponse, error) { + appendDataOptions, leaseAccessConditions, httpsHeaders, cpkInfo, err := options.format(offset, body) + if err != nil { + return AppendDataResponse{}, err + } + + resp, err := f.generatedFileClientWithDFS().AppendData(ctx, body, appendDataOptions, httpsHeaders, leaseAccessConditions, cpkInfo) + return resp, exported.ConvertToDFSError(err) +} + +func (f *Client) FlushData(ctx context.Context, offset int64, options *FlushDataOptions) (FlushDataResponse, error) { + flushDataOpts, modifiedAccessConditions, leaseAccessConditions, httpHeaderOpts, cpkInfoOpts, err := options.format(offset) + if err != nil { + return FlushDataResponse{}, err + } + + resp, err := f.generatedFileClientWithDFS().FlushData(ctx, flushDataOpts, httpHeaderOpts, leaseAccessConditions, modifiedAccessConditions, cpkInfoOpts) + return resp, exported.ConvertToDFSError(err) +} + +// Concurrent Upload Functions ----------------------------------------------------------------------------------------- + +// uploadFromReader uploads a buffer in chunks to an Azure file. +func (f *Client) uploadFromReader(ctx context.Context, reader io.ReaderAt, actualSize int64, o *uploadFromReaderOptions) error { + if actualSize > MaxFileSize { + return errors.New("buffer is too large to upload to a file") + } + if o.ChunkSize == 0 { + o.ChunkSize = MaxUpdateRangeBytes + } + + if log.Should(exported.EventUpload) { + urlParts, err := azdatalake.ParseURL(f.DFSURL()) + if err == nil { + log.Writef(exported.EventUpload, "file name %s actual size %v chunk-size %v chunk-count %v", + urlParts.PathName, actualSize, o.ChunkSize, ((actualSize-1)/o.ChunkSize)+1) + } + } + + progress := int64(0) + progressLock := &sync.Mutex{} + + err := shared.DoBatchTransfer(ctx, &shared.BatchTransferOptions{ + OperationName: "uploadFromReader", + TransferSize: actualSize, + ChunkSize: o.ChunkSize, + Concurrency: o.Concurrency, + Operation: func(ctx context.Context, offset int64, chunkSize int64) error { + // This function is called once per file range. + // It is passed this file's offset within the buffer and its count of bytes + // Prepare to read the proper range/section of the buffer + if chunkSize < o.ChunkSize { + // this is the last file range. Its actual size might be less + // than the calculated size due to rounding up of the payload + // size to fit in a whole number of chunks. + chunkSize = actualSize - offset + } + var body io.ReadSeeker = io.NewSectionReader(reader, offset, chunkSize) + if o.Progress != nil { + chunkProgress := int64(0) + body = streaming.NewRequestProgress(streaming.NopCloser(body), + func(bytesTransferred int64) { + diff := bytesTransferred - chunkProgress + chunkProgress = bytesTransferred + progressLock.Lock() // 1 goroutine at a time gets progress report + progress += diff + o.Progress(progress) + progressLock.Unlock() + }) + } + + uploadRangeOptions := o.getAppendDataOptions() + _, err := f.AppendData(ctx, offset, streaming.NopCloser(body), uploadRangeOptions) + return exported.ConvertToDFSError(err) + }, + }) + + if err != nil { + return exported.ConvertToDFSError(err) + } + // All appends were successful, call to flush + flushOpts := o.getFlushDataOptions() + _, err = f.FlushData(ctx, actualSize, flushOpts) + return exported.ConvertToDFSError(err) +} + +// UploadBuffer uploads a buffer in chunks to an Azure file. +func (f *Client) UploadBuffer(ctx context.Context, buffer []byte, options *UploadBufferOptions) error { + uploadOptions := uploadFromReaderOptions{} + if options != nil { + uploadOptions = *options + } + return exported.ConvertToDFSError(f.uploadFromReader(ctx, bytes.NewReader(buffer), int64(len(buffer)), &uploadOptions)) +} + +// UploadFile uploads a file in chunks to an Azure file. +func (f *Client) UploadFile(ctx context.Context, file *os.File, options *UploadFileOptions) error { + stat, err := file.Stat() + if err != nil { + return err + } + uploadOptions := uploadFromReaderOptions{} + if options != nil { + uploadOptions = *options + } + return exported.ConvertToDFSError(f.uploadFromReader(ctx, file, stat.Size(), &uploadOptions)) +} + +// UploadStream copies the file held in io.Reader to the file at fileClient. +// A Context deadline or cancellation will cause this to error. +func (f *Client) UploadStream(ctx context.Context, body io.Reader, options *UploadStreamOptions) error { + if options == nil { + options = &UploadStreamOptions{} + } + + err := copyFromReader(ctx, body, f, *options, newMMBPool) + return exported.ConvertToDFSError(err) +} + +// DownloadStream reads a range of bytes from a blob. The response also includes the blob's properties and metadata. +// For more information, see https://docs.microsoft.com/rest/api/storageservices/get-blob. +func (f *Client) DownloadStream(ctx context.Context, o *DownloadStreamOptions) (DownloadStreamResponse, error) { + if o == nil { + o = &DownloadStreamOptions{} + } + opts := o.format() + resp, err := f.blobClient().DownloadStream(ctx, opts) + newResp := FormatDownloadStreamResponse(&resp) + fullResp := DownloadStreamResponse{ + client: f, + DownloadResponse: newResp, + getInfo: httpGetterInfo{Range: o.Range, ETag: newResp.ETag}, + cpkInfo: o.CPKInfo, + cpkScope: o.CPKScopeInfo, + } + + return fullResp, exported.ConvertToDFSError(err) +} + +// DownloadBuffer downloads an Azure blob to a buffer with parallel. +func (f *Client) DownloadBuffer(ctx context.Context, buffer []byte, o *DownloadBufferOptions) (int64, error) { + opts := o.format() + val, err := f.blobClient().DownloadBuffer(ctx, shared.NewBytesWriter(buffer), opts) + return val, exported.ConvertToDFSError(err) +} + +// DownloadFile downloads an Azure blob to a local file. +// The file would be truncated if the size doesn't match. +func (f *Client) DownloadFile(ctx context.Context, file *os.File, o *DownloadFileOptions) (int64, error) { + opts := o.format() + val, err := f.blobClient().DownloadFile(ctx, file, opts) + return val, exported.ConvertToDFSError(err) +} + +// TODO: add undelete diff --git a/sdk/storage/azdatalake/file/client_test.go b/sdk/storage/azdatalake/file/client_test.go index f9b2ccdacf71..09e50c55615c 100644 --- a/sdk/storage/azdatalake/file/client_test.go +++ b/sdk/storage/azdatalake/file/client_test.go @@ -7,16 +7,25 @@ package file_test import ( + "bytes" "context" + "crypto/md5" + "encoding/binary" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/internal/recording" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/datalakeerror" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/file" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/testcommon" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/sas" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "hash/crc64" + "io" + "math/rand" "net/http" + "os" "testing" "time" ) @@ -2302,3 +2311,786 @@ func (s *RecordedTestSuite) TestRenameFileIfETagMatchFalse() { _require.NotNil(err) testcommon.ValidateErrorCode(_require, err, datalakeerror.SourceConditionNotMet) } + +func (s *RecordedTestSuite) TestFileUploadDownloadStream() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + var fileSize int64 = 100 * 1024 * 1024 + fileName := testcommon.GenerateFileName(testName) + fClient, err := testcommon.GetFileClient(filesystemName, fileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := fClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + content := make([]byte, fileSize) + _, err = rand.Read(content) + _require.NoError(err) + md5Value := md5.Sum(content) + contentMD5 := md5Value[:] + + err = fClient.UploadStream(context.Background(), streaming.NopCloser(bytes.NewReader(content)), &file.UploadStreamOptions{ + Concurrency: 5, + ChunkSize: 4 * 1024 * 1024, + }) + _require.NoError(err) + + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) + + dResp, err := fClient.DownloadStream(context.Background(), nil) + _require.NoError(err) + + data, err := io.ReadAll(dResp.Body) + _require.NoError(err) + + downloadedMD5Value := md5.Sum(data) + downloadedContentMD5 := downloadedMD5Value[:] + + _require.EqualValues(downloadedContentMD5, contentMD5) + +} + +func (s *RecordedTestSuite) TestFileUploadDownloadSmallStream() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + var fileSize int64 = 10 * 1024 + fileName := testcommon.GenerateFileName(testName) + fClient, err := testcommon.GetFileClient(filesystemName, fileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := fClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + content := make([]byte, fileSize) + _, err = rand.Read(content) + _require.NoError(err) + md5Value := md5.Sum(content) + contentMD5 := md5Value[:] + + err = fClient.UploadStream(context.Background(), streaming.NopCloser(bytes.NewReader(content)), &file.UploadStreamOptions{ + Concurrency: 5, + ChunkSize: 2 * 1024, + }) + _require.NoError(err) + + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) + + dResp, err := fClient.DownloadStream(context.Background(), nil) + _require.NoError(err) + + data, err := io.ReadAll(dResp.Body) + _require.NoError(err) + + downloadedMD5Value := md5.Sum(data) + downloadedContentMD5 := downloadedMD5Value[:] + + _require.EqualValues(downloadedContentMD5, contentMD5) +} + +func (s *RecordedTestSuite) TestFileUploadTinyStream() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + var fileSize int64 = 4 + fileName := testcommon.GenerateFileName(testName) + fClient, err := testcommon.GetFileClient(filesystemName, fileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := fClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + content := make([]byte, fileSize) + _, err = rand.Read(content) + _require.NoError(err) + md5Value := md5.Sum(content) + contentMD5 := md5Value[:] + + err = fClient.UploadStream(context.Background(), streaming.NopCloser(bytes.NewReader(content)), &file.UploadStreamOptions{ + Concurrency: 5, + ChunkSize: 2 * 1024, + }) + _require.NoError(err) + + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) + + dResp, err := fClient.DownloadStream(context.Background(), nil) + _require.NoError(err) + + data, err := io.ReadAll(dResp.Body) + _require.NoError(err) + + downloadedMD5Value := md5.Sum(data) + downloadedContentMD5 := downloadedMD5Value[:] + + _require.EqualValues(downloadedContentMD5, contentMD5) +} + +func (s *RecordedTestSuite) TestFileUploadFile() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + var fileSize int64 = 100 * 1024 * 1024 + fileName := testcommon.GenerateFileName(testName) + fClient, err := testcommon.GetFileClient(filesystemName, fileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := fClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + // create local file + content := make([]byte, fileSize) + _, err = rand.Read(content) + _require.NoError(err) + err = os.WriteFile("testFile", content, 0644) + _require.NoError(err) + + defer func() { + err = os.Remove("testFile") + _require.NoError(err) + }() + + fh, err := os.Open("testFile") + _require.NoError(err) + + defer func(fh *os.File) { + err := fh.Close() + _require.NoError(err) + }(fh) + + hash := md5.New() + _, err = io.Copy(hash, fh) + _require.NoError(err) + contentMD5 := hash.Sum(nil) + + err = fClient.UploadFile(context.Background(), fh, &file.UploadFileOptions{ + Concurrency: 5, + ChunkSize: 4 * 1024 * 1024, + }) + _require.NoError(err) + + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) + + dResp, err := fClient.DownloadStream(context.Background(), nil) + _require.NoError(err) + + data, err := io.ReadAll(dResp.Body) + _require.NoError(err) + + downloadedMD5Value := md5.Sum(data) + downloadedContentMD5 := downloadedMD5Value[:] + + _require.EqualValues(downloadedContentMD5, contentMD5) +} + +func (s *RecordedTestSuite) TestSmallFileUploadFile() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + var fileSize int64 = 10 * 1024 + fileName := testcommon.GenerateFileName(testName) + fClient, err := testcommon.GetFileClient(filesystemName, fileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := fClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + // create local file + content := make([]byte, fileSize) + _, err = rand.Read(content) + _require.NoError(err) + err = os.WriteFile("testFile", content, 0644) + _require.NoError(err) + + defer func() { + err = os.Remove("testFile") + _require.NoError(err) + }() + + fh, err := os.Open("testFile") + _require.NoError(err) + + defer func(fh *os.File) { + err := fh.Close() + _require.NoError(err) + }(fh) + + hash := md5.New() + _, err = io.Copy(hash, fh) + _require.NoError(err) + contentMD5 := hash.Sum(nil) + + err = fClient.UploadFile(context.Background(), fh, &file.UploadFileOptions{ + Concurrency: 5, + ChunkSize: 2 * 1024, + }) + _require.NoError(err) + + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) + + dResp, err := fClient.DownloadStream(context.Background(), nil) + _require.NoError(err) + + data, err := io.ReadAll(dResp.Body) + _require.NoError(err) + + downloadedMD5Value := md5.Sum(data) + downloadedContentMD5 := downloadedMD5Value[:] + + _require.EqualValues(downloadedContentMD5, contentMD5) +} + +func (s *RecordedTestSuite) TestTinyFileUploadFile() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + var fileSize int64 = 10 + fileName := testcommon.GenerateFileName(testName) + fClient, err := testcommon.GetFileClient(filesystemName, fileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := fClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + // create local file + content := make([]byte, fileSize) + _, err = rand.Read(content) + _require.NoError(err) + err = os.WriteFile("testFile", content, 0644) + _require.NoError(err) + + defer func() { + err = os.Remove("testFile") + _require.NoError(err) + }() + + fh, err := os.Open("testFile") + _require.NoError(err) + + defer func(fh *os.File) { + err := fh.Close() + _require.NoError(err) + }(fh) + + hash := md5.New() + _, err = io.Copy(hash, fh) + _require.NoError(err) + contentMD5 := hash.Sum(nil) + + err = fClient.UploadFile(context.Background(), fh, &file.UploadFileOptions{ + ChunkSize: 2, + }) + _require.NoError(err) + + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) + + dResp, err := fClient.DownloadStream(context.Background(), nil) + _require.NoError(err) + + data, err := io.ReadAll(dResp.Body) + _require.NoError(err) + + downloadedMD5Value := md5.Sum(data) + downloadedContentMD5 := downloadedMD5Value[:] + + _require.EqualValues(downloadedContentMD5, contentMD5) +} + +func (s *RecordedTestSuite) TestFileUploadBuffer() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + var fileSize int64 = 100 * 1024 * 1024 + fileName := testcommon.GenerateFileName(testName) + fClient, err := testcommon.GetFileClient(filesystemName, fileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := fClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + content := make([]byte, fileSize) + _, err = rand.Read(content) + _require.NoError(err) + md5Value := md5.Sum(content) + contentMD5 := md5Value[:] + + err = fClient.UploadBuffer(context.Background(), content, &file.UploadBufferOptions{ + Concurrency: 5, + ChunkSize: 4 * 1024 * 1024, + }) + _require.NoError(err) + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) + + dResp, err := fClient.DownloadStream(context.Background(), nil) + _require.NoError(err) + + data, err := io.ReadAll(dResp.Body) + _require.NoError(err) + + downloadedMD5Value := md5.Sum(data) + downloadedContentMD5 := downloadedMD5Value[:] + + _require.EqualValues(downloadedContentMD5, contentMD5) +} + +func (s *RecordedTestSuite) TestFileUploadSmallBuffer() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + var fileSize int64 = 10 * 1024 + fileName := testcommon.GenerateFileName(testName) + fClient, err := testcommon.GetFileClient(filesystemName, fileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := fClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + content := make([]byte, fileSize) + _, err = rand.Read(content) + _require.NoError(err) + md5Value := md5.Sum(content) + contentMD5 := md5Value[:] + + err = fClient.UploadBuffer(context.Background(), content, &file.UploadBufferOptions{ + Concurrency: 5, + ChunkSize: 2 * 1024, + }) + _require.NoError(err) + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) + + dResp, err := fClient.DownloadStream(context.Background(), nil) + _require.NoError(err) + + data, err := io.ReadAll(dResp.Body) + _require.NoError(err) + + downloadedMD5Value := md5.Sum(data) + downloadedContentMD5 := downloadedMD5Value[:] + + _require.EqualValues(downloadedContentMD5, contentMD5) +} + +func (s *RecordedTestSuite) TestFileAppendAndFlushData() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + srcFileName := "src" + testcommon.GenerateFileName(testName) + + srcFClient, err := testcommon.GetFileClient(filesystemName, srcFileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := srcFClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + contentSize := 1024 * 8 // 8KB + rsc, _ := testcommon.GenerateData(contentSize) + + _, err = srcFClient.AppendData(context.Background(), 0, rsc, nil) + _require.NoError(err) + + _, err = srcFClient.FlushData(context.Background(), int64(contentSize), nil) + _require.NoError(err) + + gResp2, err := srcFClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, int64(contentSize)) +} + +func (s *RecordedTestSuite) TestFileAppendAndFlushDataWithValidation() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + srcFileName := "src" + testcommon.GenerateFileName(testName) + + srcFClient, err := testcommon.GetFileClient(filesystemName, srcFileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := srcFClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + contentSize := 1024 * 8 // 8KB + content := make([]byte, contentSize) + body := bytes.NewReader(content) + rsc := streaming.NopCloser(body) + contentCRC64 := crc64.Checksum(content, shared.CRC64Table) + + opts := &file.AppendDataOptions{ + TransactionalValidation: file.TransferValidationTypeComputeCRC64(), + } + putResp, err := srcFClient.AppendData(context.Background(), 0, rsc, opts) + _require.Nil(err) + // _require.Equal(putResp.RawResponse.StatusCode, 201) + _require.NotNil(putResp.ContentCRC64) + _require.EqualValues(binary.LittleEndian.Uint64(putResp.ContentCRC64), contentCRC64) + + _, err = srcFClient.FlushData(context.Background(), int64(contentSize), nil) + _require.NoError(err) + + gResp2, err := srcFClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, int64(contentSize)) +} + +func (s *RecordedTestSuite) TestFileDownloadFile() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + var fileSize int64 = 100 * 1024 * 1024 + fileName := testcommon.GenerateFileName(testName) + fClient, err := testcommon.GetFileClient(filesystemName, fileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := fClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + content := make([]byte, fileSize) + _, err = rand.Read(content) + _require.NoError(err) + md5Value := md5.Sum(content) + contentMD5 := md5Value[:] + + err = fClient.UploadBuffer(context.Background(), content, &file.UploadBufferOptions{ + Concurrency: 5, + ChunkSize: 4 * 1024 * 1024, + }) + _require.NoError(err) + + destFileName := "BigFile-downloaded.bin" + destFile, err := os.Create(destFileName) + _require.NoError(err) + defer func(name string) { + err = os.Remove(name) + _require.NoError(err) + }(destFileName) + defer func(destFile *os.File) { + err = destFile.Close() + _require.NoError(err) + }(destFile) + + cnt, err := fClient.DownloadFile(context.Background(), destFile, &file.DownloadFileOptions{ + ChunkSize: 10 * 1024 * 1024, + Concurrency: 5, + }) + _require.NoError(err) + _require.Equal(cnt, fileSize) + + hash := md5.New() + _, err = io.Copy(hash, destFile) + _require.NoError(err) + downloadedContentMD5 := hash.Sum(nil) + + _require.EqualValues(downloadedContentMD5, contentMD5) + + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) +} + +func (s *RecordedTestSuite) TestFileUploadDownloadSmallFile() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + var fileSize int64 = 10 * 1024 + fileName := testcommon.GenerateFileName(testName) + fClient, err := testcommon.GetFileClient(filesystemName, fileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := fClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + // create local file + _, content := testcommon.GenerateData(int(fileSize)) + srcFileName := "testFileUpload" + err = os.WriteFile(srcFileName, content, 0644) + _require.NoError(err) + defer func() { + err = os.Remove(srcFileName) + _require.NoError(err) + }() + fh, err := os.Open(srcFileName) + _require.NoError(err) + defer func(fh *os.File) { + err := fh.Close() + _require.NoError(err) + }(fh) + + srcHash := md5.New() + _, err = io.Copy(srcHash, fh) + _require.NoError(err) + contentMD5 := srcHash.Sum(nil) + + err = fClient.UploadFile(context.Background(), fh, &file.UploadFileOptions{ + Concurrency: 5, + ChunkSize: 2 * 1024, + }) + _require.NoError(err) + + destFileName := "SmallFile-downloaded.bin" + destFile, err := os.Create(destFileName) + _require.NoError(err) + defer func(name string) { + err = os.Remove(name) + _require.NoError(err) + }(destFileName) + defer func(destFile *os.File) { + err = destFile.Close() + _require.NoError(err) + }(destFile) + + cnt, err := fClient.DownloadFile(context.Background(), destFile, &file.DownloadFileOptions{ + ChunkSize: 2 * 1024, + Concurrency: 5, + }) + _require.NoError(err) + _require.Equal(cnt, fileSize) + + destHash := md5.New() + _, err = io.Copy(destHash, destFile) + _require.NoError(err) + downloadedContentMD5 := destHash.Sum(nil) + + _require.EqualValues(downloadedContentMD5, contentMD5) + + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) +} + +func (s *RecordedTestSuite) TestFileUploadDownloadWithProgress() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + var fileSize int64 = 10 * 1024 + fileName := testcommon.GenerateFileName(testName) + fClient, err := testcommon.GetFileClient(filesystemName, fileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := fClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + _, content := testcommon.GenerateData(int(fileSize)) + md5Value := md5.Sum(content) + contentMD5 := md5Value[:] + + bytesUploaded := int64(0) + err = fClient.UploadBuffer(context.Background(), content, &file.UploadBufferOptions{ + Concurrency: 5, + ChunkSize: 2 * 1024, + Progress: func(bytesTransferred int64) { + _require.GreaterOrEqual(bytesTransferred, bytesUploaded) + bytesUploaded = bytesTransferred + }, + }) + _require.NoError(err) + _require.Equal(bytesUploaded, fileSize) + + destBuffer := make([]byte, fileSize) + bytesDownloaded := int64(0) + cnt, err := fClient.DownloadBuffer(context.Background(), destBuffer, &file.DownloadBufferOptions{ + ChunkSize: 2 * 1024, + Concurrency: 5, + Progress: func(bytesTransferred int64) { + _require.GreaterOrEqual(bytesTransferred, bytesDownloaded) + bytesDownloaded = bytesTransferred + }, + }) + _require.NoError(err) + _require.Equal(cnt, fileSize) + _require.Equal(bytesDownloaded, fileSize) + + downloadedMD5Value := md5.Sum(destBuffer) + downloadedContentMD5 := downloadedMD5Value[:] + + _require.EqualValues(downloadedContentMD5, contentMD5) + + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) +} + +func (s *RecordedTestSuite) TestFileDownloadBuffer() { + _require := require.New(s.T()) + testName := s.T().Name() + + filesystemName := testcommon.GenerateFilesystemName(testName) + fsClient, err := testcommon.GetFilesystemClient(filesystemName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + defer testcommon.DeleteFilesystem(context.Background(), _require, fsClient) + + _, err = fsClient.Create(context.Background(), nil) + _require.Nil(err) + + var fileSize int64 = 100 * 1024 * 1024 + fileName := testcommon.GenerateFileName(testName) + fClient, err := testcommon.GetFileClient(filesystemName, fileName, s.T(), testcommon.TestAccountDatalake, nil) + _require.NoError(err) + + resp, err := fClient.Create(context.Background(), nil) + _require.Nil(err) + _require.NotNil(resp) + + content := make([]byte, fileSize) + _, err = rand.Read(content) + _require.NoError(err) + md5Value := md5.Sum(content) + contentMD5 := md5Value[:] + + err = fClient.UploadBuffer(context.Background(), content, &file.UploadBufferOptions{ + Concurrency: 5, + ChunkSize: 4 * 1024 * 1024, + }) + _require.NoError(err) + + destBuffer := make([]byte, fileSize) + cnt, err := fClient.DownloadBuffer(context.Background(), destBuffer, &file.DownloadBufferOptions{ + ChunkSize: 10 * 1024 * 1024, + Concurrency: 5, + }) + _require.NoError(err) + _require.Equal(cnt, fileSize) + + downloadedMD5Value := md5.Sum(destBuffer) + downloadedContentMD5 := downloadedMD5Value[:] + + _require.EqualValues(downloadedContentMD5, contentMD5) + + gResp2, err := fClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.Equal(*gResp2.ContentLength, fileSize) +} + +// TODO tests all uploads/downloads with other opts diff --git a/sdk/storage/azdatalake/file/constants.go b/sdk/storage/azdatalake/file/constants.go index 2345c88d547b..7dd13f5de226 100644 --- a/sdk/storage/azdatalake/file/constants.go +++ b/sdk/storage/azdatalake/file/constants.go @@ -7,6 +7,7 @@ package file import ( + "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/path" ) @@ -36,3 +37,14 @@ const ( CopyStatusTypeAborted CopyStatusType = path.CopyStatusTypeAborted CopyStatusTypeFailed CopyStatusType = path.CopyStatusTypeFailed ) + +// TransferValidationType abstracts the various mechanisms used to verify a transfer. +type TransferValidationType = exported.TransferValidationType + +// TransferValidationTypeCRC64 is a TransferValidationType used to provide a precomputed crc64. +type TransferValidationTypeCRC64 = exported.TransferValidationTypeCRC64 + +// TransferValidationTypeComputeCRC64 is a TransferValidationType that indicates a CRC64 should be computed during transfer. +func TransferValidationTypeComputeCRC64() TransferValidationType { + return exported.TransferValidationTypeComputeCRC64() +} diff --git a/sdk/storage/azdatalake/file/mmf_unix.go b/sdk/storage/azdatalake/file/mmf_unix.go new file mode 100644 index 000000000000..4c8ed223dbae --- /dev/null +++ b/sdk/storage/azdatalake/file/mmf_unix.go @@ -0,0 +1,38 @@ +//go:build go1.18 && (linux || darwin || dragonfly || freebsd || openbsd || netbsd || solaris || aix) +// +build go1.18 +// +build linux darwin dragonfly freebsd openbsd netbsd solaris aix + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package file + +import ( + "fmt" + "os" + "syscall" +) + +// mmb is a memory mapped buffer +type mmb []byte + +// newMMB creates a new memory mapped buffer with the specified size +func newMMB(size int64) (mmb, error) { + prot, flags := syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_ANON|syscall.MAP_PRIVATE + addr, err := syscall.Mmap(-1, 0, int(size), prot, flags) + if err != nil { + return nil, os.NewSyscallError("Mmap", err) + } + return mmb(addr), nil +} + +// delete cleans up the memory mapped buffer +func (m *mmb) delete() { + err := syscall.Munmap(*m) + *m = nil + if err != nil { + // if we get here, there is likely memory corruption. + // please open an issue https://github.com/Azure/azure-sdk-for-go/issues + panic(fmt.Sprintf("Munmap error: %v", err)) + } +} diff --git a/sdk/storage/azdatalake/file/mmf_windows.go b/sdk/storage/azdatalake/file/mmf_windows.go new file mode 100644 index 000000000000..b59e6b415776 --- /dev/null +++ b/sdk/storage/azdatalake/file/mmf_windows.go @@ -0,0 +1,56 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package file + +import ( + "fmt" + "os" + "reflect" + "syscall" + "unsafe" +) + +// mmb is a memory mapped buffer +type mmb []byte + +// newMMB creates a new memory mapped buffer with the specified size +func newMMB(size int64) (mmb, error) { + const InvalidHandleValue = ^uintptr(0) // -1 + + prot, access := uint32(syscall.PAGE_READWRITE), uint32(syscall.FILE_MAP_WRITE) + hMMF, err := syscall.CreateFileMapping(syscall.Handle(InvalidHandleValue), nil, prot, uint32(size>>32), uint32(size&0xffffffff), nil) + if err != nil { + return nil, os.NewSyscallError("CreateFileMapping", err) + } + defer func() { + _ = syscall.CloseHandle(hMMF) + }() + + addr, err := syscall.MapViewOfFile(hMMF, access, 0, 0, uintptr(size)) + if err != nil { + return nil, os.NewSyscallError("MapViewOfFile", err) + } + + m := mmb{} + h := (*reflect.SliceHeader)(unsafe.Pointer(&m)) + h.Data = addr + h.Len = int(size) + h.Cap = h.Len + return m, nil +} + +// delete cleans up the memory mapped buffer +func (m *mmb) delete() { + addr := uintptr(unsafe.Pointer(&(([]byte)(*m)[0]))) + *m = mmb{} + err := syscall.UnmapViewOfFile(addr) + if err != nil { + // if we get here, there is likely memory corruption. + // please open an issue https://github.com/Azure/azure-sdk-for-go/issues + panic(fmt.Sprintf("UnmapViewOfFile error: %v", err)) + } +} diff --git a/sdk/storage/azdatalake/file/models.go b/sdk/storage/azdatalake/file/models.go index a4f8b994ff1d..1f628860bded 100644 --- a/sdk/storage/azdatalake/file/models.go +++ b/sdk/storage/azdatalake/file/models.go @@ -7,15 +7,33 @@ package file import ( + "errors" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/generated" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/path" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/shared" + "io" "net/http" "strconv" "time" ) +const ( + _1MiB = 1024 * 1024 + CountToEnd = 0 + + // MaxUpdateRangeBytes indicates the maximum number of bytes that can be updated in a call to Client.UploadRange. + MaxUpdateRangeBytes = 4 * 1024 * 1024 // 4MiB + + // MaxFileSize indicates the maximum size of the file allowed. + MaxFileSize = 4 * 1024 * 1024 * 1024 * 1024 // 4 TiB + + // DefaultDownloadChunkSize is default chunk size + DefaultDownloadChunkSize = int64(4 * 1024 * 1024) // 4MiB +) + // CreateOptions contains the optional parameters when calling the Create operation. dfs endpoint. type CreateOptions struct { // AccessConditions contains parameters for accessing the file. @@ -167,6 +185,401 @@ func (o *RemoveAccessControlOptions) format(ACL string) (*generated.PathClientSe }, mode } +type HTTPRange = exported.HTTPRange + +// uploadFromReaderOptions identifies options used by the UploadBuffer and UploadFile functions. +type uploadFromReaderOptions struct { + // ChunkSize specifies the chunk size to use in bytes; the default (and maximum size) is MaxUpdateRangeBytes. + ChunkSize int64 + // Progress is a function that is invoked periodically as bytes are sent to the FileClient. + // Note that the progress reporting is not always increasing; it can go down when retrying a request. + Progress func(bytesTransferred int64) + // Concurrency indicates the maximum number of chunks to upload in parallel (default is 5) + Concurrency uint16 + // AccessConditions contains optional parameters to access leased entity. + AccessConditions *AccessConditions + // HTTPHeaders contains the optional path HTTP headers to set when the file is created. + HTTPHeaders *HTTPHeaders + // CPKInfo contains optional parameters to perform encryption using customer-provided key. + CPKInfo *CPKInfo +} + +// UploadStreamOptions provides set of configurations for Client.UploadStream operation. +type UploadStreamOptions struct { + // ChunkSize specifies the chunk size to use in bytes; the default (and maximum size) is MaxUpdateRangeBytes. + ChunkSize int64 + // Concurrency indicates the maximum number of chunks to upload in parallel (default is 5) + Concurrency uint16 + // AccessConditions contains optional parameters to access leased entity. + AccessConditions *AccessConditions + // HTTPHeaders contains the optional path HTTP headers to set when the file is created. + HTTPHeaders *HTTPHeaders + // CPKInfo contains optional parameters to perform encryption using customer-provided key. + CPKInfo *CPKInfo +} + +// UploadBufferOptions provides set of configurations for Client.UploadBuffer operation. +type UploadBufferOptions = uploadFromReaderOptions + +// UploadFileOptions provides set of configurations for Client.UploadFile operation. +type UploadFileOptions = uploadFromReaderOptions + +// FlushDataOptions contains the optional parameters for the Client.FlushData method. +type FlushDataOptions struct { + AccessConditions *AccessConditions + // HTTPHeaders contains the optional path HTTP headers to set when the file is created. + // CPKInfo contains optional parameters to perform encryption using customer-provided key. + CPKInfo *CPKInfo + HTTPHeaders *HTTPHeaders + Close *bool + RetainUncommittedData *bool +} + +func (o *FlushDataOptions) format(offset int64) (*generated.PathClientFlushDataOptions, *generated.ModifiedAccessConditions, *generated.LeaseAccessConditions, *generated.PathHTTPHeaders, *generated.CPKInfo, error) { + defaultRetainUncommitted := false + defaultClose := false + contentLength := int64(0) + + var httpHeaderOpts *generated.PathHTTPHeaders + var leaseAccessConditions *generated.LeaseAccessConditions + var modifiedAccessConditions *generated.ModifiedAccessConditions + var cpkInfoOpts *generated.CPKInfo + flushDataOpts := &generated.PathClientFlushDataOptions{ContentLength: &contentLength, Position: &offset} + + if o == nil { + flushDataOpts.RetainUncommittedData = &defaultRetainUncommitted + flushDataOpts.Close = &defaultClose + return flushDataOpts, nil, nil, nil, nil, nil + } + + if o != nil { + if o.RetainUncommittedData == nil { + flushDataOpts.RetainUncommittedData = &defaultRetainUncommitted + } else { + flushDataOpts.RetainUncommittedData = o.RetainUncommittedData + } + if o.Close == nil { + flushDataOpts.Close = &defaultClose + } else { + flushDataOpts.Close = o.Close + } + leaseAccessConditions, modifiedAccessConditions = exported.FormatPathAccessConditions(o.AccessConditions) + if o.HTTPHeaders != nil { + httpHeaderOpts := generated.PathHTTPHeaders{} + httpHeaderOpts.ContentMD5 = o.HTTPHeaders.ContentMD5 + httpHeaderOpts.ContentType = o.HTTPHeaders.ContentType + httpHeaderOpts.CacheControl = o.HTTPHeaders.CacheControl + httpHeaderOpts.ContentDisposition = o.HTTPHeaders.ContentDisposition + httpHeaderOpts.ContentEncoding = o.HTTPHeaders.ContentEncoding + } + if o.CPKInfo != nil { + cpkInfoOpts := generated.CPKInfo{} + cpkInfoOpts.EncryptionKey = o.CPKInfo.EncryptionKey + cpkInfoOpts.EncryptionKeySHA256 = o.CPKInfo.EncryptionKeySHA256 + cpkInfoOpts.EncryptionAlgorithm = o.CPKInfo.EncryptionAlgorithm + } + } + return flushDataOpts, modifiedAccessConditions, leaseAccessConditions, httpHeaderOpts, cpkInfoOpts, nil +} + +// AppendDataOptions contains the optional parameters for the Client.UploadRange method. +type AppendDataOptions struct { + // TransactionalValidation specifies the transfer validation type to use. + // The default is nil (no transfer validation). + TransactionalValidation TransferValidationType + // LeaseAccessConditions contains optional parameters to access leased entity. + LeaseAccessConditions *LeaseAccessConditions + // HTTPHeaders contains the optional path HTTP headers to set when the file is created. + HTTPHeaders *HTTPHeaders + // CPKInfo contains optional parameters to perform encryption using customer-provided key. + CPKInfo *CPKInfo +} + +func (o *AppendDataOptions) format(offset int64, body io.ReadSeekCloser) (*generated.PathClientAppendDataOptions, *generated.LeaseAccessConditions, *generated.PathHTTPHeaders, *generated.CPKInfo, error) { + if offset < 0 || body == nil { + return nil, nil, nil, nil, errors.New("invalid argument: offset must be >= 0 and body must not be nil") + } + + count, err := shared.ValidateSeekableStreamAt0AndGetCount(body) + if err != nil { + return nil, nil, nil, nil, err + } + + if count == 0 { + return nil, nil, nil, nil, errors.New("invalid argument: body must contain readable data whose size is > 0") + } + + appendDataOptions := &generated.PathClientAppendDataOptions{} + httpRange := exported.FormatHTTPRange(HTTPRange{ + Offset: offset, + Count: count, + }) + if httpRange != nil { + appendDataOptions.Position = &offset + appendDataOptions.ContentLength = &count + } + + var leaseAccessConditions *LeaseAccessConditions + var httpHeaderOpts *generated.PathHTTPHeaders + var cpkInfoOpts *generated.CPKInfo + + if o != nil { + leaseAccessConditions = o.LeaseAccessConditions + if o.HTTPHeaders != nil { + httpHeaderOpts := generated.PathHTTPHeaders{} + httpHeaderOpts.ContentMD5 = o.HTTPHeaders.ContentMD5 + httpHeaderOpts.ContentType = o.HTTPHeaders.ContentType + httpHeaderOpts.CacheControl = o.HTTPHeaders.CacheControl + httpHeaderOpts.ContentDisposition = o.HTTPHeaders.ContentDisposition + httpHeaderOpts.ContentEncoding = o.HTTPHeaders.ContentEncoding + } + if o.CPKInfo != nil { + cpkInfoOpts := generated.CPKInfo{} + cpkInfoOpts.EncryptionKey = o.CPKInfo.EncryptionKey + cpkInfoOpts.EncryptionKeySHA256 = o.CPKInfo.EncryptionKeySHA256 + cpkInfoOpts.EncryptionAlgorithm = o.CPKInfo.EncryptionAlgorithm + } + } + if o != nil && o.TransactionalValidation != nil { + _, err = o.TransactionalValidation.Apply(body, appendDataOptions) + if err != nil { + return nil, nil, nil, nil, err + } + } + + return appendDataOptions, leaseAccessConditions, httpHeaderOpts, cpkInfoOpts, nil +} + +func (u *UploadStreamOptions) setDefaults() { + if u.Concurrency == 0 { + u.Concurrency = 1 + } + + if u.ChunkSize < _1MiB { + u.ChunkSize = _1MiB + } +} + +func (u *uploadFromReaderOptions) getAppendDataOptions() *AppendDataOptions { + if u == nil { + return nil + } + leaseAccessConditions, _ := exported.FormatPathAccessConditions(u.AccessConditions) + return &AppendDataOptions{ + LeaseAccessConditions: leaseAccessConditions, + HTTPHeaders: u.HTTPHeaders, + CPKInfo: u.CPKInfo, + } +} + +func (u *uploadFromReaderOptions) getFlushDataOptions() *FlushDataOptions { + if u == nil { + return nil + } + return &FlushDataOptions{ + AccessConditions: u.AccessConditions, + HTTPHeaders: u.HTTPHeaders, + CPKInfo: u.CPKInfo, + } +} + +func (u *UploadStreamOptions) getAppendDataOptions() *AppendDataOptions { + if u == nil { + return nil + } + leaseAccessConditions, _ := exported.FormatPathAccessConditions(u.AccessConditions) + return &AppendDataOptions{ + LeaseAccessConditions: leaseAccessConditions, + HTTPHeaders: u.HTTPHeaders, + CPKInfo: u.CPKInfo, + } +} + +func (u *UploadStreamOptions) getFlushDataOptions() *FlushDataOptions { + if u == nil { + return nil + } + return &FlushDataOptions{ + AccessConditions: u.AccessConditions, + HTTPHeaders: u.HTTPHeaders, + CPKInfo: u.CPKInfo, + } +} + +// DownloadStreamOptions contains the optional parameters for the Client.Download method. +type DownloadStreamOptions struct { + // When set to true and specified together with the Range, the service returns the MD5 hash for the range, as long as the + // range is less than or equal to 4 MB in size. + RangeGetContentMD5 *bool + + // Range specifies a range of bytes. The default value is all bytes. + Range *HTTPRange + + AccessConditions *AccessConditions + CPKInfo *CPKInfo + CPKScopeInfo *CPKScopeInfo +} + +func (o *DownloadStreamOptions) format() *blob.DownloadStreamOptions { + if o == nil { + return nil + } + + downloadStreamOptions := &blob.DownloadStreamOptions{} + if o.Range != nil { + downloadStreamOptions.Range = blob.HTTPRange{ + Offset: o.Range.Offset, + Count: o.Range.Count, + } + } + if o.CPKInfo != nil { + downloadStreamOptions.CPKInfo = &blob.CPKInfo{ + EncryptionKey: o.CPKInfo.EncryptionKey, + EncryptionKeySHA256: o.CPKInfo.EncryptionKeySHA256, + EncryptionAlgorithm: (*blob.EncryptionAlgorithmType)(o.CPKInfo.EncryptionAlgorithm), + } + } + + downloadStreamOptions.RangeGetContentMD5 = o.RangeGetContentMD5 + downloadStreamOptions.AccessConditions = exported.FormatBlobAccessConditions(o.AccessConditions) + downloadStreamOptions.CPKScopeInfo = (*blob.CPKScopeInfo)(o.CPKScopeInfo) + return downloadStreamOptions +} + +// DownloadBufferOptions contains the optional parameters for the DownloadBuffer method. +type DownloadBufferOptions struct { + // Range specifies a range of bytes. The default value is all bytes. + Range *HTTPRange + + // ChunkSize specifies the block size to use for each parallel download; the default size is DefaultDownloadBlockSize. + ChunkSize int64 + + // Progress is a function that is invoked periodically as bytes are received. + Progress func(bytesTransferred int64) + + // BlobAccessConditions indicates the access conditions used when making HTTP GET requests against the blob. + AccessConditions *AccessConditions + + // CPKInfo contains a group of parameters for client provided encryption key. + CPKInfo *CPKInfo + + // CPKScopeInfo contains a group of parameters for client provided encryption scope. + CPKScopeInfo *CPKScopeInfo + + // Concurrency indicates the maximum number of blocks to download in parallel (0=default). + Concurrency uint16 + + // RetryReaderOptionsPerChunk is used when downloading each block. + RetryReaderOptionsPerChunk *RetryReaderOptions +} + +func (o *DownloadBufferOptions) format() *blob.DownloadBufferOptions { + if o == nil { + return nil + } + + downloadBufferOptions := &blob.DownloadBufferOptions{} + if o.Range != nil { + downloadBufferOptions.Range = blob.HTTPRange{ + Offset: o.Range.Offset, + Count: o.Range.Count, + } + } + if o.CPKInfo != nil { + downloadBufferOptions.CPKInfo = &blob.CPKInfo{ + EncryptionKey: o.CPKInfo.EncryptionKey, + EncryptionKeySHA256: o.CPKInfo.EncryptionKeySHA256, + EncryptionAlgorithm: (*blob.EncryptionAlgorithmType)(o.CPKInfo.EncryptionAlgorithm), + } + } + + downloadBufferOptions.AccessConditions = exported.FormatBlobAccessConditions(o.AccessConditions) + downloadBufferOptions.CPKScopeInfo = (*blob.CPKScopeInfo)(o.CPKScopeInfo) + downloadBufferOptions.BlockSize = o.ChunkSize + downloadBufferOptions.Progress = o.Progress + downloadBufferOptions.Concurrency = o.Concurrency + if o.RetryReaderOptionsPerChunk != nil { + newFunc := func(failureCount int32, lastError error, rnge blob.HTTPRange, willRetry bool) { + newRange := HTTPRange{ + Offset: rnge.Offset, + Count: rnge.Count, + } + o.RetryReaderOptionsPerChunk.OnFailedRead(failureCount, lastError, newRange, willRetry) + } + downloadBufferOptions.RetryReaderOptionsPerBlock.OnFailedRead = newFunc + downloadBufferOptions.RetryReaderOptionsPerBlock.EarlyCloseAsError = o.RetryReaderOptionsPerChunk.EarlyCloseAsError + downloadBufferOptions.RetryReaderOptionsPerBlock.MaxRetries = o.RetryReaderOptionsPerChunk.MaxRetries + } + + return downloadBufferOptions +} + +// DownloadFileOptions contains the optional parameters for the Client.DownloadFile method. +type DownloadFileOptions struct { + // Range specifies a range of bytes. The default value is all bytes. + Range *HTTPRange + + // ChunkSize specifies the block size to use for each parallel download; the default size is DefaultDownloadBlockSize. + ChunkSize int64 + + // Progress is a function that is invoked periodically as bytes are received. + Progress func(bytesTransferred int64) + + // BlobAccessConditions indicates the access conditions used when making HTTP GET requests against the blob. + AccessConditions *AccessConditions + + // ClientProvidedKeyOptions indicates the client provided key by name and/or by value to encrypt/decrypt data. + CPKInfo *CPKInfo + CPKScopeInfo *CPKScopeInfo + + // Concurrency indicates the maximum number of blocks to download in parallel. The default value is 5. + Concurrency uint16 + + // RetryReaderOptionsPerChunk is used when downloading each block. + RetryReaderOptionsPerChunk *RetryReaderOptions +} + +func (o *DownloadFileOptions) format() *blob.DownloadFileOptions { + if o == nil { + return nil + } + + downloadFileOptions := &blob.DownloadFileOptions{} + if o.Range != nil { + downloadFileOptions.Range = blob.HTTPRange{ + Offset: o.Range.Offset, + Count: o.Range.Count, + } + } + if o.CPKInfo != nil { + downloadFileOptions.CPKInfo = &blob.CPKInfo{ + EncryptionKey: o.CPKInfo.EncryptionKey, + EncryptionKeySHA256: o.CPKInfo.EncryptionKeySHA256, + EncryptionAlgorithm: (*blob.EncryptionAlgorithmType)(o.CPKInfo.EncryptionAlgorithm), + } + } + + downloadFileOptions.AccessConditions = exported.FormatBlobAccessConditions(o.AccessConditions) + downloadFileOptions.CPKScopeInfo = (*blob.CPKScopeInfo)(o.CPKScopeInfo) + downloadFileOptions.BlockSize = o.ChunkSize + downloadFileOptions.Progress = o.Progress + downloadFileOptions.Concurrency = o.Concurrency + if o.RetryReaderOptionsPerChunk != nil { + newFunc := func(failureCount int32, lastError error, rnge blob.HTTPRange, willRetry bool) { + newRange := HTTPRange{ + Offset: rnge.Offset, + Count: rnge.Count, + } + o.RetryReaderOptionsPerChunk.OnFailedRead(failureCount, lastError, newRange, willRetry) + } + downloadFileOptions.RetryReaderOptionsPerBlock.OnFailedRead = newFunc + downloadFileOptions.RetryReaderOptionsPerBlock.EarlyCloseAsError = o.RetryReaderOptionsPerChunk.EarlyCloseAsError + downloadFileOptions.RetryReaderOptionsPerBlock.MaxRetries = o.RetryReaderOptionsPerChunk.MaxRetries + } + + return downloadFileOptions +} + // CreationExpiryType defines values for Create() ExpiryType type CreationExpiryType interface { Format() (generated.ExpiryOptions, *string) diff --git a/sdk/storage/azdatalake/file/responses.go b/sdk/storage/azdatalake/file/responses.go index 3116edab7355..a518ee58f376 100644 --- a/sdk/storage/azdatalake/file/responses.go +++ b/sdk/storage/azdatalake/file/responses.go @@ -7,8 +7,14 @@ package file import ( + "context" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/generated" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/path" + "io" + "time" ) // SetExpiryResponse contains the response fields for the SetExpiry operation. @@ -26,12 +32,229 @@ type UpdateAccessControlResponse = generated.PathClientSetAccessControlRecursive // RemoveAccessControlResponse contains the response fields for the RemoveAccessControlRecursive operation. type RemoveAccessControlResponse = generated.PathClientSetAccessControlRecursiveResponse +// AppendDataResponse contains the response from method Client.AppendData. +type AppendDataResponse = generated.PathClientAppendDataResponse + +// FlushDataResponse contains the response from method Client.FlushData. +type FlushDataResponse = generated.PathClientFlushDataResponse + // RenameResponse contains the response fields for the Create operation. type RenameResponse struct { Response generated.PathClientCreateResponse NewFileClient *Client } +// DownloadStreamResponse contains the response from the DownloadStream method. +// To read from the stream, read from the Body field, or call the NewRetryReader method. +type DownloadStreamResponse struct { + DownloadResponse + client *Client + getInfo httpGetterInfo + cpkInfo *CPKInfo + cpkScope *CPKScopeInfo +} + +// NewRetryReader constructs new RetryReader stream for reading data. If a connection fails while +// reading, it will make additional requests to reestablish a connection and continue reading. +// Pass nil for options to accept the default options. +// Callers of this method should not access the DownloadStreamResponse.Body field. +func (r *DownloadStreamResponse) NewRetryReader(ctx context.Context, options *RetryReaderOptions) *RetryReader { + if options == nil { + options = &RetryReaderOptions{} + } + + return newRetryReader(ctx, r.Body, r.getInfo, func(ctx context.Context, getInfo httpGetterInfo) (io.ReadCloser, error) { + accessConditions := &AccessConditions{ + ModifiedAccessConditions: &ModifiedAccessConditions{IfMatch: getInfo.ETag}, + } + options := DownloadStreamOptions{ + Range: getInfo.Range, + AccessConditions: accessConditions, + CPKInfo: r.cpkInfo, + CPKScopeInfo: r.cpkScope, + } + resp, err := r.client.DownloadStream(ctx, &options) + if err != nil { + return nil, err + } + return resp.Body, err + }, *options) +} + +// DownloadResponse contains the response fields for the GetProperties operation. +type DownloadResponse struct { + // AcceptRanges contains the information returned from the Accept-Ranges header response. + AcceptRanges *string + + // Body contains the streaming response. + Body io.ReadCloser + + // CacheControl contains the information returned from the Cache-Control header response. + CacheControl *string + + // ClientRequestID contains the information returned from the x-ms-client-request-id header response. + ClientRequestID *string + + // ContentCRC64 contains the information returned from the x-ms-content-crc64 header response. + ContentCRC64 []byte + + // ContentDisposition contains the information returned from the Content-Disposition header response. + ContentDisposition *string + + // ContentEncoding contains the information returned from the Content-Encoding header response. + ContentEncoding *string + + // ContentLanguage contains the information returned from the Content-Language header response. + ContentLanguage *string + + // ContentLength contains the information returned from the Content-Length header response. + ContentLength *int64 + + // ContentMD5 contains the information returned from the Content-MD5 header response. + ContentMD5 []byte + + // ContentRange contains the information returned from the Content-Range header response. + ContentRange *string + + // ContentType contains the information returned from the Content-Type header response. + ContentType *string + + // CopyCompletionTime contains the information returned from the x-ms-copy-completion-time header response. + CopyCompletionTime *time.Time + + // CopyID contains the information returned from the x-ms-copy-id header response. + CopyID *string + + // CopyProgress contains the information returned from the x-ms-copy-progress header response. + CopyProgress *string + + // CopySource contains the information returned from the x-ms-copy-source header response. + CopySource *string + + // CopyStatus contains the information returned from the x-ms-copy-status header response. + CopyStatus *CopyStatusType + + // CopyStatusDescription contains the information returned from the x-ms-copy-status-description header response. + CopyStatusDescription *string + + // Date contains the information returned from the Date header response. + Date *time.Time + + // ETag contains the information returned from the ETag header response. + ETag *azcore.ETag + + // EncryptionKeySHA256 contains the information returned from the x-ms-encryption-key-sha256 header response. + EncryptionKeySHA256 *string + + // EncryptionScope contains the information returned from the x-ms-encryption-scope header response. + EncryptionScope *string + + // ErrorCode contains the information returned from the x-ms-error-code header response. + ErrorCode *string + + // ImmutabilityPolicyExpiresOn contains the information returned from the x-ms-immutability-policy-until-date header response. + ImmutabilityPolicyExpiresOn *time.Time + + // ImmutabilityPolicyMode contains the information returned from the x-ms-immutability-policy-mode header response. + ImmutabilityPolicyMode *ImmutabilityPolicyMode + + // IsCurrentVersion contains the information returned from the x-ms-is-current-version header response. + IsCurrentVersion *bool + + // IsSealed contains the information returned from the x-ms-blob-sealed header response. + IsSealed *bool + + // IsServerEncrypted contains the information returned from the x-ms-server-encrypted header response. + IsServerEncrypted *bool + + // LastAccessed contains the information returned from the x-ms-last-access-time header response. + LastAccessed *time.Time + + // LastModified contains the information returned from the Last-Modified header response. + LastModified *time.Time + + // LeaseDuration contains the information returned from the x-ms-lease-duration header response. + LeaseDuration *azdatalake.DurationType + + // LeaseState contains the information returned from the x-ms-lease-state header response. + LeaseState *azdatalake.StateType + + // LeaseStatus contains the information returned from the x-ms-lease-status header response. + LeaseStatus *azdatalake.StatusType + + // LegalHold contains the information returned from the x-ms-legal-hold header response. + LegalHold *bool + + // Metadata contains the information returned from the x-ms-meta header response. + Metadata map[string]*string + + // ObjectReplicationPolicyID contains the information returned from the x-ms-or-policy-id header response. + ObjectReplicationPolicyID *string + + // ObjectReplicationRules contains the information returned from the x-ms-or header response. + ObjectReplicationRules map[string]*string + + // RequestID contains the information returned from the x-ms-request-id header response. + RequestID *string + + // TagCount contains the information returned from the x-ms-tag-count header response. + TagCount *int64 + + // Version contains the information returned from the x-ms-version header response. + Version *string + + // VersionID contains the information returned from the x-ms-version-id header response. + VersionID *string +} + +func FormatDownloadStreamResponse(r *blob.DownloadStreamResponse) DownloadResponse { + newResp := DownloadResponse{} + if r != nil { + newResp.AcceptRanges = r.AcceptRanges + newResp.Body = r.Body + newResp.ContentCRC64 = r.ContentCRC64 + newResp.ContentRange = r.ContentRange + newResp.CacheControl = r.CacheControl + newResp.ErrorCode = r.ErrorCode + newResp.ClientRequestID = r.ClientRequestID + newResp.ContentDisposition = r.ContentDisposition + newResp.ContentEncoding = r.ContentEncoding + newResp.ContentLanguage = r.ContentLanguage + newResp.ContentLength = r.ContentLength + newResp.ContentMD5 = r.ContentMD5 + newResp.ContentType = r.ContentType + newResp.CopyCompletionTime = r.CopyCompletionTime + newResp.CopyID = r.CopyID + newResp.CopyProgress = r.CopyProgress + newResp.CopySource = r.CopySource + newResp.CopyStatus = r.CopyStatus + newResp.CopyStatusDescription = r.CopyStatusDescription + newResp.Date = r.Date + newResp.ETag = r.ETag + newResp.EncryptionKeySHA256 = r.EncryptionKeySHA256 + newResp.EncryptionScope = r.EncryptionScope + newResp.ImmutabilityPolicyExpiresOn = r.ImmutabilityPolicyExpiresOn + newResp.ImmutabilityPolicyMode = r.ImmutabilityPolicyMode + newResp.IsCurrentVersion = r.IsCurrentVersion + newResp.IsSealed = r.IsSealed + newResp.IsServerEncrypted = r.IsServerEncrypted + newResp.LastAccessed = r.LastAccessed + newResp.LastModified = r.LastModified + newResp.LeaseDuration = r.LeaseDuration + newResp.LeaseState = r.LeaseState + newResp.LeaseStatus = r.LeaseStatus + newResp.LegalHold = r.LegalHold + newResp.Metadata = r.Metadata + newResp.ObjectReplicationPolicyID = r.ObjectReplicationPolicyID + newResp.ObjectReplicationRules = r.DownloadResponse.ObjectReplicationRules + newResp.RequestID = r.RequestID + newResp.TagCount = r.TagCount + newResp.Version = r.Version + newResp.VersionID = r.VersionID + } + return newResp +} + // ========================================== path imports =========================================================== // SetAccessControlResponse contains the response fields for the SetAccessControl operation. diff --git a/sdk/storage/azdatalake/file/retry_reader.go b/sdk/storage/azdatalake/file/retry_reader.go new file mode 100644 index 000000000000..66e3f35edf0d --- /dev/null +++ b/sdk/storage/azdatalake/file/retry_reader.go @@ -0,0 +1,191 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package file + +import ( + "context" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "io" + "net" + "strings" + "sync" +) + +// HTTPGetter is a function type that refers to a method that performs an HTTP GET operation. +type httpGetter func(ctx context.Context, i httpGetterInfo) (io.ReadCloser, error) + +// HTTPGetterInfo is passed to an HTTPGetter function passing it parameters +// that should be used to make an HTTP GET request. +type httpGetterInfo struct { + Range *HTTPRange + + // ETag specifies the resource's etag that should be used when creating + // the HTTP GET request's If-Match header + ETag *azcore.ETag +} + +// RetryReaderOptions configures the retry reader's behavior. +// Zero-value fields will have their specified default values applied during use. +// This allows for modification of a subset of fields. +type RetryReaderOptions struct { + // MaxRetries specifies the maximum number of attempts a failed read will be retried + // before producing an error. + // The default value is three. + MaxRetries int32 + + // OnFailedRead, when non-nil, is called after any failure to read. Expected usage is diagnostic logging. + OnFailedRead func(failureCount int32, lastError error, rnge HTTPRange, willRetry bool) + + // EarlyCloseAsError can be set to true to prevent retries after "read on closed response body". By default, + // retryReader has the following special behaviour: closing the response body before it is all read is treated as a + // retryable error. This is to allow callers to force a retry by closing the body from another goroutine (e.g. if the = + // read is too slow, caller may want to force a retry in the hope that the retry will be quicker). If + // TreatEarlyCloseAsError is true, then retryReader's special behaviour is suppressed, and "read on closed body" is instead + // treated as a fatal (non-retryable) error. + // Note that setting TreatEarlyCloseAsError only guarantees that Closing will produce a fatal error if the Close happens + // from the same "thread" (goroutine) as Read. Concurrent Close calls from other goroutines may instead produce network errors + // which will be retried. + // The default value is false. + EarlyCloseAsError bool + + doInjectError bool + doInjectErrorRound int32 + injectedError error +} + +// RetryReader attempts to read from response, and if there is a retry-able network error +// returned during reading, it will retry according to retry reader option through executing +// user defined action with provided data to get a new response, and continue the overall reading process +// through reading from the new response. +// RetryReader implements the io.ReadCloser interface. +type RetryReader struct { + ctx context.Context + info httpGetterInfo + retryReaderOptions RetryReaderOptions + getter httpGetter + countWasBounded bool + + // we support Close-ing during Reads (from other goroutines), so we protect the shared state, which is response + responseMu *sync.Mutex + response io.ReadCloser +} + +// newRetryReader creates a retry reader. +func newRetryReader(ctx context.Context, initialResponse io.ReadCloser, info httpGetterInfo, getter httpGetter, o RetryReaderOptions) *RetryReader { + if o.MaxRetries < 1 { + o.MaxRetries = 3 + } + return &RetryReader{ + ctx: ctx, + getter: getter, + info: info, + countWasBounded: info.Range.Count != CountToEnd, + response: initialResponse, + responseMu: &sync.Mutex{}, + retryReaderOptions: o, + } +} + +// setResponse function +func (s *RetryReader) setResponse(r io.ReadCloser) { + s.responseMu.Lock() + defer s.responseMu.Unlock() + s.response = r +} + +// Read from retry reader +func (s *RetryReader) Read(p []byte) (n int, err error) { + for try := int32(0); ; try++ { + //fmt.Println(try) // Comment out for debugging. + if s.countWasBounded && s.info.Range.Count == CountToEnd { + // User specified an original count and the remaining bytes are 0, return 0, EOF + return 0, io.EOF + } + + s.responseMu.Lock() + resp := s.response + s.responseMu.Unlock() + if resp == nil { // We don't have a response stream to read from, try to get one. + newResponse, err := s.getter(s.ctx, s.info) + if err != nil { + return 0, err + } + // Successful GET; this is the network stream we'll read from. + s.setResponse(newResponse) + resp = newResponse + } + n, err := resp.Read(p) // Read from the stream (this will return non-nil err if forceRetry is called, from another goroutine, while it is running) + + // Injection mechanism for testing. + if s.retryReaderOptions.doInjectError && try == s.retryReaderOptions.doInjectErrorRound { + if s.retryReaderOptions.injectedError != nil { + err = s.retryReaderOptions.injectedError + } else { + err = &net.DNSError{IsTemporary: true} + } + } + + // We successfully read data or end EOF. + if err == nil || err == io.EOF { + s.info.Range.Offset += int64(n) // Increments the start offset in case we need to make a new HTTP request in the future + if s.info.Range.Count != CountToEnd { + s.info.Range.Count -= int64(n) // Decrement the count in case we need to make a new HTTP request in the future + } + return n, err // Return the return to the caller + } + _ = s.Close() + + s.setResponse(nil) // Our stream is no longer good + + // Check the retry count and error code, and decide whether to retry. + retriesExhausted := try >= s.retryReaderOptions.MaxRetries + _, isNetError := err.(net.Error) + isUnexpectedEOF := err == io.ErrUnexpectedEOF + willRetry := (isNetError || isUnexpectedEOF || s.wasRetryableEarlyClose(err)) && !retriesExhausted + + // Notify, for logging purposes, of any failures + if s.retryReaderOptions.OnFailedRead != nil { + failureCount := try + 1 // because try is zero-based + s.retryReaderOptions.OnFailedRead(failureCount, err, *s.info.Range, willRetry) + } + + if willRetry { + continue + // Loop around and try to get and read from new stream. + } + return n, err // Not retryable, or retries exhausted, so just return + } +} + +// By default, we allow early Closing, from another concurrent goroutine, to be used to force a retry +// Is this safe, to close early from another goroutine? Early close ultimately ends up calling +// net.Conn.Close, and that is documented as "Any blocked Read or Write operations will be unblocked and return errors" +// which is exactly the behaviour we want. +// NOTE: that if caller has forced an early Close from a separate goroutine (separate from the Read) +// then there are two different types of error that may happen - either the one we check for here, +// or a net.Error (due to closure of connection). Which one happens depends on timing. We only need this routine +// to check for one, since the other is a net.Error, which our main Read retry loop is already handing. +func (s *RetryReader) wasRetryableEarlyClose(err error) bool { + if s.retryReaderOptions.EarlyCloseAsError { + return false // user wants all early closes to be errors, and so not retryable + } + // unfortunately, http.errReadOnClosedResBody is private, so the best we can do here is to check for its text + return strings.HasSuffix(err.Error(), ReadOnClosedBodyMessage) +} + +// ReadOnClosedBodyMessage of retry reader +const ReadOnClosedBodyMessage = "read on closed response body" + +// Close retry reader +func (s *RetryReader) Close() error { + s.responseMu.Lock() + defer s.responseMu.Unlock() + if s.response != nil { + return s.response.Close() + } + return nil +} diff --git a/sdk/storage/azdatalake/filesystem/client.go b/sdk/storage/azdatalake/filesystem/client.go index c0a8146de2ad..c0c7ee5b8086 100644 --- a/sdk/storage/azdatalake/filesystem/client.go +++ b/sdk/storage/azdatalake/filesystem/client.go @@ -162,8 +162,8 @@ func (fs *Client) containerClient() *container.Client { return containerClient } -func (f *Client) identityCredential() *azcore.TokenCredential { - return base.IdentityCredentialComposite((*base.CompositeClient[generated.FileSystemClient, generated.FileSystemClient, container.Client])(f)) +func (fs *Client) identityCredential() *azcore.TokenCredential { + return base.IdentityCredentialComposite((*base.CompositeClient[generated.FileSystemClient, generated.FileSystemClient, container.Client])(fs)) } func (fs *Client) sharedKey() *exported.SharedKeyCredential { diff --git a/sdk/storage/azdatalake/internal/exported/shared_key_credential.go b/sdk/storage/azdatalake/internal/exported/shared_key_credential.go index 63539ea0b10a..d54cdc3a0b76 100644 --- a/sdk/storage/azdatalake/internal/exported/shared_key_credential.go +++ b/sdk/storage/azdatalake/internal/exported/shared_key_credential.go @@ -182,7 +182,7 @@ func (c *SharedKeyCredential) buildCanonicalizedResource(u *url.URL) (string, er // Join the sorted key values separated by ',' // Then prepend "keyName:"; then add this string to the buffer - cr.WriteString("\n" + paramName + ":" + strings.Join(paramValues, ",")) + cr.WriteString("\n" + strings.ToLower(paramName) + ":" + strings.Join(paramValues, ",")) } } return cr.String(), nil diff --git a/sdk/storage/azdatalake/internal/exported/transfer_validation_option.go b/sdk/storage/azdatalake/internal/exported/transfer_validation_option.go new file mode 100644 index 000000000000..85430ebd1c7e --- /dev/null +++ b/sdk/storage/azdatalake/internal/exported/transfer_validation_option.go @@ -0,0 +1,56 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package exported + +import ( + "bytes" + "encoding/binary" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/generated" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/shared" + "hash/crc64" + "io" +) + +// TransferValidationType abstracts the various mechanisms used to verify a transfer. +type TransferValidationType interface { + Apply(io.ReadSeekCloser, generated.TransactionalContentSetter) (io.ReadSeekCloser, error) + notPubliclyImplementable() +} + +// TransferValidationTypeCRC64 is a TransferValidationType used to provide a precomputed CRC64. +type TransferValidationTypeCRC64 uint64 + +func (c TransferValidationTypeCRC64) Apply(rsc io.ReadSeekCloser, cfg generated.TransactionalContentSetter) (io.ReadSeekCloser, error) { + buf := make([]byte, 8) + binary.LittleEndian.PutUint64(buf, uint64(c)) + cfg.SetCRC64(buf) + return rsc, nil +} + +func (TransferValidationTypeCRC64) notPubliclyImplementable() {} + +// TransferValidationTypeComputeCRC64 is a TransferValidationType that indicates a CRC64 should be computed during transfer. +func TransferValidationTypeComputeCRC64() TransferValidationType { + return transferValidationTypeFn(func(rsc io.ReadSeekCloser, cfg generated.TransactionalContentSetter) (io.ReadSeekCloser, error) { + buf, err := io.ReadAll(rsc) + if err != nil { + return nil, err + } + + crc := crc64.Checksum(buf, shared.CRC64Table) + return TransferValidationTypeCRC64(crc).Apply(streaming.NopCloser(bytes.NewReader(buf)), cfg) + }) +} + +type transferValidationTypeFn func(io.ReadSeekCloser, generated.TransactionalContentSetter) (io.ReadSeekCloser, error) + +func (t transferValidationTypeFn) Apply(rsc io.ReadSeekCloser, cfg generated.TransactionalContentSetter) (io.ReadSeekCloser, error) { + return t(rsc, cfg) +} + +func (transferValidationTypeFn) notPubliclyImplementable() {} diff --git a/sdk/storage/azdatalake/internal/generated/models.go b/sdk/storage/azdatalake/internal/generated/models.go new file mode 100644 index 000000000000..b3f86d5973cb --- /dev/null +++ b/sdk/storage/azdatalake/internal/generated/models.go @@ -0,0 +1,15 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package generated + +type TransactionalContentSetter interface { + SetCRC64([]byte) +} + +func (a *PathClientAppendDataOptions) SetCRC64(v []byte) { + a.TransactionalContentCRC64 = v +} diff --git a/sdk/storage/azdatalake/internal/path/constants.go b/sdk/storage/azdatalake/internal/path/constants.go index 7dd11049e38e..ce070f694d23 100644 --- a/sdk/storage/azdatalake/internal/path/constants.go +++ b/sdk/storage/azdatalake/internal/path/constants.go @@ -6,14 +6,16 @@ package path -import "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" +import ( + "github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blob" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/internal/generated" +) -// EncryptionAlgorithmType defines values for EncryptionAlgorithmType. -type EncryptionAlgorithmType = blob.EncryptionAlgorithmType +type EncryptionAlgorithmType = generated.EncryptionAlgorithmType const ( - EncryptionAlgorithmTypeNone EncryptionAlgorithmType = blob.EncryptionAlgorithmTypeNone - EncryptionAlgorithmTypeAES256 EncryptionAlgorithmType = blob.EncryptionAlgorithmTypeAES256 + EncryptionAlgorithmTypeNone EncryptionAlgorithmType = generated.EncryptionAlgorithmTypeNone + EncryptionAlgorithmTypeAES256 EncryptionAlgorithmType = generated.EncryptionAlgorithmTypeAES256 ) type ImmutabilityPolicyMode = blob.ImmutabilityPolicyMode diff --git a/sdk/storage/azdatalake/internal/path/models.go b/sdk/storage/azdatalake/internal/path/models.go index 893bc9d40d19..f476fcae518f 100644 --- a/sdk/storage/azdatalake/internal/path/models.go +++ b/sdk/storage/azdatalake/internal/path/models.go @@ -30,7 +30,7 @@ func FormatGetPropertiesOptions(o *GetPropertiesOptions) *blob.GetPropertiesOpti AccessConditions: accessConditions, CPKInfo: &blob.CPKInfo{ EncryptionKey: o.CPKInfo.EncryptionKey, - EncryptionAlgorithm: o.CPKInfo.EncryptionAlgorithm, + EncryptionAlgorithm: (*blob.EncryptionAlgorithmType)(o.CPKInfo.EncryptionAlgorithm), EncryptionKeySHA256: o.CPKInfo.EncryptionKeySHA256, }, } @@ -159,19 +159,18 @@ type HTTPHeaders struct { ContentType *string } -// -//func (o HTTPHeaders) formatBlobHTTPHeaders() blob.HTTPHeaders { -// -// opts := blob.HTTPHeaders{ -// BlobCacheControl: o.CacheControl, -// BlobContentDisposition: o.ContentDisposition, -// BlobContentEncoding: o.ContentEncoding, -// BlobContentLanguage: o.ContentLanguage, -// BlobContentMD5: o.ContentMD5, -// BlobContentType: o.ContentType, -// } -// return opts -//} +func FormatBlobHTTPHeaders(o *HTTPHeaders) *blob.HTTPHeaders { + + opts := &blob.HTTPHeaders{ + BlobCacheControl: o.CacheControl, + BlobContentDisposition: o.ContentDisposition, + BlobContentEncoding: o.ContentEncoding, + BlobContentLanguage: o.ContentLanguage, + BlobContentMD5: o.ContentMD5, + BlobContentType: o.ContentType, + } + return opts +} func FormatPathHTTPHeaders(o *HTTPHeaders) *generated.PathHTTPHeaders { // TODO: will be used for file related ops, like append @@ -209,7 +208,7 @@ func FormatSetMetadataOptions(o *SetMetadataOptions) (*blob.SetMetadataOptions, if o.CPKInfo != nil { opts.CPKInfo = &blob.CPKInfo{ EncryptionKey: o.CPKInfo.EncryptionKey, - EncryptionAlgorithm: o.CPKInfo.EncryptionAlgorithm, + EncryptionAlgorithm: (*blob.EncryptionAlgorithmType)(o.CPKInfo.EncryptionAlgorithm), EncryptionKeySHA256: o.CPKInfo.EncryptionKeySHA256, } } diff --git a/sdk/storage/azdatalake/internal/shared/batch_transfer.go b/sdk/storage/azdatalake/internal/shared/batch_transfer.go new file mode 100644 index 000000000000..ec5541bfbb13 --- /dev/null +++ b/sdk/storage/azdatalake/internal/shared/batch_transfer.go @@ -0,0 +1,77 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package shared + +import ( + "context" + "errors" +) + +// BatchTransferOptions identifies options used by doBatchTransfer. +type BatchTransferOptions struct { + TransferSize int64 + ChunkSize int64 + Concurrency uint16 + Operation func(ctx context.Context, offset int64, chunkSize int64) error + OperationName string +} + +// DoBatchTransfer helps to execute operations in a batch manner. +// Can be used by users to customize batch works (for other scenarios that the SDK does not provide) +func DoBatchTransfer(ctx context.Context, o *BatchTransferOptions) error { + if o.ChunkSize == 0 { + return errors.New("ChunkSize cannot be 0") + } + + if o.Concurrency == 0 { + o.Concurrency = 5 // default concurrency + } + + // Prepare and do parallel operations. + numChunks := uint16(((o.TransferSize - 1) / o.ChunkSize) + 1) + operationChannel := make(chan func() error, o.Concurrency) // Create the channel that release 'concurrency' goroutines concurrently + operationResponseChannel := make(chan error, numChunks) // Holds each response + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // Create the goroutines that process each operation (in parallel). + for g := uint16(0); g < o.Concurrency; g++ { + //grIndex := g + go func() { + for f := range operationChannel { + err := f() + operationResponseChannel <- err + } + }() + } + + // Add each chunk's operation to the channel. + for chunkNum := uint16(0); chunkNum < numChunks; chunkNum++ { + curChunkSize := o.ChunkSize + + if chunkNum == numChunks-1 { // Last chunk + curChunkSize = o.TransferSize - (int64(chunkNum) * o.ChunkSize) // Remove size of all transferred chunks from total + } + offset := int64(chunkNum) * o.ChunkSize + operationChannel <- func() error { + return o.Operation(ctx, offset, curChunkSize) + } + } + close(operationChannel) + + // Wait for the operations to complete. + var firstErr error = nil + for chunkNum := uint16(0); chunkNum < numChunks; chunkNum++ { + responseError := <-operationResponseChannel + // record the first error (the original error which should cause the other chunks to fail with canceled context) + if responseError != nil && firstErr == nil { + cancel() // As soon as any operation fails, cancel all remaining operation calls + firstErr = responseError + } + } + return firstErr +} diff --git a/sdk/storage/azdatalake/internal/testcommon/common.go b/sdk/storage/azdatalake/internal/testcommon/common.go index 1314309c5ac2..36af75cc92a7 100644 --- a/sdk/storage/azdatalake/internal/testcommon/common.go +++ b/sdk/storage/azdatalake/internal/testcommon/common.go @@ -1,11 +1,14 @@ package testcommon import ( + "bytes" "errors" "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming" "github.com/Azure/azure-sdk-for-go/sdk/internal/recording" "github.com/Azure/azure-sdk-for-go/sdk/storage/azdatalake/datalakeerror" "github.com/stretchr/testify/require" + "io" "os" "strings" "testing" @@ -80,3 +83,20 @@ func ValidateErrorCode(_require *require.Assertions, err error, code datalakeerr func GetRelativeTimeFromAnchor(anchorTime *time.Time, amount time.Duration) time.Time { return anchorTime.Add(amount * time.Second) } + +const random64BString string = "2SDgZj6RkKYzJpu04sweQek4uWHO8ndPnYlZ0tnFS61hjnFZ5IkvIGGY44eKABov" + +func GenerateData(sizeInBytes int) (io.ReadSeekCloser, []byte) { + data := make([]byte, sizeInBytes) + _len := len(random64BString) + if sizeInBytes > _len { + count := sizeInBytes / _len + if sizeInBytes%_len != 0 { + count = count + 1 + } + copy(data[:], strings.Repeat(random64BString, count)) + } else { + copy(data[:], random64BString) + } + return streaming.NopCloser(bytes.NewReader(data)), data +} diff --git a/sdk/storage/azdatalake/sas/service.go b/sdk/storage/azdatalake/sas/service.go index 86a292028276..92ccaa8101a3 100644 --- a/sdk/storage/azdatalake/sas/service.go +++ b/sdk/storage/azdatalake/sas/service.go @@ -55,7 +55,7 @@ func getDirectoryDepth(path string) string { // SignWithSharedKey uses an account's SharedKeyCredential to sign this signature values to produce the proper SAS query parameters. func (v DatalakeSignatureValues) SignWithSharedKey(sharedKeyCredential *SharedKeyCredential) (QueryParameters, error) { - if v.ExpiryTime.IsZero() || v.Permissions == "" { + if v.Identifier == "" && v.ExpiryTime.IsZero() || v.Permissions == "" { return QueryParameters{}, errors.New("service SAS is missing at least one of these: ExpiryTime or Permissions") } @@ -118,7 +118,6 @@ func (v DatalakeSignatureValues) SignWithSharedKey(sharedKeyCredential *SharedKe // Container/Blob-specific SAS parameters resource: resource, - identifier: v.Identifier, cacheControl: v.CacheControl, contentDisposition: v.ContentDisposition, contentEncoding: v.ContentEncoding, @@ -129,7 +128,8 @@ func (v DatalakeSignatureValues) SignWithSharedKey(sharedKeyCredential *SharedKe unauthorizedObjectID: v.UnauthorizedObjectID, correlationID: v.CorrelationID, // Calculated SAS signature - signature: signature, + signature: signature, + identifier: signedIdentifier, } return p, nil