diff --git a/func.go b/func.go index 02942a4..65cd7d9 100644 --- a/func.go +++ b/func.go @@ -240,7 +240,7 @@ func startwithFunc(arg1, arg2 query) func(query, iterator) interface{} { m, n string ok bool ) - switch typ := arg1.Evaluate(t).(type) { + switch typ := functionArgs(arg1).Evaluate(t).(type) { case string: m = typ case query: @@ -252,7 +252,7 @@ func startwithFunc(arg1, arg2 query) func(query, iterator) interface{} { default: panic(errors.New("starts-with() function argument type must be string")) } - n, ok = arg2.Evaluate(t).(string) + n, ok = functionArgs(arg2).Evaluate(t).(string) if !ok { panic(errors.New("starts-with() function argument type must be string")) } @@ -267,7 +267,7 @@ func endwithFunc(arg1, arg2 query) func(query, iterator) interface{} { m, n string ok bool ) - switch typ := arg1.Evaluate(t).(type) { + switch typ := functionArgs(arg1).Evaluate(t).(type) { case string: m = typ case query: @@ -279,7 +279,7 @@ func endwithFunc(arg1, arg2 query) func(query, iterator) interface{} { default: panic(errors.New("ends-with() function argument type must be string")) } - n, ok = arg2.Evaluate(t).(string) + n, ok = functionArgs(arg2).Evaluate(t).(string) if !ok { panic(errors.New("ends-with() function argument type must be string")) } @@ -294,8 +294,7 @@ func containsFunc(arg1, arg2 query) func(query, iterator) interface{} { m, n string ok bool ) - - switch typ := arg1.Evaluate(t).(type) { + switch typ := functionArgs(arg1).Evaluate(t).(type) { case string: m = typ case query: @@ -308,7 +307,7 @@ func containsFunc(arg1, arg2 query) func(query, iterator) interface{} { panic(errors.New("contains() function argument type must be string")) } - n, ok = arg2.Evaluate(t).(string) + n, ok = functionArgs(arg2).Evaluate(t).(string) if !ok { panic(errors.New("contains() function argument type must be string")) } @@ -345,7 +344,7 @@ func normalizespaceFunc(q query, t iterator) interface{} { func substringFunc(arg1, arg2, arg3 query) func(query, iterator) interface{} { return func(q query, t iterator) interface{} { var m string - switch typ := arg1.Evaluate(t).(type) { + switch typ := functionArgs(arg1).Evaluate(t).(type) { case string: m = typ case query: @@ -359,14 +358,14 @@ func substringFunc(arg1, arg2, arg3 query) func(query, iterator) interface{} { var start, length float64 var ok bool - if start, ok = arg2.Evaluate(t).(float64); !ok { + if start, ok = functionArgs(arg2).Evaluate(t).(float64); !ok { panic(errors.New("substring() function first argument type must be int")) } else if start < 1 { panic(errors.New("substring() function first argument type must be >= 1")) } start-- if arg3 != nil { - if length, ok = arg3.Evaluate(t).(float64); !ok { + if length, ok = functionArgs(arg3).Evaluate(t).(float64); !ok { panic(errors.New("substring() function second argument type must be int")) } } @@ -384,7 +383,7 @@ func substringFunc(arg1, arg2, arg3 query) func(query, iterator) interface{} { func substringIndFunc(arg1, arg2 query, after bool) func(query, iterator) interface{} { return func(q query, t iterator) interface{} { var str string - switch v := arg1.Evaluate(t).(type) { + switch v := functionArgs(arg1).Evaluate(t).(type) { case string: str = v case query: @@ -395,7 +394,7 @@ func substringIndFunc(arg1, arg2 query, after bool) func(query, iterator) interf str = node.Value() } var word string - switch v := arg2.Evaluate(t).(type) { + switch v := functionArgs(arg2).Evaluate(t).(type) { case string: word = v case query: @@ -424,7 +423,7 @@ func substringIndFunc(arg1, arg2 query, after bool) func(query, iterator) interf // equal to the number of characters in a given string. func stringLengthFunc(arg1 query) func(query, iterator) interface{} { return func(q query, t iterator) interface{} { - switch v := arg1.Evaluate(t).(type) { + switch v := functionArgs(arg1).Evaluate(t).(type) { case string: return float64(len(v)) case query: @@ -441,9 +440,9 @@ func stringLengthFunc(arg1 query) func(query, iterator) interface{} { // translateFunc is XPath functions translate() function returns a replaced string. func translateFunc(arg1, arg2, arg3 query) func(query, iterator) interface{} { return func(q query, t iterator) interface{} { - str := asString(t, arg1.Evaluate(t)) - src := asString(t, arg2.Evaluate(t)) - dst := asString(t, arg3.Evaluate(t)) + str := asString(t, functionArgs(arg1).Evaluate(t)) + src := asString(t, functionArgs(arg2).Evaluate(t)) + dst := asString(t, functionArgs(arg3).Evaluate(t)) var replace []string for i, s := range src { @@ -460,9 +459,9 @@ func translateFunc(arg1, arg2, arg3 query) func(query, iterator) interface{} { // replaceFunc is XPath functions replace() function returns a replaced string. func replaceFunc(arg1, arg2, arg3 query) func(query, iterator) interface{} { return func(q query, t iterator) interface{} { - str := asString(t, arg1.Evaluate(t)) - src := asString(t, arg2.Evaluate(t)) - dst := asString(t, arg3.Evaluate(t)) + str := asString(t, functionArgs(arg1).Evaluate(t)) + src := asString(t, functionArgs(arg2).Evaluate(t)) + dst := asString(t, functionArgs(arg3).Evaluate(t)) return strings.Replace(str, src, dst, -1) } @@ -488,6 +487,7 @@ func concatFunc(args ...query) func(query, iterator) interface{} { return func(q query, t iterator) interface{} { var a []string for _, v := range args { + v = functionArgs(v) switch v := v.Evaluate(t).(type) { case string: a = append(a, v) @@ -501,3 +501,11 @@ func concatFunc(args ...query) func(query, iterator) interface{} { return strings.Join(a, "") } } + +// https://github.com/antchfx/xpath/issues/43 +func functionArgs(q query) query { + if _, ok := q.(*functionQuery); ok { + return q + } + return q.Clone() +}