Skip to content

Commit

Permalink
support INTERVAL type and AGE() function (#585)
Browse files Browse the repository at this point in the history
  • Loading branch information
jennifersp authored Aug 10, 2024
1 parent 05eb996 commit fb68084
Show file tree
Hide file tree
Showing 31 changed files with 1,147 additions and 54 deletions.
2 changes: 2 additions & 0 deletions postgres/messages/row_description.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ const (
OidTimestampArray = 1115
OidDateArray = 1182
OidTimeArray = 1183
OidInterval = 1186
OidIntervalArray = 1187
OidNumeric = 1700
OidRefcursor = 1790
OidRegprocedure = 2202
Expand Down
10 changes: 9 additions & 1 deletion server/ast/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,15 @@ func nodeExpr(node tree.Expr) (vitess.Expr, error) {
case *tree.DInt:
return nil, fmt.Errorf("the statement is not yet supported")
case *tree.DInterval:
return nil, fmt.Errorf("the statement is not yet supported")
cast, err := pgexprs.NewExplicitCastInjectable(pgtypes.Interval)
if err != nil {
return nil, err
}
expr := pgexprs.NewIntervalLiteral(node.Duration)
return vitess.InjectedExpr{
Expression: cast,
Children: vitess.Exprs{vitess.InjectedExpr{Expression: expr}},
}, nil
case *tree.DJSON:
return nil, fmt.Errorf("the statement is not yet supported")
case *tree.DOid:
Expand Down
2 changes: 2 additions & 0 deletions server/ast/resolvable_type_reference.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ func nodeResolvableTypeReference(typ tree.ResolvableTypeReference) (*vitess.Conv
resolvedType = pgtypes.Int32
case oid.T_int8:
resolvedType = pgtypes.Int64
case oid.T_interval:
resolvedType = pgtypes.Interval
case oid.T_json:
resolvedType = pgtypes.Json
case oid.T_jsonb:
Expand Down
2 changes: 2 additions & 0 deletions server/cast/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ func Init() {
initInt32()
initInt64()
initInternalChar()
initInterval()
initJson()
initJsonB()
initName()
Expand All @@ -33,5 +34,6 @@ func Init() {
initRegproc()
initRegtype()
initText()
initTime()
initVarChar()
}
1 change: 0 additions & 1 deletion server/cast/internal_char.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,4 @@ func internalCharImplicit() {
return val, nil
},
})

}
57 changes: 57 additions & 0 deletions server/cast/interval.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package cast

import (
"time"

"github.com/dolthub/go-mysql-server/sql"

"github.com/dolthub/doltgresql/postgres/parser/duration"
"github.com/dolthub/doltgresql/server/functions/framework"
pgtypes "github.com/dolthub/doltgresql/server/types"
)

// initInterval handles all casts that are built-in. This comprises only the "From" types.
func initInterval() {
intervalAssignment()
intervalImplicit()
}

// intervalAssignment registers all assignment casts. This comprises only the "From" types.
func intervalAssignment() {
framework.MustAddAssignmentTypeCast(framework.TypeCast{
FromType: pgtypes.Interval,
ToType: pgtypes.Time,
Function: func(ctx *sql.Context, val any, targetType pgtypes.DoltgresType) (any, error) {
dur := val.(duration.Duration)
// truncate the month and day of the duration.
dur.Months = 0
dur.Days = 0
return time.Parse("15:04:05.999", dur.String())
},
})
}

// intervalImplicit registers all implicit casts. This comprises only the "From" types.
func intervalImplicit() {
framework.MustAddImplicitTypeCast(framework.TypeCast{
FromType: pgtypes.Interval,
ToType: pgtypes.Interval,
Function: func(ctx *sql.Context, val any, targetType pgtypes.DoltgresType) (any, error) {
return val.(duration.Duration), nil
},
})
}
43 changes: 43 additions & 0 deletions server/cast/time.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package cast

import (
"time"

"github.com/dolthub/go-mysql-server/sql"

"github.com/dolthub/doltgresql/server/functions"
"github.com/dolthub/doltgresql/server/functions/framework"
pgtypes "github.com/dolthub/doltgresql/server/types"
)

// initTime handles all casts that are built-in. This comprises only the "From" types.
func initTime() {
timeImplicit()
}

// timeImplicit registers all implicit casts. This comprises only the "From" types.
func timeImplicit() {
framework.MustAddImplicitTypeCast(framework.TypeCast{
FromType: pgtypes.Time,
ToType: pgtypes.Interval,
Function: func(ctx *sql.Context, val any, targetType pgtypes.DoltgresType) (any, error) {
t := val.(time.Time)
dur := functions.GetIntervalDurationFromTimeComponents(0, 0, 0, int64(t.Hour()), int64(t.Minute()), int64(t.Second()), int64(t.Nanosecond()))
return dur, nil
},
})
}
9 changes: 9 additions & 0 deletions server/expression/literal.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
vitess "github.com/dolthub/vitess/go/vt/sqlparser"
"github.com/shopspring/decimal"

"github.com/dolthub/doltgresql/postgres/parser/duration"
"github.com/dolthub/doltgresql/server/functions/framework"
pgtypes "github.com/dolthub/doltgresql/server/types"
)
Expand Down Expand Up @@ -88,6 +89,14 @@ func NewStringLiteral(stringValue string) *Literal {
}
}

// NewIntervalLiteral returns a new *Literal containing a INTERVAL value.
func NewIntervalLiteral(duration duration.Duration) *Literal {
return &Literal{
value: duration,
typ: pgtypes.Interval,
}
}

// NewJSONLiteral returns a new *Literal containing a JSON value. This is different from JSONB.
func NewJSONLiteral(jsonValue string) *Literal {
return &Literal{
Expand Down
122 changes: 122 additions & 0 deletions server/functions/age.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package functions

import (
"time"

"github.com/dolthub/go-mysql-server/sql"

"github.com/dolthub/doltgresql/postgres/parser/duration"
"github.com/dolthub/doltgresql/server/functions/framework"
pgtypes "github.com/dolthub/doltgresql/server/types"
)

// initAge registers the functions to the catalog.
func initAge() {
framework.RegisterFunction(age_timestamp_timestamp)
framework.RegisterFunction(age_timestamp)
}

// age_timestamp_timestamp represents the PostgreSQL date/time function.
var age_timestamp_timestamp = framework.Function2{
Name: "age",
Return: pgtypes.Interval,
Parameters: [2]pgtypes.DoltgresType{pgtypes.Timestamp, pgtypes.Timestamp},
IsNonDeterministic: true,
Strict: true,
Callable: func(ctx *sql.Context, _ [3]pgtypes.DoltgresType, val1, val2 any) (any, error) {
t1 := val1.(time.Time)
t2 := val2.(time.Time)
return diffTimes(t1, t2), nil
},
}

// age_timestamp_timestamp represents the PostgreSQL date/time function.
var age_timestamp = framework.Function1{
Name: "age",
Return: pgtypes.Interval,
Parameters: [1]pgtypes.DoltgresType{pgtypes.Timestamp},
IsNonDeterministic: true,
Strict: true,
Callable: func(ctx *sql.Context, _ [2]pgtypes.DoltgresType, val any) (any, error) {
t := val.(time.Time)
// current_date (at midnight)
cur, err := time.Parse("2006-01-02", time.Now().Format("2006-01-02"))
if err != nil {
return nil, err
}
return diffTimes(cur, t), nil
},
}

// diffTimes returns the duration t1-t2. It subtracts each time component separately,
// unlike time.Sub() function.
func diffTimes(t1, t2 time.Time) duration.Duration {
// if t1 is before t2, then negate the result.
negate := t1.Before(t2)
if negate {
t1, t2 = t2, t1
}

// Calculate difference in each unit
years := int64(t1.Year() - t2.Year())
months := int64(t1.Month() - t2.Month())
days := int64(t1.Day() - t2.Day())
hours := int64(t1.Hour() - t2.Hour())
minutes := int64(t1.Minute() - t2.Minute())
seconds := int64(t1.Second() - t2.Second())
nanoseconds := int64(t1.Nanosecond() - t2.Nanosecond())

// Adjust for any negative values
if nanoseconds < 0 {
nanoseconds += 1e9
seconds--
}
if seconds < 0 {
seconds += 60
minutes--
}
if minutes < 0 {
minutes += 60
hours--
}
if hours < 0 {
hours += 24
days--
}
if days < 0 {
days += 30
months--
}
if months < 0 {
months += 12
years--
}

dur := GetIntervalDurationFromTimeComponents(years, months, days, hours, minutes, seconds, nanoseconds)
if negate {
return dur.Mul(-1)
}
return dur
}

func GetIntervalDurationFromTimeComponents(years, months, days, hours, minutes, seconds, nanos int64) duration.Duration {
durNanos := nanos + seconds*NanosPerSec + minutes*NanosPerSec*duration.SecsPerMinute + hours*NanosPerSec*duration.SecsPerHour
durDays := days
durMonths := months + years*duration.MonthsPerYear

return duration.MakeDuration(durNanos, durDays, durMonths)
}
16 changes: 16 additions & 0 deletions server/functions/binary/divide.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/dolthub/go-mysql-server/sql"
"github.com/shopspring/decimal"

"github.com/dolthub/doltgresql/postgres/parser/duration"
"github.com/dolthub/doltgresql/server/functions/framework"
pgtypes "github.com/dolthub/doltgresql/server/types"
)
Expand All @@ -42,6 +43,7 @@ func initBinaryDivide() {
framework.RegisterBinaryFunction(framework.Operator_BinaryDivide, int8div)
framework.RegisterBinaryFunction(framework.Operator_BinaryDivide, int82div)
framework.RegisterBinaryFunction(framework.Operator_BinaryDivide, int84div)
framework.RegisterBinaryFunction(framework.Operator_BinaryDivide, interval_div)
framework.RegisterBinaryFunction(framework.Operator_BinaryDivide, numeric_div)
}

Expand Down Expand Up @@ -227,6 +229,20 @@ var int84div = framework.Function2{
},
}

// interval_div represents the PostgreSQL function of the same name, taking the same parameters.
var interval_div = framework.Function2{
Name: "interval_div",
Return: pgtypes.Interval,
Parameters: [2]pgtypes.DoltgresType{pgtypes.Interval, pgtypes.Float64},
Strict: true,
Callable: func(ctx *sql.Context, _ [3]pgtypes.DoltgresType, val1 any, val2 any) (any, error) {
if val2.(float64) == 0 {
return nil, fmt.Errorf("division by zero")
}
return val1.(duration.Duration).DivFloat(val2.(float64)), nil
},
}

// numeric_div represents the PostgreSQL function of the same name, taking the same parameters.
var numeric_div = framework.Function2{
Name: "numeric_div",
Expand Down
14 changes: 14 additions & 0 deletions server/functions/binary/equal.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/dolthub/go-mysql-server/sql"
"github.com/shopspring/decimal"

"github.com/dolthub/doltgresql/postgres/parser/duration"
"github.com/dolthub/doltgresql/postgres/parser/uuid"
"github.com/dolthub/doltgresql/server/functions/framework"
pgtypes "github.com/dolthub/doltgresql/server/types"
Expand Down Expand Up @@ -50,6 +51,7 @@ func initBinaryEqual() {
framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, int82eq)
framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, int84eq)
framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, int8eq)
framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, interval_eq)
framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, jsonb_eq)
framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, nameeq)
framework.RegisterBinaryFunction(framework.Operator_BinaryEqual, nameeqtext)
Expand Down Expand Up @@ -310,6 +312,18 @@ var int8eq = framework.Function2{
},
}

// interval_eq represents the PostgreSQL function of the same name, taking the same parameters.
var interval_eq = framework.Function2{
Name: "interval_eq",
Return: pgtypes.Bool,
Parameters: [2]pgtypes.DoltgresType{pgtypes.Interval, pgtypes.Interval},
Strict: true,
Callable: func(ctx *sql.Context, _ [3]pgtypes.DoltgresType, val1 any, val2 any) (any, error) {
res, err := pgtypes.Interval.Compare(val1.(duration.Duration), val2.(duration.Duration))
return res == 0, err
},
}

// jsonb_eq represents the PostgreSQL function of the same name, taking the same parameters.
var jsonb_eq = framework.Function2{
Name: "jsonb_eq",
Expand Down
Loading

0 comments on commit fb68084

Please sign in to comment.