-
Notifications
You must be signed in to change notification settings - Fork 846
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fixes #338
- Loading branch information
Showing
5 changed files
with
446 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
package pgtype | ||
|
||
import ( | ||
"database/sql/driver" | ||
|
||
"github.com/pkg/errors" | ||
) | ||
|
||
type EnumArray struct { | ||
Elements []GenericText | ||
Dimensions []ArrayDimension | ||
Status Status | ||
} | ||
|
||
func (dst *EnumArray) Set(src interface{}) error { | ||
// untyped nil and typed nil interfaces are different | ||
if src == nil { | ||
*dst = EnumArray{Status: Null} | ||
return nil | ||
} | ||
|
||
switch value := src.(type) { | ||
|
||
case []string: | ||
if value == nil { | ||
*dst = EnumArray{Status: Null} | ||
} else if len(value) == 0 { | ||
*dst = EnumArray{Status: Present} | ||
} else { | ||
elements := make([]GenericText, len(value)) | ||
for i := range value { | ||
if err := elements[i].Set(value[i]); err != nil { | ||
return err | ||
} | ||
} | ||
*dst = EnumArray{ | ||
Elements: elements, | ||
Dimensions: []ArrayDimension{{Length: int32(len(elements)), LowerBound: 1}}, | ||
Status: Present, | ||
} | ||
} | ||
|
||
default: | ||
if originalSrc, ok := underlyingSliceType(src); ok { | ||
return dst.Set(originalSrc) | ||
} | ||
return errors.Errorf("cannot convert %v to EnumArray", value) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func (dst *EnumArray) Get() interface{} { | ||
switch dst.Status { | ||
case Present: | ||
return dst | ||
case Null: | ||
return nil | ||
default: | ||
return dst.Status | ||
} | ||
} | ||
|
||
func (src *EnumArray) AssignTo(dst interface{}) error { | ||
switch src.Status { | ||
case Present: | ||
switch v := dst.(type) { | ||
|
||
case *[]string: | ||
*v = make([]string, len(src.Elements)) | ||
for i := range src.Elements { | ||
if err := src.Elements[i].AssignTo(&((*v)[i])); err != nil { | ||
return err | ||
} | ||
} | ||
return nil | ||
|
||
default: | ||
if nextDst, retry := GetAssignToDstType(dst); retry { | ||
return src.AssignTo(nextDst) | ||
} | ||
} | ||
case Null: | ||
return NullAssignTo(dst) | ||
} | ||
|
||
return errors.Errorf("cannot decode %v into %T", src, dst) | ||
} | ||
|
||
func (dst *EnumArray) DecodeText(ci *ConnInfo, src []byte) error { | ||
if src == nil { | ||
*dst = EnumArray{Status: Null} | ||
return nil | ||
} | ||
|
||
uta, err := ParseUntypedTextArray(string(src)) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
var elements []GenericText | ||
|
||
if len(uta.Elements) > 0 { | ||
elements = make([]GenericText, len(uta.Elements)) | ||
|
||
for i, s := range uta.Elements { | ||
var elem GenericText | ||
var elemSrc []byte | ||
if s != "NULL" { | ||
elemSrc = []byte(s) | ||
} | ||
err = elem.DecodeText(ci, elemSrc) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
elements[i] = elem | ||
} | ||
} | ||
|
||
*dst = EnumArray{Elements: elements, Dimensions: uta.Dimensions, Status: Present} | ||
|
||
return nil | ||
} | ||
|
||
func (src *EnumArray) EncodeText(ci *ConnInfo, buf []byte) ([]byte, error) { | ||
switch src.Status { | ||
case Null: | ||
return nil, nil | ||
case Undefined: | ||
return nil, errUndefined | ||
} | ||
|
||
if len(src.Dimensions) == 0 { | ||
return append(buf, '{', '}'), nil | ||
} | ||
|
||
buf = EncodeTextArrayDimensions(buf, src.Dimensions) | ||
|
||
// dimElemCounts is the multiples of elements that each array lies on. For | ||
// example, a single dimension array of length 4 would have a dimElemCounts of | ||
// [4]. A multi-dimensional array of lengths [3,5,2] would have a | ||
// dimElemCounts of [30,10,2]. This is used to simplify when to render a '{' | ||
// or '}'. | ||
dimElemCounts := make([]int, len(src.Dimensions)) | ||
dimElemCounts[len(src.Dimensions)-1] = int(src.Dimensions[len(src.Dimensions)-1].Length) | ||
for i := len(src.Dimensions) - 2; i > -1; i-- { | ||
dimElemCounts[i] = int(src.Dimensions[i].Length) * dimElemCounts[i+1] | ||
} | ||
|
||
inElemBuf := make([]byte, 0, 32) | ||
for i, elem := range src.Elements { | ||
if i > 0 { | ||
buf = append(buf, ',') | ||
} | ||
|
||
for _, dec := range dimElemCounts { | ||
if i%dec == 0 { | ||
buf = append(buf, '{') | ||
} | ||
} | ||
|
||
elemBuf, err := elem.EncodeText(ci, inElemBuf) | ||
if err != nil { | ||
return nil, err | ||
} | ||
if elemBuf == nil { | ||
buf = append(buf, `NULL`...) | ||
} else { | ||
buf = append(buf, QuoteArrayElementIfNeeded(string(elemBuf))...) | ||
} | ||
|
||
for _, dec := range dimElemCounts { | ||
if (i+1)%dec == 0 { | ||
buf = append(buf, '}') | ||
} | ||
} | ||
} | ||
|
||
return buf, nil | ||
} | ||
|
||
// Scan implements the database/sql Scanner interface. | ||
func (dst *EnumArray) Scan(src interface{}) error { | ||
if src == nil { | ||
return dst.DecodeText(nil, nil) | ||
} | ||
|
||
switch src := src.(type) { | ||
case string: | ||
return dst.DecodeText(nil, []byte(src)) | ||
case []byte: | ||
srcCopy := make([]byte, len(src)) | ||
copy(srcCopy, src) | ||
return dst.DecodeText(nil, srcCopy) | ||
} | ||
|
||
return errors.Errorf("cannot scan %T", src) | ||
} | ||
|
||
// Value implements the database/sql/driver Valuer interface. | ||
func (src *EnumArray) Value() (driver.Value, error) { | ||
buf, err := src.EncodeText(nil, nil) | ||
if err != nil { | ||
return nil, err | ||
} | ||
if buf == nil { | ||
return nil, nil | ||
} | ||
|
||
return string(buf), nil | ||
} |
Oops, something went wrong.