Skip to content

Commit

Permalink
add ltree pgtype support
Browse files Browse the repository at this point in the history
  • Loading branch information
luxifer authored and jackc committed Jan 26, 2024
1 parent c90f82a commit 0fa5333
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 0 deletions.
122 changes: 122 additions & 0 deletions pgtype/ltree.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package pgtype

import (
"database/sql/driver"
"fmt"
)

type LtreeCodec struct{}

func (l LtreeCodec) FormatSupported(format int16) bool {
return format == TextFormatCode || format == BinaryFormatCode
}

// PreferredFormat returns the preferred format.
func (l LtreeCodec) PreferredFormat() int16 {
return TextFormatCode
}

// PlanEncode returns an EncodePlan for encoding value into PostgreSQL format for oid and format. If no plan can be
// found then nil is returned.
func (l LtreeCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan {
switch format {
case TextFormatCode:
return (TextCodec)(l).PlanEncode(m, oid, format, value)
case BinaryFormatCode:
switch value.(type) {
case string:
return encodeLtreeCodecBinaryString{}
case []byte:
return encodeLtreeCodecBinaryByteSlice{}
case TextValuer:
return encodeLtreeCodecBinaryTextValuer{}
}
}

return nil
}

type encodeLtreeCodecBinaryString struct{}

func (encodeLtreeCodecBinaryString) Encode(value any, buf []byte) (newBuf []byte, err error) {
ltree := value.(string)
buf = append(buf, 1)
return append(buf, ltree...), nil
}

type encodeLtreeCodecBinaryByteSlice struct{}

func (encodeLtreeCodecBinaryByteSlice) Encode(value any, buf []byte) (newBuf []byte, err error) {
ltree := value.([]byte)
buf = append(buf, 1)
return append(buf, ltree...), nil
}

type encodeLtreeCodecBinaryTextValuer struct{}

func (encodeLtreeCodecBinaryTextValuer) Encode(value any, buf []byte) (newBuf []byte, err error) {
t, err := value.(TextValuer).TextValue()
if err != nil {
return nil, err
}
if !t.Valid {
return nil, nil
}

buf = append(buf, 1)
return append(buf, t.String...), nil
}

// PlanScan returns a ScanPlan for scanning a PostgreSQL value into a destination with the same type as target. If
// no plan can be found then nil is returned.
func (l LtreeCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan {
switch format {
case TextFormatCode:
return (TextCodec)(l).PlanScan(m, oid, format, target)
case BinaryFormatCode:
switch target.(type) {
case *string:
return scanPlanBinaryLtreeToString{}
case TextScanner:
return scanPlanBinaryLtreeToTextScanner{}
}
}

return nil
}

type scanPlanBinaryLtreeToString struct{}

func (scanPlanBinaryLtreeToString) Scan(src []byte, target any) error {
version := src[0]
if version != 1 {
return fmt.Errorf("unsupported ltree version %d", version)
}

p := (target).(*string)
*p = string(src[1:])

return nil
}

type scanPlanBinaryLtreeToTextScanner struct{}

func (scanPlanBinaryLtreeToTextScanner) Scan(src []byte, target any) error {
version := src[0]
if version != 1 {
return fmt.Errorf("unsupported ltree version %d", version)
}

scanner := (target).(TextScanner)
return scanner.ScanText(Text{String: string(src[1:]), Valid: true})
}

// DecodeDatabaseSQLValue returns src decoded into a value compatible with the sql.Scanner interface.
func (l LtreeCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) {
return (TextCodec)(l).DecodeDatabaseSQLValue(m, oid, format, src)
}

// DecodeValue returns src decoded into its default format.
func (l LtreeCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) {
return (TextCodec)(l).DecodeValue(m, oid, format, src)
}
26 changes: 26 additions & 0 deletions pgtype/ltree_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package pgtype_test

import (
"context"
"testing"

"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/pgxtest"
)

func TestLtreeCodec(t *testing.T) {
skipCockroachDB(t, "Server does not support type ltree")

pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, pgxtest.KnownOIDQueryExecModes, "ltree", []pgxtest.ValueRoundTripTest{
{
Param: "A.B.C",
Result: new(string),
Test: isExpectedEq("A.B.C"),
},
{
Param: pgtype.Text{String: "", Valid: true},
Result: new(pgtype.Text),
Test: isExpectedEq(pgtype.Text{String: "", Valid: true}),
},
})
}

0 comments on commit 0fa5333

Please sign in to comment.