Skip to content

Commit

Permalink
fix: CORS config values are ignored (#789)
Browse files Browse the repository at this point in the history
Co-authored-by: zepatrik <[email protected]>
  • Loading branch information
vancanhuit and zepatrik authored Dec 1, 2021
1 parent e459b2e commit ffeb5e3
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 18 deletions.
10 changes: 8 additions & 2 deletions internal/driver/config/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,14 @@ func (k *Config) WriteAPIListenOn() string {
)
}

func (k *Config) CORS() (cors.Options, bool) {
return k.p.CORS("serve", cors.Options{
func (k *Config) CORS(iface string) (cors.Options, bool) {
switch iface {
case "read", "write":
default:
panic("expected interface 'read' or 'write', but got unknown interface " + iface)
}

return k.p.CORS("serve."+iface, cors.Options{
AllowedMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE"},
AllowedHeaders: []string{"Authorization", "Content-Type"},
ExposedHeaders: []string{"Content-Type"},
Expand Down
17 changes: 15 additions & 2 deletions internal/driver/registry_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"sync"

"github.com/ory/x/networkx"
"github.com/rs/cors"

"github.com/gobuffalo/pop/v5"
"github.com/ory/x/popx"
Expand Down Expand Up @@ -293,7 +294,13 @@ func (r *RegistryDefault) ReadRouter() http.Handler {
n.Use(r.sqaService)
}

return n
var handler http.Handler = n
options, enabled := r.Config().CORS("read")
if enabled {
handler = cors.New(options).Handler(handler)
}

return handler
}

func (r *RegistryDefault) WriteRouter() http.Handler {
Expand All @@ -318,7 +325,13 @@ func (r *RegistryDefault) WriteRouter() http.Handler {
n.Use(r.sqaService)
}

return n
var handler http.Handler = n
options, enabled := r.Config().CORS("write")
if enabled {
handler = cors.New(options).Handler(handler)
}

return handler
}

func (r *RegistryDefault) unaryInterceptors() []grpc.UnaryServerInterceptor {
Expand Down
31 changes: 30 additions & 1 deletion internal/e2e/full_suit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ package e2e

import (
"fmt"
"net/http"
"testing"
"time"

"github.com/stretchr/testify/assert"

"github.com/ory/keto/internal/x/dbx"

Expand Down Expand Up @@ -40,7 +44,7 @@ type (
func Test(t *testing.T) {
for _, dsn := range dbx.GetDSNs(t, false) {
t.Run(fmt.Sprintf("dsn=%s", dsn.Name), func(t *testing.T) {
ctx, reg, addNamespace := newInitializedReg(t, dsn)
ctx, reg, addNamespace := newInitializedReg(t, dsn, nil)

closeServer := startServer(ctx, t, reg)
defer closeServer()
Expand Down Expand Up @@ -76,3 +80,28 @@ func Test(t *testing.T) {
})
}
}

func TestServeConfig(t *testing.T) {
ctx, reg, _ := newInitializedReg(t, dbx.GetSqlite(t, dbx.SQLiteMemory), map[string]interface{}{
"serve.read.cors.enabled": true,
"serve.read.cors.debug": true,
"serve.read.cors.allowed_methods": []string{http.MethodGet},
"serve.read.cors.allowed_origins": []string{"https://ory.sh"},
})

closeServer := startServer(ctx, t, reg)
defer closeServer()

for !healthReady(t, "http://"+reg.Config().ReadAPIListenOn()) {
t.Log("Waiting for health check to be ready")
time.Sleep(10 * time.Millisecond)
}

req, err := http.NewRequest(http.MethodOptions, "http://"+reg.Config().ReadAPIListenOn()+relationtuple.RouteBase, nil)
require.NoError(t, err)
req.Header.Set("Origin", "https://ory.sh")
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "https://ory.sh", resp.Header.Get("Access-Control-Allow-Origin"), "%+v", resp.Header)
}
11 changes: 8 additions & 3 deletions internal/e2e/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
"github.com/ory/keto/internal/driver"
)

func newInitializedReg(t testing.TB, dsn *dbx.DsnT) (context.Context, driver.Registry, func(*testing.T, ...*namespace.Namespace)) {
func newInitializedReg(t testing.TB, dsn *dbx.DsnT, cfgOverwrites map[string]interface{}) (context.Context, driver.Registry, func(*testing.T, ...*namespace.Namespace)) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(func() {
cancel()
Expand All @@ -38,15 +38,20 @@ func newInitializedReg(t testing.TB, dsn *dbx.DsnT) (context.Context, driver.Reg
flags := pflag.NewFlagSet("", pflag.ContinueOnError)
configx.RegisterConfigFlag(flags, nil)

cf := dbx.ConfigFile(t, map[string]interface{}{
cfgValues := map[string]interface{}{
config.KeyDSN: dsn.Conn,
"log.level": "debug",
"log.leak_sensitive_values": true,
config.KeyReadAPIHost: "127.0.0.1",
config.KeyReadAPIPort: ports[0],
config.KeyWriteAPIHost: "127.0.0.1",
config.KeyWriteAPIPort: ports[1],
})
}
for k, v := range cfgOverwrites {
cfgValues[k] = v
}

cf := dbx.ConfigFile(t, cfgValues)
require.NoError(t, flags.Parse([]string{"--" + configx.FlagConfig, cf}))

reg, err := driver.NewDefaultRegistry(ctx, flags, true)
Expand Down
21 changes: 11 additions & 10 deletions internal/e2e/rest_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,18 +149,19 @@ func (rc *restClient) expand(t require.TestingT, r *relationtuple.SubjectSet, de
return tree
}

func (rc *restClient) waitUntilLive(t require.TestingT) {
var healthReady = func() bool {
req, err := http.NewRequest("GET", rc.readURL+healthx.ReadyCheckPath, nil)
require.NoError(t, err)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return false
}
return resp.StatusCode == http.StatusOK
func healthReady(t require.TestingT, readURL string) bool {
req, err := http.NewRequest("GET", readURL+healthx.ReadyCheckPath, nil)
require.NoError(t, err)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return false
}
return resp.StatusCode == http.StatusOK
}

func (rc *restClient) waitUntilLive(t require.TestingT) {
// wait for /health/ready
for !healthReady() {
for !healthReady(t, rc.readURL) {
time.Sleep(10 * time.Millisecond)
}
}
3 changes: 3 additions & 0 deletions internal/x/dbx/dsn_testutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ func GetSqlite(t testing.TB, mode sqliteMode) *DsnT {
case SQLiteMemory:
dsn.Name = "memory"
dsn.Conn = fmt.Sprintf("sqlite://file:%s?_fk=true&cache=shared&mode=memory", t.Name())
t.Cleanup(func() {
_ = os.Remove(t.Name())
})
case SQLiteFile:
t.Cleanup(func() {
_ = os.Remove("TestDB.sqlite")
Expand Down

0 comments on commit ffeb5e3

Please sign in to comment.