diff --git a/br/pkg/task/backup.go b/br/pkg/task/backup.go index de94b2d385ad5..311f57e9de020 100644 --- a/br/pkg/task/backup.go +++ b/br/pkg/task/backup.go @@ -137,7 +137,7 @@ func (cfg *BackupConfig) ParseFromFlags(flags *pflag.FlagSet) error { if err != nil { return errors.Trace(err) } - cfg.BackupTS, err = ParseTSString(backupTS) + cfg.BackupTS, err = ParseTSString(backupTS, false) if err != nil { return errors.Trace(err) } @@ -528,7 +528,7 @@ func RunBackup(c context.Context, g glue.Glue, cmdName string, cfg *BackupConfig } // ParseTSString port from tidb setSnapshotTS. -func ParseTSString(ts string) (uint64, error) { +func ParseTSString(ts string, tzCheck bool) (uint64, error) { if len(ts) == 0 { return 0, nil } @@ -540,6 +540,12 @@ func ParseTSString(ts string) (uint64, error) { sc := &stmtctx.StatementContext{ TimeZone: loc, } + if tzCheck { + tzIdx, _, _, _, _ := types.GetTimezone(ts) + if tzIdx < 0 { + return 0, errors.Errorf("must set timezone when using datetime format ts") + } + } t, err := types.ParseTime(sc, ts, mysql.TypeTimestamp, types.MaxFsp) if err != nil { return 0, errors.Trace(err) diff --git a/br/pkg/task/backup_test.go b/br/pkg/task/backup_test.go index ef3f33d27ee54..25d6de8ee66fb 100644 --- a/br/pkg/task/backup_test.go +++ b/br/pkg/task/backup_test.go @@ -16,21 +16,40 @@ func TestParseTSString(t *testing.T) { err error ) - ts, err = ParseTSString("") + ts, err = ParseTSString("", false) require.NoError(t, err) require.Zero(t, ts) - ts, err = ParseTSString("400036290571534337") + ts, err = ParseTSString("400036290571534337", false) require.NoError(t, err) require.Equal(t, uint64(400036290571534337), ts) - ts, err = ParseTSString("2021-01-01 01:42:23") + ts, err = ParseTSString("2021-01-01 01:42:23", false) require.NoError(t, err) localTime := time.Date(2021, time.Month(1), 1, 1, 42, 23, 0, time.Local) - localTimestamp := localTime.Unix() localTSO := uint64((localTimestamp << 18) * 1000) require.Equal(t, localTSO, ts) + + _, err = ParseTSString("2021-01-01 01:42:23", true) + require.Error(t, err) + require.Regexp(t, "must set timezone*", err.Error()) + + ts, err = ParseTSString("2021-01-01 01:42:23+00:00", true) + require.NoError(t, err) + localTime = time.Date(2021, time.Month(1), 1, 1, 42, 23, 0, time.UTC) + localTimestamp = localTime.Unix() + localTSO = uint64((localTimestamp << 18) * 1000) + require.Equal(t, localTSO, ts) + + ts, err = ParseTSString("2021-01-01 01:42:23+08:00", true) + require.NoError(t, err) + secondsEastOfUTC := int((8 * time.Hour).Seconds()) + beijing := time.FixedZone("Beijing Time", secondsEastOfUTC) + localTime = time.Date(2021, time.Month(1), 1, 1, 42, 23, 0, beijing) + localTimestamp = localTime.Unix() + localTSO = uint64((localTimestamp << 18) * 1000) + require.Equal(t, localTSO, ts) } func TestParseCompressionType(t *testing.T) { diff --git a/br/pkg/task/restore.go b/br/pkg/task/restore.go index 1e6cbd2432065..d139d58a12186 100644 --- a/br/pkg/task/restore.go +++ b/br/pkg/task/restore.go @@ -182,14 +182,14 @@ func (cfg *RestoreConfig) ParseStreamRestoreFlags(flags *pflag.FlagSet) error { if err != nil { return errors.Trace(err) } - if cfg.StartTS, err = ParseTSString(tsString); err != nil { + if cfg.StartTS, err = ParseTSString(tsString, true); err != nil { return errors.Trace(err) } tsString, err = flags.GetString(FlagStreamRestoreTS) if err != nil { return errors.Trace(err) } - if cfg.RestoreTS, err = ParseTSString(tsString); err != nil { + if cfg.RestoreTS, err = ParseTSString(tsString, true); err != nil { return errors.Trace(err) } diff --git a/br/pkg/task/stream.go b/br/pkg/task/stream.go index 7d8b717f07e0d..e3337cc49b039 100644 --- a/br/pkg/task/stream.go +++ b/br/pkg/task/stream.go @@ -189,7 +189,7 @@ func (cfg *StreamConfig) ParseStreamTruncateFromFlags(flags *pflag.FlagSet) erro if err != nil { return errors.Trace(err) } - if cfg.Until, err = ParseTSString(tsString); err != nil { + if cfg.Until, err = ParseTSString(tsString, true); err != nil { return errors.Trace(err) } if cfg.SkipPrompt, err = flags.GetBool(flagYes); err != nil { @@ -213,7 +213,7 @@ func (cfg *StreamConfig) ParseStreamStartFromFlags(flags *pflag.FlagSet) error { return errors.Trace(err) } - if cfg.StartTS, err = ParseTSString(tsString); err != nil { + if cfg.StartTS, err = ParseTSString(tsString, true); err != nil { return errors.Trace(err) } @@ -222,7 +222,7 @@ func (cfg *StreamConfig) ParseStreamStartFromFlags(flags *pflag.FlagSet) error { return errors.Trace(err) } - if cfg.EndTS, err = ParseTSString(tsString); err != nil { + if cfg.EndTS, err = ParseTSString(tsString, true); err != nil { return errors.Trace(err) }