From ec0faf28c1487c87c0d5434c8fdf0695ab3a62da Mon Sep 17 00:00:00 2001 From: Benjamin Ramser Date: Fri, 14 Jan 2022 15:31:26 -0700 Subject: [PATCH] Refactor providers/postgis to use pgx v4 * refactored providers/postgis to use the pgx4 client. Support for Postgres versions > 12 is now possible. * provider/postgis: Properly wrap errors in messages by moving from using %v -> %w when returning errors in messages. * Added error string check for context.Canceled. The underlying net.Dial function is not properly reporting context.Cancel errors. Becuase of this, a string check on the error is performed. There's an open issue for this and it appears it will be fixed eventually but for now we have this check to avoid unnecessary logs. Related issue: https://github.com/golang/go/issues/36208 * added ctxErr() check thewill check if the supplied context has an error (i.e. context canceled) and if so, return that error, else return the supplied error. This is useful as not all of Go's stdlib has adopted error wrapping so context.Canceled errors are not always easy to capture. closes #748 --- atlas/map.go | 8 + go.mod | 1 - provider/postgis/postgis.go | 256 +++++++++++++++---------- provider/postgis/postgis_test.go | 94 ++++----- provider/postgis/util.go | 46 +++-- provider/postgis/util_internal_test.go | 24 ++- server/handle_map_layer_zxy.go | 10 +- 7 files changed, 255 insertions(+), 184 deletions(-) diff --git a/atlas/map.go b/atlas/map.go index c3441225a..eb1e8acc6 100644 --- a/atlas/map.go +++ b/atlas/map.go @@ -323,6 +323,14 @@ func (m Map) encodeMVTTile(ctx context.Context, tile *slippy.Tile) ([]byte, erro case errors.Is(err, context.Canceled): // Do nothing if we were cancelled. + // the underlying net.Dial function is not properly reporting + // context.Canceled errors. Because of this, a string check on the error is performed. + // there's an open issue for this and it appears it will be fixed eventually + // but for now we have this check to avoid unnecessary logs + // https://github.com/golang/go/issues/36208 + case strings.Contains(err.Error(), "operation was canceled"): + // Do nothing, context was canceled + default: z, x, y := tile.ZXY() // TODO (arolek): should we return an error to the response or just log the error? diff --git a/go.mod b/go.mod index 5fb5294c7..54af48e08 100644 --- a/go.mod +++ b/go.mod @@ -18,7 +18,6 @@ require ( github.com/go-test/deep v0.0.0-20170429201529-f49763a6ea0a github.com/gofrs/uuid v4.0.0+incompatible // indirect github.com/golang/protobuf v1.4.3 - github.com/jackc/pgproto3 v1.1.0 github.com/jackc/pgproto3/v2 v2.2.0 github.com/jackc/pgtype v1.9.1 github.com/jackc/pgx/v4 v4.14.1 diff --git a/provider/postgis/postgis.go b/provider/postgis/postgis.go index 3a3e6584b..3d809c0a7 100644 --- a/provider/postgis/postgis.go +++ b/provider/postgis/postgis.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "errors" "fmt" "io/ioutil" "log" @@ -12,24 +13,24 @@ import ( "strings" "time" - "github.com/go-spatial/tegola/observability" - - "github.com/prometheus/client_golang/prometheus" - - "github.com/jackc/pgx" - "github.com/jackc/pgx/pgtype" - "github.com/go-spatial/geom" "github.com/go-spatial/geom/encoding/wkb" "github.com/go-spatial/tegola" "github.com/go-spatial/tegola/dict" + "github.com/go-spatial/tegola/observability" "github.com/go-spatial/tegola/provider" + "github.com/jackc/pgproto3/v2" + "github.com/jackc/pgtype" + gofrs "github.com/jackc/pgtype/ext/gofrs-uuid" + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/pgxpool" + "github.com/prometheus/client_golang/prometheus" ) const Name = "postgis" type connectionPoolCollector struct { - *pgx.ConnPool + *pgxpool.Pool maxConnectionDesc *prometheus.Desc currentConnectionsDesc *prometheus.Desc availableConnectionsDesc *prometheus.Desc @@ -40,24 +41,24 @@ func (c connectionPoolCollector) Describe(ch chan<- *prometheus.Desc) { } func (c connectionPoolCollector) Collect(ch chan<- prometheus.Metric) { - if c.ConnPool == nil { + if c.Pool == nil { return } - stat := c.ConnPool.Stat() + stat := c.Pool.Stat() ch <- prometheus.MustNewConstMetric( c.maxConnectionDesc, prometheus.GaugeValue, - float64(stat.MaxConnections), + float64(stat.MaxConns()), ) ch <- prometheus.MustNewConstMetric( c.currentConnectionsDesc, prometheus.GaugeValue, - float64(stat.CurrentConnections), + float64(stat.AcquiredConns()), ) ch <- prometheus.MustNewConstMetric( c.availableConnectionsDesc, prometheus.GaugeValue, - float64(stat.AvailableConnections), + float64(stat.MaxConns()-stat.AcquiredConns()), ) } @@ -94,7 +95,7 @@ func (c *connectionPoolCollector) Collectors(prefix string, _ func(configKey str // Provider provides the postgis data provider. type Provider struct { - config pgx.ConnPoolConfig + config pgxpool.Config pool *connectionPoolCollector // map of layer name and corresponding sql layers map[string]Layer @@ -189,6 +190,68 @@ const ( // case-insensitive and ignoring any preceeding whitespace and SQL comments. var isSelectQuery = regexp.MustCompile(`(?i)^((\s*)(--.*\n)?)*select`) +type hstoreOID struct { + OID uint32 + hasInit bool +} + +// BuildDBConfig build db config with defaults +func BuildDBConfig(cs string) (*pgxpool.Config, error) { + dbconfig, err := pgxpool.ParseConfig(cs) + if err != nil { + return nil, err + } + + dbconfig.ConnConfig.LogLevel = pgx.LogLevelWarn + dbconfig.ConnConfig.RuntimeParams = map[string]string{ + "default_transaction_read_only": "TRUE", + "application_name": "tegola", + } + + var hstore hstoreOID + + dbconfig.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { + // The AfterConnect call runs everytime a new connection is acquired, + // including everytime the connection pool expands. The hstore OID + // is not constant, so we lookup the OID once per provider and store it. + // Extensions have to be registered for every new connection. + + if !hstore.hasInit { + row := conn.QueryRow(ctx, "SELECT oid FROM pg_type WHERE typname = 'hstore';") + if err = row.Scan(&hstore.OID); err != nil { + switch { + case errors.Is(err, pgx.ErrNoRows): + // do nothing, because query can be empty if hstore is not installed + break + default: + return fmt.Errorf("error fetching hstore oid: %w", err) + } + } + hstore.hasInit = true + } + + // dont register hstore data type if hstore extension is not installed + if hstore.OID != 0 { + conn.ConnInfo().RegisterDataType(pgtype.DataType{ + Value: &pgtype.Hstore{}, + Name: "hstore", + OID: hstore.OID, + }) + } + + // register UUID type, see https://github.com/jackc/pgx/wiki/UUID-Support + conn.ConnInfo().RegisterDataType(pgtype.DataType{ + Value: &gofrs.UUID{}, + Name: "uuid", + OID: pgtype.UUIDOID, + }) + + return nil + } + + return dbconfig, nil +} + // CreateProvider instantiates and returns a new postgis provider or an error. // The function will validate that the config object looks good before // trying to create a driver. This Provider supports the following fields @@ -275,37 +338,29 @@ func CreateProvider(config dict.Dicter, providerType string) (*Provider, error) return nil, err } - connConfig := pgx.ConnConfig{ - Host: host, - Port: uint16(port), - Database: db, - User: user, - Password: password, - LogLevel: pgx.LogLevelWarn, - RuntimeParams: map[string]string{ - "default_transaction_read_only": "TRUE", - "application_name": "tegola", - }, - } + // TODO: allow connection string option in config + cs := fmt.Sprintf("postgres://%v:%v@%v:%v/%v?sslmode=%v&pool_max_conns=%v", + user, password, host, port, db, sslmode, maxcon) - err = ConfigTLS(sslmode, sslkey, sslcert, sslrootcert, &connConfig) + dbconfig, err := BuildDBConfig(cs) if err != nil { + return nil, fmt.Errorf("Failed while building db config: %w", err) + } + + if err = ConfigTLS(sslmode, sslkey, sslcert, sslrootcert, dbconfig); err != nil { return nil, err } p := Provider{ - srid: uint64(srid), - config: pgx.ConnPoolConfig{ - ConnConfig: connConfig, - MaxConnections: int(maxcon), - }, + srid: uint64(srid), + config: *dbconfig, } - pool, err := pgx.NewConnPool(p.config) + pool, err := pgxpool.ConnectConfig(context.Background(), &p.config) if err != nil { - return nil, fmt.Errorf("Failed while creating connection pool: %v", err) + return nil, fmt.Errorf("Failed while creating connection pool: %w", err) } - p.pool = &connectionPoolCollector{ConnPool: pool} + p.pool = &connectionPoolCollector{Pool: pool} layers, err := config.MapSlice(ConfigKeyLayers) if err != nil { @@ -319,7 +374,7 @@ func CreateProvider(config dict.Dicter, providerType string) (*Provider, error) lName, err := layer.String(ConfigKeyLayerName, nil) if err != nil { - return nil, fmt.Errorf("For layer (%v) we got the following error trying to get the layer's name field: %v", i, err) + return nil, fmt.Errorf("For layer (%v) we got the following error trying to get the layer's name field: %w", i, err) } if j, ok := lyrsSeen[lName]; ok { @@ -333,19 +388,19 @@ func CreateProvider(config dict.Dicter, providerType string) (*Provider, error) fields, err := layer.StringSlice(ConfigKeyFields) if err != nil { - return nil, fmt.Errorf("for layer (%v) %v %v field had the following error: %v", i, lName, ConfigKeyFields, err) + return nil, fmt.Errorf("for layer (%v) %v %v field had the following error: %w", i, lName, ConfigKeyFields, err) } geomfld := "geom" geomfld, err = layer.String(ConfigKeyGeomField, &geomfld) if err != nil { - return nil, fmt.Errorf("for layer (%v) %v : %v", i, lName, err) + return nil, fmt.Errorf("for layer (%v) %v : %w", i, lName, err) } idfld := "" idfld, err = layer.String(ConfigKeyGeomIDField, &idfld) if err != nil { - return nil, fmt.Errorf("for layer (%v) %v : %v", i, lName, err) + return nil, fmt.Errorf("for layer (%v) %v : %w", i, lName, err) } if idfld == geomfld { return nil, fmt.Errorf("for layer (%v) %v: %v (%v) and %v field (%v) is the same", i, lName, ConfigKeyGeomField, geomfld, ConfigKeyGeomIDField, idfld) @@ -354,19 +409,19 @@ func CreateProvider(config dict.Dicter, providerType string) (*Provider, error) geomType := "" geomType, err = layer.String(ConfigKeyGeomType, &geomType) if err != nil { - return nil, fmt.Errorf("for layer (%v) %v : %v", i, lName, err) + return nil, fmt.Errorf("for layer (%v) %v : %w", i, lName, err) } var tblName string tblName, err = layer.String(ConfigKeyTablename, &lName) if err != nil { - return nil, fmt.Errorf("for %v layer (%v) %v has an error: %v", i, lName, ConfigKeyTablename, err) + return nil, fmt.Errorf("for %v layer (%v) %v has an error: %w", i, lName, ConfigKeyTablename, err) } var sql string sql, err = layer.String(ConfigKeySQL, &sql) if err != nil { - return nil, fmt.Errorf("for %v layer (%v) %v has an error: %v", i, lName, ConfigKeySQL, err) + return nil, fmt.Errorf("for %v layer (%v) %v has an error: %w", i, lName, ConfigKeySQL, err) } if tblName != lName && sql != "" { @@ -415,7 +470,7 @@ func CreateProvider(config dict.Dicter, providerType string) (*Provider, error) // and if not add them to the list. If Fields list is empty/nil we will use '*' for the field list. l.sql, err = genSQL(&l, p.pool, tblName, fields, true, providerType) if err != nil { - return nil, fmt.Errorf("could not generate sql, for layer(%v): %v", lName, err) + return nil, fmt.Errorf("could not generate sql, for layer(%v): %w", lName, err) } } @@ -426,11 +481,11 @@ func CreateProvider(config dict.Dicter, providerType string) (*Provider, error) // set the layer geom type if geomType != "" { if err = p.setLayerGeomType(&l, geomType); err != nil { - return nil, fmt.Errorf("error fetching geometry type for layer (%v): %v", l.name, err) + return nil, fmt.Errorf("error fetching geometry type for layer (%v): %w", l.name, err) } } else { if err = p.inspectLayerGeomType(&l); err != nil { - return nil, fmt.Errorf("error fetching geometry type for layer (%v): %v", l.name, err) + return nil, fmt.Errorf("error fetching geometry type for layer (%v): %w", l.name, err) } } @@ -445,26 +500,21 @@ func CreateProvider(config dict.Dicter, providerType string) (*Provider, error) } // derived from github.com/jackc/pgx configTLS (https://github.com/jackc/pgx/blob/master/conn.go) -func ConfigTLS(sslMode string, sslKey string, sslCert string, sslRootCert string, cc *pgx.ConnConfig) error { +func ConfigTLS(sslMode string, sslKey string, sslCert string, sslRootCert string, cc *pgxpool.Config) error { switch sslMode { case "disable": - cc.UseFallbackTLS = false - cc.TLSConfig = nil - cc.FallbackTLSConfig = nil + cc.ConnConfig.TLSConfig = nil return nil case "allow": - cc.UseFallbackTLS = true - cc.FallbackTLSConfig = &tls.Config{InsecureSkipVerify: true} + cc.ConnConfig.TLSConfig = &tls.Config{InsecureSkipVerify: true} case "prefer": - cc.TLSConfig = &tls.Config{InsecureSkipVerify: true} - cc.UseFallbackTLS = true - cc.FallbackTLSConfig = nil + cc.ConnConfig.TLSConfig = &tls.Config{InsecureSkipVerify: true} case "require": - cc.TLSConfig = &tls.Config{InsecureSkipVerify: true} + cc.ConnConfig.TLSConfig = &tls.Config{InsecureSkipVerify: true} case "verify-ca", "verify-full": - cc.TLSConfig = &tls.Config{ - ServerName: cc.Host, + cc.ConnConfig.TLSConfig = &tls.Config{ + ServerName: cc.ConnConfig.Host, } default: return ErrInvalidSSLMode(sslMode) @@ -475,15 +525,15 @@ func ConfigTLS(sslMode string, sslKey string, sslCert string, sslRootCert string caCert, err := ioutil.ReadFile(sslRootCert) if err != nil { - return fmt.Errorf("unable to read CA file (%q): %v", sslRootCert, err) + return fmt.Errorf("unable to read CA file (%q): %w", sslRootCert, err) } if !caCertPool.AppendCertsFromPEM(caCert) { return fmt.Errorf("unable to add CA to cert pool") } - cc.TLSConfig.RootCAs = caCertPool - cc.TLSConfig.ClientCAs = caCertPool + cc.ConnConfig.TLSConfig.RootCAs = caCertPool + cc.ConnConfig.TLSConfig.ClientCAs = caCertPool } if (sslCert == "") != (sslKey == "") { @@ -491,10 +541,10 @@ func ConfigTLS(sslMode string, sslKey string, sslCert string, sslRootCert string } else if sslCert != "" { // we must have both now cert, err := tls.LoadX509KeyPair(sslCert, sslKey) if err != nil { - return fmt.Errorf("unable to read cert: %v", err) + return fmt.Errorf("unable to read cert: %w", err) } - cc.TLSConfig.Certificates = []tls.Certificate{cert} + cc.ConnConfig.TLSConfig.Certificates = []tls.Certificate{cert} } return nil @@ -561,7 +611,7 @@ func (p Provider) inspectLayerGeomType(l *Layer) error { return err } - rows, err := p.pool.Query(sql) + rows, err := p.pool.Query(context.Background(), sql) if err != nil { return err } @@ -573,12 +623,12 @@ func (p Provider) inspectLayerGeomType(l *Layer) error { vals, err := rows.Values() if err != nil { - return fmt.Errorf("error running SQL: %v ; %v", sql, err) + return fmt.Errorf("error running SQL: %v ; %w", sql, err) } // iterate the values returned from our row, sniffing for the geomField or st_geometrytype field name for i, v := range vals { - switch fdescs[i].Name { + switch string(fdescs[i].Name) { case l.geomField, "st_geometrytype": switch v { case "ST_Point": @@ -645,8 +695,8 @@ func (p Provider) TileFeatures(ctx context.Context, layer string, tile provider. } sql, err := replaceTokens(plyr.sql, &plyr, tile, true) - if err != nil { - return fmt.Errorf("error replacing layer tokens for layer (%v) SQL (%v): %v", layer, sql, err) + if err := ctxErr(ctx, err); err != nil { + return fmt.Errorf("error replacing layer tokens for layer (%v) SQL (%v): %w", layer, sql, err) } if debugExecuteSQL { @@ -659,7 +709,7 @@ func (p Provider) TileFeatures(ctx context.Context, layer string, tile provider. } now := time.Now() - rows, err := p.pool.Query(sql) + rows, err := p.pool.Query(ctx, sql) if p.queryHistogramSeconds != nil { z, _, _ := tile.ZXY() lbls := prometheus.Labels{ @@ -669,29 +719,18 @@ func (p Provider) TileFeatures(ctx context.Context, layer string, tile provider. } p.queryHistogramSeconds.With(lbls).Observe(time.Since(now).Seconds()) } - if err != nil { - return fmt.Errorf("error running layer (%v) SQL (%v): %v", layer, sql, err) - } + // when using ctxErr, it's import to make sure the defer rows.Close() + // statement happens before the error check. The context may have been + // canceled, but rows were also returned. If we don't close the rows + // the the provider can't clean up the pool and the process will hang + // trying to clean itself up. defer rows.Close() - - // fetch rows FieldDescriptions. this gives us the OID for the data types returned to aid in decoding - fdescs := rows.FieldDescriptions() - - // loop our field descriptions looking for the geometry field - var geomFieldFound bool - for i := range fdescs { - if fdescs[i].Name == plyr.GeomFieldName() { - geomFieldFound = true - break - } - } - if !geomFieldFound { - return ErrGeomFieldNotFound{ - GeomFieldName: plyr.GeomFieldName(), - LayerName: plyr.Name(), - } + if err := ctxErr(ctx, err); err != nil { + return fmt.Errorf("error running layer (%v) SQL (%v): %w", layer, sql, err) } + // fieldDescriptions + var fdescs []pgproto3.FieldDescription reportedLayerFieldName := "" for rows.Next() { // context check @@ -699,20 +738,35 @@ func (p Provider) TileFeatures(ctx context.Context, layer string, tile provider. return err } + // fetch rows FieldDescriptions. this gives us the OID for the data types + // returned to aid in decoding. This only needs to be done once. + if fdescs == nil { + fdescs = rows.FieldDescriptions() + // loop our field descriptions looking for the geometry field + var geomFieldFound bool + for i := range fdescs { + if string(fdescs[i].Name) == plyr.GeomFieldName() { + geomFieldFound = true + break + } + } + if !geomFieldFound { + return ErrGeomFieldNotFound{ + GeomFieldName: plyr.GeomFieldName(), + LayerName: plyr.Name(), + } + } + } + // fetch row values vals, err := rows.Values() - if err != nil { - return fmt.Errorf("error running layer (%v) SQL (%v): %v", layer, sql, err) + if err := ctxErr(ctx, err); err != nil { + return fmt.Errorf("error running layer (%v) SQL (%v): %w", layer, sql, err) } gid, geobytes, tags, err := decipherFields(ctx, plyr.GeomFieldName(), plyr.IDFieldName(), fdescs, vals) - if err != nil { - switch err { - case context.Canceled: - return err - default: - return fmt.Errorf("for layer (%v) %v", plyr.Name(), err) - } + if err := ctxErr(ctx, err); err != nil { + return fmt.Errorf("for layer (%v) %w", plyr.Name(), err) } // check that we have geometry data. if not, skip the feature @@ -733,7 +787,7 @@ func (p Provider) TileFeatures(ctx context.Context, layer string, tile provider. } continue default: - return fmt.Errorf("unable to decode layer (%v) geometry field (%v) into wkb where (%v = %v): %v", layer, plyr.GeomFieldName(), plyr.IDFieldName(), gid, err) + return fmt.Errorf("unable to decode layer (%v) geometry field (%v) into wkb where (%v = %v): %w", layer, plyr.GeomFieldName(), plyr.IDFieldName(), gid, err) } } @@ -782,7 +836,7 @@ func (p Provider) MVTForLayers(ctx context.Context, tile provider.Tile, layers [ log.Printf("SQL for Layer(%v):\n%v\n", l.Name(), l.sql) } sql, err := replaceTokens(l.sql, &l, tile, false) - if err != nil { + if err := ctxErr(ctx, err); err != nil { return nil, err } @@ -814,7 +868,7 @@ func (p Provider) MVTForLayers(ctx context.Context, tile provider.Tile, layers [ } { now := time.Now() - err = p.pool.QueryRow(fsql).Scan(&data) + err = p.pool.QueryRow(ctx, fsql).Scan(&data) if p.mvtProviderQueryHistogramSeconds != nil { z, _, _ := tile.ZXY() lbls := prometheus.Labels{ @@ -835,7 +889,7 @@ func (p Provider) MVTForLayers(ctx context.Context, tile provider.Tile, layers [ } // data may have garbage in it. - if err != nil { + if err := ctxErr(ctx, err); err != nil { return []byte{}, err } return data.Bytes, nil diff --git a/provider/postgis/postgis_test.go b/provider/postgis/postgis_test.go index 28eb4e3e1..17cd03b2f 100644 --- a/provider/postgis/postgis_test.go +++ b/provider/postgis/postgis_test.go @@ -1,6 +1,7 @@ package postgis_test import ( + "fmt" "testing" "context" @@ -9,17 +10,24 @@ import ( "github.com/go-spatial/tegola/internal/ttools" "github.com/go-spatial/tegola/provider" "github.com/go-spatial/tegola/provider/postgis" - "github.com/jackc/pgx" + "github.com/jackc/pgx/v4/pgxpool" ) func TestTLSConfig(t *testing.T) { - testConnConfig := pgx.ConnConfig{ - Host: "testhost", - Port: 8080, - Database: "testdb", - User: "testuser", - Password: "testpassword", + var ( + host = "testhost" + port = 8080 + database = "testdb" + user = "testuser" + password = "testpassword" + ) + + cs := fmt.Sprintf("postgres://%v:%v@%v:%v/%v", user, password, host, port, database) + testConnConfig, err := postgis.BuildDBConfig(cs) + + if err != nil { + t.Fatalf("unable to build db config: %v", err) } type tcase struct { @@ -27,13 +35,13 @@ func TestTLSConfig(t *testing.T) { sslKey string sslCert string sslRootCert string - testFunc func(config pgx.ConnConfig) + testFunc func(config *pgxpool.Config) shouldError bool } fn := func(tc tcase) func(t *testing.T) { return func(t *testing.T) { - err := postgis.ConfigTLS(tc.sslMode, tc.sslKey, tc.sslCert, tc.sslRootCert, &testConnConfig) + err := postgis.ConfigTLS(tc.sslMode, tc.sslKey, tc.sslCert, tc.sslRootCert, testConnConfig) if !tc.shouldError && err != nil { t.Errorf("unable to create a new provider: %v", err) return @@ -53,7 +61,7 @@ func TestTLSConfig(t *testing.T) { sslCert: "", sslRootCert: "", shouldError: true, - testFunc: func(config pgx.ConnConfig) { + testFunc: func(config *pgxpool.Config) { }, }, "2": { @@ -62,17 +70,9 @@ func TestTLSConfig(t *testing.T) { sslCert: "", sslRootCert: "", shouldError: false, - testFunc: func(config pgx.ConnConfig) { - if config.UseFallbackTLS != false { - t.Error("When using disable ssl mode; UseFallbackTLS, expected false got true") - } - - if config.TLSConfig != nil { - t.Errorf("When using disable ssl mode; UseFallbackTLS, expected nil got %v", testConnConfig.TLSConfig) - } - - if config.FallbackTLSConfig != nil { - t.Errorf("When using disable ssl mode; UseFallbackTLS, expected nil got %v", testConnConfig.FallbackTLSConfig) + testFunc: func(config *pgxpool.Config) { + if config.ConnConfig.TLSConfig != nil { + t.Errorf("When using disable ssl mode; UseFallbackTLS, expected nil got %v", testConnConfig.ConnConfig.TLSConfig) } }, }, @@ -82,16 +82,8 @@ func TestTLSConfig(t *testing.T) { sslCert: "", sslRootCert: "", shouldError: false, - testFunc: func(config pgx.ConnConfig) { - if config.UseFallbackTLS != true { - t.Error("When using allow ssl mode; UseFallbackTLS, expected true got false") - } - - if config.FallbackTLSConfig == nil { - t.Error("When using allow ssl mode; UseFallbackTLS, expected not nil got nil") - } - - if config.FallbackTLSConfig != nil && config.FallbackTLSConfig.InsecureSkipVerify == false { + testFunc: func(config *pgxpool.Config) { + if config.ConnConfig.TLSConfig.InsecureSkipVerify == false { t.Error("When using allow ssl mode; UseFallbackTLS.InsecureSkipVerify, expected true got false") } }, @@ -102,20 +94,12 @@ func TestTLSConfig(t *testing.T) { sslCert: "", sslRootCert: "", shouldError: false, - testFunc: func(config pgx.ConnConfig) { - if config.UseFallbackTLS != true { - t.Error("When using prefer ssl mode; UseFallbackTLS, expected true got false") - } - - if config.FallbackTLSConfig != nil { - t.Errorf("When using prefer ssl mode; UseFallbackTLS, expected nil got %v", config.FallbackTLSConfig) - } - - if config.TLSConfig == nil { + testFunc: func(config *pgxpool.Config) { + if config.ConnConfig.TLSConfig == nil { t.Error("When using prefer ssl mode; TLSConfig, expected not nil got nil") } - if config.TLSConfig != nil && config.TLSConfig.InsecureSkipVerify == false { + if config.ConnConfig.TLSConfig != nil && config.ConnConfig.TLSConfig.InsecureSkipVerify == false { t.Error("When using prefer ssl mode; TLSConfig.InsecureSkipVerify, expected true got false") } }, @@ -126,12 +110,12 @@ func TestTLSConfig(t *testing.T) { sslCert: "", sslRootCert: "", shouldError: false, - testFunc: func(config pgx.ConnConfig) { - if config.TLSConfig == nil { + testFunc: func(config *pgxpool.Config) { + if config.ConnConfig.TLSConfig == nil { t.Error("When using prefer ssl mode; TLSConfig, expected not nil got nil") } - if config.TLSConfig != nil && config.TLSConfig.InsecureSkipVerify == false { + if config.ConnConfig.TLSConfig != nil && config.ConnConfig.TLSConfig.InsecureSkipVerify == false { t.Error("When using prefer ssl mode; TLSConfig.InsecureSkipVerify, expected true got false") } }, @@ -142,13 +126,13 @@ func TestTLSConfig(t *testing.T) { sslCert: "", sslRootCert: "", shouldError: false, - testFunc: func(config pgx.ConnConfig) { - if config.TLSConfig == nil { + testFunc: func(config *pgxpool.Config) { + if config.ConnConfig.TLSConfig == nil { t.Error("When using prefer ssl mode; TLSConfig, expected not nil got nil") } - if config.TLSConfig != nil && config.TLSConfig.ServerName != testConnConfig.Host { - t.Errorf("When using prefer ssl mode; TLSConfig.ServerName, expected %s got %s", testConnConfig.Host, config.TLSConfig.ServerName) + if config.ConnConfig.TLSConfig != nil && config.ConnConfig.TLSConfig.ServerName != testConnConfig.ConnConfig.Host { + t.Errorf("When using prefer ssl mode; TLSConfig.ServerName, expected %s got %s", testConnConfig.ConnConfig.Host, config.ConnConfig.TLSConfig.ServerName) } }, }, @@ -158,13 +142,13 @@ func TestTLSConfig(t *testing.T) { sslCert: "", sslRootCert: "", shouldError: false, - testFunc: func(config pgx.ConnConfig) { - if config.TLSConfig == nil { + testFunc: func(config *pgxpool.Config) { + if config.ConnConfig.TLSConfig == nil { t.Error("When using prefer ssl mode; TLSConfig, expected not nil got nil") } - if config.TLSConfig != nil && config.TLSConfig.ServerName != testConnConfig.Host { - t.Errorf("When using prefer ssl mode; TLSConfig.ServerName, expected %s got %s", testConnConfig.Host, config.TLSConfig.ServerName) + if config.ConnConfig.TLSConfig != nil && config.ConnConfig.TLSConfig.ServerName != testConnConfig.ConnConfig.Host { + t.Errorf("When using prefer ssl mode; TLSConfig.ServerName, expected %s got %s", testConnConfig.ConnConfig.Host, config.ConnConfig.TLSConfig.ServerName) } }, }, @@ -473,7 +457,9 @@ func TestTileFeatures(t *testing.T) { LayerConfig: []map[string]interface{}{{ postgis.ConfigKeyLayerName: "missing_geom_field_name", postgis.ConfigKeyGeomField: "geom", - postgis.ConfigKeySQL: "SELECT ST_AsBinary(geom) FROM three_d_test WHERE geom && !BBOX!", + // this SQL is a workaround the normal !BBOX! token check. We don't care about the bounding + // box query, but rather simulating the missing geom column to trigger the error we're testing for. + postgis.ConfigKeySQL: "SELECT ST_AsBinary(geom), !BBOX! AS bbox FROM three_d_test", }}, }, tile: provider.NewTile(16, 11241, 26168, 64, tegola.WebMercator), diff --git a/provider/postgis/util.go b/provider/postgis/util.go index 5338ede4b..0bac49ed9 100644 --- a/provider/postgis/util.go +++ b/provider/postgis/util.go @@ -12,8 +12,8 @@ import ( "github.com/go-spatial/tegola" "github.com/go-spatial/tegola/basic" "github.com/go-spatial/tegola/provider" - "github.com/jackc/pgx" - "github.com/jackc/pgx/pgtype" + "github.com/jackc/pgproto3/v2" + "github.com/jackc/pgtype" ) // isMVT will return true if the provider is MVT based @@ -37,7 +37,7 @@ func genSQL(l *Layer, pool *connectionPoolCollector, tblname string, flds []stri return "", err } - rows, err := pool.Query(sql) + rows, err := pool.Query(context.Background(), sql) if err != nil { return "", err } @@ -45,13 +45,13 @@ func genSQL(l *Layer, pool *connectionPoolCollector, tblname string, flds []stri fdescs := rows.FieldDescriptions() if len(fdescs) == 0 { - return "", fmt.Errorf("No fields were returned for table %v", tblname) + return "", fmt.Errorf("no fields were returned for table %v", tblname) } // to avoid field names possibly colliding with Postgres keywords, // we wrap the field names in quotes for i := range fdescs { - flds = append(flds, fdescs[i].Name) + flds = append(flds, string(fdescs[i].Name)) } } @@ -144,12 +144,12 @@ func replaceTokens(sql string, lyr *Layer, tile provider.Tile, withBuffer bool) // TODO: it's currently assumed the tile will always be in WebMercator. Need to support different projections minGeo, err := basic.FromWebMercator(srid, geom.Point{extent.MinX(), extent.MinY()}) if err != nil { - return "", fmt.Errorf("Error trying to convert tile point: %v ", err) + return "", fmt.Errorf("Error trying to convert tile point: %w ", err) } maxGeo, err := basic.FromWebMercator(srid, geom.Point{extent.MaxX(), extent.MaxY()}) if err != nil { - return "", fmt.Errorf("Error trying to convert tile point: %v ", err) + return "", fmt.Errorf("Error trying to convert tile point: %w ", err) } minPt, maxPt := minGeo.(geom.Point), maxGeo.(geom.Point) @@ -209,7 +209,7 @@ func transformVal(valType pgtype.OID, val interface{}) (interface{}, error) { } case pgtype.BoolOID, pgtype.ByteaOID, pgtype.TextOID, pgtype.OIDOID, pgtype.VarcharOID, pgtype.JSONBOID: return val, nil - case pgtype.Int8OID, pgtype.Int2OID, pgtype.Int4OID, pgtype.Float4OID, pgtype.Float8OID: + case pgtype.Int8OID, pgtype.Int2OID, pgtype.NumericOID, pgtype.Int4OID, pgtype.Float4OID, pgtype.Float8OID: switch vt := val.(type) { case int8: return int64(vt), nil @@ -238,13 +238,14 @@ func transformVal(valType pgtype.OID, val interface{}) (interface{}, error) { } // decipherFields is responsible for processing the SQL result set, decoding geometries, ids and feature tags. -func decipherFields(ctx context.Context, geomFieldname, idFieldname string, descriptions []pgx.FieldDescription, values []interface{}) (gid uint64, geom []byte, tags map[string]interface{}, err error) { +func decipherFields(ctx context.Context, geomFieldname, idFieldname string, descriptions []pgproto3.FieldDescription, values []interface{}) (gid uint64, geom []byte, tags map[string]interface{}, err error) { var ok bool tags = make(map[string]interface{}) var idParsed bool for i := range values { + // do a quick check if err := ctx.Err(); err != nil { return 0, nil, nil, err @@ -256,8 +257,9 @@ func decipherFields(ctx context.Context, geomFieldname, idFieldname string, desc } desc := descriptions[i] + descName := string(desc.Name) - switch desc.Name { + switch descName { case geomFieldname: if geom, ok = values[i].([]byte); !ok { return 0, nil, nil, fmt.Errorf("unable to convert geometry field (%v) into bytes", geomFieldname) @@ -284,18 +286,16 @@ func decipherFields(ctx context.Context, geomFieldname, idFieldname string, desc tags[k] = v.String } } - case *pgtype.Numeric: + case pgtype.Numeric: var num float64 vex.AssignTo(&num) - - tags[desc.Name] = num + tags[descName] = num default: - value, err := transformVal(desc.DataType, values[i]) + value, err := transformVal(pgtype.OID(desc.DataTypeOID), values[i]) if err != nil { - return gid, geom, tags, fmt.Errorf("unable to convert field [%v] (%v) of type (%v - %v) to a suitable value: %+v", i, desc.Name, desc.DataType, desc.DataTypeName, values[i]) + return gid, geom, tags, fmt.Errorf("unable to convert field [%v] (%v) of type (%v - %v) to a suitable value: %+v (%T)", i, descName, desc.DataTypeOID, pgtype.OID(desc.DataTypeOID), values[i], values[i]) } - - tags[desc.Name] = value + tags[descName] = value } } } @@ -329,3 +329,15 @@ func gId(v interface{}) (gid uint64, err error) { return gid, fmt.Errorf("unable to convert field into a uint64") } } + +// ctxErr will check if the supplied context has an error (i.e. context canceled) +// and if so, return that error, else return the supplied error. This is useful +// as not all of Go's stdlib has adopted error wrapping so context.Canceled +// errors are not always easy to capture. +func ctxErr(ctx context.Context, err error) error { + if ctxErr := ctx.Err(); ctxErr != nil { + return ctxErr + } + + return err +} diff --git a/provider/postgis/util_internal_test.go b/provider/postgis/util_internal_test.go index c4a30b7c5..33a5c774f 100644 --- a/provider/postgis/util_internal_test.go +++ b/provider/postgis/util_internal_test.go @@ -2,10 +2,11 @@ package postgis import ( "context" + "fmt" "os" "testing" - "github.com/jackc/pgx" + "github.com/jackc/pgx/v4/pgxpool" "github.com/go-spatial/tegola" "github.com/go-spatial/tegola/internal/ttools" @@ -123,15 +124,20 @@ func TestDecipherFields(t *testing.T) { expectedTags map[string]interface{} } - cc := pgx.ConnConfig{ - Host: os.Getenv("PGHOST"), - Port: 5432, - Database: os.Getenv("PGDATABASE"), - User: os.Getenv("PGUSER"), - Password: os.Getenv("PGPASSWORD"), + host := os.Getenv("PGHOST") + port := 5432 + db := os.Getenv("PGDATABASE") + user := os.Getenv("PGUSER") + password := os.Getenv("PGPASSWORD") + + cs := fmt.Sprintf("postgres://%v:%v@%v:%v/%v", user, password, host, port, db) + dbconfig, err := BuildDBConfig(cs) + + if err != nil { + t.Fatalf("unable to build db config: %v", err) } - conn, err := pgx.Connect(cc) + conn, err := pgxpool.ConnectConfig(context.Background(), dbconfig) if err != nil { t.Fatalf("unable to connect to database: %v", err) } @@ -139,7 +145,7 @@ func TestDecipherFields(t *testing.T) { fn := func(tc tcase) func(t *testing.T) { return func(t *testing.T) { - rows, err := conn.Query(tc.sql) + rows, err := conn.Query(context.Background(), tc.sql) if err != nil { t.Errorf("Error performing query: %v", err) return diff --git a/server/handle_map_layer_zxy.go b/server/handle_map_layer_zxy.go index bfd1fe448..30651764e 100644 --- a/server/handle_map_layer_zxy.go +++ b/server/handle_map_layer_zxy.go @@ -2,6 +2,7 @@ package server import ( "context" + "errors" "fmt" "net/http" "strconv" @@ -158,10 +159,15 @@ func (req HandleMapLayerZXY) ServeHTTP(w http.ResponseWriter, r *http.Request) { encodeCtx := context.WithValue(r.Context(), observability.ObserveVarMapName, m.Name) pbyte, err := m.Encode(encodeCtx, tile) + if err != nil { - switch err { - case context.Canceled: + switch { + case errors.Is(err, context.Canceled): // TODO: add debug logs + // do nothing + return + case strings.Contains(err.Error(), "operation was canceled"): + // do nothing return default: errMsg := fmt.Sprintf("error marshalling tile: %v", err)