Skip to content

Commit

Permalink
Ensure http.send caching works in system.authz (open-policy-agent#4195)
Browse files Browse the repository at this point in the history
Fixes open-policy-agent#3946

Signed-off-by: Anders Eknert <[email protected]>
  • Loading branch information
anderseknert authored Jan 7, 2022
1 parent cf37313 commit 829086a
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 4 deletions.
10 changes: 10 additions & 0 deletions server/authorizer/authorizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/open-policy-agent/opa/server/types"
"github.com/open-policy-agent/opa/server/writer"
"github.com/open-policy-agent/opa/storage"
"github.com/open-policy-agent/opa/topdown/cache"
"github.com/open-policy-agent/opa/topdown/print"
"github.com/open-policy-agent/opa/util"
)
Expand All @@ -31,6 +32,7 @@ type Basic struct {
decision func() ast.Ref
printHook print.Hook
enablePrintStatements bool
interQueryCache cache.InterQueryCache
}

// Runtime returns an argument that sets the runtime on the authorizer.
Expand Down Expand Up @@ -65,6 +67,13 @@ func EnablePrintStatements(yes bool) func(r *Basic) {
}
}

// InterQueryCache enables the inter-query cache on the authorizer
func InterQueryCache(interQueryCache cache.InterQueryCache) func(*Basic) {
return func(b *Basic) {
b.interQueryCache = interQueryCache
}
}

// NewBasic returns a new Basic object.
func NewBasic(inner http.Handler, compiler func() *ast.Compiler, store storage.Store, opts ...func(*Basic)) http.Handler {
b := &Basic{
Expand Down Expand Up @@ -98,6 +107,7 @@ func (h *Basic) ServeHTTP(w http.ResponseWriter, r *http.Request) {
rego.Runtime(h.runtime),
rego.EnablePrintStatements(h.enablePrintStatements),
rego.PrintHook(h.printHook),
rego.InterQueryBuiltinCache(h.interQueryCache),
)

rs, err := rego.Eval(r.Context())
Expand Down
66 changes: 63 additions & 3 deletions server/authorizer/authorizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package authorizer
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"reflect"
Expand All @@ -17,6 +18,7 @@ import (
"github.com/open-policy-agent/opa/server/identifier"
"github.com/open-policy-agent/opa/server/types"
"github.com/open-policy-agent/opa/storage/inmem"
"github.com/open-policy-agent/opa/topdown/cache"
"github.com/open-policy-agent/opa/topdown/print"
"github.com/open-policy-agent/opa/util"
)
Expand Down Expand Up @@ -260,7 +262,7 @@ func TestBasicEscapeError(t *testing.T) {
recorder := httptest.NewRecorder()
req, err := http.NewRequest(http.MethodGet, "http://localhost:8181", nil)
if err != nil {
panic(err)
t.Fatal(err)
}

req.URL.Path = `/invalid/path/foo%LALALA`
Expand Down Expand Up @@ -293,7 +295,7 @@ func TestMakeInput(t *testing.T) {
path := "/foo/bar?pretty=true&explain=\"full\""
req, err := http.NewRequest(http.MethodGet, "http://localhost:8181"+path, nil)
if err != nil {
panic(err)
t.Fatal(err)
}

req.Header.Add("x-custom", "foo")
Expand All @@ -312,7 +314,7 @@ func TestMakeInput(t *testing.T) {

_, result, err := makeInput(req)
if err != nil {
panic(err)
t.Fatal(err)
}

expectedResult := util.MustUnmarshalJSON([]byte(`
Expand Down Expand Up @@ -455,6 +457,64 @@ func TestMakeInputWithBody(t *testing.T) {

}

func TestInterQueryCache(t *testing.T) {

count := 0
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
count++
}))

t.Cleanup(func() {
ts.Close()
})

compiler := func() *ast.Compiler {
module := fmt.Sprintf(`
package system.authz
allow {
http.send({
"method": "GET",
"url": "%v",
"force_cache": true,
"force_cache_duration_seconds": 60
}).status_code == 200
}
`, ts.URL)
c := ast.NewCompiler()
c.Compile(map[string]*ast.Module{
"test.rego": ast.MustParseModule(module),
})
if c.Failed() {
t.Fatalf("Unexpected error compiling test module: %v", c.Errors)
}
return c
}

recorder := httptest.NewRecorder()
req, err := http.NewRequest(http.MethodGet, "http://localhost:8181/v1/data", nil)
if err != nil {
t.Fatal(err)
}

config, _ := cache.ParseCachingConfig(nil)
interQueryCache := cache.NewInterQueryCache(config)

basic := NewBasic(&mockHandler{}, compiler, inmem.New(), InterQueryCache(interQueryCache), Decision(func() ast.Ref {
return ast.MustParseRef("data.system.authz.allow")
}))

// Execute the policy twice
basic.ServeHTTP(recorder, req)
basic.ServeHTTP(recorder, req)

// And make sure the test server was only hit once
if count != 1 {
t.Error("Expected http.send response to be cached")
}
}

func Equal(a, b []string) bool {
if len(a) != len(b) {
return false
Expand Down
3 changes: 2 additions & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,8 @@ func (s *Server) initHandlerAuth(handler http.Handler) http.Handler {
authorizer.Runtime(s.runtime),
authorizer.Decision(s.manager.Config.DefaultAuthorizationDecisionRef),
authorizer.PrintHook(s.manager.PrintHook()),
authorizer.EnablePrintStatements(s.manager.EnablePrintStatements()))
authorizer.EnablePrintStatements(s.manager.EnablePrintStatements()),
authorizer.InterQueryCache(s.interQueryBuiltinCache))
}

switch s.authentication {
Expand Down

0 comments on commit 829086a

Please sign in to comment.