Skip to content

Commit

Permalink
fix: use resilient HTTP client
Browse files Browse the repository at this point in the history
  • Loading branch information
hperl committed Oct 5, 2022
1 parent 6571bae commit e431978
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 30 deletions.
20 changes: 20 additions & 0 deletions .schema/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,26 @@
},
"additionalProperties": false
},
"clients": {
"title": "Global outgoing network settings",
"description": "Configure how outgoing network calls behave.",
"type": "object",
"properties": {
"http": {
"title": "Global HTTP client configuration",
"description": "Configure how outgoing HTTP calls behave.",
"type": "object",
"properties": {
"disallow_private_ip_ranges": {
"title": "Disallow private IP ranges",
"description": "Disallow all outgoing HTTP calls to private IP ranges. This feature can help protect against SSRF attacks.",
"type": "boolean",
"default": false
}
}
}
}
},
"version": {
"type": "string",
"title": "The Keto version this config is written for.",
Expand Down
20 changes: 20 additions & 0 deletions embedx/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,26 @@
},
"additionalProperties": false
},
"clients": {
"title": "Global outgoing network settings",
"description": "Configure how outgoing network calls behave.",
"type": "object",
"properties": {
"http": {
"title": "Global HTTP client configuration",
"description": "Configure how outgoing HTTP calls behave.",
"type": "object",
"properties": {
"disallow_private_ip_ranges": {
"title": "Disallow private IP ranges",
"description": "Disallow all outgoing HTTP calls to private IP ranges. This feature can help protect against SSRF attacks.",
"type": "boolean",
"default": false
}
}
}
}
},
"version": {
"type": "string",
"title": "The Keto version this config is written for.",
Expand Down
23 changes: 4 additions & 19 deletions internal/driver/config/opl_config_namespace_watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"sync"

"github.com/ory/x/logrusx"
Expand Down Expand Up @@ -34,9 +32,9 @@ type (

var _ namespace.Manager = (*oplConfigWatcher)(nil)

func newOPLConfigWatcher(ctx context.Context, l *logrusx.Logger, target string) (*oplConfigWatcher, error) {
func newOPLConfigWatcher(ctx context.Context, c *Config, target string) (*oplConfigWatcher, error) {
nw := &oplConfigWatcher{
logger: l,
logger: c.l,
target: target,
files: configFiles{byPath: make(map[string]io.Reader)},
memoryNamespaceManager: *NewMemoryNamespaceManager(),
Expand All @@ -49,13 +47,12 @@ func newOPLConfigWatcher(ctx context.Context, l *logrusx.Logger, target string)

switch targetUrl.Scheme {
case "file":
return nw, watchTarget(ctx, target, nw, l)
return nw, watchTarget(ctx, target, nw, c.l)
case "http", "https":
file, err := download(ctx, targetUrl)
file, err := c.Fetcher().Fetch(target)
if err != nil {
return nil, err
}
defer file.Close()
nw.files.byPath[targetUrl.String()] = file
nw.parseFiles()
return nw, err
Expand All @@ -64,18 +61,6 @@ func newOPLConfigWatcher(ctx context.Context, l *logrusx.Logger, target string)
}
}

func download(ctx context.Context, url *url.URL) (io.ReadCloser, error) {
req, err := http.NewRequestWithContext(ctx, "GET", url.String(), nil)
if err != nil {
return nil, errors.WithStack(err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, errors.WithStack(err)
}
return resp.Body, nil
}

func (nw *oplConfigWatcher) handleChange(e *watcherx.ChangeEvent) {
// the lock is acquired before parsing to ensure that the getters are
// waiting for the updated values
Expand Down
27 changes: 20 additions & 7 deletions internal/driver/config/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ import (
"fmt"
"sync"

"github.com/ory/x/fetcher"
"github.com/ory/x/httpx"

"github.com/ory/keto/embedx"

"github.com/ory/herodot"
Expand Down Expand Up @@ -192,6 +195,16 @@ func (k *Config) DSN() string {
return dsn
}

func (k *Config) Fetcher() *fetcher.Fetcher {
var opts []httpx.ResilientOptions
if k.p.Bool("clients.http.disallow_private_ip_ranges") {
opts = append(opts, httpx.ResilientClientDisallowInternalIPs())
}
return fetcher.NewFetcher(
fetcher.WithClient(httpx.NewResilientClient(opts...)),
)
}

func (k *Config) TracingServiceName() string {
return k.p.StringF("tracing.service_name", "Ory Keto")
}
Expand All @@ -217,7 +230,7 @@ func (k *Config) NamespaceManager() (namespace.Manager, error) {
return nil, err
}

k.nm, err = nnCfg.newManager()(ctx, k.l)
k.nm, err = nnCfg.newManager()(ctx, k)
if err != nil {
return nil, err
}
Expand All @@ -227,7 +240,7 @@ func (k *Config) NamespaceManager() (namespace.Manager, error) {
}

type (
buildNamespaceFn func(context.Context, *logrusx.Logger) (namespace.Manager, error)
buildNamespaceFn func(context.Context, *Config) (namespace.Manager, error)

namespaceConfig interface {
// newManager builds a new namespace manager.
Expand All @@ -242,16 +255,16 @@ type (
)

func (uri legacyURINamespaceConfig) newManager() buildNamespaceFn {
return func(ctx context.Context, l *logrusx.Logger) (namespace.Manager, error) {
return NewNamespaceWatcher(ctx, l, string(uri))
return func(ctx context.Context, c *Config) (namespace.Manager, error) {
return NewNamespaceWatcher(ctx, c.l, string(uri))
}
}
func (uri legacyURINamespaceConfig) value() any {
return string(uri)
}

func (namespaces literalNamespaceConfig) newManager() buildNamespaceFn {
return func(ctx context.Context, l *logrusx.Logger) (namespace.Manager, error) {
return func(ctx context.Context, _ *Config) (namespace.Manager, error) {
return NewMemoryNamespaceManager(namespaces...), nil
}
}
Expand All @@ -260,7 +273,7 @@ func (namespaces literalNamespaceConfig) value() any {
}

func (oplConfig oplNamespaceConfig) newManager() buildNamespaceFn {
return func(ctx context.Context, l *logrusx.Logger) (namespace.Manager, error) {
return func(ctx context.Context, c *Config) (namespace.Manager, error) {
entry, ok := oplConfig["location"]
if !ok {
return nil, errors.New("location key not found")
Expand All @@ -269,7 +282,7 @@ func (oplConfig oplNamespaceConfig) newManager() buildNamespaceFn {
if !ok {
return nil, fmt.Errorf("config value must be string, was %T", entry)
}
return newOPLConfigWatcher(ctx, l, target)
return newOPLConfigWatcher(ctx, c, target)
}
}
func (oplConfig oplNamespaceConfig) value() any {
Expand Down
22 changes: 18 additions & 4 deletions internal/driver/config/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,25 +189,39 @@ class Group implements Namespace {
t.Cleanup(func() { srv.Close() })

cases := []struct {
name string
location string
name string
location string
disallowPrivateIPRanges bool
expectErr bool
}{{
name: "local file",
location: "file://" + oplConfigFile,
}, {
name: "HTTP url",
name: "HTTP url forbidden",
location: srv.URL,
disallowPrivateIPRanges: true,
expectErr: true,
}, {
name: "HTTP url allowed",
location: srv.URL,
}}

for _, tc := range cases {
t.Run("case="+tc.name, func(t *testing.T) {
config := createFileF(t, `
dsn: memory
clients:
http:
disallow_private_ip_ranges: %v
namespaces:
location: %s`, tc.location)
location: %s`, tc.disallowPrivateIPRanges, tc.location)

_, p := setup(t, config)
nm, err := p.NamespaceManager()
if tc.expectErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
namespaces, err := nm.Namespaces(context.Background())
require.NoError(t, err)
Expand Down

0 comments on commit e431978

Please sign in to comment.