diff --git a/README.md b/README.md
index 8e658fe..1f7a2d0 100644
--- a/README.md
+++ b/README.md
@@ -262,6 +262,17 @@ p.AllowAttrs("style").OnElements("span", "p")
p.AllowStyles("color").MatchingHandler(myHandler).Globally()
```
+### Callback Function for element's attributes
+
+If you need to add/modify/delete the attributes of a given element you can use set a callback function with:
+```go
+SetCallbackForAttributes(func(elementName string, attrs []html.Attribute) []html.Attribute {
+ return attrs
+})
+```
+This function will be called before the element's attributes are parsed. The callback function can add/remove/modify the element's attributes.
+If the callback returns nil or empty array of html attributes then the attributes will not be included in the output.
+
### Links
Links are difficult beasts to sanitise safely and also one of the biggest attack vectors for malicious content.
diff --git a/policy.go b/policy.go
index 1a5e00c..1c49aa9 100644
--- a/policy.go
+++ b/policy.go
@@ -37,6 +37,7 @@ import (
"strings"
"github.com/microcosm-cc/bluemonday/css"
+ "golang.org/x/net/html"
)
// Policy encapsulates the allowlist of HTML elements and attributes that will
@@ -150,6 +151,10 @@ type Policy struct {
// and can lead to XSS being rendered thus defeating the purpose of using a
// HTML sanitizer.
allowUnsafe bool
+
+ //callbackAttr is callback function that will be called before element's attributes are parsed. The callback function can add/remove/modify the element's attributes.
+ // If the callback returns nil or empty array of html attributes then the attributes will not be included in the output.
+ callbackAttr callbackAttrFunc
}
type attrPolicy struct {
@@ -190,6 +195,10 @@ type stylePolicyBuilder struct {
handler func(string) bool
}
+// callbackAttrFunc is callback function that will be called before element's attributes are parsed. The callback function can add/remove/modify the element's attributes.
+// If the callback returns nil or empty array of html attributes then the attributes will not be included in the output.
+type callbackAttrFunc = func(elementName string, attrs []html.Attribute) []html.Attribute
+
type urlPolicy func(url *url.URL) (allowUrl bool)
type SandboxValue int64
@@ -241,6 +250,14 @@ func NewPolicy() *Policy {
return &p
}
+// SetCallbackForAttributes sets the callback function that will be called before element's attributes are parsed. The callback function can add/remove/modify the element's attributes.
+// If the callback returns nil or empty array of html attributes then the attributes will not be included in the output.
+// SetCallbackForAttributes is not goroutine safe.
+func (p *Policy) SetCallbackForAttributes(cb callbackAttrFunc) *Policy {
+ p.callbackAttr = cb
+ return p
+}
+
// AllowAttrs takes a range of HTML attribute names and returns an
// attribute policy builder that allows you to specify the pattern and scope of
// the allowed attribute.
diff --git a/sanitize.go b/sanitize.go
index 1046145..fb0932b 100644
--- a/sanitize.go
+++ b/sanitize.go
@@ -497,6 +497,10 @@ func (p *Policy) sanitizeAttrs(
aps map[string][]attrPolicy,
) []html.Attribute {
+ if p.callbackAttr != nil {
+ attrs = p.callbackAttr(elementName, attrs)
+ }
+
if len(attrs) == 0 {
return attrs
}
diff --git a/sanitize_test.go b/sanitize_test.go
index c13a23c..f667c1e 100644
--- a/sanitize_test.go
+++ b/sanitize_test.go
@@ -37,6 +37,8 @@ import (
"strings"
"sync"
"testing"
+
+ "golang.org/x/net/html"
)
// test is a simple input vs output struct used to construct a slice of many
@@ -3931,3 +3933,80 @@ func TestRemovingEmptySelfClosingTag(t *testing.T) {
expected)
}
}
+
+func TestCallbackForAttributes(t *testing.T) {
+
+ tests := []test{
+ {
+ in: ``,
+ expected: ``,
+ },
+ {
+ in: ``,
+ expected: ``,
+ },
+ {
+ in: ``,
+ expected: ``,
+ },
+ {
+ in: ``,
+ expected: ``,
+ },
+ }
+
+ p := UGCPolicy()
+ p.RequireParseableURLs(true)
+ p.RequireNoFollowOnLinks(false)
+ p.RequireNoFollowOnFullyQualifiedLinks(true)
+ p.AddTargetBlankToFullyQualifiedLinks(true)
+
+ p.SetCallbackForAttributes(func(elementName string, attrs []html.Attribute) []html.Attribute {
+
+ if elementName == "img" {
+ for i := 0; i < len(attrs); i++ {
+ if attrs[i].Key == "src" && attrs[i].Val == "giraffe.gif" {
+ attrs[i].Val = "giraffe1.gif"
+ break
+ }
+ if attrs[i].Key == "src" && attrs[i].Val == "new.gif" {
+ return nil
+ }
+ }
+ }
+ if elementName == "a" {
+ for i := 0; i < len(attrs); i++ {
+ if attrs[i].Key == "href" && attrs[i].Val == "?q=1" {
+ attrs[i].Val = "?q=2"
+ break
+ }
+ if attrs[i].Key == "href" && attrs[i].Val == "http://www.google.com" {
+ attrs[i].Val = "http://www.google.com/ATTR"
+ break
+ }
+ }
+ }
+
+ return attrs
+ })
+ // These tests are run concurrently to enable the race detector to pick up
+ // potential issues
+ wg := sync.WaitGroup{}
+ wg.Add(len(tests))
+ for ii, tt := range tests {
+ go func(ii int, tt test) {
+ out := p.Sanitize(tt.in)
+ if out != tt.expected {
+ t.Errorf(
+ "test %d failed;\ninput : %s\noutput : %s\nexpected: %s",
+ ii,
+ tt.in,
+ out,
+ tt.expected,
+ )
+ }
+ wg.Done()
+ }(ii, tt)
+ }
+ wg.Wait()
+}