From 953dd991b4a16c0b22ae3dbf0e92637cff2fc9fc Mon Sep 17 00:00:00 2001 From: lidezhu <47731263+lidezhu@users.noreply.github.com> Date: Mon, 6 Dec 2021 21:21:56 +0800 Subject: [PATCH] improve auto generated fullstack tests (#3592) --- .gitignore | 2 + .../ddl/binary_default_value.test | 21 +- tests/fullstack-test2/dml/text_blob_type.test | 14 + tests/generate-fullstack-test.py | 991 ++++++++++-------- 4 files changed, 565 insertions(+), 463 deletions(-) create mode 100644 tests/fullstack-test2/dml/text_blob_type.test diff --git a/.gitignore b/.gitignore index f0b6abd85b3..7f3d6acfdd9 100644 --- a/.gitignore +++ b/.gitignore @@ -262,4 +262,6 @@ libs/libtiflash-proxy tmp/ +tests/fullstack-test2/auto_gen + CoverageReport diff --git a/tests/fullstack-test2/ddl/binary_default_value.test b/tests/fullstack-test2/ddl/binary_default_value.test index 6deffefea0b..1f6cc15136e 100644 --- a/tests/fullstack-test2/ddl/binary_default_value.test +++ b/tests/fullstack-test2/ddl/binary_default_value.test @@ -10,19 +10,20 @@ mysql> alter table test.t add column b1 binary(8) not null; mysql> alter table test.t add column b2 binary(8) default X'3132'; mysql> alter table test.t add column b3 binary(8) not null default X'003132'; mysql> alter table test.t add column b4 binary(8) not null default X'0031323334353637'; +mysql> alter table test.t add column b5 varbinary(8) not null default X'0031323334353637'; mysql> set session tidb_isolation_read_engines='tiflash'; select * from test.t; -+------+------------------+----------------+----------------+-----------+ -| a | b1 | b2 | b3 | b4 | -+------+------------------+------------------+----------------+-----------+ -| 1 | \0\0\0\0\0\0\0\0 | 12\0\0\0\0\0\0 | \012\0\0\0\0\0 | \01234567 | -+------+------------------+----------------+----------------+-----------+ ++------+------------------+----------------+----------------+-----------+-----------+ +| a | b1 | b2 | b3 | b4 | b5 | ++------+------------------+------------------+----------------+-----------+-----------+ +| 1 | \0\0\0\0\0\0\0\0 | 12\0\0\0\0\0\0 | \012\0\0\0\0\0 | \01234567 | \01234567 | ++------+------------------+----------------+----------------+-----------+-----------+ mysql> set session tidb_isolation_read_engines='tikv'; select * from test.t; -+------+------------------+----------------+----------------+-----------+ -| a | b1 | b2 | b3 | b4 | -+------+------------------+----------------+----------------+-----------+ -| 1 | \0\0\0\0\0\0\0\0 | 12\0\0\0\0\0\0 | \012\0\0\0\0\0 | \01234567 | -+------+------------------+----------------+----------------+-----------+ ++------+------------------+----------------+----------------+-----------+-----------+ +| a | b1 | b2 | b3 | b4 | b5 | ++------+------------------+----------------+----------------+-----------+-----------+ +| 1 | \0\0\0\0\0\0\0\0 | 12\0\0\0\0\0\0 | \012\0\0\0\0\0 | \01234567 | \01234567 | ++------+------------------+----------------+----------------+-----------+-----------+ mysql> drop table if exists test.t diff --git a/tests/fullstack-test2/dml/text_blob_type.test b/tests/fullstack-test2/dml/text_blob_type.test new file mode 100644 index 00000000000..8069f1ce9f1 --- /dev/null +++ b/tests/fullstack-test2/dml/text_blob_type.test @@ -0,0 +1,14 @@ +mysql> drop table if exists test.t +mysql> create table test.t(a text, b blob) +mysql> alter table test.t set tiflash replica 1 + +func> wait_table test t + +mysql> insert into test.t values('test1', '01223'); + +mysql> set session tidb_isolation_read_engines='tiflash'; select * from test.t; ++-------+--------------+ +| a | b | ++-------+--------------+ +| test1 | 01223 | ++-------+--------------+ diff --git a/tests/generate-fullstack-test.py b/tests/generate-fullstack-test.py index 0ef12ec65d3..1e324308972 100644 --- a/tests/generate-fullstack-test.py +++ b/tests/generate-fullstack-test.py @@ -1,473 +1,558 @@ -# -*- coding:utf-8 -*- - -import sys -from string import Template -import random -import string -import re -import copy -import os import errno +import os +import sys -drop_stmt = Template("mysql> drop table if exists $database.$table\n") -create_stmt = Template("mysql> create table $database.$table($schema)\n") -alter_stmt = Template("mysql> alter table $database.$table set tiflash replica 1 location labels 'rack', 'host', 'abc'\n") -insert_stmt = Template("mysql> insert into $database.$table($columns) values($data)\n") -update_stmt = Template("mysql> update $database.$table set $exprs $condition\n") -delete_stmt = Template("mysql> delete from $database.$table $condition\n") -select_stmt = Template(">> select $columns from $database.$table\n") -tidb_select_stmt = Template("mysql> set SESSION tidb_isolation_read_engines = 'tiflash' ;select $columns from $database.$table ttt\n") -sleep_string = "\nSLEEP 15\n\n" -wait_table_stmt = Template("\nfunc> wait_table $database $table\n\n") - - -INSERT = "insert" -UPDATE = "update" -DELETE = "delete" -SELECT = "select" - - -def generate_column_name(types): - name_prefix = "a" - names = [] - for i, _ in enumerate(types): - names.append(name_prefix + str(i)) - return names - - -def generate_schema(names, types, primary_key_type): - columns = [] - primary_key_name = "" - for (name, t) in zip(names, types): - if t == primary_key_type: - primary_key_name = name - columns.append(name + " " + t) - columns.append("primary key (" + primary_key_name + ")") - return ", ".join(columns) - +class ColumnType(object): + typeTinyInt = "tinyint" + typeUTinyInt = "tinyint unsigned" + typeSmallInt = "smallint" + typeUSmallInt = "smallint unsigned" + typeMediumInt = "mediumint" + typeUMediumInt = "mediumint unsigned" + typeInt = "int" + typeUInt = "int unsigned" + typeBigInt = "bigint" + typeUBigInt = "bigint unsigned" + typeBit64 = "bit(64)" + typeBoolean = "boolean" + typeFloat = "float" + typeDouble = "double" + typeDecimal1 = "decimal(1, 0)" + typeDecimal2 = "decimal(65, 2)" + typeChar = "char(200)" + typeVarchar = "varchar(200)" + + +class ColumnTypeManager(object): + # the index must be continuous below + MinValueIndex = 0 + MaxValueIndex = 1 + NormalValueIndex = 2 + # when the column is not null and doesn't have default value + EmptyValueIndex = 3 + + def __init__(self): + self.column_value_map = {} + self.funcs_process_value_for_insert_stmt = {} + self.funcs_get_select_expr = {} + self.funcs_format_result = {} + + def register_type(self, column_type, values, process_value_func=None, get_select_expr=None, format_result=None): + self.column_value_map[column_type] = values + if process_value_func is not None: + self.funcs_process_value_for_insert_stmt[column_type] = process_value_func + if get_select_expr is not None: + self.funcs_get_select_expr[column_type] = get_select_expr + if format_result is not None: + self.funcs_format_result[column_type] = format_result + + def get_all_column_elements(self): + column_elements = [] + for i, t in enumerate(self.column_value_map): + default_value = self.get_normal_value(t) + column_elements.append(ColumnElement("a" + str(2 * i), t, False, False, default_value)) + column_elements.append(ColumnElement("a" + str(2 * i + 1), t, True, False, default_value)) + return column_elements + + def get_all_primary_key_elements(self): + column_elements = [] + for i, t in enumerate(self.column_value_map): + column_elements.append(ColumnElement("a" + str(2 * i), t, False, True)) + return column_elements + + def get_all_nullable_column_elements(self): + column_elements = [] + for i, t in enumerate(self.column_value_map): + column_elements.append(ColumnElement("a" + str(2 * i), t, True, False)) + return column_elements + + def get_min_value(self, column_type): + return self.column_value_map[column_type][ColumnTypeManager.MinValueIndex] + + def get_max_value(self, column_type): + return self.column_value_map[column_type][ColumnTypeManager.MaxValueIndex] + + def get_normal_value(self, column_type): + return self.column_value_map[column_type][ColumnTypeManager.NormalValueIndex] + + def get_empty_value(self, column_type): + return self.column_value_map[column_type][ColumnTypeManager.EmptyValueIndex] + + def get_value_for_dml_stmt(self, column_type, value): + if column_type in self.funcs_process_value_for_insert_stmt \ + and self.funcs_process_value_for_insert_stmt[column_type] is not None: + return self.funcs_process_value_for_insert_stmt[column_type](value) + else: + return str(value) -def random_string(n): - letters = string.ascii_lowercase - return ''.join(random.choice(letters) for i in range(int(n))) + # some type's outputs are hard to parse and compare, try convert it's result to string + def get_select_expr(self, column_type, name): + if column_type in self.funcs_get_select_expr \ + and self.funcs_get_select_expr[column_type] is not None: + return self.funcs_get_select_expr[column_type](name) + else: + return name + + def format_result_value(self, column_type, value): + if str(value).lower() == "null": + return value + if column_type in self.funcs_format_result \ + and self.funcs_format_result[column_type] is not None: + return self.funcs_format_result[column_type](value) + else: + return value + + +def register_all_types(column_type_manager): + # column_type_manager.register_type(ColumnType, values) + # the first three value in values should be different for some primary key tests + column_type_manager.register_type(ColumnType.typeTinyInt, [-128, 127, 100, 0]) + column_type_manager.register_type(ColumnType.typeUTinyInt, [0, 255, 68, 0]) + column_type_manager.register_type(ColumnType.typeSmallInt, [-32768, 32767, 11100, 0]) + column_type_manager.register_type(ColumnType.typeUSmallInt, [0, 65535, 68, 0]) + column_type_manager.register_type(ColumnType.typeMediumInt, [-8388608, 8388607, 68, 0]) + column_type_manager.register_type(ColumnType.typeUMediumInt, [0, 16777215, 68, 0]) + column_type_manager.register_type(ColumnType.typeInt, [-2147483648, 2147483647, 68, 0]) + column_type_manager.register_type(ColumnType.typeUInt, [0, 4294967295, 4239013, 0]) + column_type_manager.register_type(ColumnType.typeBigInt, [-9223372036854775808, 9223372036854775807, 68, 0]) + column_type_manager.register_type(ColumnType.typeUBigInt, [0, 18446744073709551615, 68, 0]) + column_type_manager.register_type( + ColumnType.typeBit64, + [0, (1 << 64) - 1, 79, 0], + None, + lambda name: "bin({})".format(name), + lambda value: "{0:b}".format(int(value))) + column_type_manager.register_type(ColumnType.typeBoolean, [0, 1, 1, 0]) + column_type_manager.register_type(ColumnType.typeFloat, ["-3.402e38", "3.402e38", "-1.17e-38", 0]) + column_type_manager.register_type(ColumnType.typeDouble, ["-1.797e308", "1.797e308", "2.225e-308", 0]) + column_type_manager.register_type(ColumnType.typeDecimal1, [-9, 9, 3, 0]) + column_type_manager.register_type( + ColumnType.typeDecimal2, + ['-' + '9' * 63 + '.' + '9' * 2, '9' * 63 + '.' + '9' * 2, 100.23, "0.00"]) + column_type_manager.register_type( + ColumnType.typeChar, + ["", "a" * 200, "test", ""], + lambda value: "'" + value + "'") + column_type_manager.register_type( + ColumnType.typeVarchar, + ["", "a" * 200, "tiflash", ""], + lambda value: "'" + value + "'") + + +class ColumnElement(object): + def __init__(self, name, column_type, nullable=False, is_primary_key=False, default_value=None): + self.name = name + self.column_type = column_type + self.nullable = nullable + self.default_value = default_value + self.is_primary_key = is_primary_key + + def get_schema(self, column_type_manager): + schema = self.name + " " + str(self.column_type) + if not self.nullable: + schema += " not null" + if self.default_value is not None: + schema += " default {}".format(column_type_manager.get_value_for_dml_stmt(self.column_type, self.default_value)) + return schema + + +class MysqlClientResultBuilder(object): + def __init__(self, column_names): + self.column_names = column_names + self.rows = [] + # + 2 is for left and right blank + self.cell_length = [len(name) + 2 for name in self.column_names] + self.point = "+" + self.vertical = "|" + self.horizontal = "-" + + def add_row(self, column_values): + assert len(column_values) == len(self.column_names) + for i in range(len(self.cell_length)): + self.cell_length[i] = max(self.cell_length[i], len(str(column_values[i])) + 2) + self.rows.append(column_values) + + def finish(self, file): + self._write_horizontal_line(file) + self._write_header_content(file) + self._write_horizontal_line(file) + self._write_body_content(file) + self._write_horizontal_line(file) + + def _write_horizontal_line(self, file): + file.write(self.point) + for length in self.cell_length: + file.write(self.horizontal * length) + file.write(self.point) + file.write("\n") + + def _write_row(self, file, row_cells): + file.write(self.vertical) + for i in range(len(row_cells)): + content = str(row_cells[i]) + file.write(" ") + file.write(content) + file.write(" " * (self.cell_length[i] - 1 - len(content))) + file.write(self.vertical) + file.write("\n") + + def _write_header_content(self, file): + self._write_row(file, self.column_names) + + def _write_body_content(self, file): + for row in self.rows: + self._write_row(file, row) + + +class StmtWriter(object): + def __init__(self, file, db_name, table_name): + self.file = file + self.db_name = db_name + self.table_name = table_name + + def write_newline(self): + self.file.write("\n") + + def write_drop_table_stmt(self): + command = "mysql> drop table if exists {}.{}\n".format(self.db_name, self.table_name) + self.file.write(command) + + def write_create_table_schema_stmt(self, column_elements, column_type_manager): + column_schema = ", ".join([c.get_schema(column_type_manager) for c in column_elements]) + primary_key_names = [] + for c in column_elements: + if c.is_primary_key: + primary_key_names.append(c.name) + if len(primary_key_names) > 0: + column_schema += ", primary key({})".format(", ".join(primary_key_names)) + + command = "mysql> create table {}.{}({})\n".format(self.db_name, self.table_name, column_schema) + self.file.write(command) + + def write_create_tiflash_replica_stmt(self): + command = "mysql> alter table {}.{} set tiflash replica 1\n".format(self.db_name, self.table_name) + self.file.write(command) + + def write_wait_table_stmt(self): + command = "func> wait_table {} {}\n".format(self.db_name, self.table_name) + self.file.write(command) + + def write_insert_stmt(self, column_names, column_values): + command = "mysql> insert into {}.{} ({}) values({})\n".format( + self.db_name, self.table_name, ", ".join(column_names), ", ".join(column_values)) + self.file.write(command) + + def write_update_stmt(self, column_names, prev_values, after_values): + assert len(column_names) == len(prev_values) == len(after_values) + update_part = "" + filter_part = "" + for i in range(len(column_names)): + update_part += "{}={}".format(column_names[i], after_values[i]) + filter_part += "{}={}".format(column_names[i], prev_values[i]) + + command = "mysql> update {}.{} set {} where {}\n".format( + self.db_name, self.table_name, update_part, filter_part) + self.file.write(command) + + def write_delete_stmt(self, column_names, values): + assert len(column_names) == len(values) + filter_part = "" + for i in range(len(column_names)): + filter_part += "{}={}".format(column_names[i], values[i]) + + command = "mysql> delete from {}.{} where {}\n".format( + self.db_name, self.table_name, filter_part) + self.file.write(command) + + def write_select_stmt(self, column_select_exprs): + command = "mysql> set SESSION tidb_isolation_read_engines='tiflash'; select {} from {}.{}\n".format( + ", ".join(column_select_exprs), self.db_name, self.table_name) + self.file.write(command) + + def write_result(self, column_select_exprs, *row_column_values): + result_builder = MysqlClientResultBuilder(column_select_exprs) + for column_values in row_column_values: + result_builder.add_row(column_values) + result_builder.finish(self.file) + + +class TestCaseWriter(object): + def __init__(self, column_type_manager): + self.column_type_manager = column_type_manager + + def _build_dml_values(self, column_elements, column_values): + return [self.column_type_manager.get_value_for_dml_stmt( + column_elements[i].column_type, column_values[i]) for i in range(len(column_values))] + + def _build_formatted_values(self, column_elements, column_values): + result = [self.column_type_manager.format_result_value( + column_elements[i].column_type, column_values[i]) for i in range(len(column_values))] + return result + + def _write_create_table(self, writer, column_elements): + writer.write_drop_table_stmt() + writer.write_create_table_schema_stmt(column_elements, self.column_type_manager) + writer.write_create_tiflash_replica_stmt() + writer.write_wait_table_stmt() + writer.write_newline() + + def _write_min_max_value_test(self, writer): + column_elements = self.column_type_manager.get_all_column_elements() + self._write_create_table(writer, column_elements) + # insert values + column_names = [c.name for c in column_elements] + column_min_values = [self.column_type_manager.get_min_value(c.column_type) for c in column_elements] + writer.write_insert_stmt(column_names, self._build_dml_values(column_elements, column_min_values)) + column_max_values = [self.column_type_manager.get_max_value(c.column_type) for c in column_elements] + writer.write_insert_stmt(column_names, self._build_dml_values(column_elements, column_max_values)) + column_normal_values = [self.column_type_manager.get_normal_value(c.column_type) for c in column_elements] + writer.write_insert_stmt(column_names, self._build_dml_values(column_elements, column_normal_values)) + # check result + column_select_exprs = [self.column_type_manager.get_select_expr(c.column_type, c.name) for c in column_elements] + writer.write_select_stmt(column_select_exprs) + writer.write_result( + column_select_exprs, + self._build_formatted_values(column_elements, column_min_values), + self._build_formatted_values(column_elements, column_max_values), + self._build_formatted_values(column_elements, column_normal_values)) + writer.write_newline() + + def _write_default_value_test(self, writer): + column_elements = self.column_type_manager.get_all_column_elements() + primary_key_name = "mykey" + primary_key_value = 10000 + column_default_values = [c.default_value for c in column_elements] + column_elements.append(ColumnElement(primary_key_name, ColumnType.typeUInt, False, True)) + column_default_values.append(primary_key_value) + self._write_create_table(writer, column_elements) + # write a row which only specify key value + writer.write_insert_stmt([primary_key_name], [str(primary_key_value)]) + # check result + column_select_exprs = [self.column_type_manager.get_select_expr(c.column_type, c.name) for c in column_elements] + writer.write_select_stmt(column_select_exprs) + writer.write_result( + column_select_exprs, + self._build_formatted_values(column_elements, column_default_values)) + + def _write_null_value_test(self, writer): + nonnull_column_elements = self.column_type_manager.get_all_nullable_column_elements() + primary_key_name = "mykey" + primary_key_value = 10000 + column_null_values = ["NULL" for _ in nonnull_column_elements] + nonnull_column_elements.append(ColumnElement(primary_key_name, ColumnType.typeUInt, False, True)) + column_null_values.append(str(primary_key_value)) + self._write_create_table(writer, nonnull_column_elements) + # write a row which only specify key value + writer.write_insert_stmt([primary_key_name], [str(primary_key_value)]) + # check result + column_select_exprs = [self.column_type_manager.get_select_expr(c.column_type, c.name) for c in + nonnull_column_elements] + writer.write_select_stmt(column_select_exprs) + writer.write_result( + column_select_exprs, + self._build_formatted_values(nonnull_column_elements, column_null_values)) + + def write_basic_type_codec_test(self, file, db_name, table_name): + writer = StmtWriter(file, db_name, table_name) + self._write_min_max_value_test(writer) + self._write_default_value_test(writer) + self._write_null_value_test(writer) + + def _write_non_cluster_index_test(self, writer): + column_elements = self.column_type_manager.get_all_column_elements() + filter_column_name = "myfilter" + filter_value1 = 10000 + filter_value2 = 20000 + column_names = [c.name for c in column_elements] + column_values1 = [self.column_type_manager.get_normal_value(c.column_type) for c in column_elements] + column_values2 = [self.column_type_manager.get_normal_value(c.column_type) for c in column_elements] + column_elements.append(ColumnElement(filter_column_name, ColumnType.typeUInt)) + column_names.append(filter_column_name) + column_values1.append(filter_value1) + column_values2.append(filter_value2) + self._write_create_table(writer, column_elements) + writer.write_insert_stmt(column_names, self._build_dml_values(column_elements, column_values1)) + writer.write_insert_stmt(column_names, self._build_dml_values(column_elements, column_values2)) + # check result + column_select_exprs = [self.column_type_manager.get_select_expr(c.column_type, c.name) for c in + column_elements] + writer.write_select_stmt(column_select_exprs) + writer.write_result( + column_select_exprs, + self._build_formatted_values(column_elements, column_values1), + self._build_formatted_values(column_elements, column_values2)) + new_filter_value1 = 30000 + writer.write_newline() + # update + writer.write_update_stmt([filter_column_name], [filter_value1], [new_filter_value1]) + column_values1[len(column_values1) - 1] = new_filter_value1 + writer.write_select_stmt(column_select_exprs) + writer.write_result( + column_select_exprs, + self._build_formatted_values(column_elements, column_values1), + self._build_formatted_values(column_elements, column_values2)) + # delete + writer.write_delete_stmt([filter_column_name], [filter_value2]) + writer.write_select_stmt(column_select_exprs) + writer.write_result( + column_select_exprs, + self._build_formatted_values(column_elements, column_values1)) + + def _write_pk_is_handle_test(self, writer, pk_column_element): + column_elements = self.column_type_manager.get_all_column_elements() + filter_column_name = "myfilter" + filter_value1 = 10000 + filter_value2 = 20000 + column_names = [c.name for c in column_elements] + column_values1 = [self.column_type_manager.get_normal_value(c.column_type) for c in column_elements] + column_values2 = [self.column_type_manager.get_normal_value(c.column_type) for c in column_elements] + # add pk column + column_elements.append(pk_column_element) + column_names.append(pk_column_element.name) + pk_values1 = self.column_type_manager.get_normal_value(pk_column_element.column_type) + pk_values2 = self.column_type_manager.get_normal_value(pk_column_element.column_type) + 1 + column_values1.append(pk_values1) + column_values2.append(pk_values2) + + # add filter column + column_elements.append(ColumnElement(filter_column_name, ColumnType.typeUInt)) + column_names.append(filter_column_name) + column_values1.append(filter_value1) + column_values2.append(filter_value2) + self._write_create_table(writer, column_elements) + writer.write_insert_stmt(column_names, self._build_dml_values(column_elements, column_values1)) + writer.write_insert_stmt(column_names, self._build_dml_values(column_elements, column_values2)) + # check result + column_select_exprs = [self.column_type_manager.get_select_expr(c.column_type, c.name) for c in + column_elements] + writer.write_select_stmt(column_select_exprs) + writer.write_result( + column_select_exprs, + self._build_formatted_values(column_elements, column_values1), + self._build_formatted_values(column_elements, column_values2)) + new_filter_value1 = 30000 + writer.write_newline() + # update + writer.write_update_stmt([filter_column_name], [filter_value1], [new_filter_value1]) + column_values1[len(column_values1) - 1] = new_filter_value1 + writer.write_select_stmt(column_select_exprs) + writer.write_result( + column_select_exprs, + self._build_formatted_values(column_elements, column_values1), + self._build_formatted_values(column_elements, column_values2)) + # delete + writer.write_delete_stmt([filter_column_name], [filter_value2]) + writer.write_select_stmt(column_select_exprs) + writer.write_result( + column_select_exprs, + self._build_formatted_values(column_elements, column_values1)) + + def _write_cluster_index_test(self, writer, primary_key_columns): + column_elements = primary_key_columns + column_names = [c.name for c in column_elements] + column_values1 = [self.column_type_manager.get_normal_value(c.column_type) for c in column_elements] + column_values2 = [self.column_type_manager.get_min_value(c.column_type) for c in column_elements] + + filter_column_name = "myfilter" + filter_value1 = 10000 + filter_value2 = 20000 + column_elements.append(ColumnElement(filter_column_name, ColumnType.typeInt)) + column_names.append(filter_column_name) + column_values1.append(filter_value1) + column_values2.append(filter_value2) + + self._write_create_table(writer, column_elements) + writer.write_insert_stmt(column_names, self._build_dml_values(column_elements, column_values1)) + writer.write_insert_stmt(column_names, self._build_dml_values(column_elements, column_values2)) + # check result + column_select_exprs = [self.column_type_manager.get_select_expr(c.column_type, c.name) for c in + column_elements] + writer.write_select_stmt(column_select_exprs) + writer.write_result( + column_select_exprs, + self._build_formatted_values(column_elements, column_values1), + self._build_formatted_values(column_elements, column_values2)) + new_filter_value1 = 30000 + writer.write_newline() + # update + writer.write_update_stmt([filter_column_name], [filter_value1], [new_filter_value1]) + column_values1[len(column_values1) - 1] = new_filter_value1 + writer.write_select_stmt(column_select_exprs) + writer.write_result( + column_select_exprs, + self._build_formatted_values(column_elements, column_values1), + self._build_formatted_values(column_elements, column_values2)) + # delete + writer.write_delete_stmt([filter_column_name], [filter_value2]) + writer.write_select_stmt(column_select_exprs) + writer.write_result( + column_select_exprs, + self._build_formatted_values(column_elements, column_values1)) + + def write_update_delete_test1(self, file, db_name, table_name): + writer = StmtWriter(file, db_name, table_name) + self._write_non_cluster_index_test(writer) + file.write("# pk_is_handle test\n") + for column_type in [ColumnType.typeTinyInt, ColumnType.typeUTinyInt, ColumnType.typeSmallInt, ColumnType.typeUSmallInt, + ColumnType.typeMediumInt, ColumnType.typeUMediumInt, ColumnType.typeInt, ColumnType.typeUInt, + ColumnType.typeBigInt, ColumnType.typeUInt]: + self._write_pk_is_handle_test(writer, ColumnElement("mypk", column_type, False, True)) + + def write_update_delete_test2(self, file, db_name, table_name): + writer = StmtWriter(file, db_name, table_name) + file.write("# cluster index test\n") + file.write("mysql> set global tidb_enable_clustered_index=ON\n") + primary_key_elements = self.column_type_manager.get_all_primary_key_elements() + for primary_key_element in primary_key_elements: + self._write_cluster_index_test(writer, [primary_key_element]) + # the max column num in primary key is 16 + max_columns_in_primary_key = 16 + self._write_cluster_index_test(writer, primary_key_elements[:max_columns_in_primary_key]) + # TODO: INT_ONLY may be removed in future release + file.write("mysql> set global tidb_enable_clustered_index=INT_ONLY\n") + + +def write_case(path, db_name, table_name, case_func): + with open(path, "w") as file: + case_func(file, db_name, table_name) + + +def run(db_name, table_name, test_dir): + column_type_manager = ColumnTypeManager() + register_all_types(column_type_manager) + + case_writer = TestCaseWriter(column_type_manager) + # case: decode min/max/normal/default/null values of different type + write_case(test_dir + "/basic_codec.test", db_name, table_name, case_writer.write_basic_type_codec_test) + # case: update/delete for different kinds of primary key + write_case(test_dir + "/update_delete1.test", db_name, table_name, case_writer.write_update_delete_test1) + write_case(test_dir + "/update_delete2.test", db_name, table_name, case_writer.write_update_delete_test2) -def generate_data(type_name, types, sample_data): - if type_name.startswith("varchar"): - lengths = re.findall(r"\d+", type_name) - if len(lengths) < 1: - return "" - else: - return random_string(lengths[0]) - elif "int" in type_name: - return str(random.randint(0, 100)) - else: - return str(sample_data[random.choice(range(len(sample_data)))][types.index(type_name)]) - - -def generate_exprs(names, values): - exprs = [] - for (name, value) in zip(names, values): - exprs.append(name + "=" + value) - return ", ".join(exprs) - - -def generate_result(names, dataset): - dataset = copy.deepcopy(dataset) - for data_point in dataset: - for i in range(len(data_point)): - if data_point[i] == "null": - data_point[i] = "\N" - left_top_corner = "┌" - right_top_corner = "┐" - left_bottom_corner = "└" - right_bottom_corner = "┘" - header_split = "┬" - footer_split = "┴" - border = "─" - body_border = "│" - blank = " " - - cell_length = [] - for name in names: - cell_length.append(len(name)) - for data in dataset: - for i, ele in enumerate(data): - if len(ele) > cell_length[i]: - cell_length[i] = len(ele) - - lines = [] - - header = [] - for i, name in enumerate(names): - header_cell = "" - if i == 0: - header_cell += left_top_corner - header_cell += border - header_cell += name - j = 0 - while cell_length[i] > len(name) + j: - header_cell += border - j += 1 - header_cell += border - if i == len(names) - 1: - header_cell += right_top_corner - - header.append(header_cell) - - lines.append(header_split.join(header)) - - for data in dataset: - cur = [] - for i, ele in enumerate(data): - body_cell = "" - if i == 0: - body_cell += body_border - body_cell += blank - body_cell += ele - j = 0 - while cell_length[i] > len(ele) + j: - body_cell += blank - j += 1 - body_cell += blank - if i == len(data) - 1: - body_cell += body_border - cur.append(body_cell) - - lines.append(body_border.join(cur)) - - footer = [] - for i, _ in enumerate(names): - footer_cell = "" - if i == 0: - footer_cell += left_bottom_corner - footer_cell += border - j = 0 - while cell_length[i] > j: - footer_cell += border - j += 1 - footer_cell += border - if i == len(names) - 1: - footer_cell += right_bottom_corner - - footer.append(footer_cell) - - lines.append(footer_split.join(footer)) - - return "\n".join(lines) - -def tidb_generate_result(names, dataset): - dataset = copy.deepcopy(dataset) - for data_point in dataset: - for i in range(len(data_point)): - if data_point[i] == "null": - data_point[i] = "NULL" - left_top_corner = "+" - right_top_corner = "+" - left_bottom_corner = "+" - right_bottom_corner = "+" - header_split = "+" - footer_split = "+" - border = "-" - body_border = "|" - blank = " " - - cell_length = [] - for name in names: - cell_length.append(len(name)) - for data in dataset: - for i, ele in enumerate(data): - if len(ele) > cell_length[i]: - cell_length[i] = len(ele) - - lines = [] - - topline = [] - for i, name in enumerate(names): - topline_cell = "" - if i == 0: - topline_cell += left_top_corner - topline_cell += border - j = 0 - while cell_length[i] > j: - topline_cell += border - j += 1 - topline_cell += border - if i == len(names) - 1: - topline_cell += right_top_corner - - topline.append(topline_cell) - - lines.append(header_split.join(topline)) - - header = [] - for i, name in enumerate(names): - header_cell = "" - if i == 0: - header_cell += body_border - header_cell += blank - header_cell += name - j = 0 - while cell_length[i] >= len(name) + j: - header_cell += blank - j += 1 - if i == len(names) - 1: - header_cell += body_border - - header.append(header_cell) - - lines.append(body_border.join(header)) - - lines.append(header_split.join(topline)) - - for data in dataset: - cur = [] - for i, ele in enumerate(data): - body_cell = "" - if i == 0: - body_cell += body_border - body_cell += blank - body_cell += ele - j = 0 - while cell_length[i] > len(ele) + j: - body_cell += blank - j += 1 - body_cell += blank - if i == len(data) - 1: - body_cell += body_border - cur.append(body_cell) - - lines.append(body_border.join(cur)) - - footer = [] - for i, _ in enumerate(names): - footer_cell = "" - if i == 0: - footer_cell += left_bottom_corner - footer_cell += border - j = 0 - while cell_length[i] > j: - footer_cell += border - j += 1 - footer_cell += border - if i == len(names) - 1: - footer_cell += right_bottom_corner - - footer.append(footer_cell) - - lines.append(footer_split.join(footer)) - - return "\n".join(lines) - -def generate_cases_inner(database, table, column_names, types, sample_data, - schema, primary_key_type, test_cases, parent_dir): - primary_key = column_names[len(column_names) - 1] - for num, case in enumerate(test_cases): - case_data = copy.deepcopy(sample_data) - path = parent_dir + primary_key_type.replace(" ", "_") + "_case" + str(num) + ".test" - with open(path, "w") as file: - file.write(drop_stmt.substitute({"database": database, "table": table})) - file.write(create_stmt.substitute({"database": database, "table": table, "schema": schema})) - file.write(alter_stmt.substitute({"database": database, "table": table})) - file.write(wait_table_stmt.substitute({"database": database, "table": table})) - - for op in case: - if op == INSERT: - for k in range(len(case_data)): - file.write(insert_stmt.substitute({"database": database, - "table": table, - "columns": ", ".join(column_names), - "data": ", ".join([repr(d) if d != "null" else d for d in case_data[k]])})) - if op == UPDATE: - for data_point in case_data: - condition = "" - exprs = [] - for i in range(len(types)): - if column_names[i] == primary_key: - condition = "where " + primary_key + " = " + repr(data_point[i]) - continue - ele = generate_data(types[i], types, sample_data) - data_point[i] = ele - value = repr(data_point[i]) - if data_point[i] == "null": - value = data_point[i] - exprs.append(column_names[i] + "=" + value) - file.write(update_stmt.substitute({"database": database, - "table": table, - "exprs": ", ".join(exprs), - "condition": condition})) - if op == DELETE: - new_case_data = random.sample(case_data, len(case_data) // 2) - for data_point in case_data: - if data_point in new_case_data: - continue - condition = "" - for i in range(len(types)): - if column_names[i] == primary_key: - condition = "where " + primary_key + " = " + repr(data_point[i]) - break - file.write(delete_stmt.substitute({"database": database, - "table": table, - "condition": condition})) - case_data = new_case_data - if op == SELECT: - # file.write(select_stmt.substitute({"columns": ", ".join(column_names), - # "database": database, - # "table": table})) - # file.write(generate_result(column_names, case_data) + "\n\n") - file.write(tidb_select_stmt.substitute({"columns": ", ".join(column_names), - "database": database, - "table": table})) - file.write(tidb_generate_result(column_names, case_data) + "\n\n") - - file.write(drop_stmt.substitute({"database": database, "table": table})) - - -def generate_cases(database, table, types, sample_data, - primary_key_candidates, primary_key_sample_data, test_cases, parent_dir): - for i, primary_key_type in enumerate(primary_key_candidates): - case_types = copy.deepcopy(types) - case_types.append(primary_key_type) - column_names = generate_column_name(case_types) - schema = generate_schema(column_names, case_types, primary_key_type) - case_sample_data = copy.deepcopy(sample_data) - for j in range(len(case_sample_data)): - case_sample_data[j].append(primary_key_sample_data[j][i]) - generate_cases_inner(database, table, column_names, case_types, case_sample_data, - schema, primary_key_type, test_cases, parent_dir) - - -def generate_data_for_types(types, sample_data, allow_empty=True, no_duplicate=False, result_len=1): - result = [] - if no_duplicate: - for name in types: - if len(sample_data[name]) < result_len: - raise Exception("not enough data sample for type: ", name) - for i in range(result_len): - cur = [] - for name in types: - cur.append(str(sample_data[name][i])) - result.append(cur) - else: - for _ in range(result_len): - cur = [] - for name in types: - if name in sample_data: - samples = sample_data[name] - cur.append(str(random.choice(samples))) - elif allow_empty: - cur.append("null") - else: - raise Exception("type without valid data_sample: ", name) - result.append(cur) - return result - - -def run(): +def main(): if len(sys.argv) != 3: - print 'usage: database table' + print('usage: database table') sys.exit(1) - database = sys.argv[1] - table = sys.argv[2] - - primary_key_candidates = ["tinyint", "smallint", "mediumint", "int", "bigint", - "tinyint unsigned", "smallint unsigned", "mediumint unsigned", "int unsigned", "bigint unsigned", ] - types = ["decimal(1, 0)", "decimal(5, 2)", "decimal(65, 0)", - "varchar(20)", "char(10)", - "date", "datetime", "timestamp", ] - min_values = { - "tinyint": [-(1 << 7), ], - "smallint": [-(1 << 15), ], - "mediumint": [-(1 << 23), ], - "int": [-(1 << 31), ], - "bigint": [-(1 << 63), ], - "tinyint unsigned": [0, ], - "smallint unsigned": [0, ], - "mediumint unsigned": [0, ], - "int unsigned": [0, ], - "bigint unsigned": [0, ], - "decimal(1, 0)": [-9, ], - "decimal(5, 2)": [-999.99, ], - "decimal(65, 0)": [-(pow(10, 65) - 1), ], - "decimal(65, 30)":[-99999999999999999999999999999999999.999999999999999999999999999999, ], - } - max_values = { - "tinyint": [(1 << 7) - 1, ], - "smallint": [(1 << 15) - 1, ], - "mediumint": [(1 << 23) - 1, ], - "int": [(1 << 31) - 1, ], - "bigint": [(1 << 63) - 1, ], - "tinyint unsigned": [(1 << 8) - 1, ], - "smallint unsigned": [(1 << 16) - 1, ], - "mediumint unsigned": [(1 << 24) - 1, ], - "int unsigned": [(1 << 32) - 1, ], - "bigint unsigned": [(1 << 64) - 1, ], - "decimal(1, 0)": [9, ], - "decimal(5, 2)": [999.99, ], - "decimal(65, 0)": [pow(10, 65) - 1, ], - "decimal(65, 30)": [99999999999999999999999999999999999.999999999999999999999999999999, ], - } - data_sample = { - "tinyint": [8, 9, 10, 11, 12, 13, 14], - "smallint": [8, 9, 10, 11, 12, 13, 14], - "mediumint": [8, 9, 10, 11, 12, 13, 14], - "int": [8, 9, 10, 11, 12, 13, 14], - "bigint": [8, 9, 10, 11, 12, 13, 14], - "tinyint unsigned": [8, 9, 10, 11, 12, 13, 14], - "smallint unsigned": [8, 9, 10, 11, 12, 13, 14], - "mediumint unsigned": [8, 9, 10, 11, 12, 13, 14], - "int unsigned": [8, 9, 10, 11, 12, 13, 14], - "bigint unsigned": [8, 9, 10, 11, 12, 13, 14], - "decimal(1, 0)": [7, 3], - "decimal(5, 2)": [3.45, 5.71], - "decimal(65, 0)": [11, ], - "decimal(65, 30)": [11, ], - "varchar(20)": ["hello world", "hello world2", "hello world3", "hello world4", ], - "char(10)": ["a" * 10, "b" * 10, ], - "date": ["2000-01-01", "2019-10-10"], - "datetime": ["2000-01-01 00:00:00", "2019-10-10 00:00:00"], - "timestamp": ["2000-01-01 00:00:00", "2019-10-10 00:00:00"], - } - - data_sample_num = 7 - primary_key_data = [] - for d in generate_data_for_types(primary_key_candidates, min_values, False, True, 1): - primary_key_data.append(d) - for d in generate_data_for_types(primary_key_candidates, max_values, False, True, 1): - primary_key_data.append(d) - for d in generate_data_for_types(primary_key_candidates, data_sample, False, True, data_sample_num - 2): - primary_key_data.append(d) - data = [] - for d in generate_data_for_types(types, min_values, True, False, 1): - data.append(d) - for d in generate_data_for_types(types, max_values, True, False, 1): - data.append(d) - for d in generate_data_for_types(types, data_sample, True, False, data_sample_num - 2): - data.append(d) - - dml_test_cases = [ - [INSERT, SELECT, UPDATE, SELECT, ], - [INSERT, SELECT, UPDATE, SELECT, DELETE, SELECT], - [INSERT, SELECT, UPDATE, SELECT, UPDATE, SELECT, DELETE, SELECT], - [INSERT, SELECT, UPDATE, SELECT, UPDATE, SELECT, UPDATE, SELECT, DELETE, SELECT], - ] - parent_dir = "./fullstack-test2/dml/dml_gen/" - directory = os.path.dirname(parent_dir) + db_name = sys.argv[1] + table_name = sys.argv[2] + test_dir = "./fullstack-test2/auto_gen/" + directory = os.path.dirname(test_dir) try: os.makedirs(directory) except OSError as e: if e.errno != errno.EEXIST: raise - generate_cases(database, table, types, data, primary_key_candidates, primary_key_data, dml_test_cases, parent_dir) - -def main(): try: - run() + print("begin to create test case to path {}".format(test_dir)) + run(db_name, table_name, test_dir) + print("create test case done") except KeyboardInterrupt: - print 'KeyboardInterrupted' + print('Test interrupted') sys.exit(1) -main() +if __name__ == "__main__": + main()