diff --git a/pkg/sql/colexec/colexecjoin/mergejoiner.go b/pkg/sql/colexec/colexecjoin/mergejoiner.go index 385688a5d518..ef67f3bef4a6 100644 --- a/pkg/sql/colexec/colexecjoin/mergejoiner.go +++ b/pkg/sql/colexec/colexecjoin/mergejoiner.go @@ -552,6 +552,8 @@ type mergeJoinBase struct { left mergeJoinInput right mergeJoinInput + cancelChecker colexecutils.CancelChecker + // Output buffer definition. output coldata.Batch outputCapacity int @@ -574,6 +576,7 @@ var _ colexecop.Closer = &mergeJoinBase{} func (o *mergeJoinBase) Reset(ctx context.Context) { o.TwoInputInitHelper.Reset(ctx) + o.cancelChecker.Init(ctx) o.state = mjEntry o.bufferedGroup.helper.Reset(ctx) o.proberState.lBatch = nil @@ -597,6 +600,7 @@ func (o *mergeJoinBase) Init(ctx context.Context) { o.memoryLimit, o.diskQueueCfg, o.fdSemaphore, o.diskAcc, o.diskQueueMemAcc, ) o.bufferedGroup.helper.init(o.Ctx) + o.cancelChecker.Init(o.Ctx) o.builderState.lGroups = make([]group, 1) o.builderState.rGroups = make([]group, 1) @@ -739,6 +743,7 @@ func (o *mergeJoinBase) sourceFinished() bool { // and updates the probing and buffered group states accordingly. func (o *mergeJoinBase) continueLeftBufferedGroup() { // Get the next batch from the left. + o.cancelChecker.CheckEveryCall() o.proberState.lIdx, o.proberState.lBatch = 0, o.InputOne.Next() o.proberState.lLength = o.proberState.lBatch.Length() o.bufferedGroup.leftGroupStartIdx = 0 @@ -805,6 +810,7 @@ func (o *mergeJoinBase) finishRightBufferedGroup() { // (or until we exhaust the right input). func (o *mergeJoinBase) completeRightBufferedGroup() { // Get the next batch from the right. + o.cancelChecker.CheckEveryCall() o.proberState.rIdx, o.proberState.rBatch = 0, o.InputTwo.Next() o.proberState.rLength = o.proberState.rBatch.Length() // The right input has been fully exhausted. @@ -870,6 +876,7 @@ func (o *mergeJoinBase) completeRightBufferedGroup() { // The buffered group is still not complete which means that we have // just appended all the tuples from batch to it, so we need to get a // fresh batch from the input. + o.cancelChecker.CheckEveryCall() o.proberState.rIdx, o.proberState.rBatch = 0, o.InputTwo.Next() o.proberState.rLength = o.proberState.rBatch.Length() if o.proberState.rLength == 0 { diff --git a/pkg/sql/colexec/colexecjoin/mergejoiner_exceptall.eg.go b/pkg/sql/colexec/colexecjoin/mergejoiner_exceptall.eg.go index 7ef6f4ac14c1..0b1e8d0c33c7 100644 --- a/pkg/sql/colexec/colexecjoin/mergejoiner_exceptall.eg.go +++ b/pkg/sql/colexec/colexecjoin/mergejoiner_exceptall.eg.go @@ -14236,10 +14236,12 @@ func (o *mergeJoinExceptAllOp) Next() coldata.Batch { // If this is the first batch or we're done with the current batch, // get the next batch. if o.proberState.lBatch == nil || (o.proberState.lLength != 0 && o.proberState.lIdx == o.proberState.lLength) { + o.cancelChecker.CheckEveryCall() o.proberState.lIdx, o.proberState.lBatch = 0, o.InputOne.Next() o.proberState.lLength = o.proberState.lBatch.Length() } if o.proberState.rBatch == nil || (o.proberState.rLength != 0 && o.proberState.rIdx == o.proberState.rLength) { + o.cancelChecker.CheckEveryCall() o.proberState.rIdx, o.proberState.rBatch = 0, o.InputTwo.Next() o.proberState.rLength = o.proberState.rBatch.Length() } diff --git a/pkg/sql/colexec/colexecjoin/mergejoiner_fullouter.eg.go b/pkg/sql/colexec/colexecjoin/mergejoiner_fullouter.eg.go index da4e855199cb..32ae33880ae4 100644 --- a/pkg/sql/colexec/colexecjoin/mergejoiner_fullouter.eg.go +++ b/pkg/sql/colexec/colexecjoin/mergejoiner_fullouter.eg.go @@ -15383,10 +15383,12 @@ func (o *mergeJoinFullOuterOp) Next() coldata.Batch { // If this is the first batch or we're done with the current batch, // get the next batch. if o.proberState.lBatch == nil || (o.proberState.lLength != 0 && o.proberState.lIdx == o.proberState.lLength) { + o.cancelChecker.CheckEveryCall() o.proberState.lIdx, o.proberState.lBatch = 0, o.InputOne.Next() o.proberState.lLength = o.proberState.lBatch.Length() } if o.proberState.rBatch == nil || (o.proberState.rLength != 0 && o.proberState.rIdx == o.proberState.rLength) { + o.cancelChecker.CheckEveryCall() o.proberState.rIdx, o.proberState.rBatch = 0, o.InputTwo.Next() o.proberState.rLength = o.proberState.rBatch.Length() } diff --git a/pkg/sql/colexec/colexecjoin/mergejoiner_inner.eg.go b/pkg/sql/colexec/colexecjoin/mergejoiner_inner.eg.go index 980d517c85c2..79ba29cbc26b 100644 --- a/pkg/sql/colexec/colexecjoin/mergejoiner_inner.eg.go +++ b/pkg/sql/colexec/colexecjoin/mergejoiner_inner.eg.go @@ -10917,10 +10917,12 @@ func (o *mergeJoinInnerOp) Next() coldata.Batch { // If this is the first batch or we're done with the current batch, // get the next batch. if o.proberState.lBatch == nil || (o.proberState.lLength != 0 && o.proberState.lIdx == o.proberState.lLength) { + o.cancelChecker.CheckEveryCall() o.proberState.lIdx, o.proberState.lBatch = 0, o.InputOne.Next() o.proberState.lLength = o.proberState.lBatch.Length() } if o.proberState.rBatch == nil || (o.proberState.rLength != 0 && o.proberState.rIdx == o.proberState.rLength) { + o.cancelChecker.CheckEveryCall() o.proberState.rIdx, o.proberState.rBatch = 0, o.InputTwo.Next() o.proberState.rLength = o.proberState.rBatch.Length() } diff --git a/pkg/sql/colexec/colexecjoin/mergejoiner_intersectall.eg.go b/pkg/sql/colexec/colexecjoin/mergejoiner_intersectall.eg.go index 4b923b4924ed..f6a686d53f05 100644 --- a/pkg/sql/colexec/colexecjoin/mergejoiner_intersectall.eg.go +++ b/pkg/sql/colexec/colexecjoin/mergejoiner_intersectall.eg.go @@ -11627,10 +11627,12 @@ func (o *mergeJoinIntersectAllOp) Next() coldata.Batch { // If this is the first batch or we're done with the current batch, // get the next batch. if o.proberState.lBatch == nil || (o.proberState.lLength != 0 && o.proberState.lIdx == o.proberState.lLength) { + o.cancelChecker.CheckEveryCall() o.proberState.lIdx, o.proberState.lBatch = 0, o.InputOne.Next() o.proberState.lLength = o.proberState.lBatch.Length() } if o.proberState.rBatch == nil || (o.proberState.rLength != 0 && o.proberState.rIdx == o.proberState.rLength) { + o.cancelChecker.CheckEveryCall() o.proberState.rIdx, o.proberState.rBatch = 0, o.InputTwo.Next() o.proberState.rLength = o.proberState.rBatch.Length() } diff --git a/pkg/sql/colexec/colexecjoin/mergejoiner_leftanti.eg.go b/pkg/sql/colexec/colexecjoin/mergejoiner_leftanti.eg.go index 78e6081ef207..2cba92c660fd 100644 --- a/pkg/sql/colexec/colexecjoin/mergejoiner_leftanti.eg.go +++ b/pkg/sql/colexec/colexecjoin/mergejoiner_leftanti.eg.go @@ -13146,10 +13146,12 @@ func (o *mergeJoinLeftAntiOp) Next() coldata.Batch { // If this is the first batch or we're done with the current batch, // get the next batch. if o.proberState.lBatch == nil || (o.proberState.lLength != 0 && o.proberState.lIdx == o.proberState.lLength) { + o.cancelChecker.CheckEveryCall() o.proberState.lIdx, o.proberState.lBatch = 0, o.InputOne.Next() o.proberState.lLength = o.proberState.lBatch.Length() } if o.proberState.rBatch == nil || (o.proberState.rLength != 0 && o.proberState.rIdx == o.proberState.rLength) { + o.cancelChecker.CheckEveryCall() o.proberState.rIdx, o.proberState.rBatch = 0, o.InputTwo.Next() o.proberState.rLength = o.proberState.rBatch.Length() } diff --git a/pkg/sql/colexec/colexecjoin/mergejoiner_leftouter.eg.go b/pkg/sql/colexec/colexecjoin/mergejoiner_leftouter.eg.go index a25970e78466..26a138a64d89 100644 --- a/pkg/sql/colexec/colexecjoin/mergejoiner_leftouter.eg.go +++ b/pkg/sql/colexec/colexecjoin/mergejoiner_leftouter.eg.go @@ -13173,10 +13173,12 @@ func (o *mergeJoinLeftOuterOp) Next() coldata.Batch { // If this is the first batch or we're done with the current batch, // get the next batch. if o.proberState.lBatch == nil || (o.proberState.lLength != 0 && o.proberState.lIdx == o.proberState.lLength) { + o.cancelChecker.CheckEveryCall() o.proberState.lIdx, o.proberState.lBatch = 0, o.InputOne.Next() o.proberState.lLength = o.proberState.lBatch.Length() } if o.proberState.rBatch == nil || (o.proberState.rLength != 0 && o.proberState.rIdx == o.proberState.rLength) { + o.cancelChecker.CheckEveryCall() o.proberState.rIdx, o.proberState.rBatch = 0, o.InputTwo.Next() o.proberState.rLength = o.proberState.rBatch.Length() } diff --git a/pkg/sql/colexec/colexecjoin/mergejoiner_leftsemi.eg.go b/pkg/sql/colexec/colexecjoin/mergejoiner_leftsemi.eg.go index 22fc187eed10..0090ad6530be 100644 --- a/pkg/sql/colexec/colexecjoin/mergejoiner_leftsemi.eg.go +++ b/pkg/sql/colexec/colexecjoin/mergejoiner_leftsemi.eg.go @@ -10870,10 +10870,12 @@ func (o *mergeJoinLeftSemiOp) Next() coldata.Batch { // If this is the first batch or we're done with the current batch, // get the next batch. if o.proberState.lBatch == nil || (o.proberState.lLength != 0 && o.proberState.lIdx == o.proberState.lLength) { + o.cancelChecker.CheckEveryCall() o.proberState.lIdx, o.proberState.lBatch = 0, o.InputOne.Next() o.proberState.lLength = o.proberState.lBatch.Length() } if o.proberState.rBatch == nil || (o.proberState.rLength != 0 && o.proberState.rIdx == o.proberState.rLength) { + o.cancelChecker.CheckEveryCall() o.proberState.rIdx, o.proberState.rBatch = 0, o.InputTwo.Next() o.proberState.rLength = o.proberState.rBatch.Length() } diff --git a/pkg/sql/colexec/colexecjoin/mergejoiner_rightanti.eg.go b/pkg/sql/colexec/colexecjoin/mergejoiner_rightanti.eg.go index 1e84de733f4b..00eaa60c2ed4 100644 --- a/pkg/sql/colexec/colexecjoin/mergejoiner_rightanti.eg.go +++ b/pkg/sql/colexec/colexecjoin/mergejoiner_rightanti.eg.go @@ -13077,10 +13077,12 @@ func (o *mergeJoinRightAntiOp) Next() coldata.Batch { // If this is the first batch or we're done with the current batch, // get the next batch. if o.proberState.lBatch == nil || (o.proberState.lLength != 0 && o.proberState.lIdx == o.proberState.lLength) { + o.cancelChecker.CheckEveryCall() o.proberState.lIdx, o.proberState.lBatch = 0, o.InputOne.Next() o.proberState.lLength = o.proberState.lBatch.Length() } if o.proberState.rBatch == nil || (o.proberState.rLength != 0 && o.proberState.rIdx == o.proberState.rLength) { + o.cancelChecker.CheckEveryCall() o.proberState.rIdx, o.proberState.rBatch = 0, o.InputTwo.Next() o.proberState.rLength = o.proberState.rBatch.Length() } diff --git a/pkg/sql/colexec/colexecjoin/mergejoiner_rightouter.eg.go b/pkg/sql/colexec/colexecjoin/mergejoiner_rightouter.eg.go index 774714de5e29..4f41afd37107 100644 --- a/pkg/sql/colexec/colexecjoin/mergejoiner_rightouter.eg.go +++ b/pkg/sql/colexec/colexecjoin/mergejoiner_rightouter.eg.go @@ -13127,10 +13127,12 @@ func (o *mergeJoinRightOuterOp) Next() coldata.Batch { // If this is the first batch or we're done with the current batch, // get the next batch. if o.proberState.lBatch == nil || (o.proberState.lLength != 0 && o.proberState.lIdx == o.proberState.lLength) { + o.cancelChecker.CheckEveryCall() o.proberState.lIdx, o.proberState.lBatch = 0, o.InputOne.Next() o.proberState.lLength = o.proberState.lBatch.Length() } if o.proberState.rBatch == nil || (o.proberState.rLength != 0 && o.proberState.rIdx == o.proberState.rLength) { + o.cancelChecker.CheckEveryCall() o.proberState.rIdx, o.proberState.rBatch = 0, o.InputTwo.Next() o.proberState.rLength = o.proberState.rBatch.Length() } diff --git a/pkg/sql/colexec/colexecjoin/mergejoiner_rightsemi.eg.go b/pkg/sql/colexec/colexecjoin/mergejoiner_rightsemi.eg.go index cca723c753e0..caaad1c71c9b 100644 --- a/pkg/sql/colexec/colexecjoin/mergejoiner_rightsemi.eg.go +++ b/pkg/sql/colexec/colexecjoin/mergejoiner_rightsemi.eg.go @@ -10830,10 +10830,12 @@ func (o *mergeJoinRightSemiOp) Next() coldata.Batch { // If this is the first batch or we're done with the current batch, // get the next batch. if o.proberState.lBatch == nil || (o.proberState.lLength != 0 && o.proberState.lIdx == o.proberState.lLength) { + o.cancelChecker.CheckEveryCall() o.proberState.lIdx, o.proberState.lBatch = 0, o.InputOne.Next() o.proberState.lLength = o.proberState.lBatch.Length() } if o.proberState.rBatch == nil || (o.proberState.rLength != 0 && o.proberState.rIdx == o.proberState.rLength) { + o.cancelChecker.CheckEveryCall() o.proberState.rIdx, o.proberState.rBatch = 0, o.InputTwo.Next() o.proberState.rLength = o.proberState.rBatch.Length() } diff --git a/pkg/sql/colexec/colexecjoin/mergejoiner_tmpl.go b/pkg/sql/colexec/colexecjoin/mergejoiner_tmpl.go index b6209aed11de..274601c080ae 100644 --- a/pkg/sql/colexec/colexecjoin/mergejoiner_tmpl.go +++ b/pkg/sql/colexec/colexecjoin/mergejoiner_tmpl.go @@ -1366,10 +1366,12 @@ func (o *mergeJoin_JOIN_TYPE_STRINGOp) Next() coldata.Batch { // If this is the first batch or we're done with the current batch, // get the next batch. if o.proberState.lBatch == nil || (o.proberState.lLength != 0 && o.proberState.lIdx == o.proberState.lLength) { + o.cancelChecker.CheckEveryCall() o.proberState.lIdx, o.proberState.lBatch = 0, o.InputOne.Next() o.proberState.lLength = o.proberState.lBatch.Length() } if o.proberState.rBatch == nil || (o.proberState.rLength != 0 && o.proberState.rIdx == o.proberState.rLength) { + o.cancelChecker.CheckEveryCall() o.proberState.rIdx, o.proberState.rBatch = 0, o.InputTwo.Next() o.proberState.rLength = o.proberState.rBatch.Length() } diff --git a/pkg/sql/colexec/sorttopk.eg.go b/pkg/sql/colexec/sorttopk.eg.go index 64a76e026ac9..a370119b2874 100644 --- a/pkg/sql/colexec/sorttopk.eg.go +++ b/pkg/sql/colexec/sorttopk.eg.go @@ -114,6 +114,7 @@ func spool_true(t *topKSorter) { // or more distinct and complete groups, and then use a K-N size heap to find // the remaining top K-N rows. { + t.cancelChecker.CheckEveryCall() t.inputBatch = t.Input.Next() t.orderState.distincterInput.SetBatch(t.inputBatch) t.orderState.distincter.Next() @@ -176,6 +177,7 @@ func spool_true(t *topKSorter) { remainingRows -= uint64(fromLength) if fromLength == t.inputBatch.Length() { { + t.cancelChecker.CheckEveryCall() t.inputBatch = t.Input.Next() t.orderState.distincterInput.SetBatch(t.inputBatch) t.orderState.distincter.Next() @@ -276,6 +278,7 @@ func spool_true(t *topKSorter) { break } { + t.cancelChecker.CheckEveryCall() t.inputBatch = t.Input.Next() t.orderState.distincterInput.SetBatch(t.inputBatch) t.orderState.distincter.Next() @@ -313,6 +316,7 @@ func spool_false(t *topKSorter) { // or more distinct and complete groups, and then use a K-N size heap to find // the remaining top K-N rows. { + t.cancelChecker.CheckEveryCall() t.inputBatch = t.Input.Next() t.firstUnprocessedTupleIdx = 0 } @@ -332,6 +336,7 @@ func spool_false(t *topKSorter) { remainingRows -= uint64(fromLength) if fromLength == t.inputBatch.Length() { { + t.cancelChecker.CheckEveryCall() t.inputBatch = t.Input.Next() t.firstUnprocessedTupleIdx = 0 } @@ -390,6 +395,7 @@ func spool_false(t *topKSorter) { }, ) { + t.cancelChecker.CheckEveryCall() t.inputBatch = t.Input.Next() t.firstUnprocessedTupleIdx = 0 } diff --git a/pkg/sql/colexec/sorttopk.go b/pkg/sql/colexec/sorttopk.go index 9c2644d2ac35..b29141952def 100644 --- a/pkg/sql/colexec/sorttopk.go +++ b/pkg/sql/colexec/sorttopk.go @@ -117,6 +117,8 @@ type topKSorter struct { emitted int output coldata.Batch + cancelChecker colexecutils.CancelChecker + exportedFromTopK int exportedFromBatch int windowedBatch coldata.Batch @@ -150,6 +152,7 @@ func (t *topKSorter) Init(ctx context.Context) { t.orderState.distincter.Init(t.Ctx) t.orderState.group = make([]int, t.k) } + t.cancelChecker.Init(t.Ctx) } func (t *topKSorter) Next() coldata.Batch { diff --git a/pkg/sql/colexec/sorttopk_tmpl.go b/pkg/sql/colexec/sorttopk_tmpl.go index 2241787b4d4c..e57f87c4a63a 100644 --- a/pkg/sql/colexec/sorttopk_tmpl.go +++ b/pkg/sql/colexec/sorttopk_tmpl.go @@ -27,6 +27,7 @@ import ( // execgen:template // execgen:inline func nextBatch(t *topKSorter, partialOrder bool) { + t.cancelChecker.CheckEveryCall() t.inputBatch = t.Input.Next() if partialOrder { t.orderState.distincterInput.SetBatch(t.inputBatch)