diff --git a/docs/generated/sql/functions.md b/docs/generated/sql/functions.md index 20d0f653770f..e4a8cdc76c12 100644 --- a/docs/generated/sql/functions.md +++ b/docs/generated/sql/functions.md @@ -905,7 +905,15 @@ available replica will error.

to_tsvector(text: string) → tsvector

Converts text to a tsvector, normalizing words according to the default configuration. Position information is included in the result.

Stable ts_parse(parser_name: string, document: string) → tuple{int AS tokid, string AS token}

ts_parse parses the given document and returns a series of records, one for each token produced by parsing. Each record includes a tokid showing the assigned token type and a token which is the text of the token.

-
Stable +Stable +ts_rank(vector: tsvector, query: tsquery) → float4

Ranks vectors based on the frequency of their matching lexemes.

+
Immutable +ts_rank(vector: tsvector, query: tsquery, normalization: int) → float4

Ranks vectors based on the frequency of their matching lexemes.

+
Immutable +ts_rank(weights: float[], vector: tsvector, query: tsquery) → float4

Ranks vectors based on the frequency of their matching lexemes.

+
Immutable +ts_rank(weights: float[], vector: tsvector, query: tsquery, normalization: int) → float4

Ranks vectors based on the frequency of their matching lexemes.

+
Immutable ### Fuzzy String Matching functions diff --git a/pkg/sql/logictest/testdata/logic_test/tsvector b/pkg/sql/logictest/testdata/logic_test/tsvector index 33d8a851d593..14c17947e0ad 100644 --- a/pkg/sql/logictest/testdata/logic_test/tsvector +++ b/pkg/sql/logictest/testdata/logic_test/tsvector @@ -298,3 +298,39 @@ query T SELECT to_tsvector('Hello I am a potato') ---- 'am':3 'hell':1 'i':2 'potat':5 + +query TT +SELECT to_tsvector('english', ''), to_tsvector('english', 'and the') +---- +· · + +statement error doesn't contain lexemes +SELECT to_tsquery('english', 'the') + +statement ok +CREATE TABLE sentences (sentence text, v TSVECTOR AS (to_tsvector('english', sentence)) STORED, INVERTED INDEX (v)); +INSERT INTO sentences VALUES + ('Future users of large data banks must be protected from having to know how the data is organized in the machine (the internal representation).'), + ('A prompting service which supplies such information is not a satisfactory solution.'), + ('Activities of users at terminals and most application programs should remain unaffected when the internal representation of data is changed and even when some aspects of the external representation + are changed.'), + ('Changes in data representation will often be needed as a result of changes in query, update, and report traffic and natural growth in the types of stored information.'), + ('Existing noninferential, formatted data systems provide users with tree-structured files or slightly more general network models of the data.'), + ('In Section 1, inadequacies of these models are discussed.'), + ('A model based on n-ary relations, a normal form for data base relations, and the concept of a universal data sublanguage are introduced.'), + ('In Section 2, certain operations on relations (other than logical inference) are discussed and applied to the problems of redundancy and consistency in the user’s model.') + +query FFFFT +SELECT +ts_rank(v, query) AS rank, +ts_rank(ARRAY[0.2, 0.3, 0.5, 0.9]:::FLOAT[], v, query) AS wrank, +ts_rank(v, query, 2|8) AS nrank, +ts_rank(ARRAY[0.3, 0.4, 0.6, 0.95]:::FLOAT[], v, query, 1|2|4|8|16|32) AS wnrank, +v +FROM sentences, to_tsquery('english', 'relation') query +WHERE query @@ v +ORDER BY rank DESC +LIMIT 10 +---- +0.075990885 0.15198177 0.00042217158 8.555783e-05 'ari':6 'base':3,13 'concept':17 'data':12,21 'form':10 'introduc':24 'model':2 'n':5 'normal':9 'relat':7,14 'sublanguag':22 'univers':20 +0.06079271 0.12158542 0.0003101669 6.095758e-05 '2':3 'appli':15 'certain':4 'consist':22 'discuss':13 'infer':11 'logic':10 'model':27 'oper':5 'problem':18 'redund':20 'relat':7 'section':2 'user':25 diff --git a/pkg/sql/sem/builtins/builtins.go b/pkg/sql/sem/builtins/builtins.go index d3100ea03f49..34b2a1094b6b 100644 --- a/pkg/sql/sem/builtins/builtins.go +++ b/pkg/sql/sem/builtins/builtins.go @@ -3765,7 +3765,6 @@ value if you rely on the HLC for accuracy.`, "jsonb_to_tsvector": makeBuiltin(tree.FunctionProperties{UnsupportedWithIssue: 7821, Category: builtinconstants.CategoryFullTextSearch}), "ts_delete": makeBuiltin(tree.FunctionProperties{UnsupportedWithIssue: 7821, Category: builtinconstants.CategoryFullTextSearch}), "ts_filter": makeBuiltin(tree.FunctionProperties{UnsupportedWithIssue: 7821, Category: builtinconstants.CategoryFullTextSearch}), - "ts_rank": makeBuiltin(tree.FunctionProperties{UnsupportedWithIssue: 7821, Category: builtinconstants.CategoryFullTextSearch}), "ts_rank_cd": makeBuiltin(tree.FunctionProperties{UnsupportedWithIssue: 7821, Category: builtinconstants.CategoryFullTextSearch}), "ts_rewrite": makeBuiltin(tree.FunctionProperties{UnsupportedWithIssue: 7821, Category: builtinconstants.CategoryFullTextSearch}), "tsquery_phrase": makeBuiltin(tree.FunctionProperties{UnsupportedWithIssue: 7821, Category: builtinconstants.CategoryFullTextSearch}), diff --git a/pkg/sql/sem/builtins/fixed_oids.go b/pkg/sql/sem/builtins/fixed_oids.go index 435b16d75e9c..680d96e4a513 100644 --- a/pkg/sql/sem/builtins/fixed_oids.go +++ b/pkg/sql/sem/builtins/fixed_oids.go @@ -2373,6 +2373,10 @@ var builtinOidsArray = []string{ 2399: `to_tsvector(text: string) -> tsvector`, 2400: `phraseto_tsquery(text: string) -> tsquery`, 2401: `plainto_tsquery(text: string) -> tsquery`, + 2402: `ts_rank(weights: float[], vector: tsvector, query: tsquery, normalization: int) -> float4`, + 2403: `ts_rank(vector: tsvector, query: tsquery, normalization: int) -> float4`, + 2404: `ts_rank(vector: tsvector, query: tsquery) -> float4`, + 2405: `ts_rank(weights: float[], vector: tsvector, query: tsquery) -> float4`, } var builtinOidsBySignature map[string]oid.Oid diff --git a/pkg/sql/sem/builtins/tsearch_builtins.go b/pkg/sql/sem/builtins/tsearch_builtins.go index e5e3a4633c79..627f1897b9b3 100644 --- a/pkg/sql/sem/builtins/tsearch_builtins.go +++ b/pkg/sql/sem/builtins/tsearch_builtins.go @@ -231,4 +231,117 @@ var tsearchBuiltins = map[string]builtinDefinition{ Volatility: volatility.Stable, }, ), + "ts_rank": makeBuiltin( + tree.FunctionProperties{}, + tree.Overload{ + Types: tree.ParamTypes{ + {Name: "weights", Typ: types.FloatArray}, + {Name: "vector", Typ: types.TSVector}, + {Name: "query", Typ: types.TSQuery}, + {Name: "normalization", Typ: types.Int}, + }, + ReturnType: tree.FixedReturnType(types.Float4), + Fn: func(_ context.Context, evalCtx *eval.Context, args tree.Datums) (tree.Datum, error) { + weights, err := getWeights(tree.MustBeDArray(args[0])) + if err != nil { + return nil, err + } + rank, err := tsearch.Rank( + weights, + tree.MustBeDTSVector(args[1]).TSVector, + tree.MustBeDTSQuery(args[2]).TSQuery, + int(tree.MustBeDInt(args[3])), + ) + if err != nil { + return nil, err + } + return tree.NewDFloat(tree.DFloat(rank)), nil + }, + Info: "Ranks vectors based on the frequency of their matching lexemes.", + Volatility: volatility.Immutable, + }, + tree.Overload{ + Types: tree.ParamTypes{ + {Name: "weights", Typ: types.FloatArray}, + {Name: "vector", Typ: types.TSVector}, + {Name: "query", Typ: types.TSQuery}, + }, + ReturnType: tree.FixedReturnType(types.Float4), + Fn: func(_ context.Context, evalCtx *eval.Context, args tree.Datums) (tree.Datum, error) { + weights, err := getWeights(tree.MustBeDArray(args[0])) + if err != nil { + return nil, err + } + rank, err := tsearch.Rank( + weights, + tree.MustBeDTSVector(args[1]).TSVector, + tree.MustBeDTSQuery(args[2]).TSQuery, + 0, + ) + if err != nil { + return nil, err + } + return tree.NewDFloat(tree.DFloat(rank)), nil + }, + Info: "Ranks vectors based on the frequency of their matching lexemes.", + Volatility: volatility.Immutable, + }, + tree.Overload{ + Types: tree.ParamTypes{ + {Name: "vector", Typ: types.TSVector}, + {Name: "query", Typ: types.TSQuery}, + {Name: "normalization", Typ: types.Int}, + }, + ReturnType: tree.FixedReturnType(types.Float4), + Fn: func(_ context.Context, evalCtx *eval.Context, args tree.Datums) (tree.Datum, error) { + rank, err := tsearch.Rank( + nil, /* weights */ + tree.MustBeDTSVector(args[0]).TSVector, + tree.MustBeDTSQuery(args[1]).TSQuery, + int(tree.MustBeDInt(args[2])), + ) + if err != nil { + return nil, err + } + return tree.NewDFloat(tree.DFloat(rank)), nil + }, + Info: "Ranks vectors based on the frequency of their matching lexemes.", + Volatility: volatility.Immutable, + }, + tree.Overload{ + Types: tree.ParamTypes{ + {Name: "vector", Typ: types.TSVector}, + {Name: "query", Typ: types.TSQuery}, + }, + ReturnType: tree.FixedReturnType(types.Float4), + Fn: func(_ context.Context, evalCtx *eval.Context, args tree.Datums) (tree.Datum, error) { + rank, err := tsearch.Rank( + nil, /* weights */ + tree.MustBeDTSVector(args[0]).TSVector, + tree.MustBeDTSQuery(args[1]).TSQuery, + 0, /* method */ + ) + if err != nil { + return nil, err + } + return tree.NewDFloat(tree.DFloat(rank)), nil + }, + Info: "Ranks vectors based on the frequency of their matching lexemes.", + Volatility: volatility.Immutable, + }, + ), +} + +func getWeights(arr *tree.DArray) ([]float32, error) { + ret := make([]float32, 4) + if arr.Len() < len(ret) { + return ret, pgerror.New(pgcode.ArraySubscript, "array of weight is too short (must be at least 4)") + } + for i, d := range arr.Array { + if d == tree.DNull { + return ret, pgerror.New(pgcode.NullValueNotAllowed, "array of weight must not contain null") + } + ret[i] = float32(tree.MustBeDFloat(d)) + } + return ret, nil } diff --git a/pkg/util/tsearch/BUILD.bazel b/pkg/util/tsearch/BUILD.bazel index cc7cc235500e..01e25c3c2780 100644 --- a/pkg/util/tsearch/BUILD.bazel +++ b/pkg/util/tsearch/BUILD.bazel @@ -9,6 +9,7 @@ go_library( "eval.go", "lex.go", "random.go", + "rank.go", "snowball.go", "stopwords.go", "tsquery.go", @@ -63,6 +64,7 @@ go_test( srcs = [ "encoding_test.go", "eval_test.go", + "rank_test.go", "tsquery_test.go", "tsvector_test.go", ], diff --git a/pkg/util/tsearch/rank.go b/pkg/util/tsearch/rank.go new file mode 100644 index 000000000000..757f736af6ec --- /dev/null +++ b/pkg/util/tsearch/rank.go @@ -0,0 +1,286 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package tsearch + +import ( + "math" + "sort" + "strings" +) + +// defaultWeights is the default list of weights corresponding to the tsvector +// lexeme weights D, C, B, and A. +var defaultWeights = [4]float32{0.1, 0.2, 0.4, 1.0} + +// Bitmask for the normalization integer. These define different ranking +// behaviors. They're defined in Postgres in tsrank.c. +// 0, the default, ignores the document length. +// 1 devides the rank by 1 + the logarithm of the document length. +// 2 divides the rank by the document length. +// 4 divides the rank by the mean harmonic distance between extents. +// +// NOTE: This is only implemented by ts_rank_cd, which is currently not +// implemented by CockroachDB. This constant is left for consistency with +// the original PostgreSQL source code. +// +// 8 divides the rank by the number of unique words in document. +// 16 divides the rank by 1 + the logarithm of the number of unique words in document. +// 32 divides the rank by itself + 1. +type rankBehavior int + +const ( + // rankNoNorm is the default. It ignores the document length. + rankNoNorm rankBehavior = 0x0 + // rankNormLoglength divides the rank by 1 + the logarithm of the document length. + rankNormLoglength = 0x01 + // rankNormLength divides the rank by the document length. + rankNormLength = 0x02 + // rankNormExtdist divides the rank by the mean harmonic distance between extents. + // Note, this is only implemented by ts_rank_cd, which is not currently implemented + // by CockroachDB. The constant is kept for consistency with Postgres. + rankNormExtdist = 0x04 + // rankNormUniq divides the rank by the number of unique words in document. + rankNormUniq = 0x08 + // rankNormLoguniq divides the rank by 1 + the logarithm of the number of unique words in document. + rankNormLoguniq = 0x10 + // rankNormRdivrplus1 divides the rank by itself + 1. + rankNormRdivrplus1 = 0x20 +) + +// Defeat the unused linter. +var _ = rankNoNorm +var _ = rankNormExtdist + +// cntLen returns the count of represented lexemes in a tsvector, including +// the number of repeated lexemes in the vector. +func cntLen(v TSVector) int { + var ret int + for i := range v { + posLen := len(v[i].positions) + if posLen > 0 { + ret += posLen + } else { + ret += 1 + } + } + return ret +} + +// Rank implements the ts_rank functionality, which ranks a tsvector against a +// tsquery. The weights parameter is a list of weights corresponding to the +// tsvector lexeme weights D, C, B, and A. The method parameter is a bitmask +// defining different ranking behaviors, defined in the rankBehavior type +// above in this file. The default ranking behavior is 0, which doesn't perform +// any normalization based on the document length. +// +// N.B.: this function is directly translated from the calc_rank function in +// tsrank.c, which contains almost no comments. As of this time, I am unable +// to sufficiently explain how this ranker works, but I'm confident that the +// implementation is at least compatible with Postgres. +// https://github.com/postgres/postgres/blob/765f5df726918bcdcfd16bcc5418e48663d1dd59/src/backend/utils/adt/tsrank.c#L357 +func Rank(weights []float32, v TSVector, q TSQuery, method int) (float32, error) { + w := defaultWeights + if weights != nil { + copy(w[:4], weights[:4]) + } + if len(v) == 0 || q.root == nil { + return 0, nil + } + var res float32 + if q.root.op == and || q.root.op == followedby { + res = rankAnd(w, v, q) + } else { + res = rankOr(w, v, q) + } + if res < 0 { + // This constant is taken from the Postgres source code, unfortunately I + // don't understand its meaning. + res = 1e-20 + } + if method&rankNormLoglength > 0 { + res /= float32(math.Log(float64(cntLen(v)+1)) / math.Log(2.0)) + } + + if method&rankNormLength > 0 { + l := cntLen(v) + if l > 0 { + res /= float32(l) + } + } + // rankNormExtDist is not applicable - it's only used for ts_rank_cd. + + if method&rankNormUniq > 0 { + res /= float32(len(v)) + } + + if method&rankNormLoguniq > 0 { + res /= float32(math.Log(float64(len(v)+1)) / math.Log(2.0)) + } + + if method&rankNormRdivrplus1 > 0 { + res /= res + 1 + } + + return res, nil +} + +func sortAndDistinctQueryTerms(q TSQuery) []*tsNode { + // Extract all leaf nodes from the query tree. + leafNodes := make([]*tsNode, 0) + var extractTerms func(q *tsNode) + extractTerms = func(q *tsNode) { + if q == nil { + return + } + if q.op != invalid { + extractTerms(q.l) + extractTerms(q.r) + } else { + leafNodes = append(leafNodes, q) + } + } + extractTerms(q.root) + // Sort the terms. + sort.Slice(leafNodes, func(i, j int) bool { + return leafNodes[i].term.lexeme < leafNodes[j].term.lexeme + }) + // Then distinct: (wouldn't it be nice if Go had generics?) + lastUniqueIdx := 0 + for j := 1; j < len(leafNodes); j++ { + if leafNodes[j].term.lexeme != leafNodes[lastUniqueIdx].term.lexeme { + // We found a unique entry, at index i. The last unique entry in the array + // was at lastUniqueIdx, so set the entry after that one to our new unique + // entry, and bump lastUniqueIdx for the next loop iteration. + lastUniqueIdx++ + leafNodes[lastUniqueIdx] = leafNodes[j] + } + } + leafNodes = leafNodes[:lastUniqueIdx+1] + return leafNodes +} + +// findRankMatches finds all matches for a given query term in a tsvector, +// regardless of the expected query weight. +// query is the term being matched. v is the tsvector being searched. +// matches is a slice of matches to append to, to save on allocations as this +// function is called in a loop. +func findRankMatches(query *tsNode, v TSVector, matches [][]tsPosition) [][]tsPosition { + target := query.term.lexeme + i := sort.Search(len(v), func(i int) bool { + return v[i].lexeme >= target + }) + if i >= len(v) { + return matches + } + if query.term.isPrefixMatch() { + for j := i; j < len(v); j++ { + t := v[j] + if !strings.HasPrefix(t.lexeme, target) { + break + } + matches = append(matches, t.positions) + } + } else if v[i].lexeme == target { + matches = append(matches, v[i].positions) + } + return matches +} + +// rankOr computes the rank for a query with an OR operator at its root. +// It takes the same parameters as TSRank. +func rankOr(weights [4]float32, v TSVector, q TSQuery) float32 { + queryLeaves := sortAndDistinctQueryTerms(q) + var matches = make([][]tsPosition, 0) + var res float32 + for i := range queryLeaves { + matches = matches[:0] + matches = findRankMatches(queryLeaves[i], v, matches) + if len(matches) == 0 { + continue + } + resj := float32(0.0) + wjm := float32(-1.0) + jm := 0 + for _, innerMatches := range matches { + for j, pos := range innerMatches { + termWeight := pos.weight.val() + weight := weights[termWeight] + resj = resj + weight/float32((j+1)*(j+1)) + if weight > wjm { + wjm = weight + jm = j + } + } + } + // Explanation from Postgres tsrank.c: + // limit (sum(1/i^2),i=1,inf) = pi^2/6 + // resj = sum(wi/i^2),i=1,noccurence, + // wi - should be sorted desc, + // don't sort for now, just choose maximum weight. This should be corrected + // Oleg Bartunov + res = res + (wjm+resj-wjm/float32((jm+1)*(jm+1)))/1.64493406685 + } + if len(queryLeaves) > 0 { + res /= float32(len(queryLeaves)) + } + return res +} + +// rankAnd computes the rank for a query with an AND or followed-by operator at +// its root. It takes the same parameters as TSRank. +func rankAnd(weights [4]float32, v TSVector, q TSQuery) float32 { + queryLeaves := sortAndDistinctQueryTerms(q) + if len(queryLeaves) < 2 { + return rankOr(weights, v, q) + } + pos := make([][]tsPosition, len(queryLeaves)) + res := float32(-1) + var matches = make([][]tsPosition, 0) + for i := range queryLeaves { + matches = matches[:0] + matches = findRankMatches(queryLeaves[i], v, matches) + for _, innerMatches := range matches { + pos[i] = innerMatches + // Loop back through the earlier position matches + for k := 0; k < i; k++ { + if pos[k] == nil { + continue + } + for l := range pos[i] { + // For each of the earlier matches + for p := range pos[k] { + dist := int(pos[i][l].position) - int(pos[k][p].position) + if dist < 0 { + dist = -dist + } + if dist != 0 { + curw := float32(math.Sqrt(float64(weights[pos[i][l].weight.val()] * weights[pos[k][p].weight.val()] * wordDistance(dist)))) + if res < 0 { + res = curw + } else { + res = 1.0 - (1.0-res)*(1.0-curw) + } + } + } + } + } + } + } + return res +} + +// Returns a weight of a word collocation. See Postgres tsrank.c. +func wordDistance(dist int) float32 { + if dist > 100 { + return 1e-30 + } + return float32(1.0 / (1.005 + 0.05*math.Exp(float64(float32(dist)/1.5-2)))) +} diff --git a/pkg/util/tsearch/rank_test.go b/pkg/util/tsearch/rank_test.go new file mode 100644 index 000000000000..900f6e13100f --- /dev/null +++ b/pkg/util/tsearch/rank_test.go @@ -0,0 +1,46 @@ +// Copyright 2023 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package tsearch + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRank(t *testing.T) { + tests := []struct { + weights []float32 + v string + q string + method int + expected float32 + }{ + {v: "a:1 s:2C d g", q: "a | s", expected: 0.091189064}, + {v: "a:1 sa:2C d g", q: "a | s", expected: 0.030396355}, + {v: "a:1 sa:2C d g", q: "a | s:*", expected: 0.091189064}, + {v: "a:1 sa:2C d g", q: "a | sa:*", expected: 0.091189064}, + {v: "a:1 s:2B d g", q: "a | s", expected: 0.15198177}, + {v: "a:1 s:2 d g", q: "a | s", expected: 0.06079271}, + {v: "a:1 s:2C d g", q: "a & s", expected: 0.14015312}, + {v: "a:1 s:2B d g", q: "a & s", expected: 0.19820644}, + {v: "a:1 s:2 d g", q: "a & s", expected: 0.09910322}, + } + for _, tt := range tests { + v, err := ParseTSVector(tt.v) + assert.NoError(t, err) + q, err := ParseTSQuery(tt.q) + assert.NoError(t, err) + actual, err := Rank(tt.weights, v, q, tt.method) + assert.NoError(t, err) + assert.Equalf(t, tt.expected, actual, "Rank(%v, %v, %v, %v)", tt.weights, tt.v, tt.q, tt.method) + } +} diff --git a/pkg/util/tsearch/tsquery.go b/pkg/util/tsearch/tsquery.go index bc5c0a7ea76a..761860495911 100644 --- a/pkg/util/tsearch/tsquery.go +++ b/pkg/util/tsearch/tsquery.go @@ -499,12 +499,13 @@ func toTSQuery(config string, interpose tsOperator, input string) (TSQuery, erro } tokens = append(tokens, term) } - lexeme, ok, err := TSLexize(config, lexemeTokens[j]) + lexeme, stopWord, err := TSLexize(config, lexemeTokens[j]) if err != nil { return TSQuery{}, err } - if !ok { + if stopWord { foundStopwords = true + //continue } tokens = append(tokens, tsTerm{lexeme: lexeme, positions: tok.positions}) } @@ -518,17 +519,20 @@ func toTSQuery(config string, interpose tsOperator, input string) (TSQuery, erro } if foundStopwords { - return cleanupStopwords(query) + query = cleanupStopwords(query) + if query.root == nil { + return query, pgerror.Newf(pgcode.Syntax, "text-search query doesn't contain lexemes: %s", input) + } } - return query, nil + return query, err } -func cleanupStopwords(query TSQuery) (TSQuery, error) { +func cleanupStopwords(query TSQuery) TSQuery { query.root, _, _ = cleanupStopword(query.root) if query.root == nil { - return TSQuery{}, nil + return TSQuery{} } - return query, nil + return query } // cleanupStopword cleans up a query tree by removing stop words and adjusting diff --git a/pkg/util/tsearch/tsvector.go b/pkg/util/tsearch/tsvector.go index f811f9d4144e..3417b5d99809 100644 --- a/pkg/util/tsearch/tsvector.go +++ b/pkg/util/tsearch/tsvector.go @@ -121,6 +121,14 @@ func (w tsWeight) TSVectorPGEncoding() (byte, error) { return 0, errors.Errorf("invalid tsvector weight %d", w) } +func (w tsWeight) val() int { + b, err := w.TSVectorPGEncoding() + if err != nil { + panic(err) + } + return int(b) +} + // matches returns true if the receiver is matched by the input tsquery weight. func (w tsWeight) matches(queryWeight tsWeight) bool { if queryWeight == weightAny { @@ -287,6 +295,10 @@ func (t tsTerm) matchesWeight(targetWeight tsWeight) bool { return false } +func (t tsTerm) isPrefixMatch() bool { + return len(t.positions) >= 1 && t.positions[0].weight&weightStar != 0 +} + // TSVector is a sorted list of terms, each of which is a lexeme that might have // an associated position within an original document. type TSVector []tsTerm @@ -392,16 +404,16 @@ func TSParse(input string) []string { // TSLexize implements the "dictionary" construct that's exposed via ts_lexize. // It gets invoked once per input token to produce an output lexeme during // routines like to_tsvector and to_tsquery. -// It can return false in the second parameter to indicate a stopword was found. -func TSLexize(config string, token string) (lexeme string, notAStopWord bool, err error) { - stopwords, notAStopWord := stopwordsMap[config] - if !notAStopWord { +// It can return true in the second parameter to indicate a stopword was found. +func TSLexize(config string, token string) (lexeme string, stopWord bool, err error) { + stopwords, ok := stopwordsMap[config] + if !ok { return "", false, pgerror.Newf(pgcode.UndefinedObject, "text search configuration %q does not exist", config) } lower := strings.ToLower(token) if _, ok := stopwords[lower]; ok { - return "", false, nil + return "", true, nil } stemmer, err := getStemmer(config) if err != nil { @@ -409,7 +421,7 @@ func TSLexize(config string, token string) (lexeme string, notAStopWord bool, er } env := snowballstem.NewEnv(lower) stemmer(env) - return env.Current(), true, nil + return env.Current(), false, nil } // DocumentToTSVector parses an input document into lexemes, removes stop words, @@ -419,17 +431,18 @@ func DocumentToTSVector(config string, input string) (TSVector, error) { tokens := TSParse(input) vector := make(TSVector, 0, len(tokens)) for i := range tokens { - lexeme, ok, err := TSLexize(config, tokens[i]) + lexeme, stopWord, err := TSLexize(config, tokens[i]) if err != nil { return nil, err } - if !ok { + if stopWord { continue } term := tsTerm{lexeme: lexeme} pos := i + 1 if i > maxTSVectorPosition { + // Postgres silently truncates positions larger than 16383 to 16383. pos = maxTSVectorPosition } term.positions = []tsPosition{{position: uint16(pos)}}