Skip to content

Commit

Permalink
Add support for array of enum
Browse files Browse the repository at this point in the history
fixes #338
  • Loading branch information
jackc committed Oct 18, 2017
1 parent ac5d463 commit ab9a1af
Show file tree
Hide file tree
Showing 5 changed files with 446 additions and 0 deletions.
39 changes: 39 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,45 @@ where (

c.ConnInfo = pgtype.NewConnInfo()
c.ConnInfo.InitializeDataTypes(nameOIDs)

return c.initConnInfoEnumArray()
}

// initConnInfoEnumArray introspects for arrays of enums and registers a data type for them.
func (c *Conn) initConnInfoEnumArray() error {
nameOIDs := make(map[string]pgtype.OID, 16)

rows, err := c.Query(`select t.oid, t.typname
from pg_type t
join pg_type base_type on t.typelem=base_type.oid
where t.typtype = 'b'
and base_type.typtype = 'e'`)
if err != nil {
return err
}

for rows.Next() {
var oid pgtype.OID
var name pgtype.Text
if err := rows.Scan(&oid, &name); err != nil {
return err
}

nameOIDs[name.String] = oid
}

if rows.Err() != nil {
return rows.Err()
}

for name, oid := range nameOIDs {
c.ConnInfo.RegisterDataType(pgtype.DataType{
&pgtype.EnumArray{},
name,
oid,
})
}

return nil
}

Expand Down
41 changes: 41 additions & 0 deletions pgmock/pgmock.go
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,47 @@ where (
steps = append(steps, SendMessage(&pgproto3.CommandComplete{CommandTag: "SELECT 163"}))
steps = append(steps, SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}))

steps = append(steps, []Step{
ExpectMessage(&pgproto3.Parse{
Query: "select t.oid, t.typname\nfrom pg_type t\n join pg_type base_type on t.typelem=base_type.oid\nwhere t.typtype = 'b'\n and base_type.typtype = 'e'",
}),
ExpectMessage(&pgproto3.Describe{
ObjectType: 'S',
}),
ExpectMessage(&pgproto3.Sync{}),
SendMessage(&pgproto3.ParseComplete{}),
SendMessage(&pgproto3.ParameterDescription{}),
SendMessage(&pgproto3.RowDescription{
Fields: []pgproto3.FieldDescription{
{Name: "oid",
TableOID: 1247,
TableAttributeNumber: 65534,
DataTypeOID: 26,
DataTypeSize: 4,
TypeModifier: 4294967295,
Format: 0,
},
{Name: "typname",
TableOID: 1247,
TableAttributeNumber: 1,
DataTypeOID: 19,
DataTypeSize: 64,
TypeModifier: 4294967295,
Format: 0,
},
},
}),
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
ExpectMessage(&pgproto3.Bind{
ResultFormatCodes: []int16{1, 1},
}),
ExpectMessage(&pgproto3.Execute{}),
ExpectMessage(&pgproto3.Sync{}),
SendMessage(&pgproto3.BindComplete{}),
SendMessage(&pgproto3.CommandComplete{CommandTag: "SELECT 0"}),
SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
}...)

return steps
}

Expand Down
212 changes: 212 additions & 0 deletions pgtype/enum_array.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
package pgtype

import (
"database/sql/driver"

"github.com/pkg/errors"
)

type EnumArray struct {
Elements []GenericText
Dimensions []ArrayDimension
Status Status
}

func (dst *EnumArray) Set(src interface{}) error {
// untyped nil and typed nil interfaces are different
if src == nil {
*dst = EnumArray{Status: Null}
return nil
}

switch value := src.(type) {

case []string:
if value == nil {
*dst = EnumArray{Status: Null}
} else if len(value) == 0 {
*dst = EnumArray{Status: Present}
} else {
elements := make([]GenericText, len(value))
for i := range value {
if err := elements[i].Set(value[i]); err != nil {
return err
}
}
*dst = EnumArray{
Elements: elements,
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}},
Status: Present,
}
}

default:
if originalSrc, ok := underlyingSliceType(src); ok {
return dst.Set(originalSrc)
}
return errors.Errorf("cannot convert %v to EnumArray", value)
}

return nil
}

func (dst *EnumArray) Get() interface{} {
switch dst.Status {
case Present:
return dst
case Null:
return nil
default:
return dst.Status
}
}

func (src *EnumArray) AssignTo(dst interface{}) error {
switch src.Status {
case Present:
switch v := dst.(type) {

case *[]string:
*v = make([]string, len(src.Elements))
for i := range src.Elements {
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil {
return err
}
}
return nil

default:
if nextDst, retry := GetAssignToDstType(dst); retry {
return src.AssignTo(nextDst)
}
}
case Null:
return NullAssignTo(dst)
}

return errors.Errorf("cannot decode %v into %T", src, dst)
}

func (dst *EnumArray) DecodeText(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = EnumArray{Status: Null}
return nil
}

uta, err := ParseUntypedTextArray(string(src))
if err != nil {
return err
}

var elements []GenericText

if len(uta.Elements) > 0 {
elements = make([]GenericText, len(uta.Elements))

for i, s := range uta.Elements {
var elem GenericText
var elemSrc []byte
if s != "NULL" {
elemSrc = []byte(s)
}
err = elem.DecodeText(ci, elemSrc)
if err != nil {
return err
}

elements[i] = elem
}
}

*dst = EnumArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present}

return nil
}

func (src *EnumArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) {
switch src.Status {
case Null:
return nil, nil
case Undefined:
return nil, errUndefined
}

if len(src.Dimensions) == 0 {
return append(buf, '{', '}'), nil
}

buf = EncodeTextArrayDimensions(buf, src.Dimensions)

// dimElemCounts is the multiples of elements that each array lies on. For
// example, a single dimension array of length 4 would have a dimElemCounts of
// [4]. A multi-dimensional array of lengths [3,5,2] would have a
// dimElemCounts of [30,10,2]. This is used to simplify when to render a '{'
// or '}'.
dimElemCounts := make([]int, len(src.Dimensions))
dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length)
for i := len(src.Dimensions) - 2; i > -1; i-- {
dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1]
}

inElemBuf := make([]byte, 0, 32)
for i, elem := range src.Elements {
if i > 0 {
buf = append(buf, ',')
}

for _, dec := range dimElemCounts {
if i%dec == 0 {
buf = append(buf, '{')
}
}

elemBuf, err := elem.EncodeText(ci, inElemBuf)
if err != nil {
return nil, err
}
if elemBuf == nil {
buf = append(buf, `NULL`...)
} else {
buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...)
}

for _, dec := range dimElemCounts {
if (i+1)%dec == 0 {
buf = append(buf, '}')
}
}
}

return buf, nil
}

// Scan implements the database/sql Scanner interface.
func (dst *EnumArray) Scan(src interface{}) error {
if src == nil {
return dst.DecodeText(nil, nil)
}

switch src := src.(type) {
case string:
return dst.DecodeText(nil, []byte(src))
case []byte:
srcCopy := make([]byte, len(src))
copy(srcCopy, src)
return dst.DecodeText(nil, srcCopy)
}

return errors.Errorf("cannot scan %T", src)
}

// Value implements the database/sql/driver Valuer interface.
func (src *EnumArray) Value() (driver.Value, error) {
buf, err := src.EncodeText(nil, nil)
if err != nil {
return nil, err
}
if buf == nil {
return nil, nil
}

return string(buf), nil
}
Loading

0 comments on commit ab9a1af

Please sign in to comment.