diff --git a/.gitattributes b/.gitattributes
index 59188db426cff..25f18a922512d 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -4,5 +4,7 @@
# Declare files that will always have LF line endings on checkout.
*.y text eol=lf
+*.result diff
+
util/collate/unicode_0*_ci.go linguist-generated=true
util/collate/ucadata/unicode_*_data.go linguist-generated=true
diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
index f67985fd41669..6ba37f8ef01d6 100644
--- a/.github/pull_request_template.md
+++ b/.github/pull_request_template.md
@@ -33,7 +33,9 @@ Tests
- [ ] Unit test
- [ ] Integration test
- [ ] Manual test (add detailed scripts or steps below)
-- [ ] No code
+- [ ] No need to test
+ > - [ ] I checked and no code files have been changed.
+ >
Side effects
diff --git a/DEPS.bzl b/DEPS.bzl
index 15e84d288b60d..23fcba7faea41 100644
--- a/DEPS.bzl
+++ b/DEPS.bzl
@@ -1642,6 +1642,32 @@ def go_deps():
"https://storage.googleapis.com/pingcapmirror/gomod/github.com/docker/go-units/com_github_docker_go_units-v0.4.0.zip",
],
)
+ go_repository(
+ name = "com_github_dolthub_maphash",
+ build_file_proto_mode = "disable_global",
+ importpath = "github.com/dolthub/maphash",
+ sha256 = "ba69ef526a9613cb059c8490c1a4f032649879c316a1c4305e2355815eb32e41",
+ strip_prefix = "github.com/dolthub/maphash@v0.1.0",
+ urls = [
+ "http://bazel-cache.pingcap.net:8080/gomod/github.com/dolthub/maphash/com_github_dolthub_maphash-v0.1.0.zip",
+ "http://ats.apps.svc/gomod/github.com/dolthub/maphash/com_github_dolthub_maphash-v0.1.0.zip",
+ "https://cache.hawkingrei.com/gomod/github.com/dolthub/maphash/com_github_dolthub_maphash-v0.1.0.zip",
+ "https://storage.googleapis.com/pingcapmirror/gomod/github.com/dolthub/maphash/com_github_dolthub_maphash-v0.1.0.zip",
+ ],
+ )
+ go_repository(
+ name = "com_github_dolthub_swiss",
+ build_file_proto_mode = "disable_global",
+ importpath = "github.com/dolthub/swiss",
+ sha256 = "e911b7cea9aaed1255544fb8b53c19780f91b713e6d0fc71fb310232e4800dcc",
+ strip_prefix = "github.com/dolthub/swiss@v0.2.1",
+ urls = [
+ "http://bazel-cache.pingcap.net:8080/gomod/github.com/dolthub/swiss/com_github_dolthub_swiss-v0.2.1.zip",
+ "http://ats.apps.svc/gomod/github.com/dolthub/swiss/com_github_dolthub_swiss-v0.2.1.zip",
+ "https://cache.hawkingrei.com/gomod/github.com/dolthub/swiss/com_github_dolthub_swiss-v0.2.1.zip",
+ "https://storage.googleapis.com/pingcapmirror/gomod/github.com/dolthub/swiss/com_github_dolthub_swiss-v0.2.1.zip",
+ ],
+ )
go_repository(
name = "com_github_dustin_go_humanize",
build_file_proto_mode = "disable_global",
@@ -4363,13 +4389,13 @@ def go_deps():
name = "com_github_klauspost_compress",
build_file_proto_mode = "disable_global",
importpath = "github.com/klauspost/compress",
- sha256 = "7e004bb6b71535508bfa9c57256cfb2ca23f09ea281dbecafea217796b712fcd",
- strip_prefix = "github.com/klauspost/compress@v1.16.5",
+ sha256 = "fa94794543608ad4f600c67994a317173b4e72c1159b8a84ab46a846c7643587",
+ strip_prefix = "github.com/klauspost/compress@v1.17.0",
urls = [
- "http://bazel-cache.pingcap.net:8080/gomod/github.com/klauspost/compress/com_github_klauspost_compress-v1.16.5.zip",
- "http://ats.apps.svc/gomod/github.com/klauspost/compress/com_github_klauspost_compress-v1.16.5.zip",
- "https://cache.hawkingrei.com/gomod/github.com/klauspost/compress/com_github_klauspost_compress-v1.16.5.zip",
- "https://storage.googleapis.com/pingcapmirror/gomod/github.com/klauspost/compress/com_github_klauspost_compress-v1.16.5.zip",
+ "http://bazel-cache.pingcap.net:8080/gomod/github.com/klauspost/compress/com_github_klauspost_compress-v1.17.0.zip",
+ "http://ats.apps.svc/gomod/github.com/klauspost/compress/com_github_klauspost_compress-v1.17.0.zip",
+ "https://cache.hawkingrei.com/gomod/github.com/klauspost/compress/com_github_klauspost_compress-v1.17.0.zip",
+ "https://storage.googleapis.com/pingcapmirror/gomod/github.com/klauspost/compress/com_github_klauspost_compress-v1.17.0.zip",
],
)
go_repository(
@@ -6976,13 +7002,13 @@ def go_deps():
name = "com_github_tikv_client_go_v2",
build_file_proto_mode = "disable_global",
importpath = "github.com/tikv/client-go/v2",
- sha256 = "663c08693ab25489aa2d65b5e1afc34cf51920fb5db86cc591c22cea8f1b0979",
- strip_prefix = "github.com/tikv/client-go/v2@v2.0.8-0.20230911065915-f9e28714c62c",
+ sha256 = "e3e864f9d839dbbaf45fbdaf63e89fe15f8d668d0b6dd9dcf4acaa52d99af75f",
+ strip_prefix = "github.com/tikv/client-go/v2@v2.0.8-0.20230919031511-be2b4c78a910",
urls = [
- "http://bazel-cache.pingcap.net:8080/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20230911065915-f9e28714c62c.zip",
- "http://ats.apps.svc/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20230911065915-f9e28714c62c.zip",
- "https://cache.hawkingrei.com/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20230911065915-f9e28714c62c.zip",
- "https://storage.googleapis.com/pingcapmirror/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20230911065915-f9e28714c62c.zip",
+ "http://bazel-cache.pingcap.net:8080/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20230919031511-be2b4c78a910.zip",
+ "http://ats.apps.svc/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20230919031511-be2b4c78a910.zip",
+ "https://cache.hawkingrei.com/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20230919031511-be2b4c78a910.zip",
+ "https://storage.googleapis.com/pingcapmirror/gomod/github.com/tikv/client-go/v2/com_github_tikv_client_go_v2-v2.0.8-0.20230919031511-be2b4c78a910.zip",
],
)
go_repository(
@@ -10074,13 +10100,13 @@ def go_deps():
name = "org_golang_x_crypto",
build_file_proto_mode = "disable_global",
importpath = "golang.org/x/crypto",
- sha256 = "29b788bd8f1229214af831bf99412a09d19096dea3c62bc3281656b64093d12d",
- strip_prefix = "golang.org/x/crypto@v0.12.0",
+ sha256 = "b58d902f48a7f595a28589b6ed4be8b5e2ee1a3496eca477039e80eb6e7eba57",
+ strip_prefix = "golang.org/x/crypto@v0.13.0",
urls = [
- "http://bazel-cache.pingcap.net:8080/gomod/golang.org/x/crypto/org_golang_x_crypto-v0.12.0.zip",
- "http://ats.apps.svc/gomod/golang.org/x/crypto/org_golang_x_crypto-v0.12.0.zip",
- "https://cache.hawkingrei.com/gomod/golang.org/x/crypto/org_golang_x_crypto-v0.12.0.zip",
- "https://storage.googleapis.com/pingcapmirror/gomod/golang.org/x/crypto/org_golang_x_crypto-v0.12.0.zip",
+ "http://bazel-cache.pingcap.net:8080/gomod/golang.org/x/crypto/org_golang_x_crypto-v0.13.0.zip",
+ "http://ats.apps.svc/gomod/golang.org/x/crypto/org_golang_x_crypto-v0.13.0.zip",
+ "https://cache.hawkingrei.com/gomod/golang.org/x/crypto/org_golang_x_crypto-v0.13.0.zip",
+ "https://storage.googleapis.com/pingcapmirror/gomod/golang.org/x/crypto/org_golang_x_crypto-v0.13.0.zip",
],
)
go_repository(
@@ -10165,13 +10191,13 @@ def go_deps():
name = "org_golang_x_net",
build_file_proto_mode = "disable_global",
importpath = "golang.org/x/net",
- sha256 = "fdd5ca5653644b65d9062705ecae70b156660547f96d4606659960fe0c053871",
- strip_prefix = "golang.org/x/net@v0.14.0",
+ sha256 = "e891941f0a83dfc85f82990e29cbf1939dca5952d04241666c8a227d419fded3",
+ strip_prefix = "golang.org/x/net@v0.15.0",
urls = [
- "http://bazel-cache.pingcap.net:8080/gomod/golang.org/x/net/org_golang_x_net-v0.14.0.zip",
- "http://ats.apps.svc/gomod/golang.org/x/net/org_golang_x_net-v0.14.0.zip",
- "https://cache.hawkingrei.com/gomod/golang.org/x/net/org_golang_x_net-v0.14.0.zip",
- "https://storage.googleapis.com/pingcapmirror/gomod/golang.org/x/net/org_golang_x_net-v0.14.0.zip",
+ "http://bazel-cache.pingcap.net:8080/gomod/golang.org/x/net/org_golang_x_net-v0.15.0.zip",
+ "http://ats.apps.svc/gomod/golang.org/x/net/org_golang_x_net-v0.15.0.zip",
+ "https://cache.hawkingrei.com/gomod/golang.org/x/net/org_golang_x_net-v0.15.0.zip",
+ "https://storage.googleapis.com/pingcapmirror/gomod/golang.org/x/net/org_golang_x_net-v0.15.0.zip",
],
)
go_repository(
@@ -10204,39 +10230,39 @@ def go_deps():
name = "org_golang_x_sys",
build_file_proto_mode = "disable_global",
importpath = "golang.org/x/sys",
- sha256 = "0d03f4d1aa3b28ec80e2005ec8301004e2153e3154911baebf57b9cfa900f993",
- strip_prefix = "golang.org/x/sys@v0.11.0",
+ sha256 = "89225d9e6603c090ffd93286b7ca124849fadfe4320c3b18a6bdccc4ac08672c",
+ strip_prefix = "golang.org/x/sys@v0.12.0",
urls = [
- "http://bazel-cache.pingcap.net:8080/gomod/golang.org/x/sys/org_golang_x_sys-v0.11.0.zip",
- "http://ats.apps.svc/gomod/golang.org/x/sys/org_golang_x_sys-v0.11.0.zip",
- "https://cache.hawkingrei.com/gomod/golang.org/x/sys/org_golang_x_sys-v0.11.0.zip",
- "https://storage.googleapis.com/pingcapmirror/gomod/golang.org/x/sys/org_golang_x_sys-v0.11.0.zip",
+ "http://bazel-cache.pingcap.net:8080/gomod/golang.org/x/sys/org_golang_x_sys-v0.12.0.zip",
+ "http://ats.apps.svc/gomod/golang.org/x/sys/org_golang_x_sys-v0.12.0.zip",
+ "https://cache.hawkingrei.com/gomod/golang.org/x/sys/org_golang_x_sys-v0.12.0.zip",
+ "https://storage.googleapis.com/pingcapmirror/gomod/golang.org/x/sys/org_golang_x_sys-v0.12.0.zip",
],
)
go_repository(
name = "org_golang_x_term",
build_file_proto_mode = "disable_global",
importpath = "golang.org/x/term",
- sha256 = "d9a79ecfb908333f03d7f3f4b597551a6916462c6c5d040528c9887df956600e",
- strip_prefix = "golang.org/x/term@v0.11.0",
+ sha256 = "f4bbc4baa0c9b053f7d252b06e4e8baabd686a9a87d82025b341796e29f39c60",
+ strip_prefix = "golang.org/x/term@v0.12.0",
urls = [
- "http://bazel-cache.pingcap.net:8080/gomod/golang.org/x/term/org_golang_x_term-v0.11.0.zip",
- "http://ats.apps.svc/gomod/golang.org/x/term/org_golang_x_term-v0.11.0.zip",
- "https://cache.hawkingrei.com/gomod/golang.org/x/term/org_golang_x_term-v0.11.0.zip",
- "https://storage.googleapis.com/pingcapmirror/gomod/golang.org/x/term/org_golang_x_term-v0.11.0.zip",
+ "http://bazel-cache.pingcap.net:8080/gomod/golang.org/x/term/org_golang_x_term-v0.12.0.zip",
+ "http://ats.apps.svc/gomod/golang.org/x/term/org_golang_x_term-v0.12.0.zip",
+ "https://cache.hawkingrei.com/gomod/golang.org/x/term/org_golang_x_term-v0.12.0.zip",
+ "https://storage.googleapis.com/pingcapmirror/gomod/golang.org/x/term/org_golang_x_term-v0.12.0.zip",
],
)
go_repository(
name = "org_golang_x_text",
build_file_proto_mode = "disable_global",
importpath = "golang.org/x/text",
- sha256 = "437a787c7f92bcb8b2f2ab97fcd74ce88b5e7a5b21aa299e90f5c5dd28a7b66f",
- strip_prefix = "golang.org/x/text@v0.12.0",
+ sha256 = "ed544fb017e967c053892df7b068612fce707ba32b57f35824cb041e31c6ae0f",
+ strip_prefix = "golang.org/x/text@v0.13.0",
urls = [
- "http://bazel-cache.pingcap.net:8080/gomod/golang.org/x/text/org_golang_x_text-v0.12.0.zip",
- "http://ats.apps.svc/gomod/golang.org/x/text/org_golang_x_text-v0.12.0.zip",
- "https://cache.hawkingrei.com/gomod/golang.org/x/text/org_golang_x_text-v0.12.0.zip",
- "https://storage.googleapis.com/pingcapmirror/gomod/golang.org/x/text/org_golang_x_text-v0.12.0.zip",
+ "http://bazel-cache.pingcap.net:8080/gomod/golang.org/x/text/org_golang_x_text-v0.13.0.zip",
+ "http://ats.apps.svc/gomod/golang.org/x/text/org_golang_x_text-v0.13.0.zip",
+ "https://cache.hawkingrei.com/gomod/golang.org/x/text/org_golang_x_text-v0.13.0.zip",
+ "https://storage.googleapis.com/pingcapmirror/gomod/golang.org/x/text/org_golang_x_text-v0.13.0.zip",
],
)
go_repository(
diff --git a/Makefile b/Makefile
index c9df25116be41..321c05fd636c7 100644
--- a/Makefile
+++ b/Makefile
@@ -385,6 +385,7 @@ mock_lightning: tools/bin/mockgen
gen_mock: tools/bin/mockgen
tools/bin/mockgen -package mock github.com/pingcap/tidb/disttask/framework/scheduler TaskTable,Pool,Scheduler,Extension > disttask/framework/mock/scheduler_mock.go
+ tools/bin/mockgen -package mock github.com/pingcap/tidb/disttask/framework/dispatcher Dispatcher > disttask/framework/mock/dispatcher_mock.go
tools/bin/mockgen -package execute github.com/pingcap/tidb/disttask/framework/scheduler/execute SubtaskExecutor > disttask/framework/mock/execute/execute_mock.go
tools/bin/mockgen -package mock github.com/pingcap/tidb/disttask/importinto MiniTaskExecutor > disttask/importinto/mock/import_mock.go
tools/bin/mockgen -package mock github.com/pingcap/tidb/disttask/framework/planner LogicalPlan,PipelineSpec > disttask/framework/mock/plan_mock.go
diff --git a/OWNERS b/OWNERS
index 120b4c4041ec3..ab1e74778890a 100644
--- a/OWNERS
+++ b/OWNERS
@@ -114,6 +114,7 @@ reviewers:
- Benjamin2037
- bobotu
- BornChanger
+ - CabinfeverB
- charleszheng44
- ChenPeng2013
- dhysum
diff --git a/OWNERS_ALIASES b/OWNERS_ALIASES
index 1fd180a0eec8e..e3adc8dade8cc 100644
--- a/OWNERS_ALIASES
+++ b/OWNERS_ALIASES
@@ -11,3 +11,15 @@ aliases:
sig-approvers-autoid-service: # approvers for auto-id service
- bb7133
- tiancaiamao
+ sig-approvers-distsql: # approvers for distsql pkg
+ - windtalker
+ - XuHuaiyu
+ - zanmato1984
+ sig-approvers-executor: # approvers for executor pkg
+ - windtalker
+ - XuHuaiyu
+ - zanmato1984
+ sig-approvers-expression: # approvers for expression pkg
+ - windtalker
+ - XuHuaiyu
+ - zanmato1984
diff --git a/README.md b/README.md
index 05d2ad0cdd36e..62b9858d915d5 100644
--- a/README.md
+++ b/README.md
@@ -98,6 +98,13 @@ Contributions are welcomed and greatly appreciated. You can get started with one
Every contributor is welcome to claim your contribution swag by filling in and submitting this [form](https://forms.pingcap.com/f/tidb-contribution-swag).
+
+
+
+
## Case studies
- [Case studies in English](https://www.pingcap.com/customers/)
diff --git a/SECURITY.md b/SECURITY.md
index ff0e7844edeef..cbf47548a526b 100644
--- a/SECURITY.md
+++ b/SECURITY.md
@@ -4,7 +4,7 @@ TiDB is a fast-growing open source database. To ensure its security, a security
The primary goal of this process is to reduce the total exposure time of users to publicly known vulnerabilities. To quickly fix vulnerabilities of TiDB products, the security team is responsible for the entire vulnerability management process, including internal communication and external disclosure.
-If you find a vulnerability or encounter a security incident involving vulnerabilities of TiDB products, please report it as soon as possible to the TiDB security team (security@tidb.io).
+If you find a vulnerability or encounter a security incident involving vulnerabilities of TiDB products, please report it as soon as possible to the TiDB security team (security@pingcap.com).
Please kindly help provide as much vulnerability information as possible in the following format:
diff --git a/bindinfo/capture_test.go b/bindinfo/capture_test.go
index f5112cad9e160..333e9b8c63cff 100644
--- a/bindinfo/capture_test.go
+++ b/bindinfo/capture_test.go
@@ -1005,7 +1005,6 @@ func TestCaptureHints(t *testing.T) {
// runtime hints
{"select /*+ memory_quota(1024 MB) */ * from t", "memory_quota(1024 mb)"},
{"select /*+ max_execution_time(1000) */ * from t", "max_execution_time(1000)"},
- {"select /*+ tidb_kv_read_timeout(1000) */ * from t", "tidb_kv_read_timeout(1000)"},
// storage hints
{"select /*+ read_from_storage(tikv[t]) */ * from t", "read_from_storage(tikv[`t`])"},
// others
diff --git a/bindinfo/session_handle_test.go b/bindinfo/session_handle_test.go
index ba7b72eab61d1..f90aae0afcac7 100644
--- a/bindinfo/session_handle_test.go
+++ b/bindinfo/session_handle_test.go
@@ -43,14 +43,14 @@ func TestGlobalAndSessionBindingBothExist(t *testing.T) {
tk.MustExec("drop table if exists t2")
tk.MustExec("create table t1(id int)")
tk.MustExec("create table t2(id int)")
- require.True(t, tk.HasPlan("SELECT * from t1,t2 where t1.id = t2.id", "HashJoin"))
- require.True(t, tk.HasPlan("SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id", "MergeJoin"))
+ tk.MustHavePlan("SELECT * from t1,t2 where t1.id = t2.id", "HashJoin")
+ tk.MustHavePlan("SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id", "MergeJoin")
tk.MustExec("create global binding for SELECT * from t1,t2 where t1.id = t2.id using SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id")
// Test bindingUsage, which indicates how many times the binding is used.
metrics.BindUsageCounter.Reset()
- require.True(t, tk.HasPlan("SELECT * from t1,t2 where t1.id = t2.id", "MergeJoin"))
+ tk.MustHavePlan("SELECT * from t1,t2 where t1.id = t2.id", "MergeJoin")
pb := &dto.Metric{}
err := metrics.BindUsageCounter.WithLabelValues(metrics.ScopeGlobal).Write(pb)
require.NoError(t, err)
@@ -58,30 +58,30 @@ func TestGlobalAndSessionBindingBothExist(t *testing.T) {
// Test 'tidb_use_plan_baselines'
tk.MustExec("set @@tidb_use_plan_baselines = 0")
- require.True(t, tk.HasPlan("SELECT * from t1,t2 where t1.id = t2.id", "HashJoin"))
+ tk.MustHavePlan("SELECT * from t1,t2 where t1.id = t2.id", "HashJoin")
tk.MustExec("set @@tidb_use_plan_baselines = 1")
// Test 'drop global binding'
- require.True(t, tk.HasPlan("SELECT * from t1,t2 where t1.id = t2.id", "MergeJoin"))
+ tk.MustHavePlan("SELECT * from t1,t2 where t1.id = t2.id", "MergeJoin")
tk.MustExec("drop global binding for SELECT * from t1,t2 where t1.id = t2.id")
- require.True(t, tk.HasPlan("SELECT * from t1,t2 where t1.id = t2.id", "HashJoin"))
+ tk.MustHavePlan("SELECT * from t1,t2 where t1.id = t2.id", "HashJoin")
// Test the case when global and session binding both exist
// PART1 : session binding should totally cover global binding
// use merge join as session binding here since the optimizer will choose hash join for this stmt in default
tk.MustExec("create global binding for SELECT * from t1,t2 where t1.id = t2.id using SELECT /*+ TIDB_HJ(t1, t2) */ * from t1,t2 where t1.id = t2.id")
- require.True(t, tk.HasPlan("SELECT * from t1,t2 where t1.id = t2.id", "HashJoin"))
+ tk.MustHavePlan("SELECT * from t1,t2 where t1.id = t2.id", "HashJoin")
tk.MustExec("create binding for SELECT * from t1,t2 where t1.id = t2.id using SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id")
- require.True(t, tk.HasPlan("SELECT * from t1,t2 where t1.id = t2.id", "MergeJoin"))
+ tk.MustHavePlan("SELECT * from t1,t2 where t1.id = t2.id", "MergeJoin")
tk.MustExec("drop global binding for SELECT * from t1,t2 where t1.id = t2.id")
- require.True(t, tk.HasPlan("SELECT * from t1,t2 where t1.id = t2.id", "MergeJoin"))
+ tk.MustHavePlan("SELECT * from t1,t2 where t1.id = t2.id", "MergeJoin")
// PART2 : the dropped session binding should continue to block the effect of global binding
tk.MustExec("create global binding for SELECT * from t1,t2 where t1.id = t2.id using SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id")
tk.MustExec("drop binding for SELECT * from t1,t2 where t1.id = t2.id")
- require.True(t, tk.HasPlan("SELECT * from t1,t2 where t1.id = t2.id", "HashJoin"))
+ tk.MustHavePlan("SELECT * from t1,t2 where t1.id = t2.id", "HashJoin")
tk.MustExec("drop global binding for SELECT * from t1,t2 where t1.id = t2.id")
- require.True(t, tk.HasPlan("SELECT * from t1,t2 where t1.id = t2.id", "HashJoin"))
+ tk.MustHavePlan("SELECT * from t1,t2 where t1.id = t2.id", "HashJoin")
}
func TestSessionBinding(t *testing.T) {
diff --git a/bindinfo/tests/bind_test.go b/bindinfo/tests/bind_test.go
index 399051dd10a2c..ac5e485fa8f9f 100644
--- a/bindinfo/tests/bind_test.go
+++ b/bindinfo/tests/bind_test.go
@@ -383,23 +383,23 @@ func TestExplain(t *testing.T) {
tk.MustExec("create table t1(id int)")
tk.MustExec("create table t2(id int)")
- require.True(t, tk.HasPlan("SELECT * from t1,t2 where t1.id = t2.id", "HashJoin"))
- require.True(t, tk.HasPlan("SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id", "MergeJoin"))
+ tk.MustHavePlan("SELECT * from t1,t2 where t1.id = t2.id", "HashJoin")
+ tk.MustHavePlan("SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id", "MergeJoin")
tk.MustExec("create global binding for SELECT * from t1,t2 where t1.id = t2.id using SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id")
- require.True(t, tk.HasPlan("SELECT * from t1,t2 where t1.id = t2.id", "MergeJoin"))
+ tk.MustHavePlan("SELECT * from t1,t2 where t1.id = t2.id", "MergeJoin")
tk.MustExec("drop global binding for SELECT * from t1,t2 where t1.id = t2.id")
// Add test for SetOprStmt
tk.MustExec("create index index_id on t1(id)")
- require.True(t, tk.HasPlan("SELECT * from t1 union SELECT * from t1", "IndexReader"))
- require.True(t, tk.HasPlan("SELECT * from t1 use index(index_id) union SELECT * from t1", "IndexReader"))
+ tk.MustHavePlan("SELECT * from t1 union SELECT * from t1", "IndexReader")
+ tk.MustHavePlan("SELECT * from t1 use index(index_id) union SELECT * from t1", "IndexReader")
tk.MustExec("create global binding for SELECT * from t1 union SELECT * from t1 using SELECT * from t1 use index(index_id) union SELECT * from t1")
- require.True(t, tk.HasPlan("SELECT * from t1 union SELECT * from t1", "IndexReader"))
+ tk.MustHavePlan("SELECT * from t1 union SELECT * from t1", "IndexReader")
tk.MustExec("drop global binding for SELECT * from t1 union SELECT * from t1")
}
@@ -433,8 +433,8 @@ func TestBindCTEMerge(t *testing.T) {
tk.MustExec("use test")
tk.MustExec("drop table if exists t1")
tk.MustExec("create table t1(id int)")
- require.True(t, tk.HasPlan("with cte as (select * from t1) select * from cte a, cte b", "CTEFullScan"))
- require.False(t, tk.HasPlan("with cte as (select /*+ MERGE() */ * from t1) select * from cte a, cte b", "CTEFullScan"))
+ tk.MustHavePlan("with cte as (select * from t1) select * from cte a, cte b", "CTEFullScan")
+ tk.MustNotHavePlan("with cte as (select /*+ MERGE() */ * from t1) select * from cte a, cte b", "CTEFullScan")
tk.MustExec(`
create global binding for
with cte as (select * from t1) select * from cte
@@ -442,7 +442,7 @@ using
with cte as (select /*+ MERGE() */ * from t1) select * from cte
`)
- require.False(t, tk.HasPlan("with cte as (select * from t1) select * from cte", "CTEFullScan"))
+ tk.MustNotHavePlan("with cte as (select * from t1) select * from cte", "CTEFullScan")
}
func TestBindNoDecorrelate(t *testing.T) {
@@ -454,8 +454,8 @@ func TestBindNoDecorrelate(t *testing.T) {
tk.MustExec("drop table if exists t2")
tk.MustExec("create table t1(a int, b int)")
tk.MustExec("create table t2(a int, b int)")
- require.False(t, tk.HasPlan("select exists (select t2.b from t2 where t2.a = t1.b limit 2) from t1", "Apply"))
- require.True(t, tk.HasPlan("select exists (select /*+ no_decorrelate() */ t2.b from t2 where t2.a = t1.b limit 2) from t1", "Apply"))
+ tk.MustNotHavePlan("select exists (select t2.b from t2 where t2.a = t1.b limit 2) from t1", "Apply")
+ tk.MustHavePlan("select exists (select /*+ no_decorrelate() */ t2.b from t2 where t2.a = t1.b limit 2) from t1", "Apply")
tk.MustExec(`
create global binding for
@@ -464,7 +464,7 @@ using
select exists (select /*+ no_decorrelate() */ t2.b from t2 where t2.a = t1.b limit 2) from t1
`)
- require.True(t, tk.HasPlan("select exists (select t2.b from t2 where t2.a = t1.b limit 2) from t1", "Apply"))
+ tk.MustHavePlan("select exists (select t2.b from t2 where t2.a = t1.b limit 2) from t1", "Apply")
}
// TestBindingSymbolList tests sql with "?, ?, ?, ?", fixes #13871
@@ -566,9 +566,9 @@ func TestDMLSQLBind(t *testing.T) {
require.Equal(t, "t1:idx_c", tk.Session().GetSessionVars().StmtCtx.IndexNames[0])
require.True(t, tk.MustUseIndex("delete from t1 where b = 1 and c > 1", "idx_c(c)"))
- require.True(t, tk.HasPlan("delete t1, t2 from t1 inner join t2 on t1.b = t2.b", "HashJoin"))
+ tk.MustHavePlan("delete t1, t2 from t1 inner join t2 on t1.b = t2.b", "HashJoin")
tk.MustExec("create global binding for delete t1, t2 from t1 inner join t2 on t1.b = t2.b using delete /*+ inl_join(t1) */ t1, t2 from t1 inner join t2 on t1.b = t2.b")
- require.True(t, tk.HasPlan("delete t1, t2 from t1 inner join t2 on t1.b = t2.b", "IndexJoin"))
+ tk.MustHavePlan("delete t1, t2 from t1 inner join t2 on t1.b = t2.b", "IndexJoin")
tk.MustExec("update t1 set a = 1 where b = 1 and c > 1")
require.Equal(t, "t1:idx_b", tk.Session().GetSessionVars().StmtCtx.IndexNames[0])
@@ -578,9 +578,9 @@ func TestDMLSQLBind(t *testing.T) {
require.Equal(t, "t1:idx_c", tk.Session().GetSessionVars().StmtCtx.IndexNames[0])
require.True(t, tk.MustUseIndex("update t1 set a = 1 where b = 1 and c > 1", "idx_c(c)"))
- require.True(t, tk.HasPlan("update t1, t2 set t1.a = 1 where t1.b = t2.b", "HashJoin"))
+ tk.MustHavePlan("update t1, t2 set t1.a = 1 where t1.b = t2.b", "HashJoin")
tk.MustExec("create global binding for update t1, t2 set t1.a = 1 where t1.b = t2.b using update /*+ inl_join(t1) */ t1, t2 set t1.a = 1 where t1.b = t2.b")
- require.True(t, tk.HasPlan("update t1, t2 set t1.a = 1 where t1.b = t2.b", "IndexJoin"))
+ tk.MustHavePlan("update t1, t2 set t1.a = 1 where t1.b = t2.b", "IndexJoin")
tk.MustExec("insert into t1 select * from t2 where t2.b = 2 and t2.c > 2")
require.Equal(t, "t2:idx_b", tk.Session().GetSessionVars().StmtCtx.IndexNames[0])
@@ -802,11 +802,11 @@ func TestRuntimeHintsInEvolveTasks(t *testing.T) {
tk.MustExec("create table t(a int, b int, c int, index idx_a(a), index idx_b(b), index idx_c(c))")
tk.MustExec("create global binding for select * from t where a >= 1 and b >= 1 and c = 0 using select * from t use index(idx_a) where a >= 1 and b >= 1 and c = 0")
- tk.MustQuery("select /*+ MAX_EXECUTION_TIME(5000), TIDB_KV_READ_TIMEOUT(20) */ * from t where a >= 4 and b >= 1 and c = 0")
+ tk.MustQuery("select /*+ MAX_EXECUTION_TIME(5000), SET_VAR(TIKV_CLIENT_READ_TIMEOUT=20) */ * from t where a >= 4 and b >= 1 and c = 0")
tk.MustExec("admin flush bindings")
rows := tk.MustQuery("show global bindings").Rows()
require.Len(t, rows, 2)
- require.Equal(t, "SELECT /*+ use_index(@`sel_1` `test`.`t` `idx_c`), no_order_index(@`sel_1` `test`.`t` `idx_c`), max_execution_time(5000), tidb_kv_read_timeout(20)*/ * FROM `test`.`t` WHERE `a` >= 4 AND `b` >= 1 AND `c` = 0", rows[0][1])
+ require.Equal(t, "SELECT /*+ use_index(@`sel_1` `test`.`t` `idx_c`), no_order_index(@`sel_1` `test`.`t` `idx_c`), max_execution_time(5000), set_var(tikv_client_read_timeout = 20)*/ * FROM `test`.`t` WHERE `a` >= 4 AND `b` >= 1 AND `c` = 0", rows[0][1])
}
func TestDefaultSessionVars(t *testing.T) {
@@ -866,15 +866,16 @@ func TestStmtHints(t *testing.T) {
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int, b int, index idx(a))")
- tk.MustExec("create global binding for select * from t using select /*+ MAX_EXECUTION_TIME(100), TIDB_KV_READ_TIMEOUT(20), MEMORY_QUOTA(2 GB) */ * from t use index(idx)")
+ tk.MustExec("create global binding for select * from t using select /*+ MAX_EXECUTION_TIME(100), SET_VAR(TIKV_CLIENT_READ_TIMEOUT=20), MEMORY_QUOTA(2 GB) */ * from t use index(idx)")
tk.MustQuery("select * from t")
require.Equal(t, int64(2147483648), tk.Session().GetSessionVars().MemTracker.GetBytesLimit())
require.Equal(t, uint64(100), tk.Session().GetSessionVars().StmtCtx.MaxExecutionTime)
- require.Equal(t, uint64(20), tk.Session().GetSessionVars().StmtCtx.TidbKvReadTimeout)
+ require.Equal(t, uint64(20), tk.Session().GetSessionVars().GetTiKVClientReadTimeout())
tk.MustQuery("select a, b from t")
require.Equal(t, int64(1073741824), tk.Session().GetSessionVars().MemTracker.GetBytesLimit())
require.Equal(t, uint64(0), tk.Session().GetSessionVars().StmtCtx.MaxExecutionTime)
- require.Equal(t, uint64(0), tk.Session().GetSessionVars().StmtCtx.TidbKvReadTimeout)
+ // TODO(crazycs520): Fix me.
+ //require.Equal(t, uint64(0), tk.Session().GetSessionVars().GetTiKVClientReadTimeout())
}
func TestPrivileges(t *testing.T) {
@@ -1153,14 +1154,14 @@ func TestSPMHitInfo(t *testing.T) {
tk.MustExec("create table t1(id int)")
tk.MustExec("create table t2(id int)")
- require.True(t, tk.HasPlan("SELECT * from t1,t2 where t1.id = t2.id", "HashJoin"))
- require.True(t, tk.HasPlan("SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id", "MergeJoin"))
+ tk.MustHavePlan("SELECT * from t1,t2 where t1.id = t2.id", "HashJoin")
+ tk.MustHavePlan("SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id", "MergeJoin")
tk.MustExec("SELECT * from t1,t2 where t1.id = t2.id")
tk.MustQuery(`select @@last_plan_from_binding;`).Check(testkit.Rows("0"))
tk.MustExec("create global binding for SELECT * from t1,t2 where t1.id = t2.id using SELECT /*+ TIDB_SMJ(t1, t2) */ * from t1,t2 where t1.id = t2.id")
- require.True(t, tk.HasPlan("SELECT * from t1,t2 where t1.id = t2.id", "MergeJoin"))
+ tk.MustHavePlan("SELECT * from t1,t2 where t1.id = t2.id", "MergeJoin")
tk.MustExec("SELECT * from t1,t2 where t1.id = t2.id")
tk.MustQuery(`select @@last_plan_from_binding;`).Check(testkit.Rows("1"))
tk.MustExec("set binding disabled for SELECT * from t1,t2 where t1.id = t2.id")
@@ -1407,7 +1408,7 @@ func TestBindSQLDigest(t *testing.T) {
// runtime hints
{"select * from t", "select /*+ memory_quota(1024 MB) */ * from t"},
{"select * from t", "select /*+ max_execution_time(1000) */ * from t"},
- {"select * from t", "select /*+ tidb_kv_read_timeout(1000) */ * from t"},
+ {"select * from t", "select /*+ set_var(tikv_client_read_timeout=1000) */ * from t"},
// storage hints
{"select * from t", "select /*+ read_from_storage(tikv[t]) */ * from t"},
// others
@@ -1469,7 +1470,7 @@ func TestDropBindBySQLDigest(t *testing.T) {
// runtime hints
{"select * from t", "select /*+ memory_quota(1024 MB) */ * from t"},
{"select * from t", "select /*+ max_execution_time(1000) */ * from t"},
- {"select * from t", "select /*+ tidb_kv_read_timeout(1000) */ * from t"},
+ {"select * from t", "select /*+ set_var(tikv_client_read_timeout=1000) */ * from t"},
// storage hints
{"select * from t", "select /*+ read_from_storage(tikv[t]) */ * from t"},
// others
diff --git a/br/cmd/tidb-lightning-ctl/main.go b/br/cmd/tidb-lightning-ctl/main.go
index ea1a48b095298..fd36380ff1252 100644
--- a/br/cmd/tidb-lightning-ctl/main.go
+++ b/br/cmd/tidb-lightning-ctl/main.go
@@ -200,7 +200,7 @@ func checkpointErrorDestroy(ctx context.Context, cfg *config.Config, tls *common
for _, table := range targetTables {
for engineID := table.MinEngineID; engineID <= table.MaxEngineID; engineID++ {
fmt.Fprintln(os.Stderr, "Closing and cleaning up engine:", table.TableName, engineID)
- _, eID := backend.MakeUUID(table.TableName, engineID)
+ _, eID := backend.MakeUUID(table.TableName, int64(engineID))
engine := local.Engine{UUID: eID}
err := engine.Cleanup(cfg.TikvImporter.SortedKVDir)
if err != nil {
diff --git a/br/pkg/lightning/backend/backend.go b/br/pkg/lightning/backend/backend.go
index be55a88fc4a96..d50731dac713a 100644
--- a/br/pkg/lightning/backend/backend.go
+++ b/br/pkg/lightning/backend/backend.go
@@ -36,7 +36,7 @@ const (
importMaxRetryTimes = 3 // tikv-importer has done retry internally. so we don't retry many times.
)
-func makeTag(tableName string, engineID int32) string {
+func makeTag(tableName string, engineID int64) string {
return fmt.Sprintf("%s:%d", tableName, engineID)
}
@@ -48,7 +48,7 @@ func makeLogger(logger log.Logger, tag string, engineUUID uuid.UUID) log.Logger
}
// MakeUUID generates a UUID for the engine and a tag for the engine.
-func MakeUUID(tableName string, engineID int32) (string, uuid.UUID) {
+func MakeUUID(tableName string, engineID int64) (string, uuid.UUID) {
tag := makeTag(tableName, engineID)
engineUUID := uuid.NewSHA1(engineNamespace, []byte(tag))
return tag, engineUUID
@@ -229,7 +229,7 @@ func MakeEngineManager(ab Backend) EngineManager {
// OpenEngine opens an engine with the given table name and engine ID.
func (be EngineManager) OpenEngine(ctx context.Context, config *EngineConfig,
tableName string, engineID int32) (*OpenedEngine, error) {
- tag, engineUUID := MakeUUID(tableName, engineID)
+ tag, engineUUID := MakeUUID(tableName, int64(engineID))
logger := makeLogger(log.FromContext(ctx), tag, engineUUID)
if err := be.backend.OpenEngine(ctx, config, engineUUID); err != nil {
@@ -298,7 +298,7 @@ func (engine *OpenedEngine) LocalWriter(ctx context.Context, cfg *LocalWriterCon
// resuming from a checkpoint.
func (be EngineManager) UnsafeCloseEngine(ctx context.Context, cfg *EngineConfig,
tableName string, engineID int32) (*ClosedEngine, error) {
- tag, engineUUID := MakeUUID(tableName, engineID)
+ tag, engineUUID := MakeUUID(tableName, int64(engineID))
return be.UnsafeCloseEngineWithUUID(ctx, cfg, tag, engineUUID, engineID)
}
diff --git a/br/pkg/lightning/backend/external/BUILD.bazel b/br/pkg/lightning/backend/external/BUILD.bazel
index bf1bdd5a8f6d0..e33c4de6ce698 100644
--- a/br/pkg/lightning/backend/external/BUILD.bazel
+++ b/br/pkg/lightning/backend/external/BUILD.bazel
@@ -10,6 +10,7 @@ go_library(
"file.go",
"iter.go",
"kv_reader.go",
+ "merge.go",
"split.go",
"stat_reader.go",
"util.go",
@@ -18,6 +19,9 @@ go_library(
importpath = "github.com/pingcap/tidb/br/pkg/lightning/backend/external",
visibility = ["//visibility:public"],
deps = [
+ "//br/pkg/lightning/backend",
+ "//br/pkg/lightning/backend/encode",
+ "//br/pkg/lightning/backend/kv",
"//br/pkg/lightning/common",
"//br/pkg/lightning/log",
"//br/pkg/membuf",
@@ -51,12 +55,14 @@ go_test(
],
embed = [":external"],
flaky = True,
- shard_count = 32,
+ shard_count = 36,
deps = [
+ "//br/pkg/lightning/backend/kv",
"//br/pkg/lightning/common",
"//br/pkg/storage",
"//kv",
"//util/codec",
+ "//util/size",
"@com_github_aws_aws_sdk_go//aws",
"@com_github_aws_aws_sdk_go//aws/credentials",
"@com_github_aws_aws_sdk_go//aws/session",
diff --git a/br/pkg/lightning/backend/external/byte_reader.go b/br/pkg/lightning/backend/external/byte_reader.go
index 526fe240df8dc..cc22e5a0639a6 100644
--- a/br/pkg/lightning/backend/external/byte_reader.go
+++ b/br/pkg/lightning/backend/external/byte_reader.go
@@ -39,7 +39,7 @@ type byteReader struct {
retPointers []*[]byte
- useConcurrentReaderCurrent atomic.Bool
+ useConcurrentReaderCurrent bool
useConcurrentReader atomic.Bool
currFileOffset int64
@@ -63,14 +63,25 @@ func openStoreReaderAndSeek(
return storageReader, nil
}
-// newByteReader wraps readNBytes functionality to storageReader. It will not
-// close storageReader when meet error.
-func newByteReader(ctx context.Context, storageReader storage.ExternalFileReader, bufSize int, st storage.ExternalStorage, name string, defaultUseConcurrency bool) (*byteReader, error) {
+// newByteReader wraps readNBytes functionality to storageReader.
+func newByteReader(
+ ctx context.Context,
+ storageReader storage.ExternalFileReader,
+ bufSize int,
+ st storage.ExternalStorage,
+ name string,
+ defaultUseConcurrency bool,
+) (r *byteReader, err error) {
+ defer func() {
+ if err != nil && r != nil {
+ _ = r.Close()
+ }
+ }()
conReader, err := newSingeFileReader(ctx, st, name, 8, ConcurrentReaderBufferSize)
if err != nil {
return nil, err
}
- r := &byteReader{
+ r = &byteReader{
ctx: ctx,
storageReader: storageReader,
buf: make([]byte, bufSize),
@@ -98,13 +109,13 @@ func (r *byteReader) switchToConcurrentReaderImpl() error {
r.conReader.currentFileOffset = currOffset
r.conReader.bufferReadOffset = 0
- r.useConcurrentReaderCurrent.Store(true)
r.conReader.buffer = make([]byte, r.conReader.concurrency*r.conReader.readBufferSize)
+ r.useConcurrentReaderCurrent = true
return nil
}
func (r *byteReader) switchToNormalReaderImpl() error {
- r.useConcurrentReaderCurrent.Store(false)
+ r.useConcurrentReaderCurrent = false
r.currFileOffset = r.conReader.currentFileOffset
r.conReader.buffer = nil
_, err := r.storageReader.Seek(r.currFileOffset, io.SeekStart)
@@ -165,7 +176,7 @@ func (r *byteReader) cloneSlices() {
}
func (r *byteReader) next(n int) []byte {
- if r.useConcurrentReaderCurrent.Load() {
+ if r.useConcurrentReaderCurrent {
return r.conReader.next(n)
}
end := mathutil.Min(r.bufOffset+n, len(r.buf))
@@ -176,8 +187,9 @@ func (r *byteReader) next(n int) []byte {
func (r *byteReader) reload() error {
to := r.useConcurrentReader.Load()
- now := r.useConcurrentReaderCurrent.Load()
+ now := r.useConcurrentReaderCurrent
if to != now {
+ logutil.Logger(r.ctx).Info("switch reader mode", zap.Bool("use concurrent mode", to))
if to {
err := r.switchToConcurrentReaderImpl()
if err != nil {
diff --git a/br/pkg/lightning/backend/external/iter.go b/br/pkg/lightning/backend/external/iter.go
index 65d95c1b9a981..2c465403d35d3 100644
--- a/br/pkg/lightning/backend/external/iter.go
+++ b/br/pkg/lightning/backend/external/iter.go
@@ -33,6 +33,7 @@ type heapElem interface {
type sortedReader[T heapElem] interface {
path() string
next() (T, error)
+ setReadMode(useConcurrency bool)
close() error
}
@@ -68,11 +69,13 @@ func (h *mergeHeap[T]) Pop() interface{} {
}
type mergeIter[T heapElem, R sortedReader[T]] struct {
- h mergeHeap[T]
- readers []*R
- curr T
- lastReaderIdx int
- err error
+ h mergeHeap[T]
+ readers []*R
+ curr T
+ lastReaderIdx int
+ err error
+ hotspotMap map[int]int
+ checkHotspotCnt int
logger *zap.Logger
}
@@ -131,6 +134,7 @@ func newMergeIter[
h: make(mergeHeap[T], 0, len(readers)),
readers: readers,
lastReaderIdx: -1,
+ hotspotMap: make(map[int]int),
logger: logger,
}
for j := range i.readers {
@@ -198,6 +202,19 @@ func (i *mergeIter[T, R]) next() bool {
var zeroT T
i.curr = zeroT
if i.lastReaderIdx >= 0 {
+ i.hotspotMap[i.lastReaderIdx] = i.hotspotMap[i.lastReaderIdx] + 1
+ i.checkHotspotCnt++
+
+ checkPeriod := 1000
+ // check hot point every checkPeriod times
+ if i.checkHotspotCnt == checkPeriod {
+ for idx, cnt := range i.hotspotMap {
+ (*i.readers[idx]).setReadMode(cnt > (checkPeriod / 2))
+ }
+ i.checkHotspotCnt = 0
+ i.hotspotMap = make(map[int]int)
+ }
+
rd := *i.readers[i.lastReaderIdx]
e, err := rd.next()
switch err {
@@ -211,6 +228,7 @@ func (i *mergeIter[T, R]) next() bool {
zap.Error(closeErr))
}
i.readers[i.lastReaderIdx] = nil
+ delete(i.hotspotMap, i.lastReaderIdx)
default:
i.err = err
return false
@@ -255,6 +273,10 @@ func (p kvReaderProxy) next() (kvPair, error) {
return kvPair{key: k, value: v}, nil
}
+func (p kvReaderProxy) setReadMode(useConcurrency bool) {
+ p.r.byteReader.switchReaderMode(useConcurrency)
+}
+
func (p kvReaderProxy) close() error {
return p.r.Close()
}
@@ -333,6 +355,10 @@ func (p statReaderProxy) next() (*rangeProperty, error) {
return p.r.nextProp()
}
+func (p statReaderProxy) setReadMode(useConcurrency bool) {
+ p.r.byteReader.switchReaderMode(useConcurrency)
+}
+
func (p statReaderProxy) close() error {
return p.r.Close()
}
diff --git a/br/pkg/lightning/backend/external/iter_test.go b/br/pkg/lightning/backend/external/iter_test.go
index e855581ee2d19..18dbf512b4c38 100644
--- a/br/pkg/lightning/backend/external/iter_test.go
+++ b/br/pkg/lightning/backend/external/iter_test.go
@@ -16,13 +16,17 @@ package external
import (
"context"
+ "encoding/binary"
"fmt"
"io"
"testing"
+ "time"
+ "github.com/pingcap/tidb/br/pkg/lightning/common"
"github.com/pingcap/tidb/br/pkg/storage"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
+ "golang.org/x/exp/rand"
)
type trackOpenMemStorage struct {
@@ -314,6 +318,10 @@ func (p kvReaderPointerProxy) next() (*kvPair, error) {
return &kvPair{key: k, value: v}, nil
}
+func (p kvReaderPointerProxy) setReadMode(useConcurrency bool) {
+ p.r.byteReader.switchReaderMode(useConcurrency)
+}
+
func (p kvReaderPointerProxy) close() error {
return p.r.Close()
}
@@ -337,3 +345,73 @@ func BenchmarkPointerT(b *testing.B) {
}
}
}
+
+func TestMergeIterSwitchMode(t *testing.T) {
+ seed := time.Now().Unix()
+ rand.Seed(uint64(seed))
+ t.Logf("seed: %d", seed)
+
+ testMergeIterSwitchMode(t, func(key []byte, i int) []byte {
+ _, err := rand.Read(key)
+ require.NoError(t, err)
+ return key
+ })
+ testMergeIterSwitchMode(t, func(key []byte, i int) []byte {
+ _, err := rand.Read(key)
+ require.NoError(t, err)
+ binary.BigEndian.PutUint64(key, uint64(i))
+ return key
+ })
+ testMergeIterSwitchMode(t, func(key []byte, i int) []byte {
+ _, err := rand.Read(key)
+ require.NoError(t, err)
+ if (i/100000)%2 == 0 {
+ binary.BigEndian.PutUint64(key, uint64(i)<<40)
+ }
+ return key
+ })
+}
+
+func testMergeIterSwitchMode(t *testing.T, f func([]byte, int) []byte) {
+ st, clean := NewS3WithBucketAndPrefix(t, "test", "prefix/")
+ defer clean()
+
+ // Prepare
+ writer := NewWriterBuilder().
+ SetPropKeysDistance(100).
+ SetMemorySizeLimit(512*1024).
+ Build(st, "testprefix", "0")
+
+ ConcurrentReaderBufferSize = 4 * 1024
+
+ kvCount := 500000
+ keySize := 100
+ valueSize := 10
+ kvs := make([]common.KvPair, 1)
+ kvs[0] = common.KvPair{
+ Key: make([]byte, keySize),
+ Val: make([]byte, valueSize),
+ }
+ for i := 0; i < kvCount; i++ {
+ kvs[0].Key = f(kvs[0].Key, i)
+ _, err := rand.Read(kvs[0].Val[0:])
+ require.NoError(t, err)
+ err = writer.WriteRow(context.Background(), kvs[0].Key, kvs[0].Val, nil)
+ require.NoError(t, err)
+ }
+ err := writer.Close(context.Background())
+ require.NoError(t, err)
+
+ dataNames, _, err := GetAllFileNames(context.Background(), st, "")
+ require.NoError(t, err)
+
+ offsets := make([]uint64, len(dataNames))
+
+ iter, err := NewMergeKVIter(context.Background(), dataNames, offsets, st, 2048)
+ require.NoError(t, err)
+
+ for iter.Next() {
+ }
+ err = iter.Close()
+ require.NoError(t, err)
+}
diff --git a/br/pkg/lightning/backend/external/kv_reader.go b/br/pkg/lightning/backend/external/kv_reader.go
index 53a284c55c41c..fe2e0915d3c85 100644
--- a/br/pkg/lightning/backend/external/kv_reader.go
+++ b/br/pkg/lightning/backend/external/kv_reader.go
@@ -42,7 +42,6 @@ func newKVReader(
}
br, err := newByteReader(ctx, sr, bufSize, store, name, false)
if err != nil {
- br.Close()
return nil, err
}
return &kvReader{
diff --git a/br/pkg/lightning/backend/external/merge.go b/br/pkg/lightning/backend/external/merge.go
new file mode 100644
index 0000000000000..2e74159aac5b1
--- /dev/null
+++ b/br/pkg/lightning/backend/external/merge.go
@@ -0,0 +1,53 @@
+package external
+
+import (
+ "context"
+
+ "github.com/pingcap/tidb/br/pkg/storage"
+)
+
+// MergeOverlappingFiles reads from given files whose key range may overlap
+// and writes to new sorted, nonoverlapping files.
+func MergeOverlappingFiles(
+ ctx context.Context,
+ paths []string,
+ store storage.ExternalStorage,
+ readBufferSize int,
+ newFilePrefix string,
+ writerID string,
+ memSizeLimit uint64,
+ writeBatchCount uint64,
+ propSizeDist uint64,
+ propKeysDist uint64,
+ onClose OnCloseFunc,
+) error {
+ zeroOffsets := make([]uint64, len(paths))
+ iter, err := NewMergeKVIter(ctx, paths, zeroOffsets, store, readBufferSize)
+ if err != nil {
+ return err
+ }
+ defer iter.Close()
+
+ writer := NewWriterBuilder().
+ SetMemorySizeLimit(memSizeLimit).
+ SetWriterBatchCount(writeBatchCount).
+ SetPropKeysDistance(propKeysDist).
+ SetPropSizeDistance(propSizeDist).
+ SetOnCloseFunc(onClose).
+ Build(store, newFilePrefix, writerID)
+
+ // currently use same goroutine to do read and write. The main advantage is
+ // there's no KV copy and iter can reuse the buffer.
+
+ for iter.Next() {
+ err = writer.WriteRow(ctx, iter.Key(), iter.Value(), nil)
+ if err != nil {
+ return err
+ }
+ }
+ err = iter.Error()
+ if err != nil {
+ return err
+ }
+ return writer.Close(ctx)
+}
diff --git a/br/pkg/lightning/backend/external/util.go b/br/pkg/lightning/backend/external/util.go
index a3d4607d0162a..3a193c48d4d3c 100644
--- a/br/pkg/lightning/backend/external/util.go
+++ b/br/pkg/lightning/backend/external/util.go
@@ -211,33 +211,105 @@ const (
type Endpoint struct {
Key []byte
Tp EndpointTp
- Weight uint64 // all EndpointTp use positive weight
+ Weight int64 // all EndpointTp use positive weight
}
// GetMaxOverlapping returns the maximum overlapping weight treating given
// `points` as endpoints of intervals. `points` are not required to be sorted,
// and will be sorted in-place in this function.
-func GetMaxOverlapping(points []Endpoint) int {
+func GetMaxOverlapping(points []Endpoint) int64 {
slices.SortFunc(points, func(i, j Endpoint) int {
if cmp := bytes.Compare(i.Key, j.Key); cmp != 0 {
return cmp
}
return int(i.Tp) - int(j.Tp)
})
- var maxWeight uint64
- var curWeight uint64
+ var maxWeight int64
+ var curWeight int64
for _, p := range points {
switch p.Tp {
case InclusiveStart:
curWeight += p.Weight
- case ExclusiveEnd:
- curWeight -= p.Weight
- case InclusiveEnd:
+ case ExclusiveEnd, InclusiveEnd:
curWeight -= p.Weight
}
if curWeight > maxWeight {
maxWeight = curWeight
}
}
- return int(maxWeight)
+ return maxWeight
+}
+
+// SortedKVMeta is the meta of sorted kv.
+type SortedKVMeta struct {
+ MinKey []byte `json:"min-key"`
+ MaxKey []byte `json:"max-key"`
+ TotalKVSize uint64 `json:"total-kv-size"`
+ // seems those 2 fields always generated from MultipleFilesStats,
+ // maybe remove them later.
+ DataFiles []string `json:"data-files"`
+ StatFiles []string `json:"stat-files"`
+ MultipleFilesStats []MultipleFilesStat `json:"multiple-files-stats"`
+}
+
+// NewSortedKVMeta creates a SortedKVMeta from a WriterSummary.
+func NewSortedKVMeta(summary *WriterSummary) *SortedKVMeta {
+ meta := &SortedKVMeta{
+ MinKey: summary.Min.Clone(),
+ MaxKey: summary.Max.Clone(),
+ TotalKVSize: summary.TotalSize,
+ MultipleFilesStats: summary.MultipleFilesStats,
+ }
+ for _, f := range summary.MultipleFilesStats {
+ for _, filename := range f.Filenames {
+ meta.DataFiles = append(meta.DataFiles, filename[0])
+ meta.StatFiles = append(meta.StatFiles, filename[1])
+ }
+ }
+ return meta
+}
+
+// Merge merges the other SortedKVMeta into this one.
+func (m *SortedKVMeta) Merge(other *SortedKVMeta) {
+ m.MinKey = NotNilMin(m.MinKey, other.MinKey)
+ m.MaxKey = NotNilMax(m.MaxKey, other.MaxKey)
+ m.TotalKVSize += other.TotalKVSize
+
+ m.DataFiles = append(m.DataFiles, other.DataFiles...)
+ m.StatFiles = append(m.StatFiles, other.StatFiles...)
+
+ m.MultipleFilesStats = append(m.MultipleFilesStats, other.MultipleFilesStats...)
+}
+
+// MergeSummary merges the WriterSummary into this SortedKVMeta.
+func (m *SortedKVMeta) MergeSummary(summary *WriterSummary) {
+ m.Merge(NewSortedKVMeta(summary))
+}
+
+// NotNilMin returns the smallest of a and b, ignoring nil values.
+func NotNilMin(a, b []byte) []byte {
+ if len(a) == 0 {
+ return b
+ }
+ if len(b) == 0 {
+ return a
+ }
+ if bytes.Compare(a, b) < 0 {
+ return a
+ }
+ return b
+}
+
+// NotNilMax returns the largest of a and b, ignoring nil values.
+func NotNilMax(a, b []byte) []byte {
+ if len(a) == 0 {
+ return b
+ }
+ if len(b) == 0 {
+ return a
+ }
+ if bytes.Compare(a, b) > 0 {
+ return a
+ }
+ return b
}
diff --git a/br/pkg/lightning/backend/external/util_test.go b/br/pkg/lightning/backend/external/util_test.go
index e8090fc0e32f8..a031c4b373457 100644
--- a/br/pkg/lightning/backend/external/util_test.go
+++ b/br/pkg/lightning/backend/external/util_test.go
@@ -225,7 +225,7 @@ func TestGetMaxOverlapping(t *testing.T) {
{Key: []byte{2}, Tp: InclusiveStart, Weight: 1},
{Key: []byte{4}, Tp: ExclusiveEnd, Weight: 1},
}
- require.Equal(t, 2, GetMaxOverlapping(points))
+ require.EqualValues(t, 2, GetMaxOverlapping(points))
// [1, 3), [2, 4), [3, 5)
points = []Endpoint{
{Key: []byte{1}, Tp: InclusiveStart, Weight: 1},
@@ -235,7 +235,7 @@ func TestGetMaxOverlapping(t *testing.T) {
{Key: []byte{3}, Tp: InclusiveStart, Weight: 1},
{Key: []byte{5}, Tp: ExclusiveEnd, Weight: 1},
}
- require.Equal(t, 2, GetMaxOverlapping(points))
+ require.EqualValues(t, 2, GetMaxOverlapping(points))
// [1, 3], [2, 4], [3, 5]
points = []Endpoint{
{Key: []byte{1}, Tp: InclusiveStart, Weight: 1},
@@ -245,5 +245,84 @@ func TestGetMaxOverlapping(t *testing.T) {
{Key: []byte{3}, Tp: InclusiveStart, Weight: 1},
{Key: []byte{5}, Tp: InclusiveEnd, Weight: 1},
}
- require.Equal(t, 3, GetMaxOverlapping(points))
+ require.EqualValues(t, 3, GetMaxOverlapping(points))
+}
+
+func TestSortedKVMeta(t *testing.T) {
+ summary := []*WriterSummary{
+ {
+ Min: []byte("a"),
+ Max: []byte("b"),
+ TotalSize: 123,
+ MultipleFilesStats: []MultipleFilesStat{
+ {
+ Filenames: [][2]string{
+ {"f1", "stat1"},
+ {"f2", "stat2"},
+ },
+ },
+ },
+ },
+ {
+ Min: []byte("x"),
+ Max: []byte("y"),
+ TotalSize: 177,
+ MultipleFilesStats: []MultipleFilesStat{
+ {
+ Filenames: [][2]string{
+ {"f3", "stat3"},
+ {"f4", "stat4"},
+ },
+ },
+ },
+ },
+ }
+ meta0 := NewSortedKVMeta(summary[0])
+ require.Equal(t, []byte("a"), meta0.MinKey)
+ require.Equal(t, []byte("b"), meta0.MaxKey)
+ require.Equal(t, uint64(123), meta0.TotalKVSize)
+ require.Equal(t, []string{"f1", "f2"}, meta0.DataFiles)
+ require.Equal(t, []string{"stat1", "stat2"}, meta0.StatFiles)
+ require.Equal(t, summary[0].MultipleFilesStats, meta0.MultipleFilesStats)
+ meta1 := NewSortedKVMeta(summary[1])
+ require.Equal(t, []byte("x"), meta1.MinKey)
+ require.Equal(t, []byte("y"), meta1.MaxKey)
+ require.Equal(t, uint64(177), meta1.TotalKVSize)
+ require.Equal(t, []string{"f3", "f4"}, meta1.DataFiles)
+ require.Equal(t, []string{"stat3", "stat4"}, meta1.StatFiles)
+ require.Equal(t, summary[1].MultipleFilesStats, meta1.MultipleFilesStats)
+
+ meta0.MergeSummary(summary[1])
+ require.Equal(t, []byte("a"), meta0.MinKey)
+ require.Equal(t, []byte("y"), meta0.MaxKey)
+ require.Equal(t, uint64(300), meta0.TotalKVSize)
+ require.Equal(t, []string{"f1", "f2", "f3", "f4"}, meta0.DataFiles)
+ require.Equal(t, []string{"stat1", "stat2", "stat3", "stat4"}, meta0.StatFiles)
+ mergedStats := append([]MultipleFilesStat{}, summary[0].MultipleFilesStats...)
+ mergedStats = append(mergedStats, summary[1].MultipleFilesStats...)
+ require.Equal(t, mergedStats, meta0.MultipleFilesStats)
+
+ meta00 := NewSortedKVMeta(summary[0])
+ meta00.Merge(meta1)
+ require.Equal(t, meta0, meta00)
+}
+
+func TestKeyMinMax(t *testing.T) {
+ require.Equal(t, []byte(nil), NotNilMin(nil, nil))
+ require.Equal(t, []byte{}, NotNilMin(nil, []byte{}))
+ require.Equal(t, []byte(nil), NotNilMin([]byte{}, nil))
+ require.Equal(t, []byte("a"), NotNilMin([]byte("a"), nil))
+ require.Equal(t, []byte("a"), NotNilMin([]byte("a"), []byte{}))
+ require.Equal(t, []byte("a"), NotNilMin(nil, []byte("a")))
+ require.Equal(t, []byte("a"), NotNilMin([]byte("a"), []byte("b")))
+ require.Equal(t, []byte("a"), NotNilMin([]byte("b"), []byte("a")))
+
+ require.Equal(t, []byte(nil), NotNilMax(nil, nil))
+ require.Equal(t, []byte{}, NotNilMax(nil, []byte{}))
+ require.Equal(t, []byte(nil), NotNilMax([]byte{}, nil))
+ require.Equal(t, []byte("a"), NotNilMax([]byte("a"), nil))
+ require.Equal(t, []byte("a"), NotNilMax([]byte("a"), []byte{}))
+ require.Equal(t, []byte("a"), NotNilMax(nil, []byte("a")))
+ require.Equal(t, []byte("b"), NotNilMax([]byte("a"), []byte("b")))
+ require.Equal(t, []byte("b"), NotNilMax([]byte("b"), []byte("a")))
}
diff --git a/br/pkg/lightning/backend/external/writer.go b/br/pkg/lightning/backend/external/writer.go
index 7ebcc8d6cf71a..a28f1aa66df7b 100644
--- a/br/pkg/lightning/backend/external/writer.go
+++ b/br/pkg/lightning/backend/external/writer.go
@@ -24,6 +24,9 @@ import (
"time"
"github.com/pingcap/errors"
+ "github.com/pingcap/tidb/br/pkg/lightning/backend"
+ "github.com/pingcap/tidb/br/pkg/lightning/backend/encode"
+ "github.com/pingcap/tidb/br/pkg/lightning/backend/kv"
"github.com/pingcap/tidb/br/pkg/lightning/common"
"github.com/pingcap/tidb/br/pkg/membuf"
"github.com/pingcap/tidb/br/pkg/storage"
@@ -33,7 +36,20 @@ import (
"go.uber.org/zap"
)
-var multiFileStatNum = 500
+var (
+ multiFileStatNum = 500
+
+ // MergeSortOverlapThreshold is the threshold of overlap between sorted kv files.
+ // if the overlap ratio is greater than this threshold, we will merge the files.
+ MergeSortOverlapThreshold int64 = 1000
+ // MergeSortFileCountStep is the step of file count when we split the sorted kv files.
+ MergeSortFileCountStep = 1000
+)
+
+const (
+ // DefaultMemSizeLimit is the default memory size limit for writer.
+ DefaultMemSizeLimit = 256 * size.MB
+)
// rangePropertiesCollector collects range properties for each range. The zero
// value of rangePropertiesCollector is not ready to use, should call reset()
@@ -58,8 +74,11 @@ func (rc *rangePropertiesCollector) encode() []byte {
// WriterSummary is the summary of a writer.
type WriterSummary struct {
- WriterID string
- Seq int
+ WriterID string
+ Seq int
+ // Min and Max are the min and max key written by this writer, both are
+ // inclusive, i.e. [Min, Max].
+ // will be empty if no key is written.
Min tidbkv.Key
Max tidbkv.Key
TotalSize uint64
@@ -69,17 +88,17 @@ type WriterSummary struct {
// OnCloseFunc is the callback function when a writer is closed.
type OnCloseFunc func(summary *WriterSummary)
-// DummyOnCloseFunc is a dummy OnCloseFunc.
-func DummyOnCloseFunc(*WriterSummary) {}
+// dummyOnCloseFunc is a dummy OnCloseFunc.
+func dummyOnCloseFunc(*WriterSummary) {}
// WriterBuilder builds a new Writer.
type WriterBuilder struct {
- memSizeLimit uint64
- writeBatchCount uint64
- propSizeDist uint64
- propKeysDist uint64
- onClose OnCloseFunc
- dupeDetectEnabled bool
+ memSizeLimit uint64
+ writeBatchCount uint64
+ propSizeDist uint64
+ propKeysDist uint64
+ onClose OnCloseFunc
+ keyDupeEncoding bool
bufferPool *membuf.Pool
}
@@ -87,11 +106,11 @@ type WriterBuilder struct {
// NewWriterBuilder creates a WriterBuilder.
func NewWriterBuilder() *WriterBuilder {
return &WriterBuilder{
- memSizeLimit: 256 * size.MB,
+ memSizeLimit: DefaultMemSizeLimit,
writeBatchCount: 8 * 1024,
propSizeDist: 1 * size.MB,
propKeysDist: 8 * 1024,
- onClose: DummyOnCloseFunc,
+ onClose: dummyOnCloseFunc,
}
}
@@ -123,6 +142,9 @@ func (b *WriterBuilder) SetPropKeysDistance(dist uint64) *WriterBuilder {
// SetOnCloseFunc sets the callback function when a writer is closed.
func (b *WriterBuilder) SetOnCloseFunc(onClose OnCloseFunc) *WriterBuilder {
+ if onClose == nil {
+ onClose = dummyOnCloseFunc
+ }
b.onClose = onClose
return b
}
@@ -133,9 +155,9 @@ func (b *WriterBuilder) SetBufferPool(bufferPool *membuf.Pool) *WriterBuilder {
return b
}
-// EnableDuplicationDetection enables the duplication detection of the writer.
-func (b *WriterBuilder) EnableDuplicationDetection() *WriterBuilder {
- b.dupeDetectEnabled = true
+// SetKeyDuplicationEncoding sets if the writer can distinguish duplicate key.
+func (b *WriterBuilder) SetKeyDuplicationEncoding(val bool) *WriterBuilder {
+ b.keyDupeEncoding = val
return b
}
@@ -152,7 +174,7 @@ func (b *WriterBuilder) Build(
}
filenamePrefix := filepath.Join(prefix, writerID)
keyAdapter := common.KeyAdapter(common.NoopKeyAdapter{})
- if b.dupeDetectEnabled {
+ if b.keyDupeEncoding {
keyAdapter = common.DupDetectKeyAdapter{}
}
ret := &Writer{
@@ -185,10 +207,10 @@ func (b *WriterBuilder) Build(
// every 500 files). It is used to estimate the data overlapping, and per-file
// statistic information maybe too big to loaded into memory.
type MultipleFilesStat struct {
- MinKey tidbkv.Key
- MaxKey tidbkv.Key
- Filenames [][2]string // [dataFile, statFile]
- MaxOverlappingNum int
+ MinKey tidbkv.Key `json:"min-key"`
+ MaxKey tidbkv.Key `json:"max-key"`
+ Filenames [][2]string `json:"filenames"` // [dataFile, statFile]
+ MaxOverlappingNum int64 `json:"max-overlapping-num"`
}
func (m *MultipleFilesStat) build(startKeys, endKeys []tidbkv.Key) {
@@ -216,6 +238,20 @@ func (m *MultipleFilesStat) build(startKeys, endKeys []tidbkv.Key) {
m.MaxOverlappingNum = GetMaxOverlapping(points)
}
+// GetMaxOverlappingTotal assume the most overlapping case from given stats and
+// returns the overlapping level.
+func GetMaxOverlappingTotal(stats []MultipleFilesStat) int64 {
+ points := make([]Endpoint, 0, len(stats)*2)
+ for _, stat := range stats {
+ points = append(points, Endpoint{Key: stat.MinKey, Tp: InclusiveStart, Weight: stat.MaxOverlappingNum})
+ }
+ for _, stat := range stats {
+ points = append(points, Endpoint{Key: stat.MaxKey, Tp: InclusiveEnd, Weight: stat.MaxOverlappingNum})
+ }
+
+ return GetMaxOverlapping(points)
+}
+
// Writer is used to write data into external storage.
type Writer struct {
store storage.ExternalStorage
@@ -420,7 +456,44 @@ func (w *Writer) createStorageWriter(ctx context.Context) (
statPath := filepath.Join(w.filenamePrefix+statSuffix, strconv.Itoa(w.currentSeq))
statsWriter, err := w.store.Create(ctx, statPath, &storage.WriterOption{Concurrency: 20})
if err != nil {
+ _ = dataWriter.Close(ctx)
return "", "", nil, nil, err
}
return dataPath, statPath, dataWriter, statsWriter, nil
}
+
+// EngineWriter implements backend.EngineWriter interface.
+type EngineWriter struct {
+ w *Writer
+}
+
+// NewEngineWriter creates a new EngineWriter.
+func NewEngineWriter(w *Writer) *EngineWriter {
+ return &EngineWriter{w: w}
+}
+
+// AppendRows implements backend.EngineWriter interface.
+func (e *EngineWriter) AppendRows(ctx context.Context, _ []string, rows encode.Rows) error {
+ kvs := kv.Rows2KvPairs(rows)
+ if len(kvs) == 0 {
+ return nil
+ }
+ for _, item := range kvs {
+ err := e.w.WriteRow(ctx, item.Key, item.Val, nil)
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// IsSynced implements backend.EngineWriter interface.
+func (e *EngineWriter) IsSynced() bool {
+ // only used when saving checkpoint
+ return true
+}
+
+// Close implements backend.EngineWriter interface.
+func (e *EngineWriter) Close(ctx context.Context) (backend.ChunkFlushStatus, error) {
+ return nil, e.w.Close(ctx)
+}
diff --git a/br/pkg/lightning/backend/external/writer_test.go b/br/pkg/lightning/backend/external/writer_test.go
index 2608324a4f809..c32876e46198c 100644
--- a/br/pkg/lightning/backend/external/writer_test.go
+++ b/br/pkg/lightning/backend/external/writer_test.go
@@ -26,9 +26,11 @@ import (
"time"
"github.com/cockroachdb/pebble"
+ "github.com/pingcap/tidb/br/pkg/lightning/backend/kv"
"github.com/pingcap/tidb/br/pkg/lightning/common"
"github.com/pingcap/tidb/br/pkg/storage"
dbkv "github.com/pingcap/tidb/kv"
+ "github.com/pingcap/tidb/util/size"
"github.com/stretchr/testify/require"
"golang.org/x/exp/rand"
)
@@ -40,11 +42,13 @@ func TestWriter(t *testing.T) {
ctx := context.Background()
memStore := storage.NewMemStorage()
- writer := NewWriterBuilder().
+ w := NewWriterBuilder().
SetPropSizeDistance(100).
SetPropKeysDistance(2).
Build(memStore, "/test", "0")
+ writer := NewEngineWriter(w)
+
kvCnt := rand.Intn(10) + 10
kvs := make([]common.KvPair, kvCnt)
for i := 0; i < kvCnt; i++ {
@@ -57,12 +61,9 @@ func TestWriter(t *testing.T) {
_, err = rand.Read(kvs[i].Val)
require.NoError(t, err)
}
- for _, pair := range kvs {
- err := writer.WriteRow(ctx, pair.Key, pair.Val, nil)
- require.NoError(t, err)
- }
- err := writer.Close(ctx)
+ require.NoError(t, writer.AppendRows(ctx, nil, kv.MakeRowsFromKvPairs(kvs)))
+ _, err := writer.Close(ctx)
require.NoError(t, err)
slices.SortFunc(kvs, func(i, j common.KvPair) int {
@@ -152,7 +153,7 @@ func TestWriterDuplicateDetect(t *testing.T) {
writer := NewWriterBuilder().
SetPropKeysDistance(2).
SetMemorySizeLimit(1000).
- EnableDuplicationDetection().
+ SetKeyDuplicationEncoding(true).
Build(memStore, "/test", "0")
kvCount := 20
for i := 0; i < kvCount; i++ {
@@ -167,10 +168,26 @@ func TestWriterDuplicateDetect(t *testing.T) {
err := writer.Close(ctx)
require.NoError(t, err)
+ // test MergeOverlappingFiles will not change duplicate detection functionality.
+ err = MergeOverlappingFiles(
+ ctx,
+ []string{"/test/0/0"},
+ memStore,
+ 100,
+ "/test2",
+ "mergeID",
+ 1000,
+ 8*1024,
+ 1*size.MB,
+ 2,
+ nil,
+ )
+ require.NoError(t, err)
+
keys := make([][]byte, 0, kvCount)
values := make([][]byte, 0, kvCount)
- kvReader, err := newKVReader(ctx, "/test/0/0", memStore, 0, 100)
+ kvReader, err := newKVReader(ctx, "/test2/mergeID/0", memStore, 0, 100)
require.NoError(t, err)
for i := 0; i < kvCount; i++ {
key, value, err := kvReader.nextKV()
@@ -218,6 +235,19 @@ func TestMultiFileStat(t *testing.T) {
require.EqualValues(t, 3, s.MaxOverlappingNum)
}
+func TestMultiFileStatOverlap(t *testing.T) {
+ s1 := MultipleFilesStat{MinKey: dbkv.Key{1}, MaxKey: dbkv.Key{100}, MaxOverlappingNum: 100}
+ s2 := MultipleFilesStat{MinKey: dbkv.Key{5}, MaxKey: dbkv.Key{102}, MaxOverlappingNum: 90}
+ s3 := MultipleFilesStat{MinKey: dbkv.Key{111}, MaxKey: dbkv.Key{200}, MaxOverlappingNum: 200}
+ require.EqualValues(t, 200, GetMaxOverlappingTotal([]MultipleFilesStat{s1, s2, s3}))
+
+ s3.MaxOverlappingNum = 70
+ require.EqualValues(t, 190, GetMaxOverlappingTotal([]MultipleFilesStat{s1, s2, s3}))
+
+ s3.MinKey = dbkv.Key{0}
+ require.EqualValues(t, 260, GetMaxOverlappingTotal([]MultipleFilesStat{s1, s2, s3}))
+}
+
func TestWriterMultiFileStat(t *testing.T) {
oldMultiFileStatNum := multiFileStatNum
t.Cleanup(func() {
@@ -228,13 +258,14 @@ func TestWriterMultiFileStat(t *testing.T) {
ctx := context.Background()
memStore := storage.NewMemStorage()
var summary *WriterSummary
+ closeFn := func(s *WriterSummary) {
+ summary = s
+ }
writer := NewWriterBuilder().
SetPropKeysDistance(2).
SetMemorySizeLimit(20). // 2 KV pair will trigger flush
- SetOnCloseFunc(func(s *WriterSummary) {
- summary = s
- }).
+ SetOnCloseFunc(closeFn).
Build(memStore, "/test", "0")
kvs := make([]common.KvPair, 0, 18)
@@ -326,4 +357,60 @@ func TestWriterMultiFileStat(t *testing.T) {
require.Equal(t, expected, summary.MultipleFilesStats[2])
require.EqualValues(t, "key01", summary.Min)
require.EqualValues(t, "key24", summary.Max)
+
+ allDataFiles := make([]string, 9)
+ for i := range allDataFiles {
+ allDataFiles[i] = fmt.Sprintf("/test/0/%d", i)
+ }
+
+ err = MergeOverlappingFiles(
+ ctx,
+ allDataFiles,
+ memStore,
+ 100,
+ "/test2",
+ "mergeID",
+ 20,
+ 8*1024,
+ 1*size.MB,
+ 2,
+ closeFn,
+ )
+ require.NoError(t, err)
+ require.Equal(t, 3, len(summary.MultipleFilesStats))
+ expected = MultipleFilesStat{
+ MinKey: []byte("key01"),
+ MaxKey: []byte("key06"),
+ Filenames: [][2]string{
+ {"/test2/mergeID/0", "/test2/mergeID_stat/0"},
+ {"/test2/mergeID/1", "/test2/mergeID_stat/1"},
+ {"/test2/mergeID/2", "/test2/mergeID_stat/2"},
+ },
+ MaxOverlappingNum: 1,
+ }
+ require.Equal(t, expected, summary.MultipleFilesStats[0])
+ expected = MultipleFilesStat{
+ MinKey: []byte("key11"),
+ MaxKey: []byte("key16"),
+ Filenames: [][2]string{
+ {"/test2/mergeID/3", "/test2/mergeID_stat/3"},
+ {"/test2/mergeID/4", "/test2/mergeID_stat/4"},
+ {"/test2/mergeID/5", "/test2/mergeID_stat/5"},
+ },
+ MaxOverlappingNum: 1,
+ }
+ require.Equal(t, expected, summary.MultipleFilesStats[1])
+ expected = MultipleFilesStat{
+ MinKey: []byte("key20"),
+ MaxKey: []byte("key24"),
+ Filenames: [][2]string{
+ {"/test2/mergeID/6", "/test2/mergeID_stat/6"},
+ {"/test2/mergeID/7", "/test2/mergeID_stat/7"},
+ {"/test2/mergeID/8", "/test2/mergeID_stat/8"},
+ },
+ MaxOverlappingNum: 1,
+ }
+ require.Equal(t, expected, summary.MultipleFilesStats[2])
+ require.EqualValues(t, "key01", summary.Min)
+ require.EqualValues(t, "key24", summary.Max)
}
diff --git a/br/pkg/lightning/backend/local/BUILD.bazel b/br/pkg/lightning/backend/local/BUILD.bazel
index d275cace49baf..33fbc53f058eb 100644
--- a/br/pkg/lightning/backend/local/BUILD.bazel
+++ b/br/pkg/lightning/backend/local/BUILD.bazel
@@ -55,8 +55,10 @@ go_library(
"//tablecodec",
"//types",
"//util/codec",
+ "//util/compress",
"//util/engine",
"//util/hack",
+ "//util/intest",
"//util/mathutil",
"//util/ranger",
"@com_github_cockroachdb_pebble//:pebble",
diff --git a/br/pkg/lightning/backend/local/compress.go b/br/pkg/lightning/backend/local/compress.go
index c9315d33b3225..4e5375973feea 100644
--- a/br/pkg/lightning/backend/local/compress.go
+++ b/br/pkg/lightning/backend/local/compress.go
@@ -16,9 +16,9 @@ package local
import (
"io"
- "sync"
"github.com/klauspost/compress/gzip" // faster than stdlib
+ "github.com/pingcap/tidb/util/compress"
"google.golang.org/grpc"
)
@@ -29,15 +29,9 @@ var (
type gzipCompressor struct{}
-var gzipWriterPool = sync.Pool{
- New: func() any {
- return gzip.NewWriter(io.Discard)
- },
-}
-
func (*gzipCompressor) Do(w io.Writer, p []byte) error {
- z := gzipWriterPool.Get().(*gzip.Writer)
- defer gzipWriterPool.Put(z)
+ z := compress.GzipWriterPool.Get().(*gzip.Writer)
+ defer compress.GzipWriterPool.Put(z)
z.Reset(w)
if _, err := z.Write(p); err != nil {
return err
@@ -51,22 +45,16 @@ func (*gzipCompressor) Type() string {
type gzipDecompressor struct{}
-var gzipReaderPool = sync.Pool{
- New: func() any {
- return &gzip.Reader{}
- },
-}
-
func (*gzipDecompressor) Do(r io.Reader) ([]byte, error) {
- z := gzipReaderPool.Get().(*gzip.Reader)
+ z := compress.GzipReaderPool.Get().(*gzip.Reader)
if err := z.Reset(r); err != nil {
- gzipReaderPool.Put(z)
+ compress.GzipReaderPool.Put(z)
return nil, err
}
defer func() {
_ = z.Close()
- gzipReaderPool.Put(z)
+ compress.GzipReaderPool.Put(z)
}()
return io.ReadAll(z)
}
diff --git a/br/pkg/lightning/backend/local/local.go b/br/pkg/lightning/backend/local/local.go
index b8dafdaeee342..efb7652307d40 100644
--- a/br/pkg/lightning/backend/local/local.go
+++ b/br/pkg/lightning/backend/local/local.go
@@ -59,6 +59,7 @@ import (
"github.com/pingcap/tidb/tablecodec"
"github.com/pingcap/tidb/util/codec"
"github.com/pingcap/tidb/util/engine"
+ "github.com/pingcap/tidb/util/intest"
"github.com/pingcap/tidb/util/mathutil"
"github.com/tikv/client-go/v2/oracle"
tikvclient "github.com/tikv/client-go/v2/tikv"
@@ -939,7 +940,11 @@ func (local *Backend) CloseEngine(ctx context.Context, cfg *backend.EngineConfig
if err != nil {
return err
}
- store, err := storage.New(ctx, storeBackend, nil)
+ opt := &storage.ExternalStorageOptions{}
+ if intest.InTest {
+ opt.NoCredentials = true
+ }
+ store, err := storage.New(ctx, storeBackend, opt)
if err != nil {
return err
}
@@ -1469,7 +1474,7 @@ func (local *Backend) ImportEngine(
log.FromContext(ctx).Info("engine contains no kv, skip import", zap.Stringer("engine", engineUUID))
return nil
}
- kvRegionSplitSize, kvRegionSplitKeys, err := getRegionSplitSizeKeys(ctx, local.pdCtl.GetPDClient(), local.tls)
+ kvRegionSplitSize, kvRegionSplitKeys, err := GetRegionSplitSizeKeys(ctx, local.pdCtl.GetPDClient(), local.tls)
if err == nil {
if kvRegionSplitSize > regionSplitSize {
regionSplitSize = kvRegionSplitSize
@@ -1549,7 +1554,7 @@ func (local *Backend) ImportEngine(
// GetRegionSplitSizeKeys gets the region split size and keys from PD.
func (local *Backend) GetRegionSplitSizeKeys(ctx context.Context) (finalSize int64, finalKeys int64, err error) {
- return getRegionSplitSizeKeys(ctx, local.pdCtl.GetPDClient(), local.tls)
+ return GetRegionSplitSizeKeys(ctx, local.pdCtl.GetPDClient(), local.tls)
}
// expose these variables to unit test.
@@ -1689,6 +1694,16 @@ func (local *Backend) GetImportedKVCount(engineUUID uuid.UUID) int64 {
return e.importedKVCount.Load()
}
+// GetExternalEngineKVStatistics returns kv statistics of some engine.
+func (local *Backend) GetExternalEngineKVStatistics(engineUUID uuid.UUID) (
+ totalKVSize int64, totalKVCount int64) {
+ v, ok := local.externalEngine[engineUUID]
+ if !ok {
+ return 0, 0
+ }
+ return v.KVStatistics()
+}
+
// ResetEngine reset the engine and reclaim the space.
func (local *Backend) ResetEngine(ctx context.Context, engineUUID uuid.UUID) error {
// the only way to reset the engine + reclaim the space is to delete and reopen it 🤷
@@ -1927,8 +1942,8 @@ func getSplitConfFromStore(ctx context.Context, host string, tls *common.TLS) (
return splitSize, nested.Coprocessor.RegionSplitKeys, nil
}
-// return region split size, region split keys, error
-func getRegionSplitSizeKeys(ctx context.Context, cli pd.Client, tls *common.TLS) (
+// GetRegionSplitSizeKeys return region split size, region split keys, error
+func GetRegionSplitSizeKeys(ctx context.Context, cli pd.Client, tls *common.TLS) (
regionSplitSize int64, regionSplitKeys int64, err error) {
stores, err := cli.GetAllStores(ctx, pd.WithExcludeTombstone())
if err != nil {
diff --git a/br/pkg/lightning/backend/local/local_test.go b/br/pkg/lightning/backend/local/local_test.go
index 1976d97e00f52..d1ae8846087f3 100644
--- a/br/pkg/lightning/backend/local/local_test.go
+++ b/br/pkg/lightning/backend/local/local_test.go
@@ -1123,7 +1123,7 @@ func TestGetRegionSplitSizeKeys(t *testing.T) {
}
return 0, 0, errors.New("invalid connection")
}
- splitSize, splitKeys, err := getRegionSplitSizeKeys(ctx, cli, nil)
+ splitSize, splitKeys, err := GetRegionSplitSizeKeys(ctx, cli, nil)
require.NoError(t, err)
require.Equal(t, int64(1), splitSize)
require.Equal(t, int64(2), splitKeys)
@@ -2255,3 +2255,13 @@ func TestExternalEngine(t *testing.T) {
}
require.Equal(t, 100, kvIdx)
}
+
+func TestGetExternalEngineKVStatistics(t *testing.T) {
+ b := Backend{
+ externalEngine: map[uuid.UUID]common.Engine{},
+ }
+ // non existent uuid
+ size, count := b.GetExternalEngineKVStatistics(uuid.New())
+ require.Zero(t, size)
+ require.Zero(t, count)
+}
diff --git a/br/pkg/lightning/importer/import.go b/br/pkg/lightning/importer/import.go
index 9ae64c784403a..2f41331b3ce1a 100644
--- a/br/pkg/lightning/importer/import.go
+++ b/br/pkg/lightning/importer/import.go
@@ -978,7 +978,7 @@ func verifyLocalFile(ctx context.Context, cpdb checkpoints.DB, dir string) error
}
for tableName, engineIDs := range targetTables {
for _, engineID := range engineIDs {
- _, eID := backend.MakeUUID(tableName, engineID)
+ _, eID := backend.MakeUUID(tableName, int64(engineID))
file := local.Engine{UUID: eID}
err := file.Exist(dir)
if err != nil {
diff --git a/br/pkg/storage/gcs.go b/br/pkg/storage/gcs.go
index f32d4344d9a83..7c6deb7aaac3f 100644
--- a/br/pkg/storage/gcs.go
+++ b/br/pkg/storage/gcs.go
@@ -296,6 +296,9 @@ func NewGCSStorage(ctx context.Context, gcs *backuppb.GCS, opts *ExternalStorage
if gcs.Endpoint != "" {
clientOps = append(clientOps, option.WithEndpoint(gcs.Endpoint))
}
+ // the HTTPClient should has credential, currently the HTTPClient only has the http.Transport.
+ // So we remove the HTTPClient in the storage.New().
+ // Issue: https: //github.com/pingcap/tidb/issues/47022
if opts.HTTPClient != nil {
clientOps = append(clientOps, option.WithHTTPClient(opts.HTTPClient))
}
diff --git a/br/pkg/storage/storage.go b/br/pkg/storage/storage.go
index e4624e2ed475e..7ed15ce2d16bf 100644
--- a/br/pkg/storage/storage.go
+++ b/br/pkg/storage/storage.go
@@ -144,7 +144,9 @@ type ExternalStorageOptions struct {
NoCredentials bool
// HTTPClient to use. The created storage may ignore this field if it is not
- // directly using HTTP (e.g. the local storage).
+ // directly using HTTP (e.g. the local storage) or use self-design HTTP client
+ // with credential (e.g. the gcs).
+ // NOTICE: the HTTPClient is only used by s3 storage and azure blob storage.
HTTPClient *http.Client
// CheckPermissions check the given permission in New() function.
@@ -197,6 +199,9 @@ func New(ctx context.Context, backend *backuppb.StorageBackend, opts *ExternalSt
if backend.Gcs == nil {
return nil, errors.Annotate(berrors.ErrStorageInvalidConfig, "GCS config not found")
}
+ // the HTTPClient should has credential, currently the HTTPClient only has the http.Transport.
+ // Issue: https: //github.com/pingcap/tidb/issues/47022
+ opts.HTTPClient = nil
return NewGCSStorage(ctx, backend.Gcs, opts)
case *backuppb.StorageBackend_AzureBlobStorage:
return newAzureBlobStorage(ctx, backend.AzureBlobStorage, opts)
diff --git a/build/nogo_config.json b/build/nogo_config.json
index d962e816e8ae1..18c0685ab77aa 100644
--- a/build/nogo_config.json
+++ b/build/nogo_config.json
@@ -627,6 +627,7 @@
"br/pkg/backup/": "br/pkg/backup/ coded",
"planer/core/casetest/binaryplan": "planer/core/casetest/binaryplan",
"planer/core/casetest/cbotest": "planer/core/casetest/cbotest",
+ "planer/core/casetest/dag": "planer/core/casetest/dag",
"planer/core/casetest/enforcempp": "planer/core/casetest/enforcempp",
"planer/core/casetest/flatplan": "planer/core/casetest/flatplan",
"planer/core/casetest/hint": "planer/core/casetest/hint",
diff --git a/cmd/mirror/mirror.go b/cmd/mirror/mirror.go
index 38b335e7be860..a8612d0a79af8 100644
--- a/cmd/mirror/mirror.go
+++ b/cmd/mirror/mirror.go
@@ -197,7 +197,7 @@ func downloadZips(
cmd := exec.Command(gobin, downloadArgs...)
cmd.Dir = tmpdir
env := os.Environ()
- env = append(env, fmt.Sprintf("GOPROXY=%s", "https://mirrors.aliyun.com/goproxy/,https://proxy.golang.org,direct"))
+ env = append(env, fmt.Sprintf("GOPROXY=%s", "https://mirrors.aliyun.com/goproxy/,http://goproxy.apps.svc,https://proxy.golang.org,direct"))
env = append(env, fmt.Sprintf("GOSUMDB=%s", "sum.golang.org"))
cmd.Env = env
jsonBytes, err := cmd.Output()
@@ -228,7 +228,7 @@ func listAllModules(tmpdir string) (map[string]listedModule, error) {
cmd := exec.Command(gobin, "list", "-mod=readonly", "-m", "-json", "all")
cmd.Dir = tmpdir
env := os.Environ()
- env = append(env, fmt.Sprintf("GOPROXY=%s", "https://mirrors.aliyun.com/goproxy/,https://proxy.golang.org,direct"))
+ env = append(env, fmt.Sprintf("GOPROXY=%s", "https://mirrors.aliyun.com/goproxy/,http://goproxy.apps.svc,https://proxy.golang.org,direct"))
env = append(env, fmt.Sprintf("GOSUMDB=%s", "sum.golang.org"))
cmd.Env = env
jsonBytes, err := cmd.Output()
diff --git a/ddl/BUILD.bazel b/ddl/BUILD.bazel
index f7fda21cb115e..b35206b6be329 100644
--- a/ddl/BUILD.bazel
+++ b/ddl/BUILD.bazel
@@ -13,9 +13,11 @@ go_library(
srcs = [
"backfilling.go",
"backfilling_dispatcher.go",
+ "backfilling_dist_scheduler.go",
"backfilling_import_cloud.go",
"backfilling_import_local.go",
- "backfilling_operator.go",
+ "backfilling_merge_sort.go",
+ "backfilling_operators.go",
"backfilling_read_index.go",
"backfilling_scheduler.go",
"callback.go",
@@ -50,7 +52,6 @@ go_library(
"schema.go",
"sequence.go",
"split_region.go",
- "stage_scheduler.go",
"stat.go",
"table.go",
"table_lock.go",
@@ -124,6 +125,7 @@ go_library(
"//util/codec",
"//util/collate",
"//util/dbterror",
+ "//util/disttask",
"//util/domainutil",
"//util/filter",
"//util/gcutil",
@@ -138,6 +140,7 @@ go_library(
"//util/rowDecoder",
"//util/rowcodec",
"//util/set",
+ "//util/size",
"//util/slice",
"//util/sqlexec",
"//util/stringutil",
diff --git a/ddl/backfilling_dispatcher.go b/ddl/backfilling_dispatcher.go
index a49852718f561..fa3e226c8894b 100644
--- a/ddl/backfilling_dispatcher.go
+++ b/ddl/backfilling_dispatcher.go
@@ -37,13 +37,16 @@ import (
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/store/helper"
"github.com/pingcap/tidb/table"
+ disttaskutil "github.com/pingcap/tidb/util/disttask"
+ "github.com/pingcap/tidb/util/intest"
"github.com/pingcap/tidb/util/logutil"
"github.com/tikv/client-go/v2/tikv"
"go.uber.org/zap"
)
type backfillingDispatcherExt struct {
- d *ddl
+ d *ddl
+ previousSchedulerIDs []string
}
var _ dispatcher.Extension = (*backfillingDispatcherExt)(nil)
@@ -63,68 +66,91 @@ func NewBackfillingDispatcherExt(d DDL) (dispatcher.Extension, error) {
func (*backfillingDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {
}
-// OnNextSubtasksBatch generate batch of next stage's plan.
-func (h *backfillingDispatcherExt) OnNextSubtasksBatch(ctx context.Context,
+// OnNextSubtasksBatch generate batch of next step's plan.
+func (h *backfillingDispatcherExt) OnNextSubtasksBatch(
+ ctx context.Context,
taskHandle dispatcher.TaskHandle,
gTask *proto.Task,
+ step int64,
) (taskMeta [][]byte, err error) {
var gTaskMeta BackfillGlobalMeta
if err := json.Unmarshal(gTask.Meta, &gTaskMeta); err != nil {
return nil, err
}
+ job := &gTaskMeta.Job
useExtStore := len(gTaskMeta.CloudStorageURI) > 0
defer func() {
// Only redact when the task is complete.
if len(taskMeta) == 0 && useExtStore {
redactCloudStorageURI(ctx, gTask, &gTaskMeta)
- if err := taskHandle.UpdateTask(gTask.State, nil, dispatcher.RetrySQLTimes); err != nil {
- logutil.Logger(ctx).Error("failed to UpdateTask", zap.Error(err))
- }
}
}()
- job := &gTaskMeta.Job
tblInfo, err := getTblInfo(h.d, job)
if err != nil {
return nil, err
}
- // generate partition table's plan.
- if tblInfo.Partition != nil {
- switch gTask.Step {
- case proto.StepInit:
+
+ switch step {
+ case proto.StepOne:
+ if tblInfo.Partition != nil {
return generatePartitionPlan(tblInfo)
- case proto.StepOne:
- if useExtStore {
- return generateMergeSortPlan(ctx, taskHandle, gTask, job.ID, gTaskMeta.CloudStorageURI)
- }
- return nil, nil
- default:
- return nil, nil
}
- }
- // generate non-partition table's plan.
- switch gTask.Step {
- case proto.StepInit:
return generateNonPartitionPlan(h.d, tblInfo, job)
- case proto.StepOne:
+ case proto.StepTwo:
+ return generateMergePlan(taskHandle, gTask)
+ case proto.StepThree:
if useExtStore {
return generateMergeSortPlan(ctx, taskHandle, gTask, job.ID, gTaskMeta.CloudStorageURI)
}
- return generateIngestTaskPlan(ctx)
+ if tblInfo.Partition != nil {
+ return nil, nil
+ }
+ return generateIngestTaskPlan(ctx, h, taskHandle, gTask)
default:
return nil, nil
}
}
-// StageFinished check if current stage finished.
-func (*backfillingDispatcherExt) StageFinished(_ *proto.Task) bool {
- return true
-}
-
-// Finished check if current task finished.
-func (*backfillingDispatcherExt) Finished(task *proto.Task) bool {
- return task.Step == proto.StepOne
+func (*backfillingDispatcherExt) GetNextStep(
+ taskHandle dispatcher.TaskHandle,
+ task *proto.Task,
+) int64 {
+ switch task.Step {
+ case proto.StepInit:
+ return proto.StepOne
+ case proto.StepOne:
+ // when in tests
+ if taskHandle == nil {
+ return proto.StepThree
+ }
+ // if data files overlaps too much, we need a merge step.
+ subTaskMetas, err := taskHandle.GetPreviousSubtaskMetas(task.ID, proto.StepInit)
+ if err != nil {
+ // TODO(lance6716): should we return error?
+ return proto.StepTwo
+ }
+ multiStats := make([]external.MultipleFilesStat, 0, 100)
+ for _, bs := range subTaskMetas {
+ var subtask BackfillSubTaskMeta
+ err = json.Unmarshal(bs, &subtask)
+ if err != nil {
+ // TODO(lance6716): should we return error?
+ return proto.StepThree
+ }
+ multiStats = append(multiStats, subtask.MultipleFilesStats...)
+ }
+ if external.GetMaxOverlappingTotal(multiStats) > external.MergeSortOverlapThreshold {
+ return proto.StepTwo
+ }
+ return proto.StepThree
+ case proto.StepTwo:
+ return proto.StepThree
+ default:
+ // current step should be proto.StepThree
+ return proto.StepDone
+ }
}
// OnErrStage generate error handling stage's plan.
@@ -136,8 +162,22 @@ func (*backfillingDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.Task
return nil, nil
}
-func (*backfillingDispatcherExt) GetEligibleInstances(ctx context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) {
- return dispatcher.GenerateSchedulerNodes(ctx)
+func (h *backfillingDispatcherExt) GetEligibleInstances(ctx context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) {
+ serverInfos, err := dispatcher.GenerateSchedulerNodes(ctx)
+ if err != nil {
+ return nil, err
+ }
+ if len(h.previousSchedulerIDs) > 0 {
+ // Only the nodes that executed step one can have step two.
+ involvedServerInfos := make([]*infosync.ServerInfo, 0, len(serverInfos))
+ for _, id := range h.previousSchedulerIDs {
+ if idx := disttaskutil.FindServerInfo(serverInfos, id); idx >= 0 {
+ involvedServerInfos = append(involvedServerInfos, serverInfos[idx])
+ }
+ }
+ return involvedServerInfos, nil
+ }
+ return serverInfos, nil
}
// IsRetryableErr implements TaskFlowHandle.IsRetryableErr interface.
@@ -243,20 +283,37 @@ func generateNonPartitionPlan(d *ddl, tblInfo *model.TableInfo, job *model.Job)
return subTaskMetas, nil
}
-func generateIngestTaskPlan(ctx context.Context) ([][]byte, error) {
+func generateIngestTaskPlan(
+ ctx context.Context,
+ h *backfillingDispatcherExt,
+ taskHandle dispatcher.TaskHandle,
+ gTask *proto.Task,
+) ([][]byte, error) {
// We dispatch dummy subtasks because the rest data in local engine will be imported
// in the initialization of subtask executor.
- serverNodes, err := dispatcher.GenerateSchedulerNodes(ctx)
- if err != nil {
- return nil, err
+ var ingestSubtaskCnt int
+ if intest.InTest && taskHandle == nil {
+ serverNodes, err := dispatcher.GenerateSchedulerNodes(ctx)
+ if err != nil {
+ return nil, err
+ }
+ ingestSubtaskCnt = len(serverNodes)
+ } else {
+ schedulerIDs, err := taskHandle.GetPreviousSchedulerIDs(ctx, gTask.ID, gTask.Step)
+ if err != nil {
+ return nil, err
+ }
+ h.previousSchedulerIDs = schedulerIDs
+ ingestSubtaskCnt = len(schedulerIDs)
}
- subTaskMetas := make([][]byte, 0, len(serverNodes))
+
+ subTaskMetas := make([][]byte, 0, ingestSubtaskCnt)
dummyMeta := &BackfillSubTaskMeta{}
metaBytes, err := json.Marshal(dummyMeta)
if err != nil {
return nil, err
}
- for range serverNodes {
+ for i := 0; i < ingestSubtaskCnt; i++ {
subTaskMetas = append(subTaskMetas, metaBytes)
}
return subTaskMetas, nil
@@ -310,12 +367,14 @@ func generateMergeSortPlan(
hex.EncodeToString(startKey), hex.EncodeToString(endKey))
}
m := &BackfillSubTaskMeta{
- MinKey: startKey,
- MaxKey: endKey,
- DataFiles: dataFiles,
- StatFiles: statFiles,
+ SortedKVMeta: external.SortedKVMeta{
+ MinKey: startKey,
+ MaxKey: endKey,
+ DataFiles: dataFiles,
+ StatFiles: statFiles,
+ TotalKVSize: totalSize / uint64(len(instanceIDs)),
+ },
RangeSplitKeys: rangeSplitKeys,
- TotalKVSize: totalSize / uint64(len(instanceIDs)),
}
metaBytes, err := json.Marshal(m)
if err != nil {
@@ -329,6 +388,39 @@ func generateMergeSortPlan(
}
}
+func generateMergePlan(
+ taskHandle dispatcher.TaskHandle,
+ task *proto.Task,
+) ([][]byte, error) {
+ _, _, _, dataFiles, _, err := getSummaryFromLastStep(taskHandle, task.ID)
+ if err != nil {
+ return nil, err
+ }
+
+ start := 0
+ step := external.MergeSortFileCountStep
+ metaArr := make([][]byte, 0, 16)
+ for start < len(dataFiles) {
+ end := start + step
+ if end > len(dataFiles) {
+ end = len(dataFiles)
+ }
+ m := &BackfillSubTaskMeta{
+ SortedKVMeta: external.SortedKVMeta{
+ DataFiles: dataFiles[start:end],
+ },
+ }
+ metaBytes, err := json.Marshal(m)
+ if err != nil {
+ return nil, err
+ }
+ metaArr = append(metaArr, metaBytes)
+
+ start = end
+ }
+ return metaArr, nil
+}
+
func getRangeSplitter(
ctx context.Context,
cloudStorageURI string,
@@ -372,7 +464,7 @@ func getSummaryFromLastStep(
taskHandle dispatcher.TaskHandle,
gTaskID int64,
) (min, max kv.Key, totalKVSize uint64, dataFiles, statFiles []string, err error) {
- subTaskMetas, err := taskHandle.GetPreviousSubtaskMetas(gTaskID, proto.StepInit)
+ subTaskMetas, err := taskHandle.GetPreviousSubtaskMetas(gTaskID, proto.StepOne)
if err != nil {
return nil, nil, 0, nil, nil, errors.Trace(err)
}
@@ -387,12 +479,16 @@ func getSummaryFromLastStep(
}
// Skip empty subtask.MinKey/MaxKey because it means
// no records need to be written in this subtask.
- minKey = notNilMin(minKey, subtask.MinKey)
- maxKey = notNilMax(maxKey, subtask.MaxKey)
+ minKey = external.NotNilMin(minKey, subtask.MinKey)
+ maxKey = external.NotNilMax(maxKey, subtask.MaxKey)
totalKVSize += subtask.TotalKVSize
- allDataFiles = append(allDataFiles, subtask.DataFiles...)
- allStatFiles = append(allStatFiles, subtask.StatFiles...)
+ for _, stat := range subtask.MultipleFilesStats {
+ for i := range stat.Filenames {
+ allDataFiles = append(allDataFiles, stat.Filenames[i][0])
+ allStatFiles = append(allStatFiles, stat.Filenames[i][1])
+ }
+ }
}
return minKey, maxKey, totalKVSize, allDataFiles, allStatFiles, nil
}
@@ -410,31 +506,3 @@ func redactCloudStorageURI(
}
gTask.Meta = metaBytes
}
-
-// notNilMin returns the smaller of a and b, ignoring nil values.
-func notNilMin(a, b []byte) []byte {
- if len(a) == 0 {
- return b
- }
- if len(b) == 0 {
- return a
- }
- if bytes.Compare(a, b) < 0 {
- return a
- }
- return b
-}
-
-// notNilMax returns the larger of a and b, ignoring nil values.
-func notNilMax(a, b []byte) []byte {
- if len(a) == 0 {
- return b
- }
- if len(b) == 0 {
- return a
- }
- if bytes.Compare(a, b) > 0 {
- return a
- }
- return b
-}
diff --git a/ddl/backfilling_dispatcher_test.go b/ddl/backfilling_dispatcher_test.go
index 832b80b34b496..b9f9bb7284047 100644
--- a/ddl/backfilling_dispatcher_test.go
+++ b/ddl/backfilling_dispatcher_test.go
@@ -50,7 +50,9 @@ func TestBackfillingDispatcher(t *testing.T) {
tblInfo := tbl.Meta()
// 1.1 OnNextSubtasksBatch
- metas, err := dsp.OnNextSubtasksBatch(context.Background(), nil, gTask)
+ gTask.Step = dsp.GetNextStep(nil, gTask)
+ require.Equal(t, proto.StepOne, gTask.Step)
+ metas, err := dsp.OnNextSubtasksBatch(context.Background(), nil, gTask, gTask.Step)
require.NoError(t, err)
require.Equal(t, len(tblInfo.Partition.Definitions), len(metas))
for i, par := range tblInfo.Partition.Definitions {
@@ -61,10 +63,17 @@ func TestBackfillingDispatcher(t *testing.T) {
// 1.2 test partition table OnNextSubtasksBatch after StepInit finished.
gTask.State = proto.TaskStateRunning
- gTask.Step++
- metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask)
+ gTask.Step = dsp.GetNextStep(nil, gTask)
+ require.Equal(t, proto.StepThree, gTask.Step)
+ // for partition table, we will not generate subtask for StepThree.
+ metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask, gTask.Step)
require.NoError(t, err)
- require.Equal(t, 0, len(metas))
+ require.Len(t, metas, 0)
+ gTask.Step = dsp.GetNextStep(nil, gTask)
+ require.Equal(t, proto.StepDone, gTask.Step)
+ metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask, gTask.Step)
+ require.NoError(t, err)
+ require.Len(t, metas, 0)
// 1.3 test partition table OnErrStage.
errMeta, err := dsp.OnErrStage(context.Background(), nil, gTask, []error{errors.New("mockErr")})
@@ -79,7 +88,7 @@ func TestBackfillingDispatcher(t *testing.T) {
// 2.1 empty table
tk.MustExec("create table t1(id int primary key, v int)")
gTask = createAddIndexGlobalTask(t, dom, "test", "t1", ddl.BackfillTaskType)
- metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask)
+ metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask, gTask.Step)
require.NoError(t, err)
require.Equal(t, 0, len(metas))
// 2.2 non empty table.
@@ -90,15 +99,23 @@ func TestBackfillingDispatcher(t *testing.T) {
tk.MustExec("insert into t2 values (), (), (), (), (), ()")
gTask = createAddIndexGlobalTask(t, dom, "test", "t2", ddl.BackfillTaskType)
// 2.2.1 stepInit
- metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask)
+ gTask.Step = dsp.GetNextStep(nil, gTask)
+ metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask, gTask.Step)
require.NoError(t, err)
require.Equal(t, 1, len(metas))
+ require.Equal(t, proto.StepOne, gTask.Step)
// 2.2.2 stepOne
- gTask.Step++
gTask.State = proto.TaskStateRunning
- metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask)
+ gTask.Step = dsp.GetNextStep(nil, gTask)
+ require.Equal(t, proto.StepThree, gTask.Step)
+ metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask, gTask.Step)
require.NoError(t, err)
require.Equal(t, 1, len(metas))
+ gTask.Step = dsp.GetNextStep(nil, gTask)
+ require.Equal(t, proto.StepDone, gTask.Step)
+ metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask, gTask.Step)
+ require.NoError(t, err)
+ require.Equal(t, 0, len(metas))
}
func createAddIndexGlobalTask(t *testing.T, dom *domain.Domain, dbName, tblName string, taskType string) *proto.Task {
diff --git a/ddl/stage_scheduler.go b/ddl/backfilling_dist_scheduler.go
similarity index 90%
rename from ddl/stage_scheduler.go
rename to ddl/backfilling_dist_scheduler.go
index 833750e6ef50a..e40a8296a2953 100644
--- a/ddl/stage_scheduler.go
+++ b/ddl/backfilling_dist_scheduler.go
@@ -19,6 +19,7 @@ import (
"encoding/json"
"github.com/pingcap/errors"
+ "github.com/pingcap/tidb/br/pkg/lightning/backend/external"
"github.com/pingcap/tidb/ddl/ingest"
"github.com/pingcap/tidb/disttask/framework/proto"
"github.com/pingcap/tidb/disttask/framework/scheduler"
@@ -44,12 +45,8 @@ type BackfillSubTaskMeta struct {
StartKey []byte `json:"start_key"`
EndKey []byte `json:"end_key"`
- DataFiles []string `json:"data_files"`
- StatFiles []string `json:"stat_files"`
- RangeSplitKeys [][]byte `json:"range_split_keys"`
- MinKey []byte `json:"min_key"`
- MaxKey []byte `json:"max_key"`
- TotalKVSize uint64 `json:"total_kv_size"`
+ RangeSplitKeys [][]byte `json:"range_split_keys"`
+ external.SortedKVMeta `json:",inline"`
}
// NewBackfillSubtaskExecutor creates a new backfill subtask executor.
@@ -74,15 +71,17 @@ func NewBackfillSubtaskExecutor(_ context.Context, taskMeta []byte, d *ddl,
}
switch stage {
- case proto.StepInit:
+ case proto.StepOne:
jc := d.jobContext(jobMeta.ID, jobMeta.ReorgMeta)
d.setDDLLabelForTopSQL(jobMeta.ID, jobMeta.Query)
d.setDDLSourceForDiagnosis(jobMeta.ID, jobMeta.Type)
return newReadIndexExecutor(
d, &bgm.Job, indexInfo, tbl.(table.PhysicalTable), jc, bc, summary, bgm.CloudStorageURI), nil
- case proto.StepOne:
+ case proto.StepTwo:
+ return newMergeSortExecutor(jobMeta.ID, indexInfo, tbl.(table.PhysicalTable), bc, bgm.CloudStorageURI)
+ case proto.StepThree:
if len(bgm.CloudStorageURI) > 0 {
- return newCloudImportExecutor(jobMeta.ID, indexInfo, tbl.(table.PhysicalTable), bc, bgm.CloudStorageURI)
+ return newCloudImportExecutor(&bgm.Job, jobMeta.ID, indexInfo, tbl.(table.PhysicalTable), bc, bgm.CloudStorageURI)
}
return newImportFromLocalStepExecutor(jobMeta.ID, indexInfo, tbl.(table.PhysicalTable), bc), nil
default:
@@ -145,7 +144,7 @@ func (s *backfillDistScheduler) Init(ctx context.Context) error {
func (s *backfillDistScheduler) GetSubtaskExecutor(ctx context.Context, task *proto.Task, summary *execute.Summary) (execute.SubtaskExecutor, error) {
switch task.Step {
- case proto.StepInit, proto.StepOne:
+ case proto.StepOne, proto.StepTwo, proto.StepThree:
return NewBackfillSubtaskExecutor(ctx, task.Meta, s.d, s.backendCtx, task.Step, summary)
default:
return nil, errors.Errorf("unknown backfill step %d for task %d", task.Step, task.ID)
diff --git a/ddl/backfilling_import_cloud.go b/ddl/backfilling_import_cloud.go
index 693af13232f10..e658803bde5c6 100644
--- a/ddl/backfilling_import_cloud.go
+++ b/ddl/backfilling_import_cloud.go
@@ -30,6 +30,7 @@ import (
)
type cloudImportExecutor struct {
+ job *model.Job
jobID int64
index *model.IndexInfo
ptbl table.PhysicalTable
@@ -38,6 +39,7 @@ type cloudImportExecutor struct {
}
func newCloudImportExecutor(
+ job *model.Job,
jobID int64,
index *model.IndexInfo,
ptbl table.PhysicalTable,
@@ -45,6 +47,7 @@ func newCloudImportExecutor(
cloudStoreURI string,
) (*cloudImportExecutor, error) {
return &cloudImportExecutor{
+ job: job,
jobID: jobID,
index: index,
ptbl: ptbl,
@@ -54,12 +57,12 @@ func newCloudImportExecutor(
}
func (*cloudImportExecutor) Init(ctx context.Context) error {
- logutil.Logger(ctx).Info("merge sort stage init subtask exec env")
+ logutil.Logger(ctx).Info("cloud import executor init subtask exec env")
return nil
}
func (m *cloudImportExecutor) RunSubtask(ctx context.Context, subtask *proto.Subtask) error {
- logutil.Logger(ctx).Info("merge sort stage split subtask")
+ logutil.Logger(ctx).Info("cloud import executor run subtask")
sm := &BackfillSubTaskMeta{}
err := json.Unmarshal(subtask.Meta, sm)
@@ -74,7 +77,7 @@ func (m *cloudImportExecutor) RunSubtask(ctx context.Context, subtask *proto.Sub
if local == nil {
return errors.Errorf("local backend not found")
}
- _, engineUUID := backend.MakeUUID(m.ptbl.Meta().Name.L, int32(m.index.ID))
+ _, engineUUID := backend.MakeUUID(m.ptbl.Meta().Name.L, m.index.ID)
err = local.CloseEngine(ctx, &backend.EngineConfig{
External: &backend.ExternalEngineConfig{
StorageURI: m.cloudStoreURI,
@@ -95,16 +98,16 @@ func (m *cloudImportExecutor) RunSubtask(ctx context.Context, subtask *proto.Sub
}
func (*cloudImportExecutor) Cleanup(ctx context.Context) error {
- logutil.Logger(ctx).Info("merge sort stage clean up subtask env")
+ logutil.Logger(ctx).Info("cloud import executor clean up subtask env")
return nil
}
func (*cloudImportExecutor) OnFinished(ctx context.Context, _ *proto.Subtask) error {
- logutil.Logger(ctx).Info("merge sort stage finish subtask")
+ logutil.Logger(ctx).Info("cloud import executor finish subtask")
return nil
}
func (*cloudImportExecutor) Rollback(ctx context.Context) error {
- logutil.Logger(ctx).Info("merge sort stage rollback subtask")
+ logutil.Logger(ctx).Info("cloud import executor rollback subtask")
return nil
}
diff --git a/ddl/backfilling_import_local.go b/ddl/backfilling_import_local.go
index 40f5b8a8a6b3f..8f79971279dc2 100644
--- a/ddl/backfilling_import_local.go
+++ b/ddl/backfilling_import_local.go
@@ -48,7 +48,7 @@ func newImportFromLocalStepExecutor(
}
func (i *localImportExecutor) Init(ctx context.Context) error {
- logutil.Logger(ctx).Info("ingest index stage init subtask exec env")
+ logutil.Logger(ctx).Info("local import executor init subtask exec env")
_, _, err := i.bc.Flush(i.index.ID, ingest.FlushModeForceGlobal)
if err != nil {
if common.ErrFoundDuplicateKeys.Equal(err) {
@@ -62,21 +62,21 @@ func (i *localImportExecutor) Init(ctx context.Context) error {
}
func (*localImportExecutor) RunSubtask(ctx context.Context, _ *proto.Subtask) error {
- logutil.Logger(ctx).Info("ingest index stage split subtask")
+ logutil.Logger(ctx).Info("local import executor run subtask")
return nil
}
func (*localImportExecutor) Cleanup(ctx context.Context) error {
- logutil.Logger(ctx).Info("ingest index stage cleanup subtask exec env")
+ logutil.Logger(ctx).Info("local import executor cleanup subtask exec env")
return nil
}
func (*localImportExecutor) OnFinished(ctx context.Context, _ *proto.Subtask) error {
- logutil.Logger(ctx).Info("ingest index stage finish subtask")
+ logutil.Logger(ctx).Info("local import executor finish subtask")
return nil
}
func (*localImportExecutor) Rollback(ctx context.Context) error {
- logutil.Logger(ctx).Info("ingest index stage rollback backfill add index task")
+ logutil.Logger(ctx).Info("local import executor rollback subtask")
return nil
}
diff --git a/ddl/backfilling_merge_sort.go b/ddl/backfilling_merge_sort.go
new file mode 100644
index 0000000000000..2862f0141c309
--- /dev/null
+++ b/ddl/backfilling_merge_sort.go
@@ -0,0 +1,131 @@
+// Copyright 2023 PingCAP, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package ddl
+
+import (
+ "context"
+ "encoding/json"
+ "path"
+ "strconv"
+ "sync"
+
+ "github.com/google/uuid"
+ "github.com/pingcap/tidb/br/pkg/lightning/backend/external"
+ "github.com/pingcap/tidb/br/pkg/storage"
+ "github.com/pingcap/tidb/ddl/ingest"
+ "github.com/pingcap/tidb/disttask/framework/proto"
+ "github.com/pingcap/tidb/parser/model"
+ "github.com/pingcap/tidb/table"
+ "github.com/pingcap/tidb/util/logutil"
+ "github.com/pingcap/tidb/util/size"
+ "go.uber.org/zap"
+)
+
+type mergeSortExecutor struct {
+ jobID int64
+ index *model.IndexInfo
+ ptbl table.PhysicalTable
+ bc ingest.BackendCtx
+ cloudStoreURI string
+ dataFiles []string
+ statFiles []string
+ mu sync.Mutex
+}
+
+func newMergeSortExecutor(
+ jobID int64,
+ index *model.IndexInfo,
+ ptbl table.PhysicalTable,
+ bc ingest.BackendCtx,
+ cloudStoreURI string,
+) (*mergeSortExecutor, error) {
+ return &mergeSortExecutor{
+ jobID: jobID,
+ index: index,
+ ptbl: ptbl,
+ bc: bc,
+ cloudStoreURI: cloudStoreURI,
+ }, nil
+}
+
+func (*mergeSortExecutor) Init(ctx context.Context) error {
+ logutil.Logger(ctx).Info("merge sort executor init subtask exec env")
+ return nil
+}
+
+func (m *mergeSortExecutor) RunSubtask(ctx context.Context, subtask *proto.Subtask) error {
+ logutil.Logger(ctx).Info("merge sort executor run subtask")
+
+ sm := &BackfillSubTaskMeta{}
+ err := json.Unmarshal(subtask.Meta, sm)
+ if err != nil {
+ logutil.BgLogger().Error("unmarshal error",
+ zap.String("category", "ddl"),
+ zap.Error(err))
+ return err
+ }
+
+ m.mu.Lock()
+ onClose := func(summary *external.WriterSummary) {
+ for _, f := range summary.MultipleFilesStats {
+ for _, filename := range f.Filenames {
+ m.dataFiles = append(m.dataFiles, filename[0])
+ m.statFiles = append(m.statFiles, filename[1])
+ }
+ }
+ m.mu.Unlock()
+ }
+
+ storeBackend, err := storage.ParseBackend(m.cloudStoreURI, nil)
+ if err != nil {
+ return err
+ }
+ store, err := storage.New(ctx, storeBackend, nil)
+ if err != nil {
+ return err
+ }
+
+ writerID := uuid.New().String()
+ prefix := path.Join(strconv.Itoa(int(m.jobID)), strconv.Itoa(int(subtask.ID)))
+
+ // TODO: config generated by plan.
+ return external.MergeOverlappingFiles(
+ ctx,
+ sm.DataFiles,
+ store,
+ 64*1024,
+ prefix,
+ writerID,
+ 256*size.MB,
+ 8*1024,
+ 1*size.MB,
+ 8*1024,
+ onClose)
+}
+
+func (*mergeSortExecutor) Cleanup(ctx context.Context) error {
+ logutil.Logger(ctx).Info("merge cleanup subtask exec env")
+ return nil
+}
+
+func (*mergeSortExecutor) OnFinished(ctx context.Context, _ *proto.Subtask) error {
+ logutil.Logger(ctx).Info("merge sort finish subtask")
+ return nil
+}
+
+func (*mergeSortExecutor) Rollback(ctx context.Context) error {
+ logutil.Logger(ctx).Info("merge sort executor rollback backfill add index task")
+ return nil
+}
diff --git a/ddl/backfilling_operator.go b/ddl/backfilling_operators.go
similarity index 99%
rename from ddl/backfilling_operator.go
rename to ddl/backfilling_operators.go
index 8724f112e0ba3..cf02ecd1ea897 100644
--- a/ddl/backfilling_operator.go
+++ b/ddl/backfilling_operators.go
@@ -160,6 +160,7 @@ func NewWriteIndexToExternalStoragePipeline(
idxInfo *model.IndexInfo,
startKey, endKey kv.Key,
totalRowCount *atomic.Int64,
+ metricCounter prometheus.Counter,
onClose external.OnCloseFunc,
) (*operator.AsyncPipeline, error) {
index := tables.NewIndex(tbl.GetPhysicalID(), tbl.Meta(), idxInfo)
@@ -188,7 +189,7 @@ func NewWriteIndexToExternalStoragePipeline(
scanOp := NewTableScanOperator(ctx, sessPool, copCtx, srcChkPool, readerCnt)
writeOp := NewWriteExternalStoreOperator(
ctx, copCtx, sessPool, jobID, subtaskID, tbl, index, extStore, srcChkPool, writerCnt, onClose)
- sinkOp := newIndexWriteResultSink(ctx, nil, tbl, index, totalRowCount, nil)
+ sinkOp := newIndexWriteResultSink(ctx, nil, tbl, index, totalRowCount, metricCounter)
operator.Compose[TableScanTask](srcOp, scanOp)
operator.Compose[IndexRecordChunk](scanOp, writeOp)
@@ -471,12 +472,9 @@ func NewWriteExternalStoreOperator(
concurrency,
func() workerpool.Worker[IndexRecordChunk, IndexWriteResult] {
builder := external.NewWriterBuilder().
- SetOnCloseFunc(onClose)
- if index.Meta().Unique {
- builder = builder.EnableDuplicationDetection()
- }
+ SetOnCloseFunc(onClose).
+ SetKeyDuplicationEncoding(index.Meta().Unique)
writerID := uuid.New().String()
-
prefix := path.Join(strconv.Itoa(int(jobID)), strconv.Itoa(int(subtaskID)))
writer := builder.Build(store, prefix, writerID)
diff --git a/ddl/backfilling_read_index.go b/ddl/backfilling_read_index.go
index 1d86f9a83f737..2140c86b84aa4 100644
--- a/ddl/backfilling_read_index.go
+++ b/ddl/backfilling_read_index.go
@@ -58,6 +58,7 @@ type readIndexSummary struct {
totalSize uint64
dataFiles []string
statFiles []string
+ stats []external.MultipleFilesStat
mu sync.Mutex
}
@@ -84,14 +85,15 @@ func newReadIndexExecutor(
}
func (*readIndexExecutor) Init(_ context.Context) error {
- logutil.BgLogger().Info("read index stage init subtask exec env",
+ logutil.BgLogger().Info("read index executor init subtask exec env",
zap.String("category", "ddl"))
return nil
}
func (r *readIndexExecutor) RunSubtask(ctx context.Context, subtask *proto.Subtask) error {
- logutil.BgLogger().Info("read index stage run subtask",
- zap.String("category", "ddl"))
+ logutil.BgLogger().Info("read index executor run subtask",
+ zap.String("category", "ddl"),
+ zap.Bool("use cloud", len(r.cloudStorageURI) > 0))
r.subtaskSummary.Store(subtask.ID, &readIndexSummary{})
@@ -147,7 +149,7 @@ func (r *readIndexExecutor) RunSubtask(ctx context.Context, subtask *proto.Subta
}
func (*readIndexExecutor) Cleanup(ctx context.Context) error {
- logutil.Logger(ctx).Info("read index stage cleanup subtask exec env",
+ logutil.Logger(ctx).Info("read index executor cleanup subtask exec env",
zap.String("category", "ddl"))
return nil
}
@@ -178,6 +180,7 @@ func (r *readIndexExecutor) OnFinished(ctx context.Context, subtask *proto.Subta
subtaskMeta.TotalKVSize = s.totalSize
subtaskMeta.DataFiles = s.dataFiles
subtaskMeta.StatFiles = s.statFiles
+ subtaskMeta.MultipleFilesStats = s.stats
logutil.Logger(ctx).Info("get key boundary on subtask finished",
zap.String("min", hex.EncodeToString(s.minKey)),
zap.String("max", hex.EncodeToString(s.maxKey)),
@@ -192,7 +195,7 @@ func (r *readIndexExecutor) OnFinished(ctx context.Context, subtask *proto.Subta
}
func (r *readIndexExecutor) Rollback(ctx context.Context) error {
- logutil.Logger(ctx).Info("read index stage rollback backfill add index task",
+ logutil.Logger(ctx).Info("read index executor rollback backfill add index task",
zap.String("category", "ddl"), zap.Int64("jobID", r.job.ID))
return nil
}
@@ -260,6 +263,7 @@ func (r *readIndexExecutor) buildExternalStorePipeline(
s.maxKey = summary.Max.Clone()
}
s.totalSize += summary.TotalSize
+ s.stats = append(s.stats, summary.MultipleFilesStats...)
for _, f := range summary.MultipleFilesStats {
for _, filename := range f.Filenames {
s.dataFiles = append(s.dataFiles, filename[0])
@@ -268,7 +272,9 @@ func (r *readIndexExecutor) buildExternalStorePipeline(
}
s.mu.Unlock()
}
+ counter := metrics.BackfillTotalCounter.WithLabelValues(
+ metrics.GenerateReorgLabel("add_idx_rate", r.job.SchemaName, tbl.Meta().Name.O))
return NewWriteIndexToExternalStoragePipeline(
opCtx, d.store, r.cloudStorageURI, r.d.sessPool, sessCtx, r.job.ID, subtaskID,
- tbl, r.index, start, end, totalRowCount, onClose)
+ tbl, r.index, start, end, totalRowCount, counter, onClose)
}
diff --git a/ddl/db_rename_test.go b/ddl/db_rename_test.go
index d2a33c2556734..0bbe0eaea7e3e 100644
--- a/ddl/db_rename_test.go
+++ b/ddl/db_rename_test.go
@@ -293,3 +293,15 @@ func TestRenameMultiTables(t *testing.T) {
tk.MustExec("drop database test1")
tk.MustExec("drop database test")
}
+
+func TestRenameMultiTablesIssue47064(t *testing.T) {
+ store := testkit.CreateMockStore(t, mockstore.WithDDLChecker())
+
+ tk := testkit.NewTestKit(t, store)
+ tk.MustExec("use test")
+ tk.MustExec("create table t1(a int)")
+ tk.MustExec("create table t2(a int)")
+ tk.MustExec("create database test1")
+ tk.MustExec("rename table test.t1 to test1.t1, test.t2 to test1.t2")
+ tk.MustQuery("select column_name from information_schema.columns where table_name = 't1'").Check(testkit.Rows("a"))
+}
diff --git a/ddl/ddl.go b/ddl/ddl.go
index 0b0e9619be899..b3e155bf746ed 100644
--- a/ddl/ddl.go
+++ b/ddl/ddl.go
@@ -154,6 +154,7 @@ const (
OnExistReplace
jobRecordCapacity = 16
+ jobOnceCapacity = 1000
)
var (
@@ -289,14 +290,14 @@ type waitSchemaSyncedController struct {
mu sync.RWMutex
job map[int64]struct{}
- // true if this node is elected to the DDL owner, we should wait 2 * lease before it runs the first DDL job.
- once *atomicutil.Bool
+ // Use to check if the DDL job is the first run on this owner.
+ onceMap map[int64]struct{}
}
func newWaitSchemaSyncedController() *waitSchemaSyncedController {
return &waitSchemaSyncedController{
- job: make(map[int64]struct{}, jobRecordCapacity),
- once: atomicutil.NewBool(true),
+ job: make(map[int64]struct{}, jobRecordCapacity),
+ onceMap: make(map[int64]struct{}, jobOnceCapacity),
}
}
@@ -319,6 +320,25 @@ func (w *waitSchemaSyncedController) synced(job *model.Job) {
delete(w.job, job.ID)
}
+// maybeAlreadyRunOnce returns true means that the job may be the first run on this owner.
+// Returns false means that the job must not be the first run on this owner.
+func (w *waitSchemaSyncedController) maybeAlreadyRunOnce(id int64) bool {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ _, ok := w.onceMap[id]
+ return ok
+}
+
+func (w *waitSchemaSyncedController) setAlreadyRunOnce(id int64) {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ if len(w.onceMap) > jobOnceCapacity {
+ // If the map is too large, we reset it. These jobs may need to check schema synced again, but it's ok.
+ w.onceMap = make(map[int64]struct{}, jobRecordCapacity)
+ }
+ w.onceMap[id] = struct{}{}
+}
+
// ddlCtx is the context when we use worker to handle DDL jobs.
type ddlCtx struct {
ctx context.Context
diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go
index efb41f6285751..aacb4e1e76aaa 100644
--- a/ddl/ddl_api.go
+++ b/ddl/ddl_api.go
@@ -7158,6 +7158,12 @@ func BuildHiddenColumnInfo(ctx sessionctx.Context, indexPartSpecifications []*as
colInfo.FieldType.SetDecimal(types.MaxFsp)
}
}
+ // For an array, the collation is set to "binary". The collation has no effect on the array itself (as it's usually
+ // regarded as a JSON), but will influence how TiKV handles the index value.
+ if colInfo.FieldType.IsArray() {
+ colInfo.SetCharset("binary")
+ colInfo.SetCollate("binary")
+ }
checkDependencies := make(map[string]struct{})
for _, colName := range FindColumnNamesInExpr(idxPart.Expr) {
colInfo.Dependences[colName.Name.L] = struct{}{}
diff --git a/ddl/ddl_worker.go b/ddl/ddl_worker.go
index e99601a6d81d7..dbd519b79f116 100644
--- a/ddl/ddl_worker.go
+++ b/ddl/ddl_worker.go
@@ -835,17 +835,14 @@ func (w *worker) HandleDDLJobTable(d *ddlCtx, job *model.Job) (int64, error) {
}
w.registerSync(job)
- if runJobErr != nil {
+ if runJobErr != nil && !dbterror.ErrPausedDDLJob.Equal(runJobErr) {
// Omit the ErrPausedDDLJob
- if !dbterror.ErrPausedDDLJob.Equal(runJobErr) {
+ w.jobLogger(job).Info("run DDL job failed, sleeps a while then retries it.",
+ zap.Duration("waitTime", GetWaitTimeWhenErrorOccurred()), zap.Error(runJobErr))
+ // In test and job is cancelling we can ignore the sleep.
+ if !(intest.InTest && job.IsCancelling()) {
// wait a while to retry again. If we don't wait here, DDL will retry this job immediately,
// which may act like a deadlock.
- w.jobLogger(job).Info("run DDL job failed, sleeps a while then retries it.",
- zap.Duration("waitTime", GetWaitTimeWhenErrorOccurred()), zap.Error(runJobErr))
- }
-
- // In test and job is cancelling we can ignore the sleep
- if !(intest.InTest && job.IsCancelling()) {
time.Sleep(GetWaitTimeWhenErrorOccurred())
}
}
@@ -1190,14 +1187,14 @@ func toTError(err error) *terror.Error {
return dbterror.ClassDDL.Synthesize(terror.CodeUnknown, err.Error())
}
-// waitSchemaChanged waits for the completion of updating all servers' schema. In order to make sure that happens,
+// waitSchemaChanged waits for the completion of updating all servers' schema or MDL synced. In order to make sure that happens,
// we wait at most 2 * lease time(sessionTTL, 90 seconds).
-func waitSchemaChanged(d *ddlCtx, waitTime time.Duration, latestSchemaVersion int64, job *model.Job) {
+func waitSchemaChanged(d *ddlCtx, waitTime time.Duration, latestSchemaVersion int64, job *model.Job) error {
if !job.IsRunning() && !job.IsRollingback() && !job.IsDone() && !job.IsRollbackDone() {
- return
+ return nil
}
if waitTime == 0 {
- return
+ return nil
}
timeStart := time.Now()
@@ -1208,16 +1205,19 @@ func waitSchemaChanged(d *ddlCtx, waitTime time.Duration, latestSchemaVersion in
if latestSchemaVersion == 0 {
logutil.Logger(d.ctx).Info("schema version doesn't change", zap.String("category", "ddl"))
- return
+ return nil
}
err = d.schemaSyncer.OwnerUpdateGlobalVersion(d.ctx, latestSchemaVersion)
if err != nil {
logutil.Logger(d.ctx).Info("update latest schema version failed", zap.String("category", "ddl"), zap.Int64("ver", latestSchemaVersion), zap.Error(err))
+ if variable.EnableMDL.Load() {
+ return err
+ }
if terror.ErrorEqual(err, context.DeadlineExceeded) {
// If err is context.DeadlineExceeded, it means waitTime(2 * lease) is elapsed. So all the schemas are synced by ticker.
// There is no need to use etcd to sync. The function returns directly.
- return
+ return nil
}
}
@@ -1225,12 +1225,13 @@ func waitSchemaChanged(d *ddlCtx, waitTime time.Duration, latestSchemaVersion in
err = d.schemaSyncer.OwnerCheckAllVersions(d.ctx, job.ID, latestSchemaVersion)
if err != nil {
logutil.Logger(d.ctx).Info("wait latest schema version encounter error", zap.String("category", "ddl"), zap.Int64("ver", latestSchemaVersion), zap.Error(err))
- return
+ return err
}
logutil.Logger(d.ctx).Info("wait latest schema version changed(get the metadata lock if tidb_enable_metadata_lock is true)", zap.String("category", "ddl"),
zap.Int64("ver", latestSchemaVersion),
zap.Duration("take time", time.Since(timeStart)),
zap.String("job", job.String()))
+ return nil
}
// waitSchemaSyncedForMDL likes waitSchemaSynced, but it waits for getting the metadata lock of the latest version of this DDL.
@@ -1289,8 +1290,7 @@ func waitSchemaSynced(d *ddlCtx, job *model.Job, waitTime time.Duration) error {
}
})
- waitSchemaChanged(d, waitTime, latestSchemaVersion, job)
- return nil
+ return waitSchemaChanged(d, waitTime, latestSchemaVersion, job)
}
func buildPlacementAffects(oldIDs []int64, newIDs []int64) []*model.AffectedOption {
@@ -1377,9 +1377,13 @@ func updateSchemaVersion(d *ddlCtx, t *meta.Meta, job *model.Job, multiInfos ...
if err != nil {
return 0, errors.Trace(err)
}
- affects := make([]*model.AffectedOption, len(newSchemaIDs))
+ affects := make([]*model.AffectedOption, len(newSchemaIDs)-1)
for i, newSchemaID := range newSchemaIDs {
- affects[i] = &model.AffectedOption{
+ // Do not add the first table to AffectedOpts. Related issue tidb#47064.
+ if i == 0 {
+ continue
+ }
+ affects[i-1] = &model.AffectedOption{
SchemaID: newSchemaID,
TableID: tableIDs[i],
OldTableID: tableIDs[i],
@@ -1420,6 +1424,10 @@ func updateSchemaVersion(d *ddlCtx, t *meta.Meta, job *model.Job, multiInfos ...
// Keep this as Schema ID of non-partitioned table
// to avoid trigger early rename in TiFlash
diff.AffectedOpts[0].SchemaID = job.SchemaID
+ // Need reload partition table, use diff.AffectedOpts[0].OldSchemaID to mark it.
+ if len(multiInfos) > 0 {
+ diff.AffectedOpts[0].OldSchemaID = ptSchemaID
+ }
} else {
// Swap
diff.TableID = ptDefID
diff --git a/ddl/index.go b/ddl/index.go
index 1bb5fb6747ad6..95ce17f6ad1fe 100644
--- a/ddl/index.go
+++ b/ddl/index.go
@@ -1959,7 +1959,7 @@ func (w *worker) updateJobRowCount(taskKey string, jobID int64) {
logutil.BgLogger().Warn("cannot get global task", zap.String("category", "ddl"), zap.String("task_key", taskKey), zap.Error(err))
return
}
- rowCount, err := taskMgr.GetSubtaskRowCount(gTask.ID, proto.StepInit)
+ rowCount, err := taskMgr.GetSubtaskRowCount(gTask.ID, proto.StepOne)
if err != nil {
logutil.BgLogger().Warn("cannot get subtask row count", zap.String("category", "ddl"), zap.String("task_key", taskKey), zap.Error(err))
return
@@ -2206,8 +2206,9 @@ func (w *worker) updateReorgInfoForPartitions(t table.PartitionedTable, reorg *r
if i == len(partitionIDs)-1 {
return true, nil
}
+ pid = partitionIDs[i+1]
+ break
}
- pid = partitionIDs[i+1]
}
currentVer, err := getValidCurrentVersion(reorg.d.store)
diff --git a/ddl/job_table.go b/ddl/job_table.go
index e42b377737f86..b3ea152d694ed 100644
--- a/ddl/job_table.go
+++ b/ddl/job_table.go
@@ -277,7 +277,7 @@ func (d *ddl) startDispatchLoop() {
}
if !d.isOwner() {
isOnce = true
- d.once.Store(true)
+ d.onceMap = make(map[int64]struct{}, jobOnceCapacity)
time.Sleep(dispatchLoopWaitingDuration)
continue
}
@@ -378,7 +378,7 @@ func (d *ddl) delivery2worker(wk *worker, pool *workerPool, job *model.Job) {
metrics.DDLRunningJobCount.WithLabelValues(pool.tp().String()).Dec()
}()
// check if this ddl job is synced to all servers.
- if !d.isSynced(job) || d.once.Load() {
+ if !job.NotStarted() && (!d.isSynced(job) || !d.maybeAlreadyRunOnce(job.ID)) {
if variable.EnableMDL.Load() {
exist, version, err := checkMDLInfo(job.ID, d.sessPool)
if err != nil {
@@ -393,7 +393,7 @@ func (d *ddl) delivery2worker(wk *worker, pool *workerPool, job *model.Job) {
if err != nil {
return
}
- d.once.Store(false)
+ d.setAlreadyRunOnce(job.ID)
cleanMDLInfo(d.sessPool, job.ID, d.etcdCli)
// Don't have a worker now.
return
@@ -407,7 +407,7 @@ func (d *ddl) delivery2worker(wk *worker, pool *workerPool, job *model.Job) {
pool.put(wk)
return
}
- d.once.Store(false)
+ d.setAlreadyRunOnce(job.ID)
}
}
@@ -426,9 +426,14 @@ func (d *ddl) delivery2worker(wk *worker, pool *workerPool, job *model.Job) {
})
// Here means the job enters another state (delete only, write only, public, etc...) or is cancelled.
- // If the job is done or still running or rolling back, we will wait 2 * lease time to guarantee other servers to update
+ // If the job is done or still running or rolling back, we will wait 2 * lease time or util MDL synced to guarantee other servers to update
// the newest schema.
- waitSchemaChanged(d.ddlCtx, d.lease*2, schemaVer, job)
+ err := waitSchemaChanged(d.ddlCtx, d.lease*2, schemaVer, job)
+ if err != nil {
+ // May be caused by server closing, shouldn't clean the MDL info.
+ logutil.BgLogger().Info("wait latest schema version error", zap.String("category", "ddl"), zap.Error(err))
+ return
+ }
cleanMDLInfo(d.sessPool, job.ID, d.etcdCli)
d.synced(job)
diff --git a/ddl/partition.go b/ddl/partition.go
index 28622385532f2..8d70224656e6d 100644
--- a/ddl/partition.go
+++ b/ddl/partition.go
@@ -2458,16 +2458,27 @@ func (w *worker) onExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Jo
return ver, errors.Trace(err)
}
}
+ var ptInfo []schemaIDAndTableInfo
+ if len(nt.Constraints) > 0 {
+ pt.ExchangePartitionInfo = &model.ExchangePartitionInfo{
+ ExchangePartitionTableID: nt.ID,
+ ExchangePartitionDefID: defID,
+ }
+ ptInfo = append(ptInfo, schemaIDAndTableInfo{
+ schemaID: ptSchemaID,
+ tblInfo: pt,
+ })
+ }
nt.ExchangePartitionInfo = &model.ExchangePartitionInfo{
- ExchangePartitionID: ptID,
- ExchangePartitionDefID: defID,
+ ExchangePartitionTableID: ptID,
+ ExchangePartitionDefID: defID,
}
// We need an interim schema version,
// so there are no non-matching rows inserted
// into the table using the schema version
// before the exchange is made.
job.SchemaState = model.StateWriteOnly
- return updateVersionAndTableInfoWithCheck(d, t, job, nt, true)
+ return updateVersionAndTableInfoWithCheck(d, t, job, nt, true, ptInfo...)
}
// From now on, nt (the non-partitioned table) has
// ExchangePartitionInfo set, meaning it is restricted
@@ -2527,6 +2538,7 @@ func (w *worker) onExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Jo
}
// exchange table meta id
+ pt.ExchangePartitionInfo = nil
partDef.ID, nt.ID = nt.ID, partDef.ID
err = t.UpdateTable(ptSchemaID, pt)
diff --git a/ddl/rollingback.go b/ddl/rollingback.go
index 47e1a3735d5cb..c72be206529e4 100644
--- a/ddl/rollingback.go
+++ b/ddl/rollingback.go
@@ -264,11 +264,35 @@ func needNotifyAndStopReorgWorker(job *model.Job) bool {
// rollbackExchangeTablePartition will clear the non-partitioned
// table's ExchangePartitionInfo state.
-func rollbackExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Job, tblInfo *model.TableInfo) (int64, error) {
+func rollbackExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Job, tblInfo *model.TableInfo) (ver int64, err error) {
tblInfo.ExchangePartitionInfo = nil
job.State = model.JobStateRollbackDone
job.SchemaState = model.StatePublic
- return updateVersionAndTableInfo(d, t, job, tblInfo, true)
+ if len(tblInfo.Constraints) == 0 {
+ return updateVersionAndTableInfo(d, t, job, tblInfo, true)
+ }
+ var (
+ defID int64
+ ptSchemaID int64
+ ptID int64
+ partName string
+ withValidation bool
+ )
+ if err = job.DecodeArgs(&defID, &ptSchemaID, &ptID, &partName, &withValidation); err != nil {
+ return ver, errors.Trace(err)
+ }
+ pt, err := getTableInfo(t, ptID, ptSchemaID)
+ if err != nil {
+ return ver, errors.Trace(err)
+ }
+ pt.ExchangePartitionInfo = nil
+ var ptInfo []schemaIDAndTableInfo
+ ptInfo = append(ptInfo, schemaIDAndTableInfo{
+ schemaID: ptSchemaID,
+ tblInfo: pt,
+ })
+ ver, err = updateVersionAndTableInfo(d, t, job, tblInfo, true, ptInfo...)
+ return ver, errors.Trace(err)
}
func rollingbackExchangeTablePartition(d *ddlCtx, t *meta.Meta, job *model.Job) (ver int64, err error) {
diff --git a/ddl/tests/fk/foreign_key_test.go b/ddl/tests/fk/foreign_key_test.go
index 8efbc8d6abba7..f2744097cf096 100644
--- a/ddl/tests/fk/foreign_key_test.go
+++ b/ddl/tests/fk/foreign_key_test.go
@@ -1566,7 +1566,7 @@ func TestRenameTablesWithForeignKey(t *testing.T) {
// check the schema diff
diff := getLatestSchemaDiff(t, tk)
require.Equal(t, model.ActionRenameTables, diff.Type)
- require.Equal(t, 3, len(diff.AffectedOpts))
+ require.Equal(t, 2, len(diff.AffectedOpts))
// check referred foreign key information.
t1ReferredFKs := getTableInfoReferredForeignKeys(t, dom, "test", "t1")
diff --git a/ddl/tests/multivaluedindex/BUILD.bazel b/ddl/tests/multivaluedindex/BUILD.bazel
new file mode 100644
index 0000000000000..25df7f9079d69
--- /dev/null
+++ b/ddl/tests/multivaluedindex/BUILD.bazel
@@ -0,0 +1,19 @@
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
+go_test(
+ name = "multivaluedindex_test",
+ timeout = "short",
+ srcs = [
+ "main_test.go",
+ "multi_valued_index_test.go",
+ ],
+ flaky = True,
+ deps = [
+ "//infoschema",
+ "//parser/model",
+ "//testkit",
+ "//testkit/testsetup",
+ "@com_github_stretchr_testify//require",
+ "@org_uber_go_goleak//:goleak",
+ ],
+)
diff --git a/ddl/tests/multivaluedindex/main_test.go b/ddl/tests/multivaluedindex/main_test.go
new file mode 100644
index 0000000000000..17eda6ca0900b
--- /dev/null
+++ b/ddl/tests/multivaluedindex/main_test.go
@@ -0,0 +1,35 @@
+// Copyright 2023 PingCAP, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package multivaluedindex
+
+import (
+ "testing"
+
+ "github.com/pingcap/tidb/testkit/testsetup"
+ "go.uber.org/goleak"
+)
+
+func TestMain(m *testing.M) {
+ testsetup.SetupForCommonTest()
+
+ opts := []goleak.Option{
+ goleak.IgnoreTopFunction("github.com/golang/glog.(*fileSink).flushDaemon"),
+ goleak.IgnoreTopFunction("github.com/lestrrat-go/httprc.runFetchWorker"),
+ goleak.IgnoreTopFunction("go.etcd.io/etcd/client/pkg/v3/logutil.(*MergeLogger).outputLoop"),
+ goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
+ }
+
+ goleak.VerifyTestMain(m, opts...)
+}
diff --git a/ddl/tests/multivaluedindex/multi_valued_index_test.go b/ddl/tests/multivaluedindex/multi_valued_index_test.go
new file mode 100644
index 0000000000000..6442c8df40445
--- /dev/null
+++ b/ddl/tests/multivaluedindex/multi_valued_index_test.go
@@ -0,0 +1,47 @@
+// Copyright 2023 PingCAP, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package multivaluedindex
+
+import (
+ "testing"
+
+ "github.com/pingcap/tidb/infoschema"
+ "github.com/pingcap/tidb/parser/model"
+ "github.com/pingcap/tidb/testkit"
+ "github.com/stretchr/testify/require"
+)
+
+func TestCreateMultiValuedIndexHasBinaryCollation(t *testing.T) {
+ store := testkit.CreateMockStore(t)
+ tk := testkit.NewTestKit(t, store)
+
+ tk.MustExec("create table test.t (pk varchar(4) primary key clustered, j json, str varchar(255), value int, key idx((cast(j as char(100) array)), str));")
+ is := tk.Session().GetDomainInfoSchema().(infoschema.InfoSchema)
+ require.NotNil(t, is)
+
+ tbl, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t"))
+ require.NoError(t, err)
+
+ foundIndex := false
+ for _, c := range tbl.Cols() {
+ if c.Hidden {
+ foundIndex = true
+ require.True(t, c.FieldType.IsArray())
+ require.Equal(t, c.FieldType.GetCharset(), "binary")
+ require.Equal(t, c.FieldType.GetCollate(), "binary")
+ }
+ }
+ require.True(t, foundIndex)
+}
diff --git a/ddl/tests/partition/db_partition_test.go b/ddl/tests/partition/db_partition_test.go
index 8ca4877989cd9..93caa59015fec 100644
--- a/ddl/tests/partition/db_partition_test.go
+++ b/ddl/tests/partition/db_partition_test.go
@@ -2373,6 +2373,44 @@ func TestDropPartitionWithGlobalIndex(t *testing.T) {
require.Equal(t, 2, cnt)
}
+func TestDropMultiPartitionWithGlobalIndex(t *testing.T) {
+ defer config.RestoreFunc()()
+ config.UpdateGlobal(func(conf *config.Config) {
+ conf.EnableGlobalIndex = true
+ })
+ store := testkit.CreateMockStore(t)
+ tk := testkit.NewTestKit(t, store)
+ tk.MustExec("use test")
+ tk.MustExec("drop table if exists test_global")
+ tk.MustExec(`create table test_global ( a int, b int, c int)
+ partition by range( a ) (
+ partition p1 values less than (10),
+ partition p2 values less than (20),
+ partition p3 values less than (30)
+ );`)
+ tt := external.GetTableByName(t, tk, "test", "test_global")
+ pid := tt.Meta().Partition.Definitions[1].ID
+
+ tk.MustExec("Alter Table test_global Add Unique Index idx_b (b);")
+ tk.MustExec("Alter Table test_global Add Unique Index idx_c (c);")
+ tk.MustExec(`INSERT INTO test_global VALUES (1, 1, 1), (2, 2, 2), (11, 3, 3), (12, 4, 4), (21, 21, 21), (29, 29, 29)`)
+
+ tk.MustExec("alter table test_global drop partition p1, p2;")
+ result := tk.MustQuery("select * from test_global;")
+ result.Sort().Check(testkit.Rows("21 21 21", "29 29 29"))
+
+ tt = external.GetTableByName(t, tk, "test", "test_global")
+ idxInfo := tt.Meta().FindIndexByName("idx_b")
+ require.NotNil(t, idxInfo)
+ cnt := checkGlobalIndexCleanUpDone(t, tk.Session(), tt.Meta(), idxInfo, pid)
+ require.Equal(t, 2, cnt)
+
+ idxInfo = tt.Meta().FindIndexByName("idx_c")
+ require.NotNil(t, idxInfo)
+ cnt = checkGlobalIndexCleanUpDone(t, tk.Session(), tt.Meta(), idxInfo, pid)
+ require.Equal(t, 2, cnt)
+}
+
func TestGlobalIndexInsertInDropPartition(t *testing.T) {
defer config.RestoreFunc()()
config.UpdateGlobal(func(conf *config.Config) {
diff --git a/distsql/OWNERS b/distsql/OWNERS
new file mode 100644
index 0000000000000..11ce88846bb13
--- /dev/null
+++ b/distsql/OWNERS
@@ -0,0 +1,5 @@
+# See the OWNERS docs at https://go.k8s.io/owners
+options:
+ no_parent_owners: true
+approvers:
+ - sig-approvers-distsql
diff --git a/distsql/request_builder.go b/distsql/request_builder.go
index d165edc8edc74..981eefe11711a 100644
--- a/distsql/request_builder.go
+++ b/distsql/request_builder.go
@@ -310,7 +310,7 @@ func (builder *RequestBuilder) SetFromSessionVars(sv *variable.SessionVars) *Req
builder.Request.ResourceGroupName = sv.ResourceGroupName
builder.Request.StoreBusyThreshold = sv.LoadBasedReplicaReadThreshold
builder.Request.RunawayChecker = sv.StmtCtx.RunawayChecker
- builder.Request.TidbKvReadTimeout = sv.GetTidbKvReadTimeout()
+ builder.Request.TiKVClientReadTimeout = sv.GetTiKVClientReadTimeout()
return builder
}
@@ -488,7 +488,7 @@ func encodeHandleKey(ran *ranger.Range) ([]byte, []byte) {
// 1. signedRanges is less or equal than MaxInt64
// 2. unsignedRanges is greater than MaxInt64
//
-// We do this because every key of tikv is encoded as an int64. As a result, MaxUInt64 is small than zero when
+// We do this because every key of tikv is encoded as an int64. As a result, MaxUInt64 is smaller than zero when
// interpreted as an int64 variable.
//
// This function does the following:
@@ -496,7 +496,7 @@ func encodeHandleKey(ran *ranger.Range) ([]byte, []byte) {
// 2. if there's a range that straddles the int64 boundary, split it into two ranges, which results in one smaller and
// one greater than MaxInt64.
//
-// if `KeepOrder` is false, we merge the two groups of ranges into one group, to save an rpc call later
+// if `KeepOrder` is false, we merge the two groups of ranges into one group, to save a rpc call later
// if `desc` is false, return signed ranges first, vice versa.
func SplitRangesAcrossInt64Boundary(ranges []*ranger.Range, keepOrder bool, desc bool, isCommonHandle bool) ([]*ranger.Range, []*ranger.Range) {
if isCommonHandle || len(ranges) == 0 || ranges[0].LowVal[0].Kind() == types.KindInt64 {
diff --git a/distsql/request_builder_test.go b/distsql/request_builder_test.go
index da10f5827156a..ad51f661e5ee8 100644
--- a/distsql/request_builder_test.go
+++ b/distsql/request_builder_test.go
@@ -603,26 +603,26 @@ func TestRequestBuilder8(t *testing.T) {
require.Equal(t, expect, actual)
}
-func TestRequestBuilderTidbKvReadTimeout(t *testing.T) {
+func TestRequestBuilderTiKVClientReadTimeout(t *testing.T) {
sv := variable.NewSessionVars(nil)
- sv.TidbKvReadTimeout = 100
+ sv.TiKVClientReadTimeout = 100
actual, err := (&RequestBuilder{}).
SetFromSessionVars(sv).
Build()
require.NoError(t, err)
expect := &kv.Request{
- Tp: 0,
- StartTs: 0x0,
- Data: []uint8(nil),
- KeyRanges: kv.NewNonParitionedKeyRanges(nil),
- Concurrency: variable.DefDistSQLScanConcurrency,
- IsolationLevel: 0,
- Priority: 0,
- MemTracker: (*memory.Tracker)(nil),
- SchemaVar: 0,
- ReadReplicaScope: kv.GlobalReplicaScope,
- TidbKvReadTimeout: 100,
- ResourceGroupName: resourcegroup.DefaultResourceGroupName,
+ Tp: 0,
+ StartTs: 0x0,
+ Data: []uint8(nil),
+ KeyRanges: kv.NewNonParitionedKeyRanges(nil),
+ Concurrency: variable.DefDistSQLScanConcurrency,
+ IsolationLevel: 0,
+ Priority: 0,
+ MemTracker: (*memory.Tracker)(nil),
+ SchemaVar: 0,
+ ReadReplicaScope: kv.GlobalReplicaScope,
+ TiKVClientReadTimeout: 100,
+ ResourceGroupName: resourcegroup.DefaultResourceGroupName,
}
expect.Paging.MinPagingSize = paging.MinPagingSize
expect.Paging.MaxPagingSize = paging.MaxPagingSize
diff --git a/disttask/framework/BUILD.bazel b/disttask/framework/BUILD.bazel
index 574b0524e5a77..5ef3fd327004d 100644
--- a/disttask/framework/BUILD.bazel
+++ b/disttask/framework/BUILD.bazel
@@ -12,7 +12,7 @@ go_test(
],
flaky = True,
race = "off",
- shard_count = 27,
+ shard_count = 29,
deps = [
"//disttask/framework/dispatcher",
"//disttask/framework/mock",
diff --git a/disttask/framework/dispatcher/BUILD.bazel b/disttask/framework/dispatcher/BUILD.bazel
index a4c4aedf87845..edc610b3d2f1c 100644
--- a/disttask/framework/dispatcher/BUILD.bazel
+++ b/disttask/framework/dispatcher/BUILD.bazel
@@ -37,8 +37,9 @@ go_test(
embed = [":dispatcher"],
flaky = True,
race = "off",
- shard_count = 8,
+ shard_count = 11,
deps = [
+ "//disttask/framework/mock",
"//disttask/framework/proto",
"//disttask/framework/storage",
"//domain/infosync",
@@ -52,5 +53,6 @@ go_test(
"@com_github_stretchr_testify//require",
"@com_github_tikv_client_go_v2//util",
"@org_uber_go_goleak//:goleak",
+ "@org_uber_go_mock//gomock",
],
)
diff --git a/disttask/framework/dispatcher/dispatcher.go b/disttask/framework/dispatcher/dispatcher.go
index 4c3272c15c88b..9d15de58550b7 100644
--- a/disttask/framework/dispatcher/dispatcher.go
+++ b/disttask/framework/dispatcher/dispatcher.go
@@ -57,19 +57,23 @@ var (
// TaskHandle provides the interface for operations needed by Dispatcher.
// Then we can use dispatcher's function in Dispatcher interface.
type TaskHandle interface {
+ // GetPreviousSchedulerIDs gets previous scheduler IDs.
+ GetPreviousSchedulerIDs(_ context.Context, taskID int64, step int64) ([]string, error)
// GetPreviousSubtaskMetas gets previous subtask metas.
GetPreviousSubtaskMetas(taskID int64, step int64) ([][]byte, error)
- // UpdateTask update the task in tidb_global_task table.
- UpdateTask(taskState string, newSubTasks []*proto.Subtask, retryTimes int) error
storage.SessionExecutor
}
// Dispatcher manages the lifetime of a task
// including submitting subtasks and updating the status of a task.
type Dispatcher interface {
+ // Init initializes the dispatcher, should be called before ExecuteTask.
+ // if Init returns error, dispatcher manager will fail the task directly,
+ // so the returned error should be a fatal error.
+ Init() error
+ // ExecuteTask start to schedule a task.
ExecuteTask()
- // Close closes the dispatcher, not routine-safe, and should be called
- // after ExecuteTask finished.
+ // Close closes the dispatcher, should be called if Init returns nil.
Close()
}
@@ -116,7 +120,12 @@ func NewBaseDispatcher(ctx context.Context, taskMgr *storage.TaskManager, server
}
}
-// ExecuteTask start to schedule a task.
+// Init implements the Dispatcher interface.
+func (*BaseDispatcher) Init() error {
+ return nil
+}
+
+// ExecuteTask implements the Dispatcher interface.
func (d *BaseDispatcher) ExecuteTask() {
logutil.Logger(d.logCtx).Info("execute one task",
zap.String("state", d.Task.State), zap.Uint64("concurrency", d.Task.Concurrency))
@@ -206,7 +215,7 @@ func (d *BaseDispatcher) onReverting() error {
if cnt == 0 {
// Finish the rollback step.
logutil.Logger(d.logCtx).Info("all reverting tasks finished, update the task to reverted state")
- return d.UpdateTask(proto.TaskStateReverted, nil, RetrySQLTimes)
+ return d.updateTask(proto.TaskStateReverted, nil, RetrySQLTimes)
}
// Wait all subtasks in this stage finished.
d.OnTick(d.ctx, d.Task)
@@ -241,18 +250,6 @@ func (d *BaseDispatcher) onRunning() error {
}
if cnt == 0 {
- logutil.Logger(d.logCtx).Info("previous subtasks finished, generate dist plan", zap.Int64("stage", d.Task.Step))
- // When all subtasks dispatched and processed, mark task as succeed.
- if d.Finished(d.Task) {
- d.Task.StateUpdateTime = time.Now().UTC()
- logutil.Logger(d.logCtx).Info("all subtasks dispatched and processed, finish the task")
- err := d.UpdateTask(proto.TaskStateSucceed, nil, RetrySQLTimes)
- if err != nil {
- logutil.Logger(d.logCtx).Warn("update task failed", zap.Error(err))
- return err
- }
- return nil
- }
return d.onNextStage()
}
// Check if any node are down.
@@ -328,29 +325,8 @@ func (d *BaseDispatcher) replaceDeadNodesIfAny() error {
return nil
}
-func (d *BaseDispatcher) addSubtasks(subtasks []*proto.Subtask) (err error) {
- for i := 0; i < RetrySQLTimes; i++ {
- err = d.taskMgr.AddSubTasks(d.Task, subtasks)
- if err == nil {
- break
- }
- if i%10 == 0 {
- logutil.Logger(d.logCtx).Warn("addSubtasks failed", zap.String("state", d.Task.State), zap.Int64("step", d.Task.Step),
- zap.Int("subtask cnt", len(subtasks)),
- zap.Int("retry times", i), zap.Error(err))
- }
- time.Sleep(RetrySQLInterval)
- }
- if err != nil {
- logutil.Logger(d.logCtx).Warn("addSubtasks failed", zap.String("state", d.Task.State), zap.Int64("step", d.Task.Step),
- zap.Int("subtask cnt", len(subtasks)),
- zap.Int("retry times", RetrySQLTimes), zap.Error(err))
- }
- return err
-}
-
-// UpdateTask update the task in tidb_global_task table.
-func (d *BaseDispatcher) UpdateTask(taskState string, newSubTasks []*proto.Subtask, retryTimes int) (err error) {
+// updateTask update the task in tidb_global_task table.
+func (d *BaseDispatcher) updateTask(taskState string, newSubTasks []*proto.Subtask, retryTimes int) (err error) {
prevState := d.Task.State
d.Task.State = taskState
if !VerifyTaskStateTransform(prevState, taskState) {
@@ -404,17 +380,29 @@ func (d *BaseDispatcher) dispatchSubTask4Revert(meta []byte) error {
subTasks := make([]*proto.Subtask, 0, len(instanceIDs))
for _, id := range instanceIDs {
- subTasks = append(subTasks, proto.NewSubtask(d.Task.ID, d.Task.Type, id, meta))
+ // reverting subtasks belong to the same step as current active step.
+ subTasks = append(subTasks, proto.NewSubtask(d.Task.Step, d.Task.ID, d.Task.Type, id, meta))
}
- return d.UpdateTask(proto.TaskStateReverting, subTasks, RetrySQLTimes)
+ return d.updateTask(proto.TaskStateReverting, subTasks, RetrySQLTimes)
+}
+
+func (*BaseDispatcher) nextStepSubtaskDispatched(*proto.Task) bool {
+ // TODO: will implement it when we we support dispatch subtask by batch.
+ // since subtask meta might be too large to save in one transaction.
+ return true
}
-func (d *BaseDispatcher) onNextStage() error {
+func (d *BaseDispatcher) onNextStage() (err error) {
/// dynamic dispatch subtasks.
failpoint.Inject("mockDynamicDispatchErr", func() {
failpoint.Return(errors.New("mockDynamicDispatchErr"))
})
+ nextStep := d.GetNextStep(d, d.Task)
+ logutil.Logger(d.logCtx).Info("onNextStage",
+ zap.Int64("current-step", d.Task.Step),
+ zap.Int64("next-step", nextStep))
+
// 1. Adjust the global task's concurrency.
if d.Task.State == proto.TaskStatePending {
if d.Task.Concurrency == 0 {
@@ -423,24 +411,40 @@ func (d *BaseDispatcher) onNextStage() error {
if d.Task.Concurrency > MaxSubtaskConcurrency {
d.Task.Concurrency = MaxSubtaskConcurrency
}
- d.Task.StateUpdateTime = time.Now().UTC()
- if err := d.UpdateTask(proto.TaskStateRunning, nil, RetrySQLTimes); err != nil {
- return err
- }
- } else if d.StageFinished(d.Task) {
- // 2. when previous stage finished, update to next stage.
- d.Task.Step++
- logutil.Logger(d.logCtx).Info("previous stage finished, run into next stage", zap.Int64("from", d.Task.Step-1), zap.Int64("to", d.Task.Step))
- d.Task.StateUpdateTime = time.Now().UTC()
- err := d.UpdateTask(proto.TaskStateRunning, nil, RetrySQLTimes)
+ }
+ defer func() {
if err != nil {
- return err
+ return
}
- }
+ // invariant: task.Step always means the most recent step that all
+ // corresponding subtasks have been saved to system table.
+ //
+ // when all subtasks of task.Step is finished, we call OnNextSubtasksBatch
+ // to generate subtasks of next step. after all subtasks of next step are
+ // saved to system table, we will update task.Step to next step, so the
+ // invariant hold.
+ // see nextStepSubtaskDispatched for why we don't update task and subtasks
+ // in a single transaction.
+ if d.nextStepSubtaskDispatched(d.Task) {
+ currStep := d.Task.Step
+ d.Task.Step = nextStep
+ // When all subtasks dispatched and processed, mark task as succeed.
+ taskState := proto.TaskStateRunning
+ if d.Task.Step == proto.StepDone {
+ taskState = proto.TaskStateSucceed
+ logutil.Logger(d.logCtx).Info("all subtasks dispatched and processed, finish the task")
+ } else {
+ logutil.Logger(d.logCtx).Info("move to next stage",
+ zap.Int64("from", currStep), zap.Int64("to", d.Task.Step))
+ }
+ d.Task.StateUpdateTime = time.Now().UTC()
+ err = d.updateTask(taskState, nil, RetrySQLTimes)
+ }
+ }()
for {
// 3. generate a batch of subtasks.
- metas, err := d.OnNextSubtasksBatch(d.ctx, d, d.Task)
+ metas, err := d.OnNextSubtasksBatch(d.ctx, d, d.Task, nextStep)
if err != nil {
logutil.Logger(d.logCtx).Warn("generate part of subtasks failed", zap.Error(err))
return d.handlePlanErr(err)
@@ -451,12 +455,12 @@ func (d *BaseDispatcher) onNextStage() error {
})
// 4. dispatch batch of subtasks to EligibleInstances.
- err = d.dispatchSubTask(metas)
+ err = d.dispatchSubTask(nextStep, metas)
if err != nil {
return err
}
- if d.StageFinished(d.Task) {
+ if d.nextStepSubtaskDispatched(d.Task) {
break
}
@@ -467,7 +471,7 @@ func (d *BaseDispatcher) onNextStage() error {
return nil
}
-func (d *BaseDispatcher) dispatchSubTask(metas [][]byte) error {
+func (d *BaseDispatcher) dispatchSubTask(subtaskStep int64, metas [][]byte) error {
logutil.Logger(d.logCtx).Info("dispatch subtasks", zap.String("state", d.Task.State), zap.Int64("step", d.Task.Step), zap.Uint64("concurrency", d.Task.Concurrency), zap.Int("subtasks", len(metas)))
// select all available TiDB nodes for task.
@@ -499,9 +503,9 @@ func (d *BaseDispatcher) dispatchSubTask(metas [][]byte) error {
pos := i % len(serverNodes)
instanceID := disttaskutil.GenerateExecID(serverNodes[pos].IP, serverNodes[pos].Port)
logutil.Logger(d.logCtx).Debug("create subtasks", zap.String("instanceID", instanceID))
- subTasks = append(subTasks, proto.NewSubtask(d.Task.ID, d.Task.Type, instanceID, meta))
+ subTasks = append(subTasks, proto.NewSubtask(subtaskStep, d.Task.ID, d.Task.Type, instanceID, meta))
}
- return d.addSubtasks(subTasks)
+ return d.updateTask(d.Task.State, subTasks, RetrySQLTimes)
}
func (d *BaseDispatcher) handlePlanErr(err error) error {
@@ -511,7 +515,7 @@ func (d *BaseDispatcher) handlePlanErr(err error) error {
}
d.Task.Error = err
// state transform: pending -> failed.
- return d.UpdateTask(proto.TaskStateFailed, nil, RetrySQLTimes)
+ return d.updateTask(proto.TaskStateFailed, nil, RetrySQLTimes)
}
// GenerateSchedulerNodes generate a eligible TiDB nodes.
@@ -598,6 +602,11 @@ func (d *BaseDispatcher) GetPreviousSubtaskMetas(taskID int64, step int64) ([][]
return previousSubtaskMetas, nil
}
+// GetPreviousSchedulerIDs gets scheduler IDs that run previous step.
+func (d *BaseDispatcher) GetPreviousSchedulerIDs(_ context.Context, taskID int64, step int64) ([]string, error) {
+ return d.taskMgr.GetSchedulerIDsByTaskIDAndStep(taskID, step)
+}
+
// WithNewSession executes the function with a new session.
func (d *BaseDispatcher) WithNewSession(fn func(se sessionctx.Context) error) error {
return d.taskMgr.WithNewSession(fn)
diff --git a/disttask/framework/dispatcher/dispatcher_manager.go b/disttask/framework/dispatcher/dispatcher_manager.go
index e13e769d9e109..e15fb1def2d54 100644
--- a/disttask/framework/dispatcher/dispatcher_manager.go
+++ b/disttask/framework/dispatcher/dispatcher_manager.go
@@ -176,12 +176,7 @@ func (dm *Manager) dispatchTaskLoop() {
if GetDispatcherFactory(task.Type) == nil {
logutil.BgLogger().Warn("unknown task type", zap.Int64("task-id", task.ID),
zap.String("task-type", task.Type))
- prevState := task.State
- task.State = proto.TaskStateFailed
- task.Error = errors.New("unknown task type")
- if _, err2 := dm.taskMgr.UpdateGlobalTaskAndAddSubTasks(task, nil, prevState); err2 != nil {
- logutil.BgLogger().Warn("update task state of unknown type failed", zap.Error(err2))
- }
+ dm.failTask(task, errors.New("unknown task type"))
continue
}
// the task is not in runningTasks set when:
@@ -201,6 +196,16 @@ func (dm *Manager) dispatchTaskLoop() {
}
}
+func (dm *Manager) failTask(task *proto.Task, err error) {
+ prevState := task.State
+ task.State = proto.TaskStateFailed
+ task.Error = err
+ if _, err2 := dm.taskMgr.UpdateGlobalTaskAndAddSubTasks(task, nil, prevState); err2 != nil {
+ logutil.BgLogger().Warn("failed to update task state to failed",
+ zap.Int64("task-id", task.ID), zap.Error(err2))
+ }
+}
+
func (dm *Manager) gcSubtaskHistoryTable() {
historySubtaskTableGcInterval := defaultHistorySubtaskTableGcInterval
failpoint.Inject("historySubtaskTableGcInterval", func(val failpoint.Value) {
@@ -244,6 +249,11 @@ func (dm *Manager) startDispatcher(task *proto.Task) {
_ = dm.gPool.Run(func() {
dispatcherFactory := GetDispatcherFactory(task.Type)
dispatcher := dispatcherFactory(dm.ctx, dm.taskMgr, dm.serverID, task)
+ if err := dispatcher.Init(); err != nil {
+ logutil.BgLogger().Error("init dispatcher failed", zap.Error(err))
+ dm.failTask(task, err)
+ return
+ }
defer dispatcher.Close()
dm.setRunningTask(task, dispatcher)
dispatcher.ExecuteTask()
diff --git a/disttask/framework/dispatcher/dispatcher_test.go b/disttask/framework/dispatcher/dispatcher_test.go
index 8a3993cf6ee01..93dfc1e956893 100644
--- a/disttask/framework/dispatcher/dispatcher_test.go
+++ b/disttask/framework/dispatcher/dispatcher_test.go
@@ -17,6 +17,7 @@ package dispatcher_test
import (
"context"
"fmt"
+ "strings"
"testing"
"time"
@@ -24,6 +25,7 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/disttask/framework/dispatcher"
+ "github.com/pingcap/tidb/disttask/framework/mock"
"github.com/pingcap/tidb/disttask/framework/proto"
"github.com/pingcap/tidb/disttask/framework/storage"
"github.com/pingcap/tidb/domain/infosync"
@@ -32,6 +34,7 @@ import (
"github.com/pingcap/tidb/util/logutil"
"github.com/stretchr/testify/require"
"github.com/tikv/client-go/v2/util"
+ "go.uber.org/mock/gomock"
)
var (
@@ -48,7 +51,7 @@ type testDispatcherExt struct{}
func (*testDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {
}
-func (*testDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, task *proto.Task) (metas [][]byte, err error) {
+func (*testDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ int64) (metas [][]byte, err error) {
return nil, nil
}
@@ -66,12 +69,8 @@ func (*testDispatcherExt) IsRetryableErr(error) bool {
return true
}
-func (dsp *testDispatcherExt) StageFinished(task *proto.Task) bool {
- return true
-}
-
-func (dsp *testDispatcherExt) Finished(task *proto.Task) bool {
- return false
+func (*testDispatcherExt) GetNextStep(dispatcher.TaskHandle, *proto.Task) int64 {
+ return proto.StepDone
}
type numberExampleDispatcherExt struct{}
@@ -79,7 +78,7 @@ type numberExampleDispatcherExt struct{}
func (*numberExampleDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {
}
-func (n *numberExampleDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, task *proto.Task) (metas [][]byte, err error) {
+func (n *numberExampleDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, task *proto.Task, _ int64) (metas [][]byte, err error) {
switch task.Step {
case proto.StepInit:
for i := 0; i < subtaskCnt; i++ {
@@ -108,12 +107,13 @@ func (*numberExampleDispatcherExt) IsRetryableErr(error) bool {
return true
}
-func (*numberExampleDispatcherExt) StageFinished(task *proto.Task) bool {
- return true
-}
-
-func (*numberExampleDispatcherExt) Finished(task *proto.Task) bool {
- return task.Step == proto.StepTwo
+func (*numberExampleDispatcherExt) GetNextStep(_ dispatcher.TaskHandle, task *proto.Task) int64 {
+ switch task.Step {
+ case proto.StepInit:
+ return proto.StepOne
+ default:
+ return proto.StepDone
+ }
}
func MockDispatcherManager(t *testing.T, pool *pools.ResourcePool) (*dispatcher.Manager, *storage.TaskManager) {
@@ -203,7 +203,49 @@ func TestGetInstance(t *testing.T) {
require.ElementsMatch(t, instanceIDs, serverIDs)
}
-func checkDispatch(t *testing.T, taskCnt int, isSucc bool, isCancel bool) {
+func TestTaskFailInManager(t *testing.T) {
+ store := testkit.CreateMockStore(t)
+ gtk := testkit.NewTestKit(t, store)
+ pool := pools.NewResourcePool(func() (pools.Resource, error) {
+ return gtk.Session(), nil
+ }, 1, 1, time.Second)
+ defer pool.Close()
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+
+ mockDispatcher := mock.NewMockDispatcher(ctrl)
+ mockDispatcher.EXPECT().Init().Return(errors.New("mock dispatcher init error"))
+
+ dspManager, mgr := MockDispatcherManager(t, pool)
+ dispatcher.RegisterDispatcherFactory(proto.TaskTypeExample,
+ func(ctx context.Context, taskMgr *storage.TaskManager, serverID string, task *proto.Task) dispatcher.Dispatcher {
+ return mockDispatcher
+ })
+ dspManager.Start()
+ defer dspManager.Stop()
+
+ // unknown task type
+ taskID, err := mgr.AddNewGlobalTask("test", "test-type", 1, nil)
+ require.NoError(t, err)
+ require.Eventually(t, func() bool {
+ task, err := mgr.GetGlobalTaskByID(taskID)
+ require.NoError(t, err)
+ return task.State == proto.TaskStateFailed &&
+ strings.Contains(task.Error.Error(), "unknown task type")
+ }, time.Second*10, time.Millisecond*300)
+
+ // dispatcher init error
+ taskID, err = mgr.AddNewGlobalTask("test2", proto.TaskTypeExample, 1, nil)
+ require.NoError(t, err)
+ require.Eventually(t, func() bool {
+ task, err := mgr.GetGlobalTaskByID(taskID)
+ require.NoError(t, err)
+ return task.State == proto.TaskStateFailed &&
+ strings.Contains(task.Error.Error(), "mock dispatcher init error")
+ }, time.Second*10, time.Millisecond*300)
+}
+
+func checkDispatch(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel bool) {
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/domain/MockDisableDistTask", "return(true)"))
defer func() {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/domain/MockDisableDistTask"))
@@ -329,10 +371,19 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool, isCancel bool) {
defer func() {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/storage/MockUpdateTaskErr"))
}()
- // Mock a subtask fails.
- for i := 1; i <= subtaskCnt*taskCnt; i += subtaskCnt {
- err = mgr.UpdateSubtaskStateAndError(int64(i), proto.TaskStateFailed, nil)
- require.NoError(t, err)
+
+ if isSubtaskCancel {
+ // Mock a subtask canceled
+ for i := 1; i <= subtaskCnt*taskCnt; i += subtaskCnt {
+ err = mgr.UpdateSubtaskStateAndError(int64(i), proto.TaskStateCanceled, nil)
+ require.NoError(t, err)
+ }
+ } else {
+ // Mock a subtask fails.
+ for i := 1; i <= subtaskCnt*taskCnt; i += subtaskCnt {
+ err = mgr.UpdateSubtaskStateAndError(int64(i), proto.TaskStateFailed, nil)
+ require.NoError(t, err)
+ }
}
}
@@ -350,27 +401,35 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool, isCancel bool) {
}
func TestSimple(t *testing.T) {
- checkDispatch(t, 1, true, false)
+ checkDispatch(t, 1, true, false, false)
}
func TestSimpleErrStage(t *testing.T) {
- checkDispatch(t, 1, false, false)
+ checkDispatch(t, 1, false, false, false)
}
func TestSimpleCancel(t *testing.T) {
- checkDispatch(t, 1, false, true)
+ checkDispatch(t, 1, false, true, false)
+}
+
+func TestSimpleSubtaskCancel(t *testing.T) {
+ checkDispatch(t, 1, false, false, true)
}
func TestParallel(t *testing.T) {
- checkDispatch(t, 3, true, false)
+ checkDispatch(t, 3, true, false, false)
}
func TestParallelErrStage(t *testing.T) {
- checkDispatch(t, 3, false, false)
+ checkDispatch(t, 3, false, false, false)
}
func TestParallelCancel(t *testing.T) {
- checkDispatch(t, 3, false, true)
+ checkDispatch(t, 3, false, true, false)
+}
+
+func TestParallelSubtaskCancel(t *testing.T) {
+ checkDispatch(t, 3, false, false, true)
}
func TestVerifyTaskStateTransform(t *testing.T) {
diff --git a/disttask/framework/dispatcher/interface.go b/disttask/framework/dispatcher/interface.go
index 758da0c52da9d..ebbd1c3e2f167 100644
--- a/disttask/framework/dispatcher/interface.go
+++ b/disttask/framework/dispatcher/interface.go
@@ -33,12 +33,13 @@ type Extension interface {
// the event is generated every checkTaskRunningInterval, and only when the task NOT FINISHED and NO ERROR.
OnTick(ctx context.Context, task *proto.Task)
- // OnNextSubtasksBatch is used to generate batch of subtasks for current stage
+ // OnNextSubtasksBatch is used to generate batch of subtasks for next stage
// NOTE: don't change gTask.State inside, framework will manage it.
// it's called when:
// 1. task is pending and entering it's first step.
// 2. subtasks dispatched has all finished with no error.
- OnNextSubtasksBatch(ctx context.Context, h TaskHandle, task *proto.Task) (subtaskMetas [][]byte, err error)
+ // when next step is StepDone, it should return nil, nil.
+ OnNextSubtasksBatch(ctx context.Context, h TaskHandle, task *proto.Task, step int64) (subtaskMetas [][]byte, err error)
// OnErrStage is called when:
// 1. subtask is finished with error.
@@ -52,14 +53,10 @@ type Extension interface {
// IsRetryableErr is used to check whether the error occurred in dispatcher is retryable.
IsRetryableErr(err error) bool
- // StageFinished is used to check if all subtasks in current stage are dispatched and processed.
- // StageFinished is called before generating batch of subtasks.
- StageFinished(task *proto.Task) bool
-
- // Finished is used to check if all subtasks for the task are dispatched and processed.
- // Finished is called before generating batch of subtasks.
- // Once Finished return true, mark the task as succeed.
- Finished(task *proto.Task) bool
+ // GetNextStep is used to get the next step for the task.
+ // if task runs successfully, it should go from StepInit to business steps,
+ // then to StepDone, then dispatcher will mark it as finished.
+ GetNextStep(h TaskHandle, task *proto.Task) int64
}
// FactoryFn is used to create a dispatcher.
diff --git a/disttask/framework/framework_dynamic_dispatch_test.go b/disttask/framework/framework_dynamic_dispatch_test.go
index 29f5331e3cd06..42960c1e0985e 100644
--- a/disttask/framework/framework_dynamic_dispatch_test.go
+++ b/disttask/framework/framework_dynamic_dispatch_test.go
@@ -37,9 +37,9 @@ var _ dispatcher.Extension = (*testDynamicDispatcherExt)(nil)
func (*testDynamicDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {}
-func (dsp *testDynamicDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) {
+func (dsp *testDynamicDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task, _ int64) (metas [][]byte, err error) {
// step1
- if gTask.Step == proto.StepInit && dsp.cnt < 3 {
+ if gTask.Step == proto.StepInit {
dsp.cnt++
return [][]byte{
[]byte(fmt.Sprintf("task%d", dsp.cnt)),
@@ -48,7 +48,7 @@ func (dsp *testDynamicDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ di
}
// step2
- if gTask.Step == proto.StepOne && dsp.cnt < 4 {
+ if gTask.Step == proto.StepOne {
dsp.cnt++
return [][]byte{
[]byte(fmt.Sprintf("task%d", dsp.cnt)),
@@ -61,18 +61,15 @@ func (*testDynamicDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.Task
return nil, nil
}
-func (dsp *testDynamicDispatcherExt) StageFinished(task *proto.Task) bool {
- if task.Step == proto.StepInit && dsp.cnt >= 3 {
- return true
+func (dsp *testDynamicDispatcherExt) GetNextStep(_ dispatcher.TaskHandle, task *proto.Task) int64 {
+ switch task.Step {
+ case proto.StepInit:
+ return proto.StepOne
+ case proto.StepOne:
+ return proto.StepTwo
+ default:
+ return proto.StepDone
}
- if task.Step == proto.StepOne && dsp.cnt >= 4 {
- return true
- }
- return false
-}
-
-func (dsp *testDynamicDispatcherExt) Finished(task *proto.Task) bool {
- return task.Step == proto.StepOne && dsp.cnt >= 4
}
func (*testDynamicDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) {
diff --git a/disttask/framework/framework_err_handling_test.go b/disttask/framework/framework_err_handling_test.go
index 6479c3e7ae98a..90a3510263c4b 100644
--- a/disttask/framework/framework_err_handling_test.go
+++ b/disttask/framework/framework_err_handling_test.go
@@ -40,7 +40,7 @@ var (
func (*planErrDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {
}
-func (p *planErrDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) {
+func (p *planErrDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task, _ int64) (metas [][]byte, err error) {
if gTask.Step == proto.StepInit {
if p.callTime == 0 {
p.callTime++
@@ -78,21 +78,15 @@ func (*planErrDispatcherExt) IsRetryableErr(error) bool {
return true
}
-func (p *planErrDispatcherExt) StageFinished(task *proto.Task) bool {
- if task.Step == proto.StepInit && p.cnt == 3 {
- return true
+func (p *planErrDispatcherExt) GetNextStep(_ dispatcher.TaskHandle, task *proto.Task) int64 {
+ switch task.Step {
+ case proto.StepInit:
+ return proto.StepOne
+ case proto.StepOne:
+ return proto.StepTwo
+ default:
+ return proto.StepDone
}
- if task.Step == proto.StepOne && p.cnt == 4 {
- return true
- }
- return false
-}
-
-func (p *planErrDispatcherExt) Finished(task *proto.Task) bool {
- if task.Step == proto.StepOne && p.cnt == 4 {
- return true
- }
- return false
}
type planNotRetryableErrDispatcherExt struct {
@@ -102,7 +96,7 @@ type planNotRetryableErrDispatcherExt struct {
func (*planNotRetryableErrDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {
}
-func (p *planNotRetryableErrDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) {
+func (p *planNotRetryableErrDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ int64) (metas [][]byte, err error) {
return nil, errors.New("not retryable err")
}
@@ -118,18 +112,8 @@ func (*planNotRetryableErrDispatcherExt) IsRetryableErr(error) bool {
return false
}
-func (p *planNotRetryableErrDispatcherExt) StageFinished(task *proto.Task) bool {
- if task.Step == proto.StepInit && p.cnt >= 3 {
- return true
- }
- if task.Step == proto.StepOne && p.cnt >= 4 {
- return true
- }
- return false
-}
-
-func (p *planNotRetryableErrDispatcherExt) Finished(task *proto.Task) bool {
- return task.Step == proto.StepOne && p.cnt >= 4
+func (p *planNotRetryableErrDispatcherExt) GetNextStep(dispatcher.TaskHandle, *proto.Task) int64 {
+ return proto.StepDone
}
func TestPlanErr(t *testing.T) {
diff --git a/disttask/framework/framework_ha_test.go b/disttask/framework/framework_ha_test.go
index 7cc43d93f2918..29487f13d01c5 100644
--- a/disttask/framework/framework_ha_test.go
+++ b/disttask/framework/framework_ha_test.go
@@ -37,7 +37,7 @@ var _ dispatcher.Extension = (*haTestDispatcherExt)(nil)
func (*haTestDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {
}
-func (dsp *haTestDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) {
+func (dsp *haTestDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task, _ int64) (metas [][]byte, err error) {
if gTask.Step == proto.StepInit {
dsp.cnt = 10
return [][]byte{
@@ -78,18 +78,15 @@ func (*haTestDispatcherExt) IsRetryableErr(error) bool {
return true
}
-func (dsp *haTestDispatcherExt) StageFinished(task *proto.Task) bool {
- if task.Step == proto.StepInit && dsp.cnt >= 10 {
- return true
+func (dsp *haTestDispatcherExt) GetNextStep(_ dispatcher.TaskHandle, task *proto.Task) int64 {
+ switch task.Step {
+ case proto.StepInit:
+ return proto.StepOne
+ case proto.StepOne:
+ return proto.StepTwo
+ default:
+ return proto.StepDone
}
- if task.Step == proto.StepOne && dsp.cnt >= 15 {
- return true
- }
- return false
-}
-
-func (dsp *haTestDispatcherExt) Finished(task *proto.Task) bool {
- return task.Step == proto.StepOne && dsp.cnt >= 15
}
func TestHABasic(t *testing.T) {
diff --git a/disttask/framework/framework_rollback_test.go b/disttask/framework/framework_rollback_test.go
index 85f9c4500fb6d..1ed99440e8c49 100644
--- a/disttask/framework/framework_rollback_test.go
+++ b/disttask/framework/framework_rollback_test.go
@@ -41,7 +41,7 @@ var rollbackCnt atomic.Int32
func (*rollbackDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {
}
-func (dsp *rollbackDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) {
+func (dsp *rollbackDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task, _ int64) (metas [][]byte, err error) {
if gTask.Step == proto.StepInit {
dsp.cnt = 3
return [][]byte{
@@ -65,12 +65,13 @@ func (*rollbackDispatcherExt) IsRetryableErr(error) bool {
return true
}
-func (dsp *rollbackDispatcherExt) StageFinished(task *proto.Task) bool {
- return task.Step == proto.StepInit && dsp.cnt >= 3
-}
-
-func (dsp *rollbackDispatcherExt) Finished(task *proto.Task) bool {
- return task.Step == proto.StepInit && dsp.cnt >= 3
+func (dsp *rollbackDispatcherExt) GetNextStep(_ dispatcher.TaskHandle, task *proto.Task) int64 {
+ switch task.Step {
+ case proto.StepInit:
+ return proto.StepOne
+ default:
+ return proto.StepDone
+ }
}
func registerRollbackTaskMeta(t *testing.T, ctrl *gomock.Controller, m *sync.Map) {
diff --git a/disttask/framework/framework_test.go b/disttask/framework/framework_test.go
index 340d0e347532b..65dde1c9816c7 100644
--- a/disttask/framework/framework_test.go
+++ b/disttask/framework/framework_test.go
@@ -44,7 +44,7 @@ var _ dispatcher.Extension = (*testDispatcherExt)(nil)
func (*testDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {
}
-func (dsp *testDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) {
+func (dsp *testDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task, _ int64) (metas [][]byte, err error) {
if gTask.Step == proto.StepInit {
dsp.cnt = 3
return [][]byte{
@@ -66,18 +66,15 @@ func (*testDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.TaskHandle,
return nil, nil
}
-func (dsp *testDispatcherExt) StageFinished(task *proto.Task) bool {
- if task.Step == proto.StepInit && dsp.cnt >= 3 {
- return true
- }
- if task.Step == proto.StepOne && dsp.cnt >= 4 {
- return true
+func (dsp *testDispatcherExt) GetNextStep(_ dispatcher.TaskHandle, task *proto.Task) int64 {
+ switch task.Step {
+ case proto.StepInit:
+ return proto.StepOne
+ case proto.StepOne:
+ return proto.StepTwo
+ default:
+ return proto.StepDone
}
- return false
-}
-
-func (dsp *testDispatcherExt) Finished(task *proto.Task) bool {
- return task.Step == proto.StepOne && dsp.cnt >= 4
}
func generateSchedulerNodes4Test() ([]*infosync.ServerInfo, error) {
@@ -116,9 +113,9 @@ func RegisterTaskMeta(t *testing.T, ctrl *gomock.Controller, m *sync.Map, dispat
mockSubtaskExecutor.EXPECT().RunSubtask(gomock.Any(), gomock.Any()).DoAndReturn(
func(ctx context.Context, subtask *proto.Subtask) error {
switch subtask.Step {
- case proto.StepInit:
- m.Store("0", "0")
case proto.StepOne:
+ m.Store("0", "0")
+ case proto.StepTwo:
m.Store("1", "1")
default:
panic("invalid step")
@@ -155,9 +152,9 @@ func RegisterTaskMetaForExample2(t *testing.T, ctrl *gomock.Controller, m *sync.
mockSubtaskExecutor.EXPECT().RunSubtask(gomock.Any(), gomock.Any()).DoAndReturn(
func(ctx context.Context, subtask *proto.Subtask) error {
switch subtask.Step {
- case proto.StepInit:
- m.Store("2", "2")
case proto.StepOne:
+ m.Store("2", "2")
+ case proto.StepTwo:
m.Store("3", "3")
default:
panic("invalid step")
@@ -174,9 +171,9 @@ func RegisterTaskMetaForExample3(t *testing.T, ctrl *gomock.Controller, m *sync.
mockSubtaskExecutor.EXPECT().RunSubtask(gomock.Any(), gomock.Any()).DoAndReturn(
func(ctx context.Context, subtask *proto.Subtask) error {
switch subtask.Step {
- case proto.StepInit:
- m.Store("4", "4")
case proto.StepOne:
+ m.Store("4", "4")
+ case proto.StepTwo:
m.Store("5", "5")
default:
panic("invalid step")
@@ -245,6 +242,7 @@ func DispatchTaskAndCheckState(taskKey string, t *testing.T, m *sync.Map, state
return true
})
}
+
func DispatchMultiTasksAndOneFail(t *testing.T, num int, m []sync.Map) []*proto.Task {
var tasks []*proto.Task
var taskID []int64
@@ -609,3 +607,33 @@ func TestGC(t *testing.T) {
distContext.Close()
}
+
+func TestFrameworkSubtaskFinishedCancel(t *testing.T) {
+ var m sync.Map
+
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+ RegisterTaskMeta(t, ctrl, &m, &testDispatcherExt{})
+ distContext := testkit.NewDistExecutionContext(t, 3)
+ err := failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/MockSubtaskFinishedCancel", "1*return(true)")
+ require.NoError(t, err)
+ defer func() {
+ require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/MockSubtaskFinishedCancel"))
+ }()
+ DispatchTaskAndCheckState("key1", t, &m, proto.TaskStateReverted)
+ distContext.Close()
+}
+
+func TestFrameworkRunSubtaskCancel(t *testing.T) {
+ var m sync.Map
+
+ ctrl := gomock.NewController(t)
+ defer ctrl.Finish()
+ RegisterTaskMeta(t, ctrl, &m, &testDispatcherExt{})
+ distContext := testkit.NewDistExecutionContext(t, 3)
+ err := failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/MockRunSubtaskCancel", "1*return(true)")
+ require.NoError(t, err)
+ DispatchTaskAndCheckState("key1", t, &m, proto.TaskStateReverted)
+ require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/MockRunSubtaskCancel"))
+ distContext.Close()
+}
diff --git a/disttask/framework/mock/BUILD.bazel b/disttask/framework/mock/BUILD.bazel
index d16d12502c432..1e02ac8e55499 100644
--- a/disttask/framework/mock/BUILD.bazel
+++ b/disttask/framework/mock/BUILD.bazel
@@ -3,6 +3,7 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "mock",
srcs = [
+ "dispatcher_mock.go",
"plan_mock.go",
"scheduler_mock.go",
],
diff --git a/disttask/framework/mock/dispatcher_mock.go b/disttask/framework/mock/dispatcher_mock.go
new file mode 100644
index 0000000000000..abb3ad16df85f
--- /dev/null
+++ b/disttask/framework/mock/dispatcher_mock.go
@@ -0,0 +1,72 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: github.com/pingcap/tidb/disttask/framework/dispatcher (interfaces: Dispatcher)
+
+// Package mock is a generated GoMock package.
+package mock
+
+import (
+ reflect "reflect"
+
+ gomock "go.uber.org/mock/gomock"
+)
+
+// MockDispatcher is a mock of Dispatcher interface.
+type MockDispatcher struct {
+ ctrl *gomock.Controller
+ recorder *MockDispatcherMockRecorder
+}
+
+// MockDispatcherMockRecorder is the mock recorder for MockDispatcher.
+type MockDispatcherMockRecorder struct {
+ mock *MockDispatcher
+}
+
+// NewMockDispatcher creates a new mock instance.
+func NewMockDispatcher(ctrl *gomock.Controller) *MockDispatcher {
+ mock := &MockDispatcher{ctrl: ctrl}
+ mock.recorder = &MockDispatcherMockRecorder{mock}
+ return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use.
+func (m *MockDispatcher) EXPECT() *MockDispatcherMockRecorder {
+ return m.recorder
+}
+
+// Close mocks base method.
+func (m *MockDispatcher) Close() {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "Close")
+}
+
+// Close indicates an expected call of Close.
+func (mr *MockDispatcherMockRecorder) Close() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockDispatcher)(nil).Close))
+}
+
+// ExecuteTask mocks base method.
+func (m *MockDispatcher) ExecuteTask() {
+ m.ctrl.T.Helper()
+ m.ctrl.Call(m, "ExecuteTask")
+}
+
+// ExecuteTask indicates an expected call of ExecuteTask.
+func (mr *MockDispatcherMockRecorder) ExecuteTask() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExecuteTask", reflect.TypeOf((*MockDispatcher)(nil).ExecuteTask))
+}
+
+// Init mocks base method.
+func (m *MockDispatcher) Init() error {
+ m.ctrl.T.Helper()
+ ret := m.ctrl.Call(m, "Init")
+ ret0, _ := ret[0].(error)
+ return ret0
+}
+
+// Init indicates an expected call of Init.
+func (mr *MockDispatcherMockRecorder) Init() *gomock.Call {
+ mr.mock.ctrl.T.Helper()
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockDispatcher)(nil).Init))
+}
diff --git a/disttask/framework/planner/plan.go b/disttask/framework/planner/plan.go
index 2e20be064826d..08b8e1695d98e 100644
--- a/disttask/framework/planner/plan.go
+++ b/disttask/framework/planner/plan.go
@@ -26,13 +26,16 @@ type PlanCtx struct {
// integrate with current distribute framework
SessionCtx sessionctx.Context
+ TaskID int64
TaskKey string
TaskType string
ThreadCnt int
- // PreviousSubtaskMetas is a list of subtask metas from previous step.
+ // PreviousSubtaskMetas is subtask metas of previous steps.
// We can remove this field if we find a better way to pass the result between steps.
- PreviousSubtaskMetas [][]byte
+ PreviousSubtaskMetas map[int64][][]byte
+ GlobalSort bool
+ NextTaskStep int64
}
// LogicalPlan represents a logical plan in distribute framework.
diff --git a/disttask/framework/proto/BUILD.bazel b/disttask/framework/proto/BUILD.bazel
index 21bab39839933..c79a7260ad2c3 100644
--- a/disttask/framework/proto/BUILD.bazel
+++ b/disttask/framework/proto/BUILD.bazel
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
go_library(
name = "proto",
@@ -6,3 +6,12 @@ go_library(
importpath = "github.com/pingcap/tidb/disttask/framework/proto",
visibility = ["//visibility:public"],
)
+
+go_test(
+ name = "proto_test",
+ timeout = "short",
+ srcs = ["task_test.go"],
+ embed = [":proto"],
+ flaky = True,
+ deps = ["@com_github_stretchr_testify//require"],
+)
diff --git a/disttask/framework/proto/task.go b/disttask/framework/proto/task.go
index c1ab89f1a4d34..2403f76304586 100644
--- a/disttask/framework/proto/task.go
+++ b/disttask/framework/proto/task.go
@@ -47,10 +47,14 @@ const (
)
// TaskStep is the step of task.
+// DO NOT change the value of the constants, will break backward compatibility.
+// successfully task MUST go from StepInit to business steps, then StepDone.
const (
- StepInit int64 = 0
- StepOne int64 = 1
- StepTwo int64 = 2
+ StepInit int64 = -1
+ StepDone int64 = -2
+ StepOne int64 = 1
+ StepTwo int64 = 2
+ StepThree int64 = 3
)
// TaskIDLabelName is the label name of task id.
@@ -101,8 +105,9 @@ type Subtask struct {
}
// NewSubtask create a new subtask.
-func NewSubtask(taskID int64, tp, schedulerID string, meta []byte) *Subtask {
+func NewSubtask(step int64, taskID int64, tp, schedulerID string, meta []byte) *Subtask {
return &Subtask{
+ Step: step,
Type: tp,
TaskID: taskID,
SchedulerID: schedulerID,
diff --git a/disttask/framework/proto/task_test.go b/disttask/framework/proto/task_test.go
new file mode 100644
index 0000000000000..8824a6bc79853
--- /dev/null
+++ b/disttask/framework/proto/task_test.go
@@ -0,0 +1,27 @@
+// Copyright 2023 PingCAP, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proto
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestTaskStep(t *testing.T) {
+ // make sure we don't change the value of StepInit accidentally
+ require.Equal(t, int64(-1), StepInit)
+ require.Equal(t, int64(-2), StepDone)
+}
diff --git a/disttask/framework/scheduler/scheduler.go b/disttask/framework/scheduler/scheduler.go
index 96a41e6d350b7..997289e49d511 100644
--- a/disttask/framework/scheduler/scheduler.go
+++ b/disttask/framework/scheduler/scheduler.go
@@ -209,10 +209,15 @@ func (s *BaseScheduler) run(ctx context.Context, task *proto.Task) error {
func (s *BaseScheduler) runSubtask(ctx context.Context, scheduler execute.SubtaskExecutor, subtask *proto.Subtask) {
err := scheduler.RunSubtask(ctx, subtask)
+ failpoint.Inject("MockRunSubtaskCancel", func(val failpoint.Value) {
+ if val.(bool) {
+ err = context.Canceled
+ }
+ })
if err != nil {
s.onError(err)
if errors.Cause(err) == context.Canceled {
- s.updateSubtaskStateAndError(subtask.ID, proto.TaskStateCanceled, nil)
+ s.updateSubtaskStateAndError(subtask.ID, proto.TaskStateCanceled, s.getError())
} else {
s.updateSubtaskStateAndError(subtask.ID, proto.TaskStateFailed, s.getError())
}
@@ -254,6 +259,7 @@ func (s *BaseScheduler) runSubtask(ctx context.Context, scheduler execute.Subtas
time.Sleep(20 * time.Second)
}
})
+
failpoint.Inject("MockExecutorRunErr", func(val failpoint.Value) {
if val.(bool) {
s.onError(errors.New("MockExecutorRunErr"))
@@ -281,6 +287,11 @@ func (s *BaseScheduler) onSubtaskFinished(ctx context.Context, scheduler execute
s.onError(err)
}
}
+ failpoint.Inject("MockSubtaskFinishedCancel", func(val failpoint.Value) {
+ if val.(bool) {
+ s.onError(context.Canceled)
+ }
+ })
if err := s.getError(); err != nil {
if errors.Cause(err) == context.Canceled {
s.updateSubtaskStateAndError(subtask.ID, proto.TaskStateCanceled, nil)
@@ -391,7 +402,8 @@ func (s *BaseScheduler) onError(err error) {
if err == nil {
return
}
-
+ err = errors.WithStack(err)
+ logutil.Logger(s.logCtx).Error("onError", zap.Error(err))
s.mu.Lock()
defer s.mu.Unlock()
diff --git a/disttask/framework/storage/BUILD.bazel b/disttask/framework/storage/BUILD.bazel
index 31a063a42834b..fded479bf075f 100644
--- a/disttask/framework/storage/BUILD.bazel
+++ b/disttask/framework/storage/BUILD.bazel
@@ -14,6 +14,7 @@ go_library(
"//parser/terror",
"//sessionctx",
"//util/chunk",
+ "//util/intest",
"//util/logutil",
"//util/sqlexec",
"@com_github_ngaut_pools//:pools",
diff --git a/disttask/framework/storage/table_test.go b/disttask/framework/storage/table_test.go
index dd1f56c0d6331..33cb7b138f63c 100644
--- a/disttask/framework/storage/table_test.go
+++ b/disttask/framework/storage/table_test.go
@@ -80,6 +80,7 @@ func TestGlobalTaskTable(t *testing.T) {
require.NoError(t, err)
require.Len(t, task4, 1)
require.Equal(t, task, task4[0])
+ require.GreaterOrEqual(t, task4[0].StateUpdateTime, task.StateUpdateTime)
prevState := task.State
task.State = proto.TaskStateRunning
@@ -90,11 +91,12 @@ func TestGlobalTaskTable(t *testing.T) {
task5, err := gm.GetGlobalTasksInStates(proto.TaskStateRunning)
require.NoError(t, err)
require.Len(t, task5, 1)
- require.Equal(t, task, task5[0])
+ require.Equal(t, task.State, task5[0].State)
task6, err := gm.GetGlobalTaskByKey("key1")
require.NoError(t, err)
- require.Equal(t, task, task6)
+ require.Len(t, task5, 1)
+ require.Equal(t, task.State, task6.State)
// test cannot insert task with dup key
_, err = gm.AddNewGlobalTask("key1", "test2", 4, []byte("test2"))
@@ -170,9 +172,6 @@ func TestSubTaskTable(t *testing.T) {
require.NoError(t, err)
require.True(t, ok)
- err = sm.UpdateSubtaskHeartbeat("tidb1", 1, time.Now())
- require.NoError(t, err)
-
ts := time.Now()
time.Sleep(time.Second)
require.NoError(t, sm.StartSubtask(1))
@@ -323,11 +322,13 @@ func TestBothGlobalAndSubTaskTable(t *testing.T) {
task.State = proto.TaskStateRunning
subTasks := []*proto.Subtask{
{
+ Step: proto.StepInit,
Type: proto.TaskTypeExample,
SchedulerID: "instance1",
Meta: []byte("m1"),
},
{
+ Step: proto.StepInit,
Type: proto.TaskTypeExample,
SchedulerID: "instance2",
Meta: []byte("m2"),
@@ -362,11 +363,13 @@ func TestBothGlobalAndSubTaskTable(t *testing.T) {
task.State = proto.TaskStateReverting
subTasks = []*proto.Subtask{
{
+ Step: proto.StepInit,
Type: proto.TaskTypeExample,
SchedulerID: "instance3",
Meta: []byte("m3"),
},
{
+ Step: proto.StepInit,
Type: proto.TaskTypeExample,
SchedulerID: "instance4",
Meta: []byte("m4"),
diff --git a/disttask/framework/storage/task_table.go b/disttask/framework/storage/task_table.go
index 6524f47744e77..4fec66cb711d5 100644
--- a/disttask/framework/storage/task_table.go
+++ b/disttask/framework/storage/task_table.go
@@ -30,6 +30,7 @@ import (
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/util/chunk"
+ "github.com/pingcap/tidb/util/intest"
"github.com/pingcap/tidb/util/logutil"
"github.com/pingcap/tidb/util/sqlexec"
"github.com/tikv/client-go/v2/util"
@@ -428,12 +429,12 @@ func (stm *TaskManager) GetSubtaskInStatesCnt(taskID int64, states ...interface{
// CollectSubTaskError collects the subtask error.
func (stm *TaskManager) CollectSubTaskError(taskID int64) ([]error, error) {
- rs, err := stm.executeSQLWithNewSession(stm.ctx, `select error from mysql.tidb_background_subtask
- where task_key = %? AND state = %?`, taskID, proto.TaskStateFailed)
+ rs, err := stm.executeSQLWithNewSession(stm.ctx,
+ `select error from mysql.tidb_background_subtask
+ where task_key = %? AND state in (%?, %?)`, taskID, proto.TaskStateFailed, proto.TaskStateCanceled)
if err != nil {
return nil, err
}
-
subTaskErrors := make([]error, 0, len(rs))
for _, row := range rs {
if row.IsNull(0) {
@@ -502,16 +503,6 @@ func (stm *TaskManager) FinishSubtask(id int64, meta []byte) error {
return err
}
-// UpdateSubtaskHeartbeat updates the heartbeat of the subtask.
-// not used now.
-// TODO: not sure whether we really need this method, don't update state_update_time now,
-func (stm *TaskManager) UpdateSubtaskHeartbeat(instanceID string, taskID int64, heartbeat time.Time) error {
- _, err := stm.executeSQLWithNewSession(stm.ctx, `update mysql.tidb_background_subtask
- set exec_expired = %? where exec_id = %? and task_key = %?`,
- heartbeat.String(), instanceID, taskID)
- return err
-}
-
// DeleteSubtasksByTaskID deletes the subtask of the given global task ID.
func (stm *TaskManager) DeleteSubtasksByTaskID(taskID int64) error {
_, err := stm.executeSQLWithNewSession(stm.ctx, `delete from mysql.tidb_background_subtask
@@ -622,7 +613,7 @@ func (stm *TaskManager) AddSubTasks(task *proto.Task, subtasks []*proto.Subtask)
_, err := ExecSQL(stm.ctx, se, `insert into mysql.tidb_background_subtask
(step, task_key, exec_id, meta, state, type, checkpoint, summary)
values (%?, %?, %?, %?, %?, %?, %?, %?)`,
- task.Step, task.ID, subtask.SchedulerID, subtask.Meta, subtaskState, proto.Type2Int(subtask.Type), []byte{}, "{}")
+ subtask.Step, task.ID, subtask.SchedulerID, subtask.Meta, subtaskState, proto.Type2Int(subtask.Type), []byte{}, "{}")
if err != nil {
return err
}
@@ -636,14 +627,38 @@ func (stm *TaskManager) AddSubTasks(task *proto.Task, subtasks []*proto.Subtask)
func (stm *TaskManager) UpdateGlobalTaskAndAddSubTasks(gTask *proto.Task, subtasks []*proto.Subtask, prevState string) (bool, error) {
retryable := true
err := stm.WithNewTxn(stm.ctx, func(se sessionctx.Context) error {
- _, err := ExecSQL(stm.ctx, se, "update mysql.tidb_global_task set state = %?, dispatcher_id = %?, step = %?, state_update_time = %?, concurrency = %?, meta = %?, error = %? where id = %? and state = %?",
- gTask.State, gTask.DispatcherID, gTask.Step, gTask.StateUpdateTime.UTC().String(), gTask.Concurrency, gTask.Meta, serializeErr(gTask.Error), gTask.ID, prevState)
+ _, err := ExecSQL(stm.ctx, se, "update mysql.tidb_global_task "+
+ "set state = %?, dispatcher_id = %?, step = %?, concurrency = %?, meta = %?, error = %?, state_update_time = CURRENT_TIMESTAMP()"+
+ "where id = %? and state = %?",
+ gTask.State, gTask.DispatcherID, gTask.Step, gTask.Concurrency, gTask.Meta, serializeErr(gTask.Error), gTask.ID, prevState)
if err != nil {
return err
}
+ // When AffectedRows == 0, means other admin command have changed the task state, it's illegal to dispatch subtasks.
if se.GetSessionVars().StmtCtx.AffectedRows() == 0 {
- retryable = false
- return errors.New("invalid task state transform, state already changed")
+ if !intest.InTest {
+ // task state have changed by other admin command
+ retryable = false
+ return errors.New("invalid task state transform, state already changed")
+ }
+ // TODO: remove it, when OnNextSubtasksBatch returns subtasks, just insert subtasks without updating tidb_global_task.
+ // Currently the business running on distributed task framework will update proto.Task in OnNextSubtasksBatch.
+ // So when dispatching subtasks, framework needs to update global task and insert subtasks in one Txn.
+ //
+ // In future, it's needed to restrict changes of task in OnNextSubtasksBatch.
+ // If OnNextSubtasksBatch won't update any fields in proto.Task, we can insert subtasks only.
+ //
+ // For now, we update nothing in proto.Task in UT's OnNextSubtasksBatch, so the AffectedRows will be 0. So UT can't fully compatible
+ // with current UpdateGlobalTaskAndAddSubTasks implementation.
+ rs, err := ExecSQL(stm.ctx, se, "select id from mysql.tidb_global_task where id = %? and state = %?", gTask.ID, prevState)
+ if err != nil {
+ return err
+ }
+ // state have changed.
+ if len(rs) == 0 {
+ retryable = false
+ return errors.New("invalid task state transform, state already changed")
+ }
}
failpoint.Inject("MockUpdateTaskErr", func(val failpoint.Value) {
@@ -651,23 +666,23 @@ func (stm *TaskManager) UpdateGlobalTaskAndAddSubTasks(gTask *proto.Task, subtas
failpoint.Return(errors.New("updateTaskErr"))
}
})
+ if len(subtasks) > 0 {
+ subtaskState := proto.TaskStatePending
+ if gTask.State == proto.TaskStateReverting {
+ subtaskState = proto.TaskStateRevertPending
+ }
- subtaskState := proto.TaskStatePending
- if gTask.State == proto.TaskStateReverting {
- subtaskState = proto.TaskStateRevertPending
- }
-
- for _, subtask := range subtasks {
- // TODO: insert subtasks in batch
- _, err = ExecSQL(stm.ctx, se, `insert into mysql.tidb_background_subtask
+ for _, subtask := range subtasks {
+ // TODO: insert subtasks in batch
+ _, err = ExecSQL(stm.ctx, se, `insert into mysql.tidb_background_subtask
(step, task_key, exec_id, meta, state, type, checkpoint, summary)
values (%?, %?, %?, %?, %?, %?, %?, %?)`,
- gTask.Step, gTask.ID, subtask.SchedulerID, subtask.Meta, subtaskState, proto.Type2Int(subtask.Type), []byte{}, "{}")
- if err != nil {
- return err
+ subtask.Step, gTask.ID, subtask.SchedulerID, subtask.Meta, subtaskState, proto.Type2Int(subtask.Type), []byte{}, "{}")
+ if err != nil {
+ return err
+ }
}
}
-
return nil
})
@@ -692,7 +707,9 @@ func serializeErr(err error) []byte {
// CancelGlobalTask cancels global task.
func (stm *TaskManager) CancelGlobalTask(taskID int64) error {
- _, err := stm.executeSQLWithNewSession(stm.ctx, "update mysql.tidb_global_task set state=%? where id=%? and state in (%?, %?)",
+ _, err := stm.executeSQLWithNewSession(stm.ctx,
+ "update mysql.tidb_global_task set state=%?, state_update_time = CURRENT_TIMESTAMP() "+
+ "where id=%? and state in (%?, %?)",
proto.TaskStateCancelling, taskID, proto.TaskStatePending, proto.TaskStateRunning,
)
return err
@@ -700,7 +717,9 @@ func (stm *TaskManager) CancelGlobalTask(taskID int64) error {
// CancelGlobalTaskByKeySession cancels global task by key using input session.
func (stm *TaskManager) CancelGlobalTaskByKeySession(se sessionctx.Context, taskKey string) error {
- _, err := ExecSQL(stm.ctx, se, "update mysql.tidb_global_task set state=%? where task_key=%? and state in (%?, %?)",
+ _, err := ExecSQL(stm.ctx, se,
+ "update mysql.tidb_global_task set state=%?, state_update_time = CURRENT_TIMESTAMP() "+
+ "where task_key=%? and state in (%?, %?)",
proto.TaskStateCancelling, taskKey, proto.TaskStatePending, proto.TaskStateRunning)
return err
}
diff --git a/disttask/importinto/BUILD.bazel b/disttask/importinto/BUILD.bazel
index 782bfd16f9bda..1133667f891c9 100644
--- a/disttask/importinto/BUILD.bazel
+++ b/disttask/importinto/BUILD.bazel
@@ -17,6 +17,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//br/pkg/lightning/backend",
+ "//br/pkg/lightning/backend/external",
"//br/pkg/lightning/backend/kv",
"//br/pkg/lightning/backend/local",
"//br/pkg/lightning/checkpoints",
@@ -26,6 +27,7 @@ go_library(
"//br/pkg/lightning/metric",
"//br/pkg/lightning/mydump",
"//br/pkg/lightning/verification",
+ "//br/pkg/storage",
"//br/pkg/utils",
"//config",
"//disttask/framework/dispatcher",
@@ -44,6 +46,7 @@ go_library(
"//meta/autoid",
"//metrics",
"//parser/ast",
+ "//parser/model",
"//parser/mysql",
"//resourcemanager/pool/workerpool",
"//resourcemanager/util",
@@ -57,7 +60,9 @@ go_library(
"//util/logutil",
"//util/mathutil",
"//util/promutil",
+ "//util/size",
"//util/sqlexec",
+ "@com_github_docker_go_units//:go-units",
"@com_github_go_sql_driver_mysql//:mysql",
"@com_github_google_uuid//:uuid",
"@com_github_pingcap_errors//:errors",
@@ -66,6 +71,7 @@ go_library(
"@com_github_tikv_client_go_v2//util",
"@org_uber_go_atomic//:atomic",
"@org_uber_go_zap//:zap",
+ "@org_uber_go_zap//zapcore",
],
)
@@ -76,6 +82,7 @@ go_test(
"dispatcher_test.go",
"dispatcher_testkit_test.go",
"encode_and_sort_operator_test.go",
+ "job_testkit_test.go",
"metrics_test.go",
"planner_test.go",
"subtask_executor_test.go",
@@ -84,11 +91,14 @@ go_test(
embed = [":importinto"],
flaky = True,
race = "on",
- shard_count = 8,
+ shard_count = 13,
deps = [
+ "//br/pkg/lightning/backend",
+ "//br/pkg/lightning/backend/external",
"//br/pkg/lightning/checkpoints",
"//br/pkg/lightning/mydump",
"//br/pkg/lightning/verification",
+ "//ddl",
"//disttask/framework/dispatcher",
"//disttask/framework/planner",
"//disttask/framework/proto",
@@ -98,10 +108,12 @@ go_test(
"//domain/infosync",
"//executor/importer",
"//meta/autoid",
+ "//parser",
+ "//parser/ast",
"//parser/model",
- "//parser/mysql",
"//testkit",
"//util/logutil",
+ "//util/mock",
"//util/sqlexec",
"@com_github_ngaut_pools//:pools",
"@com_github_pingcap_errors//:errors",
diff --git a/disttask/importinto/dispatcher.go b/disttask/importinto/dispatcher.go
index 0da29143cd187..cf821336708f1 100644
--- a/disttask/importinto/dispatcher.go
+++ b/disttask/importinto/dispatcher.go
@@ -25,9 +25,11 @@ import (
dmysql "github.com/go-sql-driver/mysql"
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
+ "github.com/pingcap/tidb/br/pkg/lightning/backend/external"
"github.com/pingcap/tidb/br/pkg/lightning/checkpoints"
"github.com/pingcap/tidb/br/pkg/lightning/common"
"github.com/pingcap/tidb/br/pkg/lightning/config"
+ "github.com/pingcap/tidb/br/pkg/lightning/log"
"github.com/pingcap/tidb/br/pkg/lightning/metric"
"github.com/pingcap/tidb/br/pkg/utils"
"github.com/pingcap/tidb/disttask/framework/dispatcher"
@@ -122,7 +124,8 @@ func (t *taskInfo) close(ctx context.Context) {
// ImportDispatcherExt is an extension of ImportDispatcher, exported for test.
type ImportDispatcherExt struct {
- mu sync.RWMutex
+ GlobalSort bool
+ mu sync.RWMutex
// NOTE: there's no need to sync for below 2 fields actually, since we add a restriction that only one
// task can be running at a time. but we might support task queuing in the future, leave it for now.
// the last time we switch TiKV into IMPORT mode, this is a global operation, do it for one task makes
@@ -192,23 +195,28 @@ func (dsp *ImportDispatcherExt) unregisterTask(ctx context.Context, task *proto.
}
// OnNextSubtasksBatch generate batch of next stage's plan.
-func (dsp *ImportDispatcherExt) OnNextSubtasksBatch(ctx context.Context, taskHandle dispatcher.TaskHandle, gTask *proto.Task) (
+func (dsp *ImportDispatcherExt) OnNextSubtasksBatch(
+ ctx context.Context,
+ taskHandle dispatcher.TaskHandle,
+ gTask *proto.Task,
+ nextStep int64,
+) (
resSubtaskMeta [][]byte, err error) {
logger := logutil.BgLogger().With(
zap.String("type", gTask.Type),
zap.Int64("task-id", gTask.ID),
- zap.String("step", stepStr(gTask.Step)),
+ zap.String("curr-step", stepStr(gTask.Step)),
+ zap.String("next-step", stepStr(nextStep)),
)
taskMeta := &TaskMeta{}
err = json.Unmarshal(gTask.Meta, taskMeta)
if err != nil {
- return nil, err
+ return nil, errors.Trace(err)
}
logger.Info("on next subtasks batch")
defer func() {
- // currently, framework will take the task as finished when err is not nil or resSubtaskMeta is empty.
- taskFinished := err == nil && len(resSubtaskMeta) == 0
+ taskFinished := err == nil && nextStep == proto.StepDone
if taskFinished {
// todo: we're not running in a transaction with task update
if err2 := dsp.finishJob(ctx, logger, taskHandle, gTask, taskMeta); err2 != nil {
@@ -223,15 +231,42 @@ func (dsp *ImportDispatcherExt) OnNextSubtasksBatch(ctx context.Context, taskHan
}
}()
- switch gTask.Step {
- case StepImport:
+ previousSubtaskMetas := make(map[int64][][]byte, 1)
+ switch nextStep {
+ case StepImport, StepEncodeAndSort:
if metrics, ok := metric.GetCommonMetric(ctx); ok {
metrics.BytesCounter.WithLabelValues(metric.StateTotalRestore).Add(float64(taskMeta.Plan.TotalFileSize))
}
- if err := preProcess(ctx, taskHandle, gTask, taskMeta, logger); err != nil {
+ jobStep := importer.JobStepImporting
+ if dsp.GlobalSort {
+ jobStep = importer.JobStepGlobalSorting
+ }
+ if err = startJob(ctx, logger, taskHandle, taskMeta, jobStep); err != nil {
+ return nil, err
+ }
+ case StepMergeSort:
+ sortAndEncodeMeta, err := taskHandle.GetPreviousSubtaskMetas(gTask.ID, StepEncodeAndSort)
+ if err != nil {
+ return nil, err
+ }
+ previousSubtaskMetas[StepEncodeAndSort] = sortAndEncodeMeta
+ case StepWriteAndIngest:
+ failpoint.Inject("failWhenDispatchWriteIngestSubtask", func() {
+ failpoint.Return(nil, errors.New("injected error"))
+ })
+ // merge sort might be skipped for some kv groups, so we need to get all
+ // subtask metas of StepEncodeAndSort step too.
+ encodeAndSortMetas, err := taskHandle.GetPreviousSubtaskMetas(gTask.ID, StepEncodeAndSort)
+ if err != nil {
+ return nil, err
+ }
+ mergeSortMetas, err := taskHandle.GetPreviousSubtaskMetas(gTask.ID, StepMergeSort)
+ if err != nil {
return nil, err
}
- if err = startJob(ctx, logger, taskHandle, taskMeta); err != nil {
+ previousSubtaskMetas[StepEncodeAndSort] = encodeAndSortMetas
+ previousSubtaskMetas[StepMergeSort] = mergeSortMetas
+ if err = job2Step(ctx, logger, taskMeta, importer.JobStepImporting); err != nil {
return nil, err
}
case StepPostProcess:
@@ -245,24 +280,30 @@ func (dsp *ImportDispatcherExt) OnNextSubtasksBatch(ctx context.Context, taskHan
failpoint.Inject("failWhenDispatchPostProcessSubtask", func() {
failpoint.Return(nil, errors.New("injected error after StepImport"))
})
- if err := updateResult(taskHandle, gTask, taskMeta); err != nil {
+ // we need get metas where checksum is stored.
+ if err := updateResult(taskHandle, gTask, taskMeta, dsp.GlobalSort); err != nil {
return nil, err
}
- if err := taskHandle.UpdateTask(gTask.State, nil, dispatcher.RetrySQLTimes); err != nil {
+ step := getStepOfEncode(dsp.GlobalSort)
+ metas, err := taskHandle.GetPreviousSubtaskMetas(gTask.ID, step)
+ if err != nil {
return nil, err
}
+ previousSubtaskMetas[step] = metas
logger.Info("move to post-process step ", zap.Any("result", taskMeta.Result))
- case StepPostProcess + 1:
+ case proto.StepDone:
return nil, nil
default:
return nil, errors.Errorf("unknown step %d", gTask.Step)
}
- previousSubtaskMetas, err := taskHandle.GetPreviousSubtaskMetas(gTask.ID, gTask.Step-1)
- if err != nil {
- return nil, err
+ planCtx := planner.PlanCtx{
+ Ctx: ctx,
+ TaskID: gTask.ID,
+ PreviousSubtaskMetas: previousSubtaskMetas,
+ GlobalSort: dsp.GlobalSort,
+ NextTaskStep: nextStep,
}
- planCtx := planner.PlanCtx{Ctx: ctx, PreviousSubtaskMetas: previousSubtaskMetas}
logicalPlan := &LogicalPlan{}
if err := logicalPlan.FromTaskMeta(gTask.Meta); err != nil {
return nil, err
@@ -271,7 +312,7 @@ func (dsp *ImportDispatcherExt) OnNextSubtasksBatch(ctx context.Context, taskHan
if err != nil {
return nil, err
}
- metaBytes, err := physicalPlan.ToSubtaskMetas(planCtx, gTask.Step)
+ metaBytes, err := physicalPlan.ToSubtaskMetas(planCtx, nextStep)
if err != nil {
return nil, err
}
@@ -290,7 +331,7 @@ func (dsp *ImportDispatcherExt) OnErrStage(ctx context.Context, handle dispatche
taskMeta := &TaskMeta{}
err := json.Unmarshal(gTask.Meta, taskMeta)
if err != nil {
- return nil, err
+ return nil, errors.Trace(err)
}
errStrs := make([]string, 0, len(receiveErrs))
for _, receiveErr := range receiveErrs {
@@ -323,7 +364,7 @@ func (*ImportDispatcherExt) GetEligibleInstances(ctx context.Context, gTask *pro
taskMeta := &TaskMeta{}
err := json.Unmarshal(gTask.Meta, taskMeta)
if err != nil {
- return nil, err
+ return nil, errors.Trace(err)
}
if len(taskMeta.EligibleInstances) > 0 {
return taskMeta.EligibleInstances, nil
@@ -337,14 +378,24 @@ func (*ImportDispatcherExt) IsRetryableErr(error) bool {
return false
}
-// StageFinished check if current stage finished.
-func (*ImportDispatcherExt) StageFinished(_ *proto.Task) bool {
- return true
-}
-
-// Finished check if current task finished.
-func (*ImportDispatcherExt) Finished(task *proto.Task) bool {
- return task.Step == StepPostProcess+1
+// GetNextStep implements dispatcher.Extension interface.
+func (dsp *ImportDispatcherExt) GetNextStep(_ dispatcher.TaskHandle, task *proto.Task) int64 {
+ switch task.Step {
+ case proto.StepInit:
+ if dsp.GlobalSort {
+ return StepEncodeAndSort
+ }
+ return StepImport
+ case StepEncodeAndSort:
+ return StepMergeSort
+ case StepMergeSort:
+ return StepWriteAndIngest
+ case StepImport, StepWriteAndIngest:
+ return StepPostProcess
+ default:
+ // current step must be StepPostProcess
+ return proto.StepDone
+ }
}
func (dsp *ImportDispatcherExt) switchTiKV2NormalMode(ctx context.Context, task *proto.Task, logger *zap.Logger) {
@@ -386,11 +437,28 @@ func newImportDispatcher(ctx context.Context, taskMgr *storage.TaskManager,
serverID string, task *proto.Task) dispatcher.Dispatcher {
metrics := metricsManager.getOrCreateMetrics(task.ID)
subCtx := metric.WithCommonMetric(ctx, metrics)
- dis := importDispatcher{
+ dsp := importDispatcher{
BaseDispatcher: dispatcher.NewBaseDispatcher(subCtx, taskMgr, serverID, task),
}
- dis.BaseDispatcher.Extension = &ImportDispatcherExt{}
- return &dis
+ return &dsp
+}
+
+func (dsp *importDispatcher) Init() (err error) {
+ defer func() {
+ if err != nil {
+ // if init failed, close is not called, so we need to unregister here.
+ metricsManager.unregister(dsp.Task.ID)
+ }
+ }()
+ taskMeta := &TaskMeta{}
+ if err = json.Unmarshal(dsp.BaseDispatcher.Task.Meta, taskMeta); err != nil {
+ return errors.Annotate(err, "unmarshal task meta failed")
+ }
+
+ dsp.BaseDispatcher.Extension = &ImportDispatcherExt{
+ GlobalSort: taskMeta.Plan.CloudStorageURI != "",
+ }
+ return dsp.BaseDispatcher.Init()
}
func (dsp *importDispatcher) Close() {
@@ -398,16 +466,6 @@ func (dsp *importDispatcher) Close() {
dsp.BaseDispatcher.Close()
}
-// preProcess does the pre-processing for the task.
-func preProcess(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task, taskMeta *TaskMeta, logger *zap.Logger) error {
- logger.Info("pre process")
- // TODO: drop table indexes depends on the option.
- // if err := dropTableIndexes(ctx, handle, taskMeta, logger); err != nil {
- // return err
- // }
- return updateMeta(gTask, taskMeta)
-}
-
// nolint:deadcode
func dropTableIndexes(ctx context.Context, handle dispatcher.TaskHandle, taskMeta *TaskMeta, logger *zap.Logger) error {
tblInfo := taskMeta.Plan.TableInfo
@@ -479,7 +537,7 @@ func executeSQL(ctx context.Context, executor storage.SessionExecutor, logger *z
func updateMeta(gTask *proto.Task, taskMeta *TaskMeta) error {
bs, err := json.Marshal(taskMeta)
if err != nil {
- return err
+ return errors.Trace(err)
}
gTask.Meta = bs
@@ -498,9 +556,17 @@ func toChunkMap(engineCheckpoints map[int32]*checkpoints.EngineCheckpoint) map[i
return chunkMap
}
+func getStepOfEncode(globalSort bool) int64 {
+ if globalSort {
+ return StepEncodeAndSort
+ }
+ return StepImport
+}
+
// we will update taskMeta in place and make gTask.Meta point to the new taskMeta.
-func updateResult(handle dispatcher.TaskHandle, gTask *proto.Task, taskMeta *TaskMeta) error {
- metas, err := handle.GetPreviousSubtaskMetas(gTask.ID, gTask.Step-1)
+func updateResult(handle dispatcher.TaskHandle, gTask *proto.Task, taskMeta *TaskMeta, globalSort bool) error {
+ stepOfEncode := getStepOfEncode(globalSort)
+ metas, err := handle.GetPreviousSubtaskMetas(gTask.ID, stepOfEncode)
if err != nil {
return err
}
@@ -509,7 +575,7 @@ func updateResult(handle dispatcher.TaskHandle, gTask *proto.Task, taskMeta *Tas
for _, bs := range metas {
var subtaskMeta ImportStepMeta
if err := json.Unmarshal(bs, &subtaskMeta); err != nil {
- return err
+ return errors.Trace(err)
}
subtaskMetas = append(subtaskMetas, &subtaskMeta)
}
@@ -522,10 +588,35 @@ func updateResult(handle dispatcher.TaskHandle, gTask *proto.Task, taskMeta *Tas
}
}
taskMeta.Result.ColSizeMap = columnSizeMap
+
+ if globalSort {
+ taskMeta.Result.LoadedRowCnt, err = getLoadedRowCountOnGlobalSort(handle, gTask)
+ if err != nil {
+ return err
+ }
+ }
+
return updateMeta(gTask, taskMeta)
}
-func startJob(ctx context.Context, logger *zap.Logger, taskHandle dispatcher.TaskHandle, taskMeta *TaskMeta) error {
+func getLoadedRowCountOnGlobalSort(handle dispatcher.TaskHandle, gTask *proto.Task) (uint64, error) {
+ metas, err := handle.GetPreviousSubtaskMetas(gTask.ID, StepWriteAndIngest)
+ if err != nil {
+ return 0, err
+ }
+
+ var loadedRowCount uint64
+ for _, bs := range metas {
+ var subtaskMeta WriteIngestStepMeta
+ if err = json.Unmarshal(bs, &subtaskMeta); err != nil {
+ return 0, errors.Trace(err)
+ }
+ loadedRowCount += subtaskMeta.Result.LoadedRowCnt
+ }
+ return loadedRowCount, nil
+}
+
+func startJob(ctx context.Context, logger *zap.Logger, taskHandle dispatcher.TaskHandle, taskMeta *TaskMeta, jobStep string) error {
failpoint.Inject("syncBeforeJobStarted", func() {
TestSyncChan <- struct{}{}
<-TestSyncChan
@@ -539,7 +630,7 @@ func startJob(ctx context.Context, logger *zap.Logger, taskHandle dispatcher.Tas
func(ctx context.Context) (bool, error) {
return true, taskHandle.WithNewSession(func(se sessionctx.Context) error {
exec := se.(sqlexec.SQLExecutor)
- return importer.StartJob(ctx, exec, taskMeta.JobID)
+ return importer.StartJob(ctx, exec, taskMeta.JobID, jobStep)
})
},
)
@@ -571,10 +662,10 @@ func job2Step(ctx context.Context, logger *zap.Logger, taskMeta *TaskMeta, step
func (dsp *ImportDispatcherExt) finishJob(ctx context.Context, logger *zap.Logger,
taskHandle dispatcher.TaskHandle, gTask *proto.Task, taskMeta *TaskMeta) error {
dsp.unregisterTask(ctx, gTask)
- redactSensitiveInfo(gTask, taskMeta)
- if err := taskHandle.UpdateTask(gTask.State, nil, dispatcher.RetrySQLTimes); err != nil {
- return err
+ if dsp.GlobalSort {
+ cleanUpGlobalSortedData(ctx, gTask, taskMeta)
}
+ redactSensitiveInfo(gTask, taskMeta)
summary := &importer.JobSummary{ImportedRows: taskMeta.Result.LoadedRowCnt}
// retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes
backoffer := backoff.NewExponential(dispatcher.RetrySQLInterval, 2, dispatcher.RetrySQLMaxInterval)
@@ -592,10 +683,10 @@ func (dsp *ImportDispatcherExt) failJob(ctx context.Context, taskHandle dispatch
taskMeta *TaskMeta, logger *zap.Logger, errorMsg string) error {
dsp.switchTiKV2NormalMode(ctx, gTask, logger)
dsp.unregisterTask(ctx, gTask)
- redactSensitiveInfo(gTask, taskMeta)
- if err := taskHandle.UpdateTask(gTask.State, nil, dispatcher.RetrySQLTimes); err != nil {
- return err
+ if dsp.GlobalSort {
+ cleanUpGlobalSortedData(ctx, gTask, taskMeta)
}
+ redactSensitiveInfo(gTask, taskMeta)
// retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes
backoffer := backoff.NewExponential(dispatcher.RetrySQLInterval, 2, dispatcher.RetrySQLMaxInterval)
return handle.RunWithRetry(ctx, dispatcher.RetrySQLTimes, backoffer, logger,
@@ -608,9 +699,35 @@ func (dsp *ImportDispatcherExt) failJob(ctx context.Context, taskHandle dispatch
)
}
+func cleanUpGlobalSortedData(ctx context.Context, gTask *proto.Task, taskMeta *TaskMeta) {
+ // we can only clean up files after all write&ingest subtasks are finished,
+ // since they might share the same file.
+ // we don't return error here, since the task is already done, we should
+ // return success if the task is success.
+ // TODO: maybe add a way to notify user that there are files left in global sorted storage.
+ logger := logutil.BgLogger().With(zap.Int64("task-id", gTask.ID))
+ callLog := log.BeginTask(logger, "cleanup global sorted data")
+ defer callLog.End(zap.InfoLevel, nil)
+
+ controller, err := buildController(&taskMeta.Plan, taskMeta.Stmt)
+ if err != nil {
+ logger.Warn("failed to build controller", zap.Error(err))
+ }
+ if err = controller.InitDataStore(ctx); err != nil {
+ logger.Warn("failed to init data store", zap.Error(err))
+ }
+ if err = external.CleanUpFiles(ctx, controller.GlobalSortStore,
+ strconv.Itoa(int(gTask.ID)), uint(taskMeta.Plan.ThreadCnt)); err != nil {
+ logger.Warn("failed to clean up files of task", zap.Error(err))
+ }
+}
+
func redactSensitiveInfo(gTask *proto.Task, taskMeta *TaskMeta) {
taskMeta.Stmt = ""
taskMeta.Plan.Path = ast.RedactURL(taskMeta.Plan.Path)
+ if taskMeta.Plan.CloudStorageURI != "" {
+ taskMeta.Plan.CloudStorageURI = ast.RedactURL(taskMeta.Plan.CloudStorageURI)
+ }
if err := updateMeta(gTask, taskMeta); err != nil {
// marshal failed, should not happen
logutil.BgLogger().Warn("failed to update task meta", zap.Error(err))
@@ -627,7 +744,7 @@ func rollback(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Ta
taskMeta := &TaskMeta{}
err = json.Unmarshal(gTask.Meta, taskMeta)
if err != nil {
- return err
+ return errors.Trace(err)
}
logger.Info("rollback")
@@ -646,10 +763,20 @@ func rollback(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Ta
func stepStr(step int64) string {
switch step {
+ case proto.StepInit:
+ return "init"
case StepImport:
return "import"
case StepPostProcess:
- return "postprocess"
+ return "post-process"
+ case StepEncodeAndSort:
+ return "encode&sort"
+ case StepMergeSort:
+ return "merge-sort"
+ case StepWriteAndIngest:
+ return "write&ingest"
+ case proto.StepDone:
+ return "done"
default:
return "unknown"
}
diff --git a/disttask/importinto/dispatcher_test.go b/disttask/importinto/dispatcher_test.go
index 86fd23938f1c4..fa9fc11e0c0dd 100644
--- a/disttask/importinto/dispatcher_test.go
+++ b/disttask/importinto/dispatcher_test.go
@@ -21,6 +21,7 @@ import (
"testing"
"github.com/pingcap/failpoint"
+ "github.com/pingcap/tidb/disttask/framework/dispatcher"
"github.com/pingcap/tidb/disttask/framework/proto"
"github.com/pingcap/tidb/domain/infosync"
"github.com/pingcap/tidb/executor/importer"
@@ -106,3 +107,70 @@ func (s *importIntoSuite) TestUpdateCurrentTask() {
require.Equal(s.T(), int64(1), dsp.currTaskID.Load())
require.True(s.T(), dsp.disableTiKVImportMode.Load())
}
+
+func (s *importIntoSuite) TestDispatcherInit() {
+ meta := TaskMeta{
+ Plan: importer.Plan{
+ CloudStorageURI: "",
+ },
+ }
+ bytes, err := json.Marshal(meta)
+ s.NoError(err)
+ dsp := importDispatcher{
+ BaseDispatcher: &dispatcher.BaseDispatcher{
+ Task: &proto.Task{
+ Meta: bytes,
+ },
+ },
+ }
+ s.NoError(dsp.Init())
+ s.False(dsp.Extension.(*ImportDispatcherExt).GlobalSort)
+
+ meta.Plan.CloudStorageURI = "s3://test"
+ bytes, err = json.Marshal(meta)
+ s.NoError(err)
+ dsp = importDispatcher{
+ BaseDispatcher: &dispatcher.BaseDispatcher{
+ Task: &proto.Task{
+ Meta: bytes,
+ },
+ },
+ }
+ s.NoError(dsp.Init())
+ s.True(dsp.Extension.(*ImportDispatcherExt).GlobalSort)
+}
+
+func (s *importIntoSuite) TestGetNextStep() {
+ task := &proto.Task{
+ Step: proto.StepInit,
+ }
+ ext := &ImportDispatcherExt{}
+ for _, nextStep := range []int64{StepImport, StepPostProcess, proto.StepDone} {
+ s.Equal(nextStep, ext.GetNextStep(nil, task))
+ task.Step = nextStep
+ }
+
+ task.Step = proto.StepInit
+ ext = &ImportDispatcherExt{GlobalSort: true}
+ for _, nextStep := range []int64{StepEncodeAndSort, StepMergeSort,
+ StepWriteAndIngest, StepPostProcess, proto.StepDone} {
+ s.Equal(nextStep, ext.GetNextStep(nil, task))
+ task.Step = nextStep
+ }
+}
+
+func (s *importIntoSuite) TestStr() {
+ s.Equal("init", stepStr(proto.StepInit))
+ s.Equal("import", stepStr(StepImport))
+ s.Equal("post-process", stepStr(StepPostProcess))
+ s.Equal("merge-sort", stepStr(StepMergeSort))
+ s.Equal("encode&sort", stepStr(StepEncodeAndSort))
+ s.Equal("write&ingest", stepStr(StepWriteAndIngest))
+ s.Equal("done", stepStr(proto.StepDone))
+ s.Equal("unknown", stepStr(111))
+}
+
+func (s *importIntoSuite) TestGetStepOfEncode() {
+ s.Equal(StepImport, getStepOfEncode(false))
+ s.Equal(StepEncodeAndSort, getStepOfEncode(true))
+}
diff --git a/disttask/importinto/dispatcher_testkit_test.go b/disttask/importinto/dispatcher_testkit_test.go
index e359fbe0853a2..c779d70f524ac 100644
--- a/disttask/importinto/dispatcher_testkit_test.go
+++ b/disttask/importinto/dispatcher_testkit_test.go
@@ -22,6 +22,8 @@ import (
"github.com/ngaut/pools"
"github.com/pingcap/errors"
+ "github.com/pingcap/failpoint"
+ "github.com/pingcap/tidb/br/pkg/lightning/backend/external"
"github.com/pingcap/tidb/disttask/framework/dispatcher"
"github.com/pingcap/tidb/disttask/framework/proto"
"github.com/pingcap/tidb/disttask/framework/storage"
@@ -35,7 +37,7 @@ import (
"github.com/tikv/client-go/v2/util"
)
-func TestDispatcherExt(t *testing.T) {
+func TestDispatcherExtLocalSort(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
pool := pools.NewResourcePool(func() (pools.Resource, error) {
@@ -89,9 +91,10 @@ func TestDispatcherExt(t *testing.T) {
// to import stage, job should be running
d := dsp.MockDispatcher(task)
ext := importinto.ImportDispatcherExt{}
- subtaskMetas, err := ext.OnNextSubtasksBatch(ctx, d, task)
+ subtaskMetas, err := ext.OnNextSubtasksBatch(ctx, d, task, ext.GetNextStep(d, task))
require.NoError(t, err)
require.Len(t, subtaskMetas, 1)
+ task.Step = ext.GetNextStep(d, task)
require.Equal(t, importinto.StepImport, task.Step)
gotJobInfo, err = importer.GetJob(ctx, conn, jobID, "root", true)
require.NoError(t, err)
@@ -99,7 +102,7 @@ func TestDispatcherExt(t *testing.T) {
// update task/subtask, and finish subtask, so we can go to next stage
subtasks := make([]*proto.Subtask, 0, len(subtaskMetas))
for _, m := range subtaskMetas {
- subtasks = append(subtasks, proto.NewSubtask(task.ID, task.Type, "", m))
+ subtasks = append(subtasks, proto.NewSubtask(task.Step, task.ID, task.Type, "", m))
}
_, err = manager.UpdateGlobalTaskAndAddSubTasks(task, subtasks, proto.TaskStatePending)
require.NoError(t, err)
@@ -109,20 +112,21 @@ func TestDispatcherExt(t *testing.T) {
require.NoError(t, manager.FinishSubtask(s.ID, []byte("{}")))
}
// to post-process stage, job should be running and in validating step
- task.Step++
- subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, d, task)
+ subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, d, task, ext.GetNextStep(d, task))
require.NoError(t, err)
require.Len(t, subtaskMetas, 1)
+ task.Step = ext.GetNextStep(d, task)
require.Equal(t, importinto.StepPostProcess, task.Step)
gotJobInfo, err = importer.GetJob(ctx, conn, jobID, "root", true)
require.NoError(t, err)
require.Equal(t, "running", gotJobInfo.Status)
require.Equal(t, "validating", gotJobInfo.Step)
// on next stage, job should be finished
- task.Step++
- subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, d, task)
+ subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, d, task, ext.GetNextStep(d, task))
require.NoError(t, err)
require.Len(t, subtaskMetas, 0)
+ task.Step = ext.GetNextStep(d, task)
+ require.Equal(t, proto.StepDone, task.Step)
gotJobInfo, err = importer.GetJob(ctx, conn, jobID, "root", true)
require.NoError(t, err)
require.Equal(t, "finished", gotJobInfo.Status)
@@ -137,10 +141,211 @@ func TestDispatcherExt(t *testing.T) {
task.Meta = bs
// Set step to StepPostProcess to skip the rollback sql.
task.Step = importinto.StepPostProcess
- require.NoError(t, importer.StartJob(ctx, conn, jobID))
+ require.NoError(t, importer.StartJob(ctx, conn, jobID, importer.JobStepImporting))
_, err = ext.OnErrStage(ctx, d, task, []error{errors.New("test")})
require.NoError(t, err)
gotJobInfo, err = importer.GetJob(ctx, conn, jobID, "root", true)
require.NoError(t, err)
require.Equal(t, "failed", gotJobInfo.Status)
}
+
+func TestDispatcherExtGlobalSort(t *testing.T) {
+ store := testkit.CreateMockStore(t)
+ tk := testkit.NewTestKit(t, store)
+ pool := pools.NewResourcePool(func() (pools.Resource, error) {
+ return tk.Session(), nil
+ }, 1, 1, time.Second)
+ defer pool.Close()
+ ctx := context.WithValue(context.Background(), "etcd", true)
+ mgr := storage.NewTaskManager(util.WithInternalSourceType(ctx, "taskManager"), pool)
+ storage.SetTaskManager(mgr)
+ dsp, err := dispatcher.NewManager(util.WithInternalSourceType(ctx, "dispatcher"), mgr, "host:port")
+ require.NoError(t, err)
+
+ // create job
+ conn := tk.Session().(sqlexec.SQLExecutor)
+ jobID, err := importer.CreateJob(ctx, conn, "test", "t", 1,
+ "root", &importer.ImportParameters{}, 123)
+ require.NoError(t, err)
+ gotJobInfo, err := importer.GetJob(ctx, conn, jobID, "root", true)
+ require.NoError(t, err)
+ require.Equal(t, "pending", gotJobInfo.Status)
+ logicalPlan := &importinto.LogicalPlan{
+ JobID: jobID,
+ Plan: importer.Plan{
+ Path: "gs://test-load/*.csv",
+ Format: "csv",
+ DBName: "test",
+ TableInfo: &model.TableInfo{
+ Name: model.NewCIStr("t"),
+ State: model.StatePublic,
+ },
+ DisableTiKVImportMode: true,
+ CloudStorageURI: "gs://sort-bucket",
+ InImportInto: true,
+ },
+ Stmt: `IMPORT INTO db.tb FROM 'gs://test-load/*.csv?endpoint=xxx'`,
+ EligibleInstances: []*infosync.ServerInfo{{ID: "1"}},
+ ChunkMap: map[int32][]importinto.Chunk{
+ 1: {{Path: "gs://test-load/1.csv"}},
+ 2: {{Path: "gs://test-load/2.csv"}},
+ },
+ }
+ bs, err := logicalPlan.ToTaskMeta()
+ require.NoError(t, err)
+ task := &proto.Task{
+ Type: proto.ImportInto,
+ Meta: bs,
+ Step: proto.StepInit,
+ State: proto.TaskStatePending,
+ StateUpdateTime: time.Now(),
+ }
+ manager, err := storage.GetTaskManager()
+ require.NoError(t, err)
+ taskMeta, err := json.Marshal(task)
+ require.NoError(t, err)
+ taskID, err := manager.AddNewGlobalTask(importinto.TaskKey(jobID), proto.ImportInto, 1, taskMeta)
+ require.NoError(t, err)
+ task.ID = taskID
+
+ // to encode-sort stage, job should be running
+ d := dsp.MockDispatcher(task)
+ ext := importinto.ImportDispatcherExt{
+ GlobalSort: true,
+ }
+ subtaskMetas, err := ext.OnNextSubtasksBatch(ctx, d, task, ext.GetNextStep(nil, task))
+ require.NoError(t, err)
+ require.Len(t, subtaskMetas, 2)
+ task.Step = ext.GetNextStep(nil, task)
+ require.Equal(t, importinto.StepEncodeAndSort, task.Step)
+ gotJobInfo, err = importer.GetJob(ctx, conn, jobID, "root", true)
+ require.NoError(t, err)
+ require.Equal(t, "running", gotJobInfo.Status)
+ require.Equal(t, "global-sorting", gotJobInfo.Step)
+ // update task/subtask, and finish subtask, so we can go to next stage
+ subtasks := make([]*proto.Subtask, 0, len(subtaskMetas))
+ for _, m := range subtaskMetas {
+ subtasks = append(subtasks, proto.NewSubtask(task.Step, task.ID, task.Type, "", m))
+ }
+ _, err = manager.UpdateGlobalTaskAndAddSubTasks(task, subtasks, proto.TaskStatePending)
+ require.NoError(t, err)
+ gotSubtasks, err := manager.GetSubtasksForImportInto(taskID, task.Step)
+ require.NoError(t, err)
+ sortStepMeta := &importinto.ImportStepMeta{
+ SortedDataMeta: &external.SortedKVMeta{
+ MinKey: []byte("ta"),
+ MaxKey: []byte("tc"),
+ TotalKVSize: 12,
+ DataFiles: []string{"gs://sort-bucket/data/1"},
+ StatFiles: []string{"gs://sort-bucket/data/1.stat"},
+ MultipleFilesStats: []external.MultipleFilesStat{
+ {
+ Filenames: [][2]string{
+ {"gs://sort-bucket/data/1", "gs://sort-bucket/data/1.stat"},
+ },
+ },
+ },
+ },
+ SortedIndexMetas: map[int64]*external.SortedKVMeta{
+ 1: {
+ MinKey: []byte("ia"),
+ MaxKey: []byte("ic"),
+ TotalKVSize: 12,
+ DataFiles: []string{"gs://sort-bucket/index/1"},
+ StatFiles: []string{"gs://sort-bucket/index/1.stat"},
+ MultipleFilesStats: []external.MultipleFilesStat{
+ {
+ Filenames: [][2]string{
+ {"gs://sort-bucket/index/1", "gs://sort-bucket/index/1.stat"},
+ },
+ },
+ },
+ },
+ },
+ }
+ sortStepMetaBytes, err := json.Marshal(sortStepMeta)
+ require.NoError(t, err)
+ for _, s := range gotSubtasks {
+ require.NoError(t, manager.FinishSubtask(s.ID, sortStepMetaBytes))
+ }
+
+ // to merge-sort stage
+ require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/importinto/forceMergeSort", `return("data")`))
+ t.Cleanup(func() {
+ require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/importinto/forceMergeSort"))
+ })
+ subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, d, task, ext.GetNextStep(nil, task))
+ require.NoError(t, err)
+ require.Len(t, subtaskMetas, 1)
+ task.Step = ext.GetNextStep(nil, task)
+ require.Equal(t, importinto.StepMergeSort, task.Step)
+ gotJobInfo, err = importer.GetJob(ctx, conn, jobID, "root", true)
+ require.NoError(t, err)
+ require.Equal(t, "running", gotJobInfo.Status)
+ require.Equal(t, "global-sorting", gotJobInfo.Step)
+ // update task/subtask, and finish subtask, so we can go to next stage
+ subtasks = make([]*proto.Subtask, 0, len(subtaskMetas))
+ for _, m := range subtaskMetas {
+ subtasks = append(subtasks, proto.NewSubtask(task.Step, task.ID, task.Type, "", m))
+ }
+ _, err = manager.UpdateGlobalTaskAndAddSubTasks(task, subtasks, proto.TaskStatePending)
+ require.NoError(t, err)
+ gotSubtasks, err = manager.GetSubtasksForImportInto(taskID, task.Step)
+ require.NoError(t, err)
+ mergeSortStepMeta := &importinto.MergeSortStepMeta{
+ KVGroup: "data",
+ SortedKVMeta: external.SortedKVMeta{
+ MinKey: []byte("ta"),
+ MaxKey: []byte("tc"),
+ TotalKVSize: 12,
+ DataFiles: []string{"gs://sort-bucket/data/1"},
+ StatFiles: []string{"gs://sort-bucket/data/1.stat"},
+ MultipleFilesStats: []external.MultipleFilesStat{
+ {
+ Filenames: [][2]string{
+ {"gs://sort-bucket/data/1", "gs://sort-bucket/data/1.stat"},
+ },
+ },
+ },
+ },
+ }
+ mergeSortStepMetaBytes, err := json.Marshal(mergeSortStepMeta)
+ require.NoError(t, err)
+ for _, s := range gotSubtasks {
+ require.NoError(t, manager.FinishSubtask(s.ID, mergeSortStepMetaBytes))
+ }
+
+ // to write-and-ingest stage
+ require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/importinto/mockWriteIngestSpecs", "return(true)"))
+ t.Cleanup(func() {
+ require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/importinto/mockWriteIngestSpecs"))
+ })
+ subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, d, task, ext.GetNextStep(nil, task))
+ require.NoError(t, err)
+ require.Len(t, subtaskMetas, 2)
+ task.Step = ext.GetNextStep(nil, task)
+ require.Equal(t, importinto.StepWriteAndIngest, task.Step)
+ gotJobInfo, err = importer.GetJob(ctx, conn, jobID, "root", true)
+ require.NoError(t, err)
+ require.Equal(t, "running", gotJobInfo.Status)
+ require.Equal(t, "importing", gotJobInfo.Step)
+ // on next stage, to post-process stage
+ subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, d, task, ext.GetNextStep(nil, task))
+ require.NoError(t, err)
+ require.Len(t, subtaskMetas, 1)
+ task.Step = ext.GetNextStep(nil, task)
+ require.Equal(t, importinto.StepPostProcess, task.Step)
+ gotJobInfo, err = importer.GetJob(ctx, conn, jobID, "root", true)
+ require.NoError(t, err)
+ require.Equal(t, "running", gotJobInfo.Status)
+ require.Equal(t, "validating", gotJobInfo.Step)
+ // next stage, done
+ subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, d, task, ext.GetNextStep(nil, task))
+ require.NoError(t, err)
+ require.Len(t, subtaskMetas, 0)
+ task.Step = ext.GetNextStep(nil, task)
+ require.Equal(t, proto.StepDone, task.Step)
+ gotJobInfo, err = importer.GetJob(ctx, conn, jobID, "root", true)
+ require.NoError(t, err)
+ require.Equal(t, "finished", gotJobInfo.Status)
+}
diff --git a/disttask/importinto/encode_and_sort_operator.go b/disttask/importinto/encode_and_sort_operator.go
index fa5e32e7b48f5..164b2477a4856 100644
--- a/disttask/importinto/encode_and_sort_operator.go
+++ b/disttask/importinto/encode_and_sort_operator.go
@@ -16,16 +16,35 @@ package importinto
import (
"context"
+ "path"
+ "strconv"
+ "time"
+ "github.com/google/uuid"
"github.com/pingcap/errors"
+ "github.com/pingcap/tidb/br/pkg/lightning/backend/external"
"github.com/pingcap/tidb/disttask/operator"
+ "github.com/pingcap/tidb/executor/importer"
+ "github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/resourcemanager/pool/workerpool"
"github.com/pingcap/tidb/resourcemanager/util"
tidbutil "github.com/pingcap/tidb/util"
+ "github.com/pingcap/tidb/util/size"
"go.uber.org/atomic"
"go.uber.org/zap"
)
+const (
+ maxWaitDuration = 30 * time.Second
+
+ // We limit the memory usage of KV deliver to 1GB per concurrency, and data
+ // KV deliver has external.DefaultMemSizeLimit, the rest of memory is for
+ // all index KV deliver.
+ // Note: this size is the memory taken by KV, not the size of taken by golang,
+ // each KV has additional 24*2 bytes overhead for golang slice.
+ indexKVTotalBufSize = size.GB - external.DefaultMemSizeLimit
+)
+
// encodeAndSortOperator is an operator that encodes and sorts data.
// this operator process data of a subtask, i.e. one engine, it contains a lot
// of data chunks, each chunk is a data file or part of it.
@@ -39,31 +58,36 @@ type encodeAndSortOperator struct {
ctx context.Context
cancel context.CancelFunc
- logger *zap.Logger
- errCh chan error
+ taskID, subtaskID int64
+ tableImporter *importer.TableImporter
+ sharedVars *SharedVars
+ logger *zap.Logger
+ errCh chan error
}
var _ operator.Operator = (*encodeAndSortOperator)(nil)
var _ operator.WithSource[*importStepMinimalTask] = (*encodeAndSortOperator)(nil)
var _ operator.WithSink[workerpool.None] = (*encodeAndSortOperator)(nil)
-func newEncodeAndSortOperator(ctx context.Context, concurrency int, logger *zap.Logger) *encodeAndSortOperator {
+func newEncodeAndSortOperator(ctx context.Context, executor *importStepExecutor,
+ sharedVars *SharedVars, subtaskID int64, indexMemorySizeLimit uint64) *encodeAndSortOperator {
subCtx, cancel := context.WithCancel(ctx)
op := &encodeAndSortOperator{
- ctx: subCtx,
- cancel: cancel,
- logger: logger,
- errCh: make(chan error),
+ ctx: subCtx,
+ cancel: cancel,
+ taskID: executor.taskID,
+ subtaskID: subtaskID,
+ tableImporter: executor.tableImporter,
+ sharedVars: sharedVars,
+ logger: executor.logger,
+ errCh: make(chan error),
}
pool := workerpool.NewWorkerPool(
"encodeAndSortOperator",
util.ImportInto,
- concurrency,
+ int(executor.taskMeta.Plan.ThreadCnt),
func() workerpool.Worker[*importStepMinimalTask, workerpool.None] {
- return &chunkWorker{
- ctx: subCtx,
- op: op,
- }
+ return newChunkWorker(ctx, op, indexMemorySizeLimit)
},
)
op.AsyncOperator = operator.NewAsyncOperator(subCtx, pool)
@@ -116,6 +140,44 @@ func (op *encodeAndSortOperator) Done() <-chan struct{} {
type chunkWorker struct {
ctx context.Context
op *encodeAndSortOperator
+
+ dataWriter *external.EngineWriter
+ indexWriter *importer.IndexRouteWriter
+}
+
+func newChunkWorker(ctx context.Context, op *encodeAndSortOperator, indexMemorySizeLimit uint64) *chunkWorker {
+ w := &chunkWorker{
+ ctx: ctx,
+ op: op,
+ }
+ if op.tableImporter.IsGlobalSort() {
+ // in case on network partition, 2 nodes might run the same subtask.
+ workerUUID := uuid.New().String()
+ // sorted index kv storage path: /{taskID}/{subtaskID}/index/{indexID}/{workerID}
+ indexWriterFn := func(indexID int64) *external.Writer {
+ builder := external.NewWriterBuilder().
+ SetOnCloseFunc(func(summary *external.WriterSummary) {
+ op.sharedVars.mergeIndexSummary(indexID, summary)
+ }).SetMemorySizeLimit(indexMemorySizeLimit)
+ prefix := subtaskPrefix(op.taskID, op.subtaskID)
+ // writer id for index: index/{indexID}/{workerID}
+ writerID := path.Join("index", strconv.Itoa(int(indexID)), workerUUID)
+ writer := builder.Build(op.tableImporter.GlobalSortStore, prefix, writerID)
+ return writer
+ }
+
+ // sorted data kv storage path: /{taskID}/{subtaskID}/data/{workerID}
+ builder := external.NewWriterBuilder().
+ SetOnCloseFunc(op.sharedVars.mergeDataSummary)
+ prefix := subtaskPrefix(op.taskID, op.subtaskID)
+ // writer id for data: data/{workerID}
+ writerID := path.Join("data", workerUUID)
+ writer := builder.Build(op.tableImporter.GlobalSortStore, prefix, writerID)
+ w.dataWriter = external.NewEngineWriter(writer)
+
+ w.indexWriter = importer.NewIndexRouteWriter(op.logger, indexWriterFn)
+ }
+ return w
}
func (w *chunkWorker) HandleTask(task *importStepMinimalTask, _ func(workerpool.None)) {
@@ -125,10 +187,66 @@ func (w *chunkWorker) HandleTask(task *importStepMinimalTask, _ func(workerpool.
// we don't use the input send function, it makes workflow more complex
// we send result to errCh and handle it here.
executor := newImportMinimalTaskExecutor(task)
- if err := executor.Run(w.ctx); err != nil {
+ if err := executor.Run(w.ctx, w.dataWriter, w.indexWriter); err != nil {
w.op.onError(err)
}
}
-func (*chunkWorker) Close() {
+func (w *chunkWorker) Close() {
+ closeCtx := w.ctx
+ if closeCtx.Err() != nil {
+ // in case of context canceled, we need to create a new context to close writers.
+ newCtx, cancel := context.WithTimeout(context.Background(), maxWaitDuration)
+ closeCtx = newCtx
+ defer cancel()
+ }
+ if w.dataWriter != nil {
+ // Note: we cannot ignore close error as we're writing to S3 or GCS.
+ // ignore error might cause data loss. below too.
+ if _, err := w.dataWriter.Close(closeCtx); err != nil {
+ w.op.onError(errors.Trace(err))
+ }
+ }
+ if w.indexWriter != nil {
+ if _, err := w.indexWriter.Close(closeCtx); err != nil {
+ w.op.onError(errors.Trace(err))
+ }
+ }
+}
+
+func subtaskPrefix(taskID, subtaskID int64) string {
+ return path.Join(strconv.Itoa(int(taskID)), strconv.Itoa(int(subtaskID)))
+}
+
+func getWriterMemorySizeLimit(plan *importer.Plan) uint64 {
+ // min(external.DefaultMemSizeLimit, indexKVTotalBufSize / num-of-index-that-gen-kv)
+ cnt := getNumOfIndexGenKV(plan.DesiredTableInfo)
+ limit := indexKVTotalBufSize
+ if cnt > 0 {
+ limit = limit / uint64(cnt)
+ }
+ if limit > external.DefaultMemSizeLimit {
+ limit = external.DefaultMemSizeLimit
+ }
+ return limit
+}
+
+func getNumOfIndexGenKV(tblInfo *model.TableInfo) int {
+ var count int
+ var nonClusteredPK bool
+ for _, idxInfo := range tblInfo.Indices {
+ // all public non-primary index generates index KVs
+ if idxInfo.State != model.StatePublic {
+ continue
+ }
+ if idxInfo.Primary && !tblInfo.HasClusteredIndex() {
+ nonClusteredPK = true
+ continue
+ }
+ count++
+ }
+ if nonClusteredPK {
+ count++
+ }
+ return count
}
diff --git a/disttask/importinto/encode_and_sort_operator_test.go b/disttask/importinto/encode_and_sort_operator_test.go
index 3a11043750bf8..3aa2ee0377732 100644
--- a/disttask/importinto/encode_and_sort_operator_test.go
+++ b/disttask/importinto/encode_and_sort_operator_test.go
@@ -24,8 +24,16 @@ import (
"time"
"github.com/pingcap/errors"
+ "github.com/pingcap/tidb/br/pkg/lightning/backend"
+ "github.com/pingcap/tidb/br/pkg/lightning/backend/external"
+ "github.com/pingcap/tidb/ddl"
"github.com/pingcap/tidb/disttask/importinto/mock"
"github.com/pingcap/tidb/disttask/operator"
+ "github.com/pingcap/tidb/executor/importer"
+ "github.com/pingcap/tidb/parser"
+ "github.com/pingcap/tidb/parser/ast"
+ "github.com/pingcap/tidb/parser/model"
+ utilmock "github.com/pingcap/tidb/util/mock"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"go.uber.org/zap"
@@ -54,15 +62,32 @@ func TestEncodeAndSortOperator(t *testing.T) {
return executor
}
+ executorForParam := &importStepExecutor{
+ taskID: 1,
+ taskMeta: &TaskMeta{
+ Plan: importer.Plan{
+ ThreadCnt: 2,
+ },
+ },
+ tableImporter: &importer.TableImporter{
+ LoadDataController: &importer.LoadDataController{
+ Plan: &importer.Plan{
+ CloudStorageURI: "",
+ },
+ },
+ },
+ logger: logger,
+ }
+
source := operator.NewSimpleDataChannel(make(chan *importStepMinimalTask))
- op := newEncodeAndSortOperator(context.Background(), 3, logger)
+ op := newEncodeAndSortOperator(context.Background(), executorForParam, nil, 3, 0)
op.SetSource(source)
require.NoError(t, op.Open())
require.Greater(t, len(op.String()), 0)
// cancel on error
mockErr := errors.New("mock err")
- executor.EXPECT().Run(gomock.Any()).Return(mockErr)
+ executor.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockErr)
source.Channel() <- &importStepMinimalTask{}
require.Eventually(t, func() bool {
return op.hasError()
@@ -75,7 +100,7 @@ func TestEncodeAndSortOperator(t *testing.T) {
// cancel on error and log other errors
mockErr2 := errors.New("mock err 2")
source = operator.NewSimpleDataChannel(make(chan *importStepMinimalTask))
- op = newEncodeAndSortOperator(context.Background(), 2, logger)
+ op = newEncodeAndSortOperator(context.Background(), executorForParam, nil, 2, 0)
op.SetSource(source)
executor1 := mock.NewMockMiniTaskExecutor(ctrl)
executor2 := mock.NewMockMiniTaskExecutor(ctrl)
@@ -89,20 +114,22 @@ func TestEncodeAndSortOperator(t *testing.T) {
var wg sync.WaitGroup
wg.Add(2)
// wait until 2 executor start running, else workerpool will be cancelled.
- executor1.EXPECT().Run(gomock.Any()).DoAndReturn(func(context.Context) error {
- wg.Done()
- wg.Wait()
- return mockErr2
- })
- executor2.EXPECT().Run(gomock.Any()).DoAndReturn(func(context.Context) error {
- wg.Done()
- wg.Wait()
- // wait error in executor1 has been processed
- require.Eventually(t, func() bool {
- return op.hasError()
- }, 3*time.Second, 300*time.Millisecond)
- return errors.New("mock error should be logged")
- })
+ executor1.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
+ func(context.Context, backend.EngineWriter, backend.EngineWriter) error {
+ wg.Done()
+ wg.Wait()
+ return mockErr2
+ })
+ executor2.EXPECT().Run(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
+ func(context.Context, backend.EngineWriter, backend.EngineWriter) error {
+ wg.Done()
+ wg.Wait()
+ // wait error in executor1 has been processed
+ require.Eventually(t, func() bool {
+ return op.hasError()
+ }, 3*time.Second, 300*time.Millisecond)
+ return errors.New("mock error should be logged")
+ })
require.NoError(t, op.Open())
// send 2 tasks
source.Channel() <- &importStepMinimalTask{}
@@ -115,3 +142,62 @@ func TestEncodeAndSortOperator(t *testing.T) {
require.NoError(t, err)
require.Contains(t, string(content), "mock error should be logged")
}
+
+func TestGetWriterMemorySizeLimit(t *testing.T) {
+ cases := []struct {
+ createSQL string
+ numOfIndexGenKV int
+ writerMemorySizeLimit uint64
+ }{
+ {
+ createSQL: "create table t (a int)",
+ numOfIndexGenKV: 0,
+ writerMemorySizeLimit: external.DefaultMemSizeLimit,
+ },
+ {
+ createSQL: "create table t (a int primary key clustered)",
+ numOfIndexGenKV: 0,
+ writerMemorySizeLimit: external.DefaultMemSizeLimit,
+ },
+ {
+ createSQL: "create table t (a int primary key nonclustered)",
+ numOfIndexGenKV: 1,
+ writerMemorySizeLimit: external.DefaultMemSizeLimit,
+ },
+ {
+ createSQL: "create table t (a int primary key clustered, b int, key(b))",
+ numOfIndexGenKV: 1,
+ writerMemorySizeLimit: external.DefaultMemSizeLimit,
+ },
+ {
+ createSQL: "create table t (a int primary key clustered, b int, key(b), key(a,b))",
+ numOfIndexGenKV: 2,
+ writerMemorySizeLimit: external.DefaultMemSizeLimit,
+ },
+ {
+ createSQL: "create table t (a int primary key clustered, b int, c int, key(b,c), unique(b), unique(c), key(a,b))",
+ numOfIndexGenKV: 4,
+ writerMemorySizeLimit: indexKVTotalBufSize / 4,
+ },
+ {
+ createSQL: "create table t (a int, b int, c int, primary key(a,b,c) nonclustered, key(b,c), unique(b), unique(c), key(a,b))",
+ numOfIndexGenKV: 5,
+ writerMemorySizeLimit: indexKVTotalBufSize / 5,
+ },
+ }
+
+ for _, c := range cases {
+ p := parser.New()
+ node, err := p.ParseOneStmt(c.createSQL, "", "")
+ require.NoError(t, err)
+ sctx := utilmock.NewContext()
+ info, err := ddl.MockTableInfo(sctx, node.(*ast.CreateTableStmt), 1)
+ require.NoError(t, err)
+ info.State = model.StatePublic
+
+ require.Equal(t, c.numOfIndexGenKV, getNumOfIndexGenKV(info), c.createSQL)
+ require.Equal(t, c.writerMemorySizeLimit, getWriterMemorySizeLimit(&importer.Plan{
+ DesiredTableInfo: info,
+ }), c.createSQL)
+ }
+}
diff --git a/disttask/importinto/job.go b/disttask/importinto/job.go
index 2f6917cb4bba8..e82524462a83a 100644
--- a/disttask/importinto/job.go
+++ b/disttask/importinto/job.go
@@ -245,7 +245,7 @@ func getTaskMeta(jobID int64) (*TaskMeta, error) {
}
var taskMeta TaskMeta
if err := json.Unmarshal(globalTask.Meta, &taskMeta); err != nil {
- return nil, err
+ return nil, errors.Trace(err)
}
return &taskMeta, nil
}
@@ -258,24 +258,42 @@ func GetTaskImportedRows(jobID int64) (uint64, error) {
return 0, err
}
taskKey := TaskKey(jobID)
- globalTask, err := globalTaskManager.GetGlobalTaskByKey(taskKey)
+ task, err := globalTaskManager.GetGlobalTaskByKey(taskKey)
if err != nil {
return 0, err
}
- if globalTask == nil {
+ if task == nil {
return 0, errors.Errorf("cannot find global task with key %s", taskKey)
}
- subtasks, err := globalTaskManager.GetSubtasksForImportInto(globalTask.ID, StepImport)
- if err != nil {
- return 0, err
+ taskMeta := TaskMeta{}
+ if err = json.Unmarshal(task.Meta, &taskMeta); err != nil {
+ return 0, errors.Trace(err)
}
var importedRows uint64
- for _, subtask := range subtasks {
- var subtaskMeta ImportStepMeta
- if err2 := json.Unmarshal(subtask.Meta, &subtaskMeta); err2 != nil {
- return 0, err2
+ if taskMeta.Plan.CloudStorageURI == "" {
+ subtasks, err := globalTaskManager.GetSubtasksForImportInto(task.ID, StepImport)
+ if err != nil {
+ return 0, err
+ }
+ for _, subtask := range subtasks {
+ var subtaskMeta ImportStepMeta
+ if err2 := json.Unmarshal(subtask.Meta, &subtaskMeta); err2 != nil {
+ return 0, errors.Trace(err2)
+ }
+ importedRows += subtaskMeta.Result.LoadedRowCnt
+ }
+ } else {
+ subtasks, err := globalTaskManager.GetSubtasksForImportInto(task.ID, StepWriteAndIngest)
+ if err != nil {
+ return 0, err
+ }
+ for _, subtask := range subtasks {
+ var subtaskMeta WriteIngestStepMeta
+ if err2 := json.Unmarshal(subtask.Meta, &subtaskMeta); err2 != nil {
+ return 0, errors.Trace(err2)
+ }
+ importedRows += subtaskMeta.Result.LoadedRowCnt
}
- importedRows += subtaskMeta.Result.LoadedRowCnt
}
return importedRows, nil
}
diff --git a/disttask/importinto/job_testkit_test.go b/disttask/importinto/job_testkit_test.go
new file mode 100644
index 0000000000000..3638feb2d59d6
--- /dev/null
+++ b/disttask/importinto/job_testkit_test.go
@@ -0,0 +1,107 @@
+// Copyright 2023 PingCAP, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package importinto_test
+
+import (
+ "context"
+ "encoding/json"
+ "testing"
+ "time"
+
+ "github.com/ngaut/pools"
+ "github.com/pingcap/tidb/disttask/framework/proto"
+ "github.com/pingcap/tidb/disttask/framework/storage"
+ "github.com/pingcap/tidb/disttask/importinto"
+ "github.com/pingcap/tidb/executor/importer"
+ "github.com/pingcap/tidb/testkit"
+ "github.com/stretchr/testify/require"
+ "github.com/tikv/client-go/v2/util"
+)
+
+func TestGetTaskImportedRows(t *testing.T) {
+ store := testkit.CreateMockStore(t)
+ tk := testkit.NewTestKit(t, store)
+ pool := pools.NewResourcePool(func() (pools.Resource, error) {
+ return tk.Session(), nil
+ }, 1, 1, time.Second)
+ defer pool.Close()
+ ctx := context.WithValue(context.Background(), "etcd", true)
+ mgr := storage.NewTaskManager(util.WithInternalSourceType(ctx, "taskManager"), pool)
+ storage.SetTaskManager(mgr)
+ manager, err := storage.GetTaskManager()
+ require.NoError(t, err)
+
+ // local sort
+ taskMeta := importinto.TaskMeta{
+ Plan: importer.Plan{},
+ }
+ bytes, err := json.Marshal(taskMeta)
+ require.NoError(t, err)
+ taskID, err := manager.AddNewGlobalTask(importinto.TaskKey(111), proto.ImportInto, 1, bytes)
+ require.NoError(t, err)
+ importStepMetas := []*importinto.ImportStepMeta{
+ {
+ Result: importinto.Result{
+ LoadedRowCnt: 1,
+ },
+ },
+ {
+ Result: importinto.Result{
+ LoadedRowCnt: 2,
+ },
+ },
+ }
+ for _, m := range importStepMetas {
+ bytes, err := json.Marshal(m)
+ require.NoError(t, err)
+ require.NoError(t, manager.AddNewSubTask(taskID, importinto.StepImport,
+ "", bytes, proto.ImportInto, false))
+ }
+ rows, err := importinto.GetTaskImportedRows(111)
+ require.NoError(t, err)
+ require.Equal(t, uint64(3), rows)
+
+ // global sort
+ taskMeta = importinto.TaskMeta{
+ Plan: importer.Plan{
+ CloudStorageURI: "s3://test-bucket/test-path",
+ },
+ }
+ bytes, err = json.Marshal(taskMeta)
+ require.NoError(t, err)
+ taskID, err = manager.AddNewGlobalTask(importinto.TaskKey(222), proto.ImportInto, 1, bytes)
+ require.NoError(t, err)
+ ingestStepMetas := []*importinto.WriteIngestStepMeta{
+ {
+ Result: importinto.Result{
+ LoadedRowCnt: 11,
+ },
+ },
+ {
+ Result: importinto.Result{
+ LoadedRowCnt: 22,
+ },
+ },
+ }
+ for _, m := range ingestStepMetas {
+ bytes, err := json.Marshal(m)
+ require.NoError(t, err)
+ require.NoError(t, manager.AddNewSubTask(taskID, importinto.StepWriteAndIngest,
+ "", bytes, proto.ImportInto, false))
+ }
+ rows, err = importinto.GetTaskImportedRows(222)
+ require.NoError(t, err)
+ require.Equal(t, uint64(33), rows)
+}
diff --git a/disttask/importinto/mock/BUILD.bazel b/disttask/importinto/mock/BUILD.bazel
index 9a780c155376d..902ed4332b31a 100644
--- a/disttask/importinto/mock/BUILD.bazel
+++ b/disttask/importinto/mock/BUILD.bazel
@@ -5,5 +5,8 @@ go_library(
srcs = ["import_mock.go"],
importpath = "github.com/pingcap/tidb/disttask/importinto/mock",
visibility = ["//visibility:public"],
- deps = ["@org_uber_go_mock//gomock"],
+ deps = [
+ "//br/pkg/lightning/backend",
+ "@org_uber_go_mock//gomock",
+ ],
)
diff --git a/disttask/importinto/mock/import_mock.go b/disttask/importinto/mock/import_mock.go
index 63cce354f2db7..e5db685538d32 100644
--- a/disttask/importinto/mock/import_mock.go
+++ b/disttask/importinto/mock/import_mock.go
@@ -8,6 +8,7 @@ import (
context "context"
reflect "reflect"
+ backend "github.com/pingcap/tidb/br/pkg/lightning/backend"
gomock "go.uber.org/mock/gomock"
)
@@ -35,15 +36,15 @@ func (m *MockMiniTaskExecutor) EXPECT() *MockMiniTaskExecutorMockRecorder {
}
// Run mocks base method.
-func (m *MockMiniTaskExecutor) Run(arg0 context.Context) error {
+func (m *MockMiniTaskExecutor) Run(arg0 context.Context, arg1, arg2 backend.EngineWriter) error {
m.ctrl.T.Helper()
- ret := m.ctrl.Call(m, "Run", arg0)
+ ret := m.ctrl.Call(m, "Run", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// Run indicates an expected call of Run.
-func (mr *MockMiniTaskExecutorMockRecorder) Run(arg0 interface{}) *gomock.Call {
+func (mr *MockMiniTaskExecutorMockRecorder) Run(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
- return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockMiniTaskExecutor)(nil).Run), arg0)
+ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockMiniTaskExecutor)(nil).Run), arg0, arg1, arg2)
}
diff --git a/disttask/importinto/planner.go b/disttask/importinto/planner.go
index 5eb2483dbf188..6713070d165ef 100644
--- a/disttask/importinto/planner.go
+++ b/disttask/importinto/planner.go
@@ -16,17 +16,28 @@ package importinto
import (
"context"
+ "encoding/hex"
"encoding/json"
+ "math"
+ "strconv"
+ "github.com/pingcap/errors"
+ "github.com/pingcap/failpoint"
+ "github.com/pingcap/tidb/br/pkg/lightning/backend/external"
"github.com/pingcap/tidb/br/pkg/lightning/backend/kv"
"github.com/pingcap/tidb/br/pkg/lightning/common"
+ "github.com/pingcap/tidb/br/pkg/lightning/config"
verify "github.com/pingcap/tidb/br/pkg/lightning/verification"
+ "github.com/pingcap/tidb/br/pkg/storage"
"github.com/pingcap/tidb/disttask/framework/planner"
"github.com/pingcap/tidb/domain/infosync"
"github.com/pingcap/tidb/executor/importer"
+ tidbkv "github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/meta/autoid"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/table/tables"
+ "github.com/pingcap/tidb/util/logutil"
+ "go.uber.org/zap"
)
var (
@@ -60,7 +71,7 @@ func (p *LogicalPlan) ToTaskMeta() ([]byte, error) {
func (p *LogicalPlan) FromTaskMeta(bs []byte) error {
var taskMeta TaskMeta
if err := json.Unmarshal(bs, &taskMeta); err != nil {
- return err
+ return errors.Trace(err)
}
p.JobID = taskMeta.JobID
p.Plan = taskMeta.Plan
@@ -74,50 +85,68 @@ func (p *LogicalPlan) FromTaskMeta(bs []byte) error {
func (p *LogicalPlan) ToPhysicalPlan(planCtx planner.PlanCtx) (*planner.PhysicalPlan, error) {
physicalPlan := &planner.PhysicalPlan{}
inputLinks := make([]planner.LinkSpec, 0)
- // physical plan only needs to be generated once.
- // However, our current implementation requires generating it for each step.
- // Only the first step needs to generate import specs.
- // This is a fast path to bypass generating import spec multiple times (as we need to access the source data).
- if len(planCtx.PreviousSubtaskMetas) == 0 {
- importSpecs, err := generateImportSpecs(planCtx.Ctx, p)
- if err != nil {
- return nil, err
- }
-
- for i, importSpec := range importSpecs {
+ addSpecs := func(specs []planner.PipelineSpec) {
+ for i, spec := range specs {
physicalPlan.AddProcessor(planner.ProcessorSpec{
ID: i,
- Pipeline: importSpec,
+ Pipeline: spec,
Output: planner.OutputSpec{
Links: []planner.LinkSpec{
{
- ProcessorID: len(importSpecs),
+ ProcessorID: len(specs),
},
},
},
- Step: StepImport,
+ Step: planCtx.NextTaskStep,
})
inputLinks = append(inputLinks, planner.LinkSpec{
ProcessorID: i,
})
}
}
+ // physical plan only needs to be generated once.
+ // However, our current implementation requires generating it for each step.
+ // we only generate needed plans for the next step.
+ switch planCtx.NextTaskStep {
+ case StepImport, StepEncodeAndSort:
+ specs, err := generateImportSpecs(planCtx.Ctx, p)
+ if err != nil {
+ return nil, err
+ }
+
+ addSpecs(specs)
+ case StepMergeSort:
+ specs, err := generateMergeSortSpecs(planCtx)
+ if err != nil {
+ return nil, err
+ }
- physicalPlan.AddProcessor(planner.ProcessorSpec{
- ID: len(inputLinks),
- Input: planner.InputSpec{
- ColumnTypes: []byte{
- // Checksum_crc64_xor, Total_kvs, Total_bytes, ReadRowCnt, LoadedRowCnt, ColSizeMap
- mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeJSON,
+ addSpecs(specs)
+ case StepWriteAndIngest:
+ specs, err := generateWriteIngestSpecs(planCtx, p)
+ if err != nil {
+ return nil, err
+ }
+
+ addSpecs(specs)
+ case StepPostProcess:
+ physicalPlan.AddProcessor(planner.ProcessorSpec{
+ ID: len(inputLinks),
+ Input: planner.InputSpec{
+ ColumnTypes: []byte{
+ // Checksum_crc64_xor, Total_kvs, Total_bytes, ReadRowCnt, LoadedRowCnt, ColSizeMap
+ mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeJSON,
+ },
+ Links: inputLinks,
},
- Links: inputLinks,
- },
- Pipeline: &PostProcessSpec{
- Schema: p.Plan.DBName,
- Table: p.Plan.TableInfo.Name.L,
- },
- Step: StepPostProcess,
- })
+ Pipeline: &PostProcessSpec{
+ Schema: p.Plan.DBName,
+ Table: p.Plan.TableInfo.Name.L,
+ },
+ Step: planCtx.NextTaskStep,
+ })
+ }
+
return physicalPlan, nil
}
@@ -137,6 +166,26 @@ func (s *ImportSpec) ToSubtaskMeta(planner.PlanCtx) ([]byte, error) {
return json.Marshal(importStepMeta)
}
+// WriteIngestSpec is the specification of a write-ingest pipeline.
+type WriteIngestSpec struct {
+ *WriteIngestStepMeta
+}
+
+// ToSubtaskMeta converts the write-ingest spec to subtask meta.
+func (s *WriteIngestSpec) ToSubtaskMeta(planner.PlanCtx) ([]byte, error) {
+ return json.Marshal(s.WriteIngestStepMeta)
+}
+
+// MergeSortSpec is the specification of a merge-sort pipeline.
+type MergeSortSpec struct {
+ *MergeSortStepMeta
+}
+
+// ToSubtaskMeta converts the merge-sort spec to subtask meta.
+func (s *MergeSortSpec) ToSubtaskMeta(planner.PlanCtx) ([]byte, error) {
+ return json.Marshal(s.MergeSortStepMeta)
+}
+
// PostProcessSpec is the specification of a post process pipeline.
type PostProcessSpec struct {
// for checksum request
@@ -146,11 +195,12 @@ type PostProcessSpec struct {
// ToSubtaskMeta converts the post process spec to subtask meta.
func (*PostProcessSpec) ToSubtaskMeta(planCtx planner.PlanCtx) ([]byte, error) {
+ encodeStep := getStepOfEncode(planCtx.GlobalSort)
subtaskMetas := make([]*ImportStepMeta, 0, len(planCtx.PreviousSubtaskMetas))
- for _, bs := range planCtx.PreviousSubtaskMetas {
+ for _, bs := range planCtx.PreviousSubtaskMetas[encodeStep] {
var subtaskMeta ImportStepMeta
if err := json.Unmarshal(bs, &subtaskMeta); err != nil {
- return nil, err
+ return nil, errors.Trace(err)
}
subtaskMetas = append(subtaskMetas, &subtaskMeta)
}
@@ -177,30 +227,34 @@ func (*PostProcessSpec) ToSubtaskMeta(planCtx planner.PlanCtx) ([]byte, error) {
return json.Marshal(postProcessStepMeta)
}
-func buildController(p *LogicalPlan) (*importer.LoadDataController, error) {
+func buildControllerForPlan(p *LogicalPlan) (*importer.LoadDataController, error) {
+ return buildController(&p.Plan, p.Stmt)
+}
+
+func buildController(plan *importer.Plan, stmt string) (*importer.LoadDataController, error) {
idAlloc := kv.NewPanickingAllocators(0)
- tbl, err := tables.TableFromMeta(idAlloc, p.Plan.TableInfo)
+ tbl, err := tables.TableFromMeta(idAlloc, plan.TableInfo)
if err != nil {
return nil, err
}
- astArgs, err := importer.ASTArgsFromStmt(p.Stmt)
+ astArgs, err := importer.ASTArgsFromStmt(stmt)
if err != nil {
return nil, err
}
- controller, err := importer.NewLoadDataController(&p.Plan, tbl, astArgs)
+ controller, err := importer.NewLoadDataController(plan, tbl, astArgs)
if err != nil {
return nil, err
}
return controller, nil
}
-func generateImportSpecs(ctx context.Context, p *LogicalPlan) ([]*ImportSpec, error) {
+func generateImportSpecs(ctx context.Context, p *LogicalPlan) ([]planner.PipelineSpec, error) {
var chunkMap map[int32][]Chunk
if len(p.ChunkMap) > 0 {
chunkMap = p.ChunkMap
} else {
- controller, err2 := buildController(p)
+ controller, err2 := buildControllerForPlan(p)
if err2 != nil {
return nil, err2
}
@@ -214,7 +268,7 @@ func generateImportSpecs(ctx context.Context, p *LogicalPlan) ([]*ImportSpec, er
}
chunkMap = toChunkMap(engineCheckpoints)
}
- importSpecs := make([]*ImportSpec, 0, len(chunkMap))
+ importSpecs := make([]planner.PipelineSpec, 0, len(chunkMap))
for id := range chunkMap {
if id == common.IndexEngineID {
continue
@@ -228,3 +282,224 @@ func generateImportSpecs(ctx context.Context, p *LogicalPlan) ([]*ImportSpec, er
}
return importSpecs, nil
}
+
+func skipMergeSort(kvGroup string, stats []external.MultipleFilesStat) bool {
+ failpoint.Inject("forceMergeSort", func(val failpoint.Value) {
+ in := val.(string)
+ if in == kvGroup || in == "*" {
+ failpoint.Return(false)
+ }
+ })
+ return external.GetMaxOverlappingTotal(stats) <= external.MergeSortOverlapThreshold
+}
+
+func generateMergeSortSpecs(planCtx planner.PlanCtx) ([]planner.PipelineSpec, error) {
+ step := external.MergeSortFileCountStep
+ result := make([]planner.PipelineSpec, 0, 16)
+ kvMetas, err := getSortedKVMetasOfEncodeStep(planCtx.PreviousSubtaskMetas[StepEncodeAndSort])
+ if err != nil {
+ return nil, err
+ }
+ for kvGroup, kvMeta := range kvMetas {
+ length := len(kvMeta.DataFiles)
+ if skipMergeSort(kvGroup, kvMeta.MultipleFilesStats) {
+ logutil.Logger(planCtx.Ctx).Info("skip merge sort for kv group",
+ zap.Int64("task-id", planCtx.TaskID),
+ zap.String("kv-group", kvGroup))
+ continue
+ }
+ for start := 0; start < length; start += step {
+ end := start + step
+ if end > length {
+ end = length
+ }
+ result = append(result, &MergeSortSpec{
+ MergeSortStepMeta: &MergeSortStepMeta{
+ KVGroup: kvGroup,
+ DataFiles: kvMeta.DataFiles[start:end],
+ },
+ })
+ }
+ }
+ return result, nil
+}
+
+func generateWriteIngestSpecs(planCtx planner.PlanCtx, p *LogicalPlan) ([]planner.PipelineSpec, error) {
+ ctx := planCtx.Ctx
+ controller, err2 := buildControllerForPlan(p)
+ if err2 != nil {
+ return nil, err2
+ }
+ if err2 = controller.InitDataStore(ctx); err2 != nil {
+ return nil, err2
+ }
+ // kvMetas contains data kv meta and all index kv metas.
+ // each kvMeta will be split into multiple range group individually,
+ // i.e. data and index kv will NOT be in the same subtask.
+ kvMetas, err := getSortedKVMetasForIngest(planCtx)
+ if err != nil {
+ return nil, err
+ }
+ failpoint.Inject("mockWriteIngestSpecs", func() {
+ failpoint.Return([]planner.PipelineSpec{
+ &WriteIngestSpec{
+ WriteIngestStepMeta: &WriteIngestStepMeta{
+ KVGroup: dataKVGroup,
+ },
+ },
+ &WriteIngestSpec{
+ WriteIngestStepMeta: &WriteIngestStepMeta{
+ KVGroup: "1",
+ },
+ },
+ }, nil)
+ })
+ specs := make([]planner.PipelineSpec, 0, 16)
+ for kvGroup, kvMeta := range kvMetas {
+ splitter, err1 := getRangeSplitter(ctx, controller.GlobalSortStore, kvMeta)
+ if err1 != nil {
+ return nil, err1
+ }
+
+ err1 = func() error {
+ defer func() {
+ err2 := splitter.Close()
+ if err2 != nil {
+ logutil.Logger(ctx).Warn("close range splitter failed", zap.Error(err2))
+ }
+ }()
+ startKey := tidbkv.Key(kvMeta.MinKey)
+ var endKey tidbkv.Key
+ for {
+ endKeyOfGroup, dataFiles, statFiles, rangeSplitKeys, err2 := splitter.SplitOneRangesGroup()
+ if err2 != nil {
+ return err2
+ }
+ if len(endKeyOfGroup) == 0 {
+ endKey = tidbkv.Key(kvMeta.MaxKey).Next()
+ } else {
+ endKey = tidbkv.Key(endKeyOfGroup).Clone()
+ }
+ logutil.Logger(ctx).Info("kv range as subtask",
+ zap.String("startKey", hex.EncodeToString(startKey)),
+ zap.String("endKey", hex.EncodeToString(endKey)))
+ if startKey.Cmp(endKey) >= 0 {
+ return errors.Errorf("invalid kv range, startKey: %s, endKey: %s",
+ hex.EncodeToString(startKey), hex.EncodeToString(endKey))
+ }
+ // each subtask will write and ingest one range group
+ m := &WriteIngestStepMeta{
+ KVGroup: kvGroup,
+ SortedKVMeta: external.SortedKVMeta{
+ MinKey: startKey,
+ MaxKey: endKey,
+ DataFiles: dataFiles,
+ StatFiles: statFiles,
+ // this is actually an estimate, we don't know the exact size of the data
+ TotalKVSize: uint64(config.DefaultBatchSize),
+ },
+ RangeSplitKeys: rangeSplitKeys,
+ }
+ specs = append(specs, &WriteIngestSpec{m})
+
+ startKey = endKey
+ if len(endKeyOfGroup) == 0 {
+ break
+ }
+ }
+ return nil
+ }()
+ if err1 != nil {
+ return nil, err1
+ }
+ }
+ return specs, nil
+}
+
+func getSortedKVMetasOfEncodeStep(subTaskMetas [][]byte) (map[string]*external.SortedKVMeta, error) {
+ dataKVMeta := &external.SortedKVMeta{}
+ indexKVMetas := make(map[int64]*external.SortedKVMeta)
+ for _, subTaskMeta := range subTaskMetas {
+ var stepMeta ImportStepMeta
+ err := json.Unmarshal(subTaskMeta, &stepMeta)
+ if err != nil {
+ return nil, errors.Trace(err)
+ }
+ dataKVMeta.Merge(stepMeta.SortedDataMeta)
+ for indexID, sortedIndexMeta := range stepMeta.SortedIndexMetas {
+ if item, ok := indexKVMetas[indexID]; !ok {
+ indexKVMetas[indexID] = sortedIndexMeta
+ } else {
+ item.Merge(sortedIndexMeta)
+ }
+ }
+ }
+ res := make(map[string]*external.SortedKVMeta, 1+len(indexKVMetas))
+ res[dataKVGroup] = dataKVMeta
+ for indexID, item := range indexKVMetas {
+ res[strconv.Itoa(int(indexID))] = item
+ }
+ return res, nil
+}
+
+func getSortedKVMetasOfMergeStep(subTaskMetas [][]byte) (map[string]*external.SortedKVMeta, error) {
+ result := make(map[string]*external.SortedKVMeta, len(subTaskMetas))
+ for _, subTaskMeta := range subTaskMetas {
+ var stepMeta MergeSortStepMeta
+ err := json.Unmarshal(subTaskMeta, &stepMeta)
+ if err != nil {
+ return nil, errors.Trace(err)
+ }
+ meta, ok := result[stepMeta.KVGroup]
+ if !ok {
+ result[stepMeta.KVGroup] = &stepMeta.SortedKVMeta
+ continue
+ }
+ meta.Merge(&stepMeta.SortedKVMeta)
+ }
+ return result, nil
+}
+
+func getSortedKVMetasForIngest(planCtx planner.PlanCtx) (map[string]*external.SortedKVMeta, error) {
+ kvMetasOfMergeSort, err := getSortedKVMetasOfMergeStep(planCtx.PreviousSubtaskMetas[StepMergeSort])
+ if err != nil {
+ return nil, err
+ }
+ kvMetasOfEncodeStep, err := getSortedKVMetasOfEncodeStep(planCtx.PreviousSubtaskMetas[StepEncodeAndSort])
+ if err != nil {
+ return nil, err
+ }
+ for kvGroup, kvMeta := range kvMetasOfEncodeStep {
+ // only part of kv files are merge sorted. we need to merge kv metas that
+ // are not merged into the kvMetasOfMergeSort.
+ if skipMergeSort(kvGroup, kvMeta.MultipleFilesStats) {
+ if _, ok := kvMetasOfMergeSort[kvGroup]; ok {
+ // this should not happen, because we only generate merge sort
+ // subtasks for those kv groups with MaxOverlappingTotal > MergeSortOverlapThreshold
+ logutil.Logger(planCtx.Ctx).Error("kv group of encode step conflict with merge sort step")
+ return nil, errors.New("kv group of encode step conflict with merge sort step")
+ }
+ kvMetasOfMergeSort[kvGroup] = kvMeta
+ }
+ }
+ return kvMetasOfMergeSort, nil
+}
+
+func getRangeSplitter(ctx context.Context, store storage.ExternalStorage, kvMeta *external.SortedKVMeta) (
+ *external.RangeSplitter, error) {
+ regionSplitSize, regionSplitKeys, err := importer.GetRegionSplitSizeKeys(ctx)
+ if err != nil {
+ logutil.Logger(ctx).Warn("fail to get region split size and keys", zap.Error(err))
+ }
+ regionSplitSize = max(regionSplitSize, int64(config.SplitRegionSize))
+ regionSplitKeys = max(regionSplitKeys, int64(config.SplitRegionKeys))
+ logutil.Logger(ctx).Info("split kv range with split size and keys",
+ zap.Int64("region-split-size", regionSplitSize),
+ zap.Int64("region-split-keys", regionSplitKeys))
+
+ return external.NewRangeSplitter(
+ ctx, kvMeta.DataFiles, kvMeta.StatFiles, store,
+ int64(config.DefaultBatchSize), int64(math.MaxInt64),
+ regionSplitSize, regionSplitKeys,
+ )
+}
diff --git a/disttask/importinto/planner_test.go b/disttask/importinto/planner_test.go
index 7ce4733b4d2ff..95655b8d794bd 100644
--- a/disttask/importinto/planner_test.go
+++ b/disttask/importinto/planner_test.go
@@ -15,15 +15,18 @@
package importinto
import (
+ "context"
"encoding/json"
+ "fmt"
"testing"
+ "github.com/pingcap/failpoint"
+ "github.com/pingcap/tidb/br/pkg/lightning/backend/external"
"github.com/pingcap/tidb/disttask/framework/planner"
"github.com/pingcap/tidb/domain/infosync"
"github.com/pingcap/tidb/executor/importer"
"github.com/pingcap/tidb/meta/autoid"
"github.com/pingcap/tidb/parser/model"
- "github.com/pingcap/tidb/parser/mysql"
"github.com/stretchr/testify/require"
)
@@ -56,7 +59,9 @@ func TestToPhysicalPlan(t *testing.T) {
EligibleInstances: []*infosync.ServerInfo{{ID: "1"}},
ChunkMap: map[int32][]Chunk{chunkID: {{Path: "gs://test-load/1.csv"}}},
}
- planCtx := planner.PlanCtx{}
+ planCtx := planner.PlanCtx{
+ NextTaskStep: StepImport,
+ }
physicalPlan, err := logicalPlan.ToPhysicalPlan(planCtx)
require.NoError(t, err)
plan := &planner.PhysicalPlan{
@@ -77,24 +82,6 @@ func TestToPhysicalPlan(t *testing.T) {
},
Step: StepImport,
},
- {
- ID: 1,
- Input: planner.InputSpec{
- ColumnTypes: []byte{
- mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeLonglong, mysql.TypeJSON,
- },
- Links: []planner.LinkSpec{
- {
- ProcessorID: 0,
- },
- },
- },
- Pipeline: &PostProcessSpec{
- Schema: "db",
- Table: "tb",
- },
- Step: StepPostProcess,
- },
},
}
require.Equal(t, plan, physicalPlan)
@@ -112,8 +99,15 @@ func TestToPhysicalPlan(t *testing.T) {
subtaskMeta1.Checksum = Checksum{Size: 1, KVs: 2, Sum: 3}
bs, err = json.Marshal(subtaskMeta1)
require.NoError(t, err)
+ planCtx = planner.PlanCtx{
+ NextTaskStep: StepPostProcess,
+ }
+ physicalPlan, err = logicalPlan.ToPhysicalPlan(planCtx)
+ require.NoError(t, err)
subtaskMetas2, err := physicalPlan.ToSubtaskMetas(planner.PlanCtx{
- PreviousSubtaskMetas: [][]byte{bs},
+ PreviousSubtaskMetas: map[int64][][]byte{
+ StepImport: {bs},
+ },
}, StepPostProcess)
require.NoError(t, err)
subtaskMeta2 := PostProcessStepMeta{
@@ -124,3 +118,145 @@ func TestToPhysicalPlan(t *testing.T) {
require.NoError(t, err)
require.Equal(t, [][]byte{bs}, subtaskMetas2)
}
+
+func genEncodeStepMetas(t *testing.T, cnt int) [][]byte {
+ stepMetaBytes := make([][]byte, 0, cnt)
+ for i := 0; i < cnt; i++ {
+ prefix := fmt.Sprintf("d_%d_", i)
+ idxPrefix := fmt.Sprintf("i1_%d_", i)
+ meta := &ImportStepMeta{
+ SortedDataMeta: &external.SortedKVMeta{
+ MinKey: []byte(prefix + "a"),
+ MaxKey: []byte(prefix + "c"),
+ TotalKVSize: 12,
+ DataFiles: []string{prefix + "/1"},
+ StatFiles: []string{prefix + "/1.stat"},
+ MultipleFilesStats: []external.MultipleFilesStat{
+ {
+ Filenames: [][2]string{
+ {prefix + "/1", prefix + "/1.stat"},
+ },
+ },
+ },
+ },
+ SortedIndexMetas: map[int64]*external.SortedKVMeta{
+ 1: {
+ MinKey: []byte(idxPrefix + "a"),
+ MaxKey: []byte(idxPrefix + "c"),
+ TotalKVSize: 12,
+ DataFiles: []string{idxPrefix + "/1"},
+ StatFiles: []string{idxPrefix + "/1.stat"},
+ MultipleFilesStats: []external.MultipleFilesStat{
+ {
+ Filenames: [][2]string{
+ {idxPrefix + "/1", idxPrefix + "/1.stat"},
+ },
+ },
+ },
+ },
+ },
+ }
+ bytes, err := json.Marshal(meta)
+ require.NoError(t, err)
+ stepMetaBytes = append(stepMetaBytes, bytes)
+ }
+ return stepMetaBytes
+}
+
+func TestGenerateMergeSortSpecs(t *testing.T) {
+ stepBak := external.MergeSortFileCountStep
+ external.MergeSortFileCountStep = 2
+ t.Cleanup(func() {
+ external.MergeSortFileCountStep = stepBak
+ })
+ // force merge sort for data kv
+ require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/importinto/forceMergeSort", `return("data")`))
+ t.Cleanup(func() {
+ require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/importinto/forceMergeSort"))
+ })
+ encodeStepMetaBytes := genEncodeStepMetas(t, 3)
+ planCtx := planner.PlanCtx{
+ Ctx: context.Background(),
+ TaskID: 1,
+ PreviousSubtaskMetas: map[int64][][]byte{
+ StepEncodeAndSort: encodeStepMetaBytes,
+ },
+ }
+ specs, err := generateMergeSortSpecs(planCtx)
+ require.NoError(t, err)
+ require.Len(t, specs, 2)
+ require.Len(t, specs[0].(*MergeSortSpec).DataFiles, 2)
+ require.Equal(t, "data", specs[0].(*MergeSortSpec).KVGroup)
+ require.Equal(t, "d_0_/1", specs[0].(*MergeSortSpec).DataFiles[0])
+ require.Equal(t, "d_1_/1", specs[0].(*MergeSortSpec).DataFiles[1])
+ require.Equal(t, "data", specs[1].(*MergeSortSpec).KVGroup)
+ require.Len(t, specs[1].(*MergeSortSpec).DataFiles, 1)
+ require.Equal(t, "d_2_/1", specs[1].(*MergeSortSpec).DataFiles[0])
+}
+
+func genMergeStepMetas(t *testing.T, cnt int) [][]byte {
+ stepMetaBytes := make([][]byte, 0, cnt)
+ for i := 0; i < cnt; i++ {
+ prefix := fmt.Sprintf("x_%d_", i)
+ meta := &MergeSortStepMeta{
+ KVGroup: "data",
+ SortedKVMeta: external.SortedKVMeta{
+ MinKey: []byte(prefix + "a"),
+ MaxKey: []byte(prefix + "c"),
+ TotalKVSize: 12,
+ DataFiles: []string{prefix + "/1"},
+ StatFiles: []string{prefix + "/1.stat"},
+ MultipleFilesStats: []external.MultipleFilesStat{
+ {
+ Filenames: [][2]string{
+ {prefix + "/1", prefix + "/1.stat"},
+ },
+ },
+ },
+ },
+ }
+ bytes, err := json.Marshal(meta)
+ require.NoError(t, err)
+ stepMetaBytes = append(stepMetaBytes, bytes)
+ }
+ return stepMetaBytes
+}
+
+func TestGetSortedKVMetas(t *testing.T) {
+ encodeStepMetaBytes := genEncodeStepMetas(t, 3)
+ kvMetas, err := getSortedKVMetasOfEncodeStep(encodeStepMetaBytes)
+ require.NoError(t, err)
+ require.Len(t, kvMetas, 2)
+ require.Contains(t, kvMetas, "data")
+ require.Contains(t, kvMetas, "1")
+ // just check meta is merged, won't check all fields
+ require.Equal(t, []byte("d_0_a"), kvMetas["data"].MinKey)
+ require.Equal(t, []byte("d_2_c"), kvMetas["data"].MaxKey)
+ require.Equal(t, []byte("i1_0_a"), kvMetas["1"].MinKey)
+ require.Equal(t, []byte("i1_2_c"), kvMetas["1"].MaxKey)
+
+ mergeStepMetas := genMergeStepMetas(t, 3)
+ kvMetas2, err := getSortedKVMetasOfMergeStep(mergeStepMetas)
+ require.NoError(t, err)
+ require.Len(t, kvMetas2, 1)
+ require.Equal(t, []byte("x_0_a"), kvMetas2["data"].MinKey)
+ require.Equal(t, []byte("x_2_c"), kvMetas2["data"].MaxKey)
+
+ // force merge sort for data kv
+ require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/importinto/forceMergeSort", `return("data")`))
+ t.Cleanup(func() {
+ require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/importinto/forceMergeSort"))
+ })
+ allKVMetas, err := getSortedKVMetasForIngest(planner.PlanCtx{
+ PreviousSubtaskMetas: map[int64][][]byte{
+ StepEncodeAndSort: encodeStepMetaBytes,
+ StepMergeSort: mergeStepMetas,
+ },
+ })
+ require.NoError(t, err)
+ require.Len(t, allKVMetas, 2)
+ require.Equal(t, []byte("x_0_a"), allKVMetas["data"].MinKey)
+ require.Equal(t, []byte("x_2_c"), allKVMetas["data"].MaxKey)
+ require.Equal(t, []byte("i1_0_a"), allKVMetas["1"].MinKey)
+ require.Equal(t, []byte("i1_2_c"), allKVMetas["1"].MaxKey)
+}
diff --git a/disttask/importinto/proto.go b/disttask/importinto/proto.go
index 337ddbd70eb8f..0979e58824478 100644
--- a/disttask/importinto/proto.go
+++ b/disttask/importinto/proto.go
@@ -19,6 +19,7 @@ import (
"sync"
"github.com/pingcap/tidb/br/pkg/lightning/backend"
+ "github.com/pingcap/tidb/br/pkg/lightning/backend/external"
"github.com/pingcap/tidb/br/pkg/lightning/mydump"
"github.com/pingcap/tidb/br/pkg/lightning/verification"
"github.com/pingcap/tidb/domain/infosync"
@@ -29,13 +30,25 @@ import (
// Steps of IMPORT INTO, each step is represented by one or multiple subtasks.
// the initial step is StepInit(-1)
-// steps are processed in the following order: StepInit -> StepImport -> StepPostProcess
+// steps are processed in the following order:
+// - local sort: StepInit -> StepImport -> StepPostProcess -> StepDone
+// - global sort:
+// StepInit -> StepEncodeAndSort -> StepMergeSort -> StepWriteAndIngest
+// -> StepPostProcess -> StepDone
const (
// StepImport we sort source data and ingest it into TiKV in this step.
- StepImport int64 = 0
+ StepImport int64 = 1
// StepPostProcess we verify checksum and add index in this step.
- // TODO: Might split into StepValidate and StepAddIndex later.
- StepPostProcess int64 = 1
+ StepPostProcess int64 = 2
+ // StepEncodeAndSort encode source data and write sorted kv into global storage.
+ StepEncodeAndSort int64 = 3
+ // StepMergeSort merge sorted kv from global storage, so we can have better
+ // read performance during StepWriteAndIngest.
+ // depends on how much kv files are overlapped, there's might 0 subtasks
+ // in this step.
+ StepMergeSort int64 = 4
+ // StepWriteAndIngest write sorted kv into TiKV and ingest it.
+ StepWriteAndIngest int64 = 5
)
// TaskMeta is the task of IMPORT INTO.
@@ -70,6 +83,34 @@ type ImportStepMeta struct {
// the max id is same among all allocator types for now, since we're using same base, see
// NewPanickingAllocators for more info.
MaxIDs map[autoid.AllocatorType]int64
+
+ SortedDataMeta *external.SortedKVMeta
+ // SortedIndexMetas is a map from index id to its sorted kv meta.
+ SortedIndexMetas map[int64]*external.SortedKVMeta
+}
+
+const (
+ // dataKVGroup is the group name of the sorted kv for data.
+ // index kv will be stored in a group named as index-id.
+ dataKVGroup = "data"
+)
+
+// MergeSortStepMeta is the meta of merge sort step.
+type MergeSortStepMeta struct {
+ // KVGroup is the group name of the sorted kv, either dataKVGroup or index-id.
+ KVGroup string `json:"kv-group"`
+ DataFiles []string `json:"data-files"`
+ external.SortedKVMeta `json:"sorted-kv-meta"`
+}
+
+// WriteIngestStepMeta is the meta of write and ingest step.
+// only used when global sort is enabled.
+type WriteIngestStepMeta struct {
+ KVGroup string `json:"kv-group"`
+ external.SortedKVMeta `json:"sorted-kv-meta"`
+ RangeSplitKeys [][]byte `json:"range-split-keys"`
+
+ Result Result
}
// PostProcessStepMeta is the meta of post process step.
@@ -80,7 +121,7 @@ type PostProcessStepMeta struct {
MaxIDs map[autoid.AllocatorType]int64
}
-// SharedVars is the shared variables between subtask and minimal tasks.
+// SharedVars is the shared variables of all minimal tasks in a subtask.
// This is because subtasks cannot directly obtain the results of the minimal subtask.
// All the fields should be concurrent safe.
type SharedVars struct {
@@ -91,6 +132,28 @@ type SharedVars struct {
mu sync.Mutex
Checksum *verification.KVChecksum
+
+ SortedDataMeta *external.SortedKVMeta
+ // SortedIndexMetas is a map from index id to its sorted kv meta.
+ SortedIndexMetas map[int64]*external.SortedKVMeta
+}
+
+func (sv *SharedVars) mergeDataSummary(summary *external.WriterSummary) {
+ sv.mu.Lock()
+ defer sv.mu.Unlock()
+ sv.SortedDataMeta.MergeSummary(summary)
+}
+
+func (sv *SharedVars) mergeIndexSummary(indexID int64, summary *external.WriterSummary) {
+ sv.mu.Lock()
+ defer sv.mu.Unlock()
+ meta, ok := sv.SortedIndexMetas[indexID]
+ if !ok {
+ meta = external.NewSortedKVMeta(summary)
+ sv.SortedIndexMetas[indexID] = meta
+ return
+ }
+ meta.MergeSummary(summary)
}
// importStepMinimalTask is the minimal task of IMPORT INTO.
diff --git a/disttask/importinto/scheduler.go b/disttask/importinto/scheduler.go
index ef65e1524fe4c..0ba3eda154ad6 100644
--- a/disttask/importinto/scheduler.go
+++ b/disttask/importinto/scheduler.go
@@ -20,10 +20,16 @@ import (
"sync"
"time"
+ "github.com/docker/go-units"
+ "github.com/google/uuid"
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
+ "github.com/pingcap/tidb/br/pkg/lightning/backend"
+ "github.com/pingcap/tidb/br/pkg/lightning/backend/external"
"github.com/pingcap/tidb/br/pkg/lightning/backend/kv"
"github.com/pingcap/tidb/br/pkg/lightning/common"
+ "github.com/pingcap/tidb/br/pkg/lightning/config"
+ "github.com/pingcap/tidb/br/pkg/lightning/log"
"github.com/pingcap/tidb/br/pkg/lightning/metric"
"github.com/pingcap/tidb/br/pkg/lightning/verification"
"github.com/pingcap/tidb/disttask/framework/proto"
@@ -35,7 +41,9 @@ import (
"github.com/pingcap/tidb/meta/autoid"
"github.com/pingcap/tidb/table/tables"
"github.com/pingcap/tidb/util/logutil"
+ "github.com/pingcap/tidb/util/size"
"go.uber.org/zap"
+ "go.uber.org/zap/zapcore"
)
// importStepExecutor is a executor for import step.
@@ -47,37 +55,41 @@ type importStepExecutor struct {
sharedVars sync.Map
logger *zap.Logger
+ indexMemorySizeLimit uint64
+
importCtx context.Context
importCancel context.CancelFunc
wg sync.WaitGroup
}
-func (s *importStepExecutor) Init(ctx context.Context) error {
- s.logger.Info("init subtask env")
-
+func getTableImporter(ctx context.Context, taskID int64, taskMeta *TaskMeta) (*importer.TableImporter, error) {
idAlloc := kv.NewPanickingAllocators(0)
- tbl, err := tables.TableFromMeta(idAlloc, s.taskMeta.Plan.TableInfo)
+ tbl, err := tables.TableFromMeta(idAlloc, taskMeta.Plan.TableInfo)
if err != nil {
- return err
+ return nil, err
}
- astArgs, err := importer.ASTArgsFromStmt(s.taskMeta.Stmt)
+ astArgs, err := importer.ASTArgsFromStmt(taskMeta.Stmt)
if err != nil {
- return err
+ return nil, err
}
- controller, err := importer.NewLoadDataController(&s.taskMeta.Plan, tbl, astArgs)
+ controller, err := importer.NewLoadDataController(&taskMeta.Plan, tbl, astArgs)
if err != nil {
- return err
+ return nil, err
}
- // todo: this method will load all files, but we only import files related to current subtask.
- if err := controller.InitDataFiles(ctx); err != nil {
- return err
+ if err = controller.InitDataStore(ctx); err != nil {
+ return nil, err
}
- tableImporter, err := importer.NewTableImporter(&importer.JobImportParam{
+ return importer.NewTableImporter(&importer.JobImportParam{
GroupCtx: ctx,
Progress: asyncloaddata.NewProgress(false),
Job: &asyncloaddata.Job{},
- }, controller, s.taskID)
+ }, controller, taskID)
+}
+
+func (s *importStepExecutor) Init(ctx context.Context) error {
+ s.logger.Info("init subtask env")
+ tableImporter, err := getTableImporter(ctx, s.taskID, s.taskMeta)
if err != nil {
return err
}
@@ -86,48 +98,64 @@ func (s *importStepExecutor) Init(ctx context.Context) error {
// we need this sub context since Cleanup which wait on this routine is called
// before parent context is canceled in normal flow.
s.importCtx, s.importCancel = context.WithCancel(ctx)
- s.wg.Add(1)
- go func() {
- defer s.wg.Done()
- s.tableImporter.CheckDiskQuota(s.importCtx)
- }()
+ // only need to check disk quota when we are using local sort.
+ if s.tableImporter.IsLocalSort() {
+ s.wg.Add(1)
+ go func() {
+ defer s.wg.Done()
+ s.tableImporter.CheckDiskQuota(s.importCtx)
+ }()
+ }
+ s.indexMemorySizeLimit = getWriterMemorySizeLimit(s.tableImporter.Plan)
+ s.logger.Info("index writer memory size limit",
+ zap.String("limit", units.BytesSize(float64(s.indexMemorySizeLimit))))
return nil
}
-func (s *importStepExecutor) RunSubtask(ctx context.Context, subtask *proto.Subtask) error {
+func (s *importStepExecutor) RunSubtask(ctx context.Context, subtask *proto.Subtask) (err error) {
+ logger := s.logger.With(zap.Int64("subtask-id", subtask.ID))
+ task := log.BeginTask(logger, "run subtask")
+ defer func() {
+ task.End(zapcore.ErrorLevel, err)
+ }()
bs := subtask.Meta
var subtaskMeta ImportStepMeta
- err := json.Unmarshal(bs, &subtaskMeta)
+ err = json.Unmarshal(bs, &subtaskMeta)
if err != nil {
- return err
+ return errors.Trace(err)
}
- s.logger.Info("split and run subtask", zap.Int32("engine-id", subtaskMeta.ID))
- dataEngine, err := s.tableImporter.OpenDataEngine(ctx, subtaskMeta.ID)
- if err != nil {
- return err
- }
- // Unlike in Lightning, we start an index engine for each subtask, whereas previously there was only a single index engine globally.
- // This is because the scheduler currently does not have a post-processing mechanism.
- // If we import the index in `cleanupSubtaskEnv`, the dispatcher will not wait for the import to complete.
- // Multiple index engines may suffer performance degradation due to range overlap.
- // These issues will be alleviated after we integrate s3 sorter.
- // engineID = -1, -2, -3, ...
- indexEngine, err := s.tableImporter.OpenIndexEngine(ctx, common.IndexEngineID-subtaskMeta.ID)
- if err != nil {
- return err
+ var dataEngine, indexEngine *backend.OpenedEngine
+ if s.tableImporter.IsLocalSort() {
+ dataEngine, err = s.tableImporter.OpenDataEngine(ctx, subtaskMeta.ID)
+ if err != nil {
+ return err
+ }
+ // Unlike in Lightning, we start an index engine for each subtask,
+ // whereas previously there was only a single index engine globally.
+ // This is because the scheduler currently does not have a post-processing mechanism.
+ // If we import the index in `cleanupSubtaskEnv`, the dispatcher will not wait for the import to complete.
+ // Multiple index engines may suffer performance degradation due to range overlap.
+ // These issues will be alleviated after we integrate s3 sorter.
+ // engineID = -1, -2, -3, ...
+ indexEngine, err = s.tableImporter.OpenIndexEngine(ctx, common.IndexEngineID-subtaskMeta.ID)
+ if err != nil {
+ return err
+ }
}
sharedVars := &SharedVars{
- TableImporter: s.tableImporter,
- DataEngine: dataEngine,
- IndexEngine: indexEngine,
- Progress: asyncloaddata.NewProgress(false),
- Checksum: &verification.KVChecksum{},
+ TableImporter: s.tableImporter,
+ DataEngine: dataEngine,
+ IndexEngine: indexEngine,
+ Progress: asyncloaddata.NewProgress(false),
+ Checksum: &verification.KVChecksum{},
+ SortedDataMeta: &external.SortedKVMeta{},
+ SortedIndexMetas: make(map[int64]*external.SortedKVMeta),
}
s.sharedVars.Store(subtaskMeta.ID, sharedVars)
source := operator.NewSimpleDataChannel(make(chan *importStepMinimalTask))
- op := newEncodeAndSortOperator(ctx, int(s.taskMeta.Plan.ThreadCnt), s.logger)
+ op := newEncodeAndSortOperator(ctx, s, sharedVars, subtask.ID, s.indexMemorySizeLimit)
op.SetSource(source)
pipeline := operator.NewAsyncPipeline(op)
if err = pipeline.Execute(); err != nil {
@@ -156,7 +184,7 @@ outer:
func (s *importStepExecutor) OnFinished(ctx context.Context, subtask *proto.Subtask) error {
var subtaskMeta ImportStepMeta
if err := json.Unmarshal(subtask.Meta, &subtaskMeta); err != nil {
- return err
+ return errors.Trace(err)
}
s.logger.Info("on subtask finished", zap.Int32("engine-id", subtaskMeta.ID))
@@ -169,23 +197,27 @@ func (s *importStepExecutor) OnFinished(ctx context.Context, subtask *proto.Subt
return errors.Errorf("sharedVars %d not found", subtaskMeta.ID)
}
- // TODO: we should close and cleanup engine in all case, since there's no checkpoint.
- s.logger.Info("import data engine", zap.Int32("engine-id", subtaskMeta.ID))
- closedDataEngine, err := sharedVars.DataEngine.Close(ctx)
- if err != nil {
- return err
- }
- dataKVCount, err := s.tableImporter.ImportAndCleanup(ctx, closedDataEngine)
- if err != nil {
- return err
- }
+ var dataKVCount int64
+ if s.tableImporter.IsLocalSort() {
+ // TODO: we should close and cleanup engine in all case, since there's no checkpoint.
+ s.logger.Info("import data engine", zap.Int32("engine-id", subtaskMeta.ID))
+ closedDataEngine, err := sharedVars.DataEngine.Close(ctx)
+ if err != nil {
+ return err
+ }
+ dataKVCount, err = s.tableImporter.ImportAndCleanup(ctx, closedDataEngine)
+ if err != nil {
+ return err
+ }
- s.logger.Info("import index engine", zap.Int32("engine-id", subtaskMeta.ID))
- if closedEngine, err := sharedVars.IndexEngine.Close(ctx); err != nil {
- return err
- } else if _, err := s.tableImporter.ImportAndCleanup(ctx, closedEngine); err != nil {
- return err
+ s.logger.Info("import index engine", zap.Int32("engine-id", subtaskMeta.ID))
+ if closedEngine, err := sharedVars.IndexEngine.Close(ctx); err != nil {
+ return err
+ } else if _, err := s.tableImporter.ImportAndCleanup(ctx, closedEngine); err != nil {
+ return err
+ }
}
+ // there's no imported dataKVCount on this stage when using global sort.
sharedVars.mu.Lock()
defer sharedVars.mu.Unlock()
@@ -203,10 +235,12 @@ func (s *importStepExecutor) OnFinished(ctx context.Context, subtask *proto.Subt
autoid.AutoIncrementType: allocators.Get(autoid.AutoIncrementType).Base(),
autoid.AutoRandomType: allocators.Get(autoid.AutoRandomType).Base(),
}
+ subtaskMeta.SortedDataMeta = sharedVars.SortedDataMeta
+ subtaskMeta.SortedIndexMetas = sharedVars.SortedIndexMetas
s.sharedVars.Delete(subtaskMeta.ID)
newMeta, err := json.Marshal(subtaskMeta)
if err != nil {
- return err
+ return errors.Trace(err)
}
subtask.Meta = newMeta
return nil
@@ -225,6 +259,169 @@ func (s *importStepExecutor) Rollback(context.Context) error {
return nil
}
+type mergeSortStepExecutor struct {
+ scheduler.EmptySubtaskExecutor
+ taskID int64
+ taskMeta *TaskMeta
+ logger *zap.Logger
+ controller *importer.LoadDataController
+ // subtask of a task is run in serial now, so we don't need lock here.
+ // change to SyncMap when we support parallel subtask in the future.
+ subtaskSortedKVMeta *external.SortedKVMeta
+}
+
+var _ execute.SubtaskExecutor = &mergeSortStepExecutor{}
+
+func (m *mergeSortStepExecutor) Init(ctx context.Context) error {
+ controller, err := buildController(&m.taskMeta.Plan, m.taskMeta.Stmt)
+ if err != nil {
+ return err
+ }
+ if err = controller.InitDataStore(ctx); err != nil {
+ return err
+ }
+ m.controller = controller
+ return nil
+}
+
+func (m *mergeSortStepExecutor) RunSubtask(ctx context.Context, subtask *proto.Subtask) (err error) {
+ logger := m.logger.With(zap.Int64("subtask-id", subtask.ID))
+ task := log.BeginTask(logger, "run subtask")
+ defer func() {
+ task.End(zapcore.ErrorLevel, err)
+ }()
+
+ sm := &MergeSortStepMeta{}
+ err = json.Unmarshal(subtask.Meta, sm)
+ if err != nil {
+ return errors.Trace(err)
+ }
+
+ var mu sync.Mutex
+ m.subtaskSortedKVMeta = &external.SortedKVMeta{}
+ onClose := func(summary *external.WriterSummary) {
+ mu.Lock()
+ defer mu.Unlock()
+ m.subtaskSortedKVMeta.MergeSummary(summary)
+ }
+
+ writerID := uuid.New().String()
+ prefix := subtaskPrefix(m.taskID, subtask.ID)
+
+ return external.MergeOverlappingFiles(
+ ctx,
+ sm.DataFiles,
+ m.controller.GlobalSortStore,
+ 64*1024,
+ prefix,
+ writerID,
+ 256*size.MB,
+ 8*1024,
+ 1*size.MB,
+ 8*1024,
+ onClose)
+}
+
+func (m *mergeSortStepExecutor) OnFinished(_ context.Context, subtask *proto.Subtask) error {
+ var subtaskMeta MergeSortStepMeta
+ if err := json.Unmarshal(subtask.Meta, &subtaskMeta); err != nil {
+ return errors.Trace(err)
+ }
+ subtaskMeta.SortedKVMeta = *m.subtaskSortedKVMeta
+ m.subtaskSortedKVMeta = nil
+ newMeta, err := json.Marshal(subtaskMeta)
+ if err != nil {
+ return errors.Trace(err)
+ }
+ subtask.Meta = newMeta
+ return nil
+}
+
+type writeAndIngestStepExecutor struct {
+ taskID int64
+ taskMeta *TaskMeta
+ logger *zap.Logger
+ tableImporter *importer.TableImporter
+}
+
+var _ execute.SubtaskExecutor = &writeAndIngestStepExecutor{}
+
+func (e *writeAndIngestStepExecutor) Init(ctx context.Context) error {
+ tableImporter, err := getTableImporter(ctx, e.taskID, e.taskMeta)
+ if err != nil {
+ return err
+ }
+ e.tableImporter = tableImporter
+ return nil
+}
+
+func (e *writeAndIngestStepExecutor) RunSubtask(ctx context.Context, subtask *proto.Subtask) (err error) {
+ sm := &WriteIngestStepMeta{}
+ err = json.Unmarshal(subtask.Meta, sm)
+ if err != nil {
+ return errors.Trace(err)
+ }
+
+ logger := e.logger.With(zap.Int64("subtask-id", subtask.ID),
+ zap.String("kv-group", sm.KVGroup))
+ task := log.BeginTask(logger, "run subtask")
+ defer func() {
+ task.End(zapcore.ErrorLevel, err)
+ }()
+
+ _, engineUUID := backend.MakeUUID("", subtask.ID)
+ localBackend := e.tableImporter.Backend()
+ err = localBackend.CloseEngine(ctx, &backend.EngineConfig{
+ External: &backend.ExternalEngineConfig{
+ StorageURI: e.taskMeta.Plan.CloudStorageURI,
+ DataFiles: sm.DataFiles,
+ StatFiles: sm.StatFiles,
+ MinKey: sm.MinKey,
+ MaxKey: sm.MaxKey,
+ SplitKeys: sm.RangeSplitKeys,
+ TotalFileSize: int64(sm.TotalKVSize),
+ TotalKVCount: 0,
+ },
+ }, engineUUID)
+ if err != nil {
+ return err
+ }
+ return localBackend.ImportEngine(ctx, engineUUID, int64(config.SplitRegionSize), int64(config.SplitRegionKeys))
+}
+
+func (e *writeAndIngestStepExecutor) OnFinished(_ context.Context, subtask *proto.Subtask) error {
+ var subtaskMeta WriteIngestStepMeta
+ if err := json.Unmarshal(subtask.Meta, &subtaskMeta); err != nil {
+ return errors.Trace(err)
+ }
+ if subtaskMeta.KVGroup != dataKVGroup {
+ return nil
+ }
+
+ // only data kv group has loaded row count
+ _, engineUUID := backend.MakeUUID("", subtask.ID)
+ localBackend := e.tableImporter.Backend()
+ _, kvCount := localBackend.GetExternalEngineKVStatistics(engineUUID)
+ subtaskMeta.Result.LoadedRowCnt = uint64(kvCount)
+
+ newMeta, err := json.Marshal(subtaskMeta)
+ if err != nil {
+ return errors.Trace(err)
+ }
+ subtask.Meta = newMeta
+ return nil
+}
+
+func (e *writeAndIngestStepExecutor) Cleanup(_ context.Context) (err error) {
+ e.logger.Info("cleanup subtask env")
+ return e.tableImporter.Close()
+}
+
+func (e *writeAndIngestStepExecutor) Rollback(context.Context) error {
+ e.logger.Info("rollback")
+ return nil
+}
+
type postStepExecutor struct {
scheduler.EmptySubtaskExecutor
taskID int64
@@ -234,15 +431,20 @@ type postStepExecutor struct {
var _ execute.SubtaskExecutor = &postStepExecutor{}
-func (p *postStepExecutor) RunSubtask(ctx context.Context, subtask *proto.Subtask) error {
+func (p *postStepExecutor) RunSubtask(ctx context.Context, subtask *proto.Subtask) (err error) {
+ logger := p.logger.With(zap.Int64("subtask-id", subtask.ID))
+ task := log.BeginTask(logger, "run subtask")
+ defer func() {
+ task.End(zapcore.ErrorLevel, err)
+ }()
stepMeta := PostProcessStepMeta{}
- if err := json.Unmarshal(subtask.Meta, &stepMeta); err != nil {
- return err
+ if err = json.Unmarshal(subtask.Meta, &stepMeta); err != nil {
+ return errors.Trace(err)
}
failpoint.Inject("waitBeforePostProcess", func() {
time.Sleep(5 * time.Second)
})
- return postProcess(ctx, p.taskMeta, &stepMeta, p.logger)
+ return postProcess(ctx, p.taskMeta, &stepMeta, logger)
}
type importScheduler struct {
@@ -267,7 +469,7 @@ func (s *importScheduler) Run(ctx context.Context, task *proto.Task) error {
func (*importScheduler) GetSubtaskExecutor(_ context.Context, task *proto.Task, _ *execute.Summary) (execute.SubtaskExecutor, error) {
taskMeta := TaskMeta{}
if err := json.Unmarshal(task.Meta, &taskMeta); err != nil {
- return nil, err
+ return nil, errors.Trace(err)
}
logger := logutil.BgLogger().With(
zap.String("type", proto.ImportInto),
@@ -277,12 +479,24 @@ func (*importScheduler) GetSubtaskExecutor(_ context.Context, task *proto.Task,
logger.Info("create step scheduler")
switch task.Step {
- case StepImport:
+ case StepImport, StepEncodeAndSort:
return &importStepExecutor{
taskID: task.ID,
taskMeta: &taskMeta,
logger: logger,
}, nil
+ case StepMergeSort:
+ return &mergeSortStepExecutor{
+ taskID: task.ID,
+ taskMeta: &taskMeta,
+ logger: logger,
+ }, nil
+ case StepWriteAndIngest:
+ return &writeAndIngestStepExecutor{
+ taskID: task.ID,
+ taskMeta: &taskMeta,
+ logger: logger,
+ }, nil
case StepPostProcess:
return &postStepExecutor{
taskID: task.ID,
diff --git a/disttask/importinto/subtask_executor.go b/disttask/importinto/subtask_executor.go
index b24e3ab72706c..4a4d7b497bfe8 100644
--- a/disttask/importinto/subtask_executor.go
+++ b/disttask/importinto/subtask_executor.go
@@ -22,6 +22,7 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
+ "github.com/pingcap/tidb/br/pkg/lightning/backend"
"github.com/pingcap/tidb/br/pkg/lightning/backend/local"
"github.com/pingcap/tidb/br/pkg/lightning/common"
"github.com/pingcap/tidb/br/pkg/lightning/config"
@@ -46,7 +47,7 @@ var TestSyncChan = make(chan struct{})
// MiniTaskExecutor is the interface for a minimal task executor.
// exported for testing.
type MiniTaskExecutor interface {
- Run(ctx context.Context) error
+ Run(ctx context.Context, dataWriter, indexWriter backend.EngineWriter) error
}
// importMinimalTaskExecutor is a minimal task executor for IMPORT INTO.
@@ -62,7 +63,7 @@ func newImportMinimalTaskExecutor0(t *importStepMinimalTask) MiniTaskExecutor {
}
}
-func (e *importMinimalTaskExecutor) Run(ctx context.Context) error {
+func (e *importMinimalTaskExecutor) Run(ctx context.Context, dataWriter, indexWriter backend.EngineWriter) error {
logger := logutil.BgLogger().With(zap.String("type", proto.ImportInto), zap.Int64("table-id", e.mTtask.Plan.TableInfo.ID))
logger.Info("run minimal task")
failpoint.Inject("waitBeforeSortChunk", func() {
@@ -77,8 +78,14 @@ func (e *importMinimalTaskExecutor) Run(ctx context.Context) error {
})
chunkCheckpoint := toChunkCheckpoint(e.mTtask.Chunk)
sharedVars := e.mTtask.SharedVars
- if err := importer.ProcessChunk(ctx, &chunkCheckpoint, sharedVars.TableImporter, sharedVars.DataEngine, sharedVars.IndexEngine, sharedVars.Progress, logger); err != nil {
- return err
+ if sharedVars.TableImporter.IsLocalSort() {
+ if err := importer.ProcessChunk(ctx, &chunkCheckpoint, sharedVars.TableImporter, sharedVars.DataEngine, sharedVars.IndexEngine, sharedVars.Progress, logger); err != nil {
+ return err
+ }
+ } else {
+ if err := importer.ProcessChunkWith(ctx, &chunkCheckpoint, sharedVars.TableImporter, dataWriter, indexWriter, sharedVars.Progress, logger); err != nil {
+ return err
+ }
}
sharedVars.mu.Lock()
diff --git a/disttask/operator/operator.go b/disttask/operator/operator.go
index 2e3bb5a68cc25..c1a3d43bb805c 100644
--- a/disttask/operator/operator.go
+++ b/disttask/operator/operator.go
@@ -70,7 +70,7 @@ func (c *AsyncOperator[T, R]) Open() error {
func (c *AsyncOperator[T, R]) Close() error {
// Wait all tasks done.
// We don't need to close the task channel because
- // it is closed by the workerpool.
+ // it is maintained outside this operator, see SetSource.
c.pool.Wait()
c.pool.Release()
return nil
diff --git a/docs/design/2023-06-30-configurable-kv-timeout.md b/docs/design/2023-06-30-configurable-kv-timeout.md
index c21700be6c5dc..83e5896a7f8df 100644
--- a/docs/design/2023-06-30-configurable-kv-timeout.md
+++ b/docs/design/2023-06-30-configurable-kv-timeout.md
@@ -21,7 +21,7 @@ encountered.
For latency sensitive applications, providing predictable sub-second read latency by set fast retrying
is valuable. queries that are usually very fast (such as point-select query), setting the value
-of `tidb_kv_read_timeout` to short value like `500ms`, the TiDB cluster would be more tolerable
+of `tikv_client_read_timeout` to short value like `500ms`, the TiDB cluster would be more tolerable
for network latency or io latency jitter for a single storage node, because the retry are more quickly.
This would be helpful if the requests could be processed by more than 1 candidate targets, for example
@@ -31,11 +31,11 @@ follower or stale read requests that could be handled by multiple available peer
A possible improvement suggested by [#44771](https://github.com/pingcap/tidb/issues/44771) is to make the
timeout values of specific KV requests configurable. For example:
-- Adding a session variable `tidb_kv_read_timeout`, which is used to control the timeout for a single
+- Adding a session variable `tikv_client_read_timeout`, which is used to control the timeout for a single
TiKV read RPC request. When the user sets the value of this variable, all read RPC request timeouts will use this value.
The default value of this variable is 0, and the timeout of TiKV read RPC requests is still the original
value of `ReadTimeoutShort` and `ReadTimeoutMedium`.
-- Adding statement level hint like `SELECT /*+ tidb_kv_read_timeout(500ms) */ * FROM t where id = ?;` to
+- Support statement level hint like `SELECT /*+ set_var(tikv_client_read_timeout=500) */ * FROM t where id = ?;` to
set the timeout value of the KV requests of this single query to the certain value.
### Example Usage
@@ -48,17 +48,17 @@ set @@tidb_read_staleness=-5;
set @@tidb_tikv_tidb_timeout=500;
select * from t where id = 1;
# The unit is miliseconds. The query hint usage.
-select /*+ tidb_kv_read_timeout(500ms) */ * FROM t where id = 1;
+select /*+ set_var(tikv_client_read_timeout=500) */ * FROM t where id = 1;
```
### Problems
-- Setting the variable `tidb_kv_read_timeout ` may not be easy if it affects the timeout for all
+- Setting the variable `tikv_client_read_timeout ` may not be easy if it affects the timeout for all
TiKV read requests, such as Get, BatchGet, Cop in this session.A timeout of 1 second may be sufficient for GET requests,
but may be small for COP requests. Some large COP requests may keep timing out and could not be processed properly.
-- If the value of the variable or query hint `tidb_kv_read_timeout` is set too small, more retries will occur,
+- If the value of the variable or query hint `tikv_client_read_timeout` is set too small, more retries will occur,
increasing the load pressure on the TiDB cluster. In the worst case the query would not return until the
-max backoff timeout is reached if the `tidb_kv_read_timeout` is set to a value which none of the replicas
+max backoff timeout is reached if the `tikv_client_read_timeout` is set to a value which none of the replicas
could finish processing within that time.
@@ -66,18 +66,6 @@ could finish processing within that time.
The query hint usage would be more flexible and safer as the impact is limited to a single query.
-### Add Hint Support For `tidb_kv_read_timeout`
-
-- Add related field in the `StatementContext` struct like
-```go
-type StmtHints struct {
- ...
- KVReadTimeout: Duration
-}
-```
-- Support `tidb_kv_read_timeout` processing in `ExtractTableHintsFromStmtNode` and `handleStmtHints` functions,
-convert the user input hint value into correspond field of `StmtContext` so it could be used later.
-
### Support Timeout Configuration For Get And Batch-Get
- Add related filed in the `KVSnapshot` struct like
@@ -237,7 +225,7 @@ There’s already a query level timeout configuration `max_execution_time` which
variables and query hints.
If the timeout of RPC or TiKV requests could derive the timeout values from the `max_execution_time` in an
-intelligent way, it’s not needed to expose another variable or usage like `tidb_kv_read_timeout`.
+intelligent way, it’s not needed to expose another variable or usage like `tikv_client_read_timeout`.
For example, consider the strategy:
- The `max_execution_time` is configured to `1s` on the query level
diff --git a/domain/domain.go b/domain/domain.go
index 4890eb67b74b8..419732c096855 100644
--- a/domain/domain.go
+++ b/domain/domain.go
@@ -2329,10 +2329,6 @@ func (do *Domain) loadStatsWorker() {
for {
select {
case <-loadTicker.C:
- err = statsHandle.RefreshVars()
- if err != nil {
- logutil.BgLogger().Debug("refresh variables failed", zap.Error(err))
- }
err = statsHandle.Update(do.InfoSchema())
if err != nil {
logutil.BgLogger().Debug("update stats info failed", zap.Error(err))
diff --git a/dumpling/export/config.go b/dumpling/export/config.go
index e48edbb3ce1f6..b22e11ed1798c 100644
--- a/dumpling/export/config.go
+++ b/dumpling/export/config.go
@@ -68,6 +68,7 @@ const (
flagKey = "key"
flagCsvSeparator = "csv-separator"
flagCsvDelimiter = "csv-delimiter"
+ flagCsvLineTerminator = "csv-line-terminator"
flagOutputFilenameTemplate = "output-filename-template"
flagCompleteInsert = "complete-insert"
flagParams = "params"
@@ -113,18 +114,19 @@ type Config struct {
SSLKeyBytes []byte `json:"-"`
}
- LogLevel string
- LogFile string
- LogFormat string
- OutputDirPath string
- StatusAddr string
- Snapshot string
- Consistency string
- CsvNullValue string
- SQL string
- CsvSeparator string
- CsvDelimiter string
- Databases []string
+ LogLevel string
+ LogFile string
+ LogFormat string
+ OutputDirPath string
+ StatusAddr string
+ Snapshot string
+ Consistency string
+ CsvNullValue string
+ SQL string
+ CsvSeparator string
+ CsvDelimiter string
+ CsvLineTerminator string
+ Databases []string
TableFilter filter.Filter `json:"-"`
Where string
@@ -191,6 +193,7 @@ func DefaultConfig() *Config {
DumpEmptyDatabase: true,
CsvDelimiter: "\"",
CsvSeparator: ",",
+ CsvLineTerminator: "\r\n",
SessionParams: make(map[string]interface{}),
OutputFileTemplate: DefaultOutputFileTemplate,
PosAfterConnect: false,
@@ -300,6 +303,7 @@ func (*Config) DefineFlags(flags *pflag.FlagSet) {
flags.String(flagKey, "", "The path name to the client private key file for TLS connection")
flags.String(flagCsvSeparator, ",", "The separator for csv files, default ','")
flags.String(flagCsvDelimiter, "\"", "The delimiter for values in csv files, default '\"'")
+ flags.String(flagCsvLineTerminator, "\r\n", "The line terminator for csv files, default '\\r\\n'")
flags.String(flagOutputFilenameTemplate, "", "The output filename template (without file extension)")
flags.Bool(flagCompleteInsert, false, "Use complete INSERT statements that include column names")
flags.StringToString(flagParams, nil, `Extra session variables used while dumping, accepted format: --params "character_set_client=latin1,character_set_connection=latin1"`)
@@ -447,6 +451,10 @@ func (conf *Config) ParseFromFlags(flags *pflag.FlagSet) error {
if err != nil {
return errors.Trace(err)
}
+ conf.CsvLineTerminator, err = flags.GetString(flagCsvLineTerminator)
+ if err != nil {
+ return errors.Trace(err)
+ }
conf.CompleteInsert, err = flags.GetBool(flagCompleteInsert)
if err != nil {
return errors.Trace(err)
diff --git a/dumpling/export/writer.go b/dumpling/export/writer.go
index b4396d2279607..6e52f2384b787 100644
--- a/dumpling/export/writer.go
+++ b/dumpling/export/writer.go
@@ -304,9 +304,10 @@ type outputFileNamer struct {
}
type csvOption struct {
- nullValue string
- separator []byte
- delimiter []byte
+ nullValue string
+ separator []byte
+ delimiter []byte
+ lineTerminator []byte
}
func newOutputFileNamer(meta TableMeta, chunkIdx int, rows, fileSize bool) *outputFileNamer {
diff --git a/dumpling/export/writer_serial_test.go b/dumpling/export/writer_serial_test.go
index 2290ca86cdfa2..4029b08ef9ce3 100644
--- a/dumpling/export/writer_serial_test.go
+++ b/dumpling/export/writer_serial_test.go
@@ -120,7 +120,7 @@ func TestWriteInsertInCsv(t *testing.T) {
bf := storage.NewBufferWriter()
// test nullValue
- opt := &csvOption{separator: []byte(","), delimiter: []byte{'"'}, nullValue: "\\N"}
+ opt := &csvOption{separator: []byte(","), delimiter: []byte{'"'}, nullValue: "\\N", lineTerminator: []byte("\r\n")}
conf := configForWriteCSV(cfg, true, opt)
m := newMetrics(cfg.PromFactory, conf.Labels)
n, err := WriteInsertInCsv(tcontext.Background(), conf, tableIR, tableIR, bf, m)
@@ -171,10 +171,29 @@ func TestWriteInsertInCsv(t *testing.T) {
require.Equal(t, float64(len(data)), ReadGauge(m.finishedRowsGauge))
require.Equal(t, float64(len(expected)), ReadGauge(m.finishedSizeGauge))
+ // test line terminator
+ bf.Reset()
+ opt.lineTerminator = []byte("\n")
+ tableIR = newMockTableIR("test", "employee", data, nil, colTypes)
+ conf = configForWriteCSV(cfg, true, opt)
+ m = newMetrics(conf.PromFactory, conf.Labels)
+ n, err = WriteInsertInCsv(tcontext.Background(), conf, tableIR, tableIR, bf, m)
+ require.Equal(t, uint64(4), n)
+ require.NoError(t, err)
+
+ expected = "1;'male';'bob@mail.com';'020-1234';\\N\n" +
+ "2;'female';'sarah@mail.com';'020-1253';'healthy'\n" +
+ "3;'male';'john@mail.com';'020-1256';'healthy'\n" +
+ "4;'female';'sarah@mail.com';'020-1235';'healthy'\n"
+ require.Equal(t, expected, bf.String())
+ require.Equal(t, float64(len(data)), ReadGauge(m.finishedRowsGauge))
+ require.Equal(t, float64(len(expected)), ReadGauge(m.finishedSizeGauge))
+
// test delimiter that included in values
bf.Reset()
opt.separator = []byte("&;,?")
opt.delimiter = []byte("ma")
+ opt.lineTerminator = []byte("\r\n")
tableIR = newMockTableIR("test", "employee", data, nil, colTypes)
tableIR.colNames = []string{"id", "gender", "email", "phone_number", "status"}
conf = configForWriteCSV(cfg, false, opt)
@@ -211,7 +230,7 @@ func TestWriteInsertInCsvReturnsError(t *testing.T) {
bf := storage.NewBufferWriter()
// test nullValue
- opt := &csvOption{separator: []byte(","), delimiter: []byte{'"'}, nullValue: "\\N"}
+ opt := &csvOption{separator: []byte(","), delimiter: []byte{'"'}, nullValue: "\\N", lineTerminator: []byte("\r\n")}
conf := configForWriteCSV(cfg, true, opt)
m := newMetrics(conf.PromFactory, conf.Labels)
n, err := WriteInsertInCsv(tcontext.Background(), conf, tableIR, tableIR, bf, m)
@@ -294,6 +313,7 @@ func configForWriteCSV(config *Config, noHeader bool, opt *csvOption) *Config {
cfg.CsvNullValue = opt.nullValue
cfg.CsvDelimiter = string(opt.delimiter)
cfg.CsvSeparator = string(opt.separator)
+ cfg.CsvLineTerminator = string(opt.lineTerminator)
cfg.FileSize = UnspecifiedSize
return cfg
}
diff --git a/dumpling/export/writer_util.go b/dumpling/export/writer_util.go
index dc34992d9adc1..1d4e328703336 100644
--- a/dumpling/export/writer_util.go
+++ b/dumpling/export/writer_util.go
@@ -309,9 +309,10 @@ func WriteInsertInCsv(
wp := newWriterPipe(w, cfg.FileSize, UnspecifiedSize, metrics, cfg.Labels)
opt := &csvOption{
- nullValue: cfg.CsvNullValue,
- separator: []byte(cfg.CsvSeparator),
- delimiter: []byte(cfg.CsvDelimiter),
+ nullValue: cfg.CsvNullValue,
+ separator: []byte(cfg.CsvSeparator),
+ delimiter: []byte(cfg.CsvDelimiter),
+ lineTerminator: []byte(cfg.CsvLineTerminator),
}
// use context.Background here to make sure writerPipe can deplete all the chunks in pipeline
@@ -365,8 +366,7 @@ func WriteInsertInCsv(
bf.Write(opt.separator)
}
}
- bf.WriteByte('\r')
- bf.WriteByte('\n')
+ bf.Write(opt.lineTerminator)
}
wp.currentFileSize += uint64(bf.Len())
@@ -381,8 +381,7 @@ func WriteInsertInCsv(
counter++
wp.currentFileSize += uint64(bf.Len()-lastBfSize) + 1 // 1 is for "\n"
- bf.WriteByte('\r')
- bf.WriteByte('\n')
+ bf.Write(opt.lineTerminator)
if bf.Len() >= lengthLimit {
select {
case <-pCtx.Done():
diff --git a/dumpling/tests/e2e_csv/run.sh b/dumpling/tests/e2e_csv/run.sh
index 9c5afaca469d7..93a0855a3232d 100644
--- a/dumpling/tests/e2e_csv/run.sh
+++ b/dumpling/tests/e2e_csv/run.sh
@@ -72,6 +72,7 @@ run() {
escape_backslash_arr="true false"
csv_delimiter_arr="\" '"
csv_separator_arr=', a aa |*|'
+csv_line_terminator_arr='\n \r\n'
compress_arr='space gzip snappy zstd'
for compress in $compress_arr
@@ -80,14 +81,17 @@ do
do
for csv_separator in $csv_separator_arr
do
- for csv_delimiter in $csv_delimiter_arr
+ for csv_line_terminator in $csv_line_terminator_arr
do
- run
+ for csv_delimiter in $csv_delimiter_arr
+ do
+ run
+ done
+ if [ "$escape_backslash" = "true" ]; then
+ csv_delimiter=""
+ run
+ fi
done
- if [ "$escape_backslash" = "true" ]; then
- csv_delimiter=""
- run
- fi
done
done
done
diff --git a/errno/errname.go b/errno/errname.go
index 648310616d4e8..bb1fa58be9ef3 100644
--- a/errno/errname.go
+++ b/errno/errname.go
@@ -1059,8 +1059,8 @@ var MySQLErrName = map[uint16]*mysql.ErrMessage{
ErrLoadParquetFromLocal: mysql.Message("Do not support loading parquet files from local. Please try to load the parquet files from the cloud storage", nil),
ErrLoadDataEmptyPath: mysql.Message("The value of INFILE must not be empty when LOAD DATA from LOCAL", nil),
ErrLoadDataUnsupportedFormat: mysql.Message("The FORMAT '%s' is not supported", nil),
- ErrLoadDataInvalidURI: mysql.Message("The URI of file location is invalid. Reason: %s. Please provide a valid URI, such as 's3://import/test.csv?access_key_id={your_access_key_id ID}&secret_access_key={your_secret_access_key}&session_token={your_session_token}'", nil),
- ErrLoadDataCantAccess: mysql.Message("Access to the source file has been denied. Reason: %s. Please check the URI, access key and secret access key are correct", nil),
+ ErrLoadDataInvalidURI: mysql.Message("The URI of %s is invalid. Reason: %s. Please provide a valid URI, such as 's3://import/test.csv?access_key_id={your_access_key_id ID}&secret_access_key={your_secret_access_key}&session_token={your_session_token}'", nil),
+ ErrLoadDataCantAccess: mysql.Message("Access to the %s has been denied. Reason: %s. Please check the URI, access key and secret access key are correct", nil),
ErrLoadDataCantRead: mysql.Message("Failed to read source files. Reason: %s. %s", nil),
ErrLoadDataWrongFormatConfig: mysql.Message("", nil),
ErrUnknownOption: mysql.Message("Unknown option %s", nil),
diff --git a/errors.toml b/errors.toml
index 0ffcb48e051a7..8d37b77363cbc 100644
--- a/errors.toml
+++ b/errors.toml
@@ -1803,12 +1803,12 @@ The FORMAT '%s' is not supported
["executor:8158"]
error = '''
-The URI of file location is invalid. Reason: %s. Please provide a valid URI, such as 's3://import/test.csv?access_key_id={your_access_key_id ID}&secret_access_key={your_secret_access_key}&session_token={your_session_token}'
+The URI of %s is invalid. Reason: %s. Please provide a valid URI, such as 's3://import/test.csv?access_key_id={your_access_key_id ID}&secret_access_key={your_secret_access_key}&session_token={your_session_token}'
'''
["executor:8159"]
error = '''
-Access to the source file has been denied. Reason: %s. Please check the URI, access key and secret access key are correct
+Access to the %s has been denied. Reason: %s. Please check the URI, access key and secret access key are correct
'''
["executor:8160"]
diff --git a/executor/BUILD.bazel b/executor/BUILD.bazel
index b2189d60742ee..57f64bfc4ea7b 100644
--- a/executor/BUILD.bazel
+++ b/executor/BUILD.bazel
@@ -275,6 +275,7 @@ go_library(
"@org_golang_google_grpc//credentials",
"@org_golang_google_grpc//credentials/insecure",
"@org_golang_google_grpc//status",
+ "@org_golang_x_exp//maps",
"@org_golang_x_sync//errgroup",
"@org_uber_go_atomic//:atomic",
"@org_uber_go_zap//:zap",
@@ -410,6 +411,7 @@ go_test(
"//sessiontxn/staleread",
"//statistics",
"//statistics/handle",
+ "//statistics/handle/globalstats",
"//store/copr",
"//store/driver/error",
"//store/helper",
diff --git a/executor/OWNERS b/executor/OWNERS
new file mode 100644
index 0000000000000..fdb555c5f9a4d
--- /dev/null
+++ b/executor/OWNERS
@@ -0,0 +1,5 @@
+# See the OWNERS docs at https://go.k8s.io/owners
+options:
+ no_parent_owners: true
+approvers:
+ - sig-approvers-executor
diff --git a/executor/adapter.go b/executor/adapter.go
index a1169af54d117..a281c9913dc8c 100644
--- a/executor/adapter.go
+++ b/executor/adapter.go
@@ -1782,7 +1782,7 @@ func getEncodedPlan(stmtCtx *stmtctx.StatementContext, genHint bool) (encodedPla
// so we have to iterate all hints from the customer and keep some other necessary hints.
switch tableHint.HintName.L {
case plannercore.HintMemoryQuota, plannercore.HintUseToja, plannercore.HintNoIndexMerge,
- plannercore.HintMaxExecutionTime, plannercore.HintTidbKvReadTimeout,
+ plannercore.HintMaxExecutionTime,
plannercore.HintIgnoreIndex, plannercore.HintReadFromStorage, plannercore.HintMerge,
plannercore.HintSemiJoinRewrite, plannercore.HintNoDecorrelate:
hints = append(hints, tableHint)
diff --git a/executor/analyze.go b/executor/analyze.go
index a3ac2bc8f7a8c..32c7b6a3df599 100644
--- a/executor/analyze.go
+++ b/executor/analyze.go
@@ -151,11 +151,12 @@ func (e *AnalyzeExec) Next(ctx context.Context, _ *chunk.Chunk) error {
dom := domain.GetDomain(e.Ctx())
dom.SysProcTracker().KillSysProcess(dom.GetAutoAnalyzeProcID())
})
-
// If we enabled dynamic prune mode, then we need to generate global stats here for partition tables.
- err = e.handleGlobalStats(ctx, needGlobalStats, globalStatsMap)
- if err != nil {
- return err
+ if needGlobalStats {
+ err = e.handleGlobalStats(ctx, globalStatsMap)
+ if err != nil {
+ return err
+ }
}
// Update analyze options to mysql.analyze_options for auto analyze.
@@ -172,46 +173,80 @@ func filterAndCollectTasks(tasks []*analyzeTask, statsHandle *handle.Handle, inf
filteredTasks []*analyzeTask
skippedTables []string
needAnalyzeTableCnt uint
- tids = make([]int64, 0, len(tasks))
// tidMap is used to deduplicate table IDs.
// In stats v1, analyze for each index is a single task, and they have the same table id.
- tidMap = make(map[int64]struct{}, len(tasks))
+ tidAndPidsMap = make(map[int64]struct{}, len(tasks))
)
- // Check the locked tables in one transaction.
- for _, task := range tasks {
- tableID := getTableIDFromTask(task)
- tids = append(tids, tableID)
- }
- lockedTables, err := statsHandle.GetLockedTables(tids...)
+ lockedTableAndPartitionIDs, err := getLockedTableAndPartitionIDs(statsHandle, tasks)
if err != nil {
return nil, 0, nil, err
}
for _, task := range tasks {
+ // Check if the table or partition is locked.
tableID := getTableIDFromTask(task)
- _, isLocked := lockedTables[tableID]
+ _, isLocked := lockedTableAndPartitionIDs[tableID.TableID]
+ // If the whole table is not locked, we should check whether the partition is locked.
+ if !isLocked && tableID.IsPartitionTable() {
+ _, isLocked = lockedTableAndPartitionIDs[tableID.PartitionID]
+ }
+
+ // Only analyze the table that is not locked.
if !isLocked {
filteredTasks = append(filteredTasks, task)
}
- if _, ok := tidMap[tableID]; !ok {
+
+ // Get the physical table ID.
+ physicalTableID := tableID.TableID
+ if tableID.IsPartitionTable() {
+ physicalTableID = tableID.PartitionID
+ }
+ if _, ok := tidAndPidsMap[physicalTableID]; !ok {
if isLocked {
- tbl, ok := infoSchema.TableByID(tableID)
- if !ok {
- logutil.BgLogger().Warn("Unknown table ID in analyze task", zap.Int64("tid", tableID))
+ if tableID.IsPartitionTable() {
+ tbl, _, def := infoSchema.FindTableByPartitionID(tableID.PartitionID)
+ if def == nil {
+ logutil.BgLogger().Warn("Unknown partition ID in analyze task", zap.Int64("pid", tableID.PartitionID))
+ } else {
+ schema, _ := infoSchema.SchemaByTable(tbl.Meta())
+ skippedTables = append(skippedTables, fmt.Sprintf("%s.%s partition (%s)", schema.Name, tbl.Meta().Name.O, def.Name.O))
+ }
} else {
- skippedTables = append(skippedTables, tbl.Meta().Name.L)
+ tbl, ok := infoSchema.TableByID(physicalTableID)
+ if !ok {
+ logutil.BgLogger().Warn("Unknown table ID in analyze task", zap.Int64("tid", physicalTableID))
+ } else {
+ schema, _ := infoSchema.SchemaByTable(tbl.Meta())
+ skippedTables = append(skippedTables, fmt.Sprintf("%s.%s", schema.Name, tbl.Meta().Name.O))
+ }
}
} else {
needAnalyzeTableCnt++
}
- tidMap[tableID] = struct{}{}
+ tidAndPidsMap[physicalTableID] = struct{}{}
}
}
return filteredTasks, needAnalyzeTableCnt, skippedTables, nil
}
+// getLockedTableAndPartitionIDs queries the locked tables and partitions.
+func getLockedTableAndPartitionIDs(statsHandle *handle.Handle, tasks []*analyzeTask) (map[int64]struct{}, error) {
+ tidAndPids := make([]int64, 0, len(tasks))
+ // Check the locked tables in one transaction.
+ // We need to check all tables and its partitions.
+ // Because if the whole table is locked, we should skip all partitions.
+ for _, task := range tasks {
+ tableID := getTableIDFromTask(task)
+ tidAndPids = append(tidAndPids, tableID.TableID)
+ if tableID.IsPartitionTable() {
+ tidAndPids = append(tidAndPids, tableID.PartitionID)
+ }
+ }
+ return statsHandle.GetLockedTables(tidAndPids...)
+}
+
// warnLockedTableMsg warns the locked table IDs.
func warnLockedTableMsg(sessionVars *variable.SessionVars, needAnalyzeTableCnt uint, skippedTables []string) {
if len(skippedTables) > 0 {
@@ -229,23 +264,21 @@ func warnLockedTableMsg(sessionVars *variable.SessionVars, needAnalyzeTableCnt u
}
}
-func getTableIDFromTask(task *analyzeTask) int64 {
- var tableID statistics.AnalyzeTableID
-
+func getTableIDFromTask(task *analyzeTask) statistics.AnalyzeTableID {
switch task.taskType {
case colTask:
- tableID = task.colExec.tableID
+ return task.colExec.tableID
case idxTask:
- tableID = task.idxExec.tableID
+ return task.idxExec.tableID
case fastTask:
- tableID = task.fastExec.tableID
+ return task.fastExec.tableID
case pkIncrementalTask:
- tableID = task.colIncrementalExec.tableID
+ return task.colIncrementalExec.tableID
case idxIncrementalTask:
- tableID = task.idxIncrementalExec.tableID
+ return task.idxIncrementalExec.tableID
}
- return tableID.TableID
+ panic("unreachable")
}
func (e *AnalyzeExec) saveV2AnalyzeOpts() error {
@@ -310,10 +343,15 @@ func recordHistoricalStats(sctx sessionctx.Context, tableID int64) error {
}
// handleResultsError will handle the error fetch from resultsCh and record it in log
-func (e *AnalyzeExec) handleResultsError(ctx context.Context, concurrency int, needGlobalStats bool,
- globalStatsMap globalStatsMap, resultsCh <-chan *statistics.AnalyzeResults) error {
+func (e *AnalyzeExec) handleResultsError(
+ ctx context.Context,
+ concurrency int,
+ needGlobalStats bool,
+ globalStatsMap globalStatsMap,
+ resultsCh <-chan *statistics.AnalyzeResults,
+) error {
partitionStatsConcurrency := e.Ctx().GetSessionVars().AnalyzePartitionConcurrency
- // If 'partitionStatsConcurrency' > 1, we will try to demand extra session from Domain to save Analyze results in concurrency.
+ // If partitionStatsConcurrency > 1, we will try to demand extra session from Domain to save Analyze results in concurrency.
// If there is no extra session we can use, we will save analyze results in single-thread.
if partitionStatsConcurrency > 1 {
dom := domain.GetDomain(e.Ctx())
@@ -523,6 +561,14 @@ func StartAnalyzeJob(sctx sessionctx.Context, job *statistics.AnalyzeJob) {
if err != nil {
logutil.BgLogger().Warn("failed to update analyze job", zap.String("update", fmt.Sprintf("%s->%s", statistics.AnalyzePending, statistics.AnalyzeRunning)), zap.Error(err))
}
+ failpoint.Inject("DebugAnalyzeJobOperations", func(val failpoint.Value) {
+ if val.(bool) {
+ logutil.BgLogger().Info("StartAnalyzeJob",
+ zap.Time("start_time", job.StartTime),
+ zap.Uint64("job id", *job.ID),
+ )
+ }
+ })
}
// UpdateAnalyzeJob updates count of the processed rows when increment reaches a threshold.
@@ -541,6 +587,14 @@ func UpdateAnalyzeJob(sctx sessionctx.Context, job *statistics.AnalyzeJob, rowCo
if err != nil {
logutil.BgLogger().Warn("failed to update analyze job", zap.String("update", fmt.Sprintf("process %v rows", delta)), zap.Error(err))
}
+ failpoint.Inject("DebugAnalyzeJobOperations", func(val failpoint.Value) {
+ if val.(bool) {
+ logutil.BgLogger().Info("UpdateAnalyzeJob",
+ zap.Int64("increase processed_rows", delta),
+ zap.Uint64("job id", *job.ID),
+ )
+ }
+ })
}
// FinishAnalyzeMergeJob finishes analyze merge job
@@ -548,6 +602,7 @@ func FinishAnalyzeMergeJob(sctx sessionctx.Context, job *statistics.AnalyzeJob,
if job == nil || job.ID == nil {
return
}
+
job.EndTime = time.Now()
var sql string
var args []interface{}
@@ -575,6 +630,14 @@ func FinishAnalyzeMergeJob(sctx sessionctx.Context, job *statistics.AnalyzeJob,
}
logutil.BgLogger().Warn("failed to update analyze job", zap.String("update", fmt.Sprintf("%s->%s", statistics.AnalyzeRunning, state)), zap.Error(err))
}
+ failpoint.Inject("DebugAnalyzeJobOperations", func(val failpoint.Value) {
+ if val.(bool) {
+ logutil.BgLogger().Info("FinishAnalyzeMergeJob",
+ zap.Time("end_time", job.EndTime),
+ zap.Uint64("job id", *job.ID),
+ )
+ }
+ })
}
// FinishAnalyzeJob updates the state of the analyze job to finished/failed according to `meetError` and sets the end time.
@@ -611,6 +674,16 @@ func FinishAnalyzeJob(sctx sessionctx.Context, job *statistics.AnalyzeJob, analy
}
logutil.BgLogger().Warn("failed to update analyze job", zap.String("update", fmt.Sprintf("%s->%s", statistics.AnalyzeRunning, state)), zap.Error(err))
}
+ failpoint.Inject("DebugAnalyzeJobOperations", func(val failpoint.Value) {
+ if val.(bool) {
+ logutil.BgLogger().Info("FinishAnalyzeJob",
+ zap.Int64("increase processed_rows", job.Progress.GetDeltaCount()),
+ zap.Time("end_time", job.EndTime),
+ zap.Uint64("job id", *job.ID),
+ zap.Error(analyzeErr),
+ )
+ }
+ })
}
func finishJobWithLog(sctx sessionctx.Context, job *statistics.AnalyzeJob, analyzeErr error) {
diff --git a/executor/analyze_col_v2.go b/executor/analyze_col_v2.go
index 7501cd6d4c0d0..930d45a08f93a 100644
--- a/executor/analyze_col_v2.go
+++ b/executor/analyze_col_v2.go
@@ -55,17 +55,16 @@ type AnalyzeColumnsExecV2 struct {
func (e *AnalyzeColumnsExecV2) analyzeColumnsPushDownWithRetryV2() *statistics.AnalyzeResults {
analyzeResult := e.analyzeColumnsPushDownV2()
- // do not retry if succeed / not oom error / not auto-analyze / samplerate not set
- if analyzeResult.Err == nil || analyzeResult.Err != errAnalyzeOOM ||
- !e.ctx.GetSessionVars().InRestrictedSQL ||
- e.analyzePB.ColReq == nil || *e.analyzePB.ColReq.SampleRate <= 0 {
+ if e.notRetryable(analyzeResult) {
return analyzeResult
}
+
finishJobWithLog(e.ctx, analyzeResult.Job, analyzeResult.Err)
statsHandle := domain.GetDomain(e.ctx).StatsHandle()
if statsHandle == nil {
return analyzeResult
}
+
var statsTbl *statistics.Table
tid := e.tableID.GetStatisticsID()
if tid == e.tableInfo.ID {
@@ -76,6 +75,7 @@ func (e *AnalyzeColumnsExecV2) analyzeColumnsPushDownWithRetryV2() *statistics.A
if statsTbl == nil || statsTbl.RealtimeCount <= 0 {
return analyzeResult
}
+
newSampleRate := math.Min(1, float64(config.DefRowsForSampleRate)/float64(statsTbl.RealtimeCount))
if newSampleRate >= *e.analyzePB.ColReq.SampleRate {
return analyzeResult
@@ -87,6 +87,13 @@ func (e *AnalyzeColumnsExecV2) analyzeColumnsPushDownWithRetryV2() *statistics.A
return e.analyzeColumnsPushDownV2()
}
+// Do **not** retry if succeed / not oom error / not auto-analyze / samplerate not set.
+func (e *AnalyzeColumnsExecV2) notRetryable(analyzeResult *statistics.AnalyzeResults) bool {
+ return analyzeResult.Err == nil || analyzeResult.Err != errAnalyzeOOM ||
+ !e.ctx.GetSessionVars().InRestrictedSQL ||
+ e.analyzePB.ColReq == nil || *e.analyzePB.ColReq.SampleRate <= 0
+}
+
func (e *AnalyzeColumnsExecV2) analyzeColumnsPushDownV2() *statistics.AnalyzeResults {
var ranges []*ranger.Range
if hc := e.handleCols; hc != nil {
@@ -98,6 +105,7 @@ func (e *AnalyzeColumnsExecV2) analyzeColumnsPushDownV2() *statistics.AnalyzeRes
} else {
ranges = ranger.FullIntRange(false)
}
+
collExtStats := e.ctx.GetSessionVars().EnableExtendedStats
specialIndexes := make([]*model.IndexInfo, 0, len(e.indexes))
specialIndexesOffsets := make([]int, 0, len(e.indexes))
@@ -126,7 +134,8 @@ func (e *AnalyzeColumnsExecV2) analyzeColumnsPushDownV2() *statistics.AnalyzeRes
e.handleNDVForSpecialIndexes(specialIndexes, idxNDVPushDownCh)
})
defer wg.Wait()
- count, hists, topns, fmSketches, extStats, err := e.buildSamplingStats(ranges, collExtStats, specialIndexesOffsets, idxNDVPushDownCh)
+
+ count, hists, topNs, fmSketches, extStats, err := e.buildSamplingStats(ranges, collExtStats, specialIndexesOffsets, idxNDVPushDownCh)
if err != nil {
e.memTracker.Release(e.memTracker.BytesConsumed())
return &statistics.AnalyzeResults{Err: err, Job: e.job}
@@ -134,7 +143,7 @@ func (e *AnalyzeColumnsExecV2) analyzeColumnsPushDownV2() *statistics.AnalyzeRes
cLen := len(e.analyzePB.ColReq.ColumnsInfo)
colGroupResult := &statistics.AnalyzeResult{
Hist: hists[cLen:],
- TopNs: topns[cLen:],
+ TopNs: topNs[cLen:],
Fms: fmSketches[cLen:],
IsIndex: 1,
}
@@ -148,9 +157,10 @@ func (e *AnalyzeColumnsExecV2) analyzeColumnsPushDownV2() *statistics.AnalyzeRes
}
colResult := &statistics.AnalyzeResult{
Hist: hists[:cLen],
- TopNs: topns[:cLen],
+ TopNs: topNs[:cLen],
Fms: fmSketches[:cLen],
}
+
return &statistics.AnalyzeResults{
TableID: e.tableID,
Ars: []*statistics.AnalyzeResult{colResult, colGroupResult},
@@ -235,6 +245,7 @@ func (e *AnalyzeColumnsExecV2) buildSamplingStats(
extStats *statistics.ExtendedStatsColl,
err error,
) {
+ // Open memory tracker and resultHandler.
if err = e.open(ranges); err != nil {
return 0, nil, nil, nil, nil, err
}
@@ -243,16 +254,20 @@ func (e *AnalyzeColumnsExecV2) buildSamplingStats(
err = err1
}
}()
+
l := len(e.analyzePB.ColReq.ColumnsInfo) + len(e.analyzePB.ColReq.ColumnGroups)
rootRowCollector := statistics.NewRowSampleCollector(int(e.analyzePB.ColReq.SampleSize), e.analyzePB.ColReq.GetSampleRate(), l)
for i := 0; i < l; i++ {
rootRowCollector.Base().FMSketches = append(rootRowCollector.Base().FMSketches, statistics.NewFMSketch(maxSketchSize))
}
+
sc := e.ctx.GetSessionVars().StmtCtx
statsConcurrency, err := getBuildStatsConcurrency(e.ctx)
if err != nil {
return 0, nil, nil, nil, nil, err
}
+
+ // Start workers to merge the result from collectors.
mergeResultCh := make(chan *samplingMergeResult, statsConcurrency)
mergeTaskCh := make(chan []byte, statsConcurrency)
e.samplingMergeWg = &util.WaitGroupWrapper{}
@@ -260,10 +275,13 @@ func (e *AnalyzeColumnsExecV2) buildSamplingStats(
for i := 0; i < statsConcurrency; i++ {
go e.subMergeWorker(mergeResultCh, mergeTaskCh, l, i)
}
+
+ // Start read data from resultHandler and send them to mergeTaskCh.
if err = readDataAndSendTask(e.ctx, e.resultHandler, mergeTaskCh, e.memTracker); err != nil {
return 0, nil, nil, nil, nil, getAnalyzePanicErr(err)
}
+ // Merge the result from collectors.
mergeWorkerPanicCnt := 0
for mergeWorkerPanicCnt < statsConcurrency {
mergeResult, ok := <-mergeResultCh
@@ -279,6 +297,7 @@ func (e *AnalyzeColumnsExecV2) buildSamplingStats(
}
oldRootCollectorSize := rootRowCollector.Base().MemSize
oldRootCollectorCount := rootRowCollector.Base().Count
+ // Merge the result from sub-collectors.
rootRowCollector.MergeCollector(mergeResult.collector)
newRootCollectorCount := rootRowCollector.Base().Count
printAnalyzeMergeCollectorLog(oldRootCollectorCount, newRootCollectorCount,
@@ -291,7 +310,7 @@ func (e *AnalyzeColumnsExecV2) buildSamplingStats(
return 0, nil, nil, nil, nil, err
}
- // handling virtual columns
+ // Decode the data from sample collectors.
virtualColIdx := buildVirtualColumnIndex(e.schemaForVirtualColEval, e.colsInfo)
if len(virtualColIdx) > 0 {
fieldTps := make([]*types.FieldType, 0, len(virtualColIdx))
@@ -314,16 +333,14 @@ func (e *AnalyzeColumnsExecV2) buildSamplingStats(
}
}
+ // Calculate handle from the row data for each row. It will be used to sort the samples.
for _, sample := range rootRowCollector.Base().Samples {
- // Calculate handle from the row data for each row. It will be used to sort the samples.
sample.Handle, err = e.handleCols.BuildHandleByDatums(sample.Columns)
if err != nil {
return 0, nil, nil, nil, nil, err
}
}
-
colLen := len(e.colsInfo)
-
// The order of the samples are broken when merging samples from sub-collectors.
// So now we need to sort the samples according to the handle in order to calculate correlation.
sort.Slice(rootRowCollector.Base().Samples, func(i, j int) bool {
@@ -343,11 +360,14 @@ func (e *AnalyzeColumnsExecV2) buildSamplingStats(
sampleCollectors := make([]*statistics.SampleCollector, len(e.colsInfo))
exitCh := make(chan struct{})
e.samplingBuilderWg.Add(statsConcurrency)
+
+ // Start workers to build stats.
for i := 0; i < statsConcurrency; i++ {
e.samplingBuilderWg.Run(func() {
e.subBuildWorker(buildResultChan, buildTaskChan, hists, topns, sampleCollectors, exitCh)
})
}
+ // Generate tasks for building stats.
for i, col := range e.colsInfo {
buildTaskChan <- &samplingBuildTask{
id: col.ID,
@@ -371,7 +391,7 @@ func (e *AnalyzeColumnsExecV2) buildSamplingStats(
rootRowCollector.Base().FMSketches[colLen+offset] = ret.Ars[0].Fms[0]
}
- // build index stats
+ // Generate tasks for building stats for indexes.
for i, idx := range e.indexes {
buildTaskChan <- &samplingBuildTask{
id: idx.ID,
@@ -383,6 +403,7 @@ func (e *AnalyzeColumnsExecV2) buildSamplingStats(
fmSketches = append(fmSketches, rootRowCollector.Base().FMSketches[colLen+i])
}
close(buildTaskChan)
+
panicCnt := 0
for panicCnt < statsConcurrency {
err1, ok := <-buildResultChan
@@ -409,6 +430,7 @@ func (e *AnalyzeColumnsExecV2) buildSamplingStats(
if err != nil {
return 0, nil, nil, nil, nil, err
}
+
count = rootRowCollector.Base().Count
if needExtStats {
statsHandle := domain.GetDomain(e.ctx).StatsHandle()
@@ -417,6 +439,7 @@ func (e *AnalyzeColumnsExecV2) buildSamplingStats(
return 0, nil, nil, nil, nil, err
}
}
+
return
}
@@ -573,7 +596,8 @@ func (e *AnalyzeColumnsExecV2) buildSubIndexJobForSpecialIndex(indexInfos []*mod
}
func (e *AnalyzeColumnsExecV2) subMergeWorker(resultCh chan<- *samplingMergeResult, taskCh <-chan []byte, l int, index int) {
- isClosedChanThread := index == 0
+ // Only close the resultCh in the first worker.
+ closeTheResultCh := index == 0
defer func() {
if r := recover(); r != nil {
logutil.BgLogger().Error("analyze worker panicked", zap.Any("recover", r), zap.Stack("stack"))
@@ -588,7 +612,7 @@ func (e *AnalyzeColumnsExecV2) subMergeWorker(resultCh chan<- *samplingMergeResu
}
}
e.samplingMergeWg.Done()
- if isClosedChanThread {
+ if closeTheResultCh {
e.samplingMergeWg.Wait()
close(resultCh)
}
@@ -612,6 +636,8 @@ func (e *AnalyzeColumnsExecV2) subMergeWorker(resultCh chan<- *samplingMergeResu
if !ok {
break
}
+
+ // Unmarshal the data.
dataSize := int64(cap(data))
colResp := &tipb.AnalyzeColumnsResp{}
err := colResp.Unmarshal(data)
@@ -619,11 +645,16 @@ func (e *AnalyzeColumnsExecV2) subMergeWorker(resultCh chan<- *samplingMergeResu
resultCh <- &samplingMergeResult{err: err}
return
}
+ // Consume the memory of the data.
colRespSize := int64(colResp.Size())
e.memTracker.Consume(colRespSize)
+
+ // Update processed rows.
subCollector := statistics.NewRowSampleCollector(int(e.analyzePB.ColReq.SampleSize), e.analyzePB.ColReq.GetSampleRate(), l)
subCollector.Base().FromProto(colResp.RowCollector, e.memTracker)
UpdateAnalyzeJob(e.ctx, e.job, subCollector.Base().Count)
+
+ // Print collect log.
oldRetCollectorSize := retCollector.Base().MemSize
oldRetCollectorCount := retCollector.Base().Count
retCollector.MergeCollector(subCollector)
@@ -631,11 +662,14 @@ func (e *AnalyzeColumnsExecV2) subMergeWorker(resultCh chan<- *samplingMergeResu
printAnalyzeMergeCollectorLog(oldRetCollectorCount, newRetCollectorCount, subCollector.Base().Count,
e.tableID.TableID, e.tableID.PartitionID, e.TableID.IsPartitionTable(),
"merge subCollector in concurrency in AnalyzeColumnsExecV2", index)
+
+ // Consume the memory of the result.
newRetCollectorSize := retCollector.Base().MemSize
subCollectorSize := subCollector.Base().MemSize
e.memTracker.Consume(newRetCollectorSize - oldRetCollectorSize - subCollectorSize)
e.memTracker.Release(dataSize + colRespSize)
}
+
resultCh <- &samplingMergeResult{collector: retCollector}
}
@@ -650,11 +684,13 @@ func (e *AnalyzeColumnsExecV2) subBuildWorker(resultCh chan error, taskCh chan *
failpoint.Inject("mockAnalyzeSamplingBuildWorkerPanic", func() {
panic("failpoint triggered")
})
+
colLen := len(e.colsInfo)
bufferedMemSize := int64(0)
bufferedReleaseSize := int64(0)
defer e.memTracker.Consume(bufferedMemSize)
defer e.memTracker.Release(bufferedReleaseSize)
+
workLoop:
for {
select {
@@ -813,6 +849,7 @@ type samplingBuildTask struct {
}
func readDataAndSendTask(ctx sessionctx.Context, handler *tableResultHandler, mergeTaskCh chan []byte, memTracker *memory.Tracker) error {
+ // After all tasks are sent, close the mergeTaskCh to notify the mergeWorker that all tasks have been sent.
defer close(mergeTaskCh)
for {
failpoint.Inject("mockKillRunningV2AnalyzeJob", func() {
@@ -825,6 +862,7 @@ func readDataAndSendTask(ctx sessionctx.Context, handler *tableResultHandler, me
failpoint.Inject("mockSlowAnalyzeV2", func() {
time.Sleep(1000 * time.Second)
})
+
data, err := handler.nextRaw(context.TODO())
if err != nil {
return errors.Trace(err)
@@ -832,8 +870,10 @@ func readDataAndSendTask(ctx sessionctx.Context, handler *tableResultHandler, me
if data == nil {
break
}
+
memTracker.Consume(int64(cap(data)))
mergeTaskCh <- data
}
+
return nil
}
diff --git a/executor/analyze_global_stats.go b/executor/analyze_global_stats.go
index 5c87f0cebc897..16c6a1fb901a1 100644
--- a/executor/analyze_global_stats.go
+++ b/executor/analyze_global_stats.go
@@ -25,6 +25,7 @@ import (
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/logutil"
"go.uber.org/zap"
+ "golang.org/x/exp/maps"
)
type globalStatsKey struct {
@@ -45,23 +46,24 @@ type globalStatsInfo struct {
// The meaning of value in map is some additional information needed to build global-level stats.
type globalStatsMap map[globalStatsKey]globalStatsInfo
-func (e *AnalyzeExec) handleGlobalStats(ctx context.Context, needGlobalStats bool, globalStatsMap globalStatsMap) error {
- if !needGlobalStats {
- return nil
- }
- globalStatsTableIDs := make(map[int64]struct{})
+func (e *AnalyzeExec) handleGlobalStats(ctx context.Context, globalStatsMap globalStatsMap) error {
+ globalStatsTableIDs := make(map[int64]struct{}, len(globalStatsMap))
for globalStatsID := range globalStatsMap {
globalStatsTableIDs[globalStatsID.tableID] = struct{}{}
}
+
statsHandle := domain.GetDomain(e.Ctx()).StatsHandle()
- tableIDs := map[int64]struct{}{}
+ tableIDs := make(map[int64]struct{}, len(globalStatsTableIDs))
+ tableAllPartitionStats := make(map[int64]*statistics.Table)
for tableID := range globalStatsTableIDs {
tableIDs[tableID] = struct{}{}
- tableAllPartitionStats := make(map[int64]*statistics.Table)
+ maps.Clear(tableAllPartitionStats)
+
for globalStatsID, info := range globalStatsMap {
if globalStatsID.tableID != tableID {
continue
}
+
job := e.newAnalyzeHandleGlobalStatsJob(globalStatsID)
if job == nil {
logutil.BgLogger().Warn("cannot find the partitioned table, skip merging global stats", zap.Int64("tableID", globalStatsID.tableID))
@@ -69,6 +71,7 @@ func (e *AnalyzeExec) handleGlobalStats(ctx context.Context, needGlobalStats boo
}
AddNewAnalyzeJob(e.Ctx(), job)
StartAnalyzeJob(e.Ctx(), job)
+
mergeStatsErr := func() error {
globalOpts := e.opts
if e.OptionsMap != nil {
@@ -76,9 +79,15 @@ func (e *AnalyzeExec) handleGlobalStats(ctx context.Context, needGlobalStats boo
globalOpts = v2Options.FilledOpts
}
}
- globalStats, err := statsHandle.MergePartitionStats2GlobalStatsByTableID(e.Ctx(), globalOpts, e.Ctx().GetInfoSchema().(infoschema.InfoSchema),
- globalStatsID.tableID, info.isIndex, info.histIDs,
- tableAllPartitionStats)
+
+ globalStats, err := statsHandle.MergePartitionStats2GlobalStatsByTableID(
+ e.Ctx(),
+ globalOpts, e.Ctx().GetInfoSchema().(infoschema.InfoSchema),
+ globalStatsID.tableID,
+ info.isIndex == 1,
+ info.histIDs,
+ tableAllPartitionStats,
+ )
if err != nil {
logutil.BgLogger().Warn("merge global stats failed",
zap.String("info", job.JobInfo), zap.Error(err), zap.Int64("tableID", tableID))
@@ -88,6 +97,7 @@ func (e *AnalyzeExec) handleGlobalStats(ctx context.Context, needGlobalStats boo
}
return err
}
+ // Dump global-level stats to kv.
for i := 0; i < globalStats.Num; i++ {
hg, cms, topN := globalStats.Hg[i], globalStats.Cms[i], globalStats.TopN[i]
if hg == nil {
@@ -114,15 +124,22 @@ func (e *AnalyzeExec) handleGlobalStats(ctx context.Context, needGlobalStats boo
}
return err
}()
+
FinishAnalyzeMergeJob(e.Ctx(), job, mergeStatsErr)
}
+
+ for _, value := range tableAllPartitionStats {
+ value.ReleaseAndPutToPool()
+ }
}
+
for tableID := range tableIDs {
// Dump stats to historical storage.
if err := recordHistoricalStats(e.Ctx(), tableID); err != nil {
logutil.BgLogger().Error("record historical stats failed", zap.Error(err))
}
}
+
return nil
}
diff --git a/executor/batch_point_get_test.go b/executor/batch_point_get_test.go
index 4b1eda8abe3d9..3309afd3be6b6 100644
--- a/executor/batch_point_get_test.go
+++ b/executor/batch_point_get_test.go
@@ -380,10 +380,10 @@ func TestBatchPointGetIssue46779(t *testing.T) {
tk.MustExec("CREATE TABLE t1 (id int, c varchar(128), primary key (id)) PARTITION BY HASH (id) PARTITIONS 3;")
tk.MustExec(`insert into t1 values (1, "a"), (11, "b"), (21, "c")`)
query := "select * from t1 where id in (1, 1, 11)"
- require.True(t, tk.HasPlan(query, "Batch_Point_Get")) // check if BatchPointGet is used
+ tk.MustHavePlan(query, "Batch_Point_Get") // check if BatchPointGet is used
tk.MustQuery(query).Sort().Check(testkit.Rows("1 a", "11 b"))
query = "select * from t1 where id in (1, 11, 11, 21)"
- require.True(t, tk.HasPlan(query, "Batch_Point_Get")) // check if BatchPointGet is used
+ tk.MustHavePlan(query, "Batch_Point_Get") // check if BatchPointGet is used
tk.MustQuery(query).Sort().Check(testkit.Rows("1 a", "11 b", "21 c"))
tk.MustExec("drop table if exists t2")
@@ -393,9 +393,9 @@ func TestBatchPointGetIssue46779(t *testing.T) {
partition p2 values less than (30));`)
tk.MustExec(`insert into t2 values (1, "a"), (11, "b"), (21, "c")`)
query = "select * from t2 where id in (1, 1, 11)"
- require.True(t, tk.HasPlan(query, "Batch_Point_Get")) // check if BatchPointGet is used
+ tk.MustHavePlan(query, "Batch_Point_Get") // check if BatchPointGet is used
tk.MustQuery(query).Sort().Check(testkit.Rows("1 a", "11 b"))
- require.True(t, tk.HasPlan(query, "Batch_Point_Get")) // check if BatchPointGet is used
+ tk.MustHavePlan(query, "Batch_Point_Get") // check if BatchPointGet is used
query = "select * from t2 where id in (1, 11, 11, 21)"
tk.MustQuery(query).Sort().Check(testkit.Rows("1 a", "11 b", "21 c"))
}
diff --git a/executor/builder.go b/executor/builder.go
index 27166a216008c..a3fb8c0235653 100644
--- a/executor/builder.go
+++ b/executor/builder.go
@@ -1940,7 +1940,7 @@ func (b *executorBuilder) getSnapshot() (kv.Snapshot, error) {
replicaReadType := sessVars.GetReplicaRead()
snapshot.SetOption(kv.ReadReplicaScope, b.readReplicaScope)
snapshot.SetOption(kv.TaskID, sessVars.StmtCtx.TaskID)
- snapshot.SetOption(kv.TidbKvReadTimeout, sessVars.GetTidbKvReadTimeout())
+ snapshot.SetOption(kv.TiKVClientReadTimeout, sessVars.GetTiKVClientReadTimeout())
snapshot.SetOption(kv.ResourceGroupName, sessVars.ResourceGroupName)
snapshot.SetOption(kv.ExplicitRequestSourceType, sessVars.ExplicitRequestSourceType)
@@ -2670,7 +2670,11 @@ func (b *executorBuilder) buildAnalyzeIndexIncremental(task plannercore.AnalyzeI
return analyzeTask
}
-func (b *executorBuilder) buildAnalyzeSamplingPushdown(task plannercore.AnalyzeColumnsTask, opts map[ast.AnalyzeOptionType]uint64, schemaForVirtualColEval *expression.Schema) *analyzeTask {
+func (b *executorBuilder) buildAnalyzeSamplingPushdown(
+ task plannercore.AnalyzeColumnsTask,
+ opts map[ast.AnalyzeOptionType]uint64,
+ schemaForVirtualColEval *expression.Schema,
+) *analyzeTask {
if task.V2Options != nil {
opts = task.V2Options.FilledOpts
}
@@ -2844,7 +2848,12 @@ func (b *executorBuilder) getApproximateTableCountFromStorage(tid int64, task pl
return pdhelper.GlobalPDHelper.GetApproximateTableCountFromStorage(b.ctx, tid, task.DBName, task.TableName, task.PartitionName)
}
-func (b *executorBuilder) buildAnalyzeColumnsPushdown(task plannercore.AnalyzeColumnsTask, opts map[ast.AnalyzeOptionType]uint64, autoAnalyze string, schemaForVirtualColEval *expression.Schema) *analyzeTask {
+func (b *executorBuilder) buildAnalyzeColumnsPushdown(
+ task plannercore.AnalyzeColumnsTask,
+ opts map[ast.AnalyzeOptionType]uint64,
+ autoAnalyze string,
+ schemaForVirtualColEval *expression.Schema,
+) *analyzeTask {
if task.StatsVersion == statistics.Version2 {
return b.buildAnalyzeSamplingPushdown(task, opts, schemaForVirtualColEval)
}
@@ -3085,7 +3094,13 @@ func (b *executorBuilder) buildAnalyze(v *plannercore.Analyze) exec.Executor {
if enableFastAnalyze {
b.buildAnalyzeFastColumn(e, task, v.Opts)
} else {
- columns, _, err := expression.ColumnInfos2ColumnsAndNames(b.ctx, model.NewCIStr(task.AnalyzeInfo.DBName), task.TblInfo.Name, task.ColsInfo, task.TblInfo)
+ columns, _, err := expression.ColumnInfos2ColumnsAndNames(
+ b.ctx,
+ model.NewCIStr(task.AnalyzeInfo.DBName),
+ task.TblInfo.Name,
+ task.ColsInfo,
+ task.TblInfo,
+ )
if err != nil {
b.err = err
return nil
@@ -3094,6 +3109,7 @@ func (b *executorBuilder) buildAnalyze(v *plannercore.Analyze) exec.Executor {
e.tasks = append(e.tasks, b.buildAnalyzeColumnsPushdown(task, v.Opts, autoAnalyze, schema))
}
}
+ // Other functions may set b.err, so we need to check it here.
if b.err != nil {
return nil
}
@@ -4946,6 +4962,8 @@ func (b *executorBuilder) buildWindow(v *plannercore.PhysicalWindow) exec.Execut
exec.orderByCols = orderByCols
exec.expectedCmpResult = cmpResult
exec.isRangeFrame = true
+ exec.start.InitCompareCols(b.ctx, exec.orderByCols)
+ exec.end.InitCompareCols(b.ctx, exec.orderByCols)
}
}
return exec
@@ -4968,7 +4986,7 @@ func (b *executorBuilder) buildWindow(v *plannercore.PhysicalWindow) exec.Execut
if len(v.OrderBy) > 0 && v.OrderBy[0].Desc {
cmpResult = 1
}
- processor = &rangeFrameWindowProcessor{
+ tmpProcessor := &rangeFrameWindowProcessor{
windowFuncs: windowFuncs,
partialResults: partialResults,
start: v.Frame.Start,
@@ -4976,6 +4994,11 @@ func (b *executorBuilder) buildWindow(v *plannercore.PhysicalWindow) exec.Execut
orderByCols: orderByCols,
expectedCmpResult: cmpResult,
}
+
+ tmpProcessor.start.InitCompareCols(b.ctx, orderByCols)
+ tmpProcessor.end.InitCompareCols(b.ctx, orderByCols)
+
+ processor = tmpProcessor
}
return &WindowExec{BaseExecutor: base,
processor: processor,
diff --git a/executor/distsql_test.go b/executor/distsql_test.go
index c00289ccb5c8d..ce02b2c1b3369 100644
--- a/executor/distsql_test.go
+++ b/executor/distsql_test.go
@@ -690,11 +690,11 @@ func TestIndexLookUpWithSelectForUpdateOnPartitionTable(t *testing.T) {
tk.MustExec("use test")
tk.MustExec("create table t(a int, b int, index k(b)) PARTITION BY HASH(a) partitions 4")
tk.MustExec("insert into t(a, b) values (1,1),(2,2),(3,3),(4,4),(5,5),(6,6),(7,7),(8,8)")
- tk.HasPlan("select b from t use index(k) where b > 2 order by b limit 1 for update", "UnionScan")
- tk.HasPlan("select b from t use index(k) where b > 2 order by b limit 1 for update", "IndexLookUp")
+ tk.MustHavePlan("select b from t use index(k) where b > 2 order by b limit 1 for update", "PartitionUnion")
+ tk.MustHavePlan("select b from t use index(k) where b > 2 order by b limit 1 for update", "IndexLookUp")
tk.MustQuery("select b from t use index(k) where b > 2 order by b limit 1 for update").Check(testkit.Rows("3"))
tk.MustExec("analyze table t")
- tk.HasPlan("select b from t use index(k) where b > 2 order by b limit 1 for update", "IndexLookUp")
+ tk.MustHavePlan("select b from t use index(k) where b > 2 order by b limit 1 for update", "IndexLookUp")
tk.MustQuery("select b from t use index(k) where b > 2 order by b limit 1 for update").Check(testkit.Rows("3"))
}
diff --git a/executor/executor_failpoint_test.go b/executor/executor_failpoint_test.go
index 2784cf1727ced..ecec2a35aa38e 100644
--- a/executor/executor_failpoint_test.go
+++ b/executor/executor_failpoint_test.go
@@ -556,7 +556,7 @@ func TestDeadlocksTable(t *testing.T) {
))
}
-func TestTidbKvReadTimeout(t *testing.T) {
+func TestTiKVClientReadTimeout(t *testing.T) {
if *testkit.WithTiKV != "" {
t.Skip("skip test since it's only work for unistore")
}
@@ -569,19 +569,19 @@ func TestTidbKvReadTimeout(t *testing.T) {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/store/mockstore/unistore/unistoreRPCDeadlineExceeded"))
}()
// Test for point_get request
- rows := tk.MustQuery("explain analyze select /*+ tidb_kv_read_timeout(1) */ * from t where a = 1").Rows()
+ rows := tk.MustQuery("explain analyze select /*+ set_var(tikv_client_read_timeout=1) */ * from t where a = 1").Rows()
require.Len(t, rows, 1)
explain := fmt.Sprintf("%v", rows[0])
require.Regexp(t, ".*Point_Get.* Get:{num_rpc:2, total_time:.*", explain)
// Test for batch_point_get request
- rows = tk.MustQuery("explain analyze select /*+ tidb_kv_read_timeout(1) */ * from t where a in (1,2)").Rows()
+ rows = tk.MustQuery("explain analyze select /*+ set_var(tikv_client_read_timeout=1) */ * from t where a in (1,2)").Rows()
require.Len(t, rows, 1)
explain = fmt.Sprintf("%v", rows[0])
require.Regexp(t, ".*Batch_Point_Get.* BatchGet:{num_rpc:2, total_time:.*", explain)
// Test for cop request
- rows = tk.MustQuery("explain analyze select /*+ tidb_kv_read_timeout(1) */ * from t where b > 1").Rows()
+ rows = tk.MustQuery("explain analyze select /*+ set_var(tikv_client_read_timeout=1) */ * from t where b > 1").Rows()
require.Len(t, rows, 3)
explain = fmt.Sprintf("%v", rows[0])
require.Regexp(t, ".*TableReader.* root time:.*, loops:.* cop_task: {num: 1, .* rpc_num: 2.*", explain)
@@ -589,13 +589,13 @@ func TestTidbKvReadTimeout(t *testing.T) {
// Test for stale read.
tk.MustExec("set @a=now(6);")
tk.MustExec("set @@tidb_replica_read='closest-replicas';")
- rows = tk.MustQuery("explain analyze select /*+ tidb_kv_read_timeout(1) */ * from t as of timestamp(@a) where b > 1").Rows()
+ rows = tk.MustQuery("explain analyze select /*+ set_var(tikv_client_read_timeout=1) */ * from t as of timestamp(@a) where b > 1").Rows()
require.Len(t, rows, 3)
explain = fmt.Sprintf("%v", rows[0])
require.Regexp(t, ".*TableReader.* root time:.*, loops:.* cop_task: {num: 1, .* rpc_num: 2.*", explain)
- // Test for tidb_kv_read_timeout session variable.
- tk.MustExec("set @@tidb_kv_read_timeout=1;")
+ // Test for tikv_client_read_timeout session variable.
+ tk.MustExec("set @@tikv_client_read_timeout=1;")
// Test for point_get request
rows = tk.MustQuery("explain analyze select * from t where a = 1").Rows()
require.Len(t, rows, 1)
diff --git a/executor/historical_stats_test.go b/executor/historical_stats_test.go
index 7288a8958b062..df449107c2d52 100644
--- a/executor/historical_stats_test.go
+++ b/executor/historical_stats_test.go
@@ -25,6 +25,7 @@ import (
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/statistics/handle"
+ "github.com/pingcap/tidb/statistics/handle/globalstats"
"github.com/pingcap/tidb/testkit"
"github.com/stretchr/testify/require"
"github.com/tikv/client-go/v2/oracle"
@@ -343,7 +344,7 @@ PARTITION p0 VALUES LESS THAN (6)
require.NotNil(t, jsTable)
// only has p0 stats
require.NotNil(t, jsTable.Partitions["p0"])
- require.Nil(t, jsTable.Partitions[handle.TiDBGlobalStats])
+ require.Nil(t, jsTable.Partitions[globalstats.TiDBGlobalStats])
// change static to dynamic then assert
tk.MustExec("set @@tidb_partition_prune_mode='dynamic'")
@@ -365,7 +366,7 @@ PARTITION p0 VALUES LESS THAN (6)
require.NotNil(t, jsTable)
// has both global and p0 stats
require.NotNil(t, jsTable.Partitions["p0"])
- require.NotNil(t, jsTable.Partitions[handle.TiDBGlobalStats])
+ require.NotNil(t, jsTable.Partitions[globalstats.TiDBGlobalStats])
}
func TestDumpHistoricalStatsFallback(t *testing.T) {
diff --git a/executor/import_into.go b/executor/import_into.go
index b63d0f14a4137..a02a3bb12d232 100644
--- a/executor/import_into.go
+++ b/executor/import_into.go
@@ -194,7 +194,7 @@ func (e *ImportIntoExec) getJobImporter(ctx context.Context, param *importer.Job
importFromServer, err := storage.IsLocalPath(e.controller.Path)
if err != nil {
// since we have checked this during creating controller, this should not happen.
- return nil, exeerrors.ErrLoadDataInvalidURI.FastGenByArgs(err.Error())
+ return nil, exeerrors.ErrLoadDataInvalidURI.FastGenByArgs(plannercore.ImportIntoDataSource, err.Error())
}
logutil.Logger(ctx).Info("get job importer", zap.Stringer("param", e.controller.Parameters),
zap.Bool("dist-task-enabled", variable.EnableDistTask.Load()))
@@ -295,7 +295,6 @@ func flushStats(ctx context.Context, se sessionctx.Context, tableID int64, resul
func cancelImportJob(ctx context.Context, manager *fstorage.TaskManager, jobID int64) error {
// todo: cancel is async operation, we don't wait here now, maybe add a wait syntax later.
// todo: after CANCEL, user can see the job status is Canceled immediately, but the job might still running.
- // and the state of framework task might became finished since framework don't force state change DAG when update task.
// todo: add a CANCELLING status?
return manager.WithNewTxn(ctx, func(se sessionctx.Context) error {
exec := se.(sqlexec.SQLExecutor)
diff --git a/executor/import_into_test.go b/executor/import_into_test.go
index 81812add2269b..53ffdded23176 100644
--- a/executor/import_into_test.go
+++ b/executor/import_into_test.go
@@ -132,6 +132,11 @@ func TestImportIntoOptionsNegativeCase(t *testing.T) {
{OptionStr: "record_errors=-123", Err: exeerrors.ErrInvalidOptionVal},
{OptionStr: "record_errors=null", Err: exeerrors.ErrInvalidOptionVal},
{OptionStr: "record_errors=true", Err: exeerrors.ErrInvalidOptionVal},
+
+ {OptionStr: "cloud_storage_uri=123", Err: exeerrors.ErrInvalidOptionVal},
+ {OptionStr: "cloud_storage_uri=':'", Err: exeerrors.ErrInvalidOptionVal},
+ {OptionStr: "cloud_storage_uri='sdsd'", Err: exeerrors.ErrInvalidOptionVal},
+ {OptionStr: "cloud_storage_uri='http://sdsd'", Err: exeerrors.ErrInvalidOptionVal},
}
sqlTemplate := "import into t from '/file.csv' with %s"
diff --git a/executor/importer/BUILD.bazel b/executor/importer/BUILD.bazel
index 88693c1ad8206..652e1deeee855 100644
--- a/executor/importer/BUILD.bazel
+++ b/executor/importer/BUILD.bazel
@@ -16,6 +16,7 @@ go_library(
deps = [
"//br/pkg/lightning/backend",
"//br/pkg/lightning/backend/encode",
+ "//br/pkg/lightning/backend/external",
"//br/pkg/lightning/backend/kv",
"//br/pkg/lightning/backend/local",
"//br/pkg/lightning/checkpoints",
@@ -85,7 +86,7 @@ go_test(
embed = [":importer"],
flaky = True,
race = "on",
- shard_count = 14,
+ shard_count = 15,
deps = [
"//br/pkg/errors",
"//br/pkg/lightning/config",
@@ -100,7 +101,9 @@ go_test(
"//parser/ast",
"//parser/model",
"//planner/core",
+ "//sessionctx/variable",
"//testkit",
+ "//types",
"//util/dbterror/exeerrors",
"//util/etcd",
"//util/logutil",
diff --git a/executor/importer/chunk_process.go b/executor/importer/chunk_process.go
index 8f5b92c0565af..de28449d4933b 100644
--- a/executor/importer/chunk_process.go
+++ b/executor/importer/chunk_process.go
@@ -22,6 +22,8 @@ import (
"github.com/docker/go-units"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/br/pkg/lightning/backend"
+ "github.com/pingcap/tidb/br/pkg/lightning/backend/encode"
+ "github.com/pingcap/tidb/br/pkg/lightning/backend/external"
"github.com/pingcap/tidb/br/pkg/lightning/backend/kv"
"github.com/pingcap/tidb/br/pkg/lightning/checkpoints"
"github.com/pingcap/tidb/br/pkg/lightning/common"
@@ -335,3 +337,68 @@ func (p *chunkProcessor) deliverLoop(ctx context.Context) error {
return nil
}
+
+// IndexRouteWriter is a writer for index when using global sort.
+// we route kvs of different index to different writer in order to make
+// merge sort easier, else kv data of all subtasks will all be overlapped.
+//
+// drawback of doing this is that the number of writers need to open will be
+// index-count * encode-concurrency, when the table has many indexes, and each
+// writer will take 256MiB buffer on default.
+// this will take a lot of memory, or even OOM.
+type IndexRouteWriter struct {
+ writers map[int64]*external.Writer
+ logger *zap.Logger
+ writerFactory func(int64) *external.Writer
+}
+
+// NewIndexRouteWriter creates a new IndexRouteWriter.
+func NewIndexRouteWriter(logger *zap.Logger, writerFactory func(int64) *external.Writer) *IndexRouteWriter {
+ return &IndexRouteWriter{
+ writers: make(map[int64]*external.Writer),
+ logger: logger,
+ writerFactory: writerFactory,
+ }
+}
+
+// AppendRows implements backend.EngineWriter interface.
+func (w *IndexRouteWriter) AppendRows(ctx context.Context, _ []string, rows encode.Rows) error {
+ kvs := kv.Rows2KvPairs(rows)
+ if len(kvs) == 0 {
+ return nil
+ }
+ for _, item := range kvs {
+ indexID, err := tablecodec.DecodeIndexID(item.Key)
+ if err != nil {
+ return errors.Trace(err)
+ }
+ writer, ok := w.writers[indexID]
+ if !ok {
+ writer = w.writerFactory(indexID)
+ w.writers[indexID] = writer
+ }
+ if err = writer.WriteRow(ctx, item.Key, item.Val, nil); err != nil {
+ return errors.Trace(err)
+ }
+ }
+ return nil
+}
+
+// IsSynced implements backend.EngineWriter interface.
+func (*IndexRouteWriter) IsSynced() bool {
+ return true
+}
+
+// Close implements backend.EngineWriter interface.
+func (w *IndexRouteWriter) Close(ctx context.Context) (backend.ChunkFlushStatus, error) {
+ var firstErr error
+ for _, writer := range w.writers {
+ if err := writer.Close(ctx); err != nil {
+ if firstErr == nil {
+ firstErr = err
+ }
+ w.logger.Error("close index writer failed", zap.Error(err))
+ }
+ }
+ return nil, firstErr
+}
diff --git a/executor/importer/engine_process.go b/executor/importer/engine_process.go
index 58ffca7e6c880..b5c44de1c762a 100644
--- a/executor/importer/engine_process.go
+++ b/executor/importer/engine_process.go
@@ -47,40 +47,53 @@ func ProcessChunk(
dataWriterCfg := &backend.LocalWriterConfig{
IsKVSorted: hasAutoIncrementAutoID,
}
- parser, err := tableImporter.getParser(ctx, chunk)
+ dataWriter, err := dataEngine.LocalWriter(ctx, dataWriterCfg)
if err != nil {
return err
}
defer func() {
- if err2 := parser.Close(); err2 != nil {
- logger.Warn("close parser failed", zap.Error(err2))
+ if _, err2 := dataWriter.Close(ctx); err2 != nil {
+ logger.Warn("close data writer failed", zap.Error(err2))
}
}()
- encoder, err := tableImporter.getKVEncoder(chunk)
+ indexWriter, err := indexEngine.LocalWriter(ctx, &backend.LocalWriterConfig{})
if err != nil {
return err
}
defer func() {
- if err2 := encoder.Close(); err2 != nil {
- logger.Warn("close encoder failed", zap.Error(err2))
+ if _, err2 := indexWriter.Close(ctx); err2 != nil {
+ logger.Warn("close index writer failed", zap.Error(err2))
}
}()
- dataWriter, err := dataEngine.LocalWriter(ctx, dataWriterCfg)
+
+ return ProcessChunkWith(ctx, chunk, tableImporter, dataWriter, indexWriter, progress, logger)
+}
+
+// ProcessChunkWith processes a chunk, and write kv pairs to dataWriter and indexWriter.
+func ProcessChunkWith(
+ ctx context.Context,
+ chunk *checkpoints.ChunkCheckpoint,
+ tableImporter *TableImporter,
+ dataWriter, indexWriter backend.EngineWriter,
+ progress *asyncloaddata.Progress,
+ logger *zap.Logger,
+) error {
+ parser, err := tableImporter.getParser(ctx, chunk)
if err != nil {
return err
}
defer func() {
- if _, err2 := dataWriter.Close(ctx); err2 != nil {
- logger.Warn("close data writer failed", zap.Error(err2))
+ if err2 := parser.Close(); err2 != nil {
+ logger.Warn("close parser failed", zap.Error(err2))
}
}()
- indexWriter, err := indexEngine.LocalWriter(ctx, &backend.LocalWriterConfig{})
+ encoder, err := tableImporter.getKVEncoder(chunk)
if err != nil {
return err
}
defer func() {
- if _, err2 := indexWriter.Close(ctx); err2 != nil {
- logger.Warn("close index writer failed", zap.Error(err2))
+ if err2 := encoder.Close(); err2 != nil {
+ logger.Warn("close encoder failed", zap.Error(err2))
}
}()
diff --git a/executor/importer/import.go b/executor/importer/import.go
index e2ee585ce61cd..a441b6e9dae0a 100644
--- a/executor/importer/import.go
+++ b/executor/importer/import.go
@@ -18,6 +18,7 @@ import (
"context"
"io"
"math"
+ "net/url"
"os"
"path/filepath"
"runtime"
@@ -92,6 +93,7 @@ const (
recordErrorsOption = "record_errors"
detachedOption = "detached"
disableTiKVImportModeOption = "disable_tikv_import_mode"
+ cloudStorageURIOption = "cloud_storage_uri"
// used for test
maxEngineSizeOption = "__max_engine_size"
)
@@ -115,6 +117,7 @@ var (
detachedOption: false,
disableTiKVImportModeOption: false,
maxEngineSizeOption: true,
+ cloudStorageURIOption: true,
}
csvOnlyOptions = map[string]struct{}{
@@ -195,6 +198,7 @@ type Plan struct {
Detached bool
DisableTiKVImportMode bool
MaxEngineSize config.ByteSize
+ CloudStorageURI string
// used for checksum in physical mode
DistSQLScanConcurrency int
@@ -248,6 +252,8 @@ type LoadDataController struct {
logger *zap.Logger
dataStore storage.ExternalStorage
dataFiles []*mydump.SourceFileMeta
+ // GlobalSortStore is used to store sorted data when using global sort.
+ GlobalSortStore storage.ExternalStorage
}
func getImportantSysVars(sctx sessionctx.Context) map[string]string {
@@ -497,6 +503,7 @@ func (p *Plan) initDefaultOptions() {
p.Detached = false
p.DisableTiKVImportMode = false
p.MaxEngineSize = config.ByteSize(defaultMaxEngineSize)
+ p.CloudStorageURI = variable.CloudStorageURI.Load()
v := "utf8mb4"
p.Charset = &v
@@ -654,6 +661,25 @@ func (p *Plan) initOptions(seCtx sessionctx.Context, options []*plannercore.Load
if _, ok := specifiedOptions[disableTiKVImportModeOption]; ok {
p.DisableTiKVImportMode = true
}
+ if opt, ok := specifiedOptions[cloudStorageURIOption]; ok {
+ v, err := optAsString(opt)
+ if err != nil {
+ return exeerrors.ErrInvalidOptionVal.FastGenByArgs(opt.Name)
+ }
+ // set cloud storage uri to empty string to force uses local sort when
+ // the global variable is set.
+ if v != "" {
+ b, err := storage.ParseBackend(v, nil)
+ if err != nil {
+ return exeerrors.ErrInvalidOptionVal.FastGenByArgs(opt.Name)
+ }
+ // only support s3 and gcs now.
+ if b.GetS3() == nil && b.GetGcs() == nil {
+ return exeerrors.ErrInvalidOptionVal.FastGenByArgs(opt.Name)
+ }
+ }
+ p.CloudStorageURI = v
+ }
if opt, ok := specifiedOptions[maxEngineSizeOption]; ok {
v, err := optAsString(opt)
if err != nil {
@@ -715,7 +741,11 @@ func (p *Plan) initParameters(plan *plannercore.ImportInto) error {
optionMap := make(map[string]interface{}, len(plan.Options))
for _, opt := range plan.Options {
if opt.Value != nil {
- optionMap[opt.Name] = opt.Value.String()
+ val := opt.Value.String()
+ if opt.Name == cloudStorageURIOption {
+ val = ast.RedactURL(val)
+ }
+ optionMap[opt.Name] = val
} else {
optionMap[opt.Name] = nil
}
@@ -863,11 +893,64 @@ func (e *LoadDataController) GenerateCSVConfig() *config.CSVConfig {
return csvConfig
}
+// InitDataStore initializes the data store.
+func (e *LoadDataController) InitDataStore(ctx context.Context) error {
+ u, err2 := storage.ParseRawURL(e.Path)
+ if err2 != nil {
+ return exeerrors.ErrLoadDataInvalidURI.GenWithStackByArgs(plannercore.ImportIntoDataSource,
+ err2.Error())
+ }
+
+ if storage.IsLocal(u) {
+ u.Path = filepath.Dir(e.Path)
+ } else {
+ u.Path = ""
+ }
+ s, err := e.initExternalStore(ctx, u, plannercore.ImportIntoDataSource)
+ if err != nil {
+ return err
+ }
+ e.dataStore = s
+
+ if e.IsGlobalSort() {
+ target := "cloud storage"
+ cloudStorageURL, err3 := storage.ParseRawURL(e.Plan.CloudStorageURI)
+ if err3 != nil {
+ return exeerrors.ErrLoadDataInvalidURI.GenWithStackByArgs(target,
+ err3.Error())
+ }
+ s, err = e.initExternalStore(ctx, cloudStorageURL, target)
+ if err != nil {
+ return err
+ }
+ e.GlobalSortStore = s
+ }
+ return nil
+}
+func (*LoadDataController) initExternalStore(ctx context.Context, u *url.URL, target string) (storage.ExternalStorage, error) {
+ b, err2 := storage.ParseBackendFromURL(u, nil)
+ if err2 != nil {
+ return nil, exeerrors.ErrLoadDataInvalidURI.GenWithStackByArgs(target, GetMsgFromBRError(err2))
+ }
+
+ opt := &storage.ExternalStorageOptions{}
+ if intest.InTest {
+ opt.NoCredentials = true
+ }
+ s, err := storage.New(ctx, b, opt)
+ if err != nil {
+ return nil, exeerrors.ErrLoadDataCantAccess.GenWithStackByArgs(target, GetMsgFromBRError(err))
+ }
+ return s, nil
+}
+
// InitDataFiles initializes the data store and files.
+// it will call InitDataStore internally.
func (e *LoadDataController) InitDataFiles(ctx context.Context) error {
u, err2 := storage.ParseRawURL(e.Path)
if err2 != nil {
- return exeerrors.ErrLoadDataInvalidURI.GenWithStackByArgs(err2.Error())
+ return exeerrors.ErrLoadDataInvalidURI.GenWithStackByArgs(plannercore.ImportIntoDataSource,
+ err2.Error())
}
var fileNameKey string
@@ -878,45 +961,39 @@ func (e *LoadDataController) InitDataFiles(ctx context.Context) error {
}
if !filepath.IsAbs(e.Path) {
- return exeerrors.ErrLoadDataInvalidURI.GenWithStackByArgs("file location should be absolute path when import from server disk")
+ return exeerrors.ErrLoadDataInvalidURI.GenWithStackByArgs(plannercore.ImportIntoDataSource,
+ "file location should be absolute path when import from server disk")
}
// we add this check for security, we don't want user import any sensitive system files,
// most of which is readable text file and don't have a suffix, such as /etc/passwd
if !slices.Contains([]string{".csv", ".sql", ".parquet"}, strings.ToLower(filepath.Ext(e.Path))) {
- return exeerrors.ErrLoadDataInvalidURI.GenWithStackByArgs("the file suffix is not supported when import from server disk")
+ return exeerrors.ErrLoadDataInvalidURI.GenWithStackByArgs(plannercore.ImportIntoDataSource,
+ "the file suffix is not supported when import from server disk")
}
dir := filepath.Dir(e.Path)
_, err := os.Stat(dir)
if err != nil {
// permission denied / file not exist error, etc.
- return exeerrors.ErrLoadDataInvalidURI.GenWithStackByArgs(err.Error())
+ return exeerrors.ErrLoadDataInvalidURI.GenWithStackByArgs(plannercore.ImportIntoDataSource,
+ err.Error())
}
fileNameKey = filepath.Base(e.Path)
- u.Path = dir
} else {
fileNameKey = strings.Trim(u.Path, "/")
- u.Path = ""
- }
- b, err2 := storage.ParseBackendFromURL(u, nil)
- if err2 != nil {
- return exeerrors.ErrLoadDataInvalidURI.GenWithStackByArgs(GetMsgFromBRError(err2))
}
// try to find pattern error in advance
_, err2 = filepath.Match(stringutil.EscapeGlobExceptAsterisk(fileNameKey), "")
if err2 != nil {
- return exeerrors.ErrLoadDataInvalidURI.GenWithStackByArgs("Glob pattern error: " + err2.Error())
+ return exeerrors.ErrLoadDataInvalidURI.GenWithStackByArgs(plannercore.ImportIntoDataSource,
+ "Glob pattern error: "+err2.Error())
}
- opt := &storage.ExternalStorageOptions{}
- if intest.InTest {
- opt.NoCredentials = true
- }
- s, err := storage.New(ctx, b, opt)
- if err != nil {
- return exeerrors.ErrLoadDataCantAccess.GenWithStackByArgs(GetMsgFromBRError(err))
+ if err2 = e.InitDataStore(ctx); err2 != nil {
+ return err2
}
+ s := e.dataStore
var totalSize int64
dataFiles := []*mydump.SourceFileMeta{}
idx := strings.IndexByte(fileNameKey, '*')
@@ -955,7 +1032,7 @@ func (e *LoadDataController) InitDataFiles(ctx context.Context) error {
// access, else walkDir will fail
// we only support '*', in order to reuse glob library manually escape the path
escapedPath := stringutil.EscapeGlobExceptAsterisk(fileNameKey)
- err = s.WalkDir(ctx, &storage.WalkOption{ObjPrefix: commonPrefix, SkipSubDir: true},
+ err := s.WalkDir(ctx, &storage.WalkOption{ObjPrefix: commonPrefix, SkipSubDir: true},
func(remotePath string, size int64) error {
// we have checked in LoadDataExec.Next
//nolint: errcheck
@@ -980,7 +1057,6 @@ func (e *LoadDataController) InitDataFiles(ctx context.Context) error {
}
}
- e.dataStore = s
e.dataFiles = dataFiles
e.TotalFileSize = totalSize
return nil
@@ -1131,6 +1207,16 @@ func (e *LoadDataController) toMyDumpFiles() []mydump.FileInfo {
return res
}
+// IsLocalSort returns true if we sort data on local disk.
+func (e *LoadDataController) IsLocalSort() bool {
+ return e.Plan.CloudStorageURI == ""
+}
+
+// IsGlobalSort returns true if we sort data on global storage.
+func (e *LoadDataController) IsGlobalSort() bool {
+ return !e.IsLocalSort()
+}
+
// CreateColAssignExprs creates the column assignment expressions using session context.
// RewriteAstExpr will write ast node in place(due to xxNode.Accept), but it doesn't change node content,
// so we sync it.
diff --git a/executor/importer/import_test.go b/executor/importer/import_test.go
index b98592be88595..d141d61629416 100644
--- a/executor/importer/import_test.go
+++ b/executor/importer/import_test.go
@@ -17,6 +17,7 @@ package importer
import (
"context"
"fmt"
+ "net/url"
"runtime"
"testing"
@@ -30,6 +31,8 @@ import (
"github.com/pingcap/tidb/parser"
"github.com/pingcap/tidb/parser/ast"
plannercore "github.com/pingcap/tidb/planner/core"
+ "github.com/pingcap/tidb/sessionctx/variable"
+ "github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/logutil"
"github.com/pingcap/tidb/util/mock"
"github.com/stretchr/testify/require"
@@ -38,6 +41,10 @@ import (
func TestInitDefaultOptions(t *testing.T) {
plan := &Plan{}
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/importer/mockNumCpu", "return(1)"))
+ variable.CloudStorageURI.Store("s3://bucket/path")
+ t.Cleanup(func() {
+ variable.CloudStorageURI.Store("")
+ })
plan.initDefaultOptions()
require.Equal(t, config.ByteSize(0), plan.DiskQuota)
require.Equal(t, config.OpLevelRequired, plan.Checksum)
@@ -49,6 +56,7 @@ func TestInitDefaultOptions(t *testing.T) {
require.Equal(t, "utf8mb4", *plan.Charset)
require.Equal(t, false, plan.DisableTiKVImportMode)
require.Equal(t, config.ByteSize(defaultMaxEngineSize), plan.MaxEngineSize)
+ require.Equal(t, "s3://bucket/path", plan.CloudStorageURI)
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/importer/mockNumCpu", "return(10)"))
plan.initDefaultOptions()
@@ -76,7 +84,6 @@ func TestInitOptionsPositiveCase(t *testing.T) {
sqlTemplate := "import into t from '/file.csv' with %s"
p := parser.New()
- plan := &Plan{Format: DataFormatCSV}
sql := fmt.Sprintf(sqlTemplate, characterSetOption+"='utf8', "+
fieldsTerminatedByOption+"='aaa', "+
fieldsEnclosedByOption+"='|', "+
@@ -96,6 +103,7 @@ func TestInitOptionsPositiveCase(t *testing.T) {
)
stmt, err := p.ParseOneStmt(sql, "", "")
require.NoError(t, err, sql)
+ plan := &Plan{Format: DataFormatCSV}
err = plan.initOptions(ctx, convertOptions(stmt.(*ast.ImportIntoStmt).Options))
require.NoError(t, err, sql)
require.Equal(t, "utf8", *plan.Charset, sql)
@@ -114,6 +122,42 @@ func TestInitOptionsPositiveCase(t *testing.T) {
require.True(t, plan.Detached, sql)
require.True(t, plan.DisableTiKVImportMode, sql)
require.Equal(t, config.ByteSize(100<<30), plan.MaxEngineSize, sql)
+ require.Empty(t, plan.CloudStorageURI, sql)
+
+ // set cloud storage uri
+ variable.CloudStorageURI.Store("s3://bucket/path")
+ t.Cleanup(func() {
+ variable.CloudStorageURI.Store("")
+ })
+ plan = &Plan{Format: DataFormatCSV}
+ err = plan.initOptions(ctx, convertOptions(stmt.(*ast.ImportIntoStmt).Options))
+ require.NoError(t, err, sql)
+ require.Equal(t, "s3://bucket/path", plan.CloudStorageURI, sql)
+
+ // override cloud storage uri using option
+ sql2 := sql + ", " + cloudStorageURIOption + "='s3://bucket/path2'"
+ stmt, err = p.ParseOneStmt(sql2, "", "")
+ require.NoError(t, err, sql2)
+ plan = &Plan{Format: DataFormatCSV}
+ err = plan.initOptions(ctx, convertOptions(stmt.(*ast.ImportIntoStmt).Options))
+ require.NoError(t, err, sql2)
+ require.Equal(t, "s3://bucket/path2", plan.CloudStorageURI, sql2)
+ // override with gs
+ sql3 := sql + ", " + cloudStorageURIOption + "='gs://bucket/path2'"
+ stmt, err = p.ParseOneStmt(sql3, "", "")
+ require.NoError(t, err, sql3)
+ plan = &Plan{Format: DataFormatCSV}
+ err = plan.initOptions(ctx, convertOptions(stmt.(*ast.ImportIntoStmt).Options))
+ require.NoError(t, err, sql3)
+ require.Equal(t, "gs://bucket/path2", plan.CloudStorageURI, sql3)
+ // override with empty string, force use local sort
+ sql4 := sql + ", " + cloudStorageURIOption + "=''"
+ stmt, err = p.ParseOneStmt(sql4, "", "")
+ require.NoError(t, err, sql4)
+ plan = &Plan{Format: DataFormatCSV}
+ err = plan.initOptions(ctx, convertOptions(stmt.(*ast.ImportIntoStmt).Options))
+ require.NoError(t, err, sql4)
+ require.Equal(t, "", plan.CloudStorageURI, sql4)
}
func TestAdjustOptions(t *testing.T) {
@@ -176,3 +220,54 @@ func TestGetFileRealSize(t *testing.T) {
require.NoError(t, err)
require.Equal(t, int64(100), c.getFileRealSize(context.Background(), fileMeta, nil))
}
+
+func urlEqual(t *testing.T, expected, actual string) {
+ urlExpected, err := url.Parse(expected)
+ require.NoError(t, err)
+ urlGot, err := url.Parse(actual)
+ require.NoError(t, err)
+ // order of query parameters might change
+ require.Equal(t, urlExpected.Query(), urlGot.Query())
+ urlExpected.RawQuery, urlGot.RawQuery = "", ""
+ require.Equal(t, urlExpected.String(), urlGot.String())
+}
+
+func TestInitParameters(t *testing.T) {
+ // test redacted
+ p := &Plan{
+ Format: DataFormatCSV,
+ Path: "s3://bucket/path?access-key=111111&secret-access-key=222222",
+ }
+ require.NoError(t, p.initParameters(&plannercore.ImportInto{
+ Options: []*plannercore.LoadDataOpt{
+ {
+ Name: cloudStorageURIOption,
+ Value: &expression.Constant{
+ Value: types.NewStringDatum("s3://this-is-for-storage/path?access-key=aaaaaa&secret-access-key=bbbbbb"),
+ },
+ },
+ },
+ }))
+ urlEqual(t, "s3://bucket/path?access-key=xxxxxx&secret-access-key=xxxxxx", p.Parameters.FileLocation)
+ require.Len(t, p.Parameters.Options, 1)
+ urlEqual(t, "s3://this-is-for-storage/path?access-key=xxxxxx&secret-access-key=xxxxxx",
+ p.Parameters.Options[cloudStorageURIOption].(string))
+
+ // test other options
+ require.NoError(t, p.initParameters(&plannercore.ImportInto{
+ Options: []*plannercore.LoadDataOpt{
+ {
+ Name: detachedOption,
+ },
+ {
+ Name: threadOption,
+ Value: &expression.Constant{
+ Value: types.NewIntDatum(3),
+ },
+ },
+ },
+ }))
+ require.Len(t, p.Parameters.Options, 2)
+ require.Contains(t, p.Parameters.Options, detachedOption)
+ require.Equal(t, "3", p.Parameters.Options[threadOption])
+}
diff --git a/executor/importer/job.go b/executor/importer/job.go
index b35e427d0e9d6..b658d42efb067 100644
--- a/executor/importer/job.go
+++ b/executor/importer/job.go
@@ -58,7 +58,14 @@ const (
jobStatusFinished = "finished"
// when the job is finished, step will be set to none.
- jobStepNone = ""
+ jobStepNone = ""
+ // JobStepGlobalSorting is the first step when using global sort,
+ // step goes from none -> global-sorting -> importing -> validating -> none.
+ JobStepGlobalSorting = "global-sorting"
+ // JobStepImporting is the first step when using local sort,
+ // step goes from none -> importing -> validating -> none.
+ // when used in global sort, it means importing the sorted data.
+ // when used in local sort, it means encode&sort data and then importing the data.
JobStepImporting = "importing"
JobStepValidating = "validating"
@@ -216,14 +223,14 @@ func CreateJob(
return rows[0].GetInt64(0), nil
}
-// StartJob tries to start a pending job with jobID, change its status/step to running/importing.
+// StartJob tries to start a pending job with jobID, change its status/step to running/input step.
// It will not return error when there's no matched job or the job has already started.
-func StartJob(ctx context.Context, conn sqlexec.SQLExecutor, jobID int64) error {
+func StartJob(ctx context.Context, conn sqlexec.SQLExecutor, jobID int64, step string) error {
ctx = util.WithInternalSourceType(ctx, kv.InternalImportInto)
_, err := conn.ExecuteInternal(ctx, `UPDATE mysql.tidb_import_jobs
SET update_time = CURRENT_TIMESTAMP(6), start_time = CURRENT_TIMESTAMP(6), status = %?, step = %?
WHERE id = %? AND status = %?;`,
- JobStatusRunning, JobStepImporting, jobID, jobStatusPending)
+ JobStatusRunning, step, jobID, jobStatusPending)
return err
}
diff --git a/executor/importer/job_test.go b/executor/importer/job_test.go
index 3bccceb864a54..1b1b4b002f312 100644
--- a/executor/importer/job_test.go
+++ b/executor/importer/job_test.go
@@ -105,7 +105,7 @@ func TestJobHappyPath(t *testing.T) {
jobInfoEqual(t, jobInfo, gotJobInfo)
// start job
- require.NoError(t, importer.StartJob(ctx, conn, jobID))
+ require.NoError(t, importer.StartJob(ctx, conn, jobID, importer.JobStepImporting))
gotJobInfo, err = importer.GetJob(ctx, conn, jobID, jobInfo.CreatedBy, false)
require.NoError(t, err)
require.False(t, gotJobInfo.CreateTime.IsZero())
@@ -222,7 +222,7 @@ func TestGetAndCancelJob(t *testing.T) {
jobInfoEqual(t, jobInfo, gotJobInfo)
// start job
- require.NoError(t, importer.StartJob(ctx, conn, jobID2))
+ require.NoError(t, importer.StartJob(ctx, conn, jobID2, importer.JobStepImporting))
gotJobInfo, err = importer.GetJob(ctx, conn, jobID2, jobInfo.CreatedBy, false)
require.NoError(t, err)
require.False(t, gotJobInfo.CreateTime.IsZero())
@@ -306,7 +306,7 @@ func TestGetJobInfoNullField(t *testing.T) {
jobID1, err := importer.CreateJob(ctx, conn, jobInfo.TableSchema, jobInfo.TableName, jobInfo.TableID,
jobInfo.CreatedBy, &jobInfo.Parameters, jobInfo.SourceFileSize)
require.NoError(t, err)
- require.NoError(t, importer.StartJob(ctx, conn, jobID1))
+ require.NoError(t, importer.StartJob(ctx, conn, jobID1, importer.JobStepImporting))
require.NoError(t, importer.FailJob(ctx, conn, jobID1, "failed"))
jobID2, err := importer.CreateJob(ctx, conn, jobInfo.TableSchema, jobInfo.TableName, jobInfo.TableID,
jobInfo.CreatedBy, &jobInfo.Parameters, jobInfo.SourceFileSize)
diff --git a/executor/importer/table_import.go b/executor/importer/table_import.go
index 9cf3ffb9bc539..cf441346c67db 100644
--- a/executor/importer/table_import.go
+++ b/executor/importer/table_import.go
@@ -147,6 +147,28 @@ func GetCachedKVStoreFrom(pdAddr string, tls *common.TLS) (tidbkv.Storage, error
return kvStore, nil
}
+// GetRegionSplitSizeKeys gets the region split size and keys from PD.
+func GetRegionSplitSizeKeys(ctx context.Context) (regionSplitSize int64, regionSplitKeys int64, err error) {
+ tidbCfg := tidb.GetGlobalConfig()
+ tls, err := common.NewTLS(
+ tidbCfg.Security.ClusterSSLCA,
+ tidbCfg.Security.ClusterSSLCert,
+ tidbCfg.Security.ClusterSSLKey,
+ "",
+ nil, nil, nil,
+ )
+ if err != nil {
+ return 0, 0, err
+ }
+ tlsOpt := tls.ToPDSecurityOption()
+ pdCli, err := pd.NewClientWithContext(ctx, []string{tidbCfg.Path}, tlsOpt)
+ if err != nil {
+ return 0, 0, errors.Trace(err)
+ }
+ defer pdCli.Close()
+ return local.GetRegionSplitSizeKeys(ctx, pdCli, tls)
+}
+
// NewTableImporter creates a new table importer.
func NewTableImporter(param *JobImportParam, e *LoadDataController, taskID int64) (ti *TableImporter, err error) {
idAlloc := kv.NewPanickingAllocators(0)
@@ -223,7 +245,6 @@ func NewTableImporter(param *JobImportParam, e *LoadDataController, taskID int64
},
encTable: tbl,
dbID: e.DBID,
- store: e.dataStore,
kvStore: kvStore,
logger: e.logger,
// this is the value we use for 50TiB data parallel import.
@@ -246,7 +267,6 @@ type TableImporter struct {
encTable table.Table
dbID int64
- store storage.ExternalStorage
// the kv store we get is a cached store, so we can't close it.
kvStore tidbkv.Storage
logger *zap.Logger
@@ -475,6 +495,11 @@ func (ti *TableImporter) fullTableName() string {
return common.UniqueTable(ti.DBName, ti.Table.Meta().Name.O)
}
+// Backend returns the backend of the importer.
+func (ti *TableImporter) Backend() *local.Backend {
+ return ti.backend
+}
+
// Close implements the io.Closer interface.
func (ti *TableImporter) Close() error {
ti.backend.Close()
diff --git a/executor/index_lookup_join_test.go b/executor/index_lookup_join_test.go
index 969fa9dee950f..8a72e796fd20e 100644
--- a/executor/index_lookup_join_test.go
+++ b/executor/index_lookup_join_test.go
@@ -192,7 +192,7 @@ func TestInapplicableIndexJoinHint(t *testing.T) {
query := `select /*+ tidb_inlj(bb) */ aa.* from (select * from t1) as aa left join
(select t2.a, t2.a*2 as a2 from t2) as bb on aa.a=bb.a;`
- tk.HasPlan(query, "IndexJoin")
+ tk.MustHavePlan(query, "IndexJoin")
}
func TestIndexJoinOverflow(t *testing.T) {
diff --git a/executor/infoschema_reader_test.go b/executor/infoschema_reader_test.go
index 597e6e8f7be5a..ce1285c976027 100644
--- a/executor/infoschema_reader_test.go
+++ b/executor/infoschema_reader_test.go
@@ -433,7 +433,6 @@ func TestPartitionsTable(t *testing.T) {
tk := testkit.NewTestKit(t, store)
tk.MustExec("USE test;")
testkit.WithPruneMode(tk, variable.Static, func() {
- require.NoError(t, h.RefreshVars())
tk.MustExec("DROP TABLE IF EXISTS `test_partitions`;")
tk.MustExec(`CREATE TABLE test_partitions (a int, b int, c varchar(5), primary key(a), index idx(c)) PARTITION BY RANGE (a) (PARTITION p0 VALUES LESS THAN (6), PARTITION p1 VALUES LESS THAN (11), PARTITION p2 VALUES LESS THAN (16));`)
require.NoError(t, h.HandleDDLEvent(<-h.DDLEventCh()))
diff --git a/executor/insert_common.go b/executor/insert_common.go
index 49efe4095984b..f222e21427102 100644
--- a/executor/insert_common.go
+++ b/executor/insert_common.go
@@ -25,7 +25,6 @@ import (
"github.com/pingcap/tidb/ddl"
"github.com/pingcap/tidb/executor/internal/exec"
"github.com/pingcap/tidb/expression"
- "github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/meta/autoid"
"github.com/pingcap/tidb/parser/ast"
@@ -690,28 +689,15 @@ func (e *InsertValues) fillRow(ctx context.Context, row []types.Datum, hasValue
}
}
}
- tbl := e.Table.Meta()
+
// Handle exchange partition
- if tbl.ExchangePartitionInfo != nil {
- is := e.Ctx().GetDomainInfoSchema().(infoschema.InfoSchema)
- pt, tableFound := is.TableByID(tbl.ExchangePartitionInfo.ExchangePartitionID)
- if !tableFound {
- return nil, errors.Errorf("exchange partition process table by id failed")
- }
- p, ok := pt.(table.PartitionedTable)
- if !ok {
- return nil, errors.Errorf("exchange partition process assert table partition failed")
- }
- err := p.CheckForExchangePartition(
- e.Ctx(),
- pt.Meta().Partition,
- row,
- tbl.ExchangePartitionInfo.ExchangePartitionDefID,
- )
- if err != nil {
+ tbl := e.Table.Meta()
+ if tbl.ExchangePartitionInfo != nil && tbl.GetPartitionInfo() == nil {
+ if err := checkRowForExchangePartition(e.Ctx(), row, tbl); err != nil {
return nil, err
}
}
+
sc := e.Ctx().GetSessionVars().StmtCtx
warnCnt := int(sc.WarningCount())
for i, gCol := range gCols {
diff --git a/executor/internal/mpp/local_mpp_coordinator.go b/executor/internal/mpp/local_mpp_coordinator.go
index 8251a732b313b..84f969d5b4f5b 100644
--- a/executor/internal/mpp/local_mpp_coordinator.go
+++ b/executor/internal/mpp/local_mpp_coordinator.go
@@ -229,6 +229,7 @@ func (c *localMppCoordinator) appendMPPDispatchReq(pf *plannercore.Fragment) err
CoordinatorAddress: c.coordinatorAddr,
ReportExecutionSummary: c.reportExecutionInfo,
State: kv.MppTaskReady,
+ ResourceGroupName: c.sessionCtx.GetSessionVars().ResourceGroupName,
}
c.reqMap[req.ID] = &mppRequestReport{mppReq: req, receivedReport: false, errMsg: "", executionSummaries: nil}
c.mppReqs = append(c.mppReqs, req)
diff --git a/executor/lockstats/BUILD.bazel b/executor/lockstats/BUILD.bazel
index d0df7ae3106b2..6bbdc51990004 100644
--- a/executor/lockstats/BUILD.bazel
+++ b/executor/lockstats/BUILD.bazel
@@ -1,4 +1,4 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
go_library(
name = "lockstats",
@@ -13,7 +13,23 @@ go_library(
"//executor/internal/exec",
"//infoschema",
"//parser/ast",
+ "//parser/model",
+ "//table/tables",
"//util/chunk",
"@com_github_pingcap_errors//:errors",
],
)
+
+go_test(
+ name = "lockstats_test",
+ timeout = "short",
+ srcs = ["lock_stats_executor_test.go"],
+ embed = [":lockstats"],
+ flaky = True,
+ deps = [
+ "//infoschema",
+ "//parser/ast",
+ "//parser/model",
+ "@com_github_stretchr_testify//require",
+ ],
+)
diff --git a/executor/lockstats/lock_stats_executor.go b/executor/lockstats/lock_stats_executor.go
index 33a71f4d5eb3f..ef441b06a6375 100644
--- a/executor/lockstats/lock_stats_executor.go
+++ b/executor/lockstats/lock_stats_executor.go
@@ -16,12 +16,15 @@ package lockstats
import (
"context"
+ "fmt"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/domain"
"github.com/pingcap/tidb/executor/internal/exec"
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/parser/ast"
+ "github.com/pingcap/tidb/parser/model"
+ "github.com/pingcap/tidb/table/tables"
"github.com/pingcap/tidb/util/chunk"
)
@@ -30,6 +33,9 @@ var _ exec.Executor = &LockExec{}
// LockExec represents a lock statistic executor.
type LockExec struct {
exec.BaseExecutor
+ // Tables is the list of tables to be locked.
+ // It might contain partition names if we are locking partitions.
+ // When locking partitions, Tables will only contain one table name.
Tables []*ast.TableName
}
@@ -45,22 +51,43 @@ func (e *LockExec) Next(_ context.Context, _ *chunk.Chunk) error {
}
is := do.InfoSchema()
- tids, pids, err := populateTableAndPartitionIDs(e.Tables, is)
- if err != nil {
- return err
- }
+ if e.onlyLockPartitions() {
+ table := e.Tables[0]
+ tid, pidNames, err := populatePartitionIDAndNames(table, table.PartitionNames, is)
+ if err != nil {
+ return err
+ }
- msg, err := h.AddLockedTables(tids, pids, e.Tables)
- if err != nil {
- return err
- }
- if msg != "" {
- e.Ctx().GetSessionVars().StmtCtx.AppendWarning(errors.New(msg))
+ tableName := fmt.Sprintf("%s.%s", table.Schema.L, table.Name.L)
+ msg, err := h.LockPartitions(tid, tableName, pidNames)
+ if err != nil {
+ return err
+ }
+ if msg != "" {
+ e.Ctx().GetSessionVars().StmtCtx.AppendWarning(errors.New(msg))
+ }
+ } else {
+ tidAndNames, pidAndNames, err := populateTableAndPartitionIDs(e.Tables, is)
+ if err != nil {
+ return err
+ }
+
+ msg, err := h.LockTables(tidAndNames, pidAndNames)
+ if err != nil {
+ return err
+ }
+ if msg != "" {
+ e.Ctx().GetSessionVars().StmtCtx.AppendWarning(errors.New(msg))
+ }
}
return nil
}
+func (e *LockExec) onlyLockPartitions() bool {
+ return len(e.Tables) == 1 && len(e.Tables[0].PartitionNames) > 0
+}
+
// Close implements the Executor Close interface.
func (*LockExec) Close() error {
return nil
@@ -71,26 +98,69 @@ func (*LockExec) Open(context.Context) error {
return nil
}
+// populatePartitionIDAndNames returns the table ID and partition IDs for the given table name and partition names.
+func populatePartitionIDAndNames(
+ table *ast.TableName,
+ partitionNames []model.CIStr,
+ is infoschema.InfoSchema,
+) (int64, map[int64]string, error) {
+ if len(partitionNames) == 0 {
+ return 0, nil, errors.New("partition list should not be empty")
+ }
+ tbl, err := is.TableByName(table.Schema, table.Name)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ pi := tbl.Meta().GetPartitionInfo()
+ if pi == nil {
+ return 0, nil, errors.Errorf("table %s is not a partition table",
+ fmt.Sprintf("%s.%s", table.Schema.L, table.Name.L))
+ }
+
+ pidNames := make(map[int64]string, len(partitionNames))
+ for _, partitionName := range partitionNames {
+ pid, err := tables.FindPartitionByName(tbl.Meta(), partitionName.L)
+ if err != nil {
+ return 0, nil, err
+ }
+ pidNames[pid] = partitionName.L
+ }
+
+ return tbl.Meta().ID, pidNames, nil
+}
+
// populateTableAndPartitionIDs returns table IDs and partition IDs for the given table names.
-func populateTableAndPartitionIDs(tables []*ast.TableName, is infoschema.InfoSchema) ([]int64, []int64, error) {
- tids := make([]int64, 0, len(tables))
- pids := make([]int64, 0)
+func populateTableAndPartitionIDs(
+ tables []*ast.TableName,
+ is infoschema.InfoSchema,
+) (map[int64]string, map[int64]string, error) {
+ if len(tables) == 0 {
+ return nil, nil, errors.New("table list should not be empty")
+ }
+
+ tidAndNames := make(map[int64]string, len(tables))
+ pidAndNames := make(map[int64]string, len(tables))
for _, table := range tables {
tbl, err := is.TableByName(table.Schema, table.Name)
if err != nil {
return nil, nil, err
}
- tids = append(tids, tbl.Meta().ID)
+ tidAndNames[tbl.Meta().ID] = fmt.Sprintf("%s.%s", table.Schema.L, table.Name.L)
pi := tbl.Meta().GetPartitionInfo()
if pi == nil {
continue
}
for _, p := range pi.Definitions {
- pids = append(pids, p.ID)
+ pidAndNames[p.ID] = genFullPartitionName(table, p.Name.L)
}
}
- return tids, pids, nil
+ return tidAndNames, pidAndNames, nil
+}
+
+func genFullPartitionName(table *ast.TableName, partitionName string) string {
+ return fmt.Sprintf("%s.%s partition (%s)", table.Schema.L, table.Name.L, partitionName)
}
diff --git a/executor/lockstats/lock_stats_executor_test.go b/executor/lockstats/lock_stats_executor_test.go
new file mode 100644
index 0000000000000..6a15b8bbe322c
--- /dev/null
+++ b/executor/lockstats/lock_stats_executor_test.go
@@ -0,0 +1,108 @@
+// Copyright 2023 PingCAP, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package lockstats
+
+import (
+ "testing"
+
+ "github.com/pingcap/tidb/infoschema"
+ "github.com/pingcap/tidb/parser/ast"
+ "github.com/pingcap/tidb/parser/model"
+ "github.com/stretchr/testify/require"
+)
+
+func TestPopulatePartitionIDAndNames(t *testing.T) {
+ fakeInfo := infoschema.MockInfoSchema([]*model.TableInfo{
+ tInfo(1, "t1", "p1", "p2"),
+ })
+
+ table := &ast.TableName{
+ Schema: model.NewCIStr("test"),
+ Name: model.NewCIStr("t1"),
+ PartitionNames: []model.CIStr{
+ model.NewCIStr("p1"),
+ model.NewCIStr("p2"),
+ },
+ }
+
+ gotTID, gotPIDNames, err := populatePartitionIDAndNames(table, table.PartitionNames, fakeInfo)
+ require.NoError(t, err)
+ require.Equal(t, int64(1), gotTID)
+ require.Equal(t, map[int64]string{
+ 2: "p1",
+ 3: "p2",
+ }, gotPIDNames)
+
+ // Empty partition names.
+ _, _, err = populatePartitionIDAndNames(nil, nil, fakeInfo)
+ require.Error(t, err)
+}
+
+func TestPopulateTableAndPartitionIDs(t *testing.T) {
+ fakeInfo := infoschema.MockInfoSchema([]*model.TableInfo{
+ tInfo(1, "t1", "p1", "p2"),
+ tInfo(4, "t2"),
+ })
+
+ tables := []*ast.TableName{
+ {
+ Schema: model.NewCIStr("test"),
+ Name: model.NewCIStr("t1"),
+ PartitionNames: []model.CIStr{
+ model.NewCIStr("p1"),
+ model.NewCIStr("p2"),
+ },
+ },
+ {
+ Schema: model.NewCIStr("test"),
+ Name: model.NewCIStr("t2"),
+ },
+ }
+
+ gotTIDAndNames, gotPIDAndNames, err := populateTableAndPartitionIDs(tables, fakeInfo)
+ require.NoError(t, err)
+ require.Equal(t, map[int64]string{
+ 1: "test.t1",
+ 4: "test.t2",
+ }, gotTIDAndNames)
+ require.Equal(t, map[int64]string{
+ 2: "test.t1 partition (p1)",
+ 3: "test.t1 partition (p2)",
+ }, gotPIDAndNames)
+
+ // Empty table list.
+ _, _, err = populateTableAndPartitionIDs(nil, fakeInfo)
+ require.Error(t, err)
+}
+
+func tInfo(id int, tableName string, partitionNames ...string) *model.TableInfo {
+ tbl := &model.TableInfo{
+ ID: int64(id),
+ Name: model.NewCIStr(tableName),
+ }
+ if len(partitionNames) > 0 {
+ tbl.Partition = &model.PartitionInfo{
+ Enable: true,
+ }
+ for i, partitionName := range partitionNames {
+ tbl.Partition.Definitions = append(tbl.Partition.Definitions, model.PartitionDefinition{
+ ID: int64(id + 1 + i),
+ Name: model.NewCIStr(partitionName),
+ })
+ }
+ }
+
+ return tbl
+}
diff --git a/executor/lockstats/unlock_stats_executor.go b/executor/lockstats/unlock_stats_executor.go
index 804f4301be701..91f3a4edd511d 100644
--- a/executor/lockstats/unlock_stats_executor.go
+++ b/executor/lockstats/unlock_stats_executor.go
@@ -16,6 +16,7 @@ package lockstats
import (
"context"
+ "fmt"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/domain"
@@ -29,6 +30,9 @@ var _ exec.Executor = &UnlockExec{}
// UnlockExec represents a unlock statistic executor.
type UnlockExec struct {
exec.BaseExecutor
+ // Tables is the list of tables to be unlocked.
+ // It might contain partition names if we are unlocking partitions.
+ // When unlocking partitions, Tables will only contain one table name.
Tables []*ast.TableName
}
@@ -44,22 +48,41 @@ func (e *UnlockExec) Next(context.Context, *chunk.Chunk) error {
}
is := do.InfoSchema()
- tids, pids, err := populateTableAndPartitionIDs(e.Tables, is)
- if err != nil {
- return err
- }
-
- msg, err := h.RemoveLockedTables(tids, pids, e.Tables)
- if err != nil {
- return err
- }
- if msg != "" {
- e.Ctx().GetSessionVars().StmtCtx.AppendWarning(errors.New(msg))
+ if e.onlyUnlockPartitions() {
+ table := e.Tables[0]
+ tid, pidNames, err := populatePartitionIDAndNames(table, table.PartitionNames, is)
+ if err != nil {
+ return err
+ }
+ tableName := fmt.Sprintf("%s.%s", table.Schema.O, table.Name.O)
+ msg, err := h.RemoveLockedPartitions(tid, tableName, pidNames)
+ if err != nil {
+ return err
+ }
+ if msg != "" {
+ e.Ctx().GetSessionVars().StmtCtx.AppendWarning(errors.New(msg))
+ }
+ } else {
+ tidAndNames, pidAndNames, err := populateTableAndPartitionIDs(e.Tables, is)
+ if err != nil {
+ return err
+ }
+ msg, err := h.RemoveLockedTables(tidAndNames, pidAndNames)
+ if err != nil {
+ return err
+ }
+ if msg != "" {
+ e.Ctx().GetSessionVars().StmtCtx.AppendWarning(errors.New(msg))
+ }
}
return nil
}
+func (e *UnlockExec) onlyUnlockPartitions() bool {
+ return len(e.Tables) == 1 && len(e.Tables[0].PartitionNames) > 0
+}
+
// Close implements the Executor Close interface.
func (*UnlockExec) Close() error {
return nil
diff --git a/executor/partition_table_test.go b/executor/partition_table_test.go
index c6b3608aedac6..18d7683b55327 100644
--- a/executor/partition_table_test.go
+++ b/executor/partition_table_test.go
@@ -206,30 +206,30 @@ func TestPointGetwithRangeAndListPartitionTable(t *testing.T) {
// select a from t where a={x}; // the result is {x}
x := rand.Intn(100) + 1
queryRange1 := fmt.Sprintf("select a from trange1 where a=%v", x)
- require.True(t, tk.HasPlan(queryRange1, "Point_Get")) // check if PointGet is used
+ tk.MustHavePlan(queryRange1, "Point_Get") // check if PointGet is used
tk.MustQuery(queryRange1).Check(testkit.Rows(fmt.Sprintf("%v", x)))
queryRange2 := fmt.Sprintf("select a from trange1 where a=%v", x)
- require.True(t, tk.HasPlan(queryRange2, "Point_Get")) // check if PointGet is used
+ tk.MustHavePlan(queryRange2, "Point_Get") // check if PointGet is used
tk.MustQuery(queryRange2).Check(testkit.Rows(fmt.Sprintf("%v", x)))
y := rand.Intn(12) + 1
queryList := fmt.Sprintf("select a from tlist where a=%v", y)
- require.True(t, tk.HasPlan(queryList, "Point_Get")) // check if PointGet is used
+ tk.MustHavePlan(queryList, "Point_Get") // check if PointGet is used
tk.MustQuery(queryList).Check(testkit.Rows(fmt.Sprintf("%v", y)))
}
// test table dual
queryRange1 := "select a from trange1 where a=200"
- require.True(t, tk.HasPlan(queryRange1, "TableDual")) // check if TableDual is used
+ tk.MustHavePlan(queryRange1, "TableDual") // check if TableDual is used
tk.MustQuery(queryRange1).Check(testkit.Rows())
queryRange2 := "select a from trange2 where a=200"
- require.True(t, tk.HasPlan(queryRange2, "TableDual")) // check if TableDual is used
+ tk.MustHavePlan(queryRange2, "TableDual") // check if TableDual is used
tk.MustQuery(queryRange2).Check(testkit.Rows())
queryList := "select a from tlist where a=200"
- require.True(t, tk.HasPlan(queryList, "TableDual")) // check if TableDual is used
+ tk.MustHavePlan(queryList, "TableDual") // check if TableDual is used
tk.MustQuery(queryList).Check(testkit.Rows())
// test PointGet for one partition
@@ -237,14 +237,14 @@ func TestPointGetwithRangeAndListPartitionTable(t *testing.T) {
tk.MustExec("create table t(a int primary key, b int) PARTITION BY RANGE (a) (partition p0 values less than(1))")
tk.MustExec("insert into t values (-1, 1), (-2, 1)")
tk.MustExec("analyze table t")
- require.True(t, tk.HasPlan(queryOnePartition, "Point_Get"))
+ tk.MustHavePlan(queryOnePartition, "Point_Get")
tk.MustQuery(queryOnePartition).Check(testkit.Rows(fmt.Sprintf("%v", -1)))
tk.MustExec("drop table t")
tk.MustExec("create table t(a int primary key, b int) PARTITION BY list (a) (partition p0 values in (-1, -2))")
tk.MustExec("insert into t values (-1, 1), (-2, 1)")
tk.MustExec("analyze table t")
- require.True(t, tk.HasPlan(queryOnePartition, "Point_Get"))
+ tk.MustHavePlan(queryOnePartition, "Point_Get")
tk.MustQuery(queryOnePartition).Check(testkit.Rows(fmt.Sprintf("%v", -1)))
}
@@ -543,7 +543,7 @@ func TestOrderByAndLimit(t *testing.T) {
y := rand.Intn(500) + 1
queryPartition := fmt.Sprintf("select * from trange use index(idx_a) where a > %v order by a, b limit %v;", x, y)
queryRegular := fmt.Sprintf("select * from tregular use index(idx_a) where a > %v order by a, b limit %v;", x, y)
- require.True(t, tk.HasPlan(queryPartition, "IndexLookUp")) // check if IndexLookUp is used
+ tk.MustHavePlan(queryPartition, "IndexLookUp") // check if IndexLookUp is used
tk.MustQuery(queryPartition).Check(tk.MustQuery(queryRegular).Rows())
}
@@ -568,17 +568,17 @@ func TestOrderByAndLimit(t *testing.T) {
regularResult := tk.MustQuery(queryRegular).Sort().Rows()
if len(regularResult) > 0 {
- require.True(t, tk.HasPlan(queryRangePartitionWithLimitHint, "Limit"))
- require.True(t, tk.HasPlan(queryRangePartitionWithLimitHint, "IndexLookUp"))
- require.True(t, tk.HasPlan(queryHashPartitionWithLimitHint, "Limit"))
- require.True(t, tk.HasPlan(queryHashPartitionWithLimitHint, "IndexLookUp"))
- require.True(t, tk.HasPlan(queryListPartitionWithLimitHint, "Limit"))
- require.True(t, tk.HasPlan(queryListPartitionWithLimitHint, "IndexLookUp"))
+ tk.MustHavePlan(queryRangePartitionWithLimitHint, "Limit")
+ tk.MustHavePlan(queryRangePartitionWithLimitHint, "IndexLookUp")
+ tk.MustHavePlan(queryHashPartitionWithLimitHint, "Limit")
+ tk.MustHavePlan(queryHashPartitionWithLimitHint, "IndexLookUp")
+ tk.MustHavePlan(queryListPartitionWithLimitHint, "Limit")
+ tk.MustHavePlan(queryListPartitionWithLimitHint, "IndexLookUp")
}
if i%2 != 0 {
- require.False(t, tk.HasPlan(queryRangePartitionWithLimitHint, "TopN")) // fully pushed
- require.False(t, tk.HasPlan(queryHashPartitionWithLimitHint, "TopN"))
- require.False(t, tk.HasPlan(queryListPartitionWithLimitHint, "TopN"))
+ tk.MustNotHavePlan(queryRangePartitionWithLimitHint, "TopN") // fully pushed
+ tk.MustNotHavePlan(queryHashPartitionWithLimitHint, "TopN")
+ tk.MustNotHavePlan(queryListPartitionWithLimitHint, "TopN")
}
tk.MustQuery(queryRangePartitionWithLimitHint).Sort().Check(regularResult)
tk.MustQuery(queryHashPartitionWithLimitHint).Sort().Check(regularResult)
@@ -604,17 +604,17 @@ func TestOrderByAndLimit(t *testing.T) {
regularResult := tk.MustQuery(queryRegular).Sort().Rows()
if len(regularResult) > 0 {
- require.True(t, tk.HasPlan(queryRangePartitionWithLimitHint, "Limit"))
- require.True(t, tk.HasPlan(queryRangePartitionWithLimitHint, "IndexLookUp"))
- require.True(t, tk.HasPlan(queryHashPartitionWithLimitHint, "Limit"))
- require.True(t, tk.HasPlan(queryHashPartitionWithLimitHint, "IndexLookUp"))
- require.True(t, tk.HasPlan(queryListPartitionWithLimitHint, "Limit"))
- require.True(t, tk.HasPlan(queryListPartitionWithLimitHint, "IndexLookUp"))
+ tk.MustHavePlan(queryRangePartitionWithLimitHint, "Limit")
+ tk.MustHavePlan(queryRangePartitionWithLimitHint, "IndexLookUp")
+ tk.MustHavePlan(queryHashPartitionWithLimitHint, "Limit")
+ tk.MustHavePlan(queryHashPartitionWithLimitHint, "IndexLookUp")
+ tk.MustHavePlan(queryListPartitionWithLimitHint, "Limit")
+ tk.MustHavePlan(queryListPartitionWithLimitHint, "IndexLookUp")
}
if i%2 != 0 {
- require.False(t, tk.HasPlan(queryRangePartitionWithLimitHint, "TopN")) // fully pushed
- require.False(t, tk.HasPlan(queryHashPartitionWithLimitHint, "TopN"))
- require.False(t, tk.HasPlan(queryListPartitionWithLimitHint, "TopN"))
+ tk.MustNotHavePlan(queryRangePartitionWithLimitHint, "TopN") // fully pushed
+ tk.MustNotHavePlan(queryHashPartitionWithLimitHint, "TopN")
+ tk.MustNotHavePlan(queryListPartitionWithLimitHint, "TopN")
}
tk.MustQuery(queryRangePartitionWithLimitHint).Sort().Check(regularResult)
tk.MustQuery(queryHashPartitionWithLimitHint).Sort().Check(regularResult)
@@ -631,7 +631,7 @@ func TestOrderByAndLimit(t *testing.T) {
y := rand.Intn(500) + 1
queryPartition := fmt.Sprintf("select * from trange ignore index(idx_a, idx_ab) where a > %v order by a, b limit %v;", x, y)
queryRegular := fmt.Sprintf("select * from tregular ignore index(idx_a, idx_ab) where a > %v order by a, b limit %v;", x, y)
- require.True(t, tk.HasPlan(queryPartition, "TableReader")) // check if tableReader is used
+ tk.MustHavePlan(queryPartition, "TableReader") // check if tableReader is used
tk.MustQuery(queryPartition).Check(tk.MustQuery(queryRegular).Rows())
}
@@ -645,12 +645,12 @@ func TestOrderByAndLimit(t *testing.T) {
queryHashPartition := fmt.Sprintf("select /*+ LIMIT_TO_COP() */ * from thash ignore index(idx_a, idx_ab) where a > %v order by a, b limit %v;", x, y)
queryListPartition := fmt.Sprintf("select /*+ LIMIT_TO_COP() */ * from tlist ignore index(idx_a, idx_ab) where a > %v order by a, b limit %v;", x, y)
queryRegular := fmt.Sprintf("select * from tregular ignore index(idx_a) where a > %v order by a, b limit %v;", x, y)
- require.True(t, tk.HasPlan(queryRangePartition, "TableReader")) // check if tableReader is used
- require.True(t, tk.HasPlan(queryHashPartition, "TableReader"))
- require.True(t, tk.HasPlan(queryListPartition, "TableReader"))
- require.False(t, tk.HasPlan(queryRangePartition, "Limit")) // check if order property is not pushed
- require.False(t, tk.HasPlan(queryHashPartition, "Limit"))
- require.False(t, tk.HasPlan(queryListPartition, "Limit"))
+ tk.MustHavePlan(queryRangePartition, "TableReader") // check if tableReader is used
+ tk.MustHavePlan(queryHashPartition, "TableReader")
+ tk.MustHavePlan(queryListPartition, "TableReader")
+ tk.MustNotHavePlan(queryRangePartition, "Limit") // check if order property is not pushed
+ tk.MustNotHavePlan(queryHashPartition, "Limit")
+ tk.MustNotHavePlan(queryListPartition, "Limit")
regularResult := tk.MustQuery(queryRegular).Rows()
tk.MustQuery(queryRangePartition).Check(regularResult)
tk.MustQuery(queryHashPartition).Check(regularResult)
@@ -662,15 +662,15 @@ func TestOrderByAndLimit(t *testing.T) {
queryHashPartition = fmt.Sprintf("select /*+ LIMIT_TO_COP() */ a from thash_intpk use index(primary) where a > %v order by a limit %v", x, y)
queryListPartition = fmt.Sprintf("select /*+ LIMIT_TO_COP() */ a from tlist_intpk use index(primary) where a > %v order by a limit %v", x, y)
queryRegular = fmt.Sprintf("select a from tregular_intpk where a > %v order by a limit %v", x, y)
- require.True(t, tk.HasPlan(queryRangePartition, "TableReader"))
- require.True(t, tk.HasPlan(queryHashPartition, "TableReader"))
- require.True(t, tk.HasPlan(queryListPartition, "TableReader"))
- require.True(t, tk.HasPlan(queryRangePartition, "Limit")) // check if order property is pushed
- require.False(t, tk.HasPlan(queryRangePartition, "TopN")) // and is fully pushed
- require.True(t, tk.HasPlan(queryHashPartition, "Limit"))
- require.False(t, tk.HasPlan(queryHashPartition, "TopN"))
- require.True(t, tk.HasPlan(queryListPartition, "Limit"))
- require.False(t, tk.HasPlan(queryListPartition, "TopN"))
+ tk.MustHavePlan(queryRangePartition, "TableReader")
+ tk.MustHavePlan(queryHashPartition, "TableReader")
+ tk.MustHavePlan(queryListPartition, "TableReader")
+ tk.MustHavePlan(queryRangePartition, "Limit") // check if order property is pushed
+ tk.MustNotHavePlan(queryRangePartition, "TopN") // and is fully pushed
+ tk.MustHavePlan(queryHashPartition, "Limit")
+ tk.MustNotHavePlan(queryHashPartition, "TopN")
+ tk.MustHavePlan(queryListPartition, "Limit")
+ tk.MustNotHavePlan(queryListPartition, "TopN")
regularResult = tk.MustQuery(queryRegular).Rows()
tk.MustQuery(queryRangePartition).Check(regularResult)
tk.MustQuery(queryHashPartition).Check(regularResult)
@@ -681,15 +681,15 @@ func TestOrderByAndLimit(t *testing.T) {
queryHashPartition = fmt.Sprintf("select /*+ LIMIT_TO_COP() */ * from thash_clustered use index(primary) where a > %v order by a, b limit %v;", x, y)
queryListPartition = fmt.Sprintf("select /*+ LIMIT_TO_COP() */ * from tlist_clustered use index(primary) where a > %v order by a, b limit %v;", x, y)
queryRegular = fmt.Sprintf("select * from tregular_clustered where a > %v order by a, b limit %v;", x, y)
- require.True(t, tk.HasPlan(queryRangePartition, "TableReader")) // check if tableReader is used
- require.True(t, tk.HasPlan(queryHashPartition, "TableReader"))
- require.True(t, tk.HasPlan(queryListPartition, "TableReader"))
- require.True(t, tk.HasPlan(queryRangePartition, "Limit")) // check if order property is pushed
- require.True(t, tk.HasPlan(queryHashPartition, "Limit"))
- require.True(t, tk.HasPlan(queryListPartition, "Limit"))
- require.False(t, tk.HasPlan(queryRangePartition, "TopN")) // could fully pushed for TableScan executor
- require.False(t, tk.HasPlan(queryHashPartition, "TopN"))
- require.False(t, tk.HasPlan(queryListPartition, "TopN"))
+ tk.MustHavePlan(queryRangePartition, "TableReader") // check if tableReader is used
+ tk.MustHavePlan(queryHashPartition, "TableReader")
+ tk.MustHavePlan(queryListPartition, "TableReader")
+ tk.MustHavePlan(queryRangePartition, "Limit") // check if order property is pushed
+ tk.MustHavePlan(queryHashPartition, "Limit")
+ tk.MustHavePlan(queryListPartition, "Limit")
+ tk.MustNotHavePlan(queryRangePartition, "TopN") // could fully pushed for TableScan executor
+ tk.MustNotHavePlan(queryHashPartition, "TopN")
+ tk.MustNotHavePlan(queryListPartition, "TopN")
regularResult = tk.MustQuery(queryRegular).Rows()
tk.MustQuery(queryRangePartition).Check(regularResult)
tk.MustQuery(queryHashPartition).Check(regularResult)
@@ -701,12 +701,12 @@ func TestOrderByAndLimit(t *testing.T) {
// check if tiflash is used
require.True(t, tk.HasTiFlashPlan(queryPartitionWithTiFlash), fmt.Sprintf("%v", tk.MustQuery("explain "+queryPartitionWithTiFlash).Rows()))
// but order is not pushed
- require.False(t, tk.HasPlan(queryPartitionWithTiFlash, "Limit"), fmt.Sprintf("%v", tk.MustQuery("explain "+queryPartitionWithTiFlash).Rows()))
+ tk.MustNotHavePlan(queryPartitionWithTiFlash, "Limit")
queryPartitionWithTiFlash = fmt.Sprintf("select /*+ read_from_storage(tiflash[trange_intpk]) */ /*+ LIMIT_TO_COP() */ * from trange_intpk where a > %v order by a limit %v", x, y)
// check if tiflash is used
require.True(t, tk.HasTiFlashPlan(queryPartitionWithTiFlash), fmt.Sprintf("%v", tk.MustQuery("explain "+queryPartitionWithTiFlash).Rows()))
// but order is not pushed
- require.False(t, tk.HasPlan(queryPartitionWithTiFlash, "Limit"), fmt.Sprintf("%v", tk.MustQuery("explain "+queryPartitionWithTiFlash).Rows()))
+ tk.MustNotHavePlan(queryPartitionWithTiFlash, "Limit")
queryPartitionWithTiFlash = fmt.Sprintf("select /*+ read_from_storage(tiflash[trange_clustered]) */ * from trange_clustered where a > %v order by a limit %v", x, y)
// check if tiflash is used
require.True(t, tk.HasTiFlashPlan(queryPartitionWithTiFlash), fmt.Sprintf("%v", tk.MustQuery("explain "+queryPartitionWithTiFlash).Rows()))
@@ -714,7 +714,7 @@ func TestOrderByAndLimit(t *testing.T) {
// check if tiflash is used
require.True(t, tk.HasTiFlashPlan(queryPartitionWithTiFlash))
// but order is not pushed
- require.False(t, tk.HasPlan(queryPartitionWithTiFlash, "Limit"), fmt.Sprintf("%v", tk.MustQuery("explain "+queryPartitionWithTiFlash).Rows()))
+ tk.MustNotHavePlan(queryPartitionWithTiFlash, "Limit")
queryPartitionWithTiFlash = fmt.Sprintf("select /*+ read_from_storage(tiflash[thash_intpk]) */ * from thash_intpk where a > %v order by a limit %v", x, y)
// check if tiflash is used
require.True(t, tk.HasTiFlashPlan(queryPartitionWithTiFlash), fmt.Sprintf("%v", tk.MustQuery("explain "+queryPartitionWithTiFlash).Rows()))
@@ -722,7 +722,7 @@ func TestOrderByAndLimit(t *testing.T) {
// check if tiflash is used
require.True(t, tk.HasTiFlashPlan(queryPartitionWithTiFlash))
// but order is not pushed
- require.False(t, tk.HasPlan(queryPartitionWithTiFlash, "Limit"), fmt.Sprintf("%v", tk.MustQuery("explain "+queryPartitionWithTiFlash).Rows()))
+ tk.MustNotHavePlan(queryPartitionWithTiFlash, "Limit")
queryPartitionWithTiFlash = fmt.Sprintf("select /*+ read_from_storage(tiflash[thash_clustered]) */ * from thash_clustered where a > %v order by a limit %v", x, y)
// check if tiflash is used
require.True(t, tk.HasTiFlashPlan(queryPartitionWithTiFlash), fmt.Sprintf("%v", tk.MustQuery("explain "+queryPartitionWithTiFlash).Rows()))
@@ -730,7 +730,7 @@ func TestOrderByAndLimit(t *testing.T) {
// check if tiflash is used
require.True(t, tk.HasTiFlashPlan(queryPartitionWithTiFlash))
// but order is not pushed
- require.False(t, tk.HasPlan(queryPartitionWithTiFlash, "Limit"), fmt.Sprintf("%v", tk.MustQuery("explain "+queryPartitionWithTiFlash).Rows()))
+ tk.MustNotHavePlan(queryPartitionWithTiFlash, "Limit")
queryPartitionWithTiFlash = fmt.Sprintf("select /*+ read_from_storage(tiflash[tlist_intpk]) */ * from tlist_intpk where a > %v order by a limit %v", x, y)
// check if tiflash is used
require.True(t, tk.HasTiFlashPlan(queryPartitionWithTiFlash), fmt.Sprintf("%v", tk.MustQuery("explain "+queryPartitionWithTiFlash).Rows()))
@@ -738,7 +738,7 @@ func TestOrderByAndLimit(t *testing.T) {
// check if tiflash is used
require.True(t, tk.HasTiFlashPlan(queryPartitionWithTiFlash))
// but order is not pushed
- require.False(t, tk.HasPlan(queryPartitionWithTiFlash, "Limit"), fmt.Sprintf("%v", tk.MustQuery("explain "+queryPartitionWithTiFlash).Rows()))
+ tk.MustNotHavePlan(queryPartitionWithTiFlash, "Limit")
queryPartitionWithTiFlash = fmt.Sprintf("select /*+ read_from_storage(tiflash[tlist_clustered]) */ * from tlist_clustered where a > %v order by a limit %v", x, y)
// check if tiflash is used
require.True(t, tk.HasTiFlashPlan(queryPartitionWithTiFlash), fmt.Sprintf("%v", tk.MustQuery("explain "+queryPartitionWithTiFlash).Rows()))
@@ -746,7 +746,7 @@ func TestOrderByAndLimit(t *testing.T) {
// check if tiflash is used
require.True(t, tk.HasTiFlashPlan(queryPartitionWithTiFlash))
// but order is not pushed
- require.False(t, tk.HasPlan(queryPartitionWithTiFlash, "Limit"), fmt.Sprintf("%v", tk.MustQuery("explain "+queryPartitionWithTiFlash).Rows()))
+ tk.MustNotHavePlan(queryPartitionWithTiFlash, "Limit")
tk.MustExec(" set @@tidb_allow_mpp=0;")
tk.MustExec("set @@session.tidb_isolation_read_engines=\"tikv\"")
}
@@ -759,7 +759,7 @@ func TestOrderByAndLimit(t *testing.T) {
y := rand.Intn(500) + 1
queryPartition := fmt.Sprintf("select a from trange use index(idx_a) where a > %v order by a limit %v;", x, y)
queryRegular := fmt.Sprintf("select a from tregular use index(idx_a) where a > %v order by a limit %v;", x, y)
- require.True(t, tk.HasPlan(queryPartition, "IndexReader")) // check if indexReader is used
+ tk.MustHavePlan(queryPartition, "IndexReader") // check if indexReader is used
tk.MustQuery(queryPartition).Check(tk.MustQuery(queryRegular).Rows())
}
@@ -772,12 +772,12 @@ func TestOrderByAndLimit(t *testing.T) {
queryRangePartition := fmt.Sprintf("select /*+ LIMIT_TO_COP() */ a from trange use index(idx_a) where a > %v order by a limit %v;", x, y)
queryHashPartition := fmt.Sprintf("select /*+ LIMIT_TO_COP() */ a from thash use index(idx_a) where a > %v order by a limit %v;", x, y)
queryRegular := fmt.Sprintf("select a from tregular use index(idx_a) where a > %v order by a limit %v;", x, y)
- require.True(t, tk.HasPlan(queryRangePartition, "IndexReader")) // check if indexReader is used
- require.True(t, tk.HasPlan(queryHashPartition, "IndexReader"))
- require.True(t, tk.HasPlan(queryRangePartition, "Limit")) // check if order property is pushed
- require.True(t, tk.HasPlan(queryHashPartition, "Limit"))
- require.False(t, tk.HasPlan(queryRangePartition, "TopN")) // fully pushed limit
- require.False(t, tk.HasPlan(queryHashPartition, "TopN"))
+ tk.MustHavePlan(queryRangePartition, "IndexReader") // check if indexReader is used
+ tk.MustHavePlan(queryHashPartition, "IndexReader")
+ tk.MustHavePlan(queryRangePartition, "Limit") // check if order property is pushed
+ tk.MustHavePlan(queryHashPartition, "Limit")
+ tk.MustNotHavePlan(queryRangePartition, "TopN") // fully pushed limit
+ tk.MustNotHavePlan(queryHashPartition, "TopN")
regularResult := tk.MustQuery(queryRegular).Rows()
tk.MustQuery(queryRangePartition).Check(regularResult)
tk.MustQuery(queryHashPartition).Check(regularResult)
@@ -791,15 +791,15 @@ func TestOrderByAndLimit(t *testing.T) {
queryHashPartition := fmt.Sprintf("select /*+ LIMIT_TO_COP() */ a from thash use index(idx_ab) where a = %v order by b limit %v;", x, y)
queryListPartition := fmt.Sprintf("select /*+ LIMIT_TO_COP() */ a from tlist use index(idx_ab) where a = %v order by b limit %v;", x, y)
queryRegular := fmt.Sprintf("select a from tregular use index(idx_ab) where a = %v order by b limit %v;", x, y)
- require.True(t, tk.HasPlan(queryRangePartition, "IndexReader")) // check if indexReader is used
- require.True(t, tk.HasPlan(queryHashPartition, "IndexReader"))
- require.True(t, tk.HasPlan(queryListPartition, "IndexReader"))
- require.True(t, tk.HasPlan(queryRangePartition, "Limit")) // check if order property is pushed
- require.True(t, tk.HasPlan(queryHashPartition, "Limit"))
- require.True(t, tk.HasPlan(queryListPartition, "Limit"))
- require.False(t, tk.HasPlan(queryRangePartition, "TopN")) // fully pushed limit
- require.False(t, tk.HasPlan(queryHashPartition, "TopN"))
- require.False(t, tk.HasPlan(queryListPartition, "TopN"))
+ tk.MustHavePlan(queryRangePartition, "IndexReader") // check if indexReader is used
+ tk.MustHavePlan(queryHashPartition, "IndexReader")
+ tk.MustHavePlan(queryListPartition, "IndexReader")
+ tk.MustHavePlan(queryRangePartition, "Limit") // check if order property is pushed
+ tk.MustHavePlan(queryHashPartition, "Limit")
+ tk.MustHavePlan(queryListPartition, "Limit")
+ tk.MustNotHavePlan(queryRangePartition, "TopN") // fully pushed limit
+ tk.MustNotHavePlan(queryHashPartition, "TopN")
+ tk.MustNotHavePlan(queryListPartition, "TopN")
regularResult := tk.MustQuery(queryRegular).Rows()
tk.MustQuery(queryRangePartition).Check(regularResult)
tk.MustQuery(queryHashPartition).Check(regularResult)
@@ -813,7 +813,7 @@ func TestOrderByAndLimit(t *testing.T) {
y := rand.Intn(500) + 1
queryHashPartition := fmt.Sprintf("select /*+ use_index_merge(thash) */ * from thash where a > 2 or b < 5 order by a, b limit %v;", y)
queryRegular := fmt.Sprintf("select * from tregular where a > 2 or b < 5 order by a, b limit %v;", y)
- require.True(t, tk.HasPlan(queryHashPartition, "IndexMerge")) // check if indexMerge is used
+ tk.MustHavePlan(queryHashPartition, "IndexMerge") // check if indexMerge is used
tk.MustQuery(queryHashPartition).Check(tk.MustQuery(queryRegular).Rows())
}
@@ -971,13 +971,13 @@ func TestBatchGetandPointGetwithHashPartition(t *testing.T) {
x := rand.Intn(100) + 1
queryHash := fmt.Sprintf("select a from thash where a=%v", x)
queryRegular := fmt.Sprintf("select a from tregular where a=%v", x)
- require.True(t, tk.HasPlan(queryHash, "Point_Get")) // check if PointGet is used
+ tk.MustHavePlan(queryHash, "Point_Get") // check if PointGet is used
tk.MustQuery(queryHash).Check(tk.MustQuery(queryRegular).Rows())
}
// test empty PointGet
queryHash := "select a from thash where a=200"
- require.True(t, tk.HasPlan(queryHash, "Point_Get")) // check if PointGet is used
+ tk.MustHavePlan(queryHash, "Point_Get") // check if PointGet is used
tk.MustQuery(queryHash).Check(testkit.Rows())
// test BatchGet
@@ -992,7 +992,7 @@ func TestBatchGetandPointGetwithHashPartition(t *testing.T) {
queryHash := fmt.Sprintf("select a from thash where a in (%v)", strings.Join(points, ","))
queryRegular := fmt.Sprintf("select a from tregular where a in (%v)", strings.Join(points, ","))
- require.True(t, tk.HasPlan(queryHash, "Point_Get")) // check if PointGet is used
+ tk.MustHavePlan(queryHash, "Point_Get") // check if PointGet is used
tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
}
}
@@ -1361,11 +1361,11 @@ func TestBatchGetforRangeandListPartitionTable(t *testing.T) {
queryRegular1 := fmt.Sprintf("select a from tregular1 where a in (%v)", strings.Join(points, ","))
queryHash := fmt.Sprintf("select a from thash where a in (%v)", strings.Join(points, ","))
- require.True(t, tk.HasPlan(queryHash, "Batch_Point_Get")) // check if BatchGet is used
+ tk.MustHavePlan(queryHash, "Batch_Point_Get") // check if BatchGet is used
tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular1).Sort().Rows())
queryRange := fmt.Sprintf("select a from trange where a in (%v)", strings.Join(points, ","))
- require.True(t, tk.HasPlan(queryRange, "Batch_Point_Get")) // check if BatchGet is used
+ tk.MustHavePlan(queryRange, "Batch_Point_Get") // check if BatchGet is used
tk.MustQuery(queryRange).Sort().Check(tk.MustQuery(queryRegular1).Sort().Rows())
points = make([]string, 0, 10)
@@ -1375,7 +1375,7 @@ func TestBatchGetforRangeandListPartitionTable(t *testing.T) {
}
queryRegular2 := fmt.Sprintf("select a from tregular2 where a in (%v)", strings.Join(points, ","))
queryList := fmt.Sprintf("select a from tlist where a in (%v)", strings.Join(points, ","))
- require.True(t, tk.HasPlan(queryList, "Batch_Point_Get")) // check if BatchGet is used
+ tk.MustHavePlan(queryList, "Batch_Point_Get") // check if BatchGet is used
tk.MustQuery(queryList).Sort().Check(tk.MustQuery(queryRegular2).Sort().Rows())
}
@@ -1405,7 +1405,7 @@ func TestBatchGetforRangeandListPartitionTable(t *testing.T) {
}
queryRegular := fmt.Sprintf("select a from tregular3 where a in (%v)", strings.Join(points, ","))
queryRange := fmt.Sprintf("select a from trange3 where a in (%v)", strings.Join(points, ","))
- require.True(t, tk.HasPlan(queryRange, "Batch_Point_Get")) // check if BatchGet is used
+ tk.MustHavePlan(queryRange, "Batch_Point_Get") // check if BatchGet is used
tk.MustQuery(queryRange).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
}
@@ -1453,17 +1453,17 @@ func TestGlobalStatsAndSQLBinding(t *testing.T) {
tk.MustExec("insert into tlist values " + strings.Join(listVals, ","))
// before analyzing, the planner will choose TableScan to access the 1% of records
- require.True(t, tk.HasPlan("select * from thash where a<100", "TableFullScan"))
- require.True(t, tk.HasPlan("select * from trange where a<100", "TableFullScan"))
- require.True(t, tk.HasPlan("select * from tlist where a<1", "TableFullScan"))
+ tk.MustHavePlan("select * from thash where a<100", "TableFullScan")
+ tk.MustHavePlan("select * from trange where a<100", "TableFullScan")
+ tk.MustHavePlan("select * from tlist where a<1", "TableFullScan")
tk.MustExec("analyze table thash")
tk.MustExec("analyze table trange")
tk.MustExec("analyze table tlist")
- require.True(t, tk.HasPlan("select * from thash where a<100", "TableFullScan"))
- require.True(t, tk.HasPlan("select * from trange where a<100", "TableFullScan"))
- require.True(t, tk.HasPlan("select * from tlist where a<1", "TableFullScan"))
+ tk.MustHavePlan("select * from thash where a<100", "TableFullScan")
+ tk.MustHavePlan("select * from trange where a<100", "TableFullScan")
+ tk.MustHavePlan("select * from tlist where a<1", "TableFullScan")
// create SQL bindings
tk.MustExec("create session binding for select * from thash where a<100 using select * from thash ignore index(a) where a<100")
@@ -1471,18 +1471,18 @@ func TestGlobalStatsAndSQLBinding(t *testing.T) {
tk.MustExec("create session binding for select * from tlist where a<100 using select * from tlist ignore index(a) where a<100")
// use TableScan again since the Index(a) is ignored
- require.True(t, tk.HasPlan("select * from thash where a<100", "TableFullScan"))
- require.True(t, tk.HasPlan("select * from trange where a<100", "TableFullScan"))
- require.True(t, tk.HasPlan("select * from tlist where a<1", "TableFullScan"))
+ tk.MustHavePlan("select * from thash where a<100", "TableFullScan")
+ tk.MustHavePlan("select * from trange where a<100", "TableFullScan")
+ tk.MustHavePlan("select * from tlist where a<1", "TableFullScan")
// drop SQL bindings
tk.MustExec("drop session binding for select * from thash where a<100")
tk.MustExec("drop session binding for select * from trange where a<100")
tk.MustExec("drop session binding for select * from tlist where a<100")
- require.True(t, tk.HasPlan("select * from thash where a<100", "TableFullScan"))
- require.True(t, tk.HasPlan("select * from trange where a<100", "TableFullScan"))
- require.True(t, tk.HasPlan("select * from tlist where a<1", "TableFullScan"))
+ tk.MustHavePlan("select * from thash where a<100", "TableFullScan")
+ tk.MustHavePlan("select * from trange where a<100", "TableFullScan")
+ tk.MustHavePlan("select * from tlist where a<1", "TableFullScan")
}
func TestPartitionTableWithDifferentJoin(t *testing.T) {
@@ -1532,78 +1532,78 @@ func TestPartitionTableWithDifferentJoin(t *testing.T) {
// hash_join range partition and hash partition
queryHash := fmt.Sprintf("select /*+ hash_join(trange, thash) */ * from trange, thash where trange.b=thash.b and thash.a = %v and trange.a > %v;", x1, x2)
queryRegular := fmt.Sprintf("select /*+ hash_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.b=tregular1.b and tregular1.a = %v and tregular2.a > %v;", x1, x2)
- require.True(t, tk.HasPlan(queryHash, "HashJoin"))
+ tk.MustHavePlan(queryHash, "HashJoin")
tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
queryHash = fmt.Sprintf("select /*+ hash_join(trange, thash) */ * from trange, thash where trange.a=thash.a and thash.a > %v;", x1)
queryRegular = fmt.Sprintf("select /*+ hash_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.a=tregular1.a and tregular1.a > %v;", x1)
- require.True(t, tk.HasPlan(queryHash, "HashJoin"))
+ tk.MustHavePlan(queryHash, "HashJoin")
tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
queryHash = fmt.Sprintf("select /*+ hash_join(trange, thash) */ * from trange, thash where trange.a=thash.a and trange.b = thash.b and thash.a > %v;", x1)
queryRegular = fmt.Sprintf("select /*+ hash_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.a=tregular1.a and tregular1.b = tregular2.b and tregular1.a > %v;", x1)
- require.True(t, tk.HasPlan(queryHash, "HashJoin"))
+ tk.MustHavePlan(queryHash, "HashJoin")
tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
queryHash = fmt.Sprintf("select /*+ hash_join(trange, thash) */ * from trange, thash where trange.a=thash.a and thash.a = %v;", x1)
queryRegular = fmt.Sprintf("select /*+ hash_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.a=tregular1.a and tregular1.a = %v;", x1)
- require.True(t, tk.HasPlan(queryHash, "HashJoin"))
+ tk.MustHavePlan(queryHash, "HashJoin")
tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
// group 2
// hash_join range partition and regular table
queryHash = fmt.Sprintf("select /*+ hash_join(trange, tregular1) */ * from trange, tregular1 where trange.a = tregular1.a and trange.a >= %v and tregular1.a > %v;", x1, x2)
queryRegular = fmt.Sprintf("select /*+ hash_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.a = tregular1.a and tregular2.a >= %v and tregular1.a > %v;", x1, x2)
- require.True(t, tk.HasPlan(queryHash, "HashJoin"))
+ tk.MustHavePlan(queryHash, "HashJoin")
tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
queryHash = fmt.Sprintf("select /*+ hash_join(trange, tregular1) */ * from trange, tregular1 where trange.a = tregular1.a and trange.a in (%v, %v, %v);", x1, x2, x3)
queryRegular = fmt.Sprintf("select /*+ hash_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.a = tregular1.a and tregular2.a in (%v, %v, %v);", x1, x2, x3)
- require.True(t, tk.HasPlan(queryHash, "HashJoin"))
+ tk.MustHavePlan(queryHash, "HashJoin")
tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
queryHash = fmt.Sprintf("select /*+ hash_join(trange, tregular1) */ * from trange, tregular1 where trange.a = tregular1.a and tregular1.a >= %v;", x1)
queryRegular = fmt.Sprintf("select /*+ hash_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.a = tregular1.a and tregular1.a >= %v;", x1)
- require.True(t, tk.HasPlan(queryHash, "HashJoin"))
+ tk.MustHavePlan(queryHash, "HashJoin")
tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
// group 3
// merge_join range partition and hash partition
queryHash = fmt.Sprintf("select /*+ merge_join(trange, thash) */ * from trange, thash where trange.b=thash.b and thash.a = %v and trange.a > %v;", x1, x2)
queryRegular = fmt.Sprintf("select /*+ merge_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.b=tregular1.b and tregular1.a = %v and tregular2.a > %v;", x1, x2)
- require.True(t, tk.HasPlan(queryHash, "MergeJoin"))
+ tk.MustHavePlan(queryHash, "MergeJoin")
tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
queryHash = fmt.Sprintf("select /*+ merge_join(trange, thash) */ * from trange, thash where trange.a=thash.a and thash.a > %v;", x1)
queryRegular = fmt.Sprintf("select /*+ merge_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.a=tregular1.a and tregular1.a > %v;", x1)
- require.True(t, tk.HasPlan(queryHash, "MergeJoin"))
+ tk.MustHavePlan(queryHash, "MergeJoin")
tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
queryHash = fmt.Sprintf("select /*+ merge_join(trange, thash) */ * from trange, thash where trange.a=thash.a and trange.b = thash.b and thash.a > %v;", x1)
queryRegular = fmt.Sprintf("select /*+ merge_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.a=tregular1.a and tregular1.b = tregular2.b and tregular1.a > %v;", x1)
- require.True(t, tk.HasPlan(queryHash, "MergeJoin"))
+ tk.MustHavePlan(queryHash, "MergeJoin")
tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
queryHash = fmt.Sprintf("select /*+ merge_join(trange, thash) */ * from trange, thash where trange.a=thash.a and thash.a = %v;", x1)
queryRegular = fmt.Sprintf("select /*+ merge_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.a=tregular1.a and tregular1.a = %v;", x1)
- require.True(t, tk.HasPlan(queryHash, "MergeJoin"))
+ tk.MustHavePlan(queryHash, "MergeJoin")
tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
// group 4
// merge_join range partition and regular table
queryHash = fmt.Sprintf("select /*+ merge_join(trange, tregular1) */ * from trange, tregular1 where trange.a = tregular1.a and trange.a >= %v and tregular1.a > %v;", x1, x2)
queryRegular = fmt.Sprintf("select /*+ merge_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.a = tregular1.a and tregular2.a >= %v and tregular1.a > %v;", x1, x2)
- require.True(t, tk.HasPlan(queryHash, "MergeJoin"))
+ tk.MustHavePlan(queryHash, "MergeJoin")
tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
queryHash = fmt.Sprintf("select /*+ merge_join(trange, tregular1) */ * from trange, tregular1 where trange.a = tregular1.a and trange.a in (%v, %v, %v);", x1, x2, x3)
queryRegular = fmt.Sprintf("select /*+ merge_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.a = tregular1.a and tregular2.a in (%v, %v, %v);", x1, x2, x3)
- require.True(t, tk.HasPlan(queryHash, "MergeJoin"))
+ tk.MustHavePlan(queryHash, "MergeJoin")
tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
queryHash = fmt.Sprintf("select /*+ merge_join(trange, tregular1) */ * from trange, tregular1 where trange.a = tregular1.a and tregular1.a >= %v;", x1)
queryRegular = fmt.Sprintf("select /*+ merge_join(tregular2, tregular1) */ * from tregular2, tregular1 where tregular2.a = tregular1.a and tregular1.a >= %v;", x1)
- require.True(t, tk.HasPlan(queryHash, "MergeJoin"))
+ tk.MustHavePlan(queryHash, "MergeJoin")
tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
// new table instances
@@ -1637,28 +1637,28 @@ func TestPartitionTableWithDifferentJoin(t *testing.T) {
// Currently don't support index merge join on two partition tables. Only test warning.
queryHash = fmt.Sprintf("select /*+ inl_merge_join(trange, trange2) */ * from trange, trange2 where trange.a=trange2.a and trange.a > %v;", x1)
// queryRegular = fmt.Sprintf("select /*+ inl_merge_join(tregular2, tregular4) */ * from tregular2, tregular4 where tregular2.a=tregular4.a and tregular2.a > %v;", x1)
- // require.True(t,tk.HasPlan(queryHash, "IndexMergeJoin"))
+ // tk.MustHavePlan(queryHash, "IndexMergeJoin")
// tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
tk.MustQuery(queryHash)
tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1815|Optimizer Hint /*+ INL_MERGE_JOIN(trange, trange2) */ is inapplicable"))
queryHash = fmt.Sprintf("select /*+ inl_merge_join(trange, trange2) */ * from trange, trange2 where trange.a=trange2.a and trange.a > %v and trange2.a > %v;", x1, x2)
// queryRegular = fmt.Sprintf("select /*+ inl_merge_join(tregular2, tregular4) */ * from tregular2, tregular4 where tregular2.a=tregular4.a and tregular2.a > %v and tregular4.a > %v;", x1, x2)
- // require.True(t,tk.HasPlan(queryHash, "IndexMergeJoin"))
+ // tk.MustHavePlan(queryHash, "IndexMergeJoin")
// tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
tk.MustQuery(queryHash)
tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1815|Optimizer Hint /*+ INL_MERGE_JOIN(trange, trange2) */ is inapplicable"))
queryHash = fmt.Sprintf("select /*+ inl_merge_join(trange, trange2) */ * from trange, trange2 where trange.a=trange2.a and trange.a > %v and trange.b > %v;", x1, x2)
// queryRegular = fmt.Sprintf("select /*+ inl_merge_join(tregular2, tregular4) */ * from tregular2, tregular4 where tregular2.a=tregular4.a and tregular2.a > %v and tregular2.b > %v;", x1, x2)
- // require.True(t,tk.HasPlan(queryHash, "IndexMergeJoin"))
+ // tk.MustHavePlan(queryHash, "IndexMergeJoin")
// tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
tk.MustQuery(queryHash)
tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1815|Optimizer Hint /*+ INL_MERGE_JOIN(trange, trange2) */ is inapplicable"))
queryHash = fmt.Sprintf("select /*+ inl_merge_join(trange, trange2) */ * from trange, trange2 where trange.a=trange2.a and trange.a > %v and trange2.b > %v;", x1, x2)
// queryRegular = fmt.Sprintf("select /*+ inl_merge_join(tregular2, tregular4) */ * from tregular2, tregular4 where tregular2.a=tregular4.a and tregular2.a > %v and tregular4.b > %v;", x1, x2)
- // require.True(t,tk.HasPlan(queryHash, "IndexMergeJoin"))
+ // tk.MustHavePlan(queryHash, "IndexMergeJoin")
// tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
tk.MustQuery(queryHash)
tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1815|Optimizer Hint /*+ INL_MERGE_JOIN(trange, trange2) */ is inapplicable"))
@@ -1667,56 +1667,56 @@ func TestPartitionTableWithDifferentJoin(t *testing.T) {
// index_merge_join range partition and regualr table
queryHash = fmt.Sprintf("select /*+ inl_merge_join(trange, tregular4) */ * from trange, tregular4 where trange.a=tregular4.a and trange.a > %v;", x1)
queryRegular = fmt.Sprintf("select /*+ inl_merge_join(tregular2, tregular4) */ * from tregular2, tregular4 where tregular2.a=tregular4.a and tregular2.a > %v;", x1)
- require.True(t, tk.HasPlan(queryHash, "IndexMergeJoin"))
+ tk.MustHavePlan(queryHash, "IndexMergeJoin")
tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
queryHash = fmt.Sprintf("select /*+ inl_merge_join(trange, tregular4) */ * from trange, tregular4 where trange.a=tregular4.a and trange.a > %v and tregular4.a > %v;", x1, x2)
queryRegular = fmt.Sprintf("select /*+ inl_merge_join(tregular2, tregular4) */ * from tregular2, tregular4 where tregular2.a=tregular4.a and tregular2.a > %v and tregular4.a > %v;", x1, x2)
- require.True(t, tk.HasPlan(queryHash, "IndexMergeJoin"))
+ tk.MustHavePlan(queryHash, "IndexMergeJoin")
tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
queryHash = fmt.Sprintf("select /*+ inl_merge_join(trange, tregular4) */ * from trange, tregular4 where trange.a=tregular4.a and trange.a > %v and trange.b > %v;", x1, x2)
queryRegular = fmt.Sprintf("select /*+ inl_merge_join(tregular2, tregular4) */ * from tregular2, tregular4 where tregular2.a=tregular4.a and tregular2.a > %v and tregular2.b > %v;", x1, x2)
- require.True(t, tk.HasPlan(queryHash, "IndexMergeJoin"))
+ tk.MustHavePlan(queryHash, "IndexMergeJoin")
tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
queryHash = fmt.Sprintf("select /*+ inl_merge_join(trange, tregular4) */ * from trange, tregular4 where trange.a=tregular4.a and trange.a > %v and tregular4.b > %v;", x1, x2)
queryRegular = fmt.Sprintf("select /*+ inl_merge_join(tregular2, tregular4) */ * from tregular2, tregular4 where tregular2.a=tregular4.a and tregular2.a > %v and tregular4.b > %v;", x1, x2)
- require.True(t, tk.HasPlan(queryHash, "IndexMergeJoin"))
+ tk.MustHavePlan(queryHash, "IndexMergeJoin")
tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
// group 7
// index_hash_join hash partition and hash partition
queryHash = fmt.Sprintf("select /*+ inl_hash_join(thash, thash2) */ * from thash, thash2 where thash.a = thash2.a and thash.a in (%v, %v);", x1, x2)
queryRegular = fmt.Sprintf("select /*+ inl_hash_join(tregular1, tregular3) */ * from tregular1, tregular3 where tregular1.a = tregular3.a and tregular1.a in (%v, %v);", x1, x2)
- require.True(t, tk.HasPlan(queryHash, "IndexHashJoin"))
+ tk.MustHavePlan(queryHash, "IndexHashJoin")
tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
queryHash = fmt.Sprintf("select /*+ inl_hash_join(thash, thash2) */ * from thash, thash2 where thash.a = thash2.a and thash.a in (%v, %v) and thash2.a in (%v, %v);", x1, x2, x3, x4)
queryRegular = fmt.Sprintf("select /*+ inl_hash_join(tregular1, tregular3) */ * from tregular1, tregular3 where tregular1.a = tregular3.a and tregular1.a in (%v, %v) and tregular3.a in (%v, %v);", x1, x2, x3, x4)
- require.True(t, tk.HasPlan(queryHash, "IndexHashJoin"))
+ tk.MustHavePlan(queryHash, "IndexHashJoin")
tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
queryHash = fmt.Sprintf("select /*+ inl_hash_join(thash, thash2) */ * from thash, thash2 where thash.a = thash2.a and thash.a > %v and thash2.b > %v;", x1, x2)
queryRegular = fmt.Sprintf("select /*+ inl_hash_join(tregular1, tregular3) */ * from tregular1, tregular3 where tregular1.a = tregular3.a and tregular1.a > %v and tregular3.b > %v;", x1, x2)
- require.True(t, tk.HasPlan(queryHash, "IndexHashJoin"))
+ tk.MustHavePlan(queryHash, "IndexHashJoin")
tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
// group 8
// index_hash_join hash partition and hash partition
queryHash = fmt.Sprintf("select /*+ inl_hash_join(thash, tregular3) */ * from thash, tregular3 where thash.a = tregular3.a and thash.a in (%v, %v);", x1, x2)
queryRegular = fmt.Sprintf("select /*+ inl_hash_join(tregular1, tregular3) */ * from tregular1, tregular3 where tregular1.a = tregular3.a and tregular1.a in (%v, %v);", x1, x2)
- require.True(t, tk.HasPlan(queryHash, "IndexHashJoin"))
+ tk.MustHavePlan(queryHash, "IndexHashJoin")
tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
queryHash = fmt.Sprintf("select /*+ inl_hash_join(thash, tregular3) */ * from thash, tregular3 where thash.a = tregular3.a and thash.a in (%v, %v) and tregular3.a in (%v, %v);", x1, x2, x3, x4)
queryRegular = fmt.Sprintf("select /*+ inl_hash_join(tregular1, tregular3) */ * from tregular1, tregular3 where tregular1.a = tregular3.a and tregular1.a in (%v, %v) and tregular3.a in (%v, %v);", x1, x2, x3, x4)
- require.True(t, tk.HasPlan(queryHash, "IndexHashJoin"))
+ tk.MustHavePlan(queryHash, "IndexHashJoin")
tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
queryHash = fmt.Sprintf("select /*+ inl_hash_join(thash, tregular3) */ * from thash, tregular3 where thash.a = tregular3.a and thash.a > %v and tregular3.b > %v;", x1, x2)
queryRegular = fmt.Sprintf("select /*+ inl_hash_join(tregular1, tregular3) */ * from tregular1, tregular3 where tregular1.a = tregular3.a and tregular1.a > %v and tregular3.b > %v;", x1, x2)
- require.True(t, tk.HasPlan(queryHash, "IndexHashJoin"))
+ tk.MustHavePlan(queryHash, "IndexHashJoin")
tk.MustQuery(queryHash).Sort().Check(tk.MustQuery(queryRegular).Sort().Rows())
}
@@ -2819,7 +2819,7 @@ func TestDirectReadingWithUnionScan(t *testing.T) {
var result [][]interface{}
for _, tb := range []string{`trange`, `tnormal`, `thash`} {
q := fmt.Sprintf(sql, tb)
- tk.HasPlan(q, `UnionScan`)
+ tk.MustHavePlan(q, `UnionScan`)
if result == nil {
result = tk.MustQuery(q).Sort().Rows()
} else {
@@ -2907,7 +2907,7 @@ func TestUnsignedPartitionColumn(t *testing.T) {
for tid, tbl := range []string{"tnormal_pk", "trange_pk", "thash_pk"} {
// unsigned + TableReader
scanSQL := fmt.Sprintf("select * from %v use index(primary) where %v", tbl, scanCond)
- require.True(t, tk.HasPlan(scanSQL, "TableReader"))
+ tk.MustHavePlan(scanSQL, "TableReader")
r := tk.MustQuery(scanSQL).Sort()
if tid == 0 {
rScan = r.Rows()
@@ -2927,7 +2927,7 @@ func TestUnsignedPartitionColumn(t *testing.T) {
// unsigned + BatchGet on PK
batchSQL := fmt.Sprintf("select * from %v where %v", tbl, batchCond)
- require.True(t, tk.HasPlan(batchSQL, "Batch_Point_Get"))
+ tk.MustHavePlan(batchSQL, "Batch_Point_Get")
r = tk.MustQuery(batchSQL).Sort()
if tid == 0 {
rBatch = r.Rows()
@@ -2941,7 +2941,7 @@ func TestUnsignedPartitionColumn(t *testing.T) {
for tid, tbl := range []string{"tnormal_uniq", "trange_uniq", "thash_uniq"} {
// unsigned + IndexReader
scanSQL := fmt.Sprintf("select a from %v use index(a) where %v", tbl, scanCond)
- require.True(t, tk.HasPlan(scanSQL, "IndexReader"))
+ tk.MustHavePlan(scanSQL, "IndexReader")
r := tk.MustQuery(scanSQL).Sort()
if tid == 0 {
rScan = r.Rows()
@@ -2971,7 +2971,7 @@ func TestUnsignedPartitionColumn(t *testing.T) {
// unsigned + BatchGet on UniqueIndex
batchSQL := fmt.Sprintf("select * from %v where %v", tbl, batchCond)
- require.True(t, tk.HasPlan(batchSQL, "Batch_Point_Get"))
+ tk.MustHavePlan(batchSQL, "Batch_Point_Get")
r = tk.MustQuery(batchSQL).Sort()
if tid == 0 {
rBatch = r.Rows()
@@ -3038,12 +3038,12 @@ func TestDirectReadingWithAgg(t *testing.T) {
queryPartition1 := fmt.Sprintf("select /*+ stream_agg() */ count(*), sum(b), max(b), a from trange where a > %v group by a;", x)
queryRegular1 := fmt.Sprintf("select /*+ stream_agg() */ count(*), sum(b), max(b), a from tregular1 where a > %v group by a;", x)
- require.True(t, tk.HasPlan(queryPartition1, "StreamAgg")) // check if IndexLookUp is used
+ tk.MustHavePlan(queryPartition1, "StreamAgg") // check if IndexLookUp is used
tk.MustQuery(queryPartition1).Sort().Check(tk.MustQuery(queryRegular1).Sort().Rows())
queryPartition2 := fmt.Sprintf("select /*+ hash_agg() */ count(*), sum(b), max(b), a from trange where a > %v group by a;", x)
queryRegular2 := fmt.Sprintf("select /*+ hash_agg() */ count(*), sum(b), max(b), a from tregular1 where a > %v group by a;", x)
- require.True(t, tk.HasPlan(queryPartition2, "HashAgg")) // check if IndexLookUp is used
+ tk.MustHavePlan(queryPartition2, "HashAgg") // check if IndexLookUp is used
tk.MustQuery(queryPartition2).Sort().Check(tk.MustQuery(queryRegular2).Sort().Rows())
y := rand.Intn(1099)
@@ -3051,12 +3051,12 @@ func TestDirectReadingWithAgg(t *testing.T) {
queryPartition3 := fmt.Sprintf("select /*+ stream_agg() */ count(*), sum(b), max(b), a from trange where a in(%v, %v, %v) group by a;", x, y, z)
queryRegular3 := fmt.Sprintf("select /*+ stream_agg() */ count(*), sum(b), max(b), a from tregular1 where a in(%v, %v, %v) group by a;", x, y, z)
- require.True(t, tk.HasPlan(queryPartition3, "StreamAgg")) // check if IndexLookUp is used
+ tk.MustHavePlan(queryPartition3, "StreamAgg") // check if IndexLookUp is used
tk.MustQuery(queryPartition3).Sort().Check(tk.MustQuery(queryRegular3).Sort().Rows())
queryPartition4 := fmt.Sprintf("select /*+ hash_agg() */ count(*), sum(b), max(b), a from trange where a in (%v, %v, %v) group by a;", x, y, z)
queryRegular4 := fmt.Sprintf("select /*+ hash_agg() */ count(*), sum(b), max(b), a from tregular1 where a in (%v, %v, %v) group by a;", x, y, z)
- require.True(t, tk.HasPlan(queryPartition4, "HashAgg")) // check if IndexLookUp is used
+ tk.MustHavePlan(queryPartition4, "HashAgg") // check if IndexLookUp is used
tk.MustQuery(queryPartition4).Sort().Check(tk.MustQuery(queryRegular4).Sort().Rows())
}
@@ -3070,12 +3070,12 @@ func TestDirectReadingWithAgg(t *testing.T) {
queryPartition1 := fmt.Sprintf("select /*+ stream_agg() */ count(*), sum(b), max(b), a from thash where a > %v group by a;", x)
queryRegular1 := fmt.Sprintf("select /*+ stream_agg() */ count(*), sum(b), max(b), a from tregular1 where a > %v group by a;", x)
- require.True(t, tk.HasPlan(queryPartition1, "StreamAgg")) // check if IndexLookUp is used
+ tk.MustHavePlan(queryPartition1, "StreamAgg") // check if IndexLookUp is used
tk.MustQuery(queryPartition1).Sort().Check(tk.MustQuery(queryRegular1).Sort().Rows())
queryPartition2 := fmt.Sprintf("select /*+ hash_agg() */ count(*), sum(b), max(b), a from thash where a > %v group by a;", x)
queryRegular2 := fmt.Sprintf("select /*+ hash_agg() */ count(*), sum(b), max(b), a from tregular1 where a > %v group by a;", x)
- require.True(t, tk.HasPlan(queryPartition2, "HashAgg")) // check if IndexLookUp is used
+ tk.MustHavePlan(queryPartition2, "HashAgg") // check if IndexLookUp is used
tk.MustQuery(queryPartition2).Sort().Check(tk.MustQuery(queryRegular2).Sort().Rows())
y := rand.Intn(1099)
@@ -3083,12 +3083,12 @@ func TestDirectReadingWithAgg(t *testing.T) {
queryPartition3 := fmt.Sprintf("select /*+ stream_agg() */ count(*), sum(b), max(b), a from thash where a in(%v, %v, %v) group by a;", x, y, z)
queryRegular3 := fmt.Sprintf("select /*+ stream_agg() */ count(*), sum(b), max(b), a from tregular1 where a in(%v, %v, %v) group by a;", x, y, z)
- require.True(t, tk.HasPlan(queryPartition3, "StreamAgg")) // check if IndexLookUp is used
+ tk.MustHavePlan(queryPartition3, "StreamAgg") // check if IndexLookUp is used
tk.MustQuery(queryPartition3).Sort().Check(tk.MustQuery(queryRegular3).Sort().Rows())
queryPartition4 := fmt.Sprintf("select /*+ hash_agg() */ count(*), sum(b), max(b), a from thash where a in (%v, %v, %v) group by a;", x, y, z)
queryRegular4 := fmt.Sprintf("select /*+ hash_agg() */ count(*), sum(b), max(b), a from tregular1 where a in (%v, %v, %v) group by a;", x, y, z)
- require.True(t, tk.HasPlan(queryPartition4, "HashAgg")) // check if IndexLookUp is used
+ tk.MustHavePlan(queryPartition4, "HashAgg") // check if IndexLookUp is used
tk.MustQuery(queryPartition4).Sort().Check(tk.MustQuery(queryRegular4).Sort().Rows())
}
@@ -3102,12 +3102,12 @@ func TestDirectReadingWithAgg(t *testing.T) {
queryPartition1 := fmt.Sprintf("select /*+ stream_agg() */ count(*), sum(b), max(b), a from tlist where a > %v group by a;", x)
queryRegular1 := fmt.Sprintf("select /*+ stream_agg() */ count(*), sum(b), max(b), a from tregular2 where a > %v group by a;", x)
- require.True(t, tk.HasPlan(queryPartition1, "StreamAgg")) // check if IndexLookUp is used
+ tk.MustHavePlan(queryPartition1, "StreamAgg") // check if IndexLookUp is used
tk.MustQuery(queryPartition1).Sort().Check(tk.MustQuery(queryRegular1).Sort().Rows())
queryPartition2 := fmt.Sprintf("select /*+ hash_agg() */ count(*), sum(b), max(b), a from tlist where a > %v group by a;", x)
queryRegular2 := fmt.Sprintf("select /*+ hash_agg() */ count(*), sum(b), max(b), a from tregular2 where a > %v group by a;", x)
- require.True(t, tk.HasPlan(queryPartition2, "HashAgg")) // check if IndexLookUp is used
+ tk.MustHavePlan(queryPartition2, "HashAgg") // check if IndexLookUp is used
tk.MustQuery(queryPartition2).Sort().Check(tk.MustQuery(queryRegular2).Sort().Rows())
y := rand.Intn(12) + 1
@@ -3115,12 +3115,12 @@ func TestDirectReadingWithAgg(t *testing.T) {
queryPartition3 := fmt.Sprintf("select /*+ stream_agg() */ count(*), sum(b), max(b), a from tlist where a in(%v, %v, %v) group by a;", x, y, z)
queryRegular3 := fmt.Sprintf("select /*+ stream_agg() */ count(*), sum(b), max(b), a from tregular2 where a in(%v, %v, %v) group by a;", x, y, z)
- require.True(t, tk.HasPlan(queryPartition3, "StreamAgg")) // check if IndexLookUp is used
+ tk.MustHavePlan(queryPartition3, "StreamAgg") // check if IndexLookUp is used
tk.MustQuery(queryPartition3).Sort().Check(tk.MustQuery(queryRegular3).Sort().Rows())
queryPartition4 := fmt.Sprintf("select /*+ hash_agg() */ count(*), sum(b), max(b), a from tlist where a in (%v, %v, %v) group by a;", x, y, z)
queryRegular4 := fmt.Sprintf("select /*+ hash_agg() */ count(*), sum(b), max(b), a from tregular2 where a in (%v, %v, %v) group by a;", x, y, z)
- require.True(t, tk.HasPlan(queryPartition4, "HashAgg")) // check if IndexLookUp is used
+ tk.MustHavePlan(queryPartition4, "HashAgg") // check if IndexLookUp is used
tk.MustQuery(queryPartition4).Sort().Check(tk.MustQuery(queryRegular4).Sort().Rows())
}
}
@@ -3242,12 +3242,12 @@ func TestIdexMerge(t *testing.T) {
queryPartition1 := fmt.Sprintf("select /*+ use_index_merge(trange) */ * from trange where a > %v or b < %v;", x1, x2)
queryRegular1 := fmt.Sprintf("select /*+ use_index_merge(tregular1) */ * from tregular1 where a > %v or b < %v;", x1, x2)
- require.True(t, tk.HasPlan(queryPartition1, "IndexMerge")) // check if IndexLookUp is used
+ tk.MustHavePlan(queryPartition1, "IndexMerge") // check if IndexLookUp is used
tk.MustQuery(queryPartition1).Sort().Check(tk.MustQuery(queryRegular1).Sort().Rows())
queryPartition2 := fmt.Sprintf("select /*+ use_index_merge(trange) */ * from trange where a > %v or b > %v;", x1, x2)
queryRegular2 := fmt.Sprintf("select /*+ use_index_merge(tregular1) */ * from tregular1 where a > %v or b > %v;", x1, x2)
- require.True(t, tk.HasPlan(queryPartition2, "IndexMerge")) // check if IndexLookUp is used
+ tk.MustHavePlan(queryPartition2, "IndexMerge") // check if IndexLookUp is used
tk.MustQuery(queryPartition2).Sort().Check(tk.MustQuery(queryRegular2).Sort().Rows())
}
@@ -3258,12 +3258,12 @@ func TestIdexMerge(t *testing.T) {
queryPartition1 := fmt.Sprintf("select /*+ use_index_merge(thash) */ * from thash where a > %v or b < %v;", x1, x2)
queryRegular1 := fmt.Sprintf("select /*+ use_index_merge(tregualr1) */ * from tregular1 where a > %v or b < %v;", x1, x2)
- require.True(t, tk.HasPlan(queryPartition1, "IndexMerge")) // check if IndexLookUp is used
+ tk.MustHavePlan(queryPartition1, "IndexMerge") // check if IndexLookUp is used
tk.MustQuery(queryPartition1).Sort().Check(tk.MustQuery(queryRegular1).Sort().Rows())
queryPartition2 := fmt.Sprintf("select /*+ use_index_merge(thash) */ * from thash where a > %v or b > %v;", x1, x2)
queryRegular2 := fmt.Sprintf("select /*+ use_index_merge(tregular1) */ * from tregular1 where a > %v or b > %v;", x1, x2)
- require.True(t, tk.HasPlan(queryPartition2, "IndexMerge")) // check if IndexLookUp is used
+ tk.MustHavePlan(queryPartition2, "IndexMerge") // check if IndexLookUp is used
tk.MustQuery(queryPartition2).Sort().Check(tk.MustQuery(queryRegular2).Sort().Rows())
}
@@ -3273,12 +3273,12 @@ func TestIdexMerge(t *testing.T) {
x2 := rand.Intn(12) + 1
queryPartition1 := fmt.Sprintf("select /*+ use_index_merge(tlist) */ * from tlist where a > %v or b < %v;", x1, x2)
queryRegular1 := fmt.Sprintf("select /*+ use_index_merge(tregular2) */ * from tregular2 where a > %v or b < %v;", x1, x2)
- require.True(t, tk.HasPlan(queryPartition1, "IndexMerge")) // check if IndexLookUp is used
+ tk.MustHavePlan(queryPartition1, "IndexMerge") // check if IndexLookUp is used
tk.MustQuery(queryPartition1).Sort().Check(tk.MustQuery(queryRegular1).Sort().Rows())
queryPartition2 := fmt.Sprintf("select /*+ use_index_merge(tlist) */ * from tlist where a > %v or b > %v;", x1, x2)
queryRegular2 := fmt.Sprintf("select /*+ use_index_merge(tregular2) */ * from tregular2 where a > %v or b > %v;", x1, x2)
- require.True(t, tk.HasPlan(queryPartition2, "IndexMerge")) // check if IndexLookUp is used
+ tk.MustHavePlan(queryPartition2, "IndexMerge") // check if IndexLookUp is used
tk.MustQuery(queryPartition2).Sort().Check(tk.MustQuery(queryRegular2).Sort().Rows())
}
}
diff --git a/executor/pipelined_window.go b/executor/pipelined_window.go
index d24f4fc3c4f48..020ad2bec11a3 100644
--- a/executor/pipelined_window.go
+++ b/executor/pipelined_window.go
@@ -264,7 +264,7 @@ func (e *PipelinedWindowExec) getStart(ctx sessionctx.Context) (uint64, error) {
var res int64
var err error
for i := range e.orderByCols {
- res, _, err = e.start.CmpFuncs[i](ctx, e.orderByCols[i], e.start.CalcFuncs[i], e.getRow(start), e.getRow(e.curRowIdx))
+ res, _, err = e.start.CmpFuncs[i](ctx, e.start.CompareCols[i], e.start.CalcFuncs[i], e.getRow(start), e.getRow(e.curRowIdx))
if err != nil {
return 0, err
}
@@ -304,7 +304,7 @@ func (e *PipelinedWindowExec) getEnd(ctx sessionctx.Context) (uint64, error) {
var res int64
var err error
for i := range e.orderByCols {
- res, _, err = e.end.CmpFuncs[i](ctx, e.end.CalcFuncs[i], e.orderByCols[i], e.getRow(e.curRowIdx), e.getRow(end))
+ res, _, err = e.end.CmpFuncs[i](ctx, e.end.CalcFuncs[i], e.end.CompareCols[i], e.getRow(e.curRowIdx), e.getRow(end))
if err != nil {
return 0, err
}
diff --git a/executor/point_get_test.go b/executor/point_get_test.go
index 8f13675457481..3c273ac202bda 100644
--- a/executor/point_get_test.go
+++ b/executor/point_get_test.go
@@ -147,7 +147,7 @@ func TestIssue25489(t *testing.T) {
PARTITION PMX VALUES LESS THAN (MAXVALUE)
) ;`)
query := "select col1, col2 from UK_RP16939 where col1 in (116, 48, -30);"
- require.False(t, tk.HasPlan(query, "Batch_Point_Get"))
+ tk.MustNotHavePlan(query, "Batch_Point_Get")
tk.MustQuery(query).Check(testkit.Rows())
tk.MustExec("drop table if exists UK_RP16939;")
@@ -165,7 +165,7 @@ func TestIssue25489(t *testing.T) {
PARTITION P1 VALUES IN (-22, 63),
PARTITION P2 VALUES IN (75, 90)
) ;`)
- require.False(t, tk.HasPlan(query, "Batch_Point_Get"))
+ tk.MustNotHavePlan(query, "Batch_Point_Get")
tk.MustQuery(query).Check(testkit.Rows())
tk.MustExec("drop table if exists UK_RP16939;")
}
diff --git a/executor/sample_test.go b/executor/sample_test.go
index 19183c235fcc2..2dd347534924d 100644
--- a/executor/sample_test.go
+++ b/executor/sample_test.go
@@ -51,7 +51,7 @@ func TestTableSampleBasic(t *testing.T) {
tk.MustExec("alter table t add column c int as (a + 1);")
tk.MustQuery("select c from t tablesample regions();").Check(testkit.Rows("1"))
tk.MustQuery("select c, _tidb_rowid from t tablesample regions();").Check(testkit.Rows("1 1"))
- require.True(t, tk.HasPlan("select * from t tablesample regions();", "TableSample"))
+ tk.MustHavePlan("select * from t tablesample regions();", "TableSample")
tk.MustExec("drop table if exists t;")
tk.MustExec("create table t(a BIGINT PRIMARY KEY AUTO_RANDOM(3), b int auto_increment, key(b)) pre_split_regions=8;")
diff --git a/executor/temporary_table_test.go b/executor/temporary_table_test.go
index 5174a91976a21..43ca00905690e 100644
--- a/executor/temporary_table_test.go
+++ b/executor/temporary_table_test.go
@@ -111,23 +111,23 @@ func assertTemporaryTableNoNetwork(t *testing.T, createTable func(*testkit.TestK
// Check the temporary table do not send request to TiKV.
// PointGet
- require.True(t, tk.HasPlan("select * from tmp_t where id=1", "Point_Get"))
+ tk.MustHavePlan("select * from tmp_t where id=1", "Point_Get")
tk.MustQuery("select * from tmp_t where id=1").Check(testkit.Rows("1 1 1"))
// BatchPointGet
- require.True(t, tk.HasPlan("select * from tmp_t where id in (1, 2)", "Batch_Point_Get"))
+ tk.MustHavePlan("select * from tmp_t where id in (1, 2)", "Batch_Point_Get")
tk.MustQuery("select * from tmp_t where id in (1, 2)").Check(testkit.Rows("1 1 1", "2 2 2"))
// Table reader
- require.True(t, tk.HasPlan("select * from tmp_t", "TableReader"))
+ tk.MustHavePlan("select * from tmp_t", "TableReader")
tk.MustQuery("select * from tmp_t").Check(testkit.Rows("1 1 1", "2 2 2"))
// Index reader
- require.True(t, tk.HasPlan("select /*+ USE_INDEX(tmp_t, a) */ a from tmp_t", "IndexReader"))
+ tk.MustHavePlan("select /*+ USE_INDEX(tmp_t, a) */ a from tmp_t", "IndexReader")
tk.MustQuery("select /*+ USE_INDEX(tmp_t, a) */ a from tmp_t").Check(testkit.Rows("1", "2"))
// Index lookup
- require.True(t, tk.HasPlan("select /*+ USE_INDEX(tmp_t, a) */ b from tmp_t where a = 1", "IndexLookUp"))
+ tk.MustHavePlan("select /*+ USE_INDEX(tmp_t, a) */ b from tmp_t where a = 1", "IndexLookUp")
tk.MustQuery("select /*+ USE_INDEX(tmp_t, a) */ b from tmp_t where a = 1").Check(testkit.Rows("1"))
tk.MustExec("rollback")
diff --git a/executor/test/analyzetest/analyze_test.go b/executor/test/analyzetest/analyze_test.go
index 912ae3466a4bc..89141e5c5f08e 100644
--- a/executor/test/analyzetest/analyze_test.go
+++ b/executor/test/analyzetest/analyze_test.go
@@ -3116,6 +3116,12 @@ func TestAutoAnalyzeSkipColumnTypes(t *testing.T) {
// TestAnalyzeMVIndex tests analyzing the mv index use some real data in the table.
// It checks the analyze jobs, async loading and the stats content in the memory.
func TestAnalyzeMVIndex(t *testing.T) {
+ require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/DebugAnalyzeJobOperations", "return(true)"))
+ require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/statistics/handle/DebugAnalyzeJobOperations", "return(true)"))
+ defer func() {
+ require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/DebugAnalyzeJobOperations"))
+ require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/statistics/handle/DebugAnalyzeJobOperations"))
+ }()
// 1. prepare the table and insert data
store, dom := testkit.CreateMockStoreAndDomain(t)
h := dom.StatsHandle()
diff --git a/executor/test/executor/executor_test.go b/executor/test/executor/executor_test.go
index 2e86189d50f3a..6ce0d2f59047a 100644
--- a/executor/test/executor/executor_test.go
+++ b/executor/test/executor/executor_test.go
@@ -312,18 +312,6 @@ func TestShow(t *testing.T) {
require.Len(t, tk.MustQuery("show table status").Rows(), 1)
}
-func TestSelectWithoutFrom(t *testing.T) {
- store := testkit.CreateMockStore(t)
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustQuery("select 1 + 2*3;").Check(testkit.Rows("7"))
- tk.MustQuery(`select _utf8"string";`).Check(testkit.Rows("string"))
- tk.MustQuery("select 1 order by 1;").Check(testkit.Rows("1"))
- tk.MustQuery("SELECT 'a' as f1 having f1 = 'a';").Check(testkit.Rows("a"))
- tk.MustQuery("SELECT (SELECT * FROM (SELECT 'a') t) AS f1 HAVING (f1 = 'a' OR TRUE);").Check(testkit.Rows("a"))
- tk.MustQuery("SELECT (SELECT * FROM (SELECT 'a') t) + 1 AS f1 HAVING (f1 = 'a' OR TRUE)").Check(testkit.Rows("1"))
-}
-
// TestSelectBackslashN Issue 3685.
func TestSelectBackslashN(t *testing.T) {
store := testkit.CreateMockStore(t)
@@ -725,30 +713,6 @@ func TestSelectOrderBy(t *testing.T) {
tk.MustQuery("select a from t use index(b) order by b").Check(testkit.Rows("9", "8", "7", "6", "5", "4", "3", "2", "1", "0"))
}
-func TestOrderBy(t *testing.T) {
- store := testkit.CreateMockStore(t)
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("create table t (c1 int, c2 int, c3 varchar(20))")
- tk.MustExec("insert into t values (1, 2, 'abc'), (2, 1, 'bcd')")
-
- // Fix issue https://github.com/pingcap/tidb/issues/337
- tk.MustQuery("select c1 as a, c1 as b from t order by c1").Check(testkit.Rows("1 1", "2 2"))
-
- tk.MustQuery("select c1 as a, t.c1 as a from t order by a desc").Check(testkit.Rows("2 2", "1 1"))
- tk.MustQuery("select c1 as c2 from t order by c2").Check(testkit.Rows("1", "2"))
- tk.MustQuery("select sum(c1) from t order by sum(c1)").Check(testkit.Rows("3"))
- tk.MustQuery("select c1 as c2 from t order by c2 + 1").Check(testkit.Rows("2", "1"))
-
- // Order by position.
- tk.MustQuery("select * from t order by 1").Check(testkit.Rows("1 2 abc", "2 1 bcd"))
- tk.MustQuery("select * from t order by 2").Check(testkit.Rows("2 1 bcd", "1 2 abc"))
-
- // Order by binary.
- tk.MustQuery("select c1, c3 from t order by binary c1 desc").Check(testkit.Rows("2 bcd", "1 abc"))
- tk.MustQuery("select c1, c2 from t order by binary c3").Check(testkit.Rows("1 2", "2 1"))
-}
-
func TestSelectErrorRow(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
@@ -763,22 +727,6 @@ func TestSelectErrorRow(t *testing.T) {
require.Error(t, tk.ExecToErr("select * from test having (select 1, 1);"))
}
-func TestNeighbouringProj(t *testing.T) {
- store := testkit.CreateMockStore(t)
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("create table t1(a int, b int)")
- tk.MustExec("create table t2(a int, b int)")
- tk.MustExec("insert into t1 value(1, 1), (2, 2)")
- tk.MustExec("insert into t2 value(1, 1), (2, 2)")
- tk.MustQuery("select sum(c) from (select t1.a as a, t1.a as c, length(t1.b) from t1 union select a, b, b from t2) t;").Check(testkit.Rows("5"))
-
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a bigint, b bigint, c bigint);")
- tk.MustExec("insert into t values(1, 1, 1), (2, 2, 2), (3, 3, 3);")
- tk.MustQuery("select cast(count(a) as signed), a as another, a from t group by a order by cast(count(a) as signed), a limit 10;").Check(testkit.Rows("1 1 1", "1 2 2", "1 3 3"))
-}
-
func TestIn(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
@@ -857,33 +805,6 @@ func TestTablePKisHandleScan(t *testing.T) {
}
}
-func TestIndexReverseOrder(t *testing.T) {
- store := testkit.CreateMockStore(t)
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t (a int primary key auto_increment, b int, index idx (b))")
- tk.MustExec("insert t (b) values (0), (1), (2), (3), (4), (5), (6), (7), (8), (9)")
- tk.MustQuery("select b from t order by b desc").Check(testkit.Rows("9", "8", "7", "6", "5", "4", "3", "2", "1", "0"))
- tk.MustQuery("select b from t where b <3 or (b >=6 and b < 8) order by b desc").Check(testkit.Rows("7", "6", "2", "1", "0"))
-
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t (a int, b int, index idx (b, a))")
- tk.MustExec("insert t values (0, 2), (1, 2), (2, 2), (0, 1), (1, 1), (2, 1), (0, 0), (1, 0), (2, 0)")
- tk.MustQuery("select b, a from t order by b, a desc").Check(testkit.Rows("0 2", "0 1", "0 0", "1 2", "1 1", "1 0", "2 2", "2 1", "2 0"))
-}
-
-func TestTableReverseOrder(t *testing.T) {
- store := testkit.CreateMockStore(t)
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t (a int primary key auto_increment, b int)")
- tk.MustExec("insert t (b) values (1), (2), (3), (4), (5), (6), (7), (8), (9)")
- tk.MustQuery("select b from t order by a desc").Check(testkit.Rows("9", "8", "7", "6", "5", "4", "3", "2", "1"))
- tk.MustQuery("select a from t where a <3 or (a >=6 and a < 8) order by a desc").Check(testkit.Rows("7", "6", "2", "1"))
-}
-
func TestDefaultNull(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
@@ -901,18 +822,6 @@ func TestDefaultNull(t *testing.T) {
tk.MustQuery("select * from t").Check(testkit.Rows("1 1 "))
}
-func TestUnsignedPKColumn(t *testing.T) {
- store := testkit.CreateMockStore(t)
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t (a int unsigned primary key, b int, c int, key idx_ba (b, c, a));")
- tk.MustExec("insert t values (1, 1, 1)")
- tk.MustQuery("select * from t;").Check(testkit.Rows("1 1 1"))
- tk.MustExec("update t set c=2 where a=1;")
- tk.MustQuery("select * from t where b=1;").Check(testkit.Rows("1 1 2"))
-}
-
func TestJSON(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
@@ -983,31 +892,6 @@ func TestJSON(t *testing.T) {
"1234567890123456789012345678901234567890123456789012345.12"))
}
-func TestMultiUpdate(t *testing.T) {
- store := testkit.CreateMockStore(t)
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec(`CREATE TABLE test_mu (a int primary key, b int, c int)`)
- tk.MustExec(`INSERT INTO test_mu VALUES (1, 2, 3), (4, 5, 6), (7, 8, 9)`)
-
- // Test INSERT ... ON DUPLICATE UPDATE set_lists.
- tk.MustExec(`INSERT INTO test_mu VALUES (1, 2, 3) ON DUPLICATE KEY UPDATE b = 3, c = b`)
- tk.MustQuery(`SELECT * FROM test_mu ORDER BY a`).Check(testkit.Rows(`1 3 3`, `4 5 6`, `7 8 9`))
-
- tk.MustExec(`INSERT INTO test_mu VALUES (1, 2, 3) ON DUPLICATE KEY UPDATE c = 2, b = c+5`)
- tk.MustQuery(`SELECT * FROM test_mu ORDER BY a`).Check(testkit.Rows(`1 7 2`, `4 5 6`, `7 8 9`))
-
- // Test UPDATE ... set_lists.
- tk.MustExec(`UPDATE test_mu SET b = 0, c = b WHERE a = 4`)
- tk.MustQuery(`SELECT * FROM test_mu ORDER BY a`).Check(testkit.Rows(`1 7 2`, `4 0 5`, `7 8 9`))
-
- tk.MustExec(`UPDATE test_mu SET c = 8, b = c WHERE a = 4`)
- tk.MustQuery(`SELECT * FROM test_mu ORDER BY a`).Check(testkit.Rows(`1 7 2`, `4 5 8`, `7 8 9`))
-
- tk.MustExec(`UPDATE test_mu SET c = b, b = c WHERE a = 7`)
- tk.MustQuery(`SELECT * FROM test_mu ORDER BY a`).Check(testkit.Rows(`1 7 2`, `4 5 8`, `7 9 8`))
-}
-
func TestGeneratedColumnWrite(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
@@ -1193,77 +1077,6 @@ func TestGeneratedColumnRead(t *testing.T) {
}
}
-// TestGeneratedColumnRead tests generated columns using point get and batch point get
-func TestGeneratedColumnPointGet(t *testing.T) {
- store := testkit.CreateMockStore(t)
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("drop table if exists tu")
- tk.MustExec("CREATE TABLE tu(a int, b int, c int GENERATED ALWAYS AS (a + b) VIRTUAL, d int as (a * b) stored, " +
- "e int GENERATED ALWAYS as (b * 2) VIRTUAL, PRIMARY KEY (a), UNIQUE KEY ukc (c), unique key ukd(d), key ke(e))")
- tk.MustExec("insert into tu(a, b) values(1, 2)")
- tk.MustExec("insert into tu(a, b) values(5, 6)")
- tk.MustQuery("select * from tu for update").Check(testkit.Rows("1 2 3 2 4", "5 6 11 30 12"))
- tk.MustQuery("select * from tu where a = 1").Check(testkit.Rows("1 2 3 2 4"))
- tk.MustQuery("select * from tu where a in (1, 2)").Check(testkit.Rows("1 2 3 2 4"))
- tk.MustQuery("select * from tu where c in (1, 2, 3)").Check(testkit.Rows("1 2 3 2 4"))
- tk.MustQuery("select * from tu where c = 3").Check(testkit.Rows("1 2 3 2 4"))
- tk.MustQuery("select d, e from tu where c = 3").Check(testkit.Rows("2 4"))
- tk.MustQuery("select * from tu where d in (1, 2, 3)").Check(testkit.Rows("1 2 3 2 4"))
- tk.MustQuery("select * from tu where d = 2").Check(testkit.Rows("1 2 3 2 4"))
- tk.MustQuery("select c, d from tu where d = 2").Check(testkit.Rows("3 2"))
- tk.MustQuery("select d, e from tu where e = 4").Check(testkit.Rows("2 4"))
- tk.MustQuery("select * from tu where e = 4").Check(testkit.Rows("1 2 3 2 4"))
- tk.MustExec("update tu set a = a + 1, b = b + 1 where c = 11")
- tk.MustQuery("select * from tu for update").Check(testkit.Rows("1 2 3 2 4", "6 7 13 42 14"))
- tk.MustQuery("select * from tu where a = 6").Check(testkit.Rows("6 7 13 42 14"))
- tk.MustQuery("select * from tu where c in (5, 6, 13)").Check(testkit.Rows("6 7 13 42 14"))
- tk.MustQuery("select b, c, e, d from tu where c = 13").Check(testkit.Rows("7 13 14 42"))
- tk.MustQuery("select a, e, d from tu where c in (5, 6, 13)").Check(testkit.Rows("6 14 42"))
- tk.MustExec("drop table if exists tu")
-}
-
-func TestUnionAutoSignedCast(t *testing.T) {
- store := testkit.CreateMockStore(t)
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("drop table if exists t1,t2")
- tk.MustExec("create table t1 (id int, i int, b bigint, d double, dd decimal)")
- tk.MustExec("create table t2 (id int, i int unsigned, b bigint unsigned, d double unsigned, dd decimal unsigned)")
- tk.MustExec("insert into t1 values(1, -1, -1, -1.1, -1)")
- tk.MustExec("insert into t2 values(2, 1, 1, 1.1, 1)")
- tk.MustQuery("select * from t1 union select * from t2 order by id").
- Check(testkit.Rows("1 -1 -1 -1.1 -1", "2 1 1 1.1 1"))
- tk.MustQuery("select id, i, b, d, dd from t2 union select id, i, b, d, dd from t1 order by id").
- Check(testkit.Rows("1 -1 -1 -1.1 -1", "2 1 1 1.1 1"))
- tk.MustQuery("select id, i from t2 union select id, cast(i as unsigned int) from t1 order by id").
- Check(testkit.Rows("1 18446744073709551615", "2 1"))
- tk.MustQuery("select dd from t2 union all select dd from t2").
- Check(testkit.Rows("1", "1"))
-
- tk.MustExec("drop table if exists t3,t4")
- tk.MustExec("create table t3 (id int, v int)")
- tk.MustExec("create table t4 (id int, v double unsigned)")
- tk.MustExec("insert into t3 values (1, -1)")
- tk.MustExec("insert into t4 values (2, 1)")
- tk.MustQuery("select id, v from t3 union select id, v from t4 order by id").
- Check(testkit.Rows("1 -1", "2 1"))
- tk.MustQuery("select id, v from t4 union select id, v from t3 order by id").
- Check(testkit.Rows("1 -1", "2 1"))
-
- tk.MustExec("drop table if exists t5,t6,t7")
- tk.MustExec("create table t5 (id int, v bigint unsigned)")
- tk.MustExec("create table t6 (id int, v decimal)")
- tk.MustExec("create table t7 (id int, v bigint)")
- tk.MustExec("insert into t5 values (1, 1)")
- tk.MustExec("insert into t6 values (2, -1)")
- tk.MustExec("insert into t7 values (3, -1)")
- tk.MustQuery("select id, v from t5 union select id, v from t6 order by id").
- Check(testkit.Rows("1 1", "2 -1"))
- tk.MustQuery("select id, v from t5 union select id, v from t7 union select id, v from t6 order by id").
- Check(testkit.Rows("1 1", "2 -1", "3 -1"))
-}
-
func TestUpdateClustered(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
@@ -1494,27 +1307,6 @@ func TestSelectPartition(t *testing.T) {
tk.MustQuery("select * from tscalar where c1 in (-1)").Check(testkit.Rows())
}
-func TestDeletePartition(t *testing.T) {
- store := testkit.CreateMockStore(t)
- tk := testkit.NewTestKit(t, store)
- tk.MustExec(`use test`)
- tk.MustExec(`drop table if exists t1`)
- tk.MustExec(`create table t1 (a int) partition by range (a) (
- partition p0 values less than (10),
- partition p1 values less than (20),
- partition p2 values less than (30),
- partition p3 values less than (40),
- partition p4 values less than MAXVALUE
- )`)
- tk.MustExec("insert into t1 values (1),(11),(21),(31)")
- tk.MustExec("delete from t1 partition (p4)")
- tk.MustQuery("select * from t1 order by a").Check(testkit.Rows("1", "11", "21", "31"))
- tk.MustExec("delete from t1 partition (p0) where a > 10")
- tk.MustQuery("select * from t1 order by a").Check(testkit.Rows("1", "11", "21", "31"))
- tk.MustExec("delete from t1 partition (p0,p1,p2)")
- tk.MustQuery("select * from t1").Check(testkit.Rows("31"))
-}
-
func TestPrepareLoadData(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
@@ -1558,19 +1350,6 @@ func TestPlanReplayerDumpSingle(t *testing.T) {
}
}
-func TestAlterTableComment(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("drop table if exists t_1")
- tk.MustExec("create table t_1 (c1 int, c2 int, c3 int default 1, index (c1)) comment = 'test table';")
- tk.MustExec("alter table `t_1` comment 'this is table comment';")
- tk.MustQuery("select table_comment from information_schema.tables where table_name = 't_1';").Check(testkit.Rows("this is table comment"))
- tk.MustExec("alter table `t_1` comment 'table t comment';")
- tk.MustQuery("select table_comment from information_schema.tables where table_name = 't_1';").Check(testkit.Rows("table t comment"))
-}
-
func TestTimezonePushDown(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
@@ -1722,69 +1501,6 @@ func TestExecutorBit(t *testing.T) {
tk.MustQuery("select * from t where c1").Check(testkit.Rows("\xff\xff\xff\xff\xff\xff\xff\xff", "12345678"))
}
-func TestExecutorEnum(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("create table t (c enum('a', 'b', 'c'))")
- tk.MustExec("insert into t values ('a'), (2), ('c')")
- tk.MustQuery("select * from t where c = 'a'").Check(testkit.Rows("a"))
-
- tk.MustQuery("select c + 1 from t where c = 2").Check(testkit.Rows("3"))
-
- tk.MustExec("delete from t")
- tk.MustExec("insert into t values ()")
- tk.MustExec("insert into t values (null), ('1')")
- tk.MustQuery("select c + 1 from t where c = 1").Check(testkit.Rows("2"))
-
- tk.MustExec("delete from t")
- tk.MustExec("insert into t values(1), (2), (3)")
- tk.MustQuery("select * from t where c").Check(testkit.Rows("a", "b", "c"))
-}
-
-func TestExecutorSet(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("create table t (c set('a', 'b', 'c'))")
- tk.MustExec("insert into t values ('a'), (2), ('c'), ('a,b'), ('b,a')")
- tk.MustQuery("select * from t where c = 'a'").Check(testkit.Rows("a"))
- tk.MustQuery("select * from t where c = 'a,b'").Check(testkit.Rows("a,b", "a,b"))
- tk.MustQuery("select c + 1 from t where c = 2").Check(testkit.Rows("3"))
- tk.MustExec("delete from t")
- tk.MustExec("insert into t values ()")
- tk.MustExec("insert into t values (null), ('1')")
- tk.MustQuery("select c + 1 from t where c = 1").Check(testkit.Rows("2"))
- tk.MustExec("delete from t")
- tk.MustExec("insert into t values(3)")
- tk.MustQuery("select * from t where c").Check(testkit.Rows("a,b"))
-}
-
-func TestSubQueryInValues(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("create table t (id int, name varchar(20))")
- tk.MustExec("create table t1 (gid int)")
- tk.MustExec("insert into t1 (gid) value (1)")
- tk.MustExec("insert into t (id, name) value ((select gid from t1) ,'asd')")
- tk.MustQuery("select * from t").Check(testkit.Rows("1 asd"))
-}
-
-func TestEnhancedRangeAccess(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("create table t (a int primary key, b int)")
- tk.MustExec("insert into t values(1, 2), (2, 1)")
- tk.MustQuery("select * from t where (a = 1 and b = 2) or (a = 2 and b = 1)").Check(testkit.Rows("1 2", "2 1"))
- tk.MustQuery("select * from t where (a = 1 and b = 1) or (a = 2 and b = 2)").Check(nil)
-}
-
// TestMaxInt64Handle Issue #4810
func TestMaxInt64Handle(t *testing.T) {
store := testkit.CreateMockStore(t)
@@ -1802,16 +1518,6 @@ func TestMaxInt64Handle(t *testing.T) {
tk.MustQuery("select * from t").Check(nil)
}
-func TestTableScanWithPointRanges(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("create table t(id int, PRIMARY KEY (id))")
- tk.MustExec("insert into t values(1), (5), (10)")
- tk.MustQuery("select * from t where id in(1, 2, 10)").Check(testkit.Rows("1", "10"))
-}
-
func TestUnsignedPk(t *testing.T) {
store := testkit.CreateMockStore(t)
@@ -1977,18 +1683,6 @@ func setColValue(t *testing.T, txn kv.Transaction, key kv.Key, v types.Datum) {
require.NoError(t, err)
}
-func TestCheckTable(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- // Test 'admin check table' when the table has a unique index with null values.
- tk.MustExec("use test")
- tk.MustExec("drop table if exists admin_test;")
- tk.MustExec("create table admin_test (c1 int, c2 int, c3 int default 1, index (c1), unique key(c2));")
- tk.MustExec("insert admin_test (c1, c2) values (1, 1), (2, 2), (NULL, NULL);")
- tk.MustExec("admin check table admin_test;")
-}
-
func TestCheckTableClusterIndex(t *testing.T) {
store := testkit.CreateMockStore(t)
@@ -2018,252 +1712,6 @@ func TestIncorrectLimitArg(t *testing.T) {
tk.MustGetErrMsg(`execute stmt2 using @a, @a;`, `[planner:1210]Incorrect arguments to LIMIT`)
}
-func TestExecutorLimit(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec(`use test;`)
- tk.MustExec(`create table t(a bigint, b bigint);`)
- tk.MustExec(`insert into t values(1, 1), (2, 2), (3, 30), (4, 40), (5, 5), (6, 6);`)
- tk.MustQuery(`select * from t order by a limit 1, 1;`).Check(testkit.Rows("2 2"))
- tk.MustQuery(`select * from t order by a limit 1, 2;`).Check(testkit.Rows("2 2", "3 30"))
- tk.MustQuery(`select * from t order by a limit 1, 3;`).Check(testkit.Rows("2 2", "3 30", "4 40"))
- tk.MustQuery(`select * from t order by a limit 1, 4;`).Check(testkit.Rows("2 2", "3 30", "4 40", "5 5"))
-
- // test inline projection
- tk.MustQuery(`select a from t where a > 0 limit 1, 1;`).Check(testkit.Rows("2"))
- tk.MustQuery(`select a from t where a > 0 limit 1, 2;`).Check(testkit.Rows("2", "3"))
- tk.MustQuery(`select b from t where a > 0 limit 1, 3;`).Check(testkit.Rows("2", "30", "40"))
- tk.MustQuery(`select b from t where a > 0 limit 1, 4;`).Check(testkit.Rows("2", "30", "40", "5"))
-
- // test @@tidb_init_chunk_size=2
- tk.MustExec(`set @@tidb_init_chunk_size=2;`)
- tk.MustQuery(`select * from t where a > 0 limit 2, 1;`).Check(testkit.Rows("3 30"))
- tk.MustQuery(`select * from t where a > 0 limit 2, 2;`).Check(testkit.Rows("3 30", "4 40"))
- tk.MustQuery(`select * from t where a > 0 limit 2, 3;`).Check(testkit.Rows("3 30", "4 40", "5 5"))
- tk.MustQuery(`select * from t where a > 0 limit 2, 4;`).Check(testkit.Rows("3 30", "4 40", "5 5", "6 6"))
-
- // test inline projection
- tk.MustQuery(`select a from t order by a limit 2, 1;`).Check(testkit.Rows("3"))
- tk.MustQuery(`select b from t order by a limit 2, 2;`).Check(testkit.Rows("30", "40"))
- tk.MustQuery(`select a from t order by a limit 2, 3;`).Check(testkit.Rows("3", "4", "5"))
- tk.MustQuery(`select b from t order by a limit 2, 4;`).Check(testkit.Rows("30", "40", "5", "6"))
-}
-
-func TestIndexScan(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("create table t (a int unique)")
- tk.MustExec("insert t values (-1), (2), (3), (5), (6), (7), (8), (9)")
- tk.MustQuery("select a from t where a < 0 or (a >= 2.1 and a < 5.1) or ( a > 5.9 and a <= 7.9) or a > '8.1'").Check(testkit.Rows("-1", "3", "5", "6", "7", "9"))
-
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t (a int unique)")
- tk.MustExec("insert t values (0)")
- tk.MustQuery("select NULL from t ").Check(testkit.Rows(""))
-
- // test for double read
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t (a int unique, b int)")
- tk.MustExec("insert t values (5, 0)")
- tk.MustExec("insert t values (4, 0)")
- tk.MustExec("insert t values (3, 0)")
- tk.MustExec("insert t values (2, 0)")
- tk.MustExec("insert t values (1, 0)")
- tk.MustExec("insert t values (0, 0)")
- tk.MustQuery("select * from t order by a limit 3").Check(testkit.Rows("0 0", "1 0", "2 0"))
-
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t (a int unique, b int)")
- tk.MustExec("insert t values (0, 1)")
- tk.MustExec("insert t values (1, 2)")
- tk.MustExec("insert t values (2, 1)")
- tk.MustExec("insert t values (3, 2)")
- tk.MustExec("insert t values (4, 1)")
- tk.MustExec("insert t values (5, 2)")
- tk.MustQuery("select * from t where a < 5 and b = 1 limit 2").Check(testkit.Rows("0 1", "2 1"))
-
- tk.MustExec("drop table if exists tab1")
- tk.MustExec("CREATE TABLE tab1(pk INTEGER PRIMARY KEY, col0 INTEGER, col1 FLOAT, col3 INTEGER, col4 FLOAT)")
- tk.MustExec("CREATE INDEX idx_tab1_0 on tab1 (col0)")
- tk.MustExec("CREATE INDEX idx_tab1_1 on tab1 (col1)")
- tk.MustExec("CREATE INDEX idx_tab1_3 on tab1 (col3)")
- tk.MustExec("CREATE INDEX idx_tab1_4 on tab1 (col4)")
- tk.MustExec("INSERT INTO tab1 VALUES(1,37,20.85,30,10.69)")
- tk.MustQuery("SELECT pk FROM tab1 WHERE ((col3 <= 6 OR col3 < 29 AND (col0 < 41)) OR col3 > 42) AND col1 >= 96.1 AND col3 = 30 AND col3 > 17 AND (col0 BETWEEN 36 AND 42)").Check(testkit.Rows())
-
- tk.MustExec("drop table if exists tab1")
- tk.MustExec("CREATE TABLE tab1(pk INTEGER PRIMARY KEY, a INTEGER, b INTEGER)")
- tk.MustExec("CREATE INDEX idx_tab1_0 on tab1 (a)")
- tk.MustExec("INSERT INTO tab1 VALUES(1,1,1)")
- tk.MustExec("INSERT INTO tab1 VALUES(2,2,1)")
- tk.MustExec("INSERT INTO tab1 VALUES(3,1,2)")
- tk.MustExec("INSERT INTO tab1 VALUES(4,2,2)")
- tk.MustQuery("SELECT * FROM tab1 WHERE pk <= 3 AND a = 1").Check(testkit.Rows("1 1 1", "3 1 2"))
- tk.MustQuery("SELECT * FROM tab1 WHERE pk <= 4 AND a = 1 AND b = 2").Check(testkit.Rows("3 1 2"))
-
- tk.MustExec("CREATE INDEX idx_tab1_1 on tab1 (b, a)")
- tk.MustQuery("SELECT pk FROM tab1 WHERE b > 1").Check(testkit.Rows("3", "4"))
-
- tk.MustExec("drop table if exists t")
- tk.MustExec("CREATE TABLE t (a varchar(3), index(a))")
- tk.MustExec("insert t values('aaa'), ('aab')")
- tk.MustQuery("select * from t where a >= 'aaaa' and a < 'aabb'").Check(testkit.Rows("aab"))
-
- tk.MustExec("drop table if exists t")
- tk.MustExec("CREATE TABLE t (a int primary key, b int, c int, index(c))")
- tk.MustExec("insert t values(1, 1, 1), (2, 2, 2), (4, 4, 4), (3, 3, 3), (5, 5, 5)")
- // Test for double read and top n.
- tk.MustQuery("select a from t where c >= 2 order by b desc limit 1").Check(testkit.Rows("5"))
-
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a varchar(50) primary key, b int, c int, index idx(b))")
- tk.MustExec("insert into t values('aa', 1, 1)")
- tk.MustQuery("select * from t use index(idx) where a > 'a'").Check(testkit.Rows("aa 1 1"))
-
- // fix issue9636
- tk.MustExec("drop table if exists t")
- tk.MustExec("CREATE TABLE `t` (a int, KEY (a))")
- tk.MustQuery(`SELECT * FROM (SELECT * FROM (SELECT a as d FROM t WHERE a IN ('100')) AS x WHERE x.d < "123" ) tmp_count`).Check(testkit.Rows())
-}
-
-func TestUpdateJoin(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("create table t1(k int, v int)")
- tk.MustExec("create table t2(k int, v int)")
- tk.MustExec("create table t3(id int auto_increment, k int, v int, primary key(id))")
- tk.MustExec("create table t4(k int, v int)")
- tk.MustExec("create table t5(v int, k int, primary key(k))")
- tk.MustExec("insert into t1 values (1, 1)")
- tk.MustExec("insert into t4 values (3, 3)")
- tk.MustExec("create table t6 (id int, v longtext)")
- tk.MustExec("create table t7 (x int, id int, v longtext, primary key(id))")
-
- // test the normal case that update one row for a single table.
- tk.MustExec("update t1 set v = 0 where k = 1")
- tk.MustQuery("select k, v from t1 where k = 1").Check(testkit.Rows("1 0"))
-
- // test the case that the table with auto_increment or none-null columns as the right table of left join.
- tk.MustExec("update t1 left join t3 on t1.k = t3.k set t1.v = 1")
- tk.MustQuery("select k, v from t1").Check(testkit.Rows("1 1"))
- tk.MustQuery("select id, k, v from t3").Check(testkit.Rows())
-
- // test left join and the case that the right table has no matching record but has updated the right table columns.
- tk.MustExec("update t1 left join t2 on t1.k = t2.k set t1.v = t2.v, t2.v = 3")
- tk.MustQuery("select k, v from t1").Check(testkit.Rows("1 "))
- tk.MustQuery("select k, v from t2").Check(testkit.Rows())
-
- // test the case that the update operation in the left table references data in the right table while data of the right table columns is modified.
- tk.MustExec("update t1 left join t2 on t1.k = t2.k set t2.v = 3, t1.v = t2.v")
- tk.MustQuery("select k, v from t1").Check(testkit.Rows("1 "))
- tk.MustQuery("select k, v from t2").Check(testkit.Rows())
-
- // test right join and the case that the left table has no matching record but has updated the left table columns.
- tk.MustExec("update t2 right join t1 on t2.k = t1.k set t2.v = 4, t1.v = 0")
- tk.MustQuery("select k, v from t1").Check(testkit.Rows("1 0"))
- tk.MustQuery("select k, v from t2").Check(testkit.Rows())
-
- // test the case of right join and left join at the same time.
- tk.MustExec("update t1 left join t2 on t1.k = t2.k right join t4 on t4.k = t2.k set t1.v = 4, t2.v = 4, t4.v = 4")
- tk.MustQuery("select k, v from t1").Check(testkit.Rows("1 0"))
- tk.MustQuery("select k, v from t2").Check(testkit.Rows())
- tk.MustQuery("select k, v from t4").Check(testkit.Rows("3 4"))
-
- // test normal left join and the case that the right table has matching rows.
- tk.MustExec("insert t2 values (1, 10)")
- tk.MustExec("update t1 left join t2 on t1.k = t2.k set t2.v = 11")
- tk.MustQuery("select k, v from t2").Check(testkit.Rows("1 11"))
-
- // test the case of continuously joining the same table and updating the unmatching records.
- tk.MustExec("update t1 t11 left join t2 on t11.k = t2.k left join t1 t12 on t2.v = t12.k set t12.v = 233, t11.v = 111")
- tk.MustQuery("select k, v from t1").Check(testkit.Rows("1 111"))
- tk.MustQuery("select k, v from t2").Check(testkit.Rows("1 11"))
-
- // test the left join case that the left table has records but all records are null.
- tk.MustExec("delete from t1")
- tk.MustExec("delete from t2")
- tk.MustExec("insert into t1 values (null, null)")
- tk.MustExec("update t1 left join t2 on t1.k = t2.k set t1.v = 1")
- tk.MustQuery("select k, v from t1").Check(testkit.Rows(" 1"))
-
- // test the case that the right table of left join has an primary key.
- tk.MustExec("insert t5 values(0, 0)")
- tk.MustExec("update t1 left join t5 on t1.k = t5.k set t1.v = 2")
- tk.MustQuery("select k, v from t1").Check(testkit.Rows(" 2"))
- tk.MustQuery("select k, v from t5").Check(testkit.Rows("0 0"))
-
- tk.MustExec("insert into t6 values (1, NULL)")
- tk.MustExec("insert into t7 values (5, 1, 'a')")
- tk.MustExec("update t6, t7 set t6.v = t7.v where t6.id = t7.id and t7.x = 5")
- tk.MustQuery("select v from t6").Check(testkit.Rows("a"))
-
- tk.MustExec("drop table if exists t1, t2")
- tk.MustExec("create table t1(id int primary key, v int, gv int GENERATED ALWAYS AS (v * 2) STORED)")
- tk.MustExec("create table t2(id int, v int)")
- tk.MustExec("update t1 tt1 inner join (select count(t1.id) a, t1.id from t1 left join t2 on t1.id = t2.id group by t1.id) x on tt1.id = x.id set tt1.v = tt1.v + x.a")
-}
-
-func TestScanControlSelection(t *testing.T) {
- store := testkit.CreateMockStore(t)
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("create table t(a int primary key, b int, c int, index idx_b(b))")
- tk.MustExec("insert into t values (1, 1, 1), (2, 1, 1), (3, 1, 2), (4, 2, 3)")
- tk.MustQuery("select (select count(1) k from t s where s.b = t1.c) from t t1").Sort().Check(testkit.Rows("0", "1", "3", "3"))
-}
-
-func TestSimpleDAG(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("create table t(a int primary key, b int, c int)")
- tk.MustExec("insert into t values (1, 1, 1), (2, 1, 1), (3, 1, 2), (4, 2, 3)")
- tk.MustQuery("select a from t").Check(testkit.Rows("1", "2", "3", "4"))
- tk.MustQuery("select * from t where a = 4").Check(testkit.Rows("4 2 3"))
- tk.MustQuery("select a from t limit 1").Check(testkit.Rows("1"))
- tk.MustQuery("select a from t order by a desc").Check(testkit.Rows("4", "3", "2", "1"))
- tk.MustQuery("select a from t order by a desc limit 1").Check(testkit.Rows("4"))
- tk.MustQuery("select a from t order by b desc limit 1").Check(testkit.Rows("4"))
- tk.MustQuery("select a from t where a < 3").Check(testkit.Rows("1", "2"))
- tk.MustQuery("select a from t where b > 1").Check(testkit.Rows("4"))
- tk.MustQuery("select a from t where b > 1 and a < 3").Check(testkit.Rows())
- tk.MustQuery("select count(*) from t where b > 1 and a < 3").Check(testkit.Rows("0"))
- tk.MustQuery("select count(*) from t").Check(testkit.Rows("4"))
- tk.MustQuery("select count(*), c from t group by c order by c").Check(testkit.Rows("2 1", "1 2", "1 3"))
- tk.MustQuery("select sum(c) as s from t group by b order by s").Check(testkit.Rows("3", "4"))
- tk.MustQuery("select avg(a) as s from t group by b order by s").Check(testkit.Rows("2.0000", "4.0000"))
- tk.MustQuery("select sum(distinct c) from t group by b").Check(testkit.Rows("3", "3"))
-
- tk.MustExec("create index i on t(c,b)")
- tk.MustQuery("select a from t where c = 1").Check(testkit.Rows("1", "2"))
- tk.MustQuery("select a from t where c = 1 and a < 2").Check(testkit.Rows("1"))
- tk.MustQuery("select a from t where c = 1 order by a limit 1").Check(testkit.Rows("1"))
- tk.MustQuery("select count(*) from t where c = 1 ").Check(testkit.Rows("2"))
- tk.MustExec("create index i1 on t(b)")
- tk.MustQuery("select c from t where b = 2").Check(testkit.Rows("3"))
- tk.MustQuery("select * from t where b = 2").Check(testkit.Rows("4 2 3"))
- tk.MustQuery("select count(*) from t where b = 1").Check(testkit.Rows("3"))
- tk.MustQuery("select * from t where b = 1 and a > 1 limit 1").Check(testkit.Rows("2 1 1"))
-
- // Test time push down.
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t (id int, c1 datetime);")
- tk.MustExec("insert into t values (1, '2015-06-07 12:12:12')")
- tk.MustQuery("select id from t where c1 = '2015-06-07 12:12:12'").Check(testkit.Rows("1"))
-
- // Test issue 17816
- tk.MustExec("drop table if exists t0")
- tk.MustExec("CREATE TABLE t0(c0 INT)")
- tk.MustExec("INSERT INTO t0 VALUES (100000)")
- tk.MustQuery("SELECT * FROM t0 WHERE NOT SPACE(t0.c0)").Check(testkit.Rows("100000"))
-}
-
func TestTimestampTimeZone(t *testing.T) {
store := testkit.CreateMockStore(t)
@@ -2724,17 +2172,6 @@ func TestPartitionHashCode(t *testing.T) {
wg.Wait()
}
-func TestAlterDefaultValue(t *testing.T) {
- store := testkit.CreateMockStore(t)
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("create table t(a int, primary key(a))")
- tk.MustExec("insert into t(a) values(1)")
- tk.MustExec("alter table t add column b int default 1")
- tk.MustExec("alter table t alter b set default 2")
- tk.MustQuery("select b from t where a = 1").Check(testkit.Rows("1"))
-}
-
// this is from jira issue #5856
func TestInsertValuesWithSubQuery(t *testing.T) {
store := testkit.CreateMockStore(t)
@@ -2980,20 +2417,6 @@ func TestIndexMergeRuntimeStats(t *testing.T) {
tk.MustQuery("select /*+ use_index_merge(t1, primary, t1a) */ * from t1 where id < 2 or a > 4 order by a").Check(testkit.Rows("1 1 1 1 1", "5 5 5 5 5"))
}
-// For issue 17256
-func TestGenerateColumnReplace(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test;")
- tk.MustExec("create table t1 (a int, b int as (a + 1) virtual not null, unique index idx(b));")
- tk.MustExec("REPLACE INTO `t1` (`a`) VALUES (2);")
- tk.MustExec("REPLACE INTO `t1` (`a`) VALUES (2);")
- tk.MustQuery("select * from t1").Check(testkit.Rows("2 3"))
- tk.MustExec("insert into `t1` (`a`) VALUES (2) on duplicate key update a = 3;")
- tk.MustQuery("select * from t1").Check(testkit.Rows("3 4"))
-}
-
func TestPrevStmtDesensitization(t *testing.T) {
store := testkit.CreateMockStore(t)
@@ -3008,18 +2431,6 @@ func TestPrevStmtDesensitization(t *testing.T) {
tk.MustGetErrMsg("insert into t values (1)", `[kv:1062]Duplicate entry '?' for key 't.a'`)
}
-func TestIssue19372(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test;")
- tk.MustExec("create table t1 (c_int int, c_str varchar(40), key(c_str));")
- tk.MustExec("create table t2 like t1;")
- tk.MustExec("insert into t1 values (1, 'a'), (2, 'b'), (3, 'c');")
- tk.MustExec("insert into t2 select * from t1;")
- tk.MustQuery("select (select t2.c_str from t2 where t2.c_str <= t1.c_str and t2.c_int in (1, 2) order by t2.c_str limit 1) x from t1 order by c_int;").Check(testkit.Rows("a", "a", "a"))
-}
-
func TestIssue19148(t *testing.T) {
store := testkit.CreateMockStore(t)
@@ -3455,38 +2866,6 @@ func TestEncodingSet(t *testing.T) {
tk.MustExec("admin check table `enum-set`")
}
-func TestDeleteWithMulTbl(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
-
- // Delete multiple tables from left joined table.
- // The result of left join is (3, null, null).
- // Because rows in t2 are not matched, so no row will be deleted in t2.
- // But row in t1 is matched, so it should be deleted.
- tk.MustExec("use test;")
- tk.MustExec("drop table if exists t1, t2;")
- tk.MustExec("create table t1 (c1 int);")
- tk.MustExec("create table t2 (c1 int primary key, c2 int);")
- tk.MustExec("insert into t1 values(3);")
- tk.MustExec("insert into t2 values(2, 2);")
- tk.MustExec("insert into t2 values(0, 0);")
- tk.MustExec("delete from t1, t2 using t1 left join t2 on t1.c1 = t2.c2;")
- tk.MustQuery("select * from t1 order by c1;").Check(testkit.Rows())
- tk.MustQuery("select * from t2 order by c1;").Check(testkit.Rows("0 0", "2 2"))
-
- // Rows in both t1 and t2 are matched, so will be deleted even if it's null.
- // NOTE: The null values are not generated by join.
- tk.MustExec("drop table if exists t1, t2;")
- tk.MustExec("create table t1 (c1 int);")
- tk.MustExec("create table t2 (c2 int);")
- tk.MustExec("insert into t1 values(null);")
- tk.MustExec("insert into t2 values(null);")
- tk.MustExec("delete from t1, t2 using t1 join t2 where t1.c1 is null;")
- tk.MustQuery("select * from t1;").Check(testkit.Rows())
- tk.MustQuery("select * from t2;").Check(testkit.Rows())
-}
-
func TestOOMPanicAction(t *testing.T) {
store, dom := testkit.CreateMockStoreAndDomain(t)
@@ -4068,35 +3447,6 @@ func TestCollectDMLRuntimeStats(t *testing.T) {
tk.MustExec("rollback")
}
-func TestIssue13758(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("drop table if exists t1, t2")
- tk.MustExec("create table t1 (pk int(11) primary key, a int(11) not null, b int(11), key idx_b(b), key idx_a(a))")
- tk.MustExec("insert into `t1` values (1,1,0),(2,7,6),(3,2,null),(4,1,null),(5,4,5)")
- tk.MustExec("create table t2 (a int)")
- tk.MustExec("insert into t2 values (1),(null)")
- tk.MustQuery("select (select a from t1 use index(idx_a) where b >= t2.a order by a limit 1) as field from t2").Check(testkit.Rows(
- "4",
- "",
- ))
-}
-
-func TestIssue20237(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("drop table if exists t, s")
- tk.MustExec("create table t(a date, b float)")
- tk.MustExec("create table s(b float)")
- tk.MustExec(`insert into t values(NULL,-37), ("2011-11-04",105), ("2013-03-02",-22), ("2006-07-02",-56), (NULL,124), (NULL,111), ("2018-03-03",-5);`)
- tk.MustExec(`insert into s values(-37),(105),(-22),(-56),(124),(105),(111),(-5);`)
- tk.MustQuery(`select count(distinct t.a, t.b) from t join s on t.b= s.b;`).Check(testkit.Rows("4"))
-}
-
func TestIssue24933(t *testing.T) {
store := testkit.CreateMockStore(t)
@@ -4723,81 +4073,6 @@ func TestYearTypeDeleteIndex(t *testing.T) {
tk.MustExec("admin check table t")
}
-func TestToPBExpr(t *testing.T) {
- store := testkit.CreateMockStore(t)
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t (a decimal(10,6), b decimal, index idx_b (b))")
- tk.MustExec("set sql_mode = ''")
- tk.MustExec("insert t values (1.1, 1.1)")
- tk.MustExec("insert t values (2.4, 2.4)")
- tk.MustExec("insert t values (3.3, 2.7)")
- result := tk.MustQuery("select * from t where a < 2.399999")
- result.Check(testkit.Rows("1.100000 1"))
- result = tk.MustQuery("select * from t where a > 1.5")
- result.Check(testkit.Rows("2.400000 2", "3.300000 3"))
- result = tk.MustQuery("select * from t where a <= 1.1")
- result.Check(testkit.Rows("1.100000 1"))
- result = tk.MustQuery("select * from t where b >= 3")
- result.Check(testkit.Rows("3.300000 3"))
- result = tk.MustQuery("select * from t where not (b = 1)")
- result.Check(testkit.Rows("2.400000 2", "3.300000 3"))
- result = tk.MustQuery("select * from t where b&1 = a|1")
- result.Check(testkit.Rows("1.100000 1"))
- result = tk.MustQuery("select * from t where b != 2 and b <=> 3")
- result.Check(testkit.Rows("3.300000 3"))
- result = tk.MustQuery("select * from t where b in (3)")
- result.Check(testkit.Rows("3.300000 3"))
- result = tk.MustQuery("select * from t where b not in (1, 2)")
- result.Check(testkit.Rows("3.300000 3"))
-
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t (a varchar(255), b int)")
- tk.MustExec("insert t values ('abc123', 1)")
- tk.MustExec("insert t values ('ab123', 2)")
- result = tk.MustQuery("select * from t where a like 'ab%'")
- result.Check(testkit.Rows("abc123 1", "ab123 2"))
- result = tk.MustQuery("select * from t where a like 'ab_12'")
- result.Check(nil)
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t (a int primary key)")
- tk.MustExec("insert t values (1)")
- tk.MustExec("insert t values (2)")
- result = tk.MustQuery("select * from t where not (a = 1)")
- result.Check(testkit.Rows("2"))
- result = tk.MustQuery("select * from t where not(not (a = 1))")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select * from t where not(a != 1 and a != 2)")
- result.Check(testkit.Rows("1", "2"))
-}
-
-func TestDatumXAPI(t *testing.T) {
- store := testkit.CreateMockStore(t)
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t (a decimal(10,6), b decimal, index idx_b (b))")
- tk.MustExec("set sql_mode = ''")
- tk.MustExec("insert t values (1.1, 1.1)")
- tk.MustExec("insert t values (2.2, 2.2)")
- tk.MustExec("insert t values (3.3, 2.7)")
- result := tk.MustQuery("select * from t where a > 1.5")
- result.Check(testkit.Rows("2.200000 2", "3.300000 3"))
- result = tk.MustQuery("select * from t where b > 1.5")
- result.Check(testkit.Rows("2.200000 2", "3.300000 3"))
-
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t (a time(3), b time, index idx_a (a))")
- tk.MustExec("insert t values ('11:11:11', '11:11:11')")
- tk.MustExec("insert t values ('11:11:12', '11:11:12')")
- tk.MustExec("insert t values ('11:11:13', '11:11:13')")
- result = tk.MustQuery("select * from t where a > '11:11:11.5'")
- result.Check(testkit.Rows("11:11:12.000 11:11:12", "11:11:13.000 11:11:13"))
- result = tk.MustQuery("select * from t where b > '11:11:11.5'")
- result.Check(testkit.Rows("11:11:12.000 11:11:12", "11:11:13.000 11:11:13"))
-}
-
func TestSQLMode(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
@@ -4842,24 +4117,6 @@ func TestSQLMode(t *testing.T) {
tk.MustExec("set @@global.sql_mode = 'STRICT_TRANS_TABLES'")
}
-func TestTableDual(t *testing.T) {
- store := testkit.CreateMockStore(t)
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- result := tk.MustQuery("Select 1")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("Select 1 from dual")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("Select count(*) from dual")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("Select 1 from dual where 1")
- result.Check(testkit.Rows("1"))
-
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a int primary key)")
- tk.MustQuery("select t1.* from t t1, t t2 where t1.a=t2.a and 1=0").Check(testkit.Rows())
-}
-
func TestTableScan(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
@@ -4959,82 +4216,6 @@ func TestClusteredIndexIsPointGet(t *testing.T) {
}
}
-func TestRow(t *testing.T) {
- store := testkit.CreateMockStore(t)
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t (c int, d int)")
- tk.MustExec("insert t values (1, 1)")
- tk.MustExec("insert t values (1, 3)")
- tk.MustExec("insert t values (2, 1)")
- tk.MustExec("insert t values (2, 3)")
- result := tk.MustQuery("select * from t where (c, d) < (2,2)")
- result.Check(testkit.Rows("1 1", "1 3", "2 1"))
- result = tk.MustQuery("select * from t where (1,2,3) > (3,2,1)")
- result.Check(testkit.Rows())
- result = tk.MustQuery("select * from t where row(1,2,3) > (3,2,1)")
- result.Check(testkit.Rows())
- result = tk.MustQuery("select * from t where (c, d) = (select * from t where (c,d) = (1,1))")
- result.Check(testkit.Rows("1 1"))
- result = tk.MustQuery("select * from t where (c, d) = (select * from t k where (t.c,t.d) = (c,d))")
- result.Check(testkit.Rows("1 1", "1 3", "2 1", "2 3"))
- result = tk.MustQuery("select (1, 2, 3) < (2, 3, 4)")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select (2, 3, 4) <= (2, 3, 3)")
- result.Check(testkit.Rows("0"))
- result = tk.MustQuery("select (2, 3, 4) <= (2, 3, 4)")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select (2, 3, 4) <= (2, 1, 4)")
- result.Check(testkit.Rows("0"))
- result = tk.MustQuery("select (2, 3, 4) >= (2, 3, 4)")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select (2, 3, 4) = (2, 3, 4)")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select (2, 3, 4) != (2, 3, 4)")
- result.Check(testkit.Rows("0"))
- result = tk.MustQuery("select row(1, 1) in (row(1, 1))")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select row(1, 0) in (row(1, 1))")
- result.Check(testkit.Rows("0"))
- result = tk.MustQuery("select row(1, 1) in (select 1, 1)")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select row(1, 1) > row(1, 0)")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select row(1, 1) > (select 1, 0)")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select 1 > (select 1)")
- result.Check(testkit.Rows("0"))
- result = tk.MustQuery("select (select 1)")
- result.Check(testkit.Rows("1"))
-
- tk.MustExec("drop table if exists t1")
- tk.MustExec("create table t1 (a int, b int)")
- tk.MustExec("insert t1 values (1,2),(1,null)")
- tk.MustExec("drop table if exists t2")
- tk.MustExec("create table t2 (c int, d int)")
- tk.MustExec("insert t2 values (0,0)")
-
- tk.MustQuery("select * from t2 where (1,2) in (select * from t1)").Check(testkit.Rows("0 0"))
- tk.MustQuery("select * from t2 where (1,2) not in (select * from t1)").Check(testkit.Rows())
- tk.MustQuery("select * from t2 where (1,1) not in (select * from t1)").Check(testkit.Rows())
- tk.MustQuery("select * from t2 where (1,null) in (select * from t1)").Check(testkit.Rows())
- tk.MustQuery("select * from t2 where (null,null) in (select * from t1)").Check(testkit.Rows())
-
- tk.MustExec("delete from t1 where a=1 and b=2")
- tk.MustQuery("select (1,1) in (select * from t2) from t1").Check(testkit.Rows("0"))
- tk.MustQuery("select (1,1) not in (select * from t2) from t1").Check(testkit.Rows("1"))
- tk.MustQuery("select (1,1) in (select 1,1 from t2) from t1").Check(testkit.Rows("1"))
- tk.MustQuery("select (1,1) not in (select 1,1 from t2) from t1").Check(testkit.Rows("0"))
-
- // MySQL 5.7 returns 1 for these 2 queries, which is wrong.
- tk.MustQuery("select (1,null) not in (select 1,1 from t2) from t1").Check(testkit.Rows(""))
- tk.MustQuery("select (t1.a,null) not in (select 1,1 from t2) from t1").Check(testkit.Rows(""))
-
- tk.MustQuery("select (1,null) in (select * from t1)").Check(testkit.Rows(""))
- tk.MustQuery("select (1,null) not in (select * from t1)").Check(testkit.Rows(""))
-}
-
func TestColumnName(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
@@ -5369,54 +4550,6 @@ func TestCurrentTimestampValueSelection(t *testing.T) {
require.Equal(t, 3, len(strings.Split(d, ".")[1]))
}
-func TestStrToDateBuiltin(t *testing.T) {
- store := testkit.CreateMockStore(t)
- tk := testkit.NewTestKit(t, store)
- tk.MustQuery(`select str_to_date('20190101','%Y%m%d%!') from dual`).Check(testkit.Rows("2019-01-01"))
- tk.MustQuery(`select str_to_date('20190101','%Y%m%d%f') from dual`).Check(testkit.Rows("2019-01-01 00:00:00.000000"))
- tk.MustQuery(`select str_to_date('20190101','%Y%m%d%H%i%s') from dual`).Check(testkit.Rows("2019-01-01 00:00:00"))
- tk.MustQuery(`select str_to_date('18/10/22','%y/%m/%d') from dual`).Check(testkit.Rows("2018-10-22"))
- tk.MustQuery(`select str_to_date('a18/10/22','%y/%m/%d') from dual`).Check(testkit.Rows(""))
- tk.MustQuery(`select str_to_date('69/10/22','%y/%m/%d') from dual`).Check(testkit.Rows("2069-10-22"))
- tk.MustQuery(`select str_to_date('70/10/22','%y/%m/%d') from dual`).Check(testkit.Rows("1970-10-22"))
- tk.MustQuery(`select str_to_date('8/10/22','%y/%m/%d') from dual`).Check(testkit.Rows("2008-10-22"))
- tk.MustQuery(`select str_to_date('8/10/22','%Y/%m/%d') from dual`).Check(testkit.Rows("2008-10-22"))
- tk.MustQuery(`select str_to_date('18/10/22','%Y/%m/%d') from dual`).Check(testkit.Rows("2018-10-22"))
- tk.MustQuery(`select str_to_date('a18/10/22','%Y/%m/%d') from dual`).Check(testkit.Rows(""))
- tk.MustQuery(`select str_to_date('69/10/22','%Y/%m/%d') from dual`).Check(testkit.Rows("2069-10-22"))
- tk.MustQuery(`select str_to_date('70/10/22','%Y/%m/%d') from dual`).Check(testkit.Rows("1970-10-22"))
- tk.MustQuery(`select str_to_date('018/10/22','%Y/%m/%d') from dual`).Check(testkit.Rows("0018-10-22"))
- tk.MustQuery(`select str_to_date('2018/10/22','%Y/%m/%d') from dual`).Check(testkit.Rows("2018-10-22"))
- tk.MustQuery(`select str_to_date('018/10/22','%y/%m/%d') from dual`).Check(testkit.Rows(""))
- tk.MustQuery(`select str_to_date('18/10/22','%y0/%m/%d') from dual`).Check(testkit.Rows(""))
- tk.MustQuery(`select str_to_date('18/10/22','%Y0/%m/%d') from dual`).Check(testkit.Rows(""))
- tk.MustQuery(`select str_to_date('18a/10/22','%y/%m/%d') from dual`).Check(testkit.Rows(""))
- tk.MustQuery(`select str_to_date('18a/10/22','%Y/%m/%d') from dual`).Check(testkit.Rows(""))
- tk.MustQuery(`select str_to_date('20188/10/22','%Y/%m/%d') from dual`).Check(testkit.Rows(""))
- tk.MustQuery(`select str_to_date('2018510522','%Y5%m5%d') from dual`).Check(testkit.Rows("2018-10-22"))
- tk.MustQuery(`select str_to_date('2018^10^22','%Y^%m^%d') from dual`).Check(testkit.Rows("2018-10-22"))
- tk.MustQuery(`select str_to_date('2018@10@22','%Y@%m@%d') from dual`).Check(testkit.Rows("2018-10-22"))
- tk.MustQuery(`select str_to_date('2018%10%22','%Y%%m%%d') from dual`).Check(testkit.Rows(""))
- tk.MustQuery(`select str_to_date('2018(10(22','%Y(%m(%d') from dual`).Check(testkit.Rows("2018-10-22"))
- tk.MustQuery(`select str_to_date('2018\10\22','%Y\%m\%d') from dual`).Check(testkit.Rows(""))
- tk.MustQuery(`select str_to_date('2018=10=22','%Y=%m=%d') from dual`).Check(testkit.Rows("2018-10-22"))
- tk.MustQuery(`select str_to_date('2018+10+22','%Y+%m+%d') from dual`).Check(testkit.Rows("2018-10-22"))
- tk.MustQuery(`select str_to_date('2018_10_22','%Y_%m_%d') from dual`).Check(testkit.Rows("2018-10-22"))
- tk.MustQuery(`select str_to_date('69510522','%y5%m5%d') from dual`).Check(testkit.Rows("2069-10-22"))
- tk.MustQuery(`select str_to_date('69^10^22','%y^%m^%d') from dual`).Check(testkit.Rows("2069-10-22"))
- tk.MustQuery(`select str_to_date('18@10@22','%y@%m@%d') from dual`).Check(testkit.Rows("2018-10-22"))
- tk.MustQuery(`select str_to_date('18%10%22','%y%%m%%d') from dual`).Check(testkit.Rows(""))
- tk.MustQuery(`select str_to_date('18(10(22','%y(%m(%d') from dual`).Check(testkit.Rows("2018-10-22"))
- tk.MustQuery(`select str_to_date('18\10\22','%y\%m\%d') from dual`).Check(testkit.Rows(""))
- tk.MustQuery(`select str_to_date('18+10+22','%y+%m+%d') from dual`).Check(testkit.Rows("2018-10-22"))
- tk.MustQuery(`select str_to_date('18=10=22','%y=%m=%d') from dual`).Check(testkit.Rows("2018-10-22"))
- tk.MustQuery(`select str_to_date('18_10_22','%y_%m_%d') from dual`).Check(testkit.Rows("2018-10-22"))
- tk.MustQuery(`SELECT STR_TO_DATE('2020-07-04 11:22:33 PM', '%Y-%m-%d %r')`).Check(testkit.Rows("2020-07-04 23:22:33"))
- tk.MustQuery(`SELECT STR_TO_DATE('2020-07-04 12:22:33 AM', '%Y-%m-%d %r')`).Check(testkit.Rows("2020-07-04 00:22:33"))
- tk.MustQuery(`SELECT STR_TO_DATE('2020-07-04 12:22:33', '%Y-%m-%d %T')`).Check(testkit.Rows("2020-07-04 12:22:33"))
- tk.MustQuery(`SELECT STR_TO_DATE('2020-07-04 00:22:33', '%Y-%m-%d %T')`).Check(testkit.Rows("2020-07-04 00:22:33"))
-}
-
func TestAddDateBuiltinWithWarnings(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
@@ -5445,39 +4578,6 @@ func TestStrToDateBuiltinWithWarnings(t *testing.T) {
))
}
-func TestReadPartitionedTable(t *testing.T) {
- // Test three reader on partitioned table.
- store := testkit.CreateMockStore(t)
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("drop table if exists pt")
- tk.MustExec("create table pt (a int, b int, index i_b(b)) partition by range (a) (partition p1 values less than (2), partition p2 values less than (4), partition p3 values less than (6))")
- for i := 0; i < 6; i++ {
- tk.MustExec(fmt.Sprintf("insert into pt values(%d, %d)", i, i))
- }
- // Table reader
- tk.MustQuery("select * from pt order by a").Check(testkit.Rows("0 0", "1 1", "2 2", "3 3", "4 4", "5 5"))
- // Index reader
- tk.MustQuery("select b from pt where b = 3").Check(testkit.Rows("3"))
- // Index lookup
- tk.MustQuery("select a from pt where b = 3").Check(testkit.Rows("3"))
-}
-
-func TestIssue10435(t *testing.T) {
- store := testkit.CreateMockStore(t)
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("drop table if exists t1")
- tk.MustExec("create table t1(i int, j int, k int)")
- tk.MustExec("insert into t1 VALUES (1,1,1),(2,2,2),(3,3,3),(4,4,4)")
- tk.MustExec("INSERT INTO t1 SELECT 10*i,j,5*j FROM t1 UNION SELECT 20*i,j,5*j FROM t1 UNION SELECT 30*i,j,5*j FROM t1")
-
- tk.MustExec("set @@session.tidb_enable_window_function=1")
- tk.MustQuery("SELECT SUM(i) OVER W FROM t1 WINDOW w AS (PARTITION BY j ORDER BY i) ORDER BY 1+SUM(i) OVER w").Check(
- testkit.Rows("1", "2", "3", "4", "11", "22", "31", "33", "44", "61", "62", "93", "122", "124", "183", "244"),
- )
-}
-
func TestAdmin(t *testing.T) {
var cluster testutils.Cluster
store := testkit.CreateMockStore(t, mockstore.WithClusterInspector(func(c testutils.Cluster) {
@@ -5811,30 +4911,6 @@ func TestUnsignedDecimalOverflow(t *testing.T) {
tk.MustQuery("select a from t limit 1").Check(testkit.Rows("0.00"))
}
-func TestIndexJoinTableDualPanic(t *testing.T) {
- store := testkit.CreateMockStore(t)
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("drop table if exists a")
- tk.MustExec("create table a (f1 int, f2 varchar(32), primary key (f1))")
- tk.MustExec("insert into a (f1,f2) values (1,'a'), (2,'b'), (3,'c')")
- // TODO here: index join cause the data race of txn.
- tk.MustQuery("select /*+ inl_merge_join(a) */ a.* from a inner join (select 1 as k1,'k2-1' as k2) as k on a.f1=k.k1;").
- Check(testkit.Rows("1 a"))
-}
-
-func TestSortLeftJoinWithNullColumnInRightChildPanic(t *testing.T) {
- store := testkit.CreateMockStore(t)
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("drop table if exists t1, t2")
- tk.MustExec("create table t1(a int)")
- tk.MustExec("create table t2(a int)")
- tk.MustExec("insert into t1(a) select 1;")
- tk.MustQuery("select b.n from t1 left join (select a as a, null as n from t2) b on b.a = t1.a order by t1.a").
- Check(testkit.Rows(""))
-}
-
func TestMaxOneRow(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
@@ -6220,37 +5296,6 @@ func TestSessionRootTrackerDetach(t *testing.T) {
require.Nil(t, tk.Session().GetSessionVars().MemTracker.GetFallbackForTest(false))
}
-func TestIssue39211(t *testing.T) {
- store := testkit.CreateMockStore(t)
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("drop table if exists t;")
- tk.MustExec("drop table if exists s;")
-
- tk.MustExec("CREATE TABLE `t` ( `a` int(11) DEFAULT NULL, `b` int(11) DEFAULT NULL);")
- tk.MustExec("CREATE TABLE `s` ( `a` int(11) DEFAULT NULL, `b` int(11) DEFAULT NULL);")
- tk.MustExec("insert into t values(1,1),(2,2);")
- tk.MustExec("insert into t select * from t;")
- tk.MustExec("insert into t select * from t;")
- tk.MustExec("insert into t select * from t;")
- tk.MustExec("insert into t select * from t;")
- tk.MustExec("insert into t select * from t;")
- tk.MustExec("insert into t select * from t;")
- tk.MustExec("insert into t select * from t;")
- tk.MustExec("insert into t select * from t;")
-
- tk.MustExec("insert into s values(3,3),(4,4),(1,null),(2,null),(null,null);")
- tk.MustExec("insert into s select * from s;")
- tk.MustExec("insert into s select * from s;")
- tk.MustExec("insert into s select * from s;")
- tk.MustExec("insert into s select * from s;")
- tk.MustExec("insert into s select * from s;")
-
- tk.MustExec("set @@tidb_max_chunk_size=32;")
- tk.MustExec("set @@tidb_enable_null_aware_anti_join=true;")
- tk.MustQuery("select * from t where (a,b) not in (select a, b from s);").Check(testkit.Rows())
-}
-
func TestPlanReplayerDumpTPCDS(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
diff --git a/executor/test/indexmergereadtest/index_merge_reader_test.go b/executor/test/indexmergereadtest/index_merge_reader_test.go
index 6e37c7e6920b9..53930ab8a7d4d 100644
--- a/executor/test/indexmergereadtest/index_merge_reader_test.go
+++ b/executor/test/indexmergereadtest/index_merge_reader_test.go
@@ -1072,44 +1072,44 @@ func TestOrderByWithLimit(t *testing.T) {
limit := rand.Intn(10) + 1
queryHandle := fmt.Sprintf("select /*+ use_index_merge(thandle, idx_ac, idx_bc) */ * from thandle where a = %v or b = %v order by c limit %v", a, b, limit)
resHandle := tk.MustQuery(queryHandle).Rows()
- require.True(t, tk.HasPlan(queryHandle, "IndexMerge"))
- require.False(t, tk.HasPlan(queryHandle, "TopN"))
+ tk.MustHavePlan(queryHandle, "IndexMerge")
+ tk.MustNotHavePlan(queryHandle, "TopN")
queryPK := fmt.Sprintf("select /*+ use_index_merge(tpk, idx_ac, idx_bc) */ * from tpk where a = %v or b = %v order by c limit %v", a, b, limit)
resPK := tk.MustQuery(queryPK).Rows()
- require.True(t, tk.HasPlan(queryPK, "IndexMerge"))
- require.False(t, tk.HasPlan(queryPK, "TopN"))
+ tk.MustHavePlan(queryPK, "IndexMerge")
+ tk.MustNotHavePlan(queryPK, "TopN")
queryCommon := fmt.Sprintf("select /*+ use_index_merge(tcommon, idx_ac, idx_bc) */ * from tcommon where a = %v or b = %v order by c limit %v", a, b, limit)
resCommon := tk.MustQuery(queryCommon).Rows()
- require.True(t, tk.HasPlan(queryCommon, "IndexMerge"))
- require.False(t, tk.HasPlan(queryCommon, "TopN"))
+ tk.MustHavePlan(queryCommon, "IndexMerge")
+ tk.MustNotHavePlan(queryCommon, "TopN")
queryTableScan := fmt.Sprintf("select /*+ use_index_merge(tcommon, primary, idx_bc) */ * from tcommon where a = %v or b = %v order by c limit %v", a, b, limit)
resTableScan := tk.MustQuery(queryTableScan).Rows()
- require.True(t, tk.HasPlan(queryTableScan, "IndexMerge"))
- require.True(t, tk.HasPlan(queryTableScan, "TableRangeScan"))
- require.False(t, tk.HasPlan(queryTableScan, "TopN"))
+ tk.MustHavePlan(queryTableScan, "IndexMerge")
+ tk.MustHavePlan(queryTableScan, "TableRangeScan")
+ tk.MustNotHavePlan(queryTableScan, "TopN")
queryHash := fmt.Sprintf("select /*+ use_index_merge(thash, idx_ac, idx_bc) */ * from thash where a = %v or b = %v order by c limit %v", a, b, limit)
resHash := tk.MustQuery(queryHash).Rows()
- require.True(t, tk.HasPlan(queryHash, "IndexMerge"))
+ tk.MustHavePlan(queryHash, "IndexMerge")
if i%2 == 1 {
- require.False(t, tk.HasPlan(queryHash, "TopN"))
+ tk.MustNotHavePlan(queryHash, "TopN")
}
queryCommonHash := fmt.Sprintf("select /*+ use_index_merge(tcommonhash, primary, idx_bc) */ * from tcommonhash where a = %v or b = %v order by c limit %v", a, b, limit)
resCommonHash := tk.MustQuery(queryCommonHash).Rows()
- require.True(t, tk.HasPlan(queryCommonHash, "IndexMerge"))
+ tk.MustHavePlan(queryCommonHash, "IndexMerge")
if i%2 == 1 {
- require.False(t, tk.HasPlan(queryCommonHash, "TopN"))
+ tk.MustNotHavePlan(queryCommonHash, "TopN")
}
queryPKHash := fmt.Sprintf("select /*+ use_index_merge(tpkhash, idx_ac, idx_bc) */ * from tpkhash where a = %v or b = %v order by c limit %v", a, b, limit)
resPKHash := tk.MustQuery(queryPKHash).Rows()
- require.True(t, tk.HasPlan(queryPKHash, "IndexMerge"))
+ tk.MustHavePlan(queryPKHash, "IndexMerge")
if i%2 == 1 {
- require.False(t, tk.HasPlan(queryPKHash, "TopN"))
+ tk.MustNotHavePlan(queryPKHash, "TopN")
}
sliceRes := getResult(valueSlice, a, b, limit, false)
@@ -1205,9 +1205,9 @@ func TestIndexMergeLimitPushedAsIntersectionEmbeddedLimit(t *testing.T) {
valA, valB, valC, limit := rand.Intn(100), rand.Intn(100), rand.Intn(50), rand.Intn(100)+1
queryTableScan := fmt.Sprintf("select * from t use index() where a > %d and b > %d and c >= %d limit %d", valA, valB, valC, limit)
queryWithIndexMerge := fmt.Sprintf("select /*+ USE_INDEX_MERGE(t, idx, idx2) */ * from t where a > %d and b > %d and c >= %d limit %d", valA, valB, valC, limit)
- require.True(t, tk.HasPlan(queryWithIndexMerge, "IndexMerge"))
+ tk.MustHavePlan(queryWithIndexMerge, "IndexMerge")
require.True(t, tk.HasKeywordInOperatorInfo(queryWithIndexMerge, "limit embedded"))
- require.True(t, tk.HasPlan(queryTableScan, "TableFullScan"))
+ tk.MustHavePlan(queryTableScan, "TableFullScan")
// index merge with embedded limit couldn't compare the exactly results with normal plan, because limit admission control has some difference, while we can only check
// the row count is exactly the same with tableFullScan plan, in case of index pushedLimit and table pushedLimit cut down the source table rows.
require.Equal(t, len(tk.MustQuery(queryWithIndexMerge).Rows()), len(tk.MustQuery(queryTableScan).Rows()))
@@ -1232,8 +1232,8 @@ func TestIndexMergeLimitNotPushedOnPartialSideButKeepOrder(t *testing.T) {
maxEle := tk.MustQuery(fmt.Sprintf("select ifnull(max(c), 100) from (select c from t use index(idx3) where (a = %d or b = %d) and c >= %d order by c limit %d) t", valA, valB, valC, limit)).Rows()[0][0]
queryWithIndexMerge := fmt.Sprintf("select /*+ USE_INDEX_MERGE(t, idx, idx2) */ * from t where (a = %d or b = %d) and c >= %d and c < greatest(%d, %v) order by c limit %d", valA, valB, valC, valC+1, maxEle, limit)
queryWithNormalIndex := fmt.Sprintf("select * from t use index(idx3) where (a = %d or b = %d) and c >= %d and c < greatest(%d, %v) order by c limit %d", valA, valB, valC, valC+1, maxEle, limit)
- require.True(t, tk.HasPlan(queryWithIndexMerge, "IndexMerge"))
- require.True(t, tk.HasPlan(queryWithIndexMerge, "Limit"))
+ tk.MustHavePlan(queryWithIndexMerge, "IndexMerge")
+ tk.MustHavePlan(queryWithIndexMerge, "Limit")
normalResult := tk.MustQuery(queryWithNormalIndex).Sort().Rows()
tk.MustQuery(queryWithIndexMerge).Sort().Check(normalResult)
}
@@ -1242,8 +1242,8 @@ func TestIndexMergeLimitNotPushedOnPartialSideButKeepOrder(t *testing.T) {
maxEle := tk.MustQuery(fmt.Sprintf("select ifnull(max(c), 100) from (select c from t use index(idx3) where (a = %d or b = %d) and c >= %d order by c limit %d offset %d) t", valA, valB, valC, limit, offset)).Rows()[0][0]
queryWithIndexMerge := fmt.Sprintf("select /*+ USE_INDEX_MERGE(t, idx, idx2) */ c from t where (a = %d or b = %d) and c >= %d and c < greatest(%d, %v) order by c limit %d offset %d", valA, valB, valC, valC+1, maxEle, limit, offset)
queryWithNormalIndex := fmt.Sprintf("select c from t use index(idx3) where (a = %d or b = %d) and c >= %d and c < greatest(%d, %v) order by c limit %d offset %d", valA, valB, valC, valC+1, maxEle, limit, offset)
- require.True(t, tk.HasPlan(queryWithIndexMerge, "IndexMerge"))
- require.True(t, tk.HasPlan(queryWithIndexMerge, "Limit"))
+ tk.MustHavePlan(queryWithIndexMerge, "IndexMerge")
+ tk.MustHavePlan(queryWithIndexMerge, "Limit")
normalResult := tk.MustQuery(queryWithNormalIndex).Sort().Rows()
tk.MustQuery(queryWithIndexMerge).Sort().Check(normalResult)
}
@@ -1256,8 +1256,8 @@ func TestIndexMergeNoOrderLimitPushed(t *testing.T) {
tk.MustExec("create table t(a int, b int, c int, index idx(a, c), index idx2(b, c))")
tk.MustExec("insert into t values(1, 1, 1), (2, 2, 2)")
sql := "select /*+ USE_INDEX_MERGE(t, idx, idx2) */ * from t where a = 1 or b = 1 limit 1"
- require.True(t, tk.HasPlan(sql, "IndexMerge"))
- require.True(t, tk.HasPlan(sql, "Limit"))
+ tk.MustHavePlan(sql, "IndexMerge")
+ tk.MustHavePlan(sql, "Limit")
// 6 means that IndexMerge(embedded limit){Limit->PartialIndexScan, Limit->PartialIndexScan, FinalTableScan}
require.Equal(t, 6, len(tk.MustQuery("explain "+sql).Rows()))
// The result is not stable. So we just check that it can run successfully.
@@ -1273,15 +1273,15 @@ func TestIndexMergeKeepOrderDirtyRead(t *testing.T) {
tk.MustExec("begin")
tk.MustExec("insert into t values(1, 1, -3)")
querySQL := "select /*+ USE_INDEX_MERGE(t, idx1, idx2) */ * from t where a = 1 or b = 1 order by c limit 2"
- tk.HasPlan(querySQL, "Limit")
- tk.HasPlan(querySQL, "IndexMerge")
+ tk.MustHavePlan(querySQL, "Limit")
+ tk.MustHavePlan(querySQL, "IndexMerge")
tk.MustQuery(querySQL).Check(testkit.Rows("1 1 -3", "2 1 -2"))
tk.MustExec("rollback")
tk.MustExec("begin")
tk.MustExec("insert into t values(1, 2, 4)")
querySQL = "select /*+ USE_INDEX_MERGE(t, idx1, idx2) */ * from t where a = 1 or b = 1 order by c desc limit 2"
- tk.HasPlan(querySQL, "Limit")
- tk.HasPlan(querySQL, "IndexMerge")
+ tk.MustHavePlan(querySQL, "Limit")
+ tk.MustHavePlan(querySQL, "IndexMerge")
tk.MustQuery(querySQL).Check(testkit.Rows("1 2 4", "1 1 1"))
tk.MustExec("rollback")
}
diff --git a/executor/test/loadremotetest/error_test.go b/executor/test/loadremotetest/error_test.go
index 330486049a97a..beb941ed96470 100644
--- a/executor/test/loadremotetest/error_test.go
+++ b/executor/test/loadremotetest/error_test.go
@@ -54,10 +54,10 @@ func (s *mockGCSSuite) TestErrorMessage() {
checkClientErrorMessage(s.T(), err, "ERROR 1054 (42S22): Unknown column 'wrong' in 'field list'")
err = s.tk.ExecToErr("LOAD DATA INFILE 'abc://1' INTO TABLE t;")
checkClientErrorMessage(s.T(), err,
- "ERROR 8158 (HY000): The URI of file location is invalid. Reason: storage abc not support yet. Please provide a valid URI, such as 's3://import/test.csv?access_key_id={your_access_key_id ID}&secret_access_key={your_secret_access_key}&session_token={your_session_token}'")
+ "ERROR 8158 (HY000): The URI of data source is invalid. Reason: storage abc not support yet. Please provide a valid URI, such as 's3://import/test.csv?access_key_id={your_access_key_id ID}&secret_access_key={your_secret_access_key}&session_token={your_session_token}'")
err = s.tk.ExecToErr("LOAD DATA INFILE 's3://no-network' INTO TABLE t;")
checkClientErrorMessage(s.T(), err,
- "ERROR 8159 (HY000): Access to the source file has been denied. Reason: failed to get region of bucket no-network. Please check the URI, access key and secret access key are correct")
+ "ERROR 8159 (HY000): Access to the data source has been denied. Reason: failed to get region of bucket no-network. Please check the URI, access key and secret access key are correct")
err = s.tk.ExecToErr(fmt.Sprintf(`LOAD DATA INFILE 'gs://wrong-bucket/p?endpoint=%s'
INTO TABLE t;`, gcsEndpoint))
checkClientErrorMessage(s.T(), err,
diff --git a/executor/test/showtest/show_test.go b/executor/test/showtest/show_test.go
index 607806ef1c61a..16fa29789684e 100644
--- a/executor/test/showtest/show_test.go
+++ b/executor/test/showtest/show_test.go
@@ -917,8 +917,8 @@ func TestIssue10549(t *testing.T) {
require.NoError(t, tk.Session().Auth(&auth.UserIdentity{Username: "dev", Hostname: "%", AuthUsername: "dev", AuthHostname: "%"}, nil, nil, nil))
tk.MustQuery("SHOW DATABASES;").Check(testkit.Rows("INFORMATION_SCHEMA", "newdb"))
- tk.MustQuery("SHOW GRANTS;").Check(testkit.Rows("GRANT USAGE ON *.* TO 'dev'@'%'", "GRANT ALL PRIVILEGES ON newdb.* TO 'dev'@'%'", "GRANT 'app_developer'@'%' TO 'dev'@'%'"))
- tk.MustQuery("SHOW GRANTS FOR CURRENT_USER").Check(testkit.Rows("GRANT USAGE ON *.* TO 'dev'@'%'", "GRANT ALL PRIVILEGES ON newdb.* TO 'dev'@'%'", "GRANT 'app_developer'@'%' TO 'dev'@'%'"))
+ tk.MustQuery("SHOW GRANTS;").Check(testkit.Rows("GRANT USAGE ON *.* TO 'dev'@'%'", "GRANT ALL PRIVILEGES ON `newdb`.* TO 'dev'@'%'", "GRANT 'app_developer'@'%' TO 'dev'@'%'"))
+ tk.MustQuery("SHOW GRANTS FOR CURRENT_USER").Check(testkit.Rows("GRANT USAGE ON *.* TO 'dev'@'%'", "GRANT ALL PRIVILEGES ON `newdb`.* TO 'dev'@'%'", "GRANT 'app_developer'@'%' TO 'dev'@'%'"))
tk.MustQuery("SHOW GRANTS FOR dev").Check(testkit.Rows("GRANT USAGE ON *.* TO 'dev'@'%'", "GRANT 'app_developer'@'%' TO 'dev'@'%'"))
}
diff --git a/executor/test/simpletest/simple_test.go b/executor/test/simpletest/simple_test.go
index a941b4609a68b..badb1821f88b0 100644
--- a/executor/test/simpletest/simple_test.go
+++ b/executor/test/simpletest/simple_test.go
@@ -264,9 +264,9 @@ func TestPrivilegesAfterDropUser(t *testing.T) {
tk.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "%"}, nil, nil, nil)
tk.MustQuery("SHOW GRANTS FOR u1").Check(testkit.Rows(
"GRANT USAGE ON *.* TO 'u1'@'%'",
- "GRANT CREATE ON test.* TO 'u1'@'%'",
- "GRANT UPDATE ON test.t1 TO 'u1'@'%'",
- "GRANT SELECT(v), UPDATE(v) ON test.t1 TO 'u1'@'%'",
+ "GRANT CREATE ON `test`.* TO 'u1'@'%'",
+ "GRANT UPDATE ON `test`.`t1` TO 'u1'@'%'",
+ "GRANT SELECT(v), UPDATE(v) ON `test`.`t1` TO 'u1'@'%'",
"GRANT SYSTEM_VARIABLES_ADMIN ON *.* TO 'u1'@'%'",
))
diff --git a/executor/test/tiflashtest/tiflash_test.go b/executor/test/tiflashtest/tiflash_test.go
index bfc8d1caf8ff2..0615208314125 100644
--- a/executor/test/tiflashtest/tiflash_test.go
+++ b/executor/test/tiflashtest/tiflash_test.go
@@ -990,7 +990,7 @@ func TestTiFlashPartitionTableShuffledHashAggregation(t *testing.T) {
tk.MustExec(fmt.Sprintf("set @@tidb_partition_prune_mode = '%v'", mode))
for _, tbl := range []string{`thash`, `trange`, `tlist`, `tnormal`} {
q := fmt.Sprintf("select /*+ HASH_AGG() */ count(*) from %v t1 where %v", tbl, cond)
- require.True(t, tk.HasPlan(q, "HashAgg"))
+ tk.MustHavePlan(q, "HashAgg")
if res == nil {
res = tk.MustQuery(q).Sort().Rows()
} else {
diff --git a/executor/union_scan_test.go b/executor/union_scan_test.go
index efe5d1562e97f..a4b7932b49b13 100644
--- a/executor/union_scan_test.go
+++ b/executor/union_scan_test.go
@@ -531,29 +531,22 @@ func TestIssue32422(t *testing.T) {
tk.MustQuery("select id+1, c from t where c = 4;").Check(testkit.Rows("5 4"))
tk.MustExec("insert into t values (6, 6)")
// Check for the new added data.
- tk.HasPlan("select id+1, c from t where c = 6;", "UnionScan")
+ tk.MustHavePlan("select id+1, c from t where c = 6;", "UnionScan")
tk.MustQuery("select id+1, c from t where c = 6;").Check(testkit.Rows("7 6"))
require.True(t, tk.Session().GetSessionVars().StmtCtx.ReadFromTableCache)
// Check for the old data.
tk.MustQuery("select id+1, c from t where c = 4;").Check(testkit.Rows("5 4"))
require.True(t, tk.Session().GetSessionVars().StmtCtx.ReadFromTableCache)
- // Point get
- tk.HasPlan("select id+1, c from t where id = 6", "PointGet")
- tk.MustQuery("select id+1, c from t where id = 6").Check(testkit.Rows("7 6"))
- require.True(t, tk.Session().GetSessionVars().StmtCtx.ReadFromTableCache)
- tk.MustQuery("select id+1, c from t where id = 4").Check(testkit.Rows("5 4"))
- require.True(t, tk.Session().GetSessionVars().StmtCtx.ReadFromTableCache)
-
// Index Lookup
- tk.HasPlan("select id+1, c from t where id = 6", "IndexLookUp")
+ tk.MustHavePlan("select id+1, c from t where id = 6", "IndexLookUp")
tk.MustQuery("select id+1, c from t use index(id) where id = 6").Check(testkit.Rows("7 6"))
require.True(t, tk.Session().GetSessionVars().StmtCtx.ReadFromTableCache)
tk.MustQuery("select id+1, c from t use index(id) where id = 4").Check(testkit.Rows("5 4"))
require.True(t, tk.Session().GetSessionVars().StmtCtx.ReadFromTableCache)
// Index Reader
- tk.HasPlan("select id from t where id = 6", "IndexReader")
+ tk.MustHavePlan("select id from t where id = 6", "IndexReader")
tk.MustQuery("select id from t use index(id) where id = 6").Check(testkit.Rows("6"))
require.True(t, tk.Session().GetSessionVars().StmtCtx.ReadFromTableCache)
tk.MustQuery("select id from t use index(id) where id = 4").Check(testkit.Rows("4"))
@@ -624,7 +617,7 @@ func BenchmarkUnionScanIndexReadDescRead(b *testing.B) {
tk.MustExec(fmt.Sprintf("insert into t values (%d, %d, %d)", i, i, i))
}
- tk.HasPlan("select b from t use index(k) where b > 50 order by b desc", "IndexReader")
+ tk.MustHavePlan("select b from t use index(k) where b > 50 order by b desc", "IndexReader")
b.ReportAllocs()
b.ResetTimer()
@@ -646,7 +639,7 @@ func BenchmarkUnionScanTableReadDescRead(b *testing.B) {
tk.MustExec(fmt.Sprintf("insert into t values (%d, %d, %d)", i, i, i))
}
- tk.HasPlan("select * from t where a > 50 order by a desc", "TableReader")
+ tk.MustHavePlan("select * from t where a > 50 order by a desc", "TableReader")
b.ReportAllocs()
b.ResetTimer()
@@ -668,7 +661,7 @@ func BenchmarkUnionScanIndexLookUpDescRead(b *testing.B) {
tk.MustExec(fmt.Sprintf("insert into t values (%d, %d, %d)", i, i, i))
}
- tk.HasPlan("select * from t use index(k) where b > 50 order by b desc", "IndexLookUp")
+ tk.MustHavePlan("select * from t use index(k) where b > 50 order by b desc", "IndexLookUp")
b.ReportAllocs()
b.ResetTimer()
diff --git a/executor/window.go b/executor/window.go
index 31d95a69e11ed..2fb4cc76ded01 100644
--- a/executor/window.go
+++ b/executor/window.go
@@ -399,7 +399,7 @@ func (p *rangeFrameWindowProcessor) getStartOffset(ctx sessionctx.Context, rows
var res int64
var err error
for i := range p.orderByCols {
- res, _, err = p.start.CmpFuncs[i](ctx, p.orderByCols[i], p.start.CalcFuncs[i], rows[p.lastStartOffset], rows[p.curRowIdx])
+ res, _, err = p.start.CmpFuncs[i](ctx, p.start.CompareCols[i], p.start.CalcFuncs[i], rows[p.lastStartOffset], rows[p.curRowIdx])
if err != nil {
return 0, err
}
@@ -425,7 +425,7 @@ func (p *rangeFrameWindowProcessor) getEndOffset(ctx sessionctx.Context, rows []
var res int64
var err error
for i := range p.orderByCols {
- res, _, err = p.end.CmpFuncs[i](ctx, p.end.CalcFuncs[i], p.orderByCols[i], rows[p.curRowIdx], rows[p.lastEndOffset])
+ res, _, err = p.end.CmpFuncs[i](ctx, p.end.CalcFuncs[i], p.end.CompareCols[i], rows[p.curRowIdx], rows[p.lastEndOffset])
if err != nil {
return 0, err
}
diff --git a/executor/write.go b/executor/write.go
index 4b789162ac264..87ae63a0b230a 100644
--- a/executor/write.go
+++ b/executor/write.go
@@ -26,9 +26,11 @@ import (
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/meta/autoid"
"github.com/pingcap/tidb/parser/ast"
+ "github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/sessionctx"
+ "github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/tablecodec"
"github.com/pingcap/tidb/types"
@@ -79,23 +81,8 @@ func updateRecord(
// Handle exchange partition
tbl := t.Meta()
- if tbl.ExchangePartitionInfo != nil {
- is := sctx.GetDomainInfoSchema().(infoschema.InfoSchema)
- pt, tableFound := is.TableByID(tbl.ExchangePartitionInfo.ExchangePartitionID)
- if !tableFound {
- return false, errors.Errorf("exchange partition process table by id failed")
- }
- p, ok := pt.(table.PartitionedTable)
- if !ok {
- return false, errors.Errorf("exchange partition process assert table partition failed")
- }
- err := p.CheckForExchangePartition(
- sctx,
- pt.Meta().Partition,
- newData,
- tbl.ExchangePartitionInfo.ExchangePartitionDefID,
- )
- if err != nil {
+ if tbl.ExchangePartitionInfo != nil && tbl.GetPartitionInfo() == nil {
+ if err := checkRowForExchangePartition(sctx, newData, tbl); err != nil {
return false, err
}
}
@@ -326,3 +313,42 @@ func resetErrDataTooLong(colName string, rowIdx int, _ error) error {
newErr := types.ErrDataTooLong.GenWithStack("Data too long for column '%v' at row %v", colName, rowIdx)
return newErr
}
+
+// checkRowForExchangePartition is only used for ExchangePartition by non-partitionTable during write only state.
+// It check if rowData inserted or updated violate partition definition or checkConstraints of partitionTable.
+func checkRowForExchangePartition(sctx sessionctx.Context, row []types.Datum, tbl *model.TableInfo) error {
+ is := sctx.GetDomainInfoSchema().(infoschema.InfoSchema)
+ pt, tableFound := is.TableByID(tbl.ExchangePartitionInfo.ExchangePartitionTableID)
+ if !tableFound {
+ return errors.Errorf("exchange partition process table by id failed")
+ }
+ p, ok := pt.(table.PartitionedTable)
+ if !ok {
+ return errors.Errorf("exchange partition process assert table partition failed")
+ }
+ err := p.CheckForExchangePartition(
+ sctx,
+ pt.Meta().Partition,
+ row,
+ tbl.ExchangePartitionInfo.ExchangePartitionDefID,
+ tbl.ID,
+ )
+ if err != nil {
+ return err
+ }
+ if variable.EnableCheckConstraint.Load() {
+ type CheckConstraintTable interface {
+ CheckRowConstraint(sctx sessionctx.Context, rowToCheck []types.Datum) error
+ }
+ cc, ok := pt.(CheckConstraintTable)
+ if !ok {
+ return errors.Errorf("exchange partition process assert check constraint failed")
+ }
+ err := cc.CheckRowConstraint(sctx, row)
+ if err != nil {
+ // TODO: make error include ExchangePartition info.
+ return err
+ }
+ }
+ return nil
+}
diff --git a/expression/BUILD.bazel b/expression/BUILD.bazel
index 11da852e21c9b..6cffb6112fca0 100644
--- a/expression/BUILD.bazel
+++ b/expression/BUILD.bazel
@@ -171,7 +171,6 @@ go_test(
"builtin_vectorized_test.go",
"collation_test.go",
"column_test.go",
- "constant_fold_test.go",
"constant_test.go",
"distsql_builtin_test.go",
"evaluator_test.go",
diff --git a/expression/OWNERS b/expression/OWNERS
new file mode 100644
index 0000000000000..468d81c5e4b5e
--- /dev/null
+++ b/expression/OWNERS
@@ -0,0 +1,5 @@
+# See the OWNERS docs at https://go.k8s.io/owners
+options:
+ no_parent_owners: true
+approvers:
+ - sig-approvers-expression
diff --git a/expression/casetest/BUILD.bazel b/expression/casetest/BUILD.bazel
deleted file mode 100644
index ee76a6ea2fdd1..0000000000000
--- a/expression/casetest/BUILD.bazel
+++ /dev/null
@@ -1,23 +0,0 @@
-load("@io_bazel_rules_go//go:def.bzl", "go_test")
-
-go_test(
- name = "casetest_test",
- timeout = "short",
- srcs = [
- "constant_propagation_test.go",
- "flag_simplify_test.go",
- "main_test.go",
- ],
- data = glob(["testdata/**"]),
- flaky = True,
- deps = [
- "//config",
- "//testkit",
- "//testkit/testdata",
- "//testkit/testmain",
- "//testkit/testsetup",
- "//util/timeutil",
- "@com_github_tikv_client_go_v2//tikv",
- "@org_uber_go_goleak//:goleak",
- ],
-)
diff --git a/expression/casetest/constant_propagation_test.go b/expression/casetest/constant_propagation_test.go
deleted file mode 100644
index 41b584cd2c3c5..0000000000000
--- a/expression/casetest/constant_propagation_test.go
+++ /dev/null
@@ -1,49 +0,0 @@
-// Copyright 2018 PingCAP, Inc.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package casetest
-
-import (
- "testing"
-
- "github.com/pingcap/tidb/testkit"
- "github.com/pingcap/tidb/testkit/testdata"
-)
-
-func TestOuterJoinPropConst(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("set tidb_cost_model_version=2")
- tk.MustExec("drop table if exists t1, t2;")
- tk.MustExec("create table t1(id bigint primary key, a int, b int);")
- tk.MustExec("create table t2(id bigint primary key, a int, b int);")
-
- var input []string
- var output []struct {
- SQL string
- Result []string
- }
-
- expressionSuiteData := GetExpressionSuiteData()
- expressionSuiteData.LoadTestCases(t, &input, &output)
- for i, tt := range input {
- testdata.OnRecord(func() {
- output[i].SQL = tt
- output[i].Result = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Rows())
- })
- tk.MustQuery(tt).Check(testkit.Rows(output[i].Result...))
- }
-}
diff --git a/expression/casetest/flag_simplify_test.go b/expression/casetest/flag_simplify_test.go
deleted file mode 100644
index 13f9048bc4f88..0000000000000
--- a/expression/casetest/flag_simplify_test.go
+++ /dev/null
@@ -1,46 +0,0 @@
-// Copyright 2020 PingCAP, Inc.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package casetest
-
-import (
- "testing"
-
- "github.com/pingcap/tidb/testkit"
- "github.com/pingcap/tidb/testkit/testdata"
-)
-
-func TestSimplifyExpressionByFlag(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(id int primary key, a bigint unsigned not null, b bigint unsigned)")
-
- var input []string
- var output []struct {
- SQL string
- Plan []string
- }
- flagSimplifyData := GetFlagSimplifyData()
- flagSimplifyData.LoadTestCases(t, &input, &output)
- for i, tt := range input {
- testdata.OnRecord(func() {
- output[i].SQL = tt
- output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Rows())
- })
- tk.MustQuery(tt).Check(testkit.Rows(output[i].Plan...))
- }
-}
diff --git a/expression/casetest/main_test.go b/expression/casetest/main_test.go
deleted file mode 100644
index b027d0f1de5d6..0000000000000
--- a/expression/casetest/main_test.go
+++ /dev/null
@@ -1,71 +0,0 @@
-// Copyright 2023 PingCAP, Inc.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package casetest
-
-import (
- "testing"
-
- "github.com/pingcap/tidb/config"
- "github.com/pingcap/tidb/testkit/testdata"
- "github.com/pingcap/tidb/testkit/testmain"
- "github.com/pingcap/tidb/testkit/testsetup"
- "github.com/pingcap/tidb/util/timeutil"
- "github.com/tikv/client-go/v2/tikv"
- "go.uber.org/goleak"
-)
-
-var testDataMap = make(testdata.BookKeeper)
-
-func TestMain(m *testing.M) {
- testsetup.SetupForCommonTest()
- testmain.ShortCircuitForBench(m)
-
- config.UpdateGlobal(func(conf *config.Config) {
- conf.TiKVClient.AsyncCommit.SafeWindow = 0
- conf.TiKVClient.AsyncCommit.AllowedClockDrift = 0
- conf.Experimental.AllowsExpressionIndex = true
- })
- tikv.EnableFailpoints()
-
- // Some test depends on the values of timeutil.SystemLocation()
- // If we don't SetSystemTZ() here, the value would change unpredictable.
- // Affected by the order whether a testsuite runs before or after integration test.
- // Note, SetSystemTZ() is a sync.Once operation.
- timeutil.SetSystemTZ("system")
-
- testDataMap.LoadTestSuiteData("testdata", "flag_simplify")
- testDataMap.LoadTestSuiteData("testdata", "expression_suite")
-
- opts := []goleak.Option{
- goleak.IgnoreTopFunction("github.com/golang/glog.(*fileSink).flushDaemon"),
- goleak.IgnoreTopFunction("github.com/lestrrat-go/httprc.runFetchWorker"),
- goleak.IgnoreTopFunction("go.etcd.io/etcd/client/pkg/v3/logutil.(*MergeLogger).outputLoop"),
- goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
- }
-
- callback := func(i int) int {
- testDataMap.GenerateOutputIfNeeded()
- return i
- }
- goleak.VerifyTestMain(testmain.WrapTestingM(m, callback), opts...)
-}
-
-func GetFlagSimplifyData() testdata.TestData {
- return testDataMap["flag_simplify"]
-}
-
-func GetExpressionSuiteData() testdata.TestData {
- return testDataMap["expression_suite"]
-}
diff --git a/expression/casetest/testdata/expression_suite_in.json b/expression/casetest/testdata/expression_suite_in.json
deleted file mode 100644
index a059793e76982..0000000000000
--- a/expression/casetest/testdata/expression_suite_in.json
+++ /dev/null
@@ -1,41 +0,0 @@
-[
- {
- "name": "TestOuterJoinPropConst",
- "cases": [
- // Positive tests.
- "explain format = 'brief' select * from t1 left join t2 on t1.a > t2.a and t1.a = 1",
- "explain format = 'brief' select * from t1 left join t2 on t1.a > t2.a where t1.a = 1",
- "explain format = 'brief' select * from t1 left join t2 on t1.a = t2.a and t1.a > 1",
- "explain format = 'brief' select * from t1 left join t2 on t1.a = t2.a where t1.a > 1",
- "explain format = 'brief' select * from t1 right join t2 on t1.a > t2.a where t2.a = 1",
- "explain format = 'brief' select * from t1 right join t2 on t1.a = t2.a where t2.a > 1",
- "explain format = 'brief' select * from t1 right join t2 on t1.a = t2.a and t2.a > 1",
- "explain format = 'brief' select * from t1 right join t2 on t1.a > t2.a and t2.a = 1",
- // Negative tests.
- "explain format = 'brief' select * from t1 left join t2 on t1.a = t2.a and t2.a > 1",
- "explain format = 'brief' select * from t1 left join t2 on t1.a > t2.a and t2.a = 1",
- "explain format = 'brief' select * from t1 right join t2 on t1.a > t2.a and t1.a = 1",
- "explain format = 'brief' select * from t1 right join t2 on t1.a = t2.a and t1.a > 1",
- "explain format = 'brief' select * from t1 left join t2 on t1.a = t1.b and t1.a > 1",
- "explain format = 'brief' select * from t1 left join t2 on t2.a = t2.b and t2.a > 1",
- // Constant equal condition merge in outer join.
- "explain format = 'brief' select * from t1 left join t2 on true where t1.a = 1 and false",
- "explain format = 'brief' select * from t1 left join t2 on true where t1.a = 1 and null",
- "explain format = 'brief' select * from t1 left join t2 on true where t1.a = null",
- "explain format = 'brief' select * from t1 left join t2 on true where t1.a = 1 and t1.a = 2",
- "explain format = 'brief' select * from t1 left join t2 on true where t1.a = 1 and t1.a = 1",
- "explain format = 'brief' select * from t1 left join t2 on false",
- "explain format = 'brief' select * from t1 right join t2 on false",
- "explain format = 'brief' select * from t1 left join t2 on t1.a = 1 and t1.a = 2",
- "explain format = 'brief' select * from t1 left join t2 on t1.a =1 where t1.a = 2",
- "explain format = 'brief' select * from t1 left join t2 on t2.a = 1 and t2.a = 2",
- // Constant propagation for DNF in outer join.
- "explain format = 'brief' select * from t1 left join t2 on t1.a = 1 or (t1.a = 2 and t1.a = 3)",
- "explain format = 'brief' select * from t1 left join t2 on true where t1.a = 1 or (t1.a = 2 and t1.a = 3)",
- // Constant propagation over left outer semi join, filter with aux column should not be derived.
- "explain format = 'brief' select * from t1 where t1.b > 1 or t1.b in (select b from t2)",
- // Don't propagate for the control function.
- "explain format = 'brief' select * from t1 left join t2 on t1.a = t2.a where ifnull(t2.b, t1.a) = 1"
- ]
- }
-]
diff --git a/expression/casetest/testdata/expression_suite_out.json b/expression/casetest/testdata/expression_suite_out.json
deleted file mode 100644
index 164ccd7f50311..0000000000000
--- a/expression/casetest/testdata/expression_suite_out.json
+++ /dev/null
@@ -1,290 +0,0 @@
-[
- {
- "Name": "TestOuterJoinPropConst",
- "Cases": [
- {
- "SQL": "explain format = 'brief' select * from t1 left join t2 on t1.a > t2.a and t1.a = 1",
- "Result": [
- "HashJoin 33233333.33 root CARTESIAN left outer join, left cond:[eq(test.t1.a, 1)]",
- "├─TableReader(Build) 3323.33 root data:Selection",
- "│ └─Selection 3323.33 cop[tikv] gt(1, test.t2.a)",
- "│ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo",
- "└─TableReader(Probe) 10000.00 root data:TableFullScan",
- " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t1 left join t2 on t1.a > t2.a where t1.a = 1",
- "Result": [
- "HashJoin 33233.33 root CARTESIAN left outer join",
- "├─TableReader(Build) 10.00 root data:Selection",
- "│ └─Selection 10.00 cop[tikv] eq(test.t1.a, 1)",
- "│ └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo",
- "└─TableReader(Probe) 3323.33 root data:Selection",
- " └─Selection 3323.33 cop[tikv] gt(1, test.t2.a)",
- " └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t1 left join t2 on t1.a = t2.a and t1.a > 1",
- "Result": [
- "HashJoin 10000.00 root left outer join, equal:[eq(test.t1.a, test.t2.a)], left cond:[gt(test.t1.a, 1)]",
- "├─TableReader(Build) 3333.33 root data:Selection",
- "│ └─Selection 3333.33 cop[tikv] gt(test.t2.a, 1), not(isnull(test.t2.a))",
- "│ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo",
- "└─TableReader(Probe) 10000.00 root data:TableFullScan",
- " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t1 left join t2 on t1.a = t2.a where t1.a > 1",
- "Result": [
- "HashJoin 4166.67 root left outer join, equal:[eq(test.t1.a, test.t2.a)]",
- "├─TableReader(Build) 3333.33 root data:Selection",
- "│ └─Selection 3333.33 cop[tikv] gt(test.t2.a, 1), not(isnull(test.t2.a))",
- "│ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo",
- "└─TableReader(Probe) 3333.33 root data:Selection",
- " └─Selection 3333.33 cop[tikv] gt(test.t1.a, 1)",
- " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t1 right join t2 on t1.a > t2.a where t2.a = 1",
- "Result": [
- "HashJoin 33333.33 root CARTESIAN right outer join",
- "├─TableReader(Build) 10.00 root data:Selection",
- "│ └─Selection 10.00 cop[tikv] eq(test.t2.a, 1)",
- "│ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo",
- "└─TableReader(Probe) 3333.33 root data:Selection",
- " └─Selection 3333.33 cop[tikv] gt(test.t1.a, 1)",
- " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t1 right join t2 on t1.a = t2.a where t2.a > 1",
- "Result": [
- "HashJoin 4166.67 root right outer join, equal:[eq(test.t1.a, test.t2.a)]",
- "├─TableReader(Build) 3333.33 root data:Selection",
- "│ └─Selection 3333.33 cop[tikv] gt(test.t2.a, 1)",
- "│ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo",
- "└─TableReader(Probe) 3333.33 root data:Selection",
- " └─Selection 3333.33 cop[tikv] gt(test.t1.a, 1), not(isnull(test.t1.a))",
- " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t1 right join t2 on t1.a = t2.a and t2.a > 1",
- "Result": [
- "HashJoin 10000.00 root right outer join, equal:[eq(test.t1.a, test.t2.a)], right cond:gt(test.t2.a, 1)",
- "├─TableReader(Build) 3333.33 root data:Selection",
- "│ └─Selection 3333.33 cop[tikv] gt(test.t1.a, 1), not(isnull(test.t1.a))",
- "│ └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo",
- "└─TableReader(Probe) 10000.00 root data:TableFullScan",
- " └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t1 right join t2 on t1.a > t2.a and t2.a = 1",
- "Result": [
- "HashJoin 33333333.33 root CARTESIAN right outer join, right cond:eq(test.t2.a, 1)",
- "├─TableReader(Build) 3333.33 root data:Selection",
- "│ └─Selection 3333.33 cop[tikv] gt(test.t1.a, 1)",
- "│ └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo",
- "└─TableReader(Probe) 10000.00 root data:TableFullScan",
- " └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t1 left join t2 on t1.a = t2.a and t2.a > 1",
- "Result": [
- "HashJoin 10000.00 root left outer join, equal:[eq(test.t1.a, test.t2.a)]",
- "├─TableReader(Build) 3333.33 root data:Selection",
- "│ └─Selection 3333.33 cop[tikv] gt(test.t2.a, 1), not(isnull(test.t2.a))",
- "│ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo",
- "└─TableReader(Probe) 10000.00 root data:TableFullScan",
- " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t1 left join t2 on t1.a > t2.a and t2.a = 1",
- "Result": [
- "HashJoin 100000.00 root CARTESIAN left outer join, other cond:gt(test.t1.a, test.t2.a)",
- "├─TableReader(Build) 10.00 root data:Selection",
- "│ └─Selection 10.00 cop[tikv] eq(test.t2.a, 1)",
- "│ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo",
- "└─TableReader(Probe) 10000.00 root data:TableFullScan",
- " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t1 right join t2 on t1.a > t2.a and t1.a = 1",
- "Result": [
- "HashJoin 100000.00 root CARTESIAN right outer join, other cond:gt(test.t1.a, test.t2.a)",
- "├─TableReader(Build) 10.00 root data:Selection",
- "│ └─Selection 10.00 cop[tikv] eq(test.t1.a, 1)",
- "│ └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo",
- "└─TableReader(Probe) 10000.00 root data:TableFullScan",
- " └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t1 right join t2 on t1.a = t2.a and t1.a > 1",
- "Result": [
- "HashJoin 10000.00 root right outer join, equal:[eq(test.t1.a, test.t2.a)]",
- "├─TableReader(Build) 3333.33 root data:Selection",
- "│ └─Selection 3333.33 cop[tikv] gt(test.t1.a, 1), not(isnull(test.t1.a))",
- "│ └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo",
- "└─TableReader(Probe) 10000.00 root data:TableFullScan",
- " └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t1 left join t2 on t1.a = t1.b and t1.a > 1",
- "Result": [
- "HashJoin 100000000.00 root CARTESIAN left outer join, left cond:[eq(test.t1.a, test.t1.b) gt(test.t1.a, 1)]",
- "├─TableReader(Build) 10000.00 root data:TableFullScan",
- "│ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo",
- "└─TableReader(Probe) 10000.00 root data:TableFullScan",
- " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t1 left join t2 on t2.a = t2.b and t2.a > 1",
- "Result": [
- "HashJoin 8888888.89 root CARTESIAN left outer join",
- "├─TableReader(Build) 888.89 root data:Selection",
- "│ └─Selection 888.89 cop[tikv] eq(test.t2.a, test.t2.b), gt(test.t2.a, 1), gt(test.t2.b, 1)",
- "│ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo",
- "└─TableReader(Probe) 10000.00 root data:TableFullScan",
- " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t1 left join t2 on true where t1.a = 1 and false",
- "Result": [
- "TableDual 0.00 root rows:0"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t1 left join t2 on true where t1.a = 1 and null",
- "Result": [
- "TableDual 0.00 root rows:0"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t1 left join t2 on true where t1.a = null",
- "Result": [
- "TableDual 0.00 root rows:0"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t1 left join t2 on true where t1.a = 1 and t1.a = 2",
- "Result": [
- "TableDual 0.00 root rows:0"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t1 left join t2 on true where t1.a = 1 and t1.a = 1",
- "Result": [
- "HashJoin 100000.00 root CARTESIAN left outer join",
- "├─TableReader(Build) 10.00 root data:Selection",
- "│ └─Selection 10.00 cop[tikv] eq(test.t1.a, 1)",
- "│ └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo",
- "└─TableReader(Probe) 10000.00 root data:TableFullScan",
- " └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t1 left join t2 on false",
- "Result": [
- "HashJoin 10000.00 root CARTESIAN left outer join",
- "├─TableDual(Build) 0.00 root rows:0",
- "└─TableReader(Probe) 10000.00 root data:TableFullScan",
- " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t1 right join t2 on false",
- "Result": [
- "HashJoin 10000.00 root CARTESIAN right outer join",
- "├─TableDual(Build) 0.00 root rows:0",
- "└─TableReader(Probe) 10000.00 root data:TableFullScan",
- " └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t1 left join t2 on t1.a = 1 and t1.a = 2",
- "Result": [
- "HashJoin 10000.00 root CARTESIAN left outer join",
- "├─TableDual(Build) 0.00 root rows:0",
- "└─TableReader(Probe) 10000.00 root data:TableFullScan",
- " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t1 left join t2 on t1.a =1 where t1.a = 2",
- "Result": [
- "HashJoin 10.00 root CARTESIAN left outer join",
- "├─TableDual(Build) 0.00 root rows:0",
- "└─TableReader(Probe) 10.00 root data:Selection",
- " └─Selection 10.00 cop[tikv] eq(test.t1.a, 2)",
- " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t1 left join t2 on t2.a = 1 and t2.a = 2",
- "Result": [
- "HashJoin 10000.00 root CARTESIAN left outer join",
- "├─TableDual(Build) 0.00 root rows:0",
- "└─TableReader(Probe) 10000.00 root data:TableFullScan",
- " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t1 left join t2 on t1.a = 1 or (t1.a = 2 and t1.a = 3)",
- "Result": [
- "HashJoin 100000000.00 root CARTESIAN left outer join, left cond:[or(eq(test.t1.a, 1), 0)]",
- "├─TableReader(Build) 10000.00 root data:TableFullScan",
- "│ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo",
- "└─TableReader(Probe) 10000.00 root data:TableFullScan",
- " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t1 left join t2 on true where t1.a = 1 or (t1.a = 2 and t1.a = 3)",
- "Result": [
- "HashJoin 100000.00 root CARTESIAN left outer join",
- "├─TableReader(Build) 10.00 root data:Selection",
- "│ └─Selection 10.00 cop[tikv] or(eq(test.t1.a, 1), 0)",
- "│ └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo",
- "└─TableReader(Probe) 10000.00 root data:TableFullScan",
- " └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t1 where t1.b > 1 or t1.b in (select b from t2)",
- "Result": [
- "Projection 8000.00 root test.t1.id, test.t1.a, test.t1.b",
- "└─Selection 8000.00 root or(gt(test.t1.b, 1), Column#7)",
- " └─HashJoin 10000.00 root CARTESIAN left outer semi join, other cond:eq(test.t1.b, test.t2.b)",
- " ├─TableReader(Build) 10000.00 root data:TableFullScan",
- " │ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo",
- " └─TableReader(Probe) 10000.00 root data:TableFullScan",
- " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t1 left join t2 on t1.a = t2.a where ifnull(t2.b, t1.a) = 1",
- "Result": [
- "Selection 9990.00 root eq(ifnull(test.t2.b, test.t1.a), 1)",
- "└─HashJoin 12487.50 root left outer join, equal:[eq(test.t1.a, test.t2.a)]",
- " ├─TableReader(Build) 9990.00 root data:Selection",
- " │ └─Selection 9990.00 cop[tikv] not(isnull(test.t2.a))",
- " │ └─TableFullScan 10000.00 cop[tikv] table:t2 keep order:false, stats:pseudo",
- " └─TableReader(Probe) 10000.00 root data:TableFullScan",
- " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo"
- ]
- }
- ]
- }
-]
diff --git a/expression/casetest/testdata/flag_simplify_in.json b/expression/casetest/testdata/flag_simplify_in.json
deleted file mode 100644
index 96c89a8cc5bef..0000000000000
--- a/expression/casetest/testdata/flag_simplify_in.json
+++ /dev/null
@@ -1,28 +0,0 @@
-[
- {
- "name": "TestSimplifyExpressionByFlag",
- "cases": [
- "explain format = 'brief' select * from t where a is null",
- "explain format = 'brief' select * from t where a is not null",
- "explain format = 'brief' select * from t where a > -1",
- "explain format = 'brief' select * from t where a <= -1",
- "explain format = 'brief' select * from t where a < 0",
- "explain format = 'brief' select * from t where a >= 0",
- "explain format = 'brief' select * from t where a = -1",
- "explain format = 'brief' select * from t where a <=> -1",
- "explain format = 'brief' select * from t where a != -1",
- "explain format = 'brief' select * from t where 0 > a",
- "explain format = 'brief' select * from t where 0 <= a",
- "explain format = 'brief' select * from t where -1 < a",
- "explain format = 'brief' select * from t where -1 >= a",
- "explain format = 'brief' select * from t where -1 = a",
- "explain format = 'brief' select * from t where -1 <=> a",
- "explain format = 'brief' select * from t where -1 != a",
- // Tuples with null b should be filered out.
- "explain format = 'brief' select * from t where b >= 0",
- "explain format = 'brief' select * from t where b != -1",
- // Int64 overflow corner case.
- "explain format = 'brief' select * from t where a = 0xFFFFFFFFFFFFFFFF"
- ]
- }
-]
diff --git a/expression/casetest/testdata/flag_simplify_out.json b/expression/casetest/testdata/flag_simplify_out.json
deleted file mode 100644
index 9d0986f903be8..0000000000000
--- a/expression/casetest/testdata/flag_simplify_out.json
+++ /dev/null
@@ -1,134 +0,0 @@
-[
- {
- "Name": "TestSimplifyExpressionByFlag",
- "Cases": [
- {
- "SQL": "explain format = 'brief' select * from t where a is null",
- "Plan": [
- "TableDual 0.00 root rows:0"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t where a is not null",
- "Plan": [
- "TableReader 10000.00 root data:TableFullScan",
- "└─TableFullScan 10000.00 cop[tikv] table:t keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t where a > -1",
- "Plan": [
- "TableReader 10000.00 root data:TableFullScan",
- "└─TableFullScan 10000.00 cop[tikv] table:t keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t where a <= -1",
- "Plan": [
- "TableDual 0.00 root rows:0"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t where a < 0",
- "Plan": [
- "TableDual 0.00 root rows:0"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t where a >= 0",
- "Plan": [
- "TableReader 10000.00 root data:TableFullScan",
- "└─TableFullScan 10000.00 cop[tikv] table:t keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t where a = -1",
- "Plan": [
- "TableDual 0.00 root rows:0"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t where a <=> -1",
- "Plan": [
- "TableDual 0.00 root rows:0"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t where a != -1",
- "Plan": [
- "TableReader 10000.00 root data:TableFullScan",
- "└─TableFullScan 10000.00 cop[tikv] table:t keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t where 0 > a",
- "Plan": [
- "TableDual 0.00 root rows:0"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t where 0 <= a",
- "Plan": [
- "TableReader 10000.00 root data:TableFullScan",
- "└─TableFullScan 10000.00 cop[tikv] table:t keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t where -1 < a",
- "Plan": [
- "TableReader 10000.00 root data:TableFullScan",
- "└─TableFullScan 10000.00 cop[tikv] table:t keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t where -1 >= a",
- "Plan": [
- "TableDual 0.00 root rows:0"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t where -1 = a",
- "Plan": [
- "TableDual 0.00 root rows:0"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t where -1 <=> a",
- "Plan": [
- "TableDual 0.00 root rows:0"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t where -1 != a",
- "Plan": [
- "TableReader 10000.00 root data:TableFullScan",
- "└─TableFullScan 10000.00 cop[tikv] table:t keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t where b >= 0",
- "Plan": [
- "TableReader 3333.33 root data:Selection",
- "└─Selection 3333.33 cop[tikv] ge(test.t.b, 0)",
- " └─TableFullScan 10000.00 cop[tikv] table:t keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t where b != -1",
- "Plan": [
- "TableReader 3333.33 root data:Selection",
- "└─Selection 3333.33 cop[tikv] ne(test.t.b, -1)",
- " └─TableFullScan 10000.00 cop[tikv] table:t keep order:false, stats:pseudo"
- ]
- },
- {
- "SQL": "explain format = 'brief' select * from t where a = 0xFFFFFFFFFFFFFFFF",
- "Plan": [
- "TableReader 10.00 root data:Selection",
- "└─Selection 10.00 cop[tikv] eq(test.t.a, 18446744073709551615)",
- " └─TableFullScan 10000.00 cop[tikv] table:t keep order:false, stats:pseudo"
- ]
- }
- ]
- }
-]
diff --git a/expression/constant_fold_test.go b/expression/constant_fold_test.go
deleted file mode 100644
index 91064bb76791e..0000000000000
--- a/expression/constant_fold_test.go
+++ /dev/null
@@ -1,39 +0,0 @@
-// Copyright 2019 PingCAP, Inc.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package expression_test
-
-import (
- "testing"
-
- "github.com/pingcap/tidb/testkit"
-)
-
-func TestFoldIfNull(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec(`use test;`)
- tk.MustExec(`drop table if exists t;`)
- tk.MustExec(`create table t(a bigint, b bigint);`)
- tk.MustExec(`insert into t values(1, 1);`)
- tk.MustQuery(`desc select ifnull("aaaa", a) from t;`).Check(testkit.Rows(
- `Projection_3 10000.00 root aaaa->Column#4`,
- `└─TableReader_5 10000.00 root data:TableFullScan_4`,
- ` └─TableFullScan_4 10000.00 cop[tikv] table:t keep order:false, stats:pseudo`,
- ))
- tk.MustQuery(`show warnings;`).Check(testkit.Rows())
- tk.MustQuery(`select ifnull("aaaa", a) from t;`).Check(testkit.Rows("aaaa"))
- tk.MustQuery(`show warnings;`).Check(testkit.Rows())
-}
diff --git a/expression/integration_serial_test/BUILD.bazel b/expression/integration_serial_test/BUILD.bazel
index bd5d60461c46f..0ba73c9a2fbf7 100644
--- a/expression/integration_serial_test/BUILD.bazel
+++ b/expression/integration_serial_test/BUILD.bazel
@@ -8,11 +8,10 @@ go_test(
"main_test.go",
],
flaky = True,
- shard_count = 17,
+ shard_count = 12,
deps = [
"//config",
"//expression",
- "//parser/mysql",
"//parser/terror",
"//planner/core",
"//session",
diff --git a/expression/integration_serial_test/integration_serial_test.go b/expression/integration_serial_test/integration_serial_test.go
index 49e5c6c6046c1..b834e5a30ecd8 100644
--- a/expression/integration_serial_test/integration_serial_test.go
+++ b/expression/integration_serial_test/integration_serial_test.go
@@ -27,7 +27,6 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/expression"
- "github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
plannercore "github.com/pingcap/tidb/planner/core"
"github.com/pingcap/tidb/session"
@@ -38,470 +37,6 @@ import (
"github.com/tikv/client-go/v2/oracle"
)
-func TestWeightString(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
-
- type testCase struct {
- input []string
- result []string
- resultAsChar1 []string
- resultAsChar3 []string
- resultAsBinary1 []string
- resultAsBinary5 []string
- resultExplicitCollateBin []string
- }
- tk.MustExec("use test")
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t (id int, a varchar(20) collate utf8mb4_general_ci)")
- cases := testCase{
- input: []string{"aAÁàãăâ", "a", "a ", "中", "中 "},
- result: []string{"\x00A\x00A\x00A\x00A\x00A\x00A\x00A", "\x00A", "\x00A", "\x4E\x2D", "\x4E\x2D"},
- resultAsChar1: []string{"\x00A", "\x00A", "\x00A", "\x4E\x2D", "\x4E\x2D"},
- resultAsChar3: []string{"\x00A\x00A\x00A", "\x00A", "\x00A", "\x4E\x2D", "\x4E\x2D"},
- resultAsBinary1: []string{"a", "a", "a", "\xE4", "\xE4"},
- resultAsBinary5: []string{"aA\xc3\x81\xc3", "a\x00\x00\x00\x00", "a \x00\x00", "中\x00\x00", "中 \x00"},
- resultExplicitCollateBin: []string{"aAÁàãăâ", "a", "a", "中", "中"},
- }
- values := make([]string, len(cases.input))
- for i, input := range cases.input {
- values[i] = fmt.Sprintf("(%d, '%s')", i, input)
- }
- tk.MustExec("insert into t values " + strings.Join(values, ","))
- rows := tk.MustQuery("select weight_string(a) from t order by id").Rows()
- for i, out := range cases.result {
- require.Equal(t, out, rows[i][0].(string))
- }
- rows = tk.MustQuery("select weight_string(a as char(1)) from t order by id").Rows()
- for i, out := range cases.resultAsChar1 {
- require.Equal(t, out, rows[i][0].(string))
- }
- rows = tk.MustQuery("select weight_string(a as char(3)) from t order by id").Rows()
- for i, out := range cases.resultAsChar3 {
- require.Equal(t, out, rows[i][0].(string))
- }
- rows = tk.MustQuery("select weight_string(a as binary(1)) from t order by id").Rows()
- for i, out := range cases.resultAsBinary1 {
- require.Equal(t, out, rows[i][0].(string))
- }
- rows = tk.MustQuery("select weight_string(a as binary(5)) from t order by id").Rows()
- for i, out := range cases.resultAsBinary5 {
- require.Equal(t, out, rows[i][0].(string))
- }
- require.Equal(t, "", tk.MustQuery("select weight_string(NULL);").Rows()[0][0])
- require.Equal(t, "", tk.MustQuery("select weight_string(7);").Rows()[0][0])
- require.Equal(t, "", tk.MustQuery("select weight_string(cast(7 as decimal(5)));").Rows()[0][0])
- require.Equal(t, "2019-08-21", tk.MustQuery("select weight_string(cast(20190821 as date));").Rows()[0][0])
- require.Equal(t, "2019-", tk.MustQuery("select weight_string(cast(20190821 as date) as binary(5));").Rows()[0][0])
- require.Equal(t, "", tk.MustQuery("select weight_string(7.0);").Rows()[0][0])
- require.Equal(t, "7\x00", tk.MustQuery("select weight_string(7 AS BINARY(2));").Rows()[0][0])
- // test explicit collation
- require.Equal(t, "\x4E\x2D", tk.MustQuery("select weight_string('中 ' collate utf8mb4_general_ci);").Rows()[0][0])
- require.Equal(t, "中", tk.MustQuery("select weight_string('中 ' collate utf8mb4_bin);").Rows()[0][0])
- require.Equal(t, "\xFB\x40\xCE\x2D", tk.MustQuery("select weight_string('中 ' collate utf8mb4_unicode_ci);").Rows()[0][0])
- require.Equal(t, "utf8mb4_general_ci", tk.MustQuery("select collation(a collate utf8mb4_general_ci) from t order by id").Rows()[0][0])
- require.Equal(t, "utf8mb4_general_ci", tk.MustQuery("select collation('中 ' collate utf8mb4_general_ci);").Rows()[0][0])
- rows = tk.MustQuery("select weight_string(a collate utf8mb4_bin) from t order by id").Rows()
- for i, out := range cases.resultExplicitCollateBin {
- require.Equal(t, out, rows[i][0].(string))
- }
- tk.MustGetErrMsg("select weight_string(a collate utf8_general_ci) from t order by id", "[ddl:1253]COLLATION 'utf8_general_ci' is not valid for CHARACTER SET 'utf8mb4'")
- tk.MustGetErrMsg("select weight_string('中' collate utf8_bin)", "[ddl:1253]COLLATION 'utf8_bin' is not valid for CHARACTER SET 'utf8mb4'")
-}
-
-func TestMathBuiltin(t *testing.T) {
- t.Skip("it has been broken. Please fix it as soon as possible.")
- ctx := context.Background()
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
-
- // for degrees
- result := tk.MustQuery("select degrees(0), degrees(1)")
- result.Check(testkit.Rows("0 57.29577951308232"))
- result = tk.MustQuery("select degrees(2), degrees(5)")
- result.Check(testkit.Rows("114.59155902616465 286.4788975654116"))
-
- // for sin
- result = tk.MustQuery("select sin(0), sin(1.5707963267949)")
- result.Check(testkit.Rows("0 1"))
- result = tk.MustQuery("select sin(1), sin(100)")
- result.Check(testkit.Rows("0.8414709848078965 -0.5063656411097588"))
- result = tk.MustQuery("select sin('abcd')")
- result.Check(testkit.Rows("0"))
-
- // for cos
- result = tk.MustQuery("select cos(0), cos(3.1415926535898)")
- result.Check(testkit.Rows("1 -1"))
- result = tk.MustQuery("select cos('abcd')")
- result.Check(testkit.Rows("1"))
-
- // for tan
- result = tk.MustQuery("select tan(0.00), tan(PI()/4)")
- result.Check(testkit.Rows("0 1"))
- result = tk.MustQuery("select tan('abcd')")
- result.Check(testkit.Rows("0"))
-
- // for log2
- result = tk.MustQuery("select log2(0.0)")
- result.Check(testkit.Rows(""))
- result = tk.MustQuery("select log2(4)")
- result.Check(testkit.Rows("2"))
- result = tk.MustQuery("select log2('8.0abcd')")
- result.Check(testkit.Rows("3"))
- result = tk.MustQuery("select log2(-1)")
- result.Check(testkit.Rows(""))
- result = tk.MustQuery("select log2(NULL)")
- result.Check(testkit.Rows(""))
-
- // for log10
- result = tk.MustQuery("select log10(0.0)")
- result.Check(testkit.Rows(""))
- result = tk.MustQuery("select log10(100)")
- result.Check(testkit.Rows("2"))
- result = tk.MustQuery("select log10('1000.0abcd')")
- result.Check(testkit.Rows("3"))
- result = tk.MustQuery("select log10(-1)")
- result.Check(testkit.Rows(""))
- result = tk.MustQuery("select log10(NULL)")
- result.Check(testkit.Rows(""))
-
- // for log
- result = tk.MustQuery("select log(0.0)")
- result.Check(testkit.Rows(""))
- result = tk.MustQuery("select log(100)")
- result.Check(testkit.Rows("4.605170185988092"))
- result = tk.MustQuery("select log('100.0abcd')")
- result.Check(testkit.Rows("4.605170185988092"))
- result = tk.MustQuery("select log(-1)")
- result.Check(testkit.Rows(""))
- result = tk.MustQuery("select log(NULL)")
- result.Check(testkit.Rows(""))
- result = tk.MustQuery("select log(NULL, NULL)")
- result.Check(testkit.Rows(""))
- result = tk.MustQuery("select log(1, 100)")
- result.Check(testkit.Rows(""))
- result = tk.MustQuery("select log(0.5, 0.25)")
- result.Check(testkit.Rows("2"))
- result = tk.MustQuery("select log(-1, 0.25)")
- result.Check(testkit.Rows(""))
-
- // for atan
- result = tk.MustQuery("select atan(0), atan(-1), atan(1), atan(1,2)")
- result.Check(testkit.Rows("0 -0.7853981633974483 0.7853981633974483 0.4636476090008061"))
- result = tk.MustQuery("select atan('tidb')")
- result.Check(testkit.Rows("0"))
-
- // for asin
- result = tk.MustQuery("select asin(0), asin(-2), asin(2), asin(1)")
- result.Check(testkit.Rows("0 1.5707963267948966"))
- result = tk.MustQuery("select asin('tidb')")
- result.Check(testkit.Rows("0"))
-
- // for acos
- result = tk.MustQuery("select acos(0), acos(-2), acos(2), acos(1)")
- result.Check(testkit.Rows("1.5707963267948966 0"))
- result = tk.MustQuery("select acos('tidb')")
- result.Check(testkit.Rows("1.5707963267948966"))
-
- // for pi
- result = tk.MustQuery("select pi()")
- result.Check(testkit.Rows("3.141592653589793"))
-
- // for floor
- result = tk.MustQuery("select floor(0), floor(null), floor(1.23), floor(-1.23), floor(1)")
- result.Check(testkit.Rows("0 1 -2 1"))
- result = tk.MustQuery("select floor('tidb'), floor('1tidb'), floor('tidb1')")
- result.Check(testkit.Rows("0 1 0"))
- result = tk.MustQuery("SELECT floor(t.c_datetime) FROM (select CAST('2017-07-19 00:00:00' AS DATETIME) AS c_datetime) AS t")
- result.Check(testkit.Rows("20170719000000"))
- result = tk.MustQuery("SELECT floor(t.c_time) FROM (select CAST('12:34:56' AS TIME) AS c_time) AS t")
- result.Check(testkit.Rows("123456"))
- result = tk.MustQuery("SELECT floor(t.c_time) FROM (select CAST('00:34:00' AS TIME) AS c_time) AS t")
- result.Check(testkit.Rows("3400"))
- result = tk.MustQuery("SELECT floor(t.c_time) FROM (select CAST('00:00:00' AS TIME) AS c_time) AS t")
- result.Check(testkit.Rows("0"))
- result = tk.MustQuery("SELECT floor(t.c_decimal) FROM (SELECT CAST('-10.01' AS DECIMAL(10,2)) AS c_decimal) AS t")
- result.Check(testkit.Rows("-11"))
- result = tk.MustQuery("SELECT floor(t.c_decimal) FROM (SELECT CAST('-10.01' AS DECIMAL(10,1)) AS c_decimal) AS t")
- result.Check(testkit.Rows("-10"))
-
- // for ceil/ceiling
- result = tk.MustQuery("select ceil(0), ceil(null), ceil(1.23), ceil(-1.23), ceil(1)")
- result.Check(testkit.Rows("0 2 -1 1"))
- result = tk.MustQuery("select ceiling(0), ceiling(null), ceiling(1.23), ceiling(-1.23), ceiling(1)")
- result.Check(testkit.Rows("0 2 -1 1"))
- result = tk.MustQuery("select ceil('tidb'), ceil('1tidb'), ceil('tidb1'), ceiling('tidb'), ceiling('1tidb'), ceiling('tidb1')")
- result.Check(testkit.Rows("0 1 0 0 1 0"))
- result = tk.MustQuery("select ceil(t.c_datetime), ceiling(t.c_datetime) from (select cast('2017-07-20 00:00:00' as datetime) as c_datetime) as t")
- result.Check(testkit.Rows("20170720000000 20170720000000"))
- result = tk.MustQuery("select ceil(t.c_time), ceiling(t.c_time) from (select cast('12:34:56' as time) as c_time) as t")
- result.Check(testkit.Rows("123456 123456"))
- result = tk.MustQuery("select ceil(t.c_time), ceiling(t.c_time) from (select cast('00:34:00' as time) as c_time) as t")
- result.Check(testkit.Rows("3400 3400"))
- result = tk.MustQuery("select ceil(t.c_time), ceiling(t.c_time) from (select cast('00:00:00' as time) as c_time) as t")
- result.Check(testkit.Rows("0 0"))
- result = tk.MustQuery("select ceil(t.c_decimal), ceiling(t.c_decimal) from (select cast('-10.01' as decimal(10,2)) as c_decimal) as t")
- result.Check(testkit.Rows("-10 -10"))
- result = tk.MustQuery("select ceil(t.c_decimal), ceiling(t.c_decimal) from (select cast('-10.01' as decimal(10,1)) as c_decimal) as t")
- result.Check(testkit.Rows("-10 -10"))
- result = tk.MustQuery("select floor(18446744073709551615), ceil(18446744073709551615)")
- result.Check(testkit.Rows("18446744073709551615 18446744073709551615"))
- result = tk.MustQuery("select floor(18446744073709551615.1233), ceil(18446744073709551615.1233)")
- result.Check(testkit.Rows("18446744073709551615 18446744073709551616"))
- result = tk.MustQuery("select floor(-18446744073709551617), ceil(-18446744073709551617), floor(-18446744073709551617.11), ceil(-18446744073709551617.11)")
- result.Check(testkit.Rows("-18446744073709551617 -18446744073709551617 -18446744073709551618 -18446744073709551617"))
- tk.MustExec("drop table if exists t;")
- tk.MustExec("create table t(a decimal(40,20) UNSIGNED);")
- tk.MustExec("insert into t values(2.99999999900000000000), (12), (0);")
- tk.MustQuery("select a, ceil(a) from t where ceil(a) > 1;").Check(testkit.Rows("2.99999999900000000000 3", "12.00000000000000000000 12"))
- tk.MustQuery("select a, ceil(a) from t;").Check(testkit.Rows("2.99999999900000000000 3", "12.00000000000000000000 12", "0.00000000000000000000 0"))
- tk.MustQuery("select ceil(-29464);").Check(testkit.Rows("-29464"))
- tk.MustQuery("select a, floor(a) from t where floor(a) > 1;").Check(testkit.Rows("2.99999999900000000000 2", "12.00000000000000000000 12"))
- tk.MustQuery("select a, floor(a) from t;").Check(testkit.Rows("2.99999999900000000000 2", "12.00000000000000000000 12", "0.00000000000000000000 0"))
- tk.MustQuery("select floor(-29464);").Check(testkit.Rows("-29464"))
-
- tk.MustExec(`drop table if exists t;`)
- tk.MustExec(`create table t(a decimal(40,20), b bigint);`)
- tk.MustExec(`insert into t values(-2.99999990000000000000, -1);`)
- tk.MustQuery(`select floor(a), floor(a), floor(a) from t;`).Check(testkit.Rows(`-3 -3 -3`))
- tk.MustQuery(`select b, floor(b) from t;`).Check(testkit.Rows(`-1 -1`))
-
- // for cot
- result = tk.MustQuery("select cot(1), cot(-1), cot(NULL)")
- result.Check(testkit.Rows("0.6420926159343308 -0.6420926159343308 "))
- result = tk.MustQuery("select cot('1tidb')")
- result.Check(testkit.Rows("0.6420926159343308"))
- rs, err := tk.Exec("select cot(0)")
- require.NoError(t, err)
- _, err = session.GetRows4Test(ctx, tk.Session(), rs)
- require.Error(t, err)
- terr := errors.Cause(err).(*terror.Error)
- require.Equal(t, errors.ErrCode(mysql.ErrDataOutOfRange), terr.Code())
- require.NoError(t, rs.Close())
-
- // for exp
- result = tk.MustQuery("select exp(0), exp(1), exp(-1), exp(1.2), exp(NULL)")
- result.Check(testkit.Rows("1 2.718281828459045 0.36787944117144233 3.3201169227365472 "))
- result = tk.MustQuery("select exp('tidb'), exp('1tidb')")
- result.Check(testkit.Rows("1 2.718281828459045"))
- rs, err = tk.Exec("select exp(1000000)")
- require.NoError(t, err)
- _, err = session.GetRows4Test(ctx, tk.Session(), rs)
- require.Error(t, err)
- terr = errors.Cause(err).(*terror.Error)
- require.Equal(t, errors.ErrCode(mysql.ErrDataOutOfRange), terr.Code())
- require.NoError(t, rs.Close())
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a float)")
- tk.MustExec("insert into t values(1000000)")
- rs, err = tk.Exec("select exp(a) from t")
- require.NoError(t, err)
- _, err = session.GetRows4Test(ctx, tk.Session(), rs)
- require.Error(t, err)
- terr = errors.Cause(err).(*terror.Error)
- require.Equal(t, errors.ErrCode(mysql.ErrDataOutOfRange), terr.Code())
- require.EqualError(t, err, "[types:1690]DOUBLE value is out of range in 'exp(test.t.a)'")
- require.NoError(t, rs.Close())
-
- // for conv
- result = tk.MustQuery("SELECT CONV('a', 16, 2);")
- result.Check(testkit.Rows("1010"))
- result = tk.MustQuery("SELECT CONV('6E', 18, 8);")
- result.Check(testkit.Rows("172"))
- result = tk.MustQuery("SELECT CONV(-17, 10, -18);")
- result.Check(testkit.Rows("-H"))
- result = tk.MustQuery("SELECT CONV(10+'10'+'10'+X'0a', 10, 10);")
- result.Check(testkit.Rows("40"))
- result = tk.MustQuery("SELECT CONV('a', 1, 10);")
- result.Check(testkit.Rows(""))
- result = tk.MustQuery("SELECT CONV('a', 37, 10);")
- result.Check(testkit.Rows(""))
- result = tk.MustQuery("SELECT CONV(0x0020, 2, 2);")
- result.Check(testkit.Rows("100000"))
- result = tk.MustQuery("SELECT CONV(0b10, 16, 2)")
- result.Check(testkit.Rows("10"))
- result = tk.MustQuery("SELECT CONV(0b10, 16, 8)")
- result.Check(testkit.Rows("2"))
- tk.MustExec("drop table if exists bit")
- tk.MustExec("create table bit(b bit(10))")
- tk.MustExec(`INSERT INTO bit (b) VALUES
- (0b0000010101),
- (0b0000010101),
- (NULL),
- (0b0000000001),
- (0b0000000000),
- (0b1111111111),
- (0b1111111111),
- (0b1111111111),
- (0b0000000000),
- (0b0000000000),
- (0b0000000000),
- (0b0000000000),
- (0b0000100000);`)
- tk.MustQuery("select conv(b, 2, 2) from `bit`").Check(testkit.Rows(
- "10101",
- "10101",
- "",
- "1",
- "0",
- "1111111111",
- "1111111111",
- "1111111111",
- "0",
- "0",
- "0",
- "0",
- "100000"))
-
- // for abs
- result = tk.MustQuery("SELECT ABS(-1);")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("SELECT ABS('abc');")
- result.Check(testkit.Rows("0"))
- result = tk.MustQuery("SELECT ABS(18446744073709551615);")
- result.Check(testkit.Rows("18446744073709551615"))
- result = tk.MustQuery("SELECT ABS(123.4);")
- result.Check(testkit.Rows("123.4"))
- result = tk.MustQuery("SELECT ABS(-123.4);")
- result.Check(testkit.Rows("123.4"))
- result = tk.MustQuery("SELECT ABS(1234E-1);")
- result.Check(testkit.Rows("123.4"))
- result = tk.MustQuery("SELECT ABS(-9223372036854775807);")
- result.Check(testkit.Rows("9223372036854775807"))
- result = tk.MustQuery("SELECT ABS(NULL);")
- result.Check(testkit.Rows(""))
- rs, err = tk.Exec("SELECT ABS(-9223372036854775808);")
- require.NoError(t, err)
- _, err = session.GetRows4Test(ctx, tk.Session(), rs)
- require.Error(t, err)
- terr = errors.Cause(err).(*terror.Error)
- require.Equal(t, errors.ErrCode(mysql.ErrDataOutOfRange), terr.Code())
- require.NoError(t, rs.Close())
-
- // for round
- result = tk.MustQuery("SELECT ROUND(2.5), ROUND(-2.5), ROUND(25E-1);")
- result.Check(testkit.Rows("3 -3 2"))
- result = tk.MustQuery("SELECT ROUND(2.5, NULL), ROUND(NULL, 4), ROUND(NULL, NULL), ROUND(NULL);")
- result.Check(testkit.Rows(" "))
- result = tk.MustQuery("SELECT ROUND('123.4'), ROUND('123e-2');")
- result.Check(testkit.Rows("123 1"))
- result = tk.MustQuery("SELECT ROUND(-9223372036854775808);")
- result.Check(testkit.Rows("-9223372036854775808"))
- result = tk.MustQuery("SELECT ROUND(123.456, 0), ROUND(123.456, 1), ROUND(123.456, 2), ROUND(123.456, 3), ROUND(123.456, 4), ROUND(123.456, -1), ROUND(123.456, -2), ROUND(123.456, -3), ROUND(123.456, -4);")
- result.Check(testkit.Rows("123 123.5 123.46 123.456 123.4560 120 100 0 0"))
- result = tk.MustQuery("SELECT ROUND(123456E-3, 0), ROUND(123456E-3, 1), ROUND(123456E-3, 2), ROUND(123456E-3, 3), ROUND(123456E-3, 4), ROUND(123456E-3, -1), ROUND(123456E-3, -2), ROUND(123456E-3, -3), ROUND(123456E-3, -4);")
- result.Check(testkit.Rows("123 123.5 123.46 123.456 123.456 120 100 0 0")) // TODO: Column 5 should be 123.4560
- result = tk.MustQuery("SELECT ROUND(1e14, 1), ROUND(1e15, 1), ROUND(1e308, 1)")
- result.Check(testkit.Rows("100000000000000 1000000000000000 100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"))
- result = tk.MustQuery("SELECT ROUND(1e-14, 1), ROUND(1e-15, 1), ROUND(1e-308, 1)")
- result.Check(testkit.Rows("0 0 0"))
-
- // for truncate
- result = tk.MustQuery("SELECT truncate(123, -2), truncate(123, 2), truncate(123, 1), truncate(123, -1);")
- result.Check(testkit.Rows("100 123 123 120"))
- result = tk.MustQuery("SELECT truncate(123.456, -2), truncate(123.456, 2), truncate(123.456, 1), truncate(123.456, 3), truncate(1.23, 100), truncate(123456E-3, 2);")
- result.Check(testkit.Rows("100 123.45 123.4 123.456 1.230000000000000000000000000000 123.45"))
- result = tk.MustQuery("SELECT truncate(9223372036854775807, -7), truncate(9223372036854775808, -10), truncate(cast(-1 as unsigned), -10);")
- result.Check(testkit.Rows("9223372036850000000 9223372030000000000 18446744070000000000"))
- // issue 17181,19390
- tk.MustQuery("select truncate(42, -9223372036854775808);").Check(testkit.Rows("0"))
- tk.MustQuery("select truncate(42, 9223372036854775808);").Check(testkit.Rows("42"))
- tk.MustQuery("select truncate(42, -2147483648);").Check(testkit.Rows("0"))
- tk.MustQuery("select truncate(42, 2147483648);").Check(testkit.Rows("42"))
- tk.MustQuery("select truncate(42, 18446744073709551615);").Check(testkit.Rows("42"))
- tk.MustQuery("select truncate(42, 4294967295);").Check(testkit.Rows("42"))
- tk.MustQuery("select truncate(42, -0);").Check(testkit.Rows("42"))
- tk.MustQuery("select truncate(42, -307);").Check(testkit.Rows("0"))
- tk.MustQuery("select truncate(42, -308);").Check(testkit.Rows("0"))
- tk.MustQuery("select truncate(42, -309);").Check(testkit.Rows("0"))
- tk.MustExec(`drop table if exists t;`)
- tk.MustExec("create table t (a bigint unsigned);")
- tk.MustExec("insert into t values (18446744073709551615), (4294967295), (9223372036854775808), (2147483648);")
- tk.MustQuery("select truncate(42, a) from t;").Check(testkit.Rows("42", "42", "42", "42"))
-
- tk.MustExec(`drop table if exists t;`)
- tk.MustExec(`create table t(a date, b datetime, c timestamp, d varchar(20));`)
- tk.MustExec(`insert into t select "1234-12-29", "1234-12-29 16:24:13.9912", "2014-12-29 16:19:28", "12.34567";`)
-
- // NOTE: the actually result is: 12341220 12341229.0 12341200 12341229.00,
- // but Datum.ToString() don't format decimal length for float numbers.
- result = tk.MustQuery(`select truncate(a, -1), truncate(a, 1), truncate(a, -2), truncate(a, 2) from t;`)
- result.Check(testkit.Rows("12341220 12341229 12341200 12341229"))
-
- // NOTE: the actually result is: 12341229162410 12341229162414.0 12341229162400 12341229162414.00,
- // but Datum.ToString() don't format decimal length for float numbers.
- result = tk.MustQuery(`select truncate(b, -1), truncate(b, 1), truncate(b, -2), truncate(b, 2) from t;`)
- result.Check(testkit.Rows("12341229162410 12341229162414 12341229162400 12341229162414"))
-
- // NOTE: the actually result is: 20141229161920 20141229161928.0 20141229161900 20141229161928.00,
- // but Datum.ToString() don't format decimal length for float numbers.
- result = tk.MustQuery(`select truncate(c, -1), truncate(c, 1), truncate(c, -2), truncate(c, 2) from t;`)
- result.Check(testkit.Rows("20141229161920 20141229161928 20141229161900 20141229161928"))
-
- result = tk.MustQuery(`select truncate(d, -1), truncate(d, 1), truncate(d, -2), truncate(d, 2) from t;`)
- result.Check(testkit.Rows("10 12.3 0 12.34"))
-
- result = tk.MustQuery(`select truncate(json_array(), 1), truncate("cascasc", 1);`)
- result.Check(testkit.Rows("0 0"))
-
- // for pow
- result = tk.MustQuery("SELECT POW('12', 2), POW(1.2e1, '2.0'), POW(12, 2.0);")
- result.Check(testkit.Rows("144 144 144"))
- result = tk.MustQuery("SELECT POW(null, 2), POW(2, null), POW(null, null);")
- result.Check(testkit.Rows(" "))
- result = tk.MustQuery("SELECT POW(0, 0);")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("SELECT POW(0, 0.1), POW(0, 0.5), POW(0, 1);")
- result.Check(testkit.Rows("0 0 0"))
- rs, err = tk.Exec("SELECT POW(0, -1);")
- require.NoError(t, err)
- _, err = session.GetRows4Test(ctx, tk.Session(), rs)
- require.Error(t, err)
- terr = errors.Cause(err).(*terror.Error)
- require.Equal(t, errors.ErrCode(mysql.ErrDataOutOfRange), terr.Code())
- require.NoError(t, rs.Close())
-
- // for sign
- result = tk.MustQuery("SELECT SIGN('12'), SIGN(1.2e1), SIGN(12), SIGN(0.0000012);")
- result.Check(testkit.Rows("1 1 1 1"))
- result = tk.MustQuery("SELECT SIGN('-12'), SIGN(-1.2e1), SIGN(-12), SIGN(-0.0000012);")
- result.Check(testkit.Rows("-1 -1 -1 -1"))
- result = tk.MustQuery("SELECT SIGN('0'), SIGN('-0'), SIGN(0);")
- result.Check(testkit.Rows("0 0 0"))
- result = tk.MustQuery("SELECT SIGN(NULL);")
- result.Check(testkit.Rows(""))
- result = tk.MustQuery("SELECT SIGN(-9223372036854775808), SIGN(9223372036854775808);")
- result.Check(testkit.Rows("-1 1"))
-
- // for sqrt
- result = tk.MustQuery("SELECT SQRT(-10), SQRT(144), SQRT(4.84), SQRT(0.04), SQRT(0);")
- result.Check(testkit.Rows(" 12 2.2 0.2 0"))
-
- // for crc32
- result = tk.MustQuery("SELECT crc32(0), crc32(-0), crc32('0'), crc32('abc'), crc32('ABC'), crc32(NULL), crc32(''), crc32('hello world!')")
- result.Check(testkit.Rows("4108050209 4108050209 4108050209 891568578 2743272264 0 62177901"))
-
- // for radians
- result = tk.MustQuery("SELECT radians(1.0), radians(pi()), radians(pi()/2), radians(180), radians(1.009);")
- result.Check(testkit.Rows("0.017453292519943295 0.05483113556160754 0.02741556778080377 3.141592653589793 0.01761037215262278"))
-
- // for rand
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a int)")
- tk.MustExec("insert into t values(1),(2),(3)")
- tk.Session().GetSessionVars().MaxChunkSize = 1
- tk.MustQuery("select rand(1) from t").Sort().Check(testkit.Rows("0.1418603212962489", "0.40540353712197724", "0.8716141803857071"))
- tk.MustQuery("select rand(a) from t").Check(testkit.Rows("0.40540353712197724", "0.6555866465490187", "0.9057697559760601"))
- tk.MustQuery("select rand(1), rand(2), rand(3)").Check(testkit.Rows("0.40540353712197724 0.6555866465490187 0.9057697559760601"))
- tk.MustQuery("set @@rand_seed1=10000000,@@rand_seed2=1000000")
- tk.MustQuery("select rand()").Check(testkit.Rows("0.028870999839968048"))
- tk.MustQuery("select rand(1)").Check(testkit.Rows("0.40540353712197724"))
- tk.MustQuery("select rand()").Check(testkit.Rows("0.11641535266900002"))
-}
-
func TestTimeBuiltin(t *testing.T) {
store := testkit.CreateMockStore(t)
@@ -1683,776 +1218,6 @@ func TestTimeBuiltin(t *testing.T) {
result.Check(testkit.Rows("2000-01-05 00:00:00.00000"))
}
-func TestBuiltin(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
-
- // for is true && is false
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t (a int, b int, index idx_b (b))")
- tk.MustExec("insert t values (1, 1)")
- tk.MustExec("insert t values (2, 2)")
- tk.MustExec("insert t values (3, 2)")
- result := tk.MustQuery("select * from t where b is true")
- result.Check(testkit.Rows("1 1", "2 2", "3 2"))
- result = tk.MustQuery("select all + a from t where a = 1")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select * from t where a is false")
- result.Check(nil)
- result = tk.MustQuery("select * from t where a is not true")
- result.Check(nil)
- result = tk.MustQuery(`select 1 is true, 0 is true, null is true, "aaa" is true, "" is true, -12.00 is true, 0.0 is true, 0.0000001 is true;`)
- result.Check(testkit.Rows("1 0 0 0 0 1 0 1"))
- result = tk.MustQuery(`select 1 is false, 0 is false, null is false, "aaa" is false, "" is false, -12.00 is false, 0.0 is false, 0.0000001 is false;`)
- result.Check(testkit.Rows("0 1 0 1 1 0 1 0"))
- // Issue https://github.com/pingcap/tidb/issues/19986
- result = tk.MustQuery("select 1 from dual where sec_to_time(2/10) is true")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select 1 from dual where sec_to_time(2/10) is false")
- result.Check(nil)
- // Issue https://github.com/pingcap/tidb/issues/19999
- result = tk.MustQuery("select 1 from dual where timediff((7/'2014-07-07 02:30:02'),'2012-01-16') is true")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select 1 from dual where timediff((7/'2014-07-07 02:30:02'),'2012-01-16') is false")
- result.Check(nil)
- // Issue https://github.com/pingcap/tidb/issues/20001
- result = tk.MustQuery("select 1 from dual where time(0.0001) is true")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select 1 from dual where time(0.0001) is false")
- result.Check(nil)
-
- // for in
- result = tk.MustQuery("select * from t where b in (a)")
- result.Check(testkit.Rows("1 1", "2 2"))
- result = tk.MustQuery("select * from t where b not in (a)")
- result.Check(testkit.Rows("3 2"))
-
- // test cast
- result = tk.MustQuery("select cast(1 as decimal(3,2))")
- result.Check(testkit.Rows("1.00"))
- result = tk.MustQuery("select cast('1991-09-05 11:11:11' as datetime)")
- result.Check(testkit.Rows("1991-09-05 11:11:11"))
- result = tk.MustQuery("select cast(cast('1991-09-05 11:11:11' as datetime) as char)")
- result.Check(testkit.Rows("1991-09-05 11:11:11"))
- result = tk.MustQuery("select cast('11:11:11' as time)")
- result.Check(testkit.Rows("11:11:11"))
- result = tk.MustQuery("select * from t where a > cast(2 as decimal)")
- result.Check(testkit.Rows("3 2"))
- result = tk.MustQuery("select cast(-1 as unsigned)")
- result.Check(testkit.Rows("18446744073709551615"))
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a decimal(3, 1), b double, c datetime, d time, e int)")
- tk.MustExec("insert into t value(12.3, 1.23, '2017-01-01 12:12:12', '12:12:12', 123)")
- result = tk.MustQuery("select cast(a as json), cast(b as json), cast(c as json), cast(d as json), cast(e as json) from t")
- result.Check(testkit.Rows(`12.3 1.23 "2017-01-01 12:12:12.000000" "12:12:12.000000" 123`))
- result = tk.MustQuery(`select cast(10101000000 as time);`)
- result.Check(testkit.Rows("00:00:00"))
- result = tk.MustQuery(`select cast(10101001000 as time);`)
- result.Check(testkit.Rows("00:10:00"))
- result = tk.MustQuery(`select cast(10000000000 as time);`)
- result.Check(testkit.Rows(""))
- result = tk.MustQuery(`select cast(20171222020005 as time);`)
- result.Check(testkit.Rows("02:00:05"))
- result = tk.MustQuery(`select cast(8380000 as time);`)
- result.Check(testkit.Rows("838:00:00"))
- result = tk.MustQuery(`select cast(8390000 as time);`)
- result.Check(testkit.Rows(""))
- result = tk.MustQuery(`select cast(8386000 as time);`)
- result.Check(testkit.Rows(""))
- result = tk.MustQuery(`select cast(8385960 as time);`)
- result.Check(testkit.Rows(""))
- result = tk.MustQuery(`select cast(cast('2017-01-01 01:01:11.12' as date) as datetime(2));`)
- result.Check(testkit.Rows("2017-01-01 00:00:00.00"))
- result = tk.MustQuery(`select cast(20170118.999 as datetime);`)
- result.Check(testkit.Rows("2017-01-18 00:00:00"))
- tk.MustQuery(`select convert(a2.a, unsigned int) from (select cast('"9223372036854775808"' as json) as a) as a2;`)
-
- tk.MustExec(`create table tb5(a bigint(64) unsigned, b double);`)
- tk.MustExec(`insert into tb5 (a, b) values (9223372036854776000, 9223372036854776000);`)
- tk.MustExec(`insert into tb5 (a, b) select * from (select cast(a as json) as a1, b from tb5) as t where t.a1 = t.b;`)
- tk.MustExec(`drop table tb5;`)
-
- tk.MustExec(`create table tb5(a float(53));`)
- tk.MustExec(`insert into tb5(a) values (13835058055282163712);`)
- tk.MustQuery(`select convert(t.a1, signed int) from (select convert(a, json) as a1 from tb5) as t`)
- tk.MustExec(`drop table tb5;`)
-
- // test builtinCastIntAsIntSig
- // Cast MaxUint64 to unsigned should be -1
- tk.MustQuery("select cast(0xffffffffffffffff as signed);").Check(testkit.Rows("-1"))
- tk.MustQuery("select cast(0x9999999999999999999999999999999999999999999 as signed);").Check(testkit.Rows("-1"))
- tk.MustExec("create table tb5(a bigint);")
- tk.MustExec("set sql_mode=''")
- tk.MustExec("insert into tb5(a) values (0xfffffffffffffffffffffffff);")
- tk.MustQuery("select * from tb5;").Check(testkit.Rows("9223372036854775807"))
- tk.MustExec("drop table tb5;")
-
- tk.MustExec(`create table tb5(a double);`)
- tk.MustExec(`insert into test.tb5 (a) values (18446744073709551616);`)
- tk.MustExec(`insert into test.tb5 (a) values (184467440737095516160);`)
- result = tk.MustQuery(`select cast(a as unsigned) from test.tb5;`)
- // Note: MySQL will return 9223372036854775807, and it should be a bug.
- result.Check(testkit.Rows("18446744073709551615", "18446744073709551615"))
- tk.MustExec(`drop table tb5;`)
-
- // test builtinCastIntAsDecimalSig
- tk.MustExec(`create table tb5(a bigint(64) unsigned, b decimal(64, 10));`)
- tk.MustExec(`insert into tb5 (a, b) values (9223372036854775808, 9223372036854775808);`)
- tk.MustExec(`insert into tb5 (select * from tb5 where a = b);`)
- result = tk.MustQuery(`select * from tb5;`)
- result.Check(testkit.Rows("9223372036854775808 9223372036854775808.0000000000", "9223372036854775808 9223372036854775808.0000000000"))
- tk.MustExec(`drop table tb5;`)
-
- // test builtinCastIntAsRealSig
- tk.MustExec(`create table tb5(a bigint(64) unsigned, b double(64, 10));`)
- tk.MustExec(`insert into tb5 (a, b) values (13835058000000000000, 13835058000000000000);`)
- tk.MustExec(`insert into tb5 (select * from tb5 where a = b);`)
- result = tk.MustQuery(`select * from tb5;`)
- result.Check(testkit.Rows("13835058000000000000 13835058000000000000", "13835058000000000000 13835058000000000000"))
- tk.MustExec(`drop table tb5;`)
-
- // test builtinCastRealAsIntSig
- tk.MustExec(`create table tb5(a double, b float);`)
- tk.MustExec(`insert into tb5 (a, b) values (184467440737095516160, 184467440737095516160);`)
- tk.MustQuery(`select * from tb5 where cast(a as unsigned int)=0;`).Check(testkit.Rows())
- tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1690 constant 1.844674407370955e+20 overflows bigint"))
- _ = tk.MustQuery(`select * from tb5 where cast(b as unsigned int)=0;`)
- tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1690 constant 1.844674407370955e+20 overflows bigint"))
- tk.MustExec(`drop table tb5;`)
- tk.MustExec(`create table tb5(a double, b bigint unsigned);`)
- tk.MustExec(`insert into tb5 (a, b) values (18446744073709551616, 18446744073709551615);`)
- _ = tk.MustQuery(`select * from tb5 where cast(a as unsigned int)=b;`)
- // TODO `obtained string = "[18446744073709552000 18446744073709551615]`
- // result.Check(testkit.Rows("18446744073709551616 18446744073709551615"))
- tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1690 constant 1.8446744073709552e+19 overflows bigint"))
- tk.MustExec(`drop table tb5;`)
-
- // test builtinCastJSONAsIntSig
- tk.MustExec(`create table tb5(a json, b bigint unsigned);`)
- tk.MustExec(`insert into tb5 (a, b) values ('184467440737095516160', 18446744073709551615);`)
- _ = tk.MustQuery(`select * from tb5 where cast(a as unsigned int)=b;`)
- tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1690 constant 1.844674407370955e+20 overflows bigint"))
- _ = tk.MustQuery(`select * from tb5 where cast(b as unsigned int)=0;`)
- tk.MustQuery("show warnings;").Check(testkit.Rows())
- tk.MustExec(`drop table tb5;`)
- tk.MustExec(`create table tb5(a json, b bigint unsigned);`)
- tk.MustExec(`insert into tb5 (a, b) values ('92233720368547758080', 18446744073709551615);`)
- _ = tk.MustQuery(`select * from tb5 where cast(a as signed int)=b;`)
- tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1690 constant 9.223372036854776e+19 overflows bigint"))
- tk.MustExec(`drop table tb5;`)
-
- // test builtinCastIntAsStringSig
- tk.MustExec(`create table tb5(a bigint(64) unsigned,b varchar(50));`)
- tk.MustExec(`insert into tb5(a, b) values (9223372036854775808, '9223372036854775808');`)
- tk.MustExec(`insert into tb5(select * from tb5 where a = b);`)
- result = tk.MustQuery(`select * from tb5;`)
- result.Check(testkit.Rows("9223372036854775808 9223372036854775808", "9223372036854775808 9223372036854775808"))
- tk.MustExec(`drop table tb5;`)
-
- // test builtinCastIntAsDecimalSig
- tk.MustExec(`drop table if exists tb5`)
- tk.MustExec(`create table tb5 (a decimal(65), b bigint(64) unsigned);`)
- tk.MustExec(`insert into tb5 (a, b) values (9223372036854775808, 9223372036854775808);`)
- result = tk.MustQuery(`select cast(b as decimal(64)) from tb5 union all select b from tb5;`)
- result.Check(testkit.Rows("9223372036854775808", "9223372036854775808"))
- tk.MustExec(`drop table tb5`)
-
- // test builtinCastIntAsRealSig
- tk.MustExec(`drop table if exists tb5`)
- tk.MustExec(`create table tb5 (a bigint(64) unsigned, b double(64, 10));`)
- tk.MustExec(`insert into tb5 (a, b) values (9223372036854775808, 9223372036854775808);`)
- result = tk.MustQuery(`select a from tb5 where a = b union all select b from tb5;`)
- result.Check(testkit.Rows("9223372036854776000", "9223372036854776000"))
- tk.MustExec(`drop table tb5`)
-
- // Test corner cases of cast string as datetime
- result = tk.MustQuery(`select cast("170102034" as datetime);`)
- result.Check(testkit.Rows("2017-01-02 03:04:00"))
- result = tk.MustQuery(`select cast("1701020304" as datetime);`)
- result.Check(testkit.Rows("2017-01-02 03:04:00"))
- result = tk.MustQuery(`select cast("1701020304." as datetime);`)
- result.Check(testkit.Rows("2017-01-02 03:04:00"))
- result = tk.MustQuery(`select cast("1701020304.1" as datetime);`)
- result.Check(testkit.Rows("2017-01-02 03:04:01"))
- result = tk.MustQuery(`select cast("1701020304.111" as datetime);`)
- result.Check(testkit.Rows("2017-01-02 03:04:11"))
- tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1292 Truncated incorrect datetime value: '1701020304.111'"))
- result = tk.MustQuery(`select cast("17011" as datetime);`)
- result.Check(testkit.Rows("2017-01-01 00:00:00"))
- result = tk.MustQuery(`select cast("150101." as datetime);`)
- result.Check(testkit.Rows("2015-01-01 00:00:00"))
- result = tk.MustQuery(`select cast("150101.a" as datetime);`)
- result.Check(testkit.Rows("2015-01-01 00:00:00"))
- tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1292 Truncated incorrect datetime value: '150101.a'"))
- result = tk.MustQuery(`select cast("150101.1a" as datetime);`)
- result.Check(testkit.Rows("2015-01-01 01:00:00"))
- tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1292 Truncated incorrect datetime value: '150101.1a'"))
- result = tk.MustQuery(`select cast("150101.1a1" as datetime);`)
- result.Check(testkit.Rows("2015-01-01 01:00:00"))
- tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1292 Truncated incorrect datetime value: '150101.1a1'"))
- result = tk.MustQuery(`select cast("1101010101.111" as datetime);`)
- result.Check(testkit.Rows("2011-01-01 01:01:11"))
- tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1292 Truncated incorrect datetime value: '1101010101.111'"))
- result = tk.MustQuery(`select cast("1101010101.11aaaaa" as datetime);`)
- result.Check(testkit.Rows("2011-01-01 01:01:11"))
- tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1292 Truncated incorrect datetime value: '1101010101.11aaaaa'"))
- result = tk.MustQuery(`select cast("1101010101.a1aaaaa" as datetime);`)
- result.Check(testkit.Rows("2011-01-01 01:01:00"))
- tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1292 Truncated incorrect datetime value: '1101010101.a1aaaaa'"))
- result = tk.MustQuery(`select cast("1101010101.11" as datetime);`)
- result.Check(testkit.Rows("2011-01-01 01:01:11"))
- tk.MustQuery("select @@warning_count;").Check(testkit.Rows("0"))
- result = tk.MustQuery(`select cast("1101010101.111" as datetime);`)
- result.Check(testkit.Rows("2011-01-01 01:01:11"))
- tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1292 Truncated incorrect datetime value: '1101010101.111'"))
- result = tk.MustQuery(`select cast("970101.111" as datetime);`)
- result.Check(testkit.Rows("1997-01-01 11:01:00"))
- tk.MustQuery("select @@warning_count;").Check(testkit.Rows("0"))
- result = tk.MustQuery(`select cast("970101.11111" as datetime);`)
- result.Check(testkit.Rows("1997-01-01 11:11:01"))
- tk.MustQuery("select @@warning_count;").Check(testkit.Rows("0"))
- result = tk.MustQuery(`select cast("970101.111a1" as datetime);`)
- result.Check(testkit.Rows("1997-01-01 11:01:00"))
- tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1292 Truncated incorrect datetime value: '970101.111a1'"))
-
- // for ISNULL
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t (a int, b int, c int, d char(10), e datetime, f float, g decimal(10, 3))")
- tk.MustExec("insert t values (1, 0, null, null, null, null, null)")
- result = tk.MustQuery("select ISNULL(a), ISNULL(b), ISNULL(c), ISNULL(d), ISNULL(e), ISNULL(f), ISNULL(g) from t")
- result.Check(testkit.Rows("0 0 1 1 1 1 1"))
-
- // fix issue #3942
- result = tk.MustQuery("select cast('-24 100:00:00' as time);")
- result.Check(testkit.Rows("-676:00:00"))
- result = tk.MustQuery("select cast('12:00:00.000000' as datetime);")
- result.Check(testkit.Rows("2012-00-00 00:00:00"))
- result = tk.MustQuery("select cast('-34 100:00:00' as time);")
- result.Check(testkit.Rows("-838:59:59"))
-
- // fix issue #4324. cast decimal/int/string to time compatibility.
- invalidTimes := []string{
- "10009010",
- "239010",
- "233070",
- "23:90:10",
- "23:30:70",
- "239010.2",
- "233070.8",
- }
- tk.MustExec("DROP TABLE IF EXISTS t;")
- tk.MustExec("CREATE TABLE t (ix TIME);")
- tk.MustExec("SET SQL_MODE='';")
- for _, invalidTime := range invalidTimes {
- msg := fmt.Sprintf("Warning 1292 Truncated incorrect time value: '%s'", invalidTime)
- result = tk.MustQuery(fmt.Sprintf("select cast('%s' as time);", invalidTime))
- result.Check(testkit.Rows(""))
- result = tk.MustQuery("show warnings")
- result.Check(testkit.Rows(msg))
- _, err := tk.Exec(fmt.Sprintf("insert into t select cast('%s' as time);", invalidTime))
- require.NoError(t, err)
- result = tk.MustQuery("show warnings")
- result.Check(testkit.Rows(msg))
- }
- tk.MustExec("set sql_mode = 'STRICT_TRANS_TABLES'")
- for _, invalidTime := range invalidTimes {
- msg := fmt.Sprintf("Warning 1292 Truncated incorrect time value: '%s'", invalidTime)
- result = tk.MustQuery(fmt.Sprintf("select cast('%s' as time);", invalidTime))
- result.Check(testkit.Rows(""))
- result = tk.MustQuery("show warnings")
- result.Check(testkit.Rows(msg))
- _, err := tk.Exec(fmt.Sprintf("insert into t select cast('%s' as time);", invalidTime))
- require.Error(t, err, fmt.Sprintf("[types:1292]Truncated incorrect time value: '%s'", invalidTime))
- }
-
- // Fix issue #3691, cast compatibility.
- result = tk.MustQuery("select cast('18446744073709551616' as unsigned);")
- result.Check(testkit.Rows("18446744073709551615"))
- result = tk.MustQuery("select cast('18446744073709551616' as signed);")
- result.Check(testkit.Rows("-1"))
- result = tk.MustQuery("select cast('9223372036854775808' as signed);")
- result.Check(testkit.Rows("-9223372036854775808"))
- result = tk.MustQuery("select cast('9223372036854775809' as signed);")
- result.Check(testkit.Rows("-9223372036854775807"))
- result = tk.MustQuery("select cast('9223372036854775807' as signed);")
- result.Check(testkit.Rows("9223372036854775807"))
- result = tk.MustQuery("select cast('18446744073709551615' as signed);")
- result.Check(testkit.Rows("-1"))
- result = tk.MustQuery("select cast('18446744073709551614' as signed);")
- result.Check(testkit.Rows("-2"))
- result = tk.MustQuery("select cast(18446744073709551615 as unsigned);")
- result.Check(testkit.Rows("18446744073709551615"))
- result = tk.MustQuery("select cast(18446744073709551616 as unsigned);")
- result.Check(testkit.Rows("18446744073709551615"))
- result = tk.MustQuery("select cast(18446744073709551616 as signed);")
- result.Check(testkit.Rows("9223372036854775807"))
- result = tk.MustQuery("select cast(18446744073709551617 as signed);")
- result.Check(testkit.Rows("9223372036854775807"))
- result = tk.MustQuery("select cast(18446744073709551615 as signed);")
- result.Check(testkit.Rows("-1"))
- result = tk.MustQuery("select cast(18446744073709551614 as signed);")
- result.Check(testkit.Rows("-2"))
- result = tk.MustQuery("select cast(-18446744073709551616 as signed);")
- result.Check(testkit.Rows("-9223372036854775808"))
- result = tk.MustQuery("select cast(18446744073709551614.9 as unsigned);") // Round up
- result.Check(testkit.Rows("18446744073709551615"))
- result = tk.MustQuery("select cast(18446744073709551614.4 as unsigned);") // Round down
- result.Check(testkit.Rows("18446744073709551614"))
- result = tk.MustQuery("select cast(-9223372036854775809 as signed);")
- result.Check(testkit.Rows("-9223372036854775808"))
- result = tk.MustQuery("select cast(-9223372036854775809 as unsigned);")
- result.Check(testkit.Rows("0"))
- result = tk.MustQuery("select cast(-9223372036854775808 as unsigned);")
- result.Check(testkit.Rows("9223372036854775808"))
- result = tk.MustQuery("select cast('-9223372036854775809' as unsigned);")
- result.Check(testkit.Rows("9223372036854775808"))
- result = tk.MustQuery("select cast('-9223372036854775807' as unsigned);")
- result.Check(testkit.Rows("9223372036854775809"))
- result = tk.MustQuery("select cast('-2' as unsigned);")
- result.Check(testkit.Rows("18446744073709551614"))
- result = tk.MustQuery("select cast(cast(1-2 as unsigned) as signed integer);")
- result.Check(testkit.Rows("-1"))
- result = tk.MustQuery("select cast(1 as signed int)")
- result.Check(testkit.Rows("1"))
-
- // test cast as double
- result = tk.MustQuery("select cast(1 as double)")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select cast(cast(12345 as unsigned) as double)")
- result.Check(testkit.Rows("12345"))
- result = tk.MustQuery("select cast(1.1 as double)")
- result.Check(testkit.Rows("1.1"))
- result = tk.MustQuery("select cast(-1.1 as double)")
- result.Check(testkit.Rows("-1.1"))
- result = tk.MustQuery("select cast('123.321' as double)")
- result.Check(testkit.Rows("123.321"))
- result = tk.MustQuery("select cast('12345678901234567890' as double) = 1.2345678901234567e19")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select cast(-1 as double)")
- result.Check(testkit.Rows("-1"))
- result = tk.MustQuery("select cast(null as double)")
- result.Check(testkit.Rows(""))
- result = tk.MustQuery("select cast(12345678901234567890 as double) = 1.2345678901234567e19")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select cast(cast(-1 as unsigned) as double) = 1.8446744073709552e19")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select cast(1e100 as double) = 1e100")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select cast(123456789012345678901234567890 as double) = 1.2345678901234568e29")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select cast(0x12345678 as double)")
- result.Check(testkit.Rows("305419896"))
-
- // test cast as float
- result = tk.MustQuery("select cast(1 as float)")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select cast(cast(12345 as unsigned) as float)")
- result.Check(testkit.Rows("12345"))
- result = tk.MustQuery("select cast(1.1 as float) = 1.1")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select cast(-1.1 as float) = -1.1")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select cast('123.321' as float) =123.321")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select cast('12345678901234567890' as float) = 1.2345678901234567e19")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select cast(-1 as float)")
- result.Check(testkit.Rows("-1"))
- result = tk.MustQuery("select cast(null as float)")
- result.Check(testkit.Rows(""))
- result = tk.MustQuery("select cast(12345678901234567890 as float) = 1.2345678901234567e19")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select cast(cast(-1 as unsigned) as float) = 1.8446744073709552e19")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select cast(1e100 as float(40)) = 1e100")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select cast(123456789012345678901234567890 as float(40)) = 1.2345678901234568e29")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select cast(0x12345678 as float(40)) = 305419896")
- result.Check(testkit.Rows("1"))
-
- // test cast as real
- result = tk.MustQuery("select cast(1 as real)")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select cast(cast(12345 as unsigned) as real)")
- result.Check(testkit.Rows("12345"))
- result = tk.MustQuery("select cast(1.1 as real) = 1.1")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select cast(-1.1 as real) = -1.1")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select cast('123.321' as real) =123.321")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select cast('12345678901234567890' as real) = 1.2345678901234567e19")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select cast(-1 as real)")
- result.Check(testkit.Rows("-1"))
- result = tk.MustQuery("select cast(null as real)")
- result.Check(testkit.Rows(""))
- result = tk.MustQuery("select cast(12345678901234567890 as real) = 1.2345678901234567e19")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select cast(cast(-1 as unsigned) as real) = 1.8446744073709552e19")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select cast(1e100 as real) = 1e100")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select cast(123456789012345678901234567890 as real) = 1.2345678901234568e29")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select cast(0x12345678 as real) = 305419896")
- result.Check(testkit.Rows("1"))
-
- // test cast time as decimal overflow
- tk.MustExec("drop table if exists t1")
- tk.MustExec("create table t1(s1 time);")
- tk.MustExec("insert into t1 values('11:11:11');")
- result = tk.MustQuery("select cast(s1 as decimal(7, 2)) from t1;")
- result.Check(testkit.Rows("99999.99"))
- result = tk.MustQuery("select cast(s1 as decimal(8, 2)) from t1;")
- result.Check(testkit.Rows("111111.00"))
- _, err := tk.Exec("insert into t1 values(cast('111111.00' as decimal(7, 2)));")
- require.Error(t, err)
-
- result = tk.MustQuery(`select CAST(0x8fffffffffffffff as signed) a,
- CAST(0xfffffffffffffffe as signed) b,
- CAST(0xffffffffffffffff as unsigned) c;`)
- result.Check(testkit.Rows("-8070450532247928833 -2 18446744073709551615"))
-
- result = tk.MustQuery(`select cast("1:2:3" as TIME) = "1:02:03"`)
- result.Check(testkit.Rows("0"))
-
- // fixed issue #3471
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a time(6));")
- tk.MustExec("insert into t value('12:59:59.999999')")
- result = tk.MustQuery("select cast(a as signed) from t")
- result.Check(testkit.Rows("130000"))
-
- // fixed issue #3762
- result = tk.MustQuery("select -9223372036854775809;")
- result.Check(testkit.Rows("-9223372036854775809"))
- result = tk.MustQuery("select --9223372036854775809;")
- result.Check(testkit.Rows("9223372036854775809"))
- result = tk.MustQuery("select -9223372036854775808;")
- result.Check(testkit.Rows("-9223372036854775808"))
-
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a bigint(30));")
- _, err = tk.Exec("insert into t values(-9223372036854775809)")
- require.Error(t, err)
-
- // test case decimal precision less than the scale.
- _, err = tk.Exec("select cast(12.1 as decimal(3, 4));")
- require.Error(t, err, "[types:1427]For float(M,D), double(M,D) or decimal(M,D), M must be >= D (column '12.1').")
-
- // test case cast(EXPR as datetime(x)) precision more than the scale(6).
- _, err = tk.Exec("SELECT CAST(1 AS DATETIME(7));")
- require.Error(t, err, "[types:1427]Too big precision 7 specified for column 'CAST'. Maximum is 6.")
-
- // test unhex and hex
- result = tk.MustQuery("select unhex('4D7953514C')")
- result.Check(testkit.Rows("MySQL"))
- result = tk.MustQuery("select unhex(hex('string'))")
- result.Check(testkit.Rows("string"))
- result = tk.MustQuery("select unhex('ggg')")
- result.Check(testkit.Rows(""))
- result = tk.MustQuery("select unhex(-1)")
- result.Check(testkit.Rows(""))
- result = tk.MustQuery("select hex(unhex('1267'))")
- result.Check(testkit.Rows("1267"))
- result = tk.MustQuery("select hex(unhex(1267))")
- result.Check(testkit.Rows("1267"))
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a binary(8))")
- tk.MustExec(`insert into t values('test')`)
- result = tk.MustQuery("select hex(a) from t")
- result.Check(testkit.Rows("7465737400000000"))
- result = tk.MustQuery("select unhex(a) from t")
- result.Check(testkit.Rows(""))
-
- // select from_unixtime
- // NOTE (#17013): make from_unixtime stable in different timezone: the result of from_unixtime
- // depends on the local time zone of the test environment, thus the result checking must
- // consider the time zone convert.
- tz := tk.Session().GetSessionVars().StmtCtx.TimeZone
- result = tk.MustQuery("select from_unixtime(1451606400)")
- unixTime := time.Unix(1451606400, 0).In(tz).String()[:19]
- result.Check(testkit.Rows(unixTime))
- result = tk.MustQuery("select from_unixtime(14516064000/10)")
- result.Check(testkit.Rows(fmt.Sprintf("%s.0000", unixTime)))
- result = tk.MustQuery("select from_unixtime('14516064000'/10)")
- result.Check(testkit.Rows(fmt.Sprintf("%s.000000", unixTime)))
- result = tk.MustQuery("select from_unixtime(cast(1451606400 as double))")
- result.Check(testkit.Rows(fmt.Sprintf("%s.000000", unixTime)))
- result = tk.MustQuery("select from_unixtime(cast(cast(1451606400 as double) as DECIMAL))")
- result.Check(testkit.Rows(unixTime))
- result = tk.MustQuery("select from_unixtime(cast(cast(1451606400 as double) as DECIMAL(65,1)))")
- result.Check(testkit.Rows(fmt.Sprintf("%s.0", unixTime)))
- result = tk.MustQuery("select from_unixtime(1451606400.123456)")
- unixTime = time.Unix(1451606400, 123456000).In(tz).String()[:26]
- result.Check(testkit.Rows(unixTime))
- result = tk.MustQuery("select from_unixtime(1451606400.1234567)")
- unixTime = time.Unix(1451606400, 123456700).In(tz).Round(time.Microsecond).Format("2006-01-02 15:04:05.000000")[:26]
- result.Check(testkit.Rows(unixTime))
- result = tk.MustQuery("select from_unixtime(1451606400.999999)")
- unixTime = time.Unix(1451606400, 999999000).In(tz).String()[:26]
- result.Check(testkit.Rows(unixTime))
- result = tk.MustQuery("select from_unixtime(1511247196661)")
- result.Check(testkit.Rows(""))
- result = tk.MustQuery("select from_unixtime('1451606400.123');")
- unixTime = time.Unix(1451606400, 0).In(tz).String()[:19]
- result.Check(testkit.Rows(fmt.Sprintf("%s.123000", unixTime)))
-
- tk.MustExec("drop table if exists t;")
- tk.MustExec("create table t(a int);")
- tk.MustExec("insert into t value(1451606400);")
- result = tk.MustQuery("select from_unixtime(a) from t;")
- result.Check(testkit.Rows(unixTime))
-
- // test strcmp
- result = tk.MustQuery("select strcmp('abc', 'def')")
- result.Check(testkit.Rows("-1"))
- result = tk.MustQuery("select strcmp('abc', 'aba')")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select strcmp('abc', 'abc')")
- result.Check(testkit.Rows("0"))
- result = tk.MustQuery("select substr(null, 1, 2)")
- result.Check(testkit.Rows(""))
- result = tk.MustQuery("select substr('123', null, 2)")
- result.Check(testkit.Rows(""))
- result = tk.MustQuery("select substr('123', 1, null)")
- result.Check(testkit.Rows(""))
-
- // for case
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t (a varchar(255), b int)")
- tk.MustExec("insert t values ('str1', 1)")
- result = tk.MustQuery("select * from t where a = case b when 1 then 'str1' when 2 then 'str2' end")
- result.Check(testkit.Rows("str1 1"))
- result = tk.MustQuery("select * from t where a = case b when 1 then 'str2' when 2 then 'str3' end")
- result.Check(nil)
- tk.MustExec("insert t values ('str2', 2)")
- result = tk.MustQuery("select * from t where a = case b when 2 then 'str2' when 3 then 'str3' end")
- result.Check(testkit.Rows("str2 2"))
- tk.MustExec("insert t values ('str3', 3)")
- result = tk.MustQuery("select * from t where a = case b when 4 then 'str4' when 5 then 'str5' else 'str3' end")
- result.Check(testkit.Rows("str3 3"))
- result = tk.MustQuery("select * from t where a = case b when 4 then 'str4' when 5 then 'str5' else 'str6' end")
- result.Check(nil)
- result = tk.MustQuery("select * from t where a = case when b then 'str3' when 1 then 'str1' else 'str2' end")
- result.Check(testkit.Rows("str3 3"))
- tk.MustExec("delete from t")
- tk.MustExec("insert t values ('str2', 0)")
- result = tk.MustQuery("select * from t where a = case when b then 'str3' when 0 then 'str1' else 'str2' end")
- result.Check(testkit.Rows("str2 0"))
- tk.MustExec("insert t values ('str1', null)")
- result = tk.MustQuery("select * from t where a = case b when null then 'str3' when 10 then 'str1' else 'str2' end")
- result.Check(testkit.Rows("str2 0"))
- result = tk.MustQuery("select * from t where a = case null when b then 'str3' when 10 then 'str1' else 'str2' end")
- result.Check(testkit.Rows("str2 0"))
- tk.MustExec("insert t values (null, 4)")
- result = tk.MustQuery("select * from t where b < case a when null then 0 when 'str2' then 0 else 9 end")
- result.Check(testkit.Rows(" 4"))
- result = tk.MustQuery("select * from t where b = case when a is null then 4 when a = 'str5' then 7 else 9 end")
- result.Check(testkit.Rows(" 4"))
- result = tk.MustQuery(`SELECT -Max(+23) * -+Cast(--10 AS SIGNED) * -CASE
- WHEN 0 > 85 THEN NULL
- WHEN NOT
- CASE +55
- WHEN +( +82 ) + -89 * -69 THEN +Count(-88)
- WHEN +CASE 57
- WHEN +89 THEN -89 * Count(*)
- WHEN 17 THEN NULL
- END THEN ( -10 )
- END IS NULL THEN NULL
- ELSE 83 + 48
- END AS col0; `)
- result.Check(testkit.Rows("-30130"))
-
- // return type of case when expr should not include NotNullFlag. issue-23036
- tk.MustExec("drop table if exists t1")
- tk.MustExec("create table t1(c1 int not null)")
- tk.MustExec("insert into t1 values(1)")
- result = tk.MustQuery("select (case when null then c1 end) is null from t1")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select (case when null then c1 end) is not null from t1")
- result.Check(testkit.Rows("0"))
-
- // test warnings
- tk.MustQuery("select case when b=0 then 1 else 1/b end from t")
- tk.MustQuery("show warnings").Check(testkit.Rows())
- tk.MustQuery("select if(b=0, 1, 1/b) from t")
- tk.MustQuery("show warnings").Check(testkit.Rows())
- tk.MustQuery("select ifnull(b, b/0) from t")
- tk.MustQuery("show warnings").Check(testkit.Rows())
-
- tk.MustQuery("select case when 1 then 1 else 1/0 end")
- tk.MustQuery("show warnings").Check(testkit.Rows())
- tk.MustQuery(" select if(1,1,1/0)")
- tk.MustQuery("show warnings").Check(testkit.Rows())
- tk.MustQuery("select ifnull(1, 1/0)")
- tk.MustQuery("show warnings").Check(testkit.Rows())
-
- tk.MustExec("delete from t")
- tk.MustExec("insert t values ('str2', 0)")
- tk.MustQuery("select case when b < 1 then 1 else 1/0 end from t")
- tk.MustQuery("show warnings").Check(testkit.Rows())
- tk.MustQuery("select case when b < 1 then 1 when 1/0 then b else 1/0 end from t")
- tk.MustQuery("show warnings").Check(testkit.Rows())
- tk.MustQuery("select if(b < 1 , 1, 1/0) from t")
- tk.MustQuery("show warnings").Check(testkit.Rows())
- tk.MustQuery("select ifnull(b, 1/0) from t")
- tk.MustQuery("show warnings").Check(testkit.Rows())
- tk.MustQuery("select COALESCE(1, b, b/0) from t")
- tk.MustQuery("show warnings").Check(testkit.Rows())
- tk.MustQuery("select 0 and b/0 from t")
- tk.MustQuery("show warnings").Check(testkit.Rows())
- tk.MustQuery("select 1 or b/0 from t")
- tk.MustQuery("show warnings").Check(testkit.Rows())
-
- tk.MustQuery("select 1 or 1/0")
- tk.MustQuery("show warnings").Check(testkit.Rows())
- tk.MustQuery("select 0 and 1/0")
- tk.MustQuery("show warnings").Check(testkit.Rows())
- tk.MustQuery("select COALESCE(1, 1/0)")
- tk.MustQuery("show warnings").Check(testkit.Rows())
- tk.MustQuery("select interval(1,0,1,2,1/0)")
- tk.MustQuery("show warnings").Check(testkit.Rows())
-
- tk.MustQuery("select case 2.0 when 2.0 then 3.0 when 3.0 then 2.0 end").Check(testkit.Rows("3.0"))
- tk.MustQuery("select case 2.0 when 3.0 then 2.0 when 4.0 then 3.0 else 5.0 end").Check(testkit.Rows("5.0"))
- tk.MustQuery("select case cast('2011-01-01' as date) when cast('2011-01-01' as date) then cast('2011-02-02' as date) end").Check(testkit.Rows("2011-02-02"))
- tk.MustQuery("select case cast('2012-01-01' as date) when cast('2011-01-01' as date) then cast('2011-02-02' as date) else cast('2011-03-03' as date) end").Check(testkit.Rows("2011-03-03"))
- tk.MustQuery("select case cast('10:10:10' as time) when cast('10:10:10' as time) then cast('11:11:11' as time) end").Check(testkit.Rows("11:11:11"))
- tk.MustQuery("select case cast('10:10:13' as time) when cast('10:10:10' as time) then cast('11:11:11' as time) else cast('22:22:22' as time) end").Check(testkit.Rows("22:22:22"))
-
- // for cast
- result = tk.MustQuery("select cast(1234 as char(3))")
- result.Check(testkit.Rows("123"))
- result = tk.MustQuery("select cast(1234 as char(0))")
- result.Check(testkit.Rows(""))
- result = tk.MustQuery("show warnings")
- result.Check(testkit.Rows("Warning 1406 Data Too Long, field len 0, data len 4"))
- result = tk.MustQuery("select CAST( - 8 AS DECIMAL ) * + 52 + 87 < - 86")
- result.Check(testkit.Rows("1"))
-
- // for char
- result = tk.MustQuery("select char(97, 100, 256, 89)")
- result.Check(testkit.Rows("ad\x01\x00Y"))
- result = tk.MustQuery("select char(97, null, 100, 256, 89)")
- result.Check(testkit.Rows("ad\x01\x00Y"))
- result = tk.MustQuery("select char(97, null, 100, 256, 89 using utf8)")
- result.Check(testkit.Rows("ad\x01\x00Y"))
- result = tk.MustQuery("select char(97, null, 100, 256, 89 using ascii)")
- result.Check(testkit.Rows("ad\x01\x00Y"))
- err = tk.ExecToErr("select char(97, null, 100, 256, 89 using tidb)")
- require.Error(t, err, "[parser:1115]Unknown character set: 'tidb'")
-
- // issue 3884
- tk.MustExec("drop table if exists t")
- tk.MustExec("CREATE TABLE t (c1 date, c2 datetime, c3 timestamp, c4 time, c5 year);")
- tk.MustExec("INSERT INTO t values ('2000-01-01', '2000-01-01 12:12:12', '2000-01-01 12:12:12', '12:12:12', '2000');")
- tk.MustExec("INSERT INTO t values ('2000-02-01', '2000-02-01 12:12:12', '2000-02-01 12:12:12', '13:12:12', 2000);")
- tk.MustExec("INSERT INTO t values ('2000-03-01', '2000-03-01', '2000-03-01 12:12:12', '1 12:12:12', 2000);")
- tk.MustExec("INSERT INTO t SET c1 = '2000-04-01', c2 = '2000-04-01', c3 = '2000-04-01 12:12:12', c4 = '-1 13:12:12', c5 = 2000;")
- result = tk.MustQuery("SELECT c4 FROM t where c4 < '-13:12:12';")
- result.Check(testkit.Rows("-37:12:12"))
- result = tk.MustQuery(`SELECT 1 DIV - - 28 + ( - SUM( - + 25 ) ) * - CASE - 18 WHEN 44 THEN NULL ELSE - 41 + 32 + + - 70 - + COUNT( - 95 ) * 15 END + 92`)
- result.Check(testkit.Rows("2442"))
-
- // for regexp, rlike
- // https://github.com/pingcap/tidb/issues/4080
- tk.MustExec(`drop table if exists t;`)
- tk.MustExec(`create table t (a char(10), b varchar(10), c binary(10), d varbinary(10));`)
- tk.MustExec(`insert into t values ('text','text','text','text');`)
- result = tk.MustQuery(`select a regexp 'xt' from t;`)
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery(`select b regexp 'xt' from t;`)
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery(`select b regexp binary 'Xt' from t;`)
- result.Check(testkit.Rows("0"))
- result = tk.MustQuery(`select c regexp 'Xt' from t;`)
- result.Check(testkit.Rows("0"))
- result = tk.MustQuery(`select d regexp 'Xt' from t;`)
- result.Check(testkit.Rows("0"))
- result = tk.MustQuery(`select a rlike 'xt' from t;`)
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery(`select a rlike binary 'Xt' from t;`)
- result.Check(testkit.Rows("0"))
- result = tk.MustQuery(`select b rlike 'xt' from t;`)
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery(`select c rlike 'Xt' from t;`)
- result.Check(testkit.Rows("0"))
- result = tk.MustQuery(`select d rlike 'Xt' from t;`)
- result.Check(testkit.Rows("0"))
- result = tk.MustQuery(`select 'a' regexp 'A', 'a' regexp binary 'A'`)
- result.Check(testkit.Rows("0 0"))
-
- // testCase is for like and regexp
- type testCase struct {
- pattern string
- val string
- result int
- }
- patternMatching := func(tk *testkit.TestKit, queryOp string, data []testCase) {
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t (a varchar(255), b int)")
- for i, d := range data {
- tk.MustExec(fmt.Sprintf("insert into t values('%s', %d)", d.val, i))
- result = tk.MustQuery(fmt.Sprintf("select * from t where a %s '%s'", queryOp, d.pattern))
- if d.result == 1 {
- rowStr := fmt.Sprintf("%s %d", d.val, i)
- result.Check(testkit.Rows(rowStr))
- } else {
- result.Check(nil)
- }
- tk.MustExec(fmt.Sprintf("delete from t where b = %d", i))
- }
- }
- // for like
- likeTests := []testCase{
- {"a", "a", 1},
- {"a", "b", 0},
- {"aA", "Aa", 0},
- {`aA%`, "aAab", 1},
- {"aA_", "Aaab", 0},
- {"Aa_", "Aab", 1},
- {"", "", 1},
- {"", "a", 0},
- }
- patternMatching(tk, "like", likeTests)
- // for regexp
- likeTests = []testCase{
- {"^$", "a", 0},
- {"a", "a", 1},
- {"a", "b", 0},
- {"aA", "aA", 1},
- {".", "a", 1},
- {"^.$", "ab", 0},
- {"..", "b", 0},
- {".ab", "aab", 1},
- {"ab.", "abcd", 1},
- {".*", "abcd", 1},
- }
- patternMatching(tk, "regexp", likeTests)
-
- // for #9838
- result = tk.MustQuery("select cast(1 as signed) + cast(9223372036854775807 as unsigned);")
- result.Check(testkit.Rows("9223372036854775808"))
- result = tk.MustQuery("select cast(9223372036854775807 as unsigned) + cast(1 as signed);")
- result.Check(testkit.Rows("9223372036854775808"))
- err = tk.QueryToErr("select cast(9223372036854775807 as signed) + cast(9223372036854775809 as unsigned);")
- require.Error(t, err)
- err = tk.QueryToErr("select cast(9223372036854775809 as unsigned) + cast(9223372036854775807 as signed);")
- require.Error(t, err)
- err = tk.QueryToErr("select cast(-9223372036854775807 as signed) + cast(9223372036854775806 as unsigned);")
- require.Error(t, err)
- err = tk.QueryToErr("select cast(9223372036854775806 as unsigned) + cast(-9223372036854775807 as signed);")
- require.Error(t, err)
-
- result = tk.MustQuery(`select 1 / '2007' div 1;`)
- result.Check(testkit.Rows("0"))
-}
-
func TestSetVariables(t *testing.T) {
store := testkit.CreateMockStore(t)
@@ -2535,6 +1300,14 @@ func TestSetVariables(t *testing.T) {
_, err = tk.Exec("set @@global.max_prepared_stmt_count='';")
require.Error(t, err)
require.Error(t, err, variable.ErrWrongTypeForVar.GenWithStackByArgs("max_prepared_stmt_count").Error())
+
+ // Previously global values were cached. This is incorrect.
+ // See: https://github.com/pingcap/tidb/issues/24368
+ tk.MustQuery("SHOW VARIABLES LIKE 'max_connections'").Check(testkit.Rows("max_connections 0"))
+ tk.MustExec("SET GLOBAL max_connections=1234")
+ tk.MustQuery("SHOW VARIABLES LIKE 'max_connections'").Check(testkit.Rows("max_connections 1234"))
+ // restore
+ tk.MustExec("SET GLOBAL max_connections=0")
}
func TestPreparePlanCache(t *testing.T) {
@@ -2694,27 +1467,6 @@ func TestCacheConstEval(t *testing.T) {
tk.MustExec("admin reload expr_pushdown_blacklist")
}
-func TestNullValueRange(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a int, b int, index(a))")
- tk.MustExec("insert into t values (null, 0), (null, 1), (10, 11), (10, 12)")
- tk.MustQuery("select * from t use index(a) where a is null order by b").Check(testkit.Rows(" 0", " 1"))
- tk.MustQuery("select * from t use index(a) where a<=>null order by b").Check(testkit.Rows(" 0", " 1"))
- tk.MustQuery("select * from t use index(a) where a<=>10 order by b").Check(testkit.Rows("10 11", "10 12"))
-
- tk.MustExec("drop table if exists t1")
- tk.MustExec("create table t1(a int, b int, c int, unique key(a, b, c))")
- tk.MustExec("insert into t1 values (1, null, 1), (1, null, 2), (1, null, 3), (1, null, 4)")
- tk.MustExec("insert into t1 values (1, 1, 1), (1, 2, 2), (1, 3, 33), (1, 4, 44)")
- tk.MustQuery("select c from t1 where a=1 and b<=>null and c>2 order by c").Check(testkit.Rows("3", "4"))
- tk.MustQuery("select c from t1 where a=1 and b is null and c>2 order by c").Check(testkit.Rows("3", "4"))
- tk.MustQuery("select c from t1 where a=1 and b is not null and c>2 order by c").Check(testkit.Rows("33", "44"))
-}
-
// issues 14448, 19383, 17734
func TestNoopFunctions(t *testing.T) {
store := testkit.CreateMockStore(t)
@@ -2907,34 +1659,6 @@ PARTITION BY RANGE (c) (
tk.MustExec("set global tidb_enable_local_txn = off;")
}
-func TestPartitionPruningRelaxOP(t *testing.T) {
- // Discovered while looking at issue 19941 (not completely related)
- // relaxOP relax the op > to >= and < to <=
- // Sometime we need to relax the condition, for example:
- // col < const => f(col) <= const
- // datetime < 2020-02-11 16:18:42 => to_days(datetime) <= to_days(2020-02-11)
- // We can't say:
- // datetime < 2020-02-11 16:18:42 => to_days(datetime) < to_days(2020-02-11)
-
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
-
- tk.MustExec("DROP TABLE IF EXISTS t1;")
- tk.MustExec(`CREATE TABLE t1 (d date NOT NULL) PARTITION BY RANGE (YEAR(d))
- (PARTITION p2016 VALUES LESS THAN (2017), PARTITION p2017 VALUES LESS THAN (2018), PARTITION p2018 VALUES LESS THAN (2019),
- PARTITION p2019 VALUES LESS THAN (2020), PARTITION pmax VALUES LESS THAN MAXVALUE)`)
-
- tk.MustExec(`INSERT INTO t1 VALUES ('2016-01-01'), ('2016-06-01'), ('2016-09-01'), ('2017-01-01'),
- ('2017-06-01'), ('2017-09-01'), ('2018-01-01'), ('2018-06-01'), ('2018-09-01'), ('2018-10-01'),
- ('2018-11-01'), ('2018-12-01'), ('2018-12-31'), ('2019-01-01'), ('2019-06-01'), ('2019-09-01'),
- ('2020-01-01'), ('2020-06-01'), ('2020-09-01');`)
-
- tk.MustQuery("SELECT COUNT(*) FROM t1 WHERE d < '2018-01-01'").Check(testkit.Rows("6"))
- tk.MustQuery("SELECT COUNT(*) FROM t1 WHERE d > '2018-01-01'").Check(testkit.Rows("12"))
-}
-
func TestTiDBRowChecksumBuiltin(t *testing.T) {
store := testkit.CreateMockStore(t)
diff --git a/expression/integration_test/BUILD.bazel b/expression/integration_test/BUILD.bazel
index 37a8f62afde99..7305f7f9d7a86 100644
--- a/expression/integration_test/BUILD.bazel
+++ b/expression/integration_test/BUILD.bazel
@@ -8,7 +8,7 @@ go_test(
"main_test.go",
],
flaky = True,
- shard_count = 50,
+ shard_count = 27,
deps = [
"//config",
"//domain",
@@ -31,7 +31,6 @@ go_test(
"//util/codec",
"//util/collate",
"//util/sem",
- "//util/sqlexec",
"//util/timeutil",
"//util/versioninfo",
"@com_github_pingcap_errors//:errors",
diff --git a/expression/integration_test/integration_test.go b/expression/integration_test/integration_test.go
index 6de5266f8a9c9..5901c4c56b15f 100644
--- a/expression/integration_test/integration_test.go
+++ b/expression/integration_test/integration_test.go
@@ -46,63 +46,11 @@ import (
"github.com/pingcap/tidb/util/codec"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tidb/util/sem"
- "github.com/pingcap/tidb/util/sqlexec"
"github.com/pingcap/tidb/util/versioninfo"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
-func TestFuncREPEAT(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
-
- tk.MustExec("USE test")
- tk.MustExec("DROP TABLE IF EXISTS table_string;")
- tk.MustExec("CREATE TABLE table_string(a CHAR(20), b VARCHAR(20), c TINYTEXT, d TEXT(20), e MEDIUMTEXT, f LONGTEXT, g BIGINT);")
- tk.MustExec("INSERT INTO table_string (a, b, c, d, e, f, g) VALUES ('a', 'b', 'c', 'd', 'e', 'f', 2);")
- tk.CheckExecResult(1, 0)
-
- r := tk.MustQuery("SELECT REPEAT(a, g), REPEAT(b, g), REPEAT(c, g), REPEAT(d, g), REPEAT(e, g), REPEAT(f, g) FROM table_string;")
- r.Check(testkit.Rows("aa bb cc dd ee ff"))
-
- r = tk.MustQuery("SELECT REPEAT(NULL, g), REPEAT(NULL, g), REPEAT(NULL, g), REPEAT(NULL, g), REPEAT(NULL, g), REPEAT(NULL, g) FROM table_string;")
- r.Check(testkit.Rows(" "))
-
- r = tk.MustQuery("SELECT REPEAT(a, NULL), REPEAT(b, NULL), REPEAT(c, NULL), REPEAT(d, NULL), REPEAT(e, NULL), REPEAT(f, NULL) FROM table_string;")
- r.Check(testkit.Rows(" "))
-
- r = tk.MustQuery("SELECT REPEAT(a, 2), REPEAT(b, 2), REPEAT(c, 2), REPEAT(d, 2), REPEAT(e, 2), REPEAT(f, 2) FROM table_string;")
- r.Check(testkit.Rows("aa bb cc dd ee ff"))
-
- r = tk.MustQuery("SELECT REPEAT(NULL, 2), REPEAT(NULL, 2), REPEAT(NULL, 2), REPEAT(NULL, 2), REPEAT(NULL, 2), REPEAT(NULL, 2) FROM table_string;")
- r.Check(testkit.Rows(" "))
-
- r = tk.MustQuery("SELECT REPEAT(a, -1), REPEAT(b, -2), REPEAT(c, -2), REPEAT(d, -2), REPEAT(e, -2), REPEAT(f, -2) FROM table_string;")
- r.Check(testkit.Rows(" "))
-
- r = tk.MustQuery("SELECT REPEAT(a, 0), REPEAT(b, 0), REPEAT(c, 0), REPEAT(d, 0), REPEAT(e, 0), REPEAT(f, 0) FROM table_string;")
- r.Check(testkit.Rows(" "))
-
- r = tk.MustQuery("SELECT REPEAT(a, 16777217), REPEAT(b, 16777217), REPEAT(c, 16777217), REPEAT(d, 16777217), REPEAT(e, 16777217), REPEAT(f, 16777217) FROM table_string;")
- r.Check(testkit.Rows(" "))
-}
-
-func TestFuncLpadAndRpad(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
-
- tk.MustExec(`USE test;`)
- tk.MustExec(`DROP TABLE IF EXISTS t;`)
- tk.MustExec(`CREATE TABLE t(a BINARY(10), b CHAR(10));`)
- tk.MustExec(`INSERT INTO t SELECT "中文", "abc";`)
- result := tk.MustQuery(`SELECT LPAD(a, 11, "a"), LPAD(b, 2, "xx") FROM t;`)
- result.Check(testkit.Rows("a中文\x00\x00\x00\x00 ab"))
- result = tk.MustQuery(`SELECT RPAD(a, 11, "a"), RPAD(b, 2, "xx") FROM t;`)
- result.Check(testkit.Rows("中文\x00\x00\x00\x00a ab"))
-}
-
func TestGetLock(t *testing.T) {
ctx := context.Background()
store := testkit.CreateMockStore(t)
@@ -340,472 +288,6 @@ func TestMiscellaneousBuiltin(t *testing.T) {
tk.MustQuery(`SELECT RELEASE_ALL_LOCKS()`).Check(testkit.Rows("0")) // none acquired
}
-func TestConvertToBit(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- tk.MustExec("drop table if exists t, t1")
- tk.MustExec("create table t (a bit(64))")
- tk.MustExec("create table t1 (a varchar(2))")
- tk.MustExec(`insert t1 value ('10')`)
- tk.MustExec(`insert t select a from t1`)
- tk.MustQuery("select a+0 from t").Check(testkit.Rows("12592"))
-
- tk.MustExec("drop table if exists t, t1")
- tk.MustExec("create table t (a bit(64))")
- tk.MustExec("create table t1 (a binary(2))")
- tk.MustExec(`insert t1 value ('10')`)
- tk.MustExec(`insert t select a from t1`)
- tk.MustQuery("select a+0 from t").Check(testkit.Rows("12592"))
-
- tk.MustExec("drop table if exists t, t1")
- tk.MustExec("create table t (a bit(64))")
- tk.MustExec("create table t1 (a datetime)")
- tk.MustExec(`insert t1 value ('09-01-01')`)
- tk.MustExec(`insert t select a from t1`)
- tk.MustQuery("select a+0 from t").Check(testkit.Rows("20090101000000"))
-}
-
-func TestStringBuiltin(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- var err error
-
- // for length
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a int, b double, c datetime, d time, e char(20), f bit(10))")
- tk.MustExec(`insert into t values(1, 1.1, "2017-01-01 12:01:01", "12:01:01", "abcdef", 0b10101)`)
- result := tk.MustQuery("select length(a), length(b), length(c), length(d), length(e), length(f), length(null) from t")
- result.Check(testkit.Rows("1 3 19 8 6 2 "))
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a char(20))")
- tk.MustExec(`insert into t values("tidb "), (concat("a ", "b "))`)
- result = tk.MustQuery("select a, length(a) from t")
- result.Check(testkit.Rows("tidb 4", "a b 4"))
-
- // for concat
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a int, b double, c datetime, d time, e char(20))")
- tk.MustExec(`insert into t values(1, 1.1, "2017-01-01 12:01:01", "12:01:01", "abcdef")`)
- result = tk.MustQuery("select concat(a, b, c, d, e) from t")
- result.Check(testkit.Rows("11.12017-01-01 12:01:0112:01:01abcdef"))
- result = tk.MustQuery("select concat(null)")
- result.Check(testkit.Rows(""))
- result = tk.MustQuery("select concat(null, a, b) from t")
- result.Check(testkit.Rows(""))
-
- // for concat_ws
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a int, b double, c datetime, d time, e char(20))")
- tk.MustExec(`insert into t values(1, 1.1, "2017-01-01 12:01:01", "12:01:01", "abcdef")`)
- result = tk.MustQuery("select concat_ws('|', a, b, c, d, e) from t")
- result.Check(testkit.Rows("1|1.1|2017-01-01 12:01:01|12:01:01|abcdef"))
- result = tk.MustQuery("select concat_ws(null, null)")
- result.Check(testkit.Rows(""))
- result = tk.MustQuery("select concat_ws(null, a, b) from t")
- result.Check(testkit.Rows(""))
- result = tk.MustQuery("select concat_ws(',', 'a', 'b')")
- result.Check(testkit.Rows("a,b"))
- result = tk.MustQuery("select concat_ws(',','First name',NULL,'Last Name')")
- result.Check(testkit.Rows("First name,Last Name"))
-
- tk.MustExec(`drop table if exists t;`)
- tk.MustExec(`create table t(a tinyint(2), b varchar(10));`)
- tk.MustExec(`insert into t values (1, 'a'), (12, 'a'), (126, 'a'), (127, 'a')`)
- tk.MustQuery(`select concat_ws('#', a, b) from t;`).Check(testkit.Rows(
- `1#a`,
- `12#a`,
- `126#a`,
- `127#a`,
- ))
-
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a binary(3))")
- tk.MustExec("insert into t values('a')")
- result = tk.MustQuery(`select concat_ws(',', a, 'test') = 'a\0\0,test' from t`)
- result.Check(testkit.Rows("1"))
-
- // for ascii
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a char(10), b int, c double, d datetime, e time, f bit(4))")
- tk.MustExec(`insert into t values('2', 2, 2.3, "2017-01-01 12:01:01", "12:01:01", 0b1010)`)
- result = tk.MustQuery("select ascii(a), ascii(b), ascii(c), ascii(d), ascii(e), ascii(f) from t")
- result.Check(testkit.Rows("50 50 50 50 49 10"))
- result = tk.MustQuery("select ascii('123'), ascii(123), ascii(''), ascii('你好'), ascii(NULL)")
- result.Check(testkit.Rows("49 49 0 228 "))
-
- // for lower
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a int, b double, c datetime, d time, e char(20), f binary(3), g binary(3))")
- tk.MustExec(`insert into t values(1, 1.1, "2017-01-01 12:01:01", "12:01:01", "abcdef", 'aa', 'BB')`)
- result = tk.MustQuery("select lower(a), lower(b), lower(c), lower(d), lower(e), lower(f), lower(g), lower(null) from t")
- result.Check(testkit.Rows("1 1.1 2017-01-01 12:01:01 12:01:01 abcdef aa\x00 BB\x00 "))
-
- // for upper
- result = tk.MustQuery("select upper(a), upper(b), upper(c), upper(d), upper(e), upper(f), upper(g), upper(null) from t")
- result.Check(testkit.Rows("1 1.1 2017-01-01 12:01:01 12:01:01 ABCDEF aa\x00 BB\x00 "))
-
- // for strcmp
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a char(10), b int, c double, d datetime, e time)")
- tk.MustExec(`insert into t values("123", 123, 12.34, "2017-01-01 12:01:01", "12:01:01")`)
- result = tk.MustQuery(`select strcmp(a, "123"), strcmp(b, "123"), strcmp(c, "12.34"), strcmp(d, "2017-01-01 12:01:01"), strcmp(e, "12:01:01") from t`)
- result.Check(testkit.Rows("0 0 0 0 0"))
- result = tk.MustQuery(`select strcmp("1", "123"), strcmp("123", "1"), strcmp("123", "45"), strcmp("123", null), strcmp(null, "123")`)
- result.Check(testkit.Rows("-1 1 -1 "))
- result = tk.MustQuery(`select strcmp("", "123"), strcmp("123", ""), strcmp("", ""), strcmp("", null), strcmp(null, "")`)
- result.Check(testkit.Rows("-1 1 0 "))
-
- // for left
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a char(10), b int, c double, d datetime, e time)")
- tk.MustExec(`insert into t values('abcde', 1234, 12.34, "2017-01-01 12:01:01", "12:01:01")`)
- result = tk.MustQuery("select left(a, 2), left(b, 2), left(c, 2), left(d, 2), left(e, 2) from t")
- result.Check(testkit.Rows("ab 12 12 20 12"))
- result = tk.MustQuery(`select left("abc", 0), left("abc", -1), left(NULL, 1), left("abc", NULL)`)
- result.Check(testkit.Rows(" "))
- result = tk.MustQuery(`select left("abc", "a"), left("abc", 1.9), left("abc", 1.2)`)
- result.Check(testkit.Rows(" ab a"))
- result = tk.MustQuery(`select left("中文abc", 2), left("中文abc", 3), left("中文abc", 4)`)
- result.Check(testkit.Rows("中文 中文a 中文ab"))
- // for right, reuse the table created for left
- result = tk.MustQuery("select right(a, 3), right(b, 3), right(c, 3), right(d, 3), right(e, 3) from t")
- result.Check(testkit.Rows("cde 234 .34 :01 :01"))
- result = tk.MustQuery(`select right("abcde", 0), right("abcde", -1), right("abcde", 100), right(NULL, 1), right("abcde", NULL)`)
- result.Check(testkit.Rows(" abcde "))
- result = tk.MustQuery(`select right("abcde", "a"), right("abcde", 1.9), right("abcde", 1.2)`)
- result.Check(testkit.Rows(" de e"))
- result = tk.MustQuery(`select right("中文abc", 2), right("中文abc", 4), right("中文abc", 5)`)
- result.Check(testkit.Rows("bc 文abc 中文abc"))
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a binary(10))")
- tk.MustExec(`insert into t select "中文abc"`)
- result = tk.MustQuery(`select left(a, 3), left(a, 6), left(a, 7) from t`)
- result.Check(testkit.Rows("中 中文 中文a"))
- result = tk.MustQuery(`select right(a, 2), right(a, 7) from t`)
- result.Check(testkit.Rows("c\x00 文abc\x00"))
-
- // for ord
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a char(10), b int, c double, d datetime, e time, f bit(4), g binary(20), h blob(10), i text(30))")
- tk.MustExec(`insert into t values('2', 2, 2.3, "2017-01-01 12:01:01", "12:01:01", 0b1010, "512", "48", "tidb")`)
- result = tk.MustQuery("select ord(a), ord(b), ord(c), ord(d), ord(e), ord(f), ord(g), ord(h), ord(i) from t")
- result.Check(testkit.Rows("50 50 50 50 49 10 53 52 116"))
- result = tk.MustQuery("select ord('123'), ord(123), ord(''), ord('你好'), ord(NULL), ord('👍')")
- result.Check(testkit.Rows("49 49 0 14990752 4036989325"))
- result = tk.MustQuery("select ord(X''), ord(X'6161'), ord(X'e4bd'), ord(X'e4bda0'), ord(_ascii'你'), ord(_latin1'你')")
- result.Check(testkit.Rows("0 97 228 228 228 228"))
-
- // for space
- result = tk.MustQuery(`select space(0), space(2), space(-1), space(1.1), space(1.9)`)
- result.Check(testkit.RowsWithSep(",", ", ,, , "))
- result = tk.MustQuery(`select space("abc"), space("2"), space("1.1"), space(''), space(null)`)
- result.Check(testkit.RowsWithSep(",", ", , ,,"))
-
- // for replace
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a char(20), b int, c double, d datetime, e time)")
- tk.MustExec(`insert into t values('www.mysql.com', 1234, 12.34, "2017-01-01 12:01:01", "12:01:01")`)
- result = tk.MustQuery(`select replace(a, 'mysql', 'pingcap'), replace(b, 2, 55), replace(c, 34, 0), replace(d, '-', '/'), replace(e, '01', '22') from t`)
- result.Check(testkit.RowsWithSep(",", "www.pingcap.com,15534,12.0,2017/01/01 12:01:01,12:22:22"))
- result = tk.MustQuery(`select replace('aaa', 'a', ''), replace(null, 'a', 'b'), replace('a', null, 'b'), replace('a', 'b', null)`)
- result.Check(testkit.Rows(" "))
-
- // for tobase64
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a int, b double, c datetime, d time, e char(20), f bit(10), g binary(20), h blob(10))")
- tk.MustExec(`insert into t values(1, 1.1, "2017-01-01 12:01:01", "12:01:01", "abcdef", 0b10101, "512", "abc")`)
- result = tk.MustQuery("select to_base64(a), to_base64(b), to_base64(c), to_base64(d), to_base64(e), to_base64(f), to_base64(g), to_base64(h), to_base64(null) from t")
- result.Check(testkit.Rows("MQ== MS4x MjAxNy0wMS0wMSAxMjowMTowMQ== MTI6MDE6MDE= YWJjZGVm ABU= NTEyAAAAAAAAAAAAAAAAAAAAAAA= YWJj "))
-
- // for from_base64
- result = tk.MustQuery(`select from_base64("abcd"), from_base64("asc")`)
- result.Check(testkit.Rows("i\xb7\x1d "))
- result = tk.MustQuery(`select from_base64("MQ=="), from_base64(1234)`)
- result.Check(testkit.Rows("1 \xd7m\xf8"))
-
- // for substr
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a char(10), b int, c double, d datetime, e time)")
- tk.MustExec(`insert into t values('Sakila', 12345, 123.45, "2017-01-01 12:01:01", "12:01:01")`)
- result = tk.MustQuery(`select substr(a, 3), substr(b, 2, 3), substr(c, -3), substr(d, -8), substr(e, -3, 100) from t`)
- result.Check(testkit.Rows("kila 234 .45 12:01:01 :01"))
- result = tk.MustQuery(`select substr('Sakila', 100), substr('Sakila', -100), substr('Sakila', -5, 3), substr('Sakila', 2, -1)`)
- result.Check(testkit.RowsWithSep(",", ",,aki,"))
- result = tk.MustQuery(`select substr('foobarbar' from 4), substr('Sakila' from -4 for 2)`)
- result.Check(testkit.Rows("barbar ki"))
- result = tk.MustQuery(`select substr(null, 2, 3), substr('foo', null, 3), substr('foo', 2, null)`)
- result.Check(testkit.Rows(" "))
- result = tk.MustQuery(`select substr('中文abc', 2), substr('中文abc', 3), substr("中文abc", 1, 2)`)
- result.Check(testkit.Rows("文abc abc 中文"))
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a binary(10))")
- tk.MustExec(`insert into t select "中文abc"`)
- result = tk.MustQuery(`select substr(a, 4), substr(a, 1, 3), substr(a, 1, 6) from t`)
- result.Check(testkit.Rows("文abc\x00 中 中文"))
- result = tk.MustQuery(`select substr("string", -1), substr("string", -2), substr("中文", -1), substr("中文", -2) from t`)
- result.Check(testkit.Rows("g ng 文 中文"))
-
- // for bit_length
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a int, b double, c datetime, d time, e char(20), f bit(10), g binary(20), h varbinary(20))")
- tk.MustExec(`insert into t values(1, 1.1, "2017-01-01 12:01:01", "12:01:01", "abcdef", 0b10101, "g", "h")`)
- result = tk.MustQuery("select bit_length(a), bit_length(b), bit_length(c), bit_length(d), bit_length(e), bit_length(f), bit_length(g), bit_length(h), bit_length(null) from t")
- result.Check(testkit.Rows("8 24 152 64 48 16 160 8 "))
-
- // for substring_index
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a char(20), b int, c double, d datetime, e time)")
- tk.MustExec(`insert into t values('www.pingcap.com', 12345, 123.45, "2017-01-01 12:01:01", "12:01:01")`)
- result = tk.MustQuery(`select substring_index(a, '.', 2), substring_index(b, '.', 2), substring_index(c, '.', -1), substring_index(d, '-', 1), substring_index(e, ':', -2) from t`)
- result.Check(testkit.Rows("www.pingcap 12345 45 2017 01:01"))
- result = tk.MustQuery(`select substring_index('www.pingcap.com', '.', 0), substring_index('www.pingcap.com', '.', 100), substring_index('www.pingcap.com', '.', -100)`)
- result.Check(testkit.Rows(" www.pingcap.com www.pingcap.com"))
- result = tk.MustQuery(`select substring_index('www.pingcap.com', 'd', 1), substring_index('www.pingcap.com', '', 1), substring_index('', '.', 1)`)
- result.Check(testkit.RowsWithSep(",", "www.pingcap.com,,"))
- result = tk.MustQuery(`select substring_index(null, '.', 1), substring_index('www.pingcap.com', null, 1), substring_index('www.pingcap.com', '.', null)`)
- result.Check(testkit.Rows(" "))
-
- // for substring_index with overflow
- tk.MustQuery(`select substring_index('xyz', 'abc', 9223372036854775808)`).Check(testkit.Rows(`xyz`))
- tk.MustQuery(`select substring_index("aaa.bbb.ccc.ddd.eee",'.',18446744073709551613);`).Check(testkit.Rows(`aaa.bbb.ccc.ddd.eee`))
- tk.MustQuery(`select substring_index("aaa.bbb.ccc.ddd.eee",'.',-18446744073709551613);`).Check(testkit.Rows(`aaa.bbb.ccc.ddd.eee`))
- tk.MustQuery(`select substring_index('aaa.bbb.ccc.ddd.eee', '.', 18446744073709551615 - 1 + id) from (select 1 as id) as t1`).Check(testkit.Rows(`aaa.bbb.ccc.ddd.eee`))
- tk.MustQuery(`select substring_index('aaa.bbb.ccc.ddd.eee', '.', -18446744073709551615 - 1 + id) from (select 1 as id) as t1`).Check(testkit.Rows(`aaa.bbb.ccc.ddd.eee`))
-
- tk.MustExec("set tidb_enable_vectorized_expression = 0;")
- tk.MustQuery(`select substring_index("aaa.bbb.ccc.ddd.eee",'.',18446744073709551613);`).Check(testkit.Rows(`aaa.bbb.ccc.ddd.eee`))
- tk.MustQuery(`select substring_index("aaa.bbb.ccc.ddd.eee",'.',-18446744073709551613);`).Check(testkit.Rows(`aaa.bbb.ccc.ddd.eee`))
- tk.MustQuery(`select substring_index('aaa.bbb.ccc.ddd.eee', '.', 18446744073709551615 - 1 + id) from (select 1 as id) as t1`).Check(testkit.Rows(`aaa.bbb.ccc.ddd.eee`))
- tk.MustQuery(`select substring_index('aaa.bbb.ccc.ddd.eee', '.', -18446744073709551615 - 1 + id) from (select 1 as id) as t1`).Check(testkit.Rows(`aaa.bbb.ccc.ddd.eee`))
- tk.MustExec("set tidb_enable_vectorized_expression = 1;")
-
- // for hex
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a char(20), b int, c double, d datetime, e time, f decimal(5, 2), g bit(4))")
- tk.MustExec(`insert into t values('www.pingcap.com', 12345, 123.45, "2017-01-01 12:01:01", "12:01:01", 123.45, 0b1100)`)
- result = tk.MustQuery(`select hex(a), hex(b), hex(c), hex(d), hex(e), hex(f), hex(g) from t`)
- result.Check(testkit.Rows("7777772E70696E676361702E636F6D 3039 7B 323031372D30312D30312031323A30313A3031 31323A30313A3031 7B C"))
- result = tk.MustQuery(`select hex('abc'), hex('你好'), hex(12), hex(12.3), hex(12.8)`)
- result.Check(testkit.Rows("616263 E4BDA0E5A5BD C C D"))
- result = tk.MustQuery(`select hex(-1), hex(-12.3), hex(-12.8), hex(0x12), hex(null)`)
- result.Check(testkit.Rows("FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFF4 FFFFFFFFFFFFFFF3 12 "))
- tk.MustExec("drop table if exists t")
- tk.MustExec("CREATE TABLE t(i int primary key auto_increment, a binary, b binary(0), c binary(20), d binary(255)) character set utf8 collate utf8_bin;")
- tk.MustExec("insert into t(a, b, c, d) values ('a', NULL, 'a','a');")
- tk.MustQuery("select i, hex(a), hex(b), hex(c), hex(d) from t;").Check(testkit.Rows("1 61 6100000000000000000000000000000000000000 610000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"))
-
- // for unhex
- result = tk.MustQuery(`select unhex('4D7953514C'), unhex('313233'), unhex(313233), unhex('')`)
- result.Check(testkit.Rows("MySQL 123 123 "))
- result = tk.MustQuery(`select unhex('string'), unhex('你好'), unhex(123.4), unhex(null)`)
- result.Check(testkit.Rows(" "))
-
- // for ltrim and rtrim
- result = tk.MustQuery(`select ltrim(' bar '), ltrim('bar'), ltrim(''), ltrim(null)`)
- result.Check(testkit.RowsWithSep(",", "bar ,bar,,"))
- result = tk.MustQuery(`select rtrim(' bar '), rtrim('bar'), rtrim(''), rtrim(null)`)
- result.Check(testkit.RowsWithSep(",", " bar,bar,,"))
- result = tk.MustQuery(`select ltrim("\t bar "), ltrim(" \tbar"), ltrim("\n bar"), ltrim("\r bar")`)
- result.Check(testkit.RowsWithSep(",", "\t bar ,\tbar,\n bar,\r bar"))
- result = tk.MustQuery(`select rtrim(" bar \t"), rtrim("bar\t "), rtrim("bar \n"), rtrim("bar \r")`)
- result.Check(testkit.RowsWithSep(",", " bar \t,bar\t,bar \n,bar \r"))
-
- // for reverse
- tk.MustExec(`DROP TABLE IF EXISTS t;`)
- tk.MustExec(`CREATE TABLE t(a BINARY(6));`)
- tk.MustExec(`INSERT INTO t VALUES("中文");`)
- result = tk.MustQuery(`SELECT a, REVERSE(a), REVERSE("中文"), REVERSE("123 ") FROM t;`)
- result.Check(testkit.Rows("中文 \x87\x96歸\xe4 文中 321"))
- result = tk.MustQuery(`SELECT REVERSE(123), REVERSE(12.09) FROM t;`)
- result.Check(testkit.Rows("321 90.21"))
-
- // for trim
- result = tk.MustQuery(`select trim(' bar '), trim(leading 'x' from 'xxxbarxxx'), trim(trailing 'xyz' from 'barxxyz'), trim(both 'x' from 'xxxbarxxx')`)
- result.Check(testkit.Rows("bar barxxx barx bar"))
- result = tk.MustQuery(`select trim('\t bar\n '), trim(' \rbar \t')`)
- result.Check(testkit.RowsWithSep(",", "\t bar\n,\rbar \t"))
- result = tk.MustQuery(`select trim(leading from ' bar'), trim('x' from 'xxxbarxxx'), trim('x' from 'bar'), trim('' from ' bar ')`)
- result.Check(testkit.RowsWithSep(",", "bar,bar,bar, bar "))
- result = tk.MustQuery(`select trim(''), trim('x' from '')`)
- result.Check(testkit.RowsWithSep(",", ","))
- result = tk.MustQuery(`select trim(null from 'bar'), trim('x' from null), trim(null), trim(leading null from 'bar')`)
- result.Check(testkit.Rows(" "))
-
- // for locate
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a char(20), b int, c double, d datetime, e time, f binary(5))")
- tk.MustExec(`insert into t values('www.pingcap.com', 12345, 123.45, "2017-01-01 12:01:01", "12:01:01", "HelLo")`)
- result = tk.MustQuery(`select locate(".ping", a), locate(".ping", a, 5) from t`)
- result.Check(testkit.Rows("4 0"))
- result = tk.MustQuery(`select locate("234", b), locate("235", b, 10) from t`)
- result.Check(testkit.Rows("2 0"))
- result = tk.MustQuery(`select locate(".45", c), locate(".35", b) from t`)
- result.Check(testkit.Rows("4 0"))
- result = tk.MustQuery(`select locate("El", f), locate("ll", f), locate("lL", f), locate("Lo", f), locate("lo", f) from t`)
- result.Check(testkit.Rows("0 0 3 4 0"))
- result = tk.MustQuery(`select locate("01 12", d) from t`)
- result.Check(testkit.Rows("9"))
- result = tk.MustQuery(`select locate("文", "中文字符串", 2)`)
- result.Check(testkit.Rows("2"))
- result = tk.MustQuery(`select locate("文", "中文字符串", 3)`)
- result.Check(testkit.Rows("0"))
- result = tk.MustQuery(`select locate("文", "中文字符串")`)
- result.Check(testkit.Rows("2"))
-
- // for bin
- result = tk.MustQuery(`select bin(-1);`)
- result.Check(testkit.Rows("1111111111111111111111111111111111111111111111111111111111111111"))
- result = tk.MustQuery(`select bin(5);`)
- result.Check(testkit.Rows("101"))
- result = tk.MustQuery(`select bin("中文");`)
- result.Check(testkit.Rows("0"))
-
- // for character_length
- result = tk.MustQuery(`select character_length(null), character_length("Hello"), character_length("a中b文c"),
- character_length(123), character_length(12.3456);`)
- result.Check(testkit.Rows(" 5 5 3 7"))
-
- // for char_length
- result = tk.MustQuery(`select char_length(null), char_length("Hello"), char_length("a中b文c"), char_length(123),char_length(12.3456);`)
- result.Check(testkit.Rows(" 5 5 3 7"))
- result = tk.MustQuery(`select char_length(null), char_length("Hello"), char_length("a 中 b 文 c"), char_length("НОЧЬ НА ОКРАИНЕ МОСКВЫ");`)
- result.Check(testkit.Rows(" 5 9 22"))
- // for char_length, binary string type
- result = tk.MustQuery(`select char_length(null), char_length(binary("Hello")), char_length(binary("a 中 b 文 c")), char_length(binary("НОЧЬ НА ОКРАИНЕ МОСКВЫ"));`)
- result.Check(testkit.Rows(" 5 13 41"))
-
- // for elt
- result = tk.MustQuery(`select elt(0, "abc", "def"), elt(2, "hello", "中文", "tidb"), elt(4, "hello", "中文",
- "tidb");`)
- result.Check(testkit.Rows(" 中文 "))
-
- // for instr
- result = tk.MustQuery(`select instr("中国", "国"), instr("中国", ""), instr("abc", ""), instr("", ""), instr("", "abc");`)
- result.Check(testkit.Rows("2 1 1 1 0"))
- result = tk.MustQuery(`select instr("中国", null), instr(null, ""), instr(null, null);`)
- result.Check(testkit.Rows(" "))
- tk.MustExec(`drop table if exists t;`)
- tk.MustExec(`create table t(a binary(20), b char(20));`)
- tk.MustExec(`insert into t values("中国", cast("国" as binary)), ("中国", ""), ("abc", ""), ("", ""), ("", "abc");`)
- result = tk.MustQuery(`select instr(a, b) from t;`)
- result.Check(testkit.Rows("4", "1", "1", "1", "0"))
-
- // for oct
- result = tk.MustQuery(`select oct("aaaa"), oct("-1.9"), oct("-9999999999999999999999999"), oct("9999999999999999999999999");`)
- result.Check(testkit.Rows("0 1777777777777777777777 1777777777777777777777 1777777777777777777777"))
- result = tk.MustQuery(`select oct(-1.9), oct(1.9), oct(-1), oct(1), oct(-9999999999999999999999999), oct(9999999999999999999999999);`)
- result.Check(testkit.Rows("1777777777777777777777 1 1777777777777777777777 1 1777777777777777777777 1777777777777777777777"))
-
- // for find_in_set
- result = tk.MustQuery(`select find_in_set("", ""), find_in_set("", ","), find_in_set("中文", "字符串,中文"), find_in_set("b,", "a,b,c,d");`)
- result.Check(testkit.Rows("0 1 2 0"))
- result = tk.MustQuery(`select find_in_set(NULL, ""), find_in_set("", NULL), find_in_set(1, "2,3,1");`)
- result.Check(testkit.Rows(" 3"))
-
- // for make_set
- result = tk.MustQuery(`select make_set(0, "12"), make_set(3, "aa", "11"), make_set(3, NULL, "中文"), make_set(NULL, "aa");`)
- result.Check(testkit.Rows(" aa,11 中文 "))
-
- // for quote
- result = tk.MustQuery(`select quote("aaaa"), quote(""), quote("\"\""), quote("\n\n");`)
- result.Check(testkit.Rows("'aaaa' '' '\"\"' '\n\n'"))
- result = tk.MustQuery(`select quote(0121), quote(0000), quote("中文"), quote(NULL);`)
- result.Check(testkit.Rows("'121' '0' '中文' NULL"))
- tk.MustQuery(`select quote(null) is NULL;`).Check(testkit.Rows(`0`))
- tk.MustQuery(`select quote(null) is NOT NULL;`).Check(testkit.Rows(`1`))
- tk.MustQuery(`select length(quote(null));`).Check(testkit.Rows(`4`))
- tk.MustQuery(`select quote(null) REGEXP binary 'null'`).Check(testkit.Rows(`0`))
- tk.MustQuery(`select quote(null) REGEXP binary 'NULL'`).Check(testkit.Rows(`1`))
- tk.MustQuery(`select quote(null) REGEXP 'NULL'`).Check(testkit.Rows(`1`))
- tk.MustQuery(`select quote(null) REGEXP 'null'`).Check(testkit.Rows(`0`))
-
- // for convert
- result = tk.MustQuery(`select convert("123" using "binary"), convert("中文" using "binary"), convert("中文" using "utf8"), convert("中文" using "utf8mb4"), convert(cast("中文" as binary) using "utf8");`)
- result.Check(testkit.Rows("123 中文 中文 中文 中文"))
- // charset 866 does not have a default collation configured currently, so this will return error.
- err = tk.ExecToErr(`select convert("123" using "866");`)
- require.Error(t, err, "[parser:1115]Unknown character set: '866'")
-
- // for insert
- result = tk.MustQuery(`select insert("中文", 1, 1, cast("aaa" as binary)), insert("ba", -1, 1, "aaa"), insert("ba", 1, 100, "aaa"), insert("ba", 100, 1, "aaa");`)
- result.Check(testkit.Rows("aaa\xb8\xad文 ba aaa ba"))
- result = tk.MustQuery(`select insert("bb", NULL, 1, "aa"), insert("bb", 1, NULL, "aa"), insert(NULL, 1, 1, "aaa"), insert("bb", 1, 1, NULL);`)
- result.Check(testkit.Rows(" "))
- result = tk.MustQuery(`SELECT INSERT("bb", 0, 1, NULL), INSERT("bb", 0, NULL, "aaa");`)
- result.Check(testkit.Rows(" "))
- result = tk.MustQuery(`SELECT INSERT("中文", 0, 1, NULL), INSERT("中文", 0, NULL, "aaa");`)
- result.Check(testkit.Rows(" "))
-
- // for export_set
- result = tk.MustQuery(`select export_set(7, "1", "0", ",", 65);`)
- result.Check(testkit.Rows("1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0"))
- result = tk.MustQuery(`select export_set(7, "1", "0", ",", -1);`)
- result.Check(testkit.Rows("1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0"))
- result = tk.MustQuery(`select export_set(7, "1", "0", ",");`)
- result.Check(testkit.Rows("1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0"))
- result = tk.MustQuery(`select export_set(7, "1", "0");`)
- result.Check(testkit.Rows("1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0"))
- result = tk.MustQuery(`select export_set(NULL, "1", "0", ",", 65);`)
- result.Check(testkit.Rows(""))
- result = tk.MustQuery(`select export_set(7, "1", "0", ",", 1);`)
- result.Check(testkit.Rows("1"))
-
- // for format
- result = tk.MustQuery(`select format(12332.1, 4), format(12332.2, 0), format(12332.2, 2,'en_US');`)
- result.Check(testkit.Rows("12,332.1000 12,332 12,332.20"))
- result = tk.MustQuery(`select format(NULL, 4), format(12332.2, NULL);`)
- result.Check(testkit.Rows(" "))
- result = tk.MustQuery(`select format(12332.2, 2,'es_EC');`)
- result.Check(testkit.Rows("12,332.20"))
- tk.MustQuery("show warnings").Check(testkit.Rows("Warning 1649 Unknown locale: 'es_EC'"))
-
- // for field
- result = tk.MustQuery(`select field(1, 2, 1), field(1, 0, NULL), field(1, NULL, 2, 1), field(NULL, 1, 2, NULL);`)
- result.Check(testkit.Rows("2 0 3 0"))
- result = tk.MustQuery(`select field("1", 2, 1), field(1, "0", NULL), field("1", NULL, 2, 1), field(NULL, 1, "2", NULL);`)
- result.Check(testkit.Rows("2 0 3 0"))
- result = tk.MustQuery(`select field("1", 2, 1), field(1, "abc", NULL), field("1", NULL, 2, 1), field(NULL, 1, "2", NULL);`)
- result.Check(testkit.Rows("2 0 3 0"))
- result = tk.MustQuery(`select field("abc", "a", 1), field(1.3, "1.3", 1.5);`)
- result.Check(testkit.Rows("1 1"))
-
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a decimal(11, 8), b decimal(11,8))")
- tk.MustExec("insert into t values('114.57011441','38.04620115'), ('-38.04620119', '38.04620115');")
- result = tk.MustQuery("select a,b,concat_ws(',',a,b) from t")
- result.Check(testkit.Rows("114.57011441 38.04620115 114.57011441,38.04620115",
- "-38.04620119 38.04620115 -38.04620119,38.04620115"))
-
- // issue 44359
- tk.MustExec("drop table if exists t1")
- tk.MustExec("CREATE TABLE t1 (c1 INT UNSIGNED NOT NULL )")
- tk.MustExec("INSERT INTO t1 VALUES (0)")
- tk.MustQuery("SELECT c1 FROM t1 WHERE c1 <> CAST(POW(-'0', 1) AS BINARY)").Check(testkit.Rows())
- tk.MustQuery("SELECT c1 FROM t1 WHERE c1 = CAST('-000' AS BINARY)").Check(testkit.Rows("0"))
-}
-
-func TestInvalidStrings(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
-
- // Test convert invalid string.
- tk.MustExec("drop table if exists t;")
- tk.MustExec("create table t (a binary(5));")
- tk.MustExec("insert into t values (0x1e240), ('ABCDE');")
- tk.MustExec("set tidb_enable_vectorized_expression = on;")
- tk.MustQuery("select convert(t.a using utf8) from t;").Check(testkit.Rows("", "ABCDE"))
- tk.MustQuery("select convert(0x1e240 using utf8);").Check(testkit.Rows(""))
- tk.MustExec("set tidb_enable_vectorized_expression = off;")
- tk.MustQuery("select convert(t.a using utf8) from t;").Check(testkit.Rows("", "ABCDE"))
- tk.MustQuery("select convert(0x1e240 using utf8);").Check(testkit.Rows(""))
-}
-
func TestEncryptionBuiltin(t *testing.T) {
store := testkit.CreateMockStore(t)
@@ -1008,101 +490,6 @@ func TestEncryptionBuiltin(t *testing.T) {
tk.MustQuery("SELECT VALIDATE_PASSWORD_STRENGTH(CAST(0xd2 AS BINARY(10)))").Check(testkit.Rows("50"))
}
-func TestOpBuiltin(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
-
- // for logicAnd
- result := tk.MustQuery("select 1 && 1, 1 && 0, 0 && 1, 0 && 0, 2 && -1, null && 1, '1a' && 'a'")
- result.Check(testkit.Rows("1 0 0 0 1 0"))
- // for bitNeg
- result = tk.MustQuery("select ~123, ~-123, ~null")
- result.Check(testkit.Rows("18446744073709551492 122 "))
- // for logicNot
- result = tk.MustQuery("select !1, !123, !0, !null")
- result.Check(testkit.Rows("0 0 1 "))
- // for logicalXor
- result = tk.MustQuery("select 1 xor 1, 1 xor 0, 0 xor 1, 0 xor 0, 2 xor -1, null xor 1, '1a' xor 'a'")
- result.Check(testkit.Rows("0 1 1 0 0 1"))
- // for bitAnd
- result = tk.MustQuery("select 123 & 321, -123 & 321, null & 1")
- result.Check(testkit.Rows("65 257 "))
- // for bitOr
- result = tk.MustQuery("select 123 | 321, -123 | 321, null | 1")
- result.Check(testkit.Rows("379 18446744073709551557 "))
- // for bitXor
- result = tk.MustQuery("select 123 ^ 321, -123 ^ 321, null ^ 1")
- result.Check(testkit.Rows("314 18446744073709551300 "))
- // for leftShift
- result = tk.MustQuery("select 123 << 2, -123 << 2, null << 1")
- result.Check(testkit.Rows("492 18446744073709551124 "))
- // for rightShift
- result = tk.MustQuery("select 123 >> 2, -123 >> 2, null >> 1")
- result.Check(testkit.Rows("30 4611686018427387873 "))
- // for logicOr
- result = tk.MustQuery("select 1 || 1, 1 || 0, 0 || 1, 0 || 0, 2 || -1, null || 1, '1a' || 'a'")
- result.Check(testkit.Rows("1 1 1 0 1 1 1"))
- // for unaryPlus
- result = tk.MustQuery(`select +1, +0, +(-9), +(-0.001), +0.999, +null, +"aaa"`)
- result.Check(testkit.Rows("1 0 -9 -0.001 0.999 aaa"))
- // for unaryMinus
- tk.MustExec("drop table if exists f")
- tk.MustExec("create table f(a decimal(65,0))")
- tk.MustExec("insert into f value (-17000000000000000000)")
- result = tk.MustQuery("select a from f")
- result.Check(testkit.Rows("-17000000000000000000"))
-}
-
-func TestDatetimeOverflow(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
-
- tk.MustExec("create table t1 (d date)")
- tk.MustExec("set sql_mode='traditional'")
- overflowSQLs := []string{
- "insert into t1 (d) select date_add('2000-01-01',interval 8000 year)",
- "insert into t1 (d) select date_sub('2000-01-01', INTERVAL 2001 YEAR)",
- "insert into t1 (d) select date_add('9999-12-31',interval 1 year)",
- "insert into t1 (d) select date_add('9999-12-31',interval 1 day)",
- }
-
- for _, sql := range overflowSQLs {
- tk.MustGetErrMsg(sql, "[types:1441]Datetime function: datetime field overflow")
- }
-
- tk.MustExec("set sql_mode=''")
- for _, sql := range overflowSQLs {
- tk.MustExec(sql)
- }
-
- rows := make([]string, 0, len(overflowSQLs))
- for range overflowSQLs {
- rows = append(rows, "")
- }
- tk.MustQuery("select * from t1").Check(testkit.Rows(rows...))
-}
-
-func TestExprDateTimeOnDST(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
-
- // insert DST datetime
- tk.MustExec("set @@session.time_zone = 'Europe/Amsterdam'")
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t (id int, dt datetime, primary key (id, dt))")
- tk.MustExec("insert into t values (1, date_add('2023-03-26 00:00:00', interval 2 hour))")
- tk.MustExec("insert into t values (4,'2023-03-26 02:00:00')")
-
- // check DST datetime
- tk.MustQuery("select * from t").Check(testkit.Rows("1 2023-03-26 02:00:00", "4 2023-03-26 02:00:00"))
-}
-
func TestInfoBuiltin(t *testing.T) {
store := testkit.CreateMockStore(t)
@@ -1257,1156 +644,134 @@ func TestInfoBuiltin(t *testing.T) {
// 2 * 2, error.
err = tk.ExecToErr(twoColumnQuery)
require.Error(t, err)
+
+ result = tk.MustQuery("select tidb_is_ddl_owner()")
+ var ret int64
+ if tk.Session().IsDDLOwner() {
+ ret = 1
+ }
+ result.Check(testkit.Rows(fmt.Sprintf("%v", ret)))
+}
+
+func TestColumnInfoModified(t *testing.T) {
+ store := testkit.CreateMockStore(t)
+
+ testKit := testkit.NewTestKit(t, store)
+ testKit.MustExec("use test")
+ testKit.MustExec("drop table if exists tab0")
+ testKit.MustExec("CREATE TABLE tab0(col0 INTEGER, col1 INTEGER, col2 INTEGER)")
+ testKit.MustExec("SELECT + - (- CASE + col0 WHEN + CAST( col0 AS SIGNED ) THEN col1 WHEN 79 THEN NULL WHEN + - col1 THEN col0 / + col0 END ) * - 16 FROM tab0")
+ ctx := testKit.Session()
+ is := domain.GetDomain(ctx).InfoSchema()
+ tbl, _ := is.TableByName(model.NewCIStr("test"), model.NewCIStr("tab0"))
+ col := table.FindCol(tbl.Cols(), "col1")
+ require.Equal(t, mysql.TypeLong, col.GetType())
}
-func TestControlBuiltin(t *testing.T) {
+func TestFilterExtractFromDNF(t *testing.T) {
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
+ tk.MustExec("drop table if exists t")
+ tk.MustExec("create table t(a int, b int, c int)")
- // for ifnull
- result := tk.MustQuery("select ifnull(1, 2)")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select ifnull(null, 2)")
- result.Check(testkit.Rows("2"))
- result = tk.MustQuery("select ifnull(1, null)")
- result.Check(testkit.Rows("1"))
- result = tk.MustQuery("select ifnull(null, null)")
- result.Check(testkit.Rows(""))
+ tests := []struct {
+ exprStr string
+ result string
+ }{
+ {
+ exprStr: "a = 1 or a = 1 or a = 1",
+ result: "[eq(test.t.a, 1)]",
+ },
+ {
+ exprStr: "a = 1 or a = 1 or (a = 1 and b = 1)",
+ result: "[eq(test.t.a, 1)]",
+ },
+ {
+ exprStr: "(a = 1 and a = 1) or a = 1 or b = 1",
+ result: "[or(or(and(eq(test.t.a, 1), eq(test.t.a, 1)), eq(test.t.a, 1)), eq(test.t.b, 1))]",
+ },
+ {
+ exprStr: "(a = 1 and b = 2) or (a = 1 and b = 3) or (a = 1 and b = 4)",
+ result: "[eq(test.t.a, 1) or(eq(test.t.b, 2), or(eq(test.t.b, 3), eq(test.t.b, 4)))]",
+ },
+ {
+ exprStr: "(a = 1 and b = 1 and c = 1) or (a = 1 and b = 1) or (a = 1 and b = 1 and c > 2 and c < 3)",
+ result: "[eq(test.t.a, 1) eq(test.t.b, 1)]",
+ },
+ }
- tk.MustExec("drop table if exists t1")
- tk.MustExec("create table t1(a bigint not null)")
- result = tk.MustQuery("select ifnull(max(a),0) from t1")
- result.Check(testkit.Rows("0"))
+ ctx := context.Background()
+ for _, tt := range tests {
+ sql := "select * from t where " + tt.exprStr
+ sctx := tk.Session()
+ sc := sctx.GetSessionVars().StmtCtx
+ stmts, err := session.Parse(sctx, sql)
+ require.NoError(t, err, "error %v, for expr %s", err, tt.exprStr)
+ require.Len(t, stmts, 1)
+ ret := &plannercore.PreprocessorReturn{}
+ err = plannercore.Preprocess(context.Background(), sctx, stmts[0], plannercore.WithPreprocessorReturn(ret))
+ require.NoError(t, err, "error %v, for resolve name, expr %s", err, tt.exprStr)
+ p, _, err := plannercore.BuildLogicalPlanForTest(ctx, sctx, stmts[0], ret.InfoSchema)
+ require.NoError(t, err, "error %v, for build plan, expr %s", err, tt.exprStr)
+ selection := p.(plannercore.LogicalPlan).Children()[0].(*plannercore.LogicalSelection)
+ conds := make([]expression.Expression, len(selection.Conditions))
+ for i, cond := range selection.Conditions {
+ conds[i] = expression.PushDownNot(sctx, cond)
+ }
+ afterFunc := expression.ExtractFiltersFromDNFs(sctx, conds)
+ sort.Slice(afterFunc, func(i, j int) bool {
+ return bytes.Compare(afterFunc[i].HashCode(sc), afterFunc[j].HashCode(sc)) < 0
+ })
+ require.Equal(t, fmt.Sprintf("%s", afterFunc), tt.result, "wrong result for expr: %s", tt.exprStr)
+ }
+}
- tk.MustExec("drop table if exists t1")
- tk.MustExec("drop table if exists t2")
- tk.MustExec("create table t1(a decimal(20,4))")
- tk.MustExec("create table t2(a decimal(20,4))")
- tk.MustExec("insert into t1 select 1.2345")
- tk.MustExec("insert into t2 select 1.2345")
-
- result = tk.MustQuery(`select sum(ifnull(a, 0)) from (
- select ifnull(a, 0) as a from t1
- union all
- select ifnull(a, 0) as a from t2
- ) t;`)
- result.Check(testkit.Rows("2.4690"))
-
- // for if
- result = tk.MustQuery(`select IF(0,"ERROR","this"),IF(1,"is","ERROR"),IF(NULL,"ERROR","a"),IF(1,2,3)|0,IF(1,2.0,3.0)+0;`)
- result.Check(testkit.Rows("this is a 2 2.0"))
- tk.MustExec("drop table if exists t1;")
- tk.MustExec("CREATE TABLE t1 (st varchar(255) NOT NULL, u int(11) NOT NULL);")
- tk.MustExec("INSERT INTO t1 VALUES ('a',1),('A',1),('aa',1),('AA',1),('a',1),('aaa',0),('BBB',0);")
- result = tk.MustQuery("select if(1,st,st) s from t1 order by s;")
- result.Check(testkit.Rows("A", "AA", "BBB", "a", "a", "aa", "aaa"))
- result = tk.MustQuery("select if(u=1,st,st) s from t1 order by s;")
- result.Check(testkit.Rows("A", "AA", "BBB", "a", "a", "aa", "aaa"))
- tk.MustExec("drop table if exists t1;")
- tk.MustExec("CREATE TABLE t1 (a varchar(255), b time, c int)")
- tk.MustExec("INSERT INTO t1 VALUE('abc', '12:00:00', 0)")
- tk.MustExec("INSERT INTO t1 VALUE('1abc', '00:00:00', 1)")
- tk.MustExec("INSERT INTO t1 VALUE('0abc', '12:59:59', 0)")
- result = tk.MustQuery("select if(a, b, c), if(b, a, c), if(c, a, b) from t1")
- result.Check(testkit.Rows("0 abc 12:00:00", "00:00:00 1 1abc", "0 0abc 12:59:59"))
- result = tk.MustQuery("select if(1, 1.0, 1)")
- result.Check(testkit.Rows("1.0"))
- // FIXME: MySQL returns `1.0`.
- result = tk.MustQuery("select if(1, 1, 1.0)")
- result.Check(testkit.Rows("1"))
- tk.MustQuery("select if(count(*), cast('2000-01-01' as date), cast('2011-01-01' as date)) from t1").Check(testkit.Rows("2000-01-01"))
- tk.MustQuery("select if(count(*)=0, cast('2000-01-01' as date), cast('2011-01-01' as date)) from t1").Check(testkit.Rows("2011-01-01"))
- tk.MustQuery("select if(count(*), cast('[]' as json), cast('{}' as json)) from t1").Check(testkit.Rows("[]"))
- tk.MustQuery("select if(count(*)=0, cast('[]' as json), cast('{}' as json)) from t1").Check(testkit.Rows("{}"))
+func TestTiDBDecodePlanFunc(t *testing.T) {
+ store := testkit.CreateMockStore(t)
- result = tk.MustQuery("SELECT 79 + + + CASE -87 WHEN -30 THEN COALESCE(COUNT(*), +COALESCE(+15, -33, -12 ) + +72) WHEN +COALESCE(+AVG(DISTINCT(60)), 21) THEN NULL ELSE NULL END AS col0;")
- result.Check(testkit.Rows(""))
+ tk := testkit.NewTestKit(t, store)
+ tk.MustQuery("select tidb_decode_plan('')").Check(testkit.Rows(""))
+ tk.MustQuery("select tidb_decode_plan('7APIMAk1XzEzCTAJMQlmdW5jczpjb3VudCgxKQoxCTE3XzE0CTAJMAlpbm5lciBqb2luLCBp" +
+ "AQyQOlRhYmxlUmVhZGVyXzIxLCBlcXVhbDpbZXEoQ29sdW1uIzEsIA0KCDkpIBkXADIVFywxMCldCjIJMzFfMTgFZXhkYXRhOlNlbGVjdGlvbl" +
+ "8xNwozCTFfMTcJMQkwCWx0HVlATlVMTCksIG5vdChpc251bGwVHAApUhcAUDIpKQo0CTEwXzE2CTEJMTAwMDAJdAHB2Dp0MSwgcmFuZ2U6Wy1p" +
+ "bmYsK2luZl0sIGtlZXAgb3JkZXI6ZmFsc2UsIHN0YXRzOnBzZXVkbwoFtgAyAZcEMAk6tgAEMjAFtgQyMDq2AAg5LCBmtgAAMFa3AAA5FbcAO" +
+ "T63AAAyzrcA')").Check(testkit.Rows("" +
+ "\tid \ttask\testRows\toperator info\n" +
+ "\tStreamAgg_13 \troot\t1 \tfuncs:count(1)\n" +
+ "\t└─HashJoin_14 \troot\t0 \tinner join, inner:TableReader_21, equal:[eq(Column#1, Column#9) eq(Column#2, Column#10)]\n" +
+ "\t ├─TableReader_18 \troot\t0 \tdata:Selection_17\n" +
+ "\t │ └─Selection_17 \tcop \t0 \tlt(Column#1, NULL), not(isnull(Column#1)), not(isnull(Column#2))\n" +
+ "\t │ └─TableScan_16\tcop \t10000 \ttable:t1, range:[-inf,+inf], keep order:false, stats:pseudo\n" +
+ "\t └─TableReader_21 \troot\t0 \tdata:Selection_20\n" +
+ "\t └─Selection_20 \tcop \t0 \tlt(Column#9, NULL), not(isnull(Column#10)), not(isnull(Column#9))\n" +
+ "\t └─TableScan_19\tcop \t10000 \ttable:t2, range:[-inf,+inf], keep order:false, stats:pseudo"))
+ tk.MustQuery("select tidb_decode_plan('rwPwcTAJNV8xNAkwCTEJZnVuY3M6bWF4KHRlc3QudC5hKS0+Q29sdW1uIzQJMQl0aW1lOj" +
+ "IyMy45MzXCtXMsIGxvb3BzOjIJMTI4IEJ5dGVzCU4vQQoxCTE2XzE4CTAJMQlvZmZzZXQ6MCwgY291bnQ6MQkxCQlHFDE4LjQyMjJHAAhOL0" +
+ "EBBCAKMgkzMl8yOAkBlEBpbmRleDpMaW1pdF8yNwkxCQ0+DDYuODUdPSwxLCBycGMgbnVtOiANDAUpGDE1MC44MjQFKjhwcm9jIGtleXM6MA" +
+ "kxOTgdsgAzAbIAMgFearIAFDU3LjM5NgVKAGwN+BGxIDQJMTNfMjYJMQGgHGFibGU6dCwgCbqwaWR4KGEpLCByYW5nZTooMCwraW5mXSwga2" +
+ "VlcCBvcmRlcjp0cnVlLCBkZXNjAT8kaW1lOjU2LjY2MR1rJDEJTi9BCU4vQQo=')").Check(testkit.Rows("" +
+ "\tid \ttask\testRows\toperator info \tactRows\texecution info \tmemory \tdisk\n" +
+ "\tStreamAgg_14 \troot\t1 \tfuncs:max(test.t.a)->Column#4 \t1 \ttime:223.935µs, loops:2 \t128 Bytes\tN/A\n" +
+ "\t└─Limit_18 \troot\t1 \toffset:0, count:1 \t1 \ttime:218.422µs, loops:2 \tN/A \tN/A\n" +
+ "\t └─IndexReader_28 \troot\t1 \tindex:Limit_27 \t1 \ttime:216.85µs, loops:1, rpc num: 1, rpc time:150.824µs, proc keys:0\t198 Bytes\tN/A\n" +
+ "\t └─Limit_27 \tcop \t1 \toffset:0, count:1 \t1 \ttime:57.396µs, loops:2 \tN/A \tN/A\n" +
+ "\t └─IndexScan_26\tcop \t1 \ttable:t, index:idx(a), range:(0,+inf], keep order:true, desc\t1 \ttime:56.661µs, loops:1 \tN/A \tN/A"))
- result = tk.MustQuery("SELECT -63 + COALESCE ( - 83, - 61 + - + 72 * - CAST( NULL AS SIGNED ) + + 3 );")
- result.Check(testkit.Rows("-146"))
+ // Test issue16939
+ tk.MustQuery("select tidb_decode_plan(query), time from information_schema.slow_query order by time desc limit 1;")
+ tk.MustQuery("select tidb_decode_plan('xxx')").Check(testkit.Rows("xxx"))
}
-func TestArithmeticBuiltin(t *testing.T) {
+func TestTiDBDecodeKeyFunc(t *testing.T) {
store := testkit.CreateMockStore(t)
+ collate.SetNewCollationEnabledForTest(false)
+ defer collate.SetNewCollationEnabledForTest(true)
+
tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
- ctx := context.Background()
-
- // for plus
- tk.MustExec("DROP TABLE IF EXISTS t;")
- tk.MustExec("CREATE TABLE t(a DECIMAL(4, 2), b DECIMAL(5, 3));")
- tk.MustExec("INSERT INTO t(a, b) VALUES(1.09, 1.999), (-1.1, -0.1);")
- result := tk.MustQuery("SELECT a+b FROM t;")
- result.Check(testkit.Rows("3.089", "-1.200"))
- result = tk.MustQuery("SELECT b+12, b+0.01, b+0.00001, b+12.00001 FROM t;")
- result.Check(testkit.Rows("13.999 2.009 1.99901 13.99901", "11.900 -0.090 -0.09999 11.90001"))
- result = tk.MustQuery("SELECT 1+12, 21+0.01, 89+\"11\", 12+\"a\", 12+NULL, NULL+1, NULL+NULL;")
- result.Check(testkit.Rows("13 21.01 100 12 "))
- tk.MustExec("DROP TABLE IF EXISTS t;")
- tk.MustExec("CREATE TABLE t(a BIGINT UNSIGNED, b BIGINT UNSIGNED);")
- tk.MustExec("INSERT INTO t SELECT 1<<63, 1<<63;")
- rs, err := tk.Exec("SELECT a+b FROM t;")
- require.NoError(t, err)
- require.NotNil(t, rs)
- rows, err := session.GetRows4Test(ctx, tk.Session(), rs)
- require.Nil(t, rows)
- require.Error(t, err)
- require.Error(t, err, "[types:1690]BIGINT UNSIGNED value is out of range in '(test.t.a + test.t.b)'")
- require.NoError(t, rs.Close())
- rs, err = tk.Exec("select cast(-3 as signed) + cast(2 as unsigned);")
- require.NoError(t, err)
- require.NotNil(t, rs)
- rows, err = session.GetRows4Test(ctx, tk.Session(), rs)
- require.Nil(t, rows)
- require.Error(t, err)
- require.Error(t, err, "[types:1690]BIGINT UNSIGNED value is out of range in '(-3 + 2)'")
- require.NoError(t, rs.Close())
- rs, err = tk.Exec("select cast(2 as unsigned) + cast(-3 as signed);")
- require.NoError(t, err)
- require.NotNil(t, rs)
- rows, err = session.GetRows4Test(ctx, tk.Session(), rs)
- require.Nil(t, rows)
- require.Error(t, err)
- require.Error(t, err, "[types:1690]BIGINT UNSIGNED value is out of range in '(2 + -3)'")
- require.NoError(t, rs.Close())
-
- // for minus
- tk.MustExec("DROP TABLE IF EXISTS t;")
- tk.MustExec("CREATE TABLE t(a DECIMAL(4, 2), b DECIMAL(5, 3));")
- tk.MustExec("INSERT INTO t(a, b) VALUES(1.09, 1.999), (-1.1, -0.1);")
- result = tk.MustQuery("SELECT a-b FROM t;")
- result.Check(testkit.Rows("-0.909", "-1.000"))
- result = tk.MustQuery("SELECT b-12, b-0.01, b-0.00001, b-12.00001 FROM t;")
- result.Check(testkit.Rows("-10.001 1.989 1.99899 -10.00101", "-12.100 -0.110 -0.10001 -12.10001"))
- result = tk.MustQuery("SELECT 1-12, 21-0.01, 89-\"11\", 12-\"a\", 12-NULL, NULL-1, NULL-NULL;")
- result.Check(testkit.Rows("-11 20.99 78 12 "))
-
- tk.MustExec("DROP TABLE IF EXISTS t;")
- tk.MustExec("CREATE TABLE t(a BIGINT UNSIGNED, b BIGINT UNSIGNED);")
- tk.MustExec("INSERT INTO t SELECT 1, 4;")
- err = tk.QueryToErr("SELECT a-b FROM t;")
- require.Error(t, err)
- require.Error(t, err, "[types:1690]BIGINT UNSIGNED value is out of range in '(test.t.a - test.t.b)'")
-
- err = tk.QueryToErr("select cast(1 as unsigned) - cast(4 as unsigned);")
- require.Error(t, err)
- // TODO: make error compatible with MySQL, should be BIGINT UNSIGNED value is out of range in '(cast(1 as unsigned) - cast(4 as unsigned))
- require.Error(t, err, "[types:1690]BIGINT UNSIGNED value is out of range in '(1 - 4)'")
-
- err = tk.QueryToErr("select cast(-1 as signed) - cast(-1 as unsigned);")
- require.Error(t, err)
- require.Error(t, err, "[types:1690]BIGINT UNSIGNED value is out of range in '(-1 - 18446744073709551615)'")
-
- err = tk.QueryToErr("select cast(1 as signed) - cast(-1 as unsigned);")
- require.Error(t, err)
- require.Error(t, err, "[types:1690]BIGINT UNSIGNED value is out of range in '(1 - 18446744073709551615)'")
-
- err = tk.QueryToErr("select cast(-1 as unsigned) - cast(-1 as signed);")
- require.Error(t, err)
- require.Error(t, err, "[types:1690]BIGINT UNSIGNED value is out of range in '(18446744073709551615 - -1)'")
-
- err = tk.QueryToErr("select cast(-9223372036854775808 as unsigned) - (-9223372036854775808);")
- require.Error(t, err)
- require.Error(t, err, "[types:1690]BIGINT UNSIGNED value is out of range in '(9223372036854775808 - -9223372036854775808)'")
-
- err = tk.QueryToErr("select cast(12 as unsigned) - (14);")
- require.Error(t, err)
- require.Error(t, err, "[types:1690]BIGINT UNSIGNED value is out of range in '(12 - 14)'")
-
- err = tk.QueryToErr("select cast(9223372036854775807 as signed) - cast(-1 as signed);")
- require.Error(t, err, "[types:1690]BIGINT value is out of range in '(9223372036854775807 - -1)'")
-
- err = tk.QueryToErr("select cast(-9223372036854775808 as signed) - cast(1 as signed);")
- require.Error(t, err)
- require.Error(t, err, "[types:1690]BIGINT value is out of range in '(-9223372036854775808 - 1)'")
-
- err = tk.QueryToErr("select cast(12 as signed) - cast(-9223372036854775808 as signed);")
- require.Error(t, err)
- require.Error(t, err, "[types:1690]BIGINT value is out of range in '(12 - -9223372036854775808)'")
-
- tk.MustExec(`create table tb5(a int(10));`)
- tk.MustExec(`insert into tb5 (a) values (10);`)
- e := tk.QueryToErr(`select * from tb5 where a - -9223372036854775808;`)
- require.NotNil(t, e)
- require.True(t, strings.HasSuffix(e.Error(), `BIGINT value is out of range in '(Column#0 - -9223372036854775808)'`), "err: %v", err)
-
- tk.MustExec(`drop table tb5`)
- tk.MustQuery("select cast(-9223372036854775808 as unsigned) - (-9223372036854775807);").Check(testkit.Rows("18446744073709551615"))
- tk.MustQuery("select cast(-3 as unsigned) - cast(-1 as signed);").Check(testkit.Rows("18446744073709551614"))
- tk.MustQuery("select 1.11 - 1.11;").Check(testkit.Rows("0.00"))
- tk.MustQuery("select cast(-1 as unsigned) - cast(-12 as unsigned);").Check(testkit.Rows("11"))
- tk.MustQuery("select cast(-1 as unsigned) - cast(0 as unsigned);").Check(testkit.Rows("18446744073709551615"))
-
- // for multiply
- tk.MustQuery("select 1234567890 * 1234567890").Check(testkit.Rows("1524157875019052100"))
- rs, err = tk.Exec("select 1234567890 * 12345671890")
- require.NoError(t, err)
- _, err = session.GetRows4Test(ctx, tk.Session(), rs)
- require.True(t, terror.ErrorEqual(err, types.ErrOverflow))
- require.NoError(t, rs.Close())
- tk.MustQuery("select cast(1234567890 as unsigned int) * 12345671890").Check(testkit.Rows("15241570095869612100"))
- tk.MustQuery("select 123344532434234234267890.0 * 1234567118923479823749823749.230").Check(testkit.Rows("152277104042296270209916846800130443726237424001224.7000"))
- rs, err = tk.Exec("select 123344532434234234267890.0 * 12345671189234798237498232384982309489238402830480239849238048239084749.230")
- require.NoError(t, err)
- _, err = session.GetRows4Test(ctx, tk.Session(), rs)
- require.True(t, terror.ErrorEqual(err, types.ErrOverflow))
- require.NoError(t, rs.Close())
- // FIXME: There is something wrong in showing float number.
- // tk.MustQuery("select 1.797693134862315708145274237317043567981e+308 * 1").Check(testkit.Rows("1.7976931348623157e308"))
- // tk.MustQuery("select 1.797693134862315708145274237317043567981e+308 * -1").Check(testkit.Rows("-1.7976931348623157e308"))
- rs, err = tk.Exec("select 1.797693134862315708145274237317043567981e+308 * 1.1")
- require.NoError(t, err)
- _, err = session.GetRows4Test(ctx, tk.Session(), rs)
- require.True(t, terror.ErrorEqual(err, types.ErrOverflow))
- require.NoError(t, rs.Close())
- rs, err = tk.Exec("select 1.797693134862315708145274237317043567981e+308 * -1.1")
- require.NoError(t, err)
- _, err = session.GetRows4Test(ctx, tk.Session(), rs)
- require.True(t, terror.ErrorEqual(err, types.ErrOverflow))
- require.NoError(t, rs.Close())
- tk.MustQuery("select 0.0 * -1;").Check(testkit.Rows("0.0"))
-
- tk.MustExec("DROP TABLE IF EXISTS t;")
- tk.MustExec("CREATE TABLE t(a DECIMAL(4, 2), b DECIMAL(5, 3));")
- tk.MustExec("INSERT INTO t(a, b) VALUES(-1.09, 1.999);")
- result = tk.MustQuery("SELECT a/b, a/12, a/-0.01, b/12, b/-0.01, b/0.000, NULL/b, b/NULL, NULL/NULL FROM t;")
- result.Check(testkit.Rows("-0.545273 -0.090833 109.000000 0.1665833 -199.9000000 "))
- tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1365 Division by 0"))
- rs, err = tk.Exec("select 1e200/1e-200")
- require.NoError(t, err)
- _, err = session.GetRows4Test(ctx, tk.Session(), rs)
- require.True(t, terror.ErrorEqual(err, types.ErrOverflow))
- require.NoError(t, rs.Close())
-
- // for intDiv
- result = tk.MustQuery("SELECT 13 DIV 12, 13 DIV 0.01, -13 DIV 2, 13 DIV NULL, NULL DIV 13, NULL DIV NULL;")
- result.Check(testkit.Rows("1 1300 -6 "))
- result = tk.MustQuery("SELECT 2.4 div 1.1, 2.4 div 1.2, 2.4 div 1.3;")
- result.Check(testkit.Rows("2 2 1"))
- result = tk.MustQuery("SELECT 1.175494351E-37 div 1.7976931348623157E+308, 1.7976931348623157E+308 div -1.7976931348623157E+307, 1 div 1e-82;")
- result.Check(testkit.Rows("0 -1 "))
- tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|",
- "Warning|1292|Truncated incorrect DECIMAL value: '1.7976931348623157e+308'",
- "Warning|1292|Truncated incorrect DECIMAL value: '1.7976931348623157e+308'",
- "Warning|1292|Truncated incorrect DECIMAL value: '-1.7976931348623158e+307'",
- "Warning|1365|Division by 0"))
- rs, err = tk.Exec("select 1e300 DIV 1.5")
- require.NoError(t, err)
- _, err = session.GetRows4Test(ctx, tk.Session(), rs)
- require.True(t, terror.ErrorEqual(err, types.ErrOverflow))
- require.NoError(t, rs.Close())
-
- tk.MustExec("drop table if exists t;")
- tk.MustExec("CREATE TABLE t (c_varchar varchar(255), c_time time, nonzero int, zero int, c_int_unsigned int unsigned, c_timestamp timestamp, c_enum enum('a','b','c'));")
- tk.MustExec("INSERT INTO t VALUE('abc', '12:00:00', 12, 0, 5, '2017-08-05 18:19:03', 'b');")
- result = tk.MustQuery("select c_varchar div nonzero, c_time div nonzero, c_time div zero, c_timestamp div nonzero, c_timestamp div zero, c_varchar div zero from t;")
- result.Check(testkit.Rows("0 10000 1680900431825 "))
- result = tk.MustQuery("select c_enum div nonzero from t;")
- result.Check(testkit.Rows("0"))
- tk.MustQuery("select c_enum div zero from t").Check(testkit.Rows(""))
- tk.MustQuery("select nonzero div zero from t").Check(testkit.Rows(""))
- tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1365 Division by 0"))
- result = tk.MustQuery("select c_time div c_enum, c_timestamp div c_time, c_timestamp div c_enum from t;")
- result.Check(testkit.Rows("60000 168090043 10085402590951"))
- result = tk.MustQuery("select c_int_unsigned div nonzero, nonzero div c_int_unsigned, c_int_unsigned div zero from t;")
- result.Check(testkit.Rows("0 2 "))
- tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1365 Division by 0"))
-
- // for mod
- result = tk.MustQuery("SELECT CAST(1 AS UNSIGNED) MOD -9223372036854775808, -9223372036854775808 MOD CAST(1 AS UNSIGNED);")
- result.Check(testkit.Rows("1 0"))
- result = tk.MustQuery("SELECT 13 MOD 12, 13 MOD 0.01, -13 MOD 2, 13 MOD NULL, NULL MOD 13, NULL DIV NULL;")
- result.Check(testkit.Rows("1 0.00 -1 "))
- result = tk.MustQuery("SELECT 2.4 MOD 1.1, 2.4 MOD 1.2, 2.4 mod 1.30;")
- result.Check(testkit.Rows("0.2 0.0 1.10"))
- tk.MustExec("drop table if exists t;")
- tk.MustExec("CREATE TABLE t (c_varchar varchar(255), c_time time, nonzero int, zero int, c_timestamp timestamp, c_enum enum('a','b','c'));")
- tk.MustExec("INSERT INTO t VALUE('abc', '12:00:00', 12, 0, '2017-08-05 18:19:03', 'b');")
- result = tk.MustQuery("select c_varchar MOD nonzero, c_time MOD nonzero, c_timestamp MOD nonzero, c_enum MOD nonzero from t;")
- result.Check(testkit.Rows("0 0 3 2"))
- result = tk.MustQuery("select c_time MOD c_enum, c_timestamp MOD c_time, c_timestamp MOD c_enum from t;")
- result.Check(testkit.Rows("0 21903 1"))
- tk.MustQuery("select c_enum MOD zero from t;").Check(testkit.Rows(""))
- tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1365 Division by 0"))
- tk.MustExec("SET SQL_MODE='ERROR_FOR_DIVISION_BY_ZERO,STRICT_ALL_TABLES';")
- tk.MustExec("drop table if exists t;")
- tk.MustExec("CREATE TABLE t (v int);")
- tk.MustExec("INSERT IGNORE INTO t VALUE(12 MOD 0);")
- tk.MustQuery("show warnings;").CheckContain("Division by 0")
- tk.MustQuery("select v from t;").Check(testkit.Rows(""))
- tk.MustQuery("select 0.000 % 0.11234500000000000000;").Check(testkit.Rows("0.00000000000000000000"))
-
- tk.MustGetDBError("INSERT INTO t VALUE(12 MOD 0);", expression.ErrDivisionByZero)
-
- tk.MustQuery("select sum(1.2e2) * 0.1").Check(testkit.Rows("12"))
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a double)")
- tk.MustExec("insert into t value(1.2)")
- tk.MustQuery("select sum(a) * 0.1 from t").Check(testkit.Rows("0.12"))
-
- tk.MustExec("drop table if exists t")
- tk.MustExec("create table t(a double)")
- tk.MustExec("insert into t value(1.2)")
- result = tk.MustQuery("select * from t where a/0 > 1")
- result.Check(testkit.Rows())
- tk.MustQuery("show warnings").Check(testkit.RowsWithSep("|", "Warning|1365|Division by 0"))
-
- tk.MustExec("USE test;")
- tk.MustExec("DROP TABLE IF EXISTS t;")
- tk.MustExec("CREATE TABLE t(a BIGINT, b DECIMAL(6, 2));")
- tk.MustExec("INSERT INTO t VALUES(0, 1.12), (1, 1.21);")
- tk.MustQuery("SELECT a/b FROM t;").Check(testkit.Rows("0.0000", "0.8264"))
-}
-
-func TestGreatestTimeType(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
-
- tk.MustExec("drop table if exists t1;")
- tk.MustExec("create table t1(c_time time(5), c_dt datetime(4), c_ts timestamp(3), c_d date, c_str varchar(100));")
- tk.MustExec("insert into t1 values('-800:10:10', '2021-10-10 10:10:10.1234', '2021-10-10 10:10:10.1234', '2021-10-11', '2021-10-10 10:10:10.1234');")
-
- for i := 0; i < 2; i++ {
- if i == 0 {
- tk.MustExec("set @@tidb_enable_vectorized_expression = off;")
- } else {
- tk.MustExec("set @@tidb_enable_vectorized_expression = on;")
- }
- tk.MustQuery("select greatest(c_time, c_time) from t1;").Check(testkit.Rows("-800:10:10.00000"))
- tk.MustQuery("select greatest(c_dt, c_dt) from t1;").Check(testkit.Rows("2021-10-10 10:10:10.1234"))
- tk.MustQuery("select greatest(c_ts, c_ts) from t1;").Check(testkit.Rows("2021-10-10 10:10:10.123"))
- tk.MustQuery("select greatest(c_d, c_d) from t1;").Check(testkit.Rows("2021-10-11"))
- tk.MustQuery("select greatest(c_str, c_str) from t1;").Check(testkit.Rows("2021-10-10 10:10:10.1234"))
-
- tk.MustQuery("select least(c_time, c_time) from t1;").Check(testkit.Rows("-800:10:10.00000"))
- tk.MustQuery("select least(c_dt, c_dt) from t1;").Check(testkit.Rows("2021-10-10 10:10:10.1234"))
- tk.MustQuery("select least(c_ts, c_ts) from t1;").Check(testkit.Rows("2021-10-10 10:10:10.123"))
- tk.MustQuery("select least(c_d, c_d) from t1;").Check(testkit.Rows("2021-10-11"))
- tk.MustQuery("select least(c_str, c_str) from t1;").Check(testkit.Rows("2021-10-10 10:10:10.1234"))
-
- tk.MustQuery("select greatest(c_time, cast('10:01:01' as time)) from t1;").Check(testkit.Rows("10:01:01.00000"))
- tk.MustQuery("select least(c_time, cast('10:01:01' as time)) from t1;").Check(testkit.Rows("-800:10:10.00000"))
-
- tk.MustQuery("select greatest(c_d, cast('1999-10-10' as date)) from t1;").Check(testkit.Rows("2021-10-11"))
- tk.MustQuery("select least(c_d, cast('1999-10-10' as date)) from t1;").Check(testkit.Rows("1999-10-10"))
-
- tk.MustQuery("select greatest(c_dt, cast('1999-10-10 10:10:10.1234' as datetime)) from t1;").Check(testkit.Rows("2021-10-10 10:10:10.1234"))
- tk.MustQuery("select least(c_dt, cast('1999-10-10 10:10:10.1234' as datetime)) from t1;").Check(testkit.Rows("1999-10-10 10:10:10"))
- }
-}
-
-func TestCompareBuiltin(t *testing.T) {
- store := testkit.CreateMockStore(t)
-
- tk := testkit.NewTestKit(t, store)
- tk.MustExec("use test")
-
- // compare as JSON
- tk.MustExec("drop table if exists t")
- tk.MustExec("CREATE TABLE t (pk int NOT NULL PRIMARY KEY AUTO_INCREMENT, i INT, j JSON);")
- tk.MustExec(`INSERT INTO t(i, j) VALUES (0, NULL)`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (1, '{"a": 2}')`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (2, '[1,2]')`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (3, '{"a":"b", "c":"d","ab":"abc", "bc": ["x", "y"]}')`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (4, '["here", ["I", "am"], "!!!"]')`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (5, '"scalar string"')`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (6, 'true')`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (7, 'false')`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (8, 'null')`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (9, '-1')`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (10, CAST(CAST(1 AS UNSIGNED) AS JSON))`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (11, '32767')`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (12, '32768')`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (13, '-32768')`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (14, '-32769')`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (15, '2147483647')`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (16, '2147483648')`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (17, '-2147483648')`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (18, '-2147483649')`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (19, '18446744073709551615')`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (20, '18446744073709551616')`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (21, '3.14')`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (22, '{}')`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (23, '[]')`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (24, CAST(CAST('2015-01-15 23:24:25' AS DATETIME) AS JSON))`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (25, CAST(CAST('23:24:25' AS TIME) AS JSON))`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (26, CAST(CAST('2015-01-15' AS DATE) AS JSON))`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (27, CAST(TIMESTAMP('2015-01-15 23:24:25') AS JSON))`)
- tk.MustExec(`INSERT INTO t(i, j) VALUES (28, CAST('[]' AS CHAR CHARACTER SET 'ascii'))`)
-
- result := tk.MustQuery(`SELECT i,
- (j = '"scalar string"') AS c1,
- (j = 'scalar string') AS c2,
- (j = CAST('"scalar string"' AS JSON)) AS c3,
- (j = CAST(CAST(j AS CHAR CHARACTER SET 'utf8mb4') AS JSON)) AS c4,
- (j = CAST(NULL AS JSON)) AS c5,
- (j = NULL) AS c6,
- (j <=> NULL) AS c7,
- (j <=> CAST(NULL AS JSON)) AS c8,
- (j IN (-1, 2, 32768, 3.14)) AS c9,
- (j IN (CAST('[1, 2]' AS JSON), CAST('{}' AS JSON), CAST(3.14 AS JSON))) AS c10,
- (j = (SELECT j FROM t WHERE j = CAST('null' AS JSON))) AS c11,
- (j = (SELECT j FROM t WHERE j IS NULL)) AS c12,
- (j = (SELECT j FROM t WHERE 1<>1)) AS c13,
- (j = DATE('2015-01-15')) AS c14,
- (j = TIME('23:24:25')) AS c15,
- (j = TIMESTAMP('2015-01-15 23:24:25')) AS c16,
- (j = CURRENT_TIMESTAMP) AS c17,
- (JSON_EXTRACT(j, '$.a') = 2) AS c18
- FROM t
- ORDER BY i;`)
- result.Check(testkit.Rows("0 1 1 ",
- "1 0 0 0 1 0 0 0 0 0 0 0 0 0 1",
- "2 0 0 0 1 0 0 0 1 0 0 0 0 0 ",
- "3 0 0 0 1 0 0 0 0 0 0 0 0 0 0",
- "4 0 0 0 1 0 0 0 0 0 0 0 0 0 ",
- "5 0 1 1 1 0 0 0 0 0 0 0 0 0