diff --git a/interval/athena/athena.go b/interval/athena/athena.go index 6340fa0c2ac56..fdc8cb1bf5324 100644 --- a/interval/athena/athena.go +++ b/interval/athena/athena.go @@ -12,6 +12,7 @@ import ( "github.com/aws/aws-sdk-go/service/athena" "github.com/pingcap/tidb/dumpling/context" "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/util/logutil" "go.uber.org/zap" ) @@ -230,7 +231,8 @@ func buildCreateTableSQL(table string, s3BucketName string, tbInfo *model.TableI buf.WriteString(", ") } writeKey(buf, col.Name.L) - buf.WriteString(" string") + buf.WriteString(" ") + buf.WriteString(getColumnType(col)) } buf.WriteString(" ) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS INPUTFORMAT 'org.apache.hadoop.mapred.TextInputFormat' ") buf.WriteString(" OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat' ") @@ -240,6 +242,27 @@ func buildCreateTableSQL(table string, s3BucketName string, tbInfo *model.TableI return buf.String() } +func getColumnType(col *model.ColumnInfo) string { + switch col.Tp { + case mysql.TypeTiny: + return "TINYINT" + case mysql.TypeShort: + return "SMALLINT" + case mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: + return "BIGINT" + case mysql.TypeTimestamp: + return "VARCHAR(64)" + case mysql.TypeDouble, mysql.TypeFloat, mysql.TypeNewDecimal: + return col.FieldType.String() + case mysql.TypeVarchar: + return fmt.Sprintf("VARCHAR(%d)", col.FieldType.Flen) + case mysql.TypeYear: + return "BIGINT" + default: + return "STRING" + } +} + func writeKey(buf *bytes.Buffer, name string) { buf.WriteByte('`') buf.WriteString(name)