diff --git a/README.md b/README.md index 8e658fe..1f7a2d0 100644 --- a/README.md +++ b/README.md @@ -262,6 +262,17 @@ p.AllowAttrs("style").OnElements("span", "p") p.AllowStyles("color").MatchingHandler(myHandler).Globally() ``` +### Callback Function for element's attributes + +If you need to add/modify/delete the attributes of a given element you can use set a callback function with: +```go +SetCallbackForAttributes(func(elementName string, attrs []html.Attribute) []html.Attribute { + return attrs +}) +``` +This function will be called before the element's attributes are parsed. The callback function can add/remove/modify the element's attributes. +If the callback returns nil or empty array of html attributes then the attributes will not be included in the output. + ### Links Links are difficult beasts to sanitise safely and also one of the biggest attack vectors for malicious content. diff --git a/policy.go b/policy.go index 1a5e00c..1c49aa9 100644 --- a/policy.go +++ b/policy.go @@ -37,6 +37,7 @@ import ( "strings" "github.com/microcosm-cc/bluemonday/css" + "golang.org/x/net/html" ) // Policy encapsulates the allowlist of HTML elements and attributes that will @@ -150,6 +151,10 @@ type Policy struct { // and can lead to XSS being rendered thus defeating the purpose of using a // HTML sanitizer. allowUnsafe bool + + //callbackAttr is callback function that will be called before element's attributes are parsed. The callback function can add/remove/modify the element's attributes. + // If the callback returns nil or empty array of html attributes then the attributes will not be included in the output. + callbackAttr callbackAttrFunc } type attrPolicy struct { @@ -190,6 +195,10 @@ type stylePolicyBuilder struct { handler func(string) bool } +// callbackAttrFunc is callback function that will be called before element's attributes are parsed. The callback function can add/remove/modify the element's attributes. +// If the callback returns nil or empty array of html attributes then the attributes will not be included in the output. +type callbackAttrFunc = func(elementName string, attrs []html.Attribute) []html.Attribute + type urlPolicy func(url *url.URL) (allowUrl bool) type SandboxValue int64 @@ -241,6 +250,14 @@ func NewPolicy() *Policy { return &p } +// SetCallbackForAttributes sets the callback function that will be called before element's attributes are parsed. The callback function can add/remove/modify the element's attributes. +// If the callback returns nil or empty array of html attributes then the attributes will not be included in the output. +// SetCallbackForAttributes is not goroutine safe. +func (p *Policy) SetCallbackForAttributes(cb callbackAttrFunc) *Policy { + p.callbackAttr = cb + return p +} + // AllowAttrs takes a range of HTML attribute names and returns an // attribute policy builder that allows you to specify the pattern and scope of // the allowed attribute. diff --git a/sanitize.go b/sanitize.go index 1046145..fb0932b 100644 --- a/sanitize.go +++ b/sanitize.go @@ -497,6 +497,10 @@ func (p *Policy) sanitizeAttrs( aps map[string][]attrPolicy, ) []html.Attribute { + if p.callbackAttr != nil { + attrs = p.callbackAttr(elementName, attrs) + } + if len(attrs) == 0 { return attrs } diff --git a/sanitize_test.go b/sanitize_test.go index c13a23c..f667c1e 100644 --- a/sanitize_test.go +++ b/sanitize_test.go @@ -37,6 +37,8 @@ import ( "strings" "sync" "testing" + + "golang.org/x/net/html" ) // test is a simple input vs output struct used to construct a slice of many @@ -3931,3 +3933,80 @@ func TestRemovingEmptySelfClosingTag(t *testing.T) { expected) } } + +func TestCallbackForAttributes(t *testing.T) { + + tests := []test{ + { + in: ``, + expected: ``, + }, + { + in: ``, + expected: ``, + }, + { + in: ``, + expected: ``, + }, + { + in: ``, + expected: ``, + }, + } + + p := UGCPolicy() + p.RequireParseableURLs(true) + p.RequireNoFollowOnLinks(false) + p.RequireNoFollowOnFullyQualifiedLinks(true) + p.AddTargetBlankToFullyQualifiedLinks(true) + + p.SetCallbackForAttributes(func(elementName string, attrs []html.Attribute) []html.Attribute { + + if elementName == "img" { + for i := 0; i < len(attrs); i++ { + if attrs[i].Key == "src" && attrs[i].Val == "giraffe.gif" { + attrs[i].Val = "giraffe1.gif" + break + } + if attrs[i].Key == "src" && attrs[i].Val == "new.gif" { + return nil + } + } + } + if elementName == "a" { + for i := 0; i < len(attrs); i++ { + if attrs[i].Key == "href" && attrs[i].Val == "?q=1" { + attrs[i].Val = "?q=2" + break + } + if attrs[i].Key == "href" && attrs[i].Val == "http://www.google.com" { + attrs[i].Val = "http://www.google.com/ATTR" + break + } + } + } + + return attrs + }) + // These tests are run concurrently to enable the race detector to pick up + // potential issues + wg := sync.WaitGroup{} + wg.Add(len(tests)) + for ii, tt := range tests { + go func(ii int, tt test) { + out := p.Sanitize(tt.in) + if out != tt.expected { + t.Errorf( + "test %d failed;\ninput : %s\noutput : %s\nexpected: %s", + ii, + tt.in, + out, + tt.expected, + ) + } + wg.Done() + }(ii, tt) + } + wg.Wait() +}