diff --git a/proxy/credentials_issuer_headers.go b/proxy/credentials_issuer_headers.go index c05ad38fc9..17631e8a46 100644 --- a/proxy/credentials_issuer_headers.go +++ b/proxy/credentials_issuer_headers.go @@ -14,12 +14,12 @@ import ( type CredentialsHeadersConfig map[string]string type CredentialsHeaders struct { - rulesCache *template.Template + RulesCache *template.Template } func NewCredentialsIssuerHeaders() *CredentialsHeaders { return &CredentialsHeaders{ - rulesCache: template.New("rules"). + RulesCache: template.New("rules"). Option("missingkey=zero"). Funcs(template.FuncMap{ "print": func(i interface{}) string { @@ -52,9 +52,9 @@ func (a *CredentialsHeaders) Issue(r *http.Request, session *AuthenticationSessi var err error templateId := fmt.Sprintf("%s:%s", rl.ID, hdr) - tmpl = a.rulesCache.Lookup(templateId) + tmpl = a.RulesCache.Lookup(templateId) if tmpl == nil { - tmpl, err = a.rulesCache.New(templateId).Parse(templateString) + tmpl, err = a.RulesCache.New(templateId).Parse(templateString) if err != nil { return errors.Wrapf(err, `error parsing header template "%s" in rule "%s"`, templateString, rl.ID) } diff --git a/proxy/credentials_issuer_headers_test.go b/proxy/credentials_issuer_headers_test.go index cafb448ca5..2dffca5f0e 100644 --- a/proxy/credentials_issuer_headers_test.go +++ b/proxy/credentials_issuer_headers_test.go @@ -1,9 +1,12 @@ package proxy import ( + "bytes" "encoding/json" + "fmt" "net/http" "testing" + "text/template" "github.com/ory/oathkeeper/rule" "github.com/stretchr/testify/assert" @@ -106,4 +109,30 @@ func TestCredentialsIssuerHeaders(t *testing.T) { assert.Equal(t, specs.Match, specs.Request.Header) }) } + + t.Run("Caching", func(t *testing.T) { + for _, specs := range testMap { + issuer := NewCredentialsIssuerHeaders() + + overrideHeaders := http.Header{} + + cache := template.New("rules") + + var cfg CredentialsHeadersConfig + d := json.NewDecoder(bytes.NewBuffer(specs.Config)) + require.NoError(t, d.Decode(&cfg)) + + for hdr, _ := range cfg { + templateId := fmt.Sprintf("%s:%s", specs.Rule.ID, hdr) + cache.New(templateId).Parse("override") + overrideHeaders.Add(hdr, "override") + } + + issuer.RulesCache = cache + + require.NoError(t, issuer.Issue(specs.Request, specs.Session, specs.Config, specs.Rule)) + + assert.Equal(t, overrideHeaders, specs.Request.Header) + } + }) }