diff --git a/go.mod b/go.mod index 28b1dd0452..42e0e77765 100644 --- a/go.mod +++ b/go.mod @@ -8,11 +8,11 @@ require ( github.com/PuerkitoBio/goquery v1.8.1 github.com/cockroachdb/apd/v2 v2.0.3-0.20200518165714-d020e156310a github.com/cockroachdb/errors v1.7.5 - github.com/dolthub/dolt/go v0.40.5-0.20240918171330-ae4c8c780fd7 + github.com/dolthub/dolt/go v0.40.5-0.20240918224257-88ae8c98593a github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20240827111219-e4bb9ca3442d github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 github.com/dolthub/go-icu-regex v0.0.0-20240916130659-0118adc6b662 - github.com/dolthub/go-mysql-server v0.18.2-0.20240918043246-e3846f5468f5 + github.com/dolthub/go-mysql-server v0.18.2-0.20240918214853-7e76e21750a6 github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 github.com/dolthub/vitess v0.0.0-20240916204416-9d4d4a09b1d9 github.com/fatih/color v1.13.0 diff --git a/go.sum b/go.sum index 4b6c85f3df..3af26e7ef8 100644 --- a/go.sum +++ b/go.sum @@ -214,8 +214,8 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZm github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec= github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= -github.com/dolthub/dolt/go v0.40.5-0.20240918171330-ae4c8c780fd7 h1:qdlLYlU0bLqOEi3BDcWzRj31LUjy40B6e2nXkWs8xWo= -github.com/dolthub/dolt/go v0.40.5-0.20240918171330-ae4c8c780fd7/go.mod h1:95aBt3R6EbixJ32k/mYTKj8XekDpgfDItulcz5VpEcU= +github.com/dolthub/dolt/go v0.40.5-0.20240918224257-88ae8c98593a h1:PkWAP7KQ954Plu1qkEYJFTKZVNsud6Vf1sRNzSr6Pmg= +github.com/dolthub/dolt/go v0.40.5-0.20240918224257-88ae8c98593a/go.mod h1:BQ/uK6QhfC8A0Lfik3JIwdVRY/JyGDsOJt1FkSvqEUk= github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20240827111219-e4bb9ca3442d h1:RZkQeYOrDrOWzCxaP2ttkvg4E2TM9n8lnEsIBLKjqkM= github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20240827111219-e4bb9ca3442d/go.mod h1:L5RDYZbC9BBWmoU2+TjTekeqqhFXX5EqH9ln00O0stY= github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 h1:u3PMzfF8RkKd3lB9pZ2bfn0qEG+1Gms9599cr0REMww= @@ -224,8 +224,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= github.com/dolthub/go-icu-regex v0.0.0-20240916130659-0118adc6b662 h1:aC17hZD6iwzBwwfO5M+3oBT5E5gGRiQPdn+vzpDXqIA= github.com/dolthub/go-icu-regex v0.0.0-20240916130659-0118adc6b662/go.mod h1:KPUcpx070QOfJK1gNe0zx4pA5sicIK1GMikIGLKC168= -github.com/dolthub/go-mysql-server v0.18.2-0.20240918043246-e3846f5468f5 h1:m12ohMzoZUQExnXqWeTYziQiSwwtzjoEFBSKzBVHin8= -github.com/dolthub/go-mysql-server v0.18.2-0.20240918043246-e3846f5468f5/go.mod h1:m88EMm9OthVVa6qIhbpnRDpj/eYUXuNpvY/+0YWKVwc= +github.com/dolthub/go-mysql-server v0.18.2-0.20240918214853-7e76e21750a6 h1:bqXlOmbV1cX3G83xLz4+czZUsAn4Fp2hRsypkz6N1fM= +github.com/dolthub/go-mysql-server v0.18.2-0.20240918214853-7e76e21750a6/go.mod h1:m88EMm9OthVVa6qIhbpnRDpj/eYUXuNpvY/+0YWKVwc= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q= github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE= diff --git a/postgres/messages/data_row.go b/postgres/messages/data_row.go index 911a8adeaa..035d23f8be 100644 --- a/postgres/messages/data_row.go +++ b/postgres/messages/data_row.go @@ -17,8 +17,6 @@ package messages import ( "fmt" - "github.com/dolthub/vitess/go/sqltypes" - "github.com/dolthub/doltgresql/postgres/connection" ) @@ -28,7 +26,7 @@ func init() { // DataRow represents a row of data. type DataRow struct { - Values []sqltypes.Value + Values [][]byte } var dataRowDefault = connection.MessageFormat{ @@ -75,12 +73,13 @@ var _ connection.Message = DataRow{} func (m DataRow) Encode() (connection.MessageFormat, error) { outputMessage := m.DefaultMessage().Copy() for i := 0; i < len(m.Values); i++ { - if m.Values[i].IsNull() { + if m.Values[i] == nil { outputMessage.Field("Columns").Child("ColumnLength", i).MustWrite(-1) } else { - value := []byte(m.Values[i].ToString()) - outputMessage.Field("Columns").Child("ColumnLength", i).MustWrite(len(value)) - outputMessage.Field("Columns").Child("ColumnData", i).MustWrite(value) + value := m.Values[i] + valLen := len(value) + outputMessage.Field("Columns").Child("ColumnLength", i).MustWrite(valLen) + outputMessage.Field("Columns").Child("ColumnData", i).MustWrite(value[:valLen]) } } return outputMessage, nil diff --git a/postgres/messages/row_description.go b/postgres/messages/row_description.go index 3c16d9f28d..72822514d5 100644 --- a/postgres/messages/row_description.go +++ b/postgres/messages/row_description.go @@ -17,8 +17,8 @@ package messages import ( "fmt" - "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/vitess/go/vt/proto/query" + "github.com/jackc/pgx/v5/pgproto3" "github.com/dolthub/doltgresql/postgres/connection" ) @@ -110,7 +110,7 @@ func init() { // RowDescription represents a RowDescription message intended for the client. type RowDescription struct { - Fields []*query.Field + Fields []pgproto3.FieldDescription } var rowDescriptionDefault = connection.MessageFormat{ @@ -182,22 +182,11 @@ func (m RowDescription) Encode() (connection.MessageFormat, error) { outputMessage := m.DefaultMessage().Copy() for i := 0; i < len(m.Fields); i++ { field := m.Fields[i] - dataTypeObjectID, err := VitessFieldToDataTypeObjectID(field) - if err != nil { - return connection.MessageFormat{}, err - } - dataTypeSize, err := VitessFieldToDataTypeSize(field) - if err != nil { - return connection.MessageFormat{}, err - } - dataTypeModifier, err := VitessFieldToDataTypeModifier(field) - if err != nil { - return connection.MessageFormat{}, err - } - outputMessage.Field("Fields").Child("ColumnName", i).MustWrite(field.Name) - outputMessage.Field("Fields").Child("DataTypeObjectID", i).MustWrite(dataTypeObjectID) - outputMessage.Field("Fields").Child("DataTypeSize", i).MustWrite(dataTypeSize) - outputMessage.Field("Fields").Child("DataTypeModifier", i).MustWrite(dataTypeModifier) + outputMessage.Field("Fields").Child("ColumnName", i).MustWrite(string(field.Name)) + outputMessage.Field("Fields").Child("DataTypeObjectID", i).MustWrite(field.DataTypeOID) + outputMessage.Field("Fields").Child("DataTypeSize", i).MustWrite(field.DataTypeSize) + outputMessage.Field("Fields").Child("DataTypeModifier", i).MustWrite(field.TypeModifier) + outputMessage.Field("Fields").Child("FormatCode", i).MustWrite(field.Format) } return outputMessage, nil } @@ -215,13 +204,7 @@ func (m RowDescription) DefaultMessage() *connection.MessageFormat { return &rowDescriptionDefault } -// VitessFieldToDataTypeObjectID returns the type of a vitess Field into a type as defined by Postgres. -// OIDs can be obtained with the following query: `SELECT oid, typname FROM pg_type ORDER BY 1;` -func VitessFieldToDataTypeObjectID(field *query.Field) (uint32, error) { - return VitessTypeToObjectID(field.Type) -} - -// VitessFieldToDataTypeObjectID returns a type, as defined by Vitess, into a type as defined by Postgres. +// VitessTypeToObjectID returns a type, as defined by Vitess, into a type as defined by Postgres. // OIDs can be obtained with the following query: `SELECT oid, typname FROM pg_type ORDER BY 1;` func VitessTypeToObjectID(typ query.Type) (uint32, error) { switch typ { @@ -272,10 +255,8 @@ func VitessTypeToObjectID(typ query.Type) (uint32, error) { case query.Type_JSON: return OidJson, nil case query.Type_TIMESTAMP, query.Type_DATETIME: - const OidTimestamp = 1114 return OidTimestamp, nil case query.Type_DATE: - const OidDate = 1082 return OidDate, nil case query.Type_NULL_TYPE: return OidText, nil // NULL is treated as TEXT on the wire @@ -285,118 +266,3 @@ func VitessTypeToObjectID(typ query.Type) (uint32, error) { return 0, fmt.Errorf("unsupported type: %s", typ) } } - -// VitessFieldToDataTypeSize returns the type's size, as defined by Vitess, into the size as defined by Postgres. -func VitessFieldToDataTypeSize(field *query.Field) (int16, error) { - switch field.Type { - case query.Type_INT8: - return 1, nil - case query.Type_INT16: - return 2, nil - case query.Type_INT24: - return 4, nil - case query.Type_INT32: - return 4, nil - case query.Type_INT64: - return 8, nil - case query.Type_UINT8: - return 4, nil - case query.Type_UINT16: - return 4, nil - case query.Type_UINT24: - return 4, nil - case query.Type_UINT32: - return 4, nil - case query.Type_UINT64: - // Since this has an upperbound greater than `INT64`, we'll treat it as `NUMERIC` - return -1, nil - case query.Type_FLOAT32: - return 4, nil - case query.Type_FLOAT64: - return 8, nil - case query.Type_DECIMAL: - return -1, nil - case query.Type_CHAR: - return -1, nil - case query.Type_VARCHAR: - return -1, nil - case query.Type_TEXT: - return -1, nil - case query.Type_BLOB: - return -1, nil - case query.Type_JSON: - return -1, nil - case query.Type_TIMESTAMP, query.Type_DATETIME: - return 8, nil - case query.Type_DATE: - return 4, nil - case query.Type_NULL_TYPE: - return -1, nil // NULL is treated as TEXT on the wire - case query.Type_ENUM: - return -1, nil // TODO: temporary solution until we support CREATE TYPE - default: - return 0, fmt.Errorf("unsupported type returned from engine: %s", field.Type) - } -} - -// VitessFieldToDataTypeModifier returns the field's data type modifier as defined by Postgres. -func VitessFieldToDataTypeModifier(field *query.Field) (int32, error) { - switch field.Type { - case query.Type_INT8: - return -1, nil - case query.Type_INT16: - return -1, nil - case query.Type_INT24: - return -1, nil - case query.Type_INT32: - return -1, nil - case query.Type_INT64: - return -1, nil - case query.Type_UINT8: - return -1, nil - case query.Type_UINT16: - return -1, nil - case query.Type_UINT24: - return -1, nil - case query.Type_UINT32: - return -1, nil - case query.Type_UINT64: - // Since we're encoding this as `NUMERIC`, we emulate a `NUMERIC` type with a precision of 19 and a scale of 0 - return (19 << 16) + 4, nil - case query.Type_FLOAT32: - return -1, nil - case query.Type_FLOAT64: - return -1, nil - case query.Type_DECIMAL: - // This is how we encode the precision and scale for some reason - precision := int32(field.ColumnLength - 1) - scale := int32(field.Decimals) - if scale > 0 { - precision-- - } - // PostgreSQL adds 4 to the length for an unknown reason - return (precision<<16 + scale) + 4, nil - case query.Type_CHAR: - // PostgreSQL adds 4 to the length for an unknown reason - return int32(int64(field.ColumnLength)/sql.CharacterSetID(field.Charset).MaxLength()) + 4, nil - case query.Type_VARCHAR: - // PostgreSQL adds 4 to the length for an unknown reason - return int32(int64(field.ColumnLength)/sql.CharacterSetID(field.Charset).MaxLength()) + 4, nil - case query.Type_TEXT: - return -1, nil - case query.Type_BLOB: - return -1, nil - case query.Type_JSON: - return -1, nil - case query.Type_TIMESTAMP, query.Type_DATETIME: - return -1, nil - case query.Type_DATE: - return -1, nil - case query.Type_NULL_TYPE: - return -1, nil // NULL is treated as TEXT on the wire - case query.Type_ENUM: - return -1, nil // TODO: temporary solution until we support CREATE TYPE - default: - return 0, fmt.Errorf("unsupported type returned from engine: %s", field.Type) - } -} diff --git a/server/connection_data.go b/server/connection_data.go new file mode 100644 index 0000000000..725d2e673b --- /dev/null +++ b/server/connection_data.go @@ -0,0 +1,302 @@ +// Copyright 2023 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "fmt" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + "github.com/dolthub/go-mysql-server/sql/plan" + "github.com/dolthub/go-mysql-server/sql/transform" + vitess "github.com/dolthub/vitess/go/vt/sqlparser" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/lib/pq/oid" + + "github.com/dolthub/doltgresql/core/dataloader" + "github.com/dolthub/doltgresql/postgres/messages" + pgexprs "github.com/dolthub/doltgresql/server/expression" + "github.com/dolthub/doltgresql/server/node" + pgtypes "github.com/dolthub/doltgresql/server/types" +) + +// ConvertedQuery represents a query that has been converted from the Postgres representation to the Vitess +// representation. String may contain the string version of the converted query. AST will contain the tree +// version of the converted query, and is the recommended form to use. If AST is nil, then use the String version, +// otherwise always prefer to AST. +type ConvertedQuery struct { + String string + AST vitess.Statement + StatementTag string +} + +// copyFromStdinState tracks the metadata for an import of data into a table using a COPY FROM STDIN statement. When +// this statement is processed, the server accepts COPY DATA messages from the client with chunks of data to load +// into a table. +type copyFromStdinState struct { + copyFromStdinNode *node.CopyFrom + dataLoader *dataloader.TabularDataLoader +} + +type PortalData struct { + Query ConvertedQuery + IsEmptyQuery bool + Fields []pgproto3.FieldDescription + BoundPlan sql.Node +} + +type PreparedStatementData struct { + Query ConvertedQuery + ReturnFields []pgproto3.FieldDescription + BindVarTypes []uint32 +} + +// extractBindVarTypes returns types based on the given query plan. +// This function is used to get bind var types for running our prepared +// tests only. A valid prepared query and execution messages must have +// the types defined. +func extractBindVarTypes(queryPlan sql.Node) ([]uint32, error) { + inspectNode := queryPlan + switch queryPlan := queryPlan.(type) { + case *plan.InsertInto: + inspectNode = queryPlan.Source + } + + types := make([]uint32, 0) + var err error + extractBindVars := func(expr sql.Expression) bool { + if err != nil { + return false + } + switch e := expr.(type) { + case *expression.BindVar: + var typOid uint32 + if doltgresType, ok := e.Type().(pgtypes.DoltgresType); ok { + typOid = doltgresType.OID() + } else { + // TODO: should remove usage non doltgres type + typOid, err = messages.VitessTypeToObjectID(e.Type().Type()) + if err != nil { + err = fmt.Errorf("could not determine OID for placeholder %s: %w", e.Name, err) + return false + } + } + types = append(types, typOid) + case *pgexprs.ExplicitCast: + if bindVar, ok := e.Child().(*expression.BindVar); ok { + var typOid uint32 + if doltgresType, ok := bindVar.Type().(pgtypes.DoltgresType); ok { + typOid = doltgresType.OID() + } else { + typOid, err = messages.VitessTypeToObjectID(e.Type().Type()) + if err != nil { + err = fmt.Errorf("could not determine OID for placeholder %s: %w", bindVar.Name, err) + return false + } + } + types = append(types, typOid) + return false + } + // $1::text and similar get converted to a Convert expression wrapping the bindvar + case *expression.Convert: + if bindVar, ok := e.Child.(*expression.BindVar); ok { + var typOid uint32 + typOid, err = messages.VitessTypeToObjectID(e.Type().Type()) + if err != nil { + err = fmt.Errorf("could not determine OID for placeholder %s: %w", bindVar.Name, err) + return false + } + types = append(types, typOid) + return false + } + } + + return true + } + + transform.InspectExpressions(inspectNode, extractBindVars) + return types, err +} + +// OidToDoltgresType is map of oid to Doltgres type. +var OidToDoltgresType = map[uint32]pgtypes.DoltgresType{ + uint32(oid.T_bool): pgtypes.Bool, + uint32(oid.T_bytea): pgtypes.Bytea, + uint32(oid.T_char): pgtypes.InternalChar, + uint32(oid.T_name): pgtypes.Name, + uint32(oid.T_int8): pgtypes.Int64, + uint32(oid.T_int2): pgtypes.Int16, + uint32(oid.T_int2vector): pgtypes.Unknown, + uint32(oid.T_int4): pgtypes.Int32, + uint32(oid.T_regproc): pgtypes.Regproc, + uint32(oid.T_text): pgtypes.Text, + uint32(oid.T_oid): pgtypes.Oid, + uint32(oid.T_tid): pgtypes.Unknown, + uint32(oid.T_xid): pgtypes.Xid, + uint32(oid.T_cid): pgtypes.Unknown, + uint32(oid.T_oidvector): pgtypes.Unknown, + uint32(oid.T_pg_ddl_command): pgtypes.Unknown, + uint32(oid.T_pg_type): pgtypes.Unknown, + uint32(oid.T_pg_attribute): pgtypes.Unknown, + uint32(oid.T_pg_proc): pgtypes.Unknown, + uint32(oid.T_pg_class): pgtypes.Unknown, + uint32(oid.T_json): pgtypes.Json, + uint32(oid.T_xml): pgtypes.Unknown, + uint32(oid.T__xml): pgtypes.Unknown, + uint32(oid.T_pg_node_tree): pgtypes.Unknown, + uint32(oid.T__json): pgtypes.JsonArray, + uint32(oid.T_smgr): pgtypes.Unknown, + uint32(oid.T_index_am_handler): pgtypes.Unknown, + uint32(oid.T_point): pgtypes.Unknown, + uint32(oid.T_lseg): pgtypes.Unknown, + uint32(oid.T_path): pgtypes.Unknown, + uint32(oid.T_box): pgtypes.Unknown, + uint32(oid.T_polygon): pgtypes.Unknown, + uint32(oid.T_line): pgtypes.Unknown, + uint32(oid.T__line): pgtypes.Unknown, + uint32(oid.T_cidr): pgtypes.Unknown, + uint32(oid.T__cidr): pgtypes.Unknown, + uint32(oid.T_float4): pgtypes.Float32, + uint32(oid.T_float8): pgtypes.Float64, + uint32(oid.T_abstime): pgtypes.Unknown, + uint32(oid.T_reltime): pgtypes.Unknown, + uint32(oid.T_tinterval): pgtypes.Unknown, + uint32(oid.T_unknown): pgtypes.Unknown, + uint32(oid.T_circle): pgtypes.Unknown, + uint32(oid.T__circle): pgtypes.Unknown, + uint32(oid.T_money): pgtypes.Unknown, + uint32(oid.T__money): pgtypes.Unknown, + uint32(oid.T_macaddr): pgtypes.Unknown, + uint32(oid.T_inet): pgtypes.Unknown, + uint32(oid.T__bool): pgtypes.BoolArray, + uint32(oid.T__bytea): pgtypes.ByteaArray, + uint32(oid.T__char): pgtypes.InternalCharArray, + uint32(oid.T__name): pgtypes.NameArray, + uint32(oid.T__int2): pgtypes.Int16Array, + uint32(oid.T__int2vector): pgtypes.Unknown, + uint32(oid.T__int4): pgtypes.Int32Array, + uint32(oid.T__regproc): pgtypes.RegprocArray, + uint32(oid.T__text): pgtypes.TextArray, + uint32(oid.T__tid): pgtypes.Unknown, + uint32(oid.T__xid): pgtypes.XidArray, + uint32(oid.T__cid): pgtypes.Unknown, + uint32(oid.T__oidvector): pgtypes.Unknown, + uint32(oid.T__bpchar): pgtypes.BpCharArray, + uint32(oid.T__varchar): pgtypes.VarCharArray, + uint32(oid.T__int8): pgtypes.Int64Array, + uint32(oid.T__point): pgtypes.Unknown, + uint32(oid.T__lseg): pgtypes.Unknown, + uint32(oid.T__path): pgtypes.Unknown, + uint32(oid.T__box): pgtypes.Unknown, + uint32(oid.T__float4): pgtypes.Float32Array, + uint32(oid.T__float8): pgtypes.Float64Array, + uint32(oid.T__abstime): pgtypes.Unknown, + uint32(oid.T__reltime): pgtypes.Unknown, + uint32(oid.T__tinterval): pgtypes.Unknown, + uint32(oid.T__polygon): pgtypes.Unknown, + uint32(oid.T__oid): pgtypes.OidArray, + uint32(oid.T_aclitem): pgtypes.Unknown, + uint32(oid.T__aclitem): pgtypes.Unknown, + uint32(oid.T__macaddr): pgtypes.Unknown, + uint32(oid.T__inet): pgtypes.Unknown, + uint32(oid.T_bpchar): pgtypes.BpChar, + uint32(oid.T_varchar): pgtypes.VarChar, + uint32(oid.T_date): pgtypes.Date, + uint32(oid.T_time): pgtypes.Time, + uint32(oid.T_timestamp): pgtypes.Timestamp, + uint32(oid.T__timestamp): pgtypes.TimestampArray, + uint32(oid.T__date): pgtypes.DateArray, + uint32(oid.T__time): pgtypes.TimeArray, + uint32(oid.T_timestamptz): pgtypes.TimestampTZ, + uint32(oid.T__timestamptz): pgtypes.TimestampTZArray, + uint32(oid.T_interval): pgtypes.Interval, + uint32(oid.T__interval): pgtypes.IntervalArray, + uint32(oid.T__numeric): pgtypes.NumericArray, + uint32(oid.T_pg_database): pgtypes.Unknown, + uint32(oid.T__cstring): pgtypes.Unknown, + uint32(oid.T_timetz): pgtypes.TimeTZ, + uint32(oid.T__timetz): pgtypes.TimeTZArray, + uint32(oid.T_bit): pgtypes.Unknown, + uint32(oid.T__bit): pgtypes.Unknown, + uint32(oid.T_varbit): pgtypes.Unknown, + uint32(oid.T__varbit): pgtypes.Unknown, + uint32(oid.T_numeric): pgtypes.Numeric, + uint32(oid.T_refcursor): pgtypes.Unknown, + uint32(oid.T__refcursor): pgtypes.Unknown, + uint32(oid.T_regprocedure): pgtypes.Unknown, + uint32(oid.T_regoper): pgtypes.Unknown, + uint32(oid.T_regoperator): pgtypes.Unknown, + uint32(oid.T_regclass): pgtypes.Regclass, + uint32(oid.T_regtype): pgtypes.Regtype, + uint32(oid.T__regprocedure): pgtypes.Unknown, + uint32(oid.T__regoper): pgtypes.Unknown, + uint32(oid.T__regoperator): pgtypes.Unknown, + uint32(oid.T__regclass): pgtypes.RegclassArray, + uint32(oid.T__regtype): pgtypes.RegtypeArray, + uint32(oid.T_record): pgtypes.Unknown, + uint32(oid.T_cstring): pgtypes.Unknown, + uint32(oid.T_any): pgtypes.Unknown, + uint32(oid.T_anyarray): pgtypes.AnyArray, + uint32(oid.T_void): pgtypes.Unknown, + uint32(oid.T_trigger): pgtypes.Unknown, + uint32(oid.T_language_handler): pgtypes.Unknown, + uint32(oid.T_internal): pgtypes.Unknown, + uint32(oid.T_opaque): pgtypes.Unknown, + uint32(oid.T_anyelement): pgtypes.AnyElement, + uint32(oid.T__record): pgtypes.Unknown, + uint32(oid.T_anynonarray): pgtypes.AnyNonArray, + uint32(oid.T_pg_authid): pgtypes.Unknown, + uint32(oid.T_pg_auth_members): pgtypes.Unknown, + uint32(oid.T__txid_snapshot): pgtypes.Unknown, + uint32(oid.T_uuid): pgtypes.Uuid, + uint32(oid.T__uuid): pgtypes.UuidArray, + uint32(oid.T_txid_snapshot): pgtypes.Unknown, + uint32(oid.T_fdw_handler): pgtypes.Unknown, + uint32(oid.T_pg_lsn): pgtypes.Unknown, + uint32(oid.T__pg_lsn): pgtypes.Unknown, + uint32(oid.T_tsm_handler): pgtypes.Unknown, + uint32(oid.T_anyenum): pgtypes.Unknown, + uint32(oid.T_tsvector): pgtypes.Unknown, + uint32(oid.T_tsquery): pgtypes.Unknown, + uint32(oid.T_gtsvector): pgtypes.Unknown, + uint32(oid.T__tsvector): pgtypes.Unknown, + uint32(oid.T__gtsvector): pgtypes.Unknown, + uint32(oid.T__tsquery): pgtypes.Unknown, + uint32(oid.T_regconfig): pgtypes.Unknown, + uint32(oid.T__regconfig): pgtypes.Unknown, + uint32(oid.T_regdictionary): pgtypes.Unknown, + uint32(oid.T__regdictionary): pgtypes.Unknown, + uint32(oid.T_jsonb): pgtypes.JsonB, + uint32(oid.T__jsonb): pgtypes.JsonBArray, + uint32(oid.T_anyrange): pgtypes.Unknown, + uint32(oid.T_event_trigger): pgtypes.Unknown, + uint32(oid.T_int4range): pgtypes.Unknown, + uint32(oid.T__int4range): pgtypes.Unknown, + uint32(oid.T_numrange): pgtypes.Unknown, + uint32(oid.T__numrange): pgtypes.Unknown, + uint32(oid.T_tsrange): pgtypes.Unknown, + uint32(oid.T__tsrange): pgtypes.Unknown, + uint32(oid.T_tstzrange): pgtypes.Unknown, + uint32(oid.T__tstzrange): pgtypes.Unknown, + uint32(oid.T_daterange): pgtypes.Unknown, + uint32(oid.T__daterange): pgtypes.Unknown, + uint32(oid.T_int8range): pgtypes.Unknown, + uint32(oid.T__int8range): pgtypes.Unknown, + uint32(oid.T_pg_shseclabel): pgtypes.Unknown, + uint32(oid.T_regnamespace): pgtypes.Unknown, + uint32(oid.T__regnamespace): pgtypes.Unknown, + uint32(oid.T_regrole): pgtypes.Unknown, + uint32(oid.T__regrole): pgtypes.Unknown, +} diff --git a/server/connection_handler.go b/server/connection_handler.go index ef693a4298..9c51960207 100644 --- a/server/connection_handler.go +++ b/server/connection_handler.go @@ -26,14 +26,11 @@ import ( "strings" "sync/atomic" + "github.com/dolthub/dolt/go/libraries/doltcore/sqlserver" "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/go-mysql-server/sql/expression" - "github.com/dolthub/go-mysql-server/sql/plan" - "github.com/dolthub/go-mysql-server/sql/transform" "github.com/dolthub/vitess/go/mysql" - "github.com/dolthub/vitess/go/sqltypes" - querypb "github.com/dolthub/vitess/go/vt/proto/query" "github.com/dolthub/vitess/go/vt/sqlparser" + "github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgtype" "github.com/sirupsen/logrus" @@ -45,7 +42,6 @@ import ( "github.com/dolthub/doltgresql/server/ast" pgexprs "github.com/dolthub/doltgresql/server/expression" "github.com/dolthub/doltgresql/server/node" - pgtypes "github.com/dolthub/doltgresql/server/types" ) // ConnectionHandler is responsible for the entire lifecycle of a user connection: receiving messages they send, @@ -54,7 +50,7 @@ type ConnectionHandler struct { mysqlConn *mysql.Conn preparedStatements map[string]PreparedStatementData portals map[string]PortalData - handler mysql.Handler + doltgresHandler *DoltgresHandler pgTypeMap *pgtype.Map waitForSync bool // copyFromStdinState is set when this connection is in the COPY FROM STDIN mode, meaning it is waiting on @@ -62,14 +58,6 @@ type ConnectionHandler struct { copyFromStdinState *copyFromStdinState } -// copyFromStdinState tracks the metadata for an import of data into a table using a COPY FROM STDIN statement. When -// this statement is processed, the server accepts COPY DATA messages from the client with chunks of data to load -// into a table. -type copyFromStdinState struct { - copyFromStdinNode *node.CopyFrom - dataLoader *dataloader.TabularDataLoader -} - // Set this env var to disable panic handling in the connection, which is useful when debugging a panic const disablePanicHandlingEnvVar = "DOLT_PGSQL_PANIC" @@ -97,11 +85,23 @@ func NewConnectionHandler(conn net.Conn, handler mysql.Handler) *ConnectionHandl preparedStatements := make(map[string]PreparedStatementData) portals := make(map[string]PortalData) + // TODO: possibly should define engine and session manager ourselves + // instead of depending on the GetRunningServer method. + server := sqlserver.GetRunningServer() + doltgresHandler := &DoltgresHandler{ + e: server.Engine, + sm: server.SessionManager(), + readTimeout: 0, // cfg.ConnReadTimeout, + encodeLoggedQuery: false, // cfg.EncodeLoggedQuery, + } + + // TODO: should we use this backend??? + //pgproto3.NewBackend() return &ConnectionHandler{ mysqlConn: mysqlConn, preparedStatements: preparedStatements, portals: portals, - handler: handler, + doltgresHandler: doltgresHandler, pgTypeMap: pgtype.NewMap(), } } @@ -137,13 +137,13 @@ func (h *ConnectionHandler) HandleConnection() { fmt.Println(returnErr.Error()) } - h.handler.ConnectionClosed(h.mysqlConn) + h.doltgresHandler.ConnectionClosed(h.mysqlConn) if err := h.Conn().Close(); err != nil { fmt.Printf("Failed to properly close connection:\n%v\n", err) } }() } - h.handler.NewConnection(h.mysqlConn) + h.doltgresHandler.NewConnection(h.mysqlConn) startupMessage, err := h.receiveStartupMessage() if err != nil { @@ -308,27 +308,31 @@ InitialMessageLoop: // chooseInitialDatabase attempts to choose the initial database for the connection, if one is specified in the // startup message provided func (h *ConnectionHandler) chooseInitialDatabase(startupMessage messages.StartupMessage) error { - if db, ok := startupMessage.Parameters["database"]; ok && len(db) > 0 { - err := h.handler.ComQuery(context.Background(), h.mysqlConn, fmt.Sprintf("USE `%s`;", db), func(res *sqltypes.Result, more bool) error { - return nil - }) - if err != nil { - _ = connection.Send(h.Conn(), messages.ErrorResponse{ - Severity: messages.ErrorResponseSeverity_Fatal, - SqlStateCode: "3D000", - Message: fmt.Sprintf(`"database "%s" does not exist"`, db), - Optional: messages.ErrorResponseOptionalFields{ - Routine: "InitPostgres", - }, - }) - return err - } - } else { - // If a database isn't specified, then we attempt to connect to a database with the same name as the user, - // ignoring any error - _ = h.handler.ComQuery(context.Background(), h.mysqlConn, fmt.Sprintf("USE `%s`;", h.mysqlConn.User), func(*sqltypes.Result, bool) error { - return nil + db, ok := startupMessage.Parameters["database"] + dbSpecified := ok && len(db) > 0 + if !dbSpecified { + db = h.mysqlConn.User + } + useStmt := fmt.Sprintf("SET database TO '%s';", db) + parsed, err := sql.GlobalParser.ParseSimple(useStmt) + if err != nil { + return err + } + err = h.doltgresHandler.ComQuery(context.Background(), h.mysqlConn, useStmt, parsed, func(res *Result) error { + return nil + }) + // If a database isn't specified, then we attempt to connect to a database with the same name as the user, + // ignoring any error + if err != nil && dbSpecified { + _ = connection.Send(h.Conn(), messages.ErrorResponse{ + Severity: messages.ErrorResponseSeverity_Fatal, + SqlStateCode: "3D000", + Message: fmt.Sprintf(`"database "%s" does not exist"`, db), + Optional: messages.ErrorResponseOptionalFields{ + Routine: "InitPostgres", + }, }) + return err } return nil } @@ -450,11 +454,16 @@ func (h *ConnectionHandler) handleParse(message messages.Parse) error { return nil } - analyzedPlan, fields, err := h.getPlanAndFields(query) + parsedQuery, fields, err := h.doltgresHandler.ComPrepareParsed(context.Background(), h.mysqlConn, query.String, query.AST) if err != nil { return err } + analyzedPlan, ok := parsedQuery.(sql.Node) + if !ok { + return fmt.Errorf("expected a sql.Node, got %T", parsedQuery) + } + // A valid Parse message must have ParameterObjectIDs if there are any binding variables. bindVarTypes := message.ParameterObjectIDs if len(bindVarTypes) == 0 { @@ -465,16 +474,6 @@ func (h *ConnectionHandler) handleParse(message messages.Parse) error { } } - // Nil fields means an OKResult, fill one in here - if fields == nil { - fields = []*querypb.Field{ - { - Name: "Rows", - Type: sqltypes.Int32, - }, - } - } - h.preparedStatements[message.Name] = PreparedStatementData{ Query: query, ReturnFields: fields, @@ -486,7 +485,7 @@ func (h *ConnectionHandler) handleParse(message messages.Parse) error { // handleDescribe handles a Describe message, returning any error that occurs func (h *ConnectionHandler) handleDescribe(message messages.Describe) error { - var fields []*querypb.Field + var fields []pgproto3.FieldDescription var bindvarTypes []uint32 var tag string @@ -539,11 +538,16 @@ func (h *ConnectionHandler) handleBind(message messages.Bind) error { return err } - boundPlan, fields, err := h.bindParams(preparedData.Query.String, preparedData.Query.AST, bindVars) + analyzedPlan, fields, err := h.doltgresHandler.ComBind(context.Background(), h.mysqlConn, preparedData.Query.String, preparedData.Query.AST, bindVars) if err != nil { return err } + boundPlan, ok := analyzedPlan.(sql.Node) + if !ok { + return fmt.Errorf("expected a sql.Node, got %T", analyzedPlan) + } + h.portals[message.DestinationPortal] = PortalData{ Query: preparedData.Query, Fields: fields, @@ -581,7 +585,7 @@ func (h *ConnectionHandler) handleExecute(message messages.Execute) error { return err } - err = h.handler.(mysql.ExtendedHandler).ComExecuteBound(context.Background(), h.mysqlConn, query.String, portalData.BoundPlan, spoolRowsCallback(h.Conn(), &complete, true)) + err = h.doltgresHandler.ComExecuteBound(context.Background(), h.mysqlConn, query.String, portalData.BoundPlan, spoolRowsCallback(h.Conn(), &complete, true)) if err != nil { return err } @@ -600,11 +604,7 @@ func (h *ConnectionHandler) handleCopyData(message messages.CopyData) (stop bool } // Grab a sql.Context - ctxProvider, ok := h.handler.(sql.ContextProvider) - if !ok { - return false, true, fmt.Errorf("%T does not implement server.ContextProvider", h.handler) - } - sqlCtx, err := ctxProvider.NewContext(context.Background(), h.mysqlConn, "") + sqlCtx, err := h.doltgresHandler.NewContext(context.Background(), h.mysqlConn, "") if err != nil { return false, false, err } @@ -664,11 +664,7 @@ func (h *ConnectionHandler) handleCopyDone(_ messages.CopyDone) (stop bool, endO fmt.Errorf("no data loader found for COPY FROM STDIN operation") } - ctxProvider, ok := h.handler.(sql.ContextProvider) - if !ok { - return false, true, fmt.Errorf("%T does not implement server.ContextProvider", h.handler) - } - sqlCtx, err := ctxProvider.NewContext(context.Background(), h.mysqlConn, "") + sqlCtx, err := h.doltgresHandler.NewContext(context.Background(), h.mysqlConn, "") if err != nil { return false, false, err } @@ -725,127 +721,30 @@ func (h *ConnectionHandler) deallocatePreparedStatement(name string, preparedSta return connection.Send(conn, commandComplete) } -func extractBindVarTypes(queryPlan sql.Node) ([]uint32, error) { - inspectNode := queryPlan - switch queryPlan := queryPlan.(type) { - case *plan.InsertInto: - inspectNode = queryPlan.Source - } - - types := make([]uint32, 0) - var err error - extractBindVars := func(expr sql.Expression) bool { - if err != nil { - return false - } - switch e := expr.(type) { - case *expression.BindVar: - var oid uint32 - if doltgresType, ok := e.Type().(pgtypes.DoltgresType); ok { - oid = doltgresType.OID() - } else { - oid, err = messages.VitessTypeToObjectID(e.Type().Type()) - if err != nil { - err = fmt.Errorf("could not determine OID for placeholder %s: %w", e.Name, err) - return false - } - } - types = append(types, oid) - case *pgexprs.ExplicitCast: - if bindVar, ok := e.Child().(*expression.BindVar); ok { - var oid uint32 - if doltgresType, ok := bindVar.Type().(pgtypes.DoltgresType); ok { - oid = doltgresType.OID() - } else { - oid, err = messages.VitessTypeToObjectID(e.Type().Type()) - if err != nil { - err = fmt.Errorf("could not determine OID for placeholder %s: %w", bindVar.Name, err) - return false - } - } - types = append(types, oid) - return false - } - // $1::text and similar get converted to a Convert expression wrapping the bindvar - case *expression.Convert: - if bindVar, ok := e.Child.(*expression.BindVar); ok { - var oid uint32 - oid, err = messages.VitessTypeToObjectID(e.Type().Type()) - if err != nil { - err = fmt.Errorf("could not determine OID for placeholder %s: %w", bindVar.Name, err) - return false - } - types = append(types, oid) - return false - } - } - - return true - } - - transform.InspectExpressions(inspectNode, extractBindVars) - return types, err -} - // convertBindParameters handles the conversion from bind parameters to variable values. -func (h *ConnectionHandler) convertBindParameters(types []uint32, formatCodes []int32, values []messages.BindParameterValue) (map[string]*querypb.BindVariable, error) { - bindings := make(map[string]*querypb.BindVariable, len(values)) +func (h *ConnectionHandler) convertBindParameters(types []uint32, formatCodes []int32, values []messages.BindParameterValue) (map[string]sqlparser.Expr, error) { + bindings := make(map[string]sqlparser.Expr, len(values)) for i := range values { - bindingName := fmt.Sprintf("v%d", i+1) - typ := convertType(types[i]) + typ := types[i] var bindVarString string - // We'll rely on a library to decode each format, which will deal with text and binary representations for us - if err := h.pgTypeMap.Scan(types[i], int16(formatCodes[i]), values[i].Data, &bindVarString); err != nil { + if err := h.pgTypeMap.Scan(typ, int16(formatCodes[i]), values[i].Data, &bindVarString); err != nil { return nil, err } - bindVar := &querypb.BindVariable{ - Type: typ, - Value: []byte(bindVarString), - Values: nil, // TODO + + pgTyp, ok := OidToDoltgresType[typ] + if !ok { + return nil, fmt.Errorf("unhandled oid type: %v", typ) + } + v, err := pgTyp.IoInput(nil, bindVarString) + if err != nil { + return nil, err } - bindings[bindingName] = bindVar + bindings[fmt.Sprintf("v%d", i+1)] = sqlparser.InjectedExpr{Expression: pgexprs.NewUnsafeLiteral(v, pgTyp)} } return bindings, nil } -// TODO: we need to migrate this away from vitess types and deal strictly with OIDs which are compatible with Postgres types -func convertType(oid uint32) querypb.Type { - switch oid { - // TODO: this should never be 0 - case 0: - return sqltypes.Int32 - case messages.OidInt2: - return sqltypes.Int16 - case messages.OidInt4: - return sqltypes.Int32 - case messages.OidInt8: - return sqltypes.Int64 - case messages.OidFloat4: - return sqltypes.Float32 - case messages.OidFloat8: - return sqltypes.Float64 - case messages.OidName: - return sqltypes.Text - case messages.OidNumeric: - return sqltypes.Decimal - case messages.OidText: - return sqltypes.Text - case messages.OidBool: - return sqltypes.Bit - case messages.OidDate: - return sqltypes.Date - case messages.OidTimestamp: - return sqltypes.Timestamp - case messages.OidVarchar: - return sqltypes.Text - case messages.OidOid: - return sqltypes.Uint32 - default: - panic(fmt.Sprintf("convertType(oid): unhandled type %d", oid)) - } -} - // sendClientStartupMessages sends introductory messages to the client and returns any error // TODO: implement users and authentication func (h *ConnectionHandler) sendClientStartupMessages(startupMessage messages.StartupMessage) error { @@ -892,7 +791,7 @@ func (h *ConnectionHandler) sendClientStartupMessages(startupMessage messages.St } if err := connection.Send(h.Conn(), messages.BackendKeyData{ - ProcessID: processID, + ProcessID: int32(processID), SecretKey: 0, }); err != nil { return err @@ -908,8 +807,7 @@ func (h *ConnectionHandler) query(query ConvertedQuery) error { Tag: query.StatementTag, } - err := h.comQuery(query, spoolRowsCallback(h.Conn(), &commandComplete, false)) - + err := h.doltgresHandler.ComQuery(context.Background(), h.mysqlConn, query.String, query.AST, spoolRowsCallback(h.Conn(), &commandComplete, false)) if err != nil { if strings.HasPrefix(err.Error(), "syntax error at position") { return fmt.Errorf("This statement is not yet supported") @@ -926,8 +824,8 @@ func (h *ConnectionHandler) query(query ConvertedQuery) error { // spoolRowsCallback returns a callback function that will send RowDescription message, then a DataRow message for // each row in the result set. -func spoolRowsCallback(conn net.Conn, commandComplete *messages.CommandComplete, isExecute bool) mysql.ResultSpoolFn { - return func(res *sqltypes.Result, more bool) error { +func spoolRowsCallback(conn net.Conn, commandComplete *messages.CommandComplete, isExecute bool) func(res *Result) error { + return func(res *Result) error { if messages.ReturnsRow(commandComplete.Tag) { // EXECUTE does not send RowDescription; instead it should be sent from DESCRIBE prior to it if !isExecute { @@ -940,7 +838,7 @@ func spoolRowsCallback(conn net.Conn, commandComplete *messages.CommandComplete, for _, row := range res.Rows { if err := connection.Send(conn, messages.DataRow{ - Values: row, + Values: row.val, }); err != nil { return err } @@ -957,7 +855,7 @@ func spoolRowsCallback(conn net.Conn, commandComplete *messages.CommandComplete, } // sendDescribeResponse sends a response message for a Describe message -func (h *ConnectionHandler) sendDescribeResponse(conn net.Conn, fields []*querypb.Field, types []uint32, tag string) (err error) { +func (h *ConnectionHandler) sendDescribeResponse(conn net.Conn, fields []pgproto3.FieldDescription, types []uint32, tag string) (err error) { // The prepared statement variant of the describe command returns the OIDs of the parameters. if types != nil { if err := connection.Send(conn, messages.ParameterDescription{ @@ -1133,64 +1031,9 @@ func (h *ConnectionHandler) convertQuery(query string) (ConvertedQuery, error) { }, nil } -// getPlanAndFields builds a plan and return fields for the given query -func (h *ConnectionHandler) getPlanAndFields(query ConvertedQuery) (sql.Node, []*querypb.Field, error) { - if query.AST == nil { - return nil, nil, fmt.Errorf("cannot prepare a query that has not been parsed") - } - - parsedQuery, fields, err := h.handler.(mysql.ExtendedHandler).ComPrepareParsed(context.Background(), h.mysqlConn, query.String, query.AST, &mysql.PrepareData{ - PrepareStmt: query.String, - }) - - if err != nil { - return nil, nil, err - } - - analyzedPlan, ok := parsedQuery.(sql.Node) - if !ok { - return nil, nil, fmt.Errorf("expected a sql.Node, got %T", parsedQuery) - } - - return analyzedPlan, fields, nil -} - -// comQuery is a shortcut that determines which version of ComQuery to call based on whether the query has been parsed. -func (h *ConnectionHandler) comQuery(query ConvertedQuery, callback func(res *sqltypes.Result, more bool) error) error { - if query.AST == nil { - return h.handler.ComQuery(context.Background(), h.mysqlConn, query.String, callback) - } else { - return h.handler.(mysql.ExtendedHandler).ComParsedQuery(context.Background(), h.mysqlConn, query.String, query.AST, callback) - } -} - -// bindParams binds the paramters given to the query plan given and returns the resulting plan and fields. -func (h *ConnectionHandler) bindParams( - query string, - parsedQuery sqlparser.Statement, - bindVars map[string]*querypb.BindVariable, -) (sql.Node, []*querypb.Field, error) { - bound, fields, err := h.handler.(mysql.ExtendedHandler).ComBind(context.Background(), h.mysqlConn, query, parsedQuery, &mysql.PrepareData{ - PrepareStmt: query, - ParamsCount: uint16(len(bindVars)), - BindVars: bindVars, - }) - - if err != nil { - return nil, nil, err - } - - plan, ok := bound.(sql.Node) - if !ok { - return nil, nil, fmt.Errorf("expected a sql.Node, got %T", bound) - } - - return plan, fields, err -} - // discardAll handles the DISCARD ALL command func (h *ConnectionHandler) discardAll(query ConvertedQuery, conn net.Conn) error { - err := h.handler.ComResetConnection(h.mysqlConn) + err := h.doltgresHandler.ComResetConnection(h.mysqlConn) if err != nil { return err } @@ -1207,11 +1050,7 @@ func (h *ConnectionHandler) discardAll(query ConvertedQuery, conn net.Conn) erro // COPY FROM STDIN can't be handled directly by the GMS engine, since COPY FROM STDIN relies on multiple messages sent // over the wire. func (h *ConnectionHandler) handleCopyFromStdinQuery(copyFrom *node.CopyFrom, conn net.Conn) error { - ctxProvider, ok := h.handler.(sql.ContextProvider) - if !ok { - return fmt.Errorf("%T does not implement server.ContextProvider", h.handler) - } - sqlCtx, err := ctxProvider.NewContext(context.Background(), h.mysqlConn, "") + sqlCtx, err := h.doltgresHandler.NewContext(context.Background(), h.mysqlConn, "") if err != nil { return err } diff --git a/server/converted_query.go b/server/converted_query.go deleted file mode 100644 index 333192bd26..0000000000 --- a/server/converted_query.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2023 Dolthub, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package server - -import ( - "github.com/dolthub/go-mysql-server/sql" - querypb "github.com/dolthub/vitess/go/vt/proto/query" - vitess "github.com/dolthub/vitess/go/vt/sqlparser" -) - -// ConvertedQuery represents a query that has been converted from the Postgres representation to the Vitess -// representation. String may contain the string version of the converted query. AST will contain the tree -// version of the converted query, and is the recommended form to use. If AST is nil, then use the String version, -// otherwise always prefer to AST. -type ConvertedQuery struct { - String string - AST vitess.Statement - StatementTag string -} - -type PreparedStatementData struct { - Query ConvertedQuery - ReturnFields []*querypb.Field - BindVarTypes []uint32 -} - -type PortalData struct { - Query ConvertedQuery - IsEmptyQuery bool - Fields []*querypb.Field - BoundPlan sql.Node -} diff --git a/server/doltgres_handler.go b/server/doltgres_handler.go new file mode 100644 index 0000000000..2a35552b09 --- /dev/null +++ b/server/doltgres_handler.go @@ -0,0 +1,536 @@ +package server + +import ( + "context" + "encoding/base64" + "fmt" + "io" + "regexp" + "runtime/trace" + "sync" + "time" + + sqle "github.com/dolthub/go-mysql-server" + "github.com/dolthub/go-mysql-server/server" + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/analyzer" + "github.com/dolthub/go-mysql-server/sql/plan" + "github.com/dolthub/go-mysql-server/sql/types" + "github.com/dolthub/vitess/go/mysql" + "github.com/dolthub/vitess/go/vt/sqlparser" + "github.com/jackc/pgx/v5/pgproto3" + "github.com/sirupsen/logrus" + + "github.com/dolthub/doltgresql/postgres/messages" + pgtypes "github.com/dolthub/doltgresql/server/types" +) + +// Result represents a query result. +type Result struct { + Fields []pgproto3.FieldDescription `json:"fields"` + Rows []Row `json:"rows"` + RowsAffected uint64 `json:"rows_affected"` +} + +// Row represents a single row value in bytes format. +// |val| represents array of a single row elements, +// which each element value is in byte array format. +type Row struct { + val [][]byte +} + +const rowsBatch = 128 + +// DoltgresHandler is a handler uses SQLe engine directly +// running Doltgres specific queries. +type DoltgresHandler struct { + e *sqle.Engine + sm *server.SessionManager + readTimeout time.Duration + encodeLoggedQuery bool +} + +var _ Handler = &DoltgresHandler{} + +// ComBind implements the Handler interface. +func (h *DoltgresHandler) ComBind(ctx context.Context, c *mysql.Conn, query string, parsedQuery mysql.ParsedQuery, bindVars map[string]sqlparser.Expr) (mysql.BoundQuery, []pgproto3.FieldDescription, error) { + sqlCtx, err := h.sm.NewContextWithQuery(ctx, c, query) + if err != nil { + return nil, nil, err + } + + stmt, ok := parsedQuery.(sqlparser.Statement) + if !ok { + return nil, nil, fmt.Errorf("parsedQuery must be a sqlparser.Statement, but got %T", parsedQuery) + } + + queryPlan, err := h.e.BoundQueryPlan(sqlCtx, query, stmt, bindVars) + if err != nil { + return nil, nil, err + } + + return queryPlan, schemaToFieldDescriptions(sqlCtx, queryPlan.Schema()), nil +} + +// ComExecuteBound implements the Handler interface. +func (h *DoltgresHandler) ComExecuteBound(ctx context.Context, conn *mysql.Conn, query string, boundQuery mysql.BoundQuery, callback func(*Result) error) error { + analyzedPlan, ok := boundQuery.(sql.Node) + if !ok { + return fmt.Errorf("boundQuery must be a sql.Node, but got %T", boundQuery) + } + + err := h.doQuery(ctx, conn, query, nil, analyzedPlan, h.executeBoundPlan, callback) + if err != nil { + err = sql.CastSQLError(err) + } + + return err +} + +// ComPrepareParsed implements the Handler interface. +func (h *DoltgresHandler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query string, parsed sqlparser.Statement) (mysql.ParsedQuery, []pgproto3.FieldDescription, error) { + sqlCtx, err := h.sm.NewContextWithQuery(ctx, c, query) + if err != nil { + return nil, nil, err + } + + analyzed, err := h.e.PrepareParsedQuery(sqlCtx, query, query, parsed) + if err != nil { + logrus.WithField("query", query).Errorf("unable to prepare query: %s", err.Error()) + err := sql.CastSQLError(err) + return nil, nil, err + } + + var fields []pgproto3.FieldDescription + // The query is not a SELECT statement if it corresponds to an OK result. + if nodeReturnsOkResultSchema(analyzed) { + fields = []pgproto3.FieldDescription{ + { + Name: []byte("Rows"), + DataTypeOID: pgtypes.Int32.OID(), + DataTypeSize: int16(pgtypes.Int32.MaxTextResponseByteLength(nil)), + }, + } + } else { + fields = schemaToFieldDescriptions(sqlCtx, analyzed.Schema()) + } + return analyzed, fields, nil +} + +// ComQuery implements the Handler interface. +func (h *DoltgresHandler) ComQuery(ctx context.Context, c *mysql.Conn, query string, parsed sqlparser.Statement, callback func(*Result) error) error { + err := h.doQuery(ctx, c, query, parsed, nil, h.executeQuery, callback) + if err != nil { + err = sql.CastSQLError(err) + } + return err +} + +// ComResetConnection implements the Handler interface. +func (h *DoltgresHandler) ComResetConnection(c *mysql.Conn) error { + logrus.WithField("connectionId", c.ConnectionID).Debug("COM_RESET_CONNECTION command received") + + // Grab the currently selected database name + db := h.sm.GetCurrentDB(c) + + // Dispose of the connection's current session + h.maybeReleaseAllLocks(c) + h.e.CloseSession(c.ConnectionID) + + // Create a new session and set the current database + err := h.sm.NewSession(context.Background(), c) + if err != nil { + return err + } + return h.sm.SetDB(c, db) +} + +// ConnectionClosed implements the Handler interface. +func (h *DoltgresHandler) ConnectionClosed(c *mysql.Conn) { + defer h.sm.RemoveConn(c) + defer h.e.CloseSession(c.ConnectionID) + + h.maybeReleaseAllLocks(c) + + logrus.WithField(sql.ConnectionIdLogField, c.ConnectionID).Infof("ConnectionClosed") +} + +// NewConnection implements the Handler interface. +func (h *DoltgresHandler) NewConnection(c *mysql.Conn) { + h.sm.AddConn(c) + sql.StatusVariables.IncrementGlobal("Connections", 1) + + c.DisableClientMultiStatements = true // TODO: h.disableMultiStmts + logrus.WithField(sql.ConnectionIdLogField, c.ConnectionID).WithField("DisableClientMultiStatements", c.DisableClientMultiStatements).Infof("NewConnection") +} + +// NewContext implements the Handler interface. +func (h *DoltgresHandler) NewContext(ctx context.Context, c *mysql.Conn, query string) (*sql.Context, error) { + return h.sm.NewContext(ctx, c, query) +} + +var queryLoggingRegex = regexp.MustCompile(`[\r\n\t ]+`) + +func (h *DoltgresHandler) doQuery(ctx context.Context, c *mysql.Conn, query string, parsed sqlparser.Statement, analyzedPlan sql.Node, queryExec QueryExecutor, callback func(*Result) error) error { + sqlCtx, err := h.sm.NewContextWithQuery(ctx, c, query) + if err != nil { + return err + } + + start := time.Now() + var queryStrToLog string + if h.encodeLoggedQuery { + queryStrToLog = base64.StdEncoding.EncodeToString([]byte(query)) + } else if logrus.IsLevelEnabled(logrus.DebugLevel) { + // this is expensive, so skip this unless we're logging at DEBUG level + queryStrToLog = string(queryLoggingRegex.ReplaceAll([]byte(query), []byte(" "))) + } + + if queryStrToLog != "" { + sqlCtx.SetLogger(sqlCtx.GetLogger().WithField("query", queryStrToLog)) + } + sqlCtx.GetLogger().Debugf("Starting query") + sqlCtx.GetLogger().Tracef("beginning execution") + + oCtx := ctx + + // TODO: it would be nice to put this logic in the engine, not the handler, but we don't want the process to be + // marked done until we're done spooling rows over the wire + ctx, err = sqlCtx.ProcessList.BeginQuery(sqlCtx, query) + defer func() { + if err != nil && ctx != nil { + sqlCtx.ProcessList.EndQuery(sqlCtx) + } + }() + + schema, rowIter, qFlags, err := queryExec(sqlCtx, query, parsed, analyzedPlan) + if err != nil { + sqlCtx.GetLogger().WithError(err).Warn("error running query") + fmt.Printf("Err: %+v", err) + return err + } + + // create result before goroutines to avoid |ctx| racing + var r *Result + var processedAtLeastOneBatch bool + + // zero/single return schema use spooling shortcut + if types.IsOkResultSchema(schema) { + r, err = resultForOkIter(sqlCtx, rowIter) + } else if schema == nil { + r, err = resultForEmptyIter(sqlCtx, rowIter) + } else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) { + resultFields := schemaToFieldDescriptions(sqlCtx, schema) + r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields) + } else { + resultFields := schemaToFieldDescriptions(sqlCtx, schema) + r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, schema, rowIter, callback, resultFields) + } + if err != nil { + return err + } + + // errGroup context is now canceled + ctx = oCtx + + sqlCtx.GetLogger().Debugf("Query finished in %d ms", time.Since(start).Milliseconds()) + + // processedAtLeastOneBatch means we already called callback() at least + // once, so no need to call it if RowsAffected == 0. + if r != nil && (r.RowsAffected == 0 && processedAtLeastOneBatch) { + return nil + } + + return callback(r) +} + +// QueryExecutor is a function that executes a query and returns the result as a schema and iterator. Either of +// |parsed| or |analyzed| can be nil depending on the use case +type QueryExecutor func(ctx *sql.Context, query string, parsed sqlparser.Statement, analyzed sql.Node) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) + +// executeQuery is a QueryExecutor that calls QueryWithBindings on the given engine using the given query and parsed +// statement, which may be nil. +func (h *DoltgresHandler) executeQuery(ctx *sql.Context, query string, parsed sqlparser.Statement, _ sql.Node) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) { + return h.e.QueryWithBindings(ctx, query, parsed, nil, nil) +} + +// executeBoundPlan is a QueryExecutor that calls QueryWithBindings on the given engine using the given query and parsed +// statement, which may be nil. +func (h *DoltgresHandler) executeBoundPlan(ctx *sql.Context, query string, _ sqlparser.Statement, plan sql.Node) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) { + return h.e.PrepQueryPlanForExecution(ctx, query, plan) +} + +// maybeReleaseAllLocks makes a best effort attempt to release all locks on the given connection. If the attempt fails, +// an error is logged but not returned. +func (h *DoltgresHandler) maybeReleaseAllLocks(c *mysql.Conn) { + if ctx, err := h.sm.NewContextWithQuery(context.Background(), c, ""); err != nil { + logrus.Errorf("unable to release all locks on session close: %s", err) + logrus.Errorf("unable to unlock tables on session close: %s", err) + } else { + _, err = h.e.LS.ReleaseAll(ctx) + if err != nil { + logrus.Errorf("unable to release all locks on session close: %s", err) + } + if err = h.e.Analyzer.Catalog.UnlockTables(ctx, c.ConnectionID); err != nil { + logrus.Errorf("unable to unlock tables on session close: %s", err) + } + } +} + +// nodeReturnsOkResultSchema returns whether the node returns OK result or the schema is OK result schema. +// These nodes will eventually return an OK result, but their intermediate forms here return a different schema +// than they will at execution time. +func nodeReturnsOkResultSchema(node sql.Node) bool { + switch node.(type) { + case *plan.InsertInto, *plan.Update, *plan.UpdateJoin, *plan.DeleteFrom: + return true + } + return types.IsOkResultSchema(node.Schema()) +} + +func schemaToFieldDescriptions(ctx *sql.Context, s sql.Schema) []pgproto3.FieldDescription { + fields := make([]pgproto3.FieldDescription, len(s)) + for i, c := range s { + var oid uint32 + var err error + if doltgresType, ok := c.Type.(pgtypes.DoltgresType); ok { + oid = doltgresType.OID() + } else { + oid, err = messages.VitessTypeToObjectID(c.Type.Type()) + if err != nil { + panic(err) + } + } + + // "Format" field: The format code being used for the field. + // Currently, will be zero (text) or one (binary). + // In a RowDescription returned from the statement variant of Describe, + // the format code is not yet known and will always be zero. + + fields[i] = pgproto3.FieldDescription{ + Name: []byte(c.Name), + TableOID: uint32(0), + TableAttributeNumber: uint16(0), + DataTypeOID: oid, + DataTypeSize: int16(c.Type.MaxTextResponseByteLength(ctx)), + TypeModifier: int32(-1), // TODO: used for domain type, which we don't support yet + Format: int16(0), + } + } + + return fields +} + +// resultForOkIter reads a maximum of one result row from a result iterator. +func resultForOkIter(ctx *sql.Context, iter sql.RowIter) (*Result, error) { + defer trace.StartRegion(ctx, "DoltgresHandler.resultForOkIter").End() + + row, err := iter.Next(ctx) + if err != nil { + return nil, err + } + _, err = iter.Next(ctx) + if err != io.EOF { + return nil, fmt.Errorf("result schema iterator returned more than one row") + } + if err := iter.Close(ctx); err != nil { + return nil, err + } + + return &Result{ + RowsAffected: row[0].(types.OkResult).RowsAffected, + }, nil +} + +// resultForEmptyIter ensures that an expected empty iterator returns no rows. +func resultForEmptyIter(ctx *sql.Context, iter sql.RowIter) (*Result, error) { + defer trace.StartRegion(ctx, "DoltgresHandler.resultForEmptyIter").End() + if _, err := iter.Next(ctx); err != io.EOF { + return nil, fmt.Errorf("result schema iterator returned more than zero rows") + } + if err := iter.Close(ctx); err != nil { + return nil, err + } + return &Result{Fields: nil}, nil +} + +// resultForMax1RowIter ensures that an empty iterator returns at most one row +func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, resultFields []pgproto3.FieldDescription) (*Result, error) { + defer trace.StartRegion(ctx, "DoltgresHandler.resultForMax1RowIter").End() + row, err := iter.Next(ctx) + if err == io.EOF { + return &Result{Fields: resultFields}, nil + } else if err != nil { + return nil, err + } + + if _, err = iter.Next(ctx); err != io.EOF { + return nil, fmt.Errorf("result max1Row iterator returned more than one row") + } + if err := iter.Close(ctx); err != nil { + return nil, err + } + + outputRow, err := rowToBytes(ctx, schema, row) + if err != nil { + return nil, err + } + + ctx.GetLogger().Tracef("spooling result row %s", outputRow) + + return &Result{Fields: resultFields, Rows: []Row{{outputRow}}, RowsAffected: 1}, nil +} + +// resultForDefaultIter reads batches of rows from the iterator +// and writes results into the callback function. +func (h *DoltgresHandler) resultForDefaultIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, callback func(*Result) error, resultFields []pgproto3.FieldDescription) (r *Result, processedAtLeastOneBatch bool, returnErr error) { + defer trace.StartRegion(ctx, "DoltgresHandler.resultForDefaultIter").End() + + eg, ctx := ctx.NewErrgroup() + + var rowChan = make(chan sql.Row, 512) + + pan2err := func() { + if recoveredPanic := recover(); recoveredPanic != nil { + returnErr = fmt.Errorf("DoltgresHandler caught panic: %v", recoveredPanic) + } + } + + wg := sync.WaitGroup{} + wg.Add(2) + // Read rows off the row iterator and send them to the row channel. + eg.Go(func() error { + defer pan2err() + defer wg.Done() + defer close(rowChan) + for { + select { + case <-ctx.Done(): + return nil + default: + row, err := iter.Next(ctx) + if err == io.EOF { + return nil + } + if err != nil { + return err + } + select { + case rowChan <- row: + case <-ctx.Done(): + return nil + } + } + } + }) + + // Default waitTime is one minute if there is no timeout configured, in which case + // it will loop to iterate again unless the socket died by the OS timeout or other problems. + // If there is a timeout, it will be enforced to ensure that Vitess has a chance to + // call DoltgresHandler.CloseConnection() + waitTime := 1 * time.Minute + if h.readTimeout > 0 { + waitTime = h.readTimeout + } + timer := time.NewTimer(waitTime) + defer timer.Stop() + + // reads rows from the channel, converts them to wire format, + // and calls |callback| to give them to vitess. + eg.Go(func() error { + defer pan2err() + //defer cancelF() + defer wg.Done() + for { + if r == nil { + r = &Result{Fields: resultFields} + } + if r.RowsAffected == rowsBatch { + if err := callback(r); err != nil { + return err + } + r = nil + processedAtLeastOneBatch = true + continue + } + + select { + case <-ctx.Done(): + return nil + case row, ok := <-rowChan: + if !ok { + return nil + } + if types.IsOkResult(row) { + if len(r.Rows) > 0 { + panic("Got OkResult mixed with RowResult") + } + result := row[0].(types.OkResult) + r = &Result{ + RowsAffected: result.RowsAffected, + } + continue + } + + outputRow, err := rowToBytes(ctx, schema, row) + if err != nil { + return err + } + + ctx.GetLogger().Tracef("spooling result row %s", outputRow) + r.Rows = append(r.Rows, Row{outputRow}) + r.RowsAffected++ + case <-timer.C: + if h.readTimeout != 0 { + // Cancel and return so Vitess can call the CloseConnection callback + ctx.GetLogger().Tracef("connection timeout") + return fmt.Errorf("row read wait bigger than connection timeout") + } + } + if !timer.Stop() { + <-timer.C + } + timer.Reset(waitTime) + } + }) + + // Close() kills this PID in the process list, + // wait until all rows have be sent over the wire + eg.Go(func() error { + defer pan2err() + wg.Wait() + return iter.Close(ctx) + }) + + err := eg.Wait() + if err != nil { + ctx.GetLogger().WithError(err).Warn("error running query") + fmt.Printf("Err: %+v", err) + returnErr = err + } + + return +} + +func rowToBytes(ctx *sql.Context, s sql.Schema, row sql.Row) ([][]byte, error) { + if len(row) == 0 { + return nil, nil + } + if len(s) == 0 { + // should not happen + return nil, fmt.Errorf("received empty schema") + } + o := make([][]byte, len(row)) + for i, v := range row { + if v == nil { + o[i] = nil + } else { + val, err := s[i].Type.SQL(ctx, []byte{}, v) + if err != nil { + return nil, err + } + o[i] = val.ToBytes() + } + } + return o, nil +} diff --git a/server/functions/extract.go b/server/functions/extract.go index 395c3733bb..46834d4a77 100644 --- a/server/functions/extract.go +++ b/server/functions/extract.go @@ -146,7 +146,7 @@ var extract_text_interval = framework.Function2{ case "epoch": epoch := float64(duration.SecsPerDay*duration.DaysPerMonth*dur.Months) + float64(duration.SecsPerDay*dur.Days) + (float64(dur.Nanos()) / (NanosPerSec)) - return decimal.NewFromFloatWithExponent(epoch, -6), nil + return decimal.NewFromString(decimal.NewFromFloat(epoch).StringFixed(6)) case "hour", "hours": hours := math.Floor(float64(dur.Nanos()) / (NanosPerSec * duration.SecsPerHour)) return decimal.NewFromFloat(hours), nil @@ -159,7 +159,7 @@ var extract_text_interval = framework.Function2{ case "millisecond", "milliseconds": secondsInNanos := dur.Nanos() % (NanosPerSec * duration.SecsPerMinute) milliseconds := float64(secondsInNanos) / NanosPerMilli - return decimal.NewFromFloatWithExponent(milliseconds, -3), nil + return decimal.NewFromString(decimal.NewFromFloat(milliseconds).StringFixed(3)) case "minute", "minutes": minutesInNanos := dur.Nanos() % (NanosPerSec * duration.SecsPerHour) minutes := math.Floor(float64(minutesInNanos) / (NanosPerSec * duration.SecsPerMinute)) @@ -171,7 +171,7 @@ var extract_text_interval = framework.Function2{ case "second", "seconds": secondsInNanos := dur.Nanos() % (NanosPerSec * duration.SecsPerMinute) seconds := float64(secondsInNanos) / NanosPerSec - return decimal.NewFromFloatWithExponent(seconds, -6), nil + return decimal.NewFromString(decimal.NewFromFloat(seconds).StringFixed(6)) case "year", "years": return decimal.NewFromFloat(math.Floor(float64(dur.Months) / 12)), nil case "dow", "doy", "isodow", "isoyear", "julian", "timezone", "timezone_hour", "timezone_minute", "week": @@ -196,7 +196,7 @@ func getFieldFromTimeVal(field string, tVal time.Time) (decimal.Decimal, error) case "doy": return decimal.NewFromInt(int64(tVal.YearDay())), nil case "epoch": - return decimal.NewFromFloat(float64(tVal.UnixMicro()) / 1000000), nil + return decimal.NewFromString(decimal.NewFromFloat(float64(tVal.UnixMicro()) / 1000000).StringFixed(6)) case "hour", "hours": return decimal.NewFromInt(int64(tVal.Hour())), nil case "isodow": @@ -219,7 +219,7 @@ func getFieldFromTimeVal(field string, tVal time.Time) (decimal.Decimal, error) case "millisecond", "milliseconds": w := float64(tVal.Second() * 1000) f := float64(tVal.Nanosecond()) / float64(1000000) - return decimal.NewFromFloatWithExponent(w+f, -3), nil + return decimal.NewFromString(decimal.NewFromFloat(w + f).StringFixed(3)) case "minute", "minutes": return decimal.NewFromInt(int64(tVal.Minute())), nil case "month", "months": @@ -230,7 +230,7 @@ func getFieldFromTimeVal(field string, tVal time.Time) (decimal.Decimal, error) case "second", "seconds": w := float64(tVal.Second()) f := float64(tVal.Nanosecond()) / float64(1000000000) - return decimal.NewFromFloatWithExponent(w+f, -6), nil + return decimal.NewFromString(decimal.NewFromFloat(w + f).StringFixed(6)) case "timezone": // TODO: postgres seem to use server timezone regardless of input value return decimal.NewFromInt(-28800), nil diff --git a/server/handler.go b/server/handler.go new file mode 100644 index 0000000000..f03e5d533f --- /dev/null +++ b/server/handler.go @@ -0,0 +1,32 @@ +package server + +import ( + "context" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/vitess/go/mysql" + "github.com/dolthub/vitess/go/vt/sqlparser" + "github.com/jackc/pgx/v5/pgproto3" +) + +type Handler interface { + // ComBind is called when a connection receives a request to bind a prepared statement to a set of values. + ComBind(ctx context.Context, c *mysql.Conn, query string, parsedQuery mysql.ParsedQuery, bindVars map[string]sqlparser.Expr) (mysql.BoundQuery, []pgproto3.FieldDescription, error) + // ComExecuteBound is called when a connection receives a request to execute a prepared statement that has already bound to a set of values. + ComExecuteBound(ctx context.Context, conn *mysql.Conn, query string, boundQuery mysql.BoundQuery, callback func(*Result) error) error + // ComPrepareParsed is called when a connection receives a prepared statement query that has already been parsed. + ComPrepareParsed(ctx context.Context, c *mysql.Conn, query string, parsed sqlparser.Statement) (mysql.ParsedQuery, []pgproto3.FieldDescription, error) + // ComQuery is called when a connection receives a query. Note the contents of the query slice may change + // after the first call to callback. So the DoltgresHandler should not hang on to the byte slice. + ComQuery(ctx context.Context, c *mysql.Conn, query string, parsed sqlparser.Statement, callback func(*Result) error) error + // ComResetConnection resets the connection's session, clearing out any cached prepared statements, locks, user and + // session variables. The currently selected database is preserved. + ComResetConnection(c *mysql.Conn) error + // ConnectionClosed reports that a connection has been closed. + ConnectionClosed(c *mysql.Conn) + // NewConnection reports that a new connection has been established. + NewConnection(c *mysql.Conn) + // NewContext creates a new sql.Context instance for the connection |c|. The + // optional |query| can be specified to populate the sql.Context's query field. + NewContext(ctx context.Context, c *mysql.Conn, query string) (*sql.Context, error) +} diff --git a/server/listener.go b/server/listener.go index 4cf4b45cfb..c400722cd7 100644 --- a/server/listener.go +++ b/server/listener.go @@ -27,7 +27,7 @@ import ( var ( connectionIDCounter uint32 - processID = int32(os.Getpid()) + processID = uint32(os.Getpid()) certificate tls.Certificate //TODO: move this into the mysql.ListenerConfig ) diff --git a/server/types/date.go b/server/types/date.go index dfbcda618e..ec5ac29752 100644 --- a/server/types/date.go +++ b/server/types/date.go @@ -158,7 +158,7 @@ func (b DateType) MaxSerializedWidth() types.ExtendedTypeSerializedWidth { // MaxTextResponseByteLength implements the DoltgresType interface. func (b DateType) MaxTextResponseByteLength(ctx *sql.Context) uint32 { - return 32 + return 4 } // OID implements the DoltgresType interface. diff --git a/server/types/internal_char.go b/server/types/internal_char.go index a9c4c93142..57d662add4 100644 --- a/server/types/internal_char.go +++ b/server/types/internal_char.go @@ -131,7 +131,8 @@ func (b InternalCharType) GetSerializationID() SerializationID { // IoInput implements the DoltgresType interface. func (b InternalCharType) IoInput(ctx *sql.Context, input string) (any, error) { - if uint32(len(input)) > InternalCharLength { + c := []byte(input) + if uint32(len(c)) > InternalCharLength { return input[:InternalCharLength], nil } return input, nil diff --git a/server/types/interval.go b/server/types/interval.go index 379e17b3f5..b942b8e718 100644 --- a/server/types/interval.go +++ b/server/types/interval.go @@ -157,7 +157,7 @@ func (b IntervalType) MaxSerializedWidth() types.ExtendedTypeSerializedWidth { // MaxTextResponseByteLength implements the DoltgresType interface. func (b IntervalType) MaxTextResponseByteLength(ctx *sql.Context) uint32 { - return 64 + return 16 } // OID implements the DoltgresType interface. diff --git a/server/types/jsonb.go b/server/types/jsonb.go index ecebcca189..de49f769b5 100644 --- a/server/types/jsonb.go +++ b/server/types/jsonb.go @@ -279,15 +279,15 @@ func (b JsonBType) unmarshalToJsonDocument(val []byte) (JsonDocument, error) { if err := json.Unmarshal(val, &decoded); err != nil { return JsonDocument{}, err } - jsonValue, err := b.convertToJsonDocument(decoded) + jsonValue, err := b.ConvertToJsonDocument(decoded) if err != nil { return JsonDocument{}, err } return JsonDocument{Value: jsonValue}, nil } -// convertToJsonDocument recursively constructs a valid JsonDocument based on the structures returned by the decoder. -func (b JsonBType) convertToJsonDocument(val interface{}) (JsonValue, error) { +// ConvertToJsonDocument recursively constructs a valid JsonDocument based on the structures returned by the decoder. +func (b JsonBType) ConvertToJsonDocument(val interface{}) (JsonValue, error) { var err error switch val := val.(type) { case map[string]interface{}: @@ -306,7 +306,7 @@ func (b JsonBType) convertToJsonDocument(val interface{}) (JsonValue, error) { index := make(map[string]int) for i, key := range keys { items[i].Key = key - items[i].Value, err = b.convertToJsonDocument(val[key]) + items[i].Value, err = b.ConvertToJsonDocument(val[key]) if err != nil { return nil, err } @@ -319,7 +319,7 @@ func (b JsonBType) convertToJsonDocument(val interface{}) (JsonValue, error) { case []interface{}: values := make(JsonValueArray, len(val)) for i, item := range val { - values[i], err = b.convertToJsonDocument(item) + values[i], err = b.ConvertToJsonDocument(item) if err != nil { return nil, err } diff --git a/server/types/numeric.go b/server/types/numeric.go index 9fd9783b46..75b8dc4941 100644 --- a/server/types/numeric.go +++ b/server/types/numeric.go @@ -154,7 +154,12 @@ func (b NumericType) IoOutput(ctx *sql.Context, output any) (string, error) { if err != nil { return "", err } - return converted.(decimal.Decimal).String(), nil + dec := converted.(decimal.Decimal) + scale := b.Scale + if scale == -1 { + scale = dec.Exponent() * -1 + } + return dec.StringFixed(scale), nil } // IsPreferredType implements the DoltgresType interface. diff --git a/server/types/time.go b/server/types/time.go index 2ba16d9a06..21d99f14d7 100644 --- a/server/types/time.go +++ b/server/types/time.go @@ -161,7 +161,7 @@ func (b TimeType) MaxSerializedWidth() types.ExtendedTypeSerializedWidth { // MaxTextResponseByteLength implements the DoltgresType interface. func (b TimeType) MaxTextResponseByteLength(ctx *sql.Context) uint32 { - return 32 + return 8 } // OID implements the DoltgresType interface. diff --git a/server/types/timestamp.go b/server/types/timestamp.go index 498032030d..75cb49ffb6 100644 --- a/server/types/timestamp.go +++ b/server/types/timestamp.go @@ -158,7 +158,7 @@ func (b TimestampType) MaxSerializedWidth() types.ExtendedTypeSerializedWidth { // MaxTextResponseByteLength implements the DoltgresType interface. func (b TimestampType) MaxTextResponseByteLength(ctx *sql.Context) uint32 { - return 32 + return 8 } // OID implements the DoltgresType interface. diff --git a/server/types/timestamptz.go b/server/types/timestamptz.go index 82b792c6ff..f99aac1a36 100644 --- a/server/types/timestamptz.go +++ b/server/types/timestamptz.go @@ -171,7 +171,7 @@ func (b TimestampTZType) MaxSerializedWidth() types.ExtendedTypeSerializedWidth // MaxTextResponseByteLength implements the DoltgresType interface. func (b TimestampTZType) MaxTextResponseByteLength(ctx *sql.Context) uint32 { - return 40 + return 8 } // OID implements the DoltgresType interface. diff --git a/server/types/timetz.go b/server/types/timetz.go index af27979238..23dc492588 100644 --- a/server/types/timetz.go +++ b/server/types/timetz.go @@ -161,7 +161,7 @@ func (b TimeTZType) MaxSerializedWidth() types.ExtendedTypeSerializedWidth { // MaxTextResponseByteLength implements the DoltgresType interface. func (b TimeTZType) MaxTextResponseByteLength(ctx *sql.Context) uint32 { - return 32 + return 12 } // OID implements the DoltgresType interface. diff --git a/testing/go/coercion_test.go b/testing/go/coercion_test.go index 2d7b9f2051..3838de291a 100644 --- a/testing/go/coercion_test.go +++ b/testing/go/coercion_test.go @@ -34,8 +34,8 @@ func TestCoercion(t *testing.T) { Expected: []sql.Row{{Numeric("0.5")}}, }, { - Query: `SELECT 0.5`, - Expected: []sql.Row{{Numeric("0.5")}}, + Query: `SELECT 0.50`, + Expected: []sql.Row{{Numeric("0.50")}}, }, { Query: `SELECT -0.5`, diff --git a/testing/go/enginetest/doltgres_harness_test.go b/testing/go/enginetest/doltgres_harness_test.go index d5c055aa1f..2c0975a8fa 100644 --- a/testing/go/enginetest/doltgres_harness_test.go +++ b/testing/go/enginetest/doltgres_harness_test.go @@ -33,7 +33,6 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/analyzer" gmstypes "github.com/dolthub/go-mysql-server/sql/types" - "github.com/dolthub/vitess/go/vt/proto/query" vitess "github.com/dolthub/vitess/go/vt/sqlparser" _ "github.com/jackc/pgx/v4/stdlib" "github.com/stretchr/testify/assert" @@ -700,7 +699,7 @@ func (d *DoltgresQueryEngine) EnginePreparedDataCache() *gms.PreparedDataCache { panic("implement me") } -func (d *DoltgresQueryEngine) QueryWithBindings(ctx *sql.Context, query string, parsed vitess.Statement, bindings map[string]*query.BindVariable, qFlags *sql.QueryFlags) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) { +func (d *DoltgresQueryEngine) QueryWithBindings(ctx *sql.Context, query string, parsed vitess.Statement, bindings map[string]vitess.Expr, qFlags *sql.QueryFlags) (sql.Schema, sql.RowIter, *sql.QueryFlags, error) { if len(bindings) > 0 { return nil, nil, nil, fmt.Errorf("bindings not supported") } @@ -734,7 +733,7 @@ func columns(rows *gosql.Rows) (sql.Schema, []interface{}, error) { colVal := gosql.NullBool{} columnVals = append(columnVals, &colVal) schema = append(schema, &sql.Column{Name: columnType.Name(), Type: gmstypes.Int8, Nullable: true}) - case "TEXT", "VARCHAR", "MEDIUMTEXT", "CHAR", "TINYTEXT", "NAME", "BYTEA": + case "TEXT", "VARCHAR", "MEDIUMTEXT", "CHAR", "TINYTEXT", "NAME", "BYTEA", "_TEXT": colVal := gosql.NullString{} columnVals = append(columnVals, &colVal) schema = append(schema, &sql.Column{Name: columnType.Name(), Type: gmstypes.LongText, Nullable: true}) diff --git a/testing/go/framework.go b/testing/go/framework.go index f94aa5f66c..0485b30bf3 100644 --- a/testing/go/framework.go +++ b/testing/go/framework.go @@ -28,12 +28,17 @@ import ( "github.com/dolthub/dolt/go/libraries/utils/svcs" "github.com/dolthub/go-mysql-server/sql" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/dolthub/doltgresql/postgres/parser/duration" + "github.com/dolthub/doltgresql/postgres/parser/uuid" dserver "github.com/dolthub/doltgresql/server" + "github.com/dolthub/doltgresql/server/functions" + "github.com/dolthub/doltgresql/server/types" "github.com/dolthub/doltgresql/servercfg" ) @@ -168,7 +173,7 @@ func runScript(t *testing.T, ctx context.Context, script ScriptTest, conn *pgx.C } if normalizeRows { - assert.Equal(t, NormalizeRows(assertion.Expected), readRows) + assert.Equal(t, NormalizeExpectedRow(rows.FieldDescriptions(), assertion.Expected), readRows) } else { assert.Equal(t, assertion.Expected, readRows) } @@ -275,106 +280,286 @@ func ReadRows(rows pgx.Rows, normalizeRows bool) (readRows []sql.Row, err error) } slice = append(slice, row) } - if normalizeRows { - return NormalizeRows(slice), nil - } else { - // We must always normalize Numeric values, as they have an infinite number of ways to represent the same value - return NormalizeRowsOnlyNumeric(slice), nil + return NormalizeRows(rows.FieldDescriptions(), slice, normalizeRows), nil +} + +// NormalizeRows normalizes each value's type within each row, as the tests only want to compare values. Returns a new +// set of rows in the same order. +func NormalizeRows(fds []pgconn.FieldDescription, rows []sql.Row, normalize bool) []sql.Row { + newRows := make([]sql.Row, len(rows)) + for i := range rows { + newRows[i] = NormalizeRow(fds, rows[i], normalize) } + return newRows } -// NormalizeRow normalizes each value's type, as the tests only want to compare values. Returns a new row. -func NormalizeRow(row sql.Row) sql.Row { +// NormalizeRow normalizes each value's type, as the tests only want to compare values. +// Returns a new row. +func NormalizeRow(fds []pgconn.FieldDescription, row sql.Row, normalize bool) sql.Row { if len(row) == 0 { return nil } newRow := make(sql.Row, len(row)) for i := range row { - switch val := row[i].(type) { - case int: - newRow[i] = int64(val) - case int8: - newRow[i] = int64(val) - case int16: - newRow[i] = int64(val) - case int32: - newRow[i] = int64(val) - case uint: - newRow[i] = int64(val) - case uint8: - newRow[i] = int64(val) - case uint16: - newRow[i] = int64(val) - case uint32: - newRow[i] = int64(val) - case uint64: - // PostgreSQL does not support an uint64 type, so we can always convert this to an int64 safely. - newRow[i] = int64(val) - case float32: - newRow[i] = float64(val) - case pgtype.Numeric: - if val.NaN { - newRow[i] = math.NaN() - } else if val.InfinityModifier != pgtype.Finite { - newRow[i] = math.Inf(int(val.InfinityModifier)) - } else if !val.Valid { - newRow[i] = nil - } else { - fVal, err := val.Float64Value() - if err != nil { - panic(err) - } - if !fVal.Valid { - panic("no idea why the numeric float value is invalid") - } - newRow[i] = fVal.Float64 - } - case time.Time: - newRow[i] = val.Format("2006-01-02 15:04:05") - case map[string]interface{}: - str, err := json.Marshal(val) - if err != nil { - panic(err) - } - newRow[i] = string(str) - default: - newRow[i] = val + dt, ok := dserver.OidToDoltgresType[fds[i].DataTypeOID] + if !ok { + panic(fmt.Sprintf("unhandled oid type: %v", fds[i].DataTypeOID)) + } + newRow[i] = NormalizeValToString(dt, row[i]) + if normalize { + newRow[i] = NormalizeIntsAndFloats(newRow[i]) } } return newRow } -// NormalizeRows normalizes each value's type within each row, as the tests only want to compare values. Returns a new -// set of rows in the same order. -func NormalizeRows(rows []sql.Row) []sql.Row { +// NormalizeExpectedRow normalizes each value's type, as the tests only want to compare values. Returns a new row. +func NormalizeExpectedRow(fds []pgconn.FieldDescription, rows []sql.Row) []sql.Row { newRows := make([]sql.Row, len(rows)) - for i := range rows { - newRows[i] = NormalizeRow(rows[i]) + for ri, row := range rows { + if len(row) == 0 { + newRows[ri] = nil + } else { + newRow := make(sql.Row, len(row)) + for i := range row { + dt, ok := dserver.OidToDoltgresType[fds[i].DataTypeOID] + if !ok { + panic(fmt.Sprintf("unhandled oid type: %v", fds[i].DataTypeOID)) + } + if dt == types.Json { + newRow[i] = UnmarshalAndMarshalJsonString(row[i].(string)) + } else if dta, ok := dt.(types.DoltgresArrayType); ok && dta.BaseType() == types.Json { + v, err := dta.IoInput(nil, row[i].(string)) + if err != nil { + panic(err) + } + arr := v.([]any) + newArr := make([]any, len(arr)) + for j, el := range arr { + newArr[j] = UnmarshalAndMarshalJsonString(el.(string)) + } + ret, err := dt.IoOutput(nil, newArr) + if err != nil { + panic(err) + } + newRow[i] = ret + } else { + newRow[i] = NormalizeIntsAndFloats(row[i]) + } + } + newRows[ri] = newRow + } } return newRows } -// NormalizeRowsOnlyNumeric normalizes Numeric values only. There are an infinite number of ways to represent the same -// value in-memory, so we must at least normalize Numeric values. -func NormalizeRowsOnlyNumeric(rows []sql.Row) []sql.Row { - newRows := make([]sql.Row, len(rows)) - for rowIdx, row := range rows { - newRow := make(sql.Row, len(row)) - copy(newRow, row) - for colIdx := range newRow { - if numericValue, ok := newRow[colIdx].(pgtype.Numeric); ok { - val, err := numericValue.Value() - if err != nil { - panic(err) // Should never happen - } - // Using decimal as an intermediate value will remove all differences between the string formatting - d := decimal.RequireFromString(val.(string)) - newRow[colIdx] = Numeric(d.String()) - } +// UnmarshalAndMarshalJsonString is used to normalize expected json type value to compare the actual value. +// JSON type value is in string format, and since Postrges JSON type preserves the input string if valid, +// it cannot be compared to the returned map as json.Marshal method space padded key value pair. +// To allow result matching, we unmarshal and marshal the expected string. This causes missing check +// for the identical format as the input of the json string. +func UnmarshalAndMarshalJsonString(val string) string { + var decoded any + err := json.Unmarshal([]byte(val), &decoded) + if err != nil { + panic(err) + } + ret, err := json.Marshal(decoded) + if err != nil { + panic(err) + } + return string(ret) +} + +// NormalizeValToString normalizes values into types that can be compared. +// JSON types, any pg types and time and decimal type values are converted into string value. +// |normalizeNumeric| defines whether to normalize Numeric values into either Numeric type or string type. +// There are an infinite number of ways to represent the same value in-memory, +// so we must at least normalize Numeric values. +func NormalizeValToString(dt types.DoltgresType, v any) any { + switch t := dt.(type) { + case types.JsonType: + str, err := json.Marshal(v) + if err != nil { + panic(err) } - newRows[rowIdx] = newRow + ret, err := t.IoOutput(nil, string(str)) + if err != nil { + panic(err) + } + return ret + case types.JsonBType: + jv, err := t.ConvertToJsonDocument(v) + if err != nil { + panic(err) + } + str, err := t.IoOutput(nil, types.JsonDocument{Value: jv}) + if err != nil { + panic(err) + } + return str + case types.InternalCharType: + if v == nil { + return nil + } + var b []byte + if v.(int32) == 0 { + b = []byte{} + } else { + b = []byte{uint8(v.(int32))} + } + val, err := t.IoOutput(nil, string(b)) + if err != nil { + panic(err) + } + return val + } + + switch val := v.(type) { + case bool: + if val { + return "t" + } else { + return "f" + } + case pgtype.Numeric: + if val.NaN { + return math.NaN() + } else if val.InfinityModifier != pgtype.Finite { + return math.Inf(int(val.InfinityModifier)) + } else if !val.Valid { + return nil + } else { + decStr := decimal.NewFromBigInt(val.Int, val.Exp).StringFixed(val.Exp * -1) + return Numeric(decStr) + } + case pgtype.Time, pgtype.Interval, [16]byte, time.Time: + // These values need to be normalized into the appropriate types + // before being converted to string type using the Doltgres + // IoOutput method. + // - pgtype.Time is specific to Time type. + // - pgtype.Interval is specific to Interval type. + // - [16]byte is specific to UUID type + // - time.Time is specific to date, timetz, timestamp and timestamptz types. + tVal, err := dt.IoOutput(nil, NormalizeVal(dt, val)) + if err != nil { + panic(err) + } + return tVal + case []any: + if dta, ok := dt.(types.DoltgresArrayType); ok { + return NormalizeArrayType(dta, val) + } + } + return v +} + +// NormalizeArrayType normalizes array types by normalizing its elements first, +// then to a string using the type IoOutput method. +func NormalizeArrayType(dta types.DoltgresArrayType, arr []any) any { + newVal := make([]any, len(arr)) + for i, el := range arr { + newVal[i] = NormalizeVal(dta.BaseType(), el) + } + baseType := dta.BaseType() + if baseType == types.Bool { + sqlVal, err := dta.SQL(nil, nil, newVal) + if err != nil { + panic(err) + } + return sqlVal.ToString() + } else { + ret, err := dta.IoOutput(nil, newVal) + if err != nil { + panic(err) + } + return ret + } +} + +// NormalizeVal normalizes values to the Doltgres type expects, so it can be used to +// convert the values using the given Doltgres type. This is used to normalize array +// types as the type conversion expects certain type values. +func NormalizeVal(dt types.DoltgresType, v any) any { + switch t := dt.(type) { + case types.JsonType: + str, err := json.Marshal(v) + if err != nil { + panic(err) + } + return string(str) + case types.JsonBType: + jv, err := t.ConvertToJsonDocument(v) + if err != nil { + panic(err) + } + return types.JsonDocument{Value: jv} + } + + switch val := v.(type) { + case pgtype.Numeric: + if val.NaN { + return math.NaN() + } else if val.InfinityModifier != pgtype.Finite { + return math.Inf(int(val.InfinityModifier)) + } else if !val.Valid { + return nil + } else { + return decimal.NewFromBigInt(val.Int, val.Exp) + } + case pgtype.Time: + dur := time.Duration(val.Microseconds) * time.Microsecond + return time.Time{}.Add(dur) + case pgtype.Interval: + return duration.MakeDuration(val.Microseconds*functions.NanosPerMicro, int64(val.Days), int64(val.Months)) + case [16]byte: + u, err := uuid.FromBytes(val[:]) + if err != nil { + panic(err) + } + return u + case []any: + baseType := dt + if dta, ok := baseType.(types.DoltgresArrayType); ok { + baseType = dta.BaseType() + } + newVal := make([]any, len(val)) + for i, el := range val { + newVal[i] = NormalizeVal(baseType, el) + } + return newVal + } + return v +} + +// NormalizeIntsAndFloats normalizes all int and float types +// to int64 and float64, respectively. +func NormalizeIntsAndFloats(v any) any { + switch val := v.(type) { + case int: + return int64(val) + case int8: + return int64(val) + case int16: + return int64(val) + case int32: + return int64(val) + case uint: + return int64(val) + case uint8: + return int64(val) + case uint16: + return int64(val) + case uint32: + return int64(val) + case uint64: + // PostgreSQL does not support an uint64 type, so we can always convert this to an int64 safely. + return int64(val) + case float32: + return float64(val) + default: + return val } - return newRows } // GetUnusedPort returns an unused port. diff --git a/testing/go/functions_test.go b/testing/go/functions_test.go index bb46dc1a92..0cd8d4add8 100644 --- a/testing/go/functions_test.go +++ b/testing/go/functions_test.go @@ -43,9 +43,9 @@ func TestFunctionsMath(t *testing.T) { Query: `SELECT round(cbrt(v1)::numeric, 10), round(cbrt(v2)::numeric, 10), round(cbrt(v3)::numeric, 10) FROM test ORDER BY pk;`, Cols: []string{"round", "round", "round"}, Expected: []sql.Row{ - {-1.0000000000, -1.2599210499, -1.4422495703}, - {1.9129311828, 2.2239800906, 2.3513346877}, - {2.6684016487, -2.8438669799, 3.0723168257}, + {Numeric("-1.0000000000"), Numeric("-1.2599210499"), Numeric("-1.4422495703")}, + {Numeric("1.9129311828"), Numeric("2.2239800906"), Numeric("2.3513346877")}, + {Numeric("2.6684016487"), Numeric("-2.8438669799"), Numeric("3.0723168257")}, }, }, { @@ -991,56 +991,40 @@ func TestSchemaVisibilityInquiryFunctions(t *testing.T) { }, }, { - Query: `SHOW search_path;`, - Expected: []sql.Row{ - {"testschema"}, - }, + Query: `SHOW search_path;`, + Expected: []sql.Row{{"testschema"}}, }, { - Query: `select pg_table_is_visible(1613758465);`, // index from testschema - Expected: []sql.Row{ - {"t"}, - }, + Query: `select pg_table_is_visible(1613758465);`, // index from testschema + Expected: []sql.Row{{"t"}}, }, { - Query: `select pg_table_is_visible(2687500288);`, // table from testschema - Expected: []sql.Row{ - {"t"}, - }, + Query: `select pg_table_is_visible(2687500288);`, // table from testschema + Expected: []sql.Row{{"t"}}, }, { - Query: `select pg_table_is_visible(2419064832);`, // sequence from testschema - Expected: []sql.Row{ - {"t"}, - }, + Query: `select pg_table_is_visible(2419064832);`, // sequence from testschema + Expected: []sql.Row{{"t"}}, }, { - Query: `select pg_table_is_visible(2952790016);`, // view from myschema - Expected: []sql.Row{ - {"f"}, - }, + Query: `select pg_table_is_visible(2952790016);`, // view from myschema + Expected: []sql.Row{{"f"}}, }, { Query: `SET search_path = 'myschema';`, Expected: []sql.Row{}, }, { - Query: `SHOW search_path;`, - Expected: []sql.Row{ - {"myschema"}, - }, + Query: `SHOW search_path;`, + Expected: []sql.Row{{"myschema"}}, }, { - Query: `select pg_table_is_visible(2952790016);`, // view from myschema - Expected: []sql.Row{ - {"t"}, - }, + Query: `select pg_table_is_visible(2952790016);`, // view from myschema + Expected: []sql.Row{{"t"}}, }, { - Query: `select pg_table_is_visible(2684354560);`, // table from myschema - Expected: []sql.Row{ - {"t"}, - }, + Query: `select pg_table_is_visible(2684354560);`, // table from myschema + Expected: []sql.Row{{"t"}}, }, }, }, @@ -1132,115 +1116,115 @@ func TestDateAndTimeFunction(t *testing.T) { Assertions: []ScriptTestAssertion{ { Query: `SELECT EXTRACT(CENTURY FROM TIMESTAMP '2000-12-16 12:21:13');`, - Expected: []sql.Row{{float64(20)}}, + Expected: []sql.Row{{Numeric("20")}}, }, { Query: `SELECT EXTRACT(CENTURY FROM TIMESTAMP '2001-02-16 20:38:40');`, - Expected: []sql.Row{{float64(21)}}, + Expected: []sql.Row{{Numeric("21")}}, }, { Skip: true, // TODO: cannot parse calendar era Query: `SELECT EXTRACT(CENTURY FROM DATE '0001-01-01 AD');`, - Expected: []sql.Row{{float64(1)}}, + Expected: []sql.Row{{Numeric("1")}}, }, { Skip: true, // TODO: cannot parse calendar era Query: `SELECT EXTRACT(CENTURY FROM DATE '0001-12-31 BC');`, - Expected: []sql.Row{{float64(-1)}}, + Expected: []sql.Row{{Numeric("-1")}}, }, { Query: `SELECT EXTRACT(DAY FROM TIMESTAMP '2001-02-16 20:38:40');`, - Expected: []sql.Row{{float64(16)}}, + Expected: []sql.Row{{Numeric("16")}}, }, { Query: `SELECT EXTRACT(DECADE FROM TIMESTAMP '2001-02-16 20:38:40');`, - Expected: []sql.Row{{float64(200)}}, + Expected: []sql.Row{{Numeric("200")}}, }, { Query: `SELECT EXTRACT(DOW FROM TIMESTAMP '2001-02-16 20:38:40');`, - Expected: []sql.Row{{float64(5)}}, + Expected: []sql.Row{{Numeric("5")}}, }, { Query: `SELECT EXTRACT(DOY FROM TIMESTAMP '2001-02-16 20:38:40');`, - Expected: []sql.Row{{float64(47)}}, + Expected: []sql.Row{{Numeric("47")}}, }, { Query: `SELECT EXTRACT(EPOCH FROM TIMESTAMP WITH TIME ZONE '2001-02-16 20:38:40.12-08');`, - Expected: []sql.Row{{float64(982384720.120000)}}, + Expected: []sql.Row{{Numeric("982384720.120000")}}, }, { Query: `SELECT EXTRACT(EPOCH FROM TIMESTAMP '2001-02-16 20:38:40.12');`, - Expected: []sql.Row{{float64(982355920.120000)}}, + Expected: []sql.Row{{Numeric("982355920.120000")}}, }, { Query: `SELECT EXTRACT(HOUR FROM TIMESTAMP '2001-02-16 20:38:40');`, - Expected: []sql.Row{{float64(20)}}, + Expected: []sql.Row{{Numeric("20")}}, }, { Query: `SELECT EXTRACT(ISODOW FROM TIMESTAMP '2001-02-18 20:38:40');`, - Expected: []sql.Row{{float64(7)}}, + Expected: []sql.Row{{Numeric("7")}}, }, { Query: `SELECT EXTRACT(ISOYEAR FROM DATE '2006-01-01');`, - Expected: []sql.Row{{float64(2005)}}, + Expected: []sql.Row{{Numeric("2005")}}, }, { Query: `SELECT EXTRACT(ISOYEAR FROM DATE '2006-01-02');`, - Expected: []sql.Row{{float64(2006)}}, + Expected: []sql.Row{{Numeric("2006")}}, }, { Skip: true, // TODO: not supported yet Query: `SELECT EXTRACT(JULIAN FROM DATE '2006-01-01');`, - Expected: []sql.Row{{float64(2453737)}}, + Expected: []sql.Row{{Numeric("2453737")}}, }, { Skip: true, // TODO: not supported yet Query: `SELECT EXTRACT(JULIAN FROM TIMESTAMP '2006-01-01 12:00');`, - Expected: []sql.Row{{float64(2453737.50000000000000000000)}}, + Expected: []sql.Row{{Numeric("2453737.50000000000000000000")}}, }, { Query: `SELECT EXTRACT(MICROSECONDS FROM TIME '17:12:28.5');`, - Expected: []sql.Row{{float64(28500000)}}, + Expected: []sql.Row{{Numeric("28500000")}}, }, { Query: `SELECT EXTRACT(MILLENNIUM FROM TIMESTAMP '2001-02-16 20:38:40');`, - Expected: []sql.Row{{float64(3)}}, + Expected: []sql.Row{{Numeric("3")}}, }, { Query: `SELECT EXTRACT(MILLENNIUM FROM TIMESTAMP '2000-02-16 20:38:40');`, - Expected: []sql.Row{{float64(2)}}, + Expected: []sql.Row{{Numeric("2")}}, }, { Query: `SELECT EXTRACT(MILLISECONDS FROM TIME '17:12:28.5');`, - Expected: []sql.Row{{float64(28500.000)}}, + Expected: []sql.Row{{Numeric("28500.000")}}, }, { Query: `SELECT EXTRACT(MINUTE FROM TIMESTAMP '2001-02-16 20:38:40');`, - Expected: []sql.Row{{float64(38)}}, + Expected: []sql.Row{{Numeric("38")}}, }, { Query: `SELECT EXTRACT(MONTH FROM TIMESTAMP '2001-02-16 20:38:40');`, - Expected: []sql.Row{{float64(2)}}, + Expected: []sql.Row{{Numeric("2")}}, }, { Query: `SELECT EXTRACT(QUARTER FROM TIMESTAMP '2001-02-16 20:38:40');`, - Expected: []sql.Row{{float64(1)}}, + Expected: []sql.Row{{Numeric("1")}}, }, { Query: `SELECT EXTRACT(SECOND FROM TIMESTAMP '2001-02-16 20:38:40');`, - Expected: []sql.Row{{float64(40.000000)}}, + Expected: []sql.Row{{Numeric("40.000000")}}, }, { Query: `SELECT EXTRACT(SECOND FROM TIME '17:12:28.5');`, - Expected: []sql.Row{{float64(28.500000)}}, + Expected: []sql.Row{{Numeric("28.500000")}}, }, { Query: `SELECT EXTRACT(WEEK FROM TIMESTAMP '2001-02-16 20:38:40');`, - Expected: []sql.Row{{float64(7)}}, + Expected: []sql.Row{{Numeric("7")}}, }, { Query: `SELECT EXTRACT(YEAR FROM TIMESTAMP '2001-02-16 20:38:40');`, - Expected: []sql.Row{{float64(2001)}}, + Expected: []sql.Row{{Numeric("2001")}}, }, }, }, @@ -1250,71 +1234,71 @@ func TestDateAndTimeFunction(t *testing.T) { Assertions: []ScriptTestAssertion{ { Query: `SELECT EXTRACT(CENTURY FROM INTERVAL '2001 years');`, - Expected: []sql.Row{{float64(20)}}, + Expected: []sql.Row{{Numeric("20")}}, }, { Query: `SELECT EXTRACT(DAY FROM INTERVAL '40 days 1 minute');`, - Expected: []sql.Row{{float64(40)}}, + Expected: []sql.Row{{Numeric("40")}}, }, { Query: `select extract(decades from interval '1000 months');`, - Expected: []sql.Row{{float64(8)}}, + Expected: []sql.Row{{Numeric("8")}}, }, { Query: `SELECT EXTRACT(EPOCH FROM INTERVAL '5 days 3 hours');`, - Expected: []sql.Row{{float64(442800.000000)}}, + Expected: []sql.Row{{Numeric("442800.000000")}}, }, { Query: `select extract(epoch from interval '10 months 10 seconds');`, - Expected: []sql.Row{{float64(25920010.000000)}}, + Expected: []sql.Row{{Numeric("25920010.000000")}}, }, { Query: `select extract(hours from interval '10 months 65 minutes 10 seconds');`, - Expected: []sql.Row{{float64(1)}}, + Expected: []sql.Row{{Numeric("1")}}, }, { Query: `select extract(microsecond from interval '10 months 65 minutes 10 seconds');`, - Expected: []sql.Row{{float64(10000000)}}, + Expected: []sql.Row{{Numeric("10000000")}}, }, { Query: `SELECT EXTRACT(MILLENNIUM FROM INTERVAL '2001 years');`, - Expected: []sql.Row{{float64(2)}}, + Expected: []sql.Row{{Numeric("2")}}, }, { Query: `select extract(millenniums from interval '3000 years 65 minutes 10 seconds');`, - Expected: []sql.Row{{float64(3)}}, + Expected: []sql.Row{{Numeric("3")}}, }, { Query: `select extract(millisecond from interval '10 months 65 minutes 10 seconds');`, - Expected: []sql.Row{{float64(10000.000)}}, + Expected: []sql.Row{{Numeric("10000.000")}}, }, { Query: `select extract(minutes from interval '10 months 65 minutes 10 seconds');`, - Expected: []sql.Row{{float64(5)}}, + Expected: []sql.Row{{Numeric("5")}}, }, { Query: `SELECT EXTRACT(MONTH FROM INTERVAL '2 years 3 months');`, - Expected: []sql.Row{{float64(3)}}, + Expected: []sql.Row{{Numeric("3")}}, }, { Query: `SELECT EXTRACT(MONTH FROM INTERVAL '2 years 13 months');`, - Expected: []sql.Row{{float64(1)}}, + Expected: []sql.Row{{Numeric("1")}}, }, { Query: `select extract(months from interval '20 months 65 minutes 10 seconds');`, - Expected: []sql.Row{{float64(8)}}, + Expected: []sql.Row{{Numeric("8")}}, }, { Query: `select extract(quarter from interval '20 months 65 minutes 10 seconds');`, - Expected: []sql.Row{{float64(3)}}, + Expected: []sql.Row{{Numeric("3")}}, }, { Query: `select extract(seconds from interval '65 minutes 10 seconds 5 millisecond');`, - Expected: []sql.Row{{float64(10.005000)}}, + Expected: []sql.Row{{Numeric("10.005000")}}, }, { Query: `select extract(years from interval '20 months 65 minutes 10 seconds');`, - Expected: []sql.Row{{float64(1)}}, + Expected: []sql.Row{{Numeric("1")}}, }, }, }, diff --git a/testing/go/operators_test.go b/testing/go/operators_test.go index 2a8ebfe563..2da6d60f5c 100644 --- a/testing/go/operators_test.go +++ b/testing/go/operators_test.go @@ -574,7 +574,7 @@ func TestOperators(t *testing.T) { }, { Query: `SELECT 8::int2 / 2::numeric;`, - Expected: []sql.Row{{Numeric("4")}}, + Expected: []sql.Row{{Numeric("4.0000000000000000")}}, }, { Query: `SELECT 8::int4 / 2::float4;`, @@ -598,7 +598,7 @@ func TestOperators(t *testing.T) { }, { Query: `SELECT 8::int4 / 2::numeric;`, - Expected: []sql.Row{{Numeric("4")}}, + Expected: []sql.Row{{Numeric("4.0000000000000000")}}, }, { Query: `SELECT 8::int8 / 2::float4;`, @@ -622,7 +622,7 @@ func TestOperators(t *testing.T) { }, { Query: `SELECT 8::int8 / 2::numeric;`, - Expected: []sql.Row{{Numeric("4")}}, + Expected: []sql.Row{{Numeric("4.0000000000000000")}}, }, { Query: `SELECT 8::numeric / 2::float4;`, @@ -634,19 +634,19 @@ func TestOperators(t *testing.T) { }, { Query: `SELECT 8::numeric / 2::int2;`, - Expected: []sql.Row{{Numeric("4")}}, + Expected: []sql.Row{{Numeric("4.0000000000000000")}}, }, { Query: `SELECT 8::numeric / 2::int4;`, - Expected: []sql.Row{{Numeric("4")}}, + Expected: []sql.Row{{Numeric("4.0000000000000000")}}, }, { Query: `SELECT 8::numeric / 2::int8;`, - Expected: []sql.Row{{Numeric("4")}}, + Expected: []sql.Row{{Numeric("4.0000000000000000")}}, }, { Query: `SELECT 8::numeric / 2::numeric;`, - Expected: []sql.Row{{Numeric("4")}}, + Expected: []sql.Row{{Numeric("4.0000000000000000")}}, }, { Query: `select interval '20 days' / 2.3`, @@ -3182,7 +3182,7 @@ func TestOperators(t *testing.T) { Assertions: []ScriptTestAssertion{ { Query: `SELECT '[{"a":"foo"},{"b":"bar"},{"c":"baz"}]'::json -> 2;`, - Expected: []sql.Row{{`{"c": "baz"}`}}, + Expected: []sql.Row{{`{"c":"baz"}`}}, }, { Query: `SELECT '[{"a":"foo"},{"b":"bar"},{"c":"baz"}]'::jsonb -> 2;`, @@ -3190,7 +3190,7 @@ func TestOperators(t *testing.T) { }, { Query: `SELECT '[{"a":"foo"},{"b":"bar"},{"c":"baz"}]'::json -> -3;`, - Expected: []sql.Row{{`{"a": "foo"}`}}, + Expected: []sql.Row{{`{"a":"foo"}`}}, }, { Query: `SELECT '[{"a":"foo"},{"b":"bar"},{"c":"baz"}]'::jsonb -> -3;`, @@ -3198,7 +3198,7 @@ func TestOperators(t *testing.T) { }, { Query: `SELECT '{"a": {"b":"foo"}}'::json -> 'a';`, - Expected: []sql.Row{{`{"b": "foo"}`}}, + Expected: []sql.Row{{`{"b":"foo"}`}}, }, { Query: `SELECT '{"a": {"b":"foo"}}'::jsonb -> 'a';`, diff --git a/testing/go/prepared_statement_test.go b/testing/go/prepared_statement_test.go index 3d253b262e..330a898bf8 100755 --- a/testing/go/prepared_statement_test.go +++ b/testing/go/prepared_statement_test.go @@ -546,7 +546,7 @@ func RunScriptN(t *testing.T, script ScriptTest, n int) { foundRows, err := ReadRows(rows, true) if assertion.ExpectedErr == "" { require.NoError(t, err) - assert.Equal(t, NormalizeRows(assertion.Expected), foundRows) + assert.Equal(t, NormalizeExpectedRow(rows.FieldDescriptions(), assertion.Expected), foundRows) } else if err != nil { errorSeen = err.Error() } diff --git a/testing/go/replication_test.go b/testing/go/replication_test.go index 9baa415599..2ca8ca3332 100755 --- a/testing/go/replication_test.go +++ b/testing/go/replication_test.go @@ -654,7 +654,7 @@ func runReplicationScript( require.NoError(t, err) readRows, err := ReadRows(rows, true) require.NoError(t, err) - normalizedRows := NormalizeRows(assertion.Expected) + normalizedRows := NormalizeExpectedRow(rows.FieldDescriptions(), assertion.Expected) // For queries against the replica, whether or not replication is caught up is a heuristic that can be // incorrect. So we retry queries with sleeps in between to give replication a chance to catch up when this diff --git a/testing/go/smoke_test.go b/testing/go/smoke_test.go index 73b1c09f38..528a24eb15 100644 --- a/testing/go/smoke_test.go +++ b/testing/go/smoke_test.go @@ -326,9 +326,17 @@ func TestSmokeTests(t *testing.T) { }, }, { + Skip: true, // TODO: result differs from Postgres + Query: `SELECT '{"\x68656c6c6f", "\x776f726c64", "\x6578616d706c65"}'::bytea[]::text[];`, + Expected: []sql.Row{ + {`{"\\x7836383635366336633666","\\x7837373666373236633634","\\x783635373836313664373036633635"}`}, + }, + }, + { + Skip: true, // TODO: result differs from Postgres Query: `SELECT '{"\\x68656c6c6f", "\\x776f726c64", "\\x6578616d706c65"}'::bytea[]::text[];`, Expected: []sql.Row{ - {`{"\x68656c6c6f","\x776f726c64","\x6578616d706c65"}`}, + {`{"\\x68656c6c6f", "\\x776f726c64", "\\x6578616d706c65"}`}, }, }, { diff --git a/testing/go/ssl_test.go b/testing/go/ssl_test.go index a73047e172..cf5b472ef9 100644 --- a/testing/go/ssl_test.go +++ b/testing/go/ssl_test.go @@ -97,5 +97,5 @@ func TestSSL(t *testing.T) { require.NoError(t, err) readRows, err := ReadRows(rows, true) require.NoError(t, err) - assert.Equal(t, NormalizeRows([]sql.Row{{3645, 37643}}), readRows) + assert.Equal(t, NormalizeExpectedRow(rows.FieldDescriptions(), []sql.Row{{3645, 37643}}), readRows) } diff --git a/testing/go/types_test.go b/testing/go/types_test.go index 1e47ea7711..5ddf4efdca 100644 --- a/testing/go/types_test.go +++ b/testing/go/types_test.go @@ -732,14 +732,25 @@ var typesTests = []ScriptTest{ Name: "JSON type", SetUpScript: []string{ "CREATE TABLE t_json (id INTEGER primary key, v1 JSON);", - "INSERT INTO t_json VALUES (1, '{\"key\": \"value\"}'), (2, '{\"num\":42}');", + `INSERT INTO t_json VALUES (1, '{"key1": {"key": "value"}}'), (2, '{"num":42}'), (3, '{"key1": "value1", "key2": "value2"}'), (4, '{"key1": {"key": [2,3]}}');`, }, Assertions: []ScriptTestAssertion{ + { + Query: "SELECT * FROM t_json ORDER BY 1;", + Expected: []sql.Row{ + {1, `{"key1": {"key": "value"}}`}, + {2, `{"num":42}`}, + {3, `{"key1": "value1", "key2": "value2"}`}, + {4, `{"key1": {"key": [2,3]}}`}, + }, + }, { Query: "SELECT * FROM t_json ORDER BY id;", Expected: []sql.Row{ - {1, `{"key": "value"}`}, + {1, `{"key1": {"key": "value"}}`}, {2, `{"num":42}`}, + {3, `{"key1": "value1", "key2": "value2"}`}, + {4, `{"key1": {"key": [2,3]}}`}, }, }, { @@ -1180,7 +1191,7 @@ var typesTests = []ScriptTest{ { Query: "SELECT v1::smallint, v1::integer, v1::bigint, v1::float4, v1::float8, v1::numeric FROM t_name WHERE id=2;", Expected: []sql.Row{ - {12345, 12345, 12345, float64(12345), float64(12345), float64(12345)}, + {12345, 12345, 12345, float64(12345), float64(12345), Numeric("12345")}, }, }, { @@ -1298,14 +1309,15 @@ var typesTests = []ScriptTest{ Name: "Numeric type", SetUpScript: []string{ "CREATE TABLE t_numeric (id INTEGER primary key, v1 NUMERIC(5,2));", - "INSERT INTO t_numeric VALUES (1, 123.45), (2, 67.89);", + "INSERT INTO t_numeric VALUES (1, 123.45), (2, 67.89), (3, 100.3);", }, Assertions: []ScriptTestAssertion{ { Query: "SELECT * FROM t_numeric ORDER BY id;", Expected: []sql.Row{ - {1, 123.45}, - {2, 67.89}, + {1, Numeric("123.45")}, + {2, Numeric("67.89")}, + {3, Numeric("100.30")}, }, }, }, @@ -1314,14 +1326,15 @@ var typesTests = []ScriptTest{ Name: "Numeric type, no scale or precision", SetUpScript: []string{ "CREATE TABLE t_numeric (id INTEGER primary key, v1 NUMERIC);", - "INSERT INTO t_numeric VALUES (1, 123.45), (2, 67.875);", + "INSERT INTO t_numeric VALUES (1, 123.45), (2, 67.875), (3, 100.3);", }, Assertions: []ScriptTestAssertion{ { Query: "SELECT * FROM t_numeric ORDER BY id;", Expected: []sql.Row{ - {1, 123.45}, - {2, 67.875}, + {1, Numeric("123.45")}, + {2, Numeric("67.875")}, + {3, Numeric("100.3")}, }, }, }, @@ -2577,14 +2590,15 @@ func TestSameTypes(t *testing.T) { Name: "Arbitrary precision types", SetUpScript: []string{ "CREATE TABLE test (v1 DECIMAL(10, 1), v2 NUMERIC(11, 2));", - "INSERT INTO test VALUES (14854.5, 2504.25), (566821525.5, 735134574.75);", + "INSERT INTO test VALUES (14854.5, 2504.25), (566821525.5, 735134574.75), (21525, 134574.7);", }, Assertions: []ScriptTestAssertion{ { Query: "SELECT * FROM test ORDER BY 1;", Expected: []sql.Row{ - {14854.5, 2504.25}, - {566821525.5, 735134574.75}, + {Numeric("14854.5"), Numeric("2504.25")}, + {Numeric("21525.0"), Numeric("134574.70")}, + {Numeric("566821525.5"), Numeric("735134574.75")}, }, }, }, @@ -2648,30 +2662,5 @@ func TestSameTypes(t *testing.T) { }, }, }, - { - Name: "JSON type", - SetUpScript: []string{ - "CREATE TABLE test (v1 INT, v2 JSON);", - `INSERT INTO test VALUES (1, '{"key1": {"key": "value"}}'), (2, '{"key1": "value1", "key2": "value2"}'), (3, '{"key1": {"key": [2,3]}}');`, - }, - Assertions: []ScriptTestAssertion{ - { - Query: "SELECT * FROM test ORDER BY 1;", - Expected: []sql.Row{ - {1, `{"key1": {"key": "value"}}`}, - {2, `{"key1": "value1", "key2": "value2"}`}, - {3, `{"key1": {"key": [2,3]}}`}, - }, - }, - { - Query: "SELECT * FROM test ORDER BY v1;", - Expected: []sql.Row{ - {1, `{"key1": {"key": "value"}}`}, - {2, `{"key1": "value1", "key2": "value2"}`}, - {3, `{"key1": {"key": [2,3]}}`}, - }, - }, - }, - }, }) } diff --git a/testing/postgres-client-tests/node/fields.js b/testing/postgres-client-tests/node/fields.js index 94b633daf4..71dc65e8d4 100644 --- a/testing/postgres-client-tests/node/fields.js +++ b/testing/postgres-client-tests/node/fields.js @@ -4,7 +4,7 @@ export const countFields = [ tableID: 0, columnID: 0, dataTypeID: 20, - dataTypeSize: 8, + dataTypeSize: 20, dataTypeModifier: -1, format: "text", }, @@ -15,7 +15,7 @@ export const doltAddFields = [ name: "dolt_add", tableID: 0, columnID: 0, - dataTypeID: 25, + dataTypeID: 1009, dataTypeSize: -1, dataTypeModifier: -1, format: "text", @@ -27,7 +27,7 @@ export const doltBranchFields = [ name: "dolt_branch", tableID: 0, columnID: 0, - dataTypeID: 25, + dataTypeID: 1009, dataTypeSize: -1, dataTypeModifier: -1, format: "text", @@ -39,7 +39,7 @@ export const doltCheckoutFields = [ name: "dolt_checkout", tableID: 0, columnID: 0, - dataTypeID: 25, + dataTypeID: 1009, dataTypeSize: -1, dataTypeModifier: -1, format: "text", @@ -51,7 +51,7 @@ export const doltCommitFields = [ name: "dolt_commit", tableID: 0, columnID: 0, - dataTypeID: 25, + dataTypeID: 1009, dataTypeSize: -1, dataTypeModifier: -1, format: "text", @@ -64,7 +64,7 @@ export const doltStatusFields = [ tableID: 0, columnID: 0, dataTypeID: 25, - dataTypeSize: -1, + dataTypeSize: -4, dataTypeModifier: -1, format: "text", }, @@ -73,7 +73,7 @@ export const doltStatusFields = [ tableID: 0, columnID: 0, dataTypeID: 21, - dataTypeSize: 1, + dataTypeSize: 4, dataTypeModifier: -1, format: "text", }, @@ -82,7 +82,7 @@ export const doltStatusFields = [ tableID: 0, columnID: 0, dataTypeID: 25, - dataTypeSize: -1, + dataTypeSize: -4, dataTypeModifier: -1, format: "text", }, diff --git a/testing/postgres-client-tests/node/helpers.js b/testing/postgres-client-tests/node/helpers.js index a88d591e3c..1183c8de08 100644 --- a/testing/postgres-client-tests/node/helpers.js +++ b/testing/postgres-client-tests/node/helpers.js @@ -29,14 +29,13 @@ export function assertQueryResult(q, expected, data, matcher) { } if (q.toLowerCase().includes("dolt_commit")) { if (data.rows.length !== 1) return false; - const hash = data.rows[0].dolt_commit.slice(1, -1); - // dolt_commit row returns 32 character hash enclosed in brackets. + const hash = data.rows[0].dolt_commit[0]; + // dolt_commit row returns 32 character hash return hash.length === 32; } if (q.toLowerCase().includes("dolt_merge")) { if (data.rows.length !== 1) return false; - const res = data.rows[0].dolt_merge.slice(1, -1).split(","); - const [hash, fastForward, conflicts, message] = res; + const [hash, fastForward, conflicts, message] = data.rows[0].dolt_merge; return ( hash.length === 32 && expected.fastForward === fastForward && diff --git a/testing/postgres-client-tests/node/index.js b/testing/postgres-client-tests/node/index.js index 35fdfc4abd..2e8dbfd224 100644 --- a/testing/postgres-client-tests/node/index.js +++ b/testing/postgres-client-tests/node/index.js @@ -92,7 +92,7 @@ const tests = [ command: "SELECT", rowCount: 1, oid: null, - rows: [{ dolt_add: "{0}" }], + rows: [{ dolt_add: ["0"] }], fields: doltAddFields, }, }, @@ -102,7 +102,7 @@ const tests = [ command: "SELECT", rowCount: 1, oid: null, - rows: [{ dolt_commit: "" }], + rows: [{ dolt_commit: [""] }], fields: doltCommitFields, }, }, @@ -122,7 +122,7 @@ const tests = [ command: "SELECT", rowCount: 1, oid: null, - rows: [{ dolt_checkout: `{0,"Switched to branch 'mybranch'"}` }], + rows: [{ dolt_checkout: ["0","Switched to branch 'mybranch'"] }], fields: doltCheckoutFields, }, }, @@ -139,11 +139,21 @@ const tests = [ { q: "select dolt_commit('-a', '-m', 'my commit2')", res: { - command: "CALL", - rowCount: null, + command: "SELECT", + rowCount: 1, oid: null, - rows: [], - fields: [], + rows: [{ dolt_commit: [""] }], + fields: [ + { + name: "dolt_commit", + tableID: 0, + columnID: 0, + dataTypeID: 1009, + dataTypeSize: -1, + dataTypeModifier: -1, + format: "text", + }, + ], }, }, { @@ -152,7 +162,7 @@ const tests = [ command: "SELECT", rowCount: 1, oid: null, - rows: [{ dolt_checkout: `{0,"Switched to branch 'main'"}` }], + rows: [{ dolt_checkout: ["0","Switched to branch 'main'"] }], fields: doltCheckoutFields, }, }, @@ -161,7 +171,7 @@ const tests = [ res: { fastForward: "1", conflicts: "0", - message: `"merge successful"`, + message: "merge successful", }, }, { diff --git a/testing/postgres-client-tests/node/package-lock.json b/testing/postgres-client-tests/node/package-lock.json index 5174f06e2c..70ce76096a 100644 --- a/testing/postgres-client-tests/node/package-lock.json +++ b/testing/postgres-client-tests/node/package-lock.json @@ -10,7 +10,7 @@ "license": "ISC", "dependencies": { "knex": "^2.5.1", - "pg": "^8.11.5", + "pg": "^8.12.0", "pg-promise": "^11.6.0", "wtfnode": "^0.9.2" } @@ -182,9 +182,9 @@ "integrity": "sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw==" }, "node_modules/pg": { - "version": "8.11.5", - "resolved": "https://registry.npmjs.org/pg/-/pg-8.11.5.tgz", - "integrity": "sha512-jqgNHSKL5cbDjFlHyYsCXmQDrfIX/3RsNwYqpd4N0Kt8niLuNoRNH+aazv6cOd43gPh9Y4DjQCtb+X0MH0Hvnw==", + "version": "8.12.0", + "resolved": "https://registry.npmjs.org/pg/-/pg-8.12.0.tgz", + "integrity": "sha512-A+LHUSnwnxrnL/tZ+OLfqR1SxLN3c/pgDztZ47Rpbsd4jUytsTtwQo/TLPRzPJMp/1pbhYVhH9cuSZLAajNfjQ==", "dependencies": { "pg-connection-string": "^2.6.4", "pg-pool": "^3.6.2", @@ -256,6 +256,37 @@ "node": ">=14.0" } }, + "node_modules/pg-promise/node_modules/pg": { + "version": "8.11.5", + "resolved": "https://registry.npmjs.org/pg/-/pg-8.11.5.tgz", + "integrity": "sha512-jqgNHSKL5cbDjFlHyYsCXmQDrfIX/3RsNwYqpd4N0Kt8niLuNoRNH+aazv6cOd43gPh9Y4DjQCtb+X0MH0Hvnw==", + "dependencies": { + "pg-connection-string": "^2.6.4", + "pg-pool": "^3.6.2", + "pg-protocol": "^1.6.1", + "pg-types": "^2.1.0", + "pgpass": "1.x" + }, + "engines": { + "node": ">= 8.0.0" + }, + "optionalDependencies": { + "pg-cloudflare": "^1.1.1" + }, + "peerDependencies": { + "pg-native": ">=3.0.1" + }, + "peerDependenciesMeta": { + "pg-native": { + "optional": true + } + } + }, + "node_modules/pg-promise/node_modules/pg-connection-string": { + "version": "2.6.4", + "resolved": "https://registry.npmjs.org/pg-connection-string/-/pg-connection-string-2.6.4.tgz", + "integrity": "sha512-v+Z7W/0EO707aNMaAEfiGnGL9sxxumwLl2fJvCQtMn9Fxsg+lPpPkdcyBSv/KFgpGdYkMfn+EI1Or2EHjpgLCA==" + }, "node_modules/pg-protocol": { "version": "1.6.1", "resolved": "https://registry.npmjs.org/pg-protocol/-/pg-protocol-1.6.1.tgz", @@ -526,9 +557,9 @@ "integrity": "sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw==" }, "pg": { - "version": "8.11.5", - "resolved": "https://registry.npmjs.org/pg/-/pg-8.11.5.tgz", - "integrity": "sha512-jqgNHSKL5cbDjFlHyYsCXmQDrfIX/3RsNwYqpd4N0Kt8niLuNoRNH+aazv6cOd43gPh9Y4DjQCtb+X0MH0Hvnw==", + "version": "8.12.0", + "resolved": "https://registry.npmjs.org/pg/-/pg-8.12.0.tgz", + "integrity": "sha512-A+LHUSnwnxrnL/tZ+OLfqR1SxLN3c/pgDztZ47Rpbsd4jUytsTtwQo/TLPRzPJMp/1pbhYVhH9cuSZLAajNfjQ==", "requires": { "pg-cloudflare": "^1.1.1", "pg-connection-string": "^2.6.4", @@ -581,6 +612,26 @@ "pg": "8.11.5", "pg-minify": "1.6.3", "spex": "3.3.0" + }, + "dependencies": { + "pg": { + "version": "8.11.5", + "resolved": "https://registry.npmjs.org/pg/-/pg-8.11.5.tgz", + "integrity": "sha512-jqgNHSKL5cbDjFlHyYsCXmQDrfIX/3RsNwYqpd4N0Kt8niLuNoRNH+aazv6cOd43gPh9Y4DjQCtb+X0MH0Hvnw==", + "requires": { + "pg-cloudflare": "^1.1.1", + "pg-connection-string": "^2.6.4", + "pg-pool": "^3.6.2", + "pg-protocol": "^1.6.1", + "pg-types": "^2.1.0", + "pgpass": "1.x" + } + }, + "pg-connection-string": { + "version": "2.6.4", + "resolved": "https://registry.npmjs.org/pg-connection-string/-/pg-connection-string-2.6.4.tgz", + "integrity": "sha512-v+Z7W/0EO707aNMaAEfiGnGL9sxxumwLl2fJvCQtMn9Fxsg+lPpPkdcyBSv/KFgpGdYkMfn+EI1Or2EHjpgLCA==" + } } }, "pg-protocol": { diff --git a/testing/postgres-client-tests/node/package.json b/testing/postgres-client-tests/node/package.json index fc2ac1d935..154266be71 100644 --- a/testing/postgres-client-tests/node/package.json +++ b/testing/postgres-client-tests/node/package.json @@ -11,7 +11,7 @@ "license": "ISC", "dependencies": { "knex": "^2.5.1", - "pg": "^8.11.5", + "pg": "^8.12.0", "pg-promise": "^11.6.0", "wtfnode": "^0.9.2" } diff --git a/testing/postgres-client-tests/node/workbenchTests/branches.js b/testing/postgres-client-tests/node/workbenchTests/branches.js index 94df90b6af..854e5004b7 100644 --- a/testing/postgres-client-tests/node/workbenchTests/branches.js +++ b/testing/postgres-client-tests/node/workbenchTests/branches.js @@ -16,7 +16,7 @@ export const branchTests = [ command: "SELECT", rowCount: 1, oid: null, - rows: [{ dolt_branch: "{0}" }], + rows: [{ dolt_branch: ["0"] }], fields: doltBranchFields, }, }, @@ -96,7 +96,7 @@ export const branchTests = [ command: "SELECT", rowCount: 1, oid: null, - rows: [{ dolt_checkout: `{0,"Switched to branch 'branch-to-delete'"}` }], + rows: [{ dolt_checkout: ["0","Switched to branch 'branch-to-delete'"] }], fields: doltCheckoutFields, }, }, @@ -117,7 +117,7 @@ export const branchTests = [ command: "SELECT", rowCount: 1, oid: null, - rows: [{ dolt_checkout: `{0,"Switched to branch 'main'"}` }], + rows: [{ dolt_checkout: ["0","Switched to branch 'main'"] }], fields: doltCheckoutFields, }, }, @@ -128,7 +128,7 @@ export const branchTests = [ command: "SELECT", rowCount: 1, oid: null, - rows: [{ dolt_branch: "{0}" }], + rows: [{ dolt_branch: ["0"] }], fields: doltBranchFields, }, }, diff --git a/testing/postgres-client-tests/node/workbenchTests/databases.js b/testing/postgres-client-tests/node/workbenchTests/databases.js index 23a0a9f755..78f30f441c 100644 --- a/testing/postgres-client-tests/node/workbenchTests/databases.js +++ b/testing/postgres-client-tests/node/workbenchTests/databases.js @@ -23,8 +23,8 @@ export const databaseTests = [ name: "datname", tableID: 0, columnID: 0, - dataTypeID: 25, - dataTypeSize: -1, + dataTypeID: 19, + dataTypeSize: 252, dataTypeModifier: -1, format: "text", }, @@ -57,8 +57,8 @@ export const databaseTests = [ name: "datname", tableID: 0, columnID: 0, - dataTypeID: 25, - dataTypeSize: -1, + dataTypeID: 19, + dataTypeSize: 252, dataTypeModifier: -1, format: "text", },