diff --git a/pkg/sql/importer/import_processor_planning.go b/pkg/sql/importer/import_processor_planning.go index bca3aab66abe..011238078201 100644 --- a/pkg/sql/importer/import_processor_planning.go +++ b/pkg/sql/importer/import_processor_planning.go @@ -32,9 +32,12 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/sem/tree" "github.com/cockroachdb/cockroach/pkg/sql/types" "github.com/cockroachdb/cockroach/pkg/util/ctxgroup" + "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/protoutil" "github.com/cockroachdb/cockroach/pkg/util/syncutil" + "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/cockroach/pkg/util/tracing" + "github.com/cockroachdb/logtags" ) var replanThreshold = settings.RegisterFloatSetting( @@ -217,8 +220,9 @@ func distImport( } } + flowCtx, watcher := newCancelWatcher(ctx, 3*progressUpdateInterval) recv := sql.MakeDistSQLReceiver( - ctx, + flowCtx, sql.NewMetadataCallbackWriter(rowResultWriter, metaFn), tree.Rows, nil, /* rangeCache */ @@ -238,6 +242,10 @@ func distImport( stopProgress := make(chan struct{}) g := ctxgroup.WithContext(ctx) + g.GoCtx(func(ctx context.Context) error { + watcher.watch(ctx, recv) + return nil + }) g.GoCtx(func(ctx context.Context) error { tick := time.NewTicker(time.Second * 10) defer tick.Stop() @@ -260,6 +268,13 @@ func distImport( g.GoCtx(func(ctx context.Context) error { defer cancelReplanner() defer close(stopProgress) + defer watcher.stop() + + if testingKnobs.beforeRunDSP != nil { + if err := testingKnobs.beforeRunDSP(); err != nil { + return err + } + } if testingKnobs.beforeRunDSP != nil { if err := testingKnobs.beforeRunDSP(); err != nil { @@ -269,7 +284,7 @@ func distImport( // Copy the evalCtx, as dsp.Run() might change it. evalCtxCopy := *evalCtx - dsp.Run(ctx, planCtx, nil, p, recv, &evalCtxCopy, testingKnobs.onSetupFinish) + dsp.Run(flowCtx, planCtx, nil, p, recv, &evalCtxCopy, testingKnobs.onSetupFinish) return rowResultWriter.Err() }) @@ -282,6 +297,79 @@ func distImport( return res, nil } +// cancelWatcher is used to handle job PAUSE and CANCEL +// gracefully. +// +// When the a job is canceled or paused, the context passed to the +// Resumer is canceled by the job system. Rather than passing the +// Resumer's context directly to our DistSQL receiver and planner, we +// instead construct a new context for the distsql machinery and watch +// the original context for cancelation using this cancelWatcher. +// +// When the cancelWatcher sees a cancelation on the watched context, +// it informs the DistSQL receiver passed to watch by calling +// SetError. This gives the DistSQL flow a chance to drain gracefully. +// +// However, we do not want to wait forever. After SetError is called, +// we wait up to `timeout` before canceling the context returned by +// newCancelWatcher. +type cancelWatcher struct { + watchedCtx context.Context + timeout time.Duration + + done chan struct{} + cancel context.CancelFunc +} + +// newCancelWatcher constructs a cancelWatcher. To start the watcher +// call watch. The context passed to newCancelWatcher will be watched +// for cancelation. The returned context should be used for the +// DistSQL receiver and DistSQLPlanner. +func newCancelWatcher( + ctxToWatch context.Context, timeout time.Duration, +) (context.Context, *cancelWatcher) { + ctx, cancel := context.WithCancel( + logtags.AddTags( + context.Background(), + logtags.FromContext(ctxToWatch))) + return ctx, &cancelWatcher{ + watchedCtx: ctxToWatch, + timeout: timeout, + + done: make(chan struct{}), + cancel: cancel, + } +} + +// watch starts watching the context passed to newCancelWatcher for +// cancellation and notifies the given DistSQLReceiver when a +// cancellation occurs. +// +// After cancellation, if the watcher is not stopped before the +// configured timeout, the context returned from the constructor is +// cancelled. +func (c *cancelWatcher) watch(ctx context.Context, recv *sql.DistSQLReceiver) { + select { + case <-c.watchedCtx.Done(): + recv.SetError(c.watchedCtx.Err()) + timer := timeutil.NewTimer() + defer timer.Stop() + timer.Reset(c.timeout) + select { + case <-c.done: + case <-timer.C: + timer.Read = true + log.Warningf(ctx, "watcher not stopped after %s, canceling flow context", c.timeout) + c.cancel() + } + case <-c.done: + } +} + +func (c *cancelWatcher) stop() { + close(c.done) +} + func getLastImportSummary(job *jobs.Job) roachpb.BulkOpSummary { progress := job.Progress() importProgress := progress.GetImport()