Skip to content

Commit

Permalink
Store source token and pass to other threads (#1996)
Browse files Browse the repository at this point in the history
* Store source token

* testing

* failing pipe

* cleanup

* test logger

* fix test failure

* fix 2

* fix

* sync fix

* cleanup check
  • Loading branch information
tasherif-msft authored Jan 20, 2023
1 parent 4779674 commit 16ca699
Show file tree
Hide file tree
Showing 12 changed files with 102 additions and 76 deletions.
48 changes: 46 additions & 2 deletions cmd/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -1380,6 +1380,49 @@ func (cca *CookedCopyCmdArgs) processRedirectionUpload(blobResource common.Resou
return err
}

// get source credential - if there is a token it will be used to get passed along our pipeline
func (cca *CookedCopyCmdArgs) getSrcCredential(ctx context.Context, jpo *common.CopyJobPartOrderRequest) (common.CredentialInfo, error) {
srcCredInfo := common.CredentialInfo{}
var err error
var isPublic bool

if srcCredInfo, isPublic, err = GetCredentialInfoForLocation(ctx, cca.FromTo.From(), cca.Source.Value, cca.Source.SAS, true, cca.CpkOptions); err != nil {
return srcCredInfo, err
// If S2S and source takes OAuthToken as its cred type (OR) source takes anonymous as its cred type, but it's not public and there's no SAS
} else if cca.FromTo.IsS2S() &&
((srcCredInfo.CredentialType == common.ECredentialType.OAuthToken() && cca.FromTo.To() != common.ELocation.Blob()) || // Blob can forward OAuth tokens
(srcCredInfo.CredentialType == common.ECredentialType.Anonymous() && !isPublic && cca.Source.SAS == "")) {
return srcCredInfo, errors.New("a SAS token (or S3 access key) is required as a part of the source in S2S transfers, unless the source is a public resource, or the destination is blob storage")
}

if cca.Source.SAS != "" && cca.FromTo.IsS2S() && jpo.CredentialInfo.CredentialType == common.ECredentialType.OAuthToken() {
//glcm.Info("Authentication: If the source and destination accounts are in the same AAD tenant & the user/spn/msi has appropriate permissions on both, the source SAS token is not required and OAuth can be used round-trip.")
}

if cca.FromTo.IsS2S() {
jpo.S2SSourceCredentialType = srcCredInfo.CredentialType

if jpo.S2SSourceCredentialType.IsAzureOAuth() {
uotm := GetUserOAuthTokenManagerInstance()
// get token from env var or cache
if tokenInfo, err := uotm.GetTokenInfo(ctx); err != nil {
return srcCredInfo, err
} else {
cca.credentialInfo.OAuthTokenInfo = *tokenInfo
jpo.CredentialInfo.OAuthTokenInfo = *tokenInfo
}
// if the source is not local then store the credential token if it was OAuth to avoid constant refreshing
jpo.CredentialInfo.SourceBlobToken = common.CreateBlobCredential(ctx, srcCredInfo, common.CredentialOpOptions{
// LogInfo: glcm.Info, //Comment out for debugging
LogError: glcm.Info,
})
cca.credentialInfo.SourceBlobToken = jpo.CredentialInfo.SourceBlobToken
srcCredInfo.SourceBlobToken = jpo.CredentialInfo.SourceBlobToken
}
}
return srcCredInfo, nil
}

// handles the copy command
// dispatches the job order (in parts) to the storage engine
func (cca *CookedCopyCmdArgs) processCopyJobPartOrders() (err error) {
Expand Down Expand Up @@ -1492,11 +1535,12 @@ func (cca *CookedCopyCmdArgs) processCopyJobPartOrders() (err error) {
common.EFromTo.BenchmarkFile():

var e *CopyEnumerator
e, err = cca.initEnumerator(jobPartOrder, ctx)
srcCredInfo, _ := cca.getSrcCredential(ctx, &jobPartOrder)

e, err = cca.initEnumerator(jobPartOrder, srcCredInfo, ctx)
if err != nil {
return fmt.Errorf("failed to initialize enumerator: %w", err)
}

err = e.enumerate()
case common.EFromTo.BlobTrash(), common.EFromTo.FileTrash():
e, createErr := newRemoveEnumerator(cca)
Expand Down
38 changes: 3 additions & 35 deletions cmd/copyEnumeratorInit.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,41 +45,9 @@ func (cca *CookedCopyCmdArgs) validateSourceDir(traverser ResourceTraverser) err
return nil
}

func (cca *CookedCopyCmdArgs) initEnumerator(jobPartOrder common.CopyJobPartOrderRequest, ctx context.Context) (*CopyEnumerator, error) {
func (cca *CookedCopyCmdArgs) initEnumerator(jobPartOrder common.CopyJobPartOrderRequest, srcCredInfo common.CredentialInfo, ctx context.Context) (*CopyEnumerator, error) {
var traverser ResourceTraverser

srcCredInfo := common.CredentialInfo{}
var isPublic bool
var err error

if srcCredInfo, isPublic, err = GetCredentialInfoForLocation(ctx, cca.FromTo.From(), cca.Source.Value, cca.Source.SAS, true, cca.CpkOptions); err != nil {
return nil, err
// If S2S and source takes OAuthToken as its cred type (OR) source takes anonymous as its cred type, but it's not public and there's no SAS
} else if cca.FromTo.IsS2S() &&
((srcCredInfo.CredentialType == common.ECredentialType.OAuthToken() && cca.FromTo.To() != common.ELocation.Blob()) || // Blob can forward OAuth tokens
(srcCredInfo.CredentialType == common.ECredentialType.Anonymous() && !isPublic && cca.Source.SAS == "")) {
return nil, errors.New("a SAS token (or S3 access key) is required as a part of the source in S2S transfers, unless the source is a public resource, or the destination is blob storage")
}

if cca.Source.SAS != "" && cca.FromTo.IsS2S() && jobPartOrder.CredentialInfo.CredentialType == common.ECredentialType.OAuthToken() {
glcm.Info("Authentication: If the source and destination accounts are in the same AAD tenant & the user/spn/msi has appropriate permissions on both, the source SAS token is not required and OAuth can be used round-trip.")
}

if cca.FromTo.IsS2S() {
jobPartOrder.S2SSourceCredentialType = srcCredInfo.CredentialType

if jobPartOrder.S2SSourceCredentialType.IsAzureOAuth() {
uotm := GetUserOAuthTokenManagerInstance()
// get token from env var or cache
if tokenInfo, err := uotm.GetTokenInfo(ctx); err != nil {
return nil, err
} else {
cca.credentialInfo.OAuthTokenInfo = *tokenInfo
jobPartOrder.CredentialInfo.OAuthTokenInfo = *tokenInfo
}
}
}

jobPartOrder.CpkOptions = cca.CpkOptions
jobPartOrder.PreserveSMBPermissions = cca.preservePermissions
jobPartOrder.PreserveSMBInfo = cca.preserveSMBInfo
Expand All @@ -90,7 +58,7 @@ func (cca *CookedCopyCmdArgs) initEnumerator(jobPartOrder common.CopyJobPartOrde
// If preserve properties is enabled, but get properties in backend is disabled, turn it on
// If source change validation is enabled on files to remote, turn it on (consider a separate flag entirely?)
getRemoteProperties := cca.ForceWrite == common.EOverwriteOption.IfSourceNewer() ||
(cca.FromTo.From() == common.ELocation.File() && !cca.FromTo.To().IsRemote()) || // If download, we still need LMT and MD5 from files.
(cca.FromTo.From() == common.ELocation.File() && !cca.FromTo.To().IsRemote()) || // If it's a download, we still need LMT and MD5 from files.
(cca.FromTo.From() == common.ELocation.File() && cca.FromTo.To().IsRemote() && (cca.s2sSourceChangeValidation || cca.IncludeAfter != nil || cca.IncludeBefore != nil)) || // If S2S from File to *, and sourceChangeValidation is enabled, we get properties so that we have LMTs. Likewise, if we are using includeAfter or includeBefore, which require LMTs.
(cca.FromTo.From().IsRemote() && cca.FromTo.To().IsRemote() && cca.s2sPreserveProperties && !cca.s2sGetPropertiesInBackend) // If S2S and preserve properties AND get properties in backend is on, turn this off, as properties will be obtained in the backend.
jobPartOrder.S2SGetPropertiesInBackend = cca.s2sPreserveProperties && !getRemoteProperties && cca.s2sGetPropertiesInBackend // Infer GetProperties if GetPropertiesInBackend is enabled.
Expand Down Expand Up @@ -473,7 +441,7 @@ func (cca *CookedCopyCmdArgs) createDstContainer(containerName string, dstWithSA
if dstCredInfo, _, err = GetCredentialInfoForLocation(ctx, cca.FromTo.To(), cca.Destination.Value, cca.Destination.SAS, false, cca.CpkOptions); err != nil {
return err
}

// TODO: we can pass cred here as well
dstPipeline, err := InitPipeline(ctx, cca.FromTo.To(), dstCredInfo, logLevel.ToPipelineLogLevel())
if err != nil {
return
Expand Down
13 changes: 8 additions & 5 deletions cmd/credentialUtil.go
Original file line number Diff line number Diff line change
Expand Up @@ -655,11 +655,14 @@ func getCredentialType(ctx context.Context, raw rawFromToInfo, cpkOptions common
// pipeline factory methods
// ==============================================================================================
func createBlobPipeline(ctx context.Context, credInfo common.CredentialInfo, logLevel pipeline.LogLevel) (pipeline.Pipeline, error) {
credential := common.CreateBlobCredential(ctx, credInfo, common.CredentialOpOptions{
// LogInfo: glcm.Info, //Comment out for debugging
LogError: glcm.Info,
})

// are we getting dest token?
credential := credInfo.SourceBlobToken
if credential == nil {
credential = common.CreateBlobCredential(ctx, credInfo, common.CredentialOpOptions{
// LogInfo: glcm.Info, //Comment out for debugging
LogError: glcm.Info,
})
}
logOption := pipeline.LogOptions{}
if azcopyScanningLogger != nil {
logOption = pipeline.LogOptions{
Expand Down
17 changes: 9 additions & 8 deletions cmd/syncProcessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,14 +230,15 @@ type localFileDeleter struct {
// As at version 10.4.0, we intentionally don't delete directories in sync,
// even if our folder properties option suggests we should.
// Why? The key difficulties are as follows, and its the third one that we don't currently have a solution for.
// 1. Timing (solvable in theory with FolderDeletionManager)
// 2. Identifying which should be removed when source does not have concept of folders (e.g. BLob)
// Probably solution is to just respect the folder properties option setting (which we already do in our delete processors)
// 3. In Azure Files case (and to a lesser extent on local disks) users may have ACLS or other properties
// set on the directories, and wish to retain those even tho the directories are empty. (Perhaps less of an issue
// when syncing from folder-aware sources that DOES NOT HAVE the directory. But still an issue when syncing from
// blob. E.g. we delete a folder because there's nothing in it right now, but really user wanted it there,
// and have set up custom ACLs on it for future use. If we delete, they lose the custom ACL setup.
// 1. Timing (solvable in theory with FolderDeletionManager)
// 2. Identifying which should be removed when source does not have concept of folders (e.g. BLob)
// Probably solution is to just respect the folder properties option setting (which we already do in our delete processors)
// 3. In Azure Files case (and to a lesser extent on local disks) users may have ACLS or other properties
// set on the directories, and wish to retain those even tho the directories are empty. (Perhaps less of an issue
// when syncing from folder-aware sources that DOES NOT HAVE the directory. But still an issue when syncing from
// blob. E.g. we delete a folder because there's nothing in it right now, but really user wanted it there,
// and have set up custom ACLs on it for future use. If we delete, they lose the custom ACL setup.
//
// TODO: shall we add folder deletion support at some stage? (In cases where folderPropertiesOption says that folders should be processed)
func shouldSyncRemoveFolders() bool {
return false
Expand Down
1 change: 0 additions & 1 deletion cmd/zc_enumerator.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,6 @@ func InitResourceTraverser(resource common.ResourceString, location common.Locat
// Initialize the pipeline if creds and ctx is provided
if ctx != nil && credential != nil {
tmppipe, err := InitPipeline(*ctx, location, *credential, logLevel)

if err != nil {
return nil, err
}
Expand Down
1 change: 0 additions & 1 deletion cmd/zc_pipeline_init.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package cmd
import (
"context"
"fmt"

"github.com/Azure/azure-pipeline-go/pipeline"

"github.com/Azure/azure-storage-azcopy/v10/common"
Expand Down
14 changes: 9 additions & 5 deletions common/credentialFactory.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,15 @@ func CreateBlobCredential(ctx context.Context, credInfo CredentialInfo, options
}

// Create TokenCredential with refresher.
return azblob.NewTokenCredential(
credInfo.OAuthTokenInfo.AccessToken,
func(credential azblob.TokenCredential) time.Duration {
return refreshBlobToken(ctx, credInfo.OAuthTokenInfo, credential, options)
})
if credInfo.SourceBlobToken != nil {
return credInfo.SourceBlobToken
} else {
return azblob.NewTokenCredential(
credInfo.OAuthTokenInfo.AccessToken,
func(credential azblob.TokenCredential) time.Duration {
return refreshBlobToken(ctx, credInfo.OAuthTokenInfo, credential, options)
})
}
}

return credential
Expand Down
1 change: 1 addition & 0 deletions common/rpc-models.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ type CredentialInfo struct {
OAuthTokenInfo OAuthTokenInfo
S3CredentialInfo S3CredentialInfo
GCPCredentialInfo GCPCredentialInfo
SourceBlobToken azblob.Credential
}

func (c CredentialInfo) WithType(credentialType CredentialType) CredentialInfo {
Expand Down
11 changes: 6 additions & 5 deletions jobsAdmin/JobsAdmin.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/Azure/azure-storage-blob-go/azblob"
"os"
"path/filepath"
"runtime"
Expand Down Expand Up @@ -74,7 +75,7 @@ var JobsAdmin interface {

// JobMgr returns the specified JobID's JobMgr
JobMgr(jobID common.JobID) (ste.IJobMgr, bool)
JobMgrEnsureExists(jobID common.JobID, level common.LogLevel, commandString string) ste.IJobMgr
JobMgrEnsureExists(jobID common.JobID, level common.LogLevel, commandString string, sourceBlobToken azblob.Credential) ste.IJobMgr

// AddJobPartMgr associates the specified JobPartMgr with the Jobs Administrator
//AddJobPartMgr(appContext context.Context, planFile JobPartPlanFileName) IJobPartMgr
Expand Down Expand Up @@ -293,12 +294,12 @@ func (ja *jobsAdmin) AppPathFolder() string {
// JobMgrEnsureExists returns the specified JobID's IJobMgr if it exists or creates it if it doesn't already exit
// If it does exist, then the appCtx argument is ignored.
func (ja *jobsAdmin) JobMgrEnsureExists(jobID common.JobID,
level common.LogLevel, commandString string) ste.IJobMgr {
level common.LogLevel, commandString string, sourceBlobToken azblob.Credential) ste.IJobMgr {

return ja.jobIDToJobMgr.EnsureExists(jobID,
func() ste.IJobMgr {
// Return existing or new IJobMgr to caller
return ste.NewJobMgr(ja.concurrency, jobID, ja.appCtx, ja.cpuMonitor, level, commandString, ja.logDir, ja.concurrencyTuner, ja.pacer, ja.slicePool, ja.cacheLimiter, ja.fileCountLimiter, ja.jobLogger, false)
return ste.NewJobMgr(ja.concurrency, jobID, ja.appCtx, ja.cpuMonitor, level, commandString, ja.logDir, ja.concurrencyTuner, ja.pacer, ja.slicePool, ja.cacheLimiter, ja.fileCountLimiter, ja.jobLogger, false, sourceBlobToken)
})
}

Expand Down Expand Up @@ -387,7 +388,7 @@ func (ja *jobsAdmin) ResurrectJob(jobId common.JobID, sourceSAS string, destinat
continue
}
mmf := planFile.Map()
jm := ja.JobMgrEnsureExists(jobID, mmf.Plan().LogLevel, "")
jm := ja.JobMgrEnsureExists(jobID, mmf.Plan().LogLevel, "", nil)
jm.AddJobPart(partNum, planFile, mmf, sourceSAS, destinationSAS, false, nil)
}

Expand Down Expand Up @@ -421,7 +422,7 @@ func (ja *jobsAdmin) ResurrectJobParts() {
}
mmf := planFile.Map()
//todo : call the compute transfer function here for each job.
jm := ja.JobMgrEnsureExists(jobID, mmf.Plan().LogLevel, "")
jm := ja.JobMgrEnsureExists(jobID, mmf.Plan().LogLevel, "", nil)
jm.AddJobPart(partNum, planFile, mmf, EMPTY_SAS_STRING, EMPTY_SAS_STRING, false, nil)
}
}
Expand Down
4 changes: 2 additions & 2 deletions jobsAdmin/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ func MainSTE(concurrency ste.ConcurrencySettings, targetRateInMegaBitsPerSec flo
func ExecuteNewCopyJobPartOrder(order common.CopyJobPartOrderRequest) common.CopyJobPartOrderResponse {
// Get the file name for this Job Part's Plan
jppfn := JobsAdmin.NewJobPartPlanFileName(order.JobID, order.PartNum)
jppfn.Create(order) // Convert the order to a plan file
jm := JobsAdmin.JobMgrEnsureExists(order.JobID, order.LogLevel, order.CommandString) // Get a this job part's job manager (create it if it doesn't exist)
jppfn.Create(order) // Convert the order to a plan file
jm := JobsAdmin.JobMgrEnsureExists(order.JobID, order.LogLevel, order.CommandString, order.CredentialInfo.SourceBlobToken) // Get a this job part's job manager (create it if it doesn't exist)

if len(order.Transfers.List) == 0 && order.IsFinalPart {
/*
Expand Down
11 changes: 7 additions & 4 deletions ste/mgr-JobMgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ package ste
import (
"context"
"fmt"
"github.com/Azure/azure-storage-blob-go/azblob"
"net/http"
"runtime"
"strings"
Expand Down Expand Up @@ -111,7 +112,7 @@ type IJobMgr interface {
func NewJobMgr(concurrency ConcurrencySettings, jobID common.JobID, appCtx context.Context, cpuMon common.CPUMonitor, level common.LogLevel,
commandString string, logFileFolder string, tuner ConcurrencyTuner,
pacer PacerAdmin, slicePool common.ByteSlicePooler, cacheLimiter common.CacheLimiter, fileCountLimiter common.CacheLimiter,
jobLogger common.ILoggerResetable, daemonMode bool) IJobMgr {
jobLogger common.ILoggerResetable, daemonMode bool, sourceBlobToken azblob.Credential) IJobMgr {
const channelSize = 100000
// PartsChannelSize defines the number of JobParts which can be placed into the
// parts channel. Any JobPart which comes from FE and partChannel is full,
Expand Down Expand Up @@ -187,6 +188,7 @@ func NewJobMgr(concurrency ConcurrencySettings, jobID common.JobID, appCtx conte
cpuMon: cpuMon,
jstm: &jstm,
isDaemon: daemonMode,
sourceBlobToken: sourceBlobToken,
/*Other fields remain zero-value until this job is scheduled */}
jm.Reset(appCtx, commandString)
// One routine constantly monitors the partsChannel. It takes the JobPartManager from
Expand Down Expand Up @@ -338,7 +340,8 @@ type jobMgr struct {
fileCountLimiter common.CacheLimiter
jstm *jobStatusManager

isDaemon bool /* is it running as service */
isDaemon bool /* is it running as service */
sourceBlobToken azblob.Credential
}

// //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -713,7 +716,7 @@ func (jm *jobMgr) CloseLog() {
// DeferredCleanupJobMgr cleanup all the jobMgr resources.
// Warning: DeferredCleanupJobMgr should be called from JobMgrCleanup().
//
// As this function neither threadsafe nor idempotient. So if DeferredCleanupJobMgr called
// As this function neither thread safe nor idempotent. So if DeferredCleanupJobMgr called
// multiple times, it may stuck as receiving channel already closed. Where as JobMgrCleanup()
// safe in that sense it will do the cleanup only once.
//
Expand Down Expand Up @@ -963,7 +966,7 @@ func (jm *jobMgr) scheduleJobParts() {
go jm.poolSizer()
startedPoolSizer = true
}
jobPart.ScheduleTransfers(jm.Context())
jobPart.ScheduleTransfers(jm.Context(), jm.sourceBlobToken)
}
}
}
Expand Down
Loading

0 comments on commit 16ca699

Please sign in to comment.