From f6e8acdbfed92de85cb6b3e17fc50b4092a22653 Mon Sep 17 00:00:00 2001 From: Jordan Lewis Date: Mon, 11 Jul 2022 00:24:06 -0400 Subject: [PATCH 1/5] util: add tsearch package; tsvector parsing This commit adds a tsearch package which will contain text search algorithms for the tsvector/tsquery full text search capability. It also adds parsing of the tsvector "sub language". Release note: None --- pkg/BUILD.bazel | 4 + pkg/util/tsearch/BUILD.bazel | 29 +++ pkg/util/tsearch/lex.go | 351 ++++++++++++++++++++++++++++++ pkg/util/tsearch/tsvector.go | 201 +++++++++++++++++ pkg/util/tsearch/tsvector_test.go | 119 ++++++++++ 5 files changed, 704 insertions(+) create mode 100644 pkg/util/tsearch/BUILD.bazel create mode 100644 pkg/util/tsearch/lex.go create mode 100644 pkg/util/tsearch/tsvector.go create mode 100644 pkg/util/tsearch/tsvector_test.go diff --git a/pkg/BUILD.bazel b/pkg/BUILD.bazel index bfc9dab23d50..2f24e20288a4 100644 --- a/pkg/BUILD.bazel +++ b/pkg/BUILD.bazel @@ -614,6 +614,7 @@ ALL_TESTS = [ "//pkg/util/tracing:tracing_test", "//pkg/util/treeprinter:treeprinter_test", "//pkg/util/trigram:trigram_test", + "//pkg/util/tsearch:tsearch_test", "//pkg/util/uint128:uint128_test", "//pkg/util/ulid:ulid_test", "//pkg/util/unique:unique_test", @@ -2094,6 +2095,8 @@ GO_TARGETS = [ "//pkg/util/treeprinter:treeprinter_test", "//pkg/util/trigram:trigram", "//pkg/util/trigram:trigram_test", + "//pkg/util/tsearch:tsearch", + "//pkg/util/tsearch:tsearch_test", "//pkg/util/uint128:uint128", "//pkg/util/uint128:uint128_test", "//pkg/util/ulid:ulid", @@ -3046,6 +3049,7 @@ GET_X_DATA_TARGETS = [ "//pkg/util/tracing/zipper:get_x_data", "//pkg/util/treeprinter:get_x_data", "//pkg/util/trigram:get_x_data", + "//pkg/util/tsearch:get_x_data", "//pkg/util/uint128:get_x_data", "//pkg/util/ulid:get_x_data", "//pkg/util/unaccent:get_x_data", diff --git a/pkg/util/tsearch/BUILD.bazel b/pkg/util/tsearch/BUILD.bazel new file mode 100644 index 000000000000..8fff7e2655e0 --- /dev/null +++ b/pkg/util/tsearch/BUILD.bazel @@ -0,0 +1,29 @@ +load("//build/bazelutil/unused_checker:unused.bzl", "get_x_data") +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "tsearch", + srcs = [ + "lex.go", + "tsvector.go", + ], + importpath = "github.com/cockroachdb/cockroach/pkg/util/tsearch", + visibility = ["//visibility:public"], + deps = [ + "//pkg/sql/pgwire/pgcode", + "//pkg/sql/pgwire/pgerror", + ], +) + +go_test( + name = "tsearch_test", + srcs = ["tsvector_test.go"], + embed = [":tsearch"], + deps = [ + "//pkg/util/randutil", + "@com_github_stretchr_testify//assert", + "@com_github_stretchr_testify//require", + ], +) + +get_x_data(name = "get_x_data") diff --git a/pkg/util/tsearch/lex.go b/pkg/util/tsearch/lex.go new file mode 100644 index 000000000000..ee0ae85c6713 --- /dev/null +++ b/pkg/util/tsearch/lex.go @@ -0,0 +1,351 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package tsearch + +import ( + "sort" + "strconv" + "unicode" + "unicode/utf8" + + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" +) + +type tsVectorParseState int + +const ( + // Waiting for term (whitespace, ', or any other char) + expectingTerm tsVectorParseState = iota + // Inside of a normal term (single quotes are processed as normal chars) + insideNormalTerm + // Inside of a ' term + insideQuoteTerm + // Finished with ' term (waiting for : or space) + finishedQuoteTerm + // Found a colon (or comma) and expecting a position + expectingPosList + // Finished parsing a position, expecting a comma or whitespace + expectingPosDelimiter +) + +type tsVectorLexer struct { + input string + lastLen int + pos int + state tsVectorParseState + + // If true, we're in "TSQuery lexing mode" + tsquery bool +} + +func (p *tsVectorLexer) back() { + p.pos -= p.lastLen + p.lastLen = 0 +} + +func (p *tsVectorLexer) advance() rune { + r, n := utf8.DecodeRuneInString(p.input[p.pos:]) + p.pos += n + p.lastLen = n + return r +} + +// lex lexes the input in the receiver according to the TSVector "grammar". +// +// A simple TSVector input could look like this: +// +// foo bar:3 baz:3A 'blah :blah' +// +// A TSVector is a list of terms. +// +// Each term is a word and an optional "position list". +// +// A word may be single-quote wrapped, in which case the next term may begin +// without any whitespace in between (if there is no position list on the word). +// In a single-quote wrapped word, the word must terminate with a single quote. +// All other characters are treated as literals. Backlashes can be used to +// escape single quotes, and are otherwise no-ops. +// +// If a word is not single-quote wrapped, the next term will begin if there is +// whitespace after the word. Whitespace and colons may be entered by escaping +// them with backslashes. All other uses of backslashes are no-ops. +// +// A word is delimited from its position list with a colon. +// +// A position list is made up of a comma-delimited list of numbers, each +// of which may have an optional "strength" which is a letter from A-D. +// +// See examples in tsvector_test.go. +func (p tsVectorLexer) lex() (TSVector, error) { + // termBuf will be reused as a temporary buffer to assemble each term before + // copying into the vector. + termBuf := make([]rune, 0, 32) + ret := TSVector{} + + for p.pos < len(p.input) { + r := p.advance() + switch p.state { + case expectingTerm: + // Expect either a single quote, a whitespace, or anything else. + if r == '\'' { + p.state = insideQuoteTerm + continue + } + if unicode.IsSpace(r) { + continue + } + + if p.tsquery { + // Check for &, |, !, and <-> (or ) + switch r { + case '&': + ret = append(ret, tsTerm{operator: and}) + continue + case '|': + ret = append(ret, tsTerm{operator: or}) + continue + case '!': + ret = append(ret, tsTerm{operator: not}) + continue + case '(': + ret = append(ret, tsTerm{operator: lparen}) + continue + case ')': + ret = append(ret, tsTerm{operator: rparen}) + continue + case '<': + r = p.advance() + n := 1 + if r == '-' { + r = p.advance() + } else { + for unicode.IsNumber(r) { + termBuf = append(termBuf, r) + r = p.advance() + } + var err error + n, err = strconv.Atoi(string(termBuf)) + termBuf = termBuf[:0] + if err != nil { + return p.syntaxError() + } + } + if r != '>' { + return p.syntaxError() + } + ret = append(ret, tsTerm{operator: followedby, followedN: n}) + continue + } + } + + p.state = insideNormalTerm + // Need to consume the rune we just found again. + p.back() + continue + + case insideQuoteTerm: + // If escaped, eat character and continue. + switch r { + case '\\': + r = p.advance() + termBuf = append(termBuf, r) + continue + case '\'': + ret = append(ret, tsTerm{lexeme: string(termBuf)}) + termBuf = termBuf[:0] + p.state = finishedQuoteTerm + continue + } + termBuf = append(termBuf, r) + case finishedQuoteTerm: + if unicode.IsSpace(r) { + p.state = expectingTerm + } else if r == ':' { + lastTerm := &ret[len(ret)-1] + lastTerm.positions = append(lastTerm.positions, tsposition{}) + p.state = expectingPosList + } else { + p.state = expectingTerm + p.back() + } + case insideNormalTerm: + // If escaped, eat character and continue. + if r == '\\' { + r = p.advance() + termBuf = append(termBuf, r) + continue + } + + if p.tsquery { + switch r { + case '&', '!', '|', '<', '(', ')': + // These are all "operators" in the TSQuery language. End the current + // term and start a new one. + ret = append(ret, tsTerm{lexeme: string(termBuf)}) + termBuf = termBuf[:0] + p.state = expectingTerm + p.back() + continue + } + } + + // Colon that comes first is an ordinary character. + space := unicode.IsSpace(r) + if space || r == ':' && len(termBuf) > 0 { + // Found a terminator. + // Copy the termBuf into the vector, resize the termBuf, continue on. + term := tsTerm{lexeme: string(termBuf)} + if r == ':' { + term.positions = append(term.positions, tsposition{}) + } + ret = append(ret, term) + termBuf = termBuf[:0] + if space { + p.state = expectingTerm + } else { + p.state = expectingPosList + } + continue + } + if p.tsquery && r == ':' { + return p.syntaxError() + } + termBuf = append(termBuf, r) + case expectingPosList: + var pos int + if !p.tsquery { + // If we have nothing in our termBuf, we need to see at least one number. + if unicode.IsNumber(r) { + termBuf = append(termBuf, r) + continue + } + if len(termBuf) == 0 { + return p.syntaxError() + } + var err error + pos, err = strconv.Atoi(string(termBuf)) + if err != nil { + return p.syntaxError() + } + if pos == 0 { + return ret, pgerror.Newf(pgcode.Syntax, "wrong position info in TSVector", p.input) + } + termBuf = termBuf[:0] + } + lastTerm := &ret[len(ret)-1] + lastTermPos := len(lastTerm.positions) - 1 + lastTerm.positions[lastTermPos].position = pos + if unicode.IsSpace(r) { + // Done with our term. Advance to next term! + p.state = expectingTerm + continue + } + switch r { + case ',': + if p.tsquery { + // Not valid! No , allowed in position lists in tsqueries. + return ret, pgerror.Newf(pgcode.Syntax, "syntax error in TSVector: %s", p.input) + } + lastTerm.positions = append(lastTerm.positions, tsposition{}) + // Expecting another number next. + continue + case '*': + if p.tsquery { + lastTerm.positions[lastTermPos].weight |= weightStar + } else { + p.state = expectingPosDelimiter + lastTerm.positions[lastTermPos].weight |= weightA + } + case 'a', 'A': + if !p.tsquery { + p.state = expectingPosDelimiter + } + lastTerm.positions[lastTermPos].weight |= weightA + case 'b', 'B': + if !p.tsquery { + p.state = expectingPosDelimiter + } + lastTerm.positions[lastTermPos].weight |= weightB + case 'c', 'C': + if !p.tsquery { + p.state = expectingPosDelimiter + } + lastTerm.positions[lastTermPos].weight |= weightC + case 'd', 'D': + if p.tsquery { + lastTerm.positions[lastTermPos].weight |= weightD + } else { + p.state = expectingPosDelimiter + } + default: + return p.syntaxError() + } + case expectingPosDelimiter: + if r == ',' { + p.state = expectingPosList + lastTerm := &ret[len(ret)-1] + lastTerm.positions = append(lastTerm.positions, tsposition{}) + } else if unicode.IsSpace(r) { + p.state = expectingTerm + } else { + return p.syntaxError() + } + default: + panic("invalid TSVector lex state") + } + } + // Reached the end of the string. + switch p.state { + case insideQuoteTerm: + // Unfinished quote term. + return p.syntaxError() + case insideNormalTerm: + // Finish normal term. + ret = append(ret, tsTerm{lexeme: string(termBuf)}) + case expectingPosList: + // Finish number. + if !p.tsquery { + if len(termBuf) == 0 { + return p.syntaxError() + } + pos, err := strconv.Atoi(string(termBuf)) + if err != nil { + return p.syntaxError() + } + if pos == 0 { + return ret, pgerror.Newf(pgcode.Syntax, "wrong position info in TSVector", p.input) + } + lastTerm := &ret[len(ret)-1] + lastTerm.positions[len(lastTerm.positions)-1].position = pos + } + case expectingTerm, finishedQuoteTerm: + // We are good to go, we just finished a term and nothing needs to be cleaned up. + case expectingPosDelimiter: + // We are good to go, we just finished a position and nothing needs to be cleaned up. + default: + panic("invalid TSVector lex state") + } + for _, t := range ret { + sort.Slice(t.positions, func(i, j int) bool { + return t.positions[i].position < t.positions[j].position + }) + } + return ret, nil +} + +func (p *tsVectorLexer) syntaxError() (TSVector, error) { + typ := "TSVector" + if p.tsquery { + typ = "TSQuery" + } + return TSVector{}, pgerror.Newf(pgcode.Syntax, "syntax error in %s: %s", typ, p.input) +} diff --git a/pkg/util/tsearch/tsvector.go b/pkg/util/tsearch/tsvector.go new file mode 100644 index 000000000000..46746365c625 --- /dev/null +++ b/pkg/util/tsearch/tsvector.go @@ -0,0 +1,201 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package tsearch + +import ( + "fmt" + "sort" + "strconv" + "strings" +) + +// tsweight is a string A, B, C, or D. +type tsweight int + +const ( + weightD tsweight = 1 << iota + weightC + weightB + weightA + weightStar +) + +func (w tsweight) String() string { + var ret strings.Builder + if w&weightStar != 0 { + ret.WriteByte('*') + } + if w&weightA != 0 { + ret.WriteByte('A') + } + if w&weightB != 0 { + ret.WriteByte('B') + } + if w&weightC != 0 { + ret.WriteByte('C') + } + if w&weightD != 0 { + ret.WriteByte('D') + } + return ret.String() +} + +type tsposition struct { + position int + weight tsweight +} + +type tsoperator int + +const ( + // Parentheses can be used to control nesting of the TSQuery operators. + // Without parentheses, | binds least tightly, + // then &, then <->, and ! most tightly. + + invalid tsoperator = iota + and // binary + or // binary + not // unary prefix + followedby // binary + lparen // grouping + rparen // grouping +) + +func (o tsoperator) precedence() int { + switch o { + case not: + return 4 + case followedby: + return 3 + case and: + return 2 + case or: + return 1 + } + panic(fmt.Sprintf("no precdence for operator %d", o)) +} + +type tsTerm struct { + // A term is either a lexeme and position list, or an operator. + lexeme string + positions []tsposition + + operator tsoperator + // Set only when operator = followedby + followedN int +} + +func (t tsTerm) String() string { + if t.operator != 0 { + switch t.operator { + case and: + return "&" + case or: + return "|" + case not: + return "!" + case lparen: + return "(" + case rparen: + return ")" + case followedby: + if t.followedN == 1 { + return "<->" + } + return fmt.Sprintf("<%d>", t.followedN) + } + } + + var buf strings.Builder + buf.WriteByte('\'') + for _, r := range t.lexeme { + if r == '\'' { + // Single quotes are escaped as double single quotes inside of a TSVector. + buf.WriteString(`''`) + } else { + buf.WriteRune(r) + } + } + buf.WriteByte('\'') + for i, pos := range t.positions { + if i > 0 { + buf.WriteByte(',') + } else { + buf.WriteByte(':') + } + if pos.position > 0 { + buf.WriteString(strconv.Itoa(pos.position)) + } + buf.WriteString(pos.weight.String()) + } + return buf.String() +} + +// TSVector is a sorted list of terms, each of which is a lexeme that might have +// an associated position within an original document. +type TSVector []tsTerm + +func (t TSVector) String() string { + var buf strings.Builder + for i, term := range t { + if i > 0 { + buf.WriteByte(' ') + } + buf.WriteString(term.String()) + } + return buf.String() +} + +// ParseTSVector produces a TSVector from an input string. The input will be +// sorted by lexeme, but will not be automatically stemmed or stop-worded. +func ParseTSVector(input string) (TSVector, error) { + parser := tsVectorLexer{ + input: input, + state: expectingTerm, + } + ret, err := parser.lex() + if err != nil { + return ret, err + } + + if len(ret) > 1 { + // Sort and de-duplicate the resultant TSVector. + sort.Slice(ret, func(i, j int) bool { + return ret[i].lexeme < ret[j].lexeme + }) + // Then distinct: (wouldn't it be nice if Go had generics?) + lastUniqueIdx := 0 + for j := 1; j < len(ret); j++ { + if ret[j].lexeme != ret[lastUniqueIdx].lexeme { + // We found a unique entry, at index i. The last unique entry in the + // array was at lastUniqueIdx, so set the entry after that one to our + // new unique entry, and bump lastUniqueIdx for the next loop iteration. + // First, sort and unique the position list now that we've collapsed all + // of the identical lexemes. + ret[lastUniqueIdx].positions = sortAndUniqTSPositions(ret[lastUniqueIdx].positions) + lastUniqueIdx++ + ret[lastUniqueIdx] = ret[j] + } else { + // The last entries were not unique. Collapse their positions into the + // first entry's list. + ret[lastUniqueIdx].positions = append(ret[lastUniqueIdx].positions, ret[j].positions...) + } + } + ret = ret[:lastUniqueIdx+1] + } + if len(ret) >= 1 { + // Make sure to sort and uniq the position list even if there's only 1 + // entry. + latsIdx := len(ret) - 1 + ret[latsIdx].positions = sortAndUniqTSPositions(ret[latsIdx].positions) + } + return ret, nil +} diff --git a/pkg/util/tsearch/tsvector_test.go b/pkg/util/tsearch/tsvector_test.go new file mode 100644 index 000000000000..f8bd7f373947 --- /dev/null +++ b/pkg/util/tsearch/tsvector_test.go @@ -0,0 +1,119 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package tsearch + +import ( + "testing" + + "github.com/cockroachdb/cockroach/pkg/util/randutil" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseTSVector(t *testing.T) { + for _, tc := range []struct { + input string + expectedStr string + }{ + {``, ``}, + {`foo`, `'foo'`}, + {`foo bar`, `'bar' 'foo'`}, + {`foo bar`, `'bar' 'foo'`}, + {`foo +bar `, `'bar' 'foo'`}, + {`foo` + "\t" + `bar `, `'bar' 'foo'`}, + {`foo ` + "\t" + `bar `, `'bar' 'foo'`}, + {`foo:3`, `'foo':3`}, + {`foo:03`, `'foo':3`}, + {`foo:3a`, `'foo':3A`}, + {`foo:3A`, `'foo':3A`}, + {`foo:3D`, `'foo':3`}, + {`foo:3D`, `'foo':3`}, + {`foo:30A`, `'foo':30A`}, + {`foo:20,30b,1a`, `'foo':1A,20,30B`}, + {`foo:3 bar`, `'bar' 'foo':3`}, + {`foo:3a bar`, `'bar' 'foo':3A`}, + {`foo:3a bar:100`, `'bar':100 'foo':3A`}, + {`:foo`, `':foo'`}, + {`:`, `':'`}, + {`\:`, `':'`}, + {`'\:'`, `':'`}, + {`'\ '`, `' '`}, + {`\ `, `' '`}, + {`:3`, `':3'`}, + {`::3`, `':':3`}, + {`:3:3`, `':3':3`}, + + {`blah'blah`, `'blah''blah'`}, + {`blah'`, `'blah'''`}, + {`blah''`, `'blah'''''`}, + {`blah'':3`, `'blah''''':3`}, + {`blah\'':3`, `'blah''''':3`}, + {`b\:lah\::3`, `'b:lah:':3`}, + {`b\ lah\ :3`, `'b lah ':3`}, + {`b\lah:3`, `'blah':3`}, + {`b\lah\:`, `'blah:'`}, + + {`'blah'`, `'blah'`}, + {`'blah':3`, `'blah':3`}, + {`'bla\'h'`, `'bla''h'`}, + {`'bla\h'`, `'blah'`}, + {`'bla\ h'`, `'bla h'`}, + {`'bla h'`, `'bla h'`}, + {`'blah'arg:3`, `'arg':3 'blah'`}, + {`'blah'arg`, `'arg' 'blah'`}, + + // Test that we sort and de-duplicate inputs (both lexemes and position + // lists). + {`foo:3,2,1`, `'foo':1,2,3`}, + {`foo:3,1`, `'foo':1,3`}, + {`foo:3,2,1 foo:1,2,3`, `'foo':1,2,3`}, + {`a:3 b:2 a:1`, `'a':1,3 'b':2`}, + } { + t.Log(tc.input) + vec, err := ParseTSVector(tc.input) + require.NoError(t, err) + actual := vec.String() + assert.Equal(t, tc.expectedStr, actual) + } +} + +func TestParseTSVectorError(t *testing.T) { + for _, tc := range []string{ + `foo:`, + `foo:a`, + `foo:ab`, + `foo:0`, + `foo:0,3`, + `foo:1f`, + `foo:1f blah`, + `foo:1,`, + `foo:1, a`, + `foo:1, `, + `foo:\1`, + `foo:-1`, + `'foo`, + } { + t.Log(tc) + _, err := ParseTSVector(tc) + assert.Error(t, err) + } +} + +func TestParseTSRandom(t *testing.T) { + r, _ := randutil.NewTestRand() + for i := 0; i < 10000; i++ { + b := randutil.RandBytes(r, randutil.RandIntInRange(r, 0, 50)) + // Just make sure we don't panic + _, _ = ParseTSVector(string(b)) + _, _ = ParseTSQuery(string(b)) + } +} From 56ffe66776b0ee0582b50ba3566110816ff84386 Mon Sep 17 00:00:00 2001 From: Jordan Lewis Date: Mon, 11 Jul 2022 00:25:27 -0400 Subject: [PATCH 2/5] tsearch: add parsing for tsquery sublanguage This commit adds lexing and parsing for the tsquery sublanguage, which is similar to tsvector but with a few different bells and whistles. The tsquery is lexed and then parsed into a simple AST. Release note: None --- pkg/util/tsearch/BUILD.bazel | 8 +- pkg/util/tsearch/tsquery.go | 224 ++++++++++++++++++++++++++ pkg/util/tsearch/tsquery_test.go | 265 +++++++++++++++++++++++++++++++ 3 files changed, 496 insertions(+), 1 deletion(-) create mode 100644 pkg/util/tsearch/tsquery.go create mode 100644 pkg/util/tsearch/tsquery_test.go diff --git a/pkg/util/tsearch/BUILD.bazel b/pkg/util/tsearch/BUILD.bazel index 8fff7e2655e0..d953811d9522 100644 --- a/pkg/util/tsearch/BUILD.bazel +++ b/pkg/util/tsearch/BUILD.bazel @@ -5,6 +5,7 @@ go_library( name = "tsearch", srcs = [ "lex.go", + "tsquery.go", "tsvector.go", ], importpath = "github.com/cockroachdb/cockroach/pkg/util/tsearch", @@ -17,10 +18,15 @@ go_library( go_test( name = "tsearch_test", - srcs = ["tsvector_test.go"], + srcs = [ + "tsquery_test.go", + "tsvector_test.go", + ], embed = [":tsearch"], deps = [ + "//pkg/testutils/skip", "//pkg/util/randutil", + "@com_github_jackc_pgx_v4//:pgx", "@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//require", ], diff --git a/pkg/util/tsearch/tsquery.go b/pkg/util/tsearch/tsquery.go new file mode 100644 index 000000000000..32e67cf75c8b --- /dev/null +++ b/pkg/util/tsearch/tsquery.go @@ -0,0 +1,224 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package tsearch + +import ( + "fmt" + "strings" + + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" + "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" +) + +// tsNode represents a single AST node within the tree of a TSQuery. +type tsNode struct { + // Only one of term, op, or paren will be set. + // If term is set, this is a leaf node containing a lexeme. + term tsTerm + // If op is set, this is an operator node: either not, and, or, or followedby. + op tsoperator + // set only when op is followedby. Indicates the number n within the + // operator, which means the number of terms separating the left and the right + // argument. + followedN int + + // l is the left child of the node if op is set, or the only child if + // op is set to "not". + l *tsNode + // r is the right child of the node if op is set. + r *tsNode +} + +func (n tsNode) String() string { + return n.infixString(0) +} + +func (n tsNode) infixString(parentPrecedence int) string { + if n.op == invalid { + return n.term.String() + } + var s strings.Builder + prec := n.op.precedence() + needParen := prec < parentPrecedence + if needParen { + s.WriteString("( ") + } + switch n.op { + case not: + fmt.Fprintf(&s, "!%s", n.l.infixString(prec)) + default: + fmt.Fprintf(&s, "%s %s %s", + n.l.infixString(prec), + tsTerm{operator: n.op, followedN: n.followedN}, + n.r.infixString(prec), + ) + } + if needParen { + s.WriteString(" )") + } + return s.String() +} + +// UnambiguousString returns a string representation of this tsNode that wraps +// all expressions with parentheses. It's just for testing. +func (n tsNode) UnambiguousString() string { + switch n.op { + case invalid: + return n.term.lexeme + case not: + return fmt.Sprintf("!%s", n.l.UnambiguousString()) + } + return fmt.Sprintf("[%s%s%s]", n.l.UnambiguousString(), tsTerm{operator: n.op, followedN: n.followedN}, n.r.UnambiguousString()) +} + +// TSQuery represents a tsquery AST root. A tsquery is a tree of text search +// operators that can be run against a tsvector to produce a predicate of +// whether the query matched. +type TSQuery struct { + root *tsNode +} + +func (q TSQuery) String() string { + if q.root == nil { + return "" + } + return q.root.String() +} + +func lexTSQuery(input string) (TSVector, error) { + parser := tsVectorLexer{ + input: input, + state: expectingTerm, + tsquery: true, + } + + return parser.lex() +} + +// ParseTSQuery produces a TSQuery from an input string. +func ParseTSQuery(input string) (TSQuery, error) { + terms, err := lexTSQuery(input) + if err != nil { + return TSQuery{}, err + } + + // Now create the operator tree. + queryParser := tsQueryParser{terms: terms, input: input} + return queryParser.parse() +} + +type tsQueryParser struct { + input string + terms TSVector +} + +func (p tsQueryParser) peek() (*tsTerm, bool) { + if len(p.terms) == 0 { + return nil, false + } + return &p.terms[0], true +} + +func (p *tsQueryParser) nextTerm() (*tsTerm, bool) { + if len(p.terms) == 0 { + return nil, false + } + ret := &p.terms[0] + p.terms = p.terms[1:] + return ret, true +} + +func (p *tsQueryParser) parse() (TSQuery, error) { + expr, err := p.parseTSExpr(0) + if err != nil { + return TSQuery{}, err + } + if len(p.terms) > 0 { + _, err := p.syntaxError() + return TSQuery{}, err + } + return TSQuery{root: expr}, nil +} + +func (p *tsQueryParser) parseTSExpr(minBindingPower int) (*tsNode, error) { + t, ok := p.nextTerm() + if !ok { + return nil, pgerror.Newf(pgcode.Syntax, "text-search query doesn't contain lexemes: %s", p.input) + } + + // First section: grab either atoms, nots, or parens. + var lExpr *tsNode + switch t.operator { + case invalid: + lExpr = &tsNode{term: *t} + case lparen: + expr, err := p.parseTSExpr(0) + if err != nil { + return nil, err + } + t, ok := p.nextTerm() + if !ok || t.operator != rparen { + return p.syntaxError() + } + lExpr = expr + case not: + t, ok := p.nextTerm() + if !ok { + return p.syntaxError() + } + switch t.operator { + case invalid: + lExpr = &tsNode{op: not, l: &tsNode{term: *t}} + case lparen: + expr, err := p.parseTSExpr(0) + if err != nil { + return nil, err + } + lExpr = &tsNode{op: not, l: expr} + t, ok := p.nextTerm() + if !ok || t.operator != rparen { + return p.syntaxError() + } + default: + return p.syntaxError() + } + default: + return p.syntaxError() + } + + // Now we do our "Pratt parser loop". + for { + next, ok := p.peek() + if !ok { + return lExpr, nil + } + switch next.operator { + case and, or, followedby: + default: + return lExpr, nil + } + precedence := next.operator.precedence() + if precedence < minBindingPower { + break + } + p.nextTerm() + rExpr, err := p.parseTSExpr(precedence) + if err != nil { + return nil, err + } + lExpr = &tsNode{op: next.operator, followedN: next.followedN, l: lExpr, r: rExpr} + } + return lExpr, nil +} + +func (p *tsQueryParser) syntaxError() (*tsNode, error) { + return nil, pgerror.Newf(pgcode.Syntax, "syntax error in TSQuery: %s", p.input) +} diff --git a/pkg/util/tsearch/tsquery_test.go b/pkg/util/tsearch/tsquery_test.go new file mode 100644 index 000000000000..cce8b046a115 --- /dev/null +++ b/pkg/util/tsearch/tsquery_test.go @@ -0,0 +1,265 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package tsearch + +import ( + "context" + "testing" + + "github.com/cockroachdb/cockroach/pkg/testutils/skip" + "github.com/jackc/pgx/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestScanTSQuery(t *testing.T) { + for _, tc := range []struct { + input string + expectedStr string + }{ + {`foo&bar`, `'foo' & 'bar'`}, + + {``, ``}, + {`foo`, `'foo'`}, + {`foo bar`, `'foo' 'bar'`}, + {`foo bar`, `'foo' 'bar'`}, + {`foo +bar `, `'foo' 'bar'`}, + {`foo` + "\t" + `bar `, `'foo' 'bar'`}, + {`foo ` + "\t" + `bar `, `'foo' 'bar'`}, + + {`foo:a`, `'foo':A`}, + {`foo:ab`, `'foo':AB`}, + {`foo:ba`, `'foo':AB`}, + {`foo:dbaba`, `'foo':ABD`}, + {`foo:*`, `'foo':*`}, + {`foo:cab*cccdba`, `'foo':*ABCD`}, + + {`\:`, `':'`}, + {`'\:'`, `':'`}, + {`'\ '`, `' '`}, + {`\ `, `' '`}, + + {`blah'blah`, `'blah''blah'`}, + {`blah'`, `'blah'''`}, + {`blah''`, `'blah'''''`}, + {`b\lah\:`, `'blah:'`}, + + {`'blah'`, `'blah'`}, + {`'bla\'h'`, `'bla''h'`}, + {`'bla\h'`, `'blah'`}, + {`'bla\ h'`, `'bla h'`}, + {`'bla h'`, `'bla h'`}, + {`'blah'arg`, `'blah' 'arg'`}, + + {`foo&bar`, `'foo' & 'bar'`}, + {`foo\&bar`, `'foo&bar'`}, + {`\&bar`, `'&bar'`}, + {`'b&ar'`, `'b&ar'`}, + {`foo |bar`, `'foo' | 'bar'`}, + {`foo!bar`, `'foo' ! 'bar'`}, + {`foo<->bar`, `'foo' <-> 'bar'`}, + {`foo<1>bar`, `'foo' <-> 'bar'`}, + {`foo<2>bar`, `'foo' <2> 'bar'`}, + {`foo<0>bar`, `'foo' <0> 'bar'`}, + {`foo<20>bar`, `'foo' <20> 'bar'`}, + {`foo<020>bar`, `'foo' <20> 'bar'`}, + + {`()`, `( )`}, + {`f(a)b`, `'f' ( 'a' ) 'b'`}, + {`'(' ')'`, `'(' ')'`}, + } { + t.Log(tc.input) + vec, err := lexTSQuery(tc.input) + require.NoError(t, err) + actual := vec.String() + assert.Equal(t, tc.expectedStr, actual) + } +} + +func TestScanTSQueryError(t *testing.T) { + for _, tc := range []string{ + `foo:3`, + `foo:3A`, + `foo:3D`, + `foo:3D`, + `foo:30A`, + `foo:30A,20,1a`, + `foo:3 bar`, + `foo:3a bar`, + `foo:3a bar:100`, + + `:`, + `:3`, + `::3`, + `:3:3`, + `:foo`, + + `blah'':3`, + `blah\'':3`, + `b\:lah\::3`, + `b\ lah\ :3`, + `b\lah:3`, + `'blah':3`, + `'blah'arg:3`, + + `a<`, + `a<-x`, + `a<-0`, + `a`, + } { + t.Log(tc) + _, err := lexTSQuery(tc) + assert.Error(t, err) + } +} + +func TestParseTSQuery(t *testing.T) { + tcs := []struct { + input string + expectedStr string + }{ + {`foo`, `'foo'`}, + + {`foo:a`, `'foo':A`}, + {`foo:ab`, `'foo':AB`}, + {`foo:ba`, `'foo':AB`}, + {`foo:dbaba`, `'foo':ABD`}, + {`foo:*`, `'foo':*`}, + {`foo:cab*cccdba`, `'foo':*ABCD`}, + + {`\:`, `':'`}, + {`'\:'`, `':'`}, + {`'\ '`, `' '`}, + {`\ `, `' '`}, + + {`blah'blah`, `'blah''blah'`}, + {`blah'`, `'blah'''`}, + {`blah''`, `'blah'''''`}, + {`b\lah\:`, `'blah:'`}, + + {`'blah'`, `'blah'`}, + {`'bla\'h'`, `'bla''h'`}, + {`'bla\h'`, `'blah'`}, + {`'bla\ h'`, `'bla h'`}, + {`'bla h'`, `'bla h'`}, + + {`foo&bar`, `'foo' & 'bar'`}, + {`foo\&bar`, `'foo&bar'`}, + {`\&bar`, `'&bar'`}, + {`'b&ar'`, `'b&ar'`}, + {`foo |bar`, `'foo' | 'bar'`}, + {`foo<->bar`, `'foo' <-> 'bar'`}, + {`foo<1>bar`, `'foo' <-> 'bar'`}, + {`foo<2>bar`, `'foo' <2> 'bar'`}, + {`foo<0>bar`, `'foo' <0> 'bar'`}, + {`foo<20>bar`, `'foo' <20> 'bar'`}, + {`foo<020>bar`, `'foo' <20> 'bar'`}, + {`foo<->bar<->baz`, `'foo' <-> 'bar' <-> 'baz'`}, + {`foo<->bar&baz|qux`, `'foo' <-> 'bar' & 'baz' | 'qux'`}, + {`(foo)`, `'foo'`}, + {`((foo))`, `'foo'`}, + + {`!(a | !b)`, `!( 'a' | !'b' )`}, + {`!(a | (b & c))`, `!( 'a' | 'b' & 'c' )`}, + {`a & b & c`, `'a' & 'b' & 'c'`}, + {`a & b & !c`, `'a' & 'b' & !'c'`}, + {`d <0> !x | y`, `'d' <0> !'x' | 'y'`}, + {`a <-> (d <0> !x | y)`, `'a' <-> ( 'd' <0> !'x' | 'y' )`}, + {`(a) | b`, `'a' | 'b'`}, + {`!a | b`, `!'a' | 'b'`}, + + {`(b & c) <-> (d <0> !x | y)`, `( 'b' & 'c' ) <-> ( 'd' <0> !'x' | 'y' )`}, + {`a | (b & c) <-> (d <0> !x | y)`, `'a' | ( 'b' & 'c' ) <-> ( 'd' <0> !'x' | 'y' )`}, + {`!(a | (b & c) <-> (d <0> !x | y))`, `!( 'a' | ( 'b' & 'c' ) <-> ( 'd' <0> !'x' | 'y' ) )`}, + } + for _, tc := range tcs { + t.Log(tc.input) + query, err := ParseTSQuery(tc.input) + assert.NoError(t, err) + actual := query.String() + assert.Equal(t, tc.expectedStr, actual) + } + + t.Run("ComparePG", func(t *testing.T) { + skip.IgnoreLint(t, "need to manually enable") + conn, err := pgx.Connect(context.Background(), "postgresql://jordan@localhost:5432") + require.NoError(t, err) + for _, tc := range tcs { + t.Log(tc) + + var actual string + row := conn.QueryRow(context.Background(), "SELECT $1::TSQuery", tc.input) + require.NoError(t, row.Scan(&actual)) + assert.Equal(t, tc.expectedStr, actual) + } + }) +} + +func TestParseTSQueryAssociativity(t *testing.T) { + for _, tc := range []struct { + input string + expectedStr string + }{ + {`a<->b<->c`, `[a<->[b<->c]]`}, + {`a|b|c`, `[a|[b|c]]`}, + {`a&b&c`, `[a&[b&c]]`}, + {`a<->b&c|d`, `[[[a<->b]&c]|d]`}, + {`a&b<->c`, `[a&[b<->c]]`}, + {`a|b&c<->d`, `[a|[b&[c<->d]]]`}, + {`a|b<->c&d`, `[a|[[b<->c]&d]]`}, + } { + t.Log(tc.input) + query, err := ParseTSQuery(tc.input) + assert.NoError(t, err) + actual := query.root.UnambiguousString() + assert.Equal(t, tc.expectedStr, actual) + } +} + +func TestParseTSQueryError(t *testing.T) { + for _, tc := range []string{ + ``, + ` `, + ` | `, + `&`, + `<->`, + `<0>`, + `!`, + `(`, + `)`, + `(foo))`, + `(foo`, + `foo )`, + `foo bar`, + `foo(bar)`, + `foo!bar`, + `&foo`, + `|foo`, + `<->foo`, + `foo&`, + `foo|`, + `foo!`, + `foo<->`, + `()`, + `(foo bar)`, + `f(a)b`, + } { + t.Log(tc) + _, err := ParseTSQuery(tc) + assert.Error(t, err) + } +} From 2e510690dcfbefb8bafb039d662a641462fd62a9 Mon Sep 17 00:00:00 2001 From: Jordan Lewis Date: Mon, 11 Jul 2022 00:26:15 -0400 Subject: [PATCH 3/5] tsearch: add eval package The eval package contains a function to evaluate a tsquery against a tsvector, in memory. Release note: None --- pkg/util/tsearch/BUILD.bazel | 3 + pkg/util/tsearch/eval.go | 312 ++++++++++++++++++++++++++++++++++ pkg/util/tsearch/eval_test.go | 137 +++++++++++++++ 3 files changed, 452 insertions(+) create mode 100644 pkg/util/tsearch/eval.go create mode 100644 pkg/util/tsearch/eval_test.go diff --git a/pkg/util/tsearch/BUILD.bazel b/pkg/util/tsearch/BUILD.bazel index d953811d9522..af90c2c62508 100644 --- a/pkg/util/tsearch/BUILD.bazel +++ b/pkg/util/tsearch/BUILD.bazel @@ -4,6 +4,7 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "tsearch", srcs = [ + "eval.go", "lex.go", "tsquery.go", "tsvector.go", @@ -13,12 +14,14 @@ go_library( deps = [ "//pkg/sql/pgwire/pgcode", "//pkg/sql/pgwire/pgerror", + "@com_github_cockroachdb_errors//:errors", ], ) go_test( name = "tsearch_test", srcs = [ + "eval_test.go", "tsquery_test.go", "tsvector_test.go", ], diff --git a/pkg/util/tsearch/eval.go b/pkg/util/tsearch/eval.go new file mode 100644 index 000000000000..e2344ecdfc91 --- /dev/null +++ b/pkg/util/tsearch/eval.go @@ -0,0 +1,312 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package tsearch + +import ( + "math" + "sort" + "strings" + + "github.com/cockroachdb/errors" +) + +// EvalTSQuery runs the provided TSQuery against the provided TSVector, +// returning whether or not the query matches the vector. +func EvalTSQuery(q TSQuery, v TSVector) (bool, error) { + evaluator := tsEvaluator{ + v: v, + q: q, + } + return evaluator.eval() +} + +type tsEvaluator struct { + v TSVector + q TSQuery +} + +func (e *tsEvaluator) eval() (bool, error) { + return e.evalNode(e.q.root) +} + +func (e *tsEvaluator) evalNode(node *tsNode) (bool, error) { + switch node.op { + case invalid: + prefixMatch := false + if len(node.term.positions) > 0 && node.term.positions[0].weight == weightStar { + prefixMatch = true + } + + // To evaluate a term, we search the vector for a match. + target := node.term.lexeme + i := sort.Search(len(e.v), func(i int) bool { + return e.v[i].lexeme >= target + }) + if i < len(e.v) { + t := e.v[i] + if prefixMatch { + return strings.HasPrefix(t.lexeme, target), nil + } + return t.lexeme == target, nil + } + return false, nil + case and: + l, err := e.evalNode(node.l) + if err != nil || !l { + return false, err + } + return e.evalNode(node.r) + case or: + l, err := e.evalNode(node.l) + if err != nil { + return false, err + } + if l { + return true, nil + } + return e.evalNode(node.r) + case not: + ret, err := e.evalNode(node.l) + return !ret, err + case followedby: + positions, err := e.evalWithinFollowedBy(node) + return len(positions.positions) > 0, err + } + return false, errors.AssertionFailedf("invalid operator %d", node.op) +} + +type tspositionset struct { + positions []tsposition + width int + invert bool +} + +type emitMode int + +const ( + emitMatches emitMode = 1 << iota + emitLeftUnmatched + emitRightUnmatched +) + +func (e *tsEvaluator) evalFollowedBy( + lPositions, rPositions tspositionset, lOffset, rOffset int, emitMode emitMode, +) (tspositionset, error) { + // Followed by makes sure that two terms are separated by exactly n words. + // First, find all slots that match for the left expression. + + // Find the offsetted intersection of 2 sorted integer lists, using the + // followedN as the offset. + var ret tspositionset + var lIdx, rIdx int + // Loop through the two sorted position lists, until the position on the + // right is as least as large as the position on the left. + for { + lExhausted := lIdx >= len(lPositions.positions) + rExhausted := rIdx >= len(rPositions.positions) + if lExhausted && rExhausted { + break + } + var lPos, rPos int + if !lExhausted { + lPos = lPositions.positions[lIdx].position + lOffset + } else { + // Quit unless we're outputting all of the RHS, which we will if we have + // a negative match on the LHS. + if emitMode&emitRightUnmatched == 0 { + break + } + lPos = math.MaxInt64 + } + if !rExhausted { + rPos = rPositions.positions[rIdx].position + rOffset + } else { + // Quit unless we're outputting all of the LHS, which we will if we have + // a negative match on the RHS. + if emitMode&emitLeftUnmatched == 0 { + break + } + rPos = math.MaxInt64 + } + + if lPos < rPos { + if emitMode&emitLeftUnmatched > 0 { + ret.positions = append(ret.positions, tsposition{position: lPos}) + } + lIdx++ + } else if lPos == rPos { + if emitMode&emitMatches > 0 { + ret.positions = append(ret.positions, tsposition{position: rPos}) + } + lIdx++ + rIdx++ + } else { + if emitMode&emitRightUnmatched > 0 { + ret.positions = append(ret.positions, tsposition{position: rPos}) + } + rIdx++ + } + } + return ret, nil +} + +// evalWithinFollowedBy is the evaluator for subexpressions of a followed by +// operator. Instead of just returning true or false, and possibly short +// circuiting on boolean ops, we need to return all of the tspositions at which +// each arm of the followed by expression matches. +func (e *tsEvaluator) evalWithinFollowedBy(node *tsNode) (tspositionset, error) { + switch node.op { + case invalid: + prefixMatch := false + if len(node.term.positions) > 0 && node.term.positions[0].weight == weightStar { + prefixMatch = true + } + + // To evaluate a term, we search the vector for a match. + target := node.term.lexeme + i := sort.Search(len(e.v), func(i int) bool { + return e.v[i].lexeme >= target + }) + if i >= len(e.v) { + // No match. + return tspositionset{}, nil + } + var ret []tsposition + if prefixMatch { + for j := i; j < len(e.v); j++ { + t := e.v[j] + if !strings.HasPrefix(t.lexeme, target) { + break + } + ret = append(ret, t.positions...) + } + ret = sortAndUniqTSPositions(ret) + return tspositionset{positions: ret}, nil + } else if e.v[i].lexeme != target { + // No match. + return tspositionset{}, nil + } + return tspositionset{positions: e.v[i].positions}, nil + case or: + var lOffset, rOffset, width int + + lPositions, err := e.evalWithinFollowedBy(node.l) + if err != nil { + return tspositionset{}, err + } + rPositions, err := e.evalWithinFollowedBy(node.r) + if err != nil { + return tspositionset{}, err + } + + width = lPositions.width + if rPositions.width > width { + width = rPositions.width + } + lOffset = width - lPositions.width + rOffset = width - rPositions.width + + mode := emitMatches | emitLeftUnmatched | emitRightUnmatched + invertResults := false + switch { + case lPositions.invert && rPositions.invert: + invertResults = true + mode = emitMatches + case lPositions.invert: + invertResults = true + mode = emitLeftUnmatched + case rPositions.invert: + invertResults = true + mode = emitRightUnmatched + } + ret, err := e.evalFollowedBy(lPositions, rPositions, lOffset, rOffset, mode) + if invertResults { + ret.invert = true + } + ret.width = width + return ret, err + case not: + ret, err := e.evalWithinFollowedBy(node.l) + if err != nil { + return tspositionset{}, err + } + ret.invert = !ret.invert + return ret, nil + case followedby: + // Followed by and and have similar handling. + fallthrough + case and: + var lOffset, rOffset, width int + + lPositions, err := e.evalWithinFollowedBy(node.l) + if err != nil { + return tspositionset{}, err + } + rPositions, err := e.evalWithinFollowedBy(node.r) + if err != nil { + return tspositionset{}, err + } + if node.op == followedby { + lOffset = node.followedN + rPositions.width + width = lOffset + lPositions.width + } else { + width = lPositions.width + if rPositions.width > width { + width = rPositions.width + } + lOffset = width - lPositions.width + rOffset = width - rPositions.width + } + + mode := emitMatches + invertResults := false + switch { + case lPositions.invert && rPositions.invert: + invertResults = true + mode |= emitLeftUnmatched | emitRightUnmatched + case lPositions.invert: + mode = emitRightUnmatched + case rPositions.invert: + mode = emitLeftUnmatched + } + ret, err := e.evalFollowedBy(lPositions, rPositions, lOffset, rOffset, mode) + if invertResults { + ret.invert = true + } + ret.width = width + return ret, err + } + return tspositionset{}, errors.AssertionFailedf("invalid operator %d", node.op) +} + +// sortAndUniqTSPositions sorts and uniquifies the input tsposition list by +// their position attributes. +func sortAndUniqTSPositions(pos []tsposition) []tsposition { + if len(pos) <= 1 { + return pos + } + sort.Slice(pos, func(i, j int) bool { + return pos[i].position < pos[j].position + }) + // Then distinct: (wouldn't it be nice if Go had generics?) + lastUniqueIdx := 0 + for j := 1; j < len(pos); j++ { + if pos[j].position != pos[lastUniqueIdx].position { + // We found a unique entry, at index i. The last unique entry in the array + // was at lastUniqueIdx, so set the entry after that one to our new unique + // entry, and bump lastUniqueIdx for the next loop iteration. + lastUniqueIdx++ + pos[lastUniqueIdx] = pos[j] + } + } + pos = pos[:lastUniqueIdx+1] + return pos +} diff --git a/pkg/util/tsearch/eval_test.go b/pkg/util/tsearch/eval_test.go new file mode 100644 index 000000000000..2c87e4219a86 --- /dev/null +++ b/pkg/util/tsearch/eval_test.go @@ -0,0 +1,137 @@ +// Copyright 2022 The Cockroach Authors. +// +// Use of this software is governed by the Business Source License +// included in the file licenses/BSL.txt. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0, included in the file +// licenses/APL.txt. + +package tsearch + +import ( + "context" + "testing" + + "github.com/cockroachdb/cockroach/pkg/testutils/skip" + "github.com/jackc/pgx/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEval(t *testing.T) { + tcs := []struct { + query string + vector string + expected bool + }{ + {`a`, `a:10`, true}, + {`a`, ``, false}, + {`b`, `a:10`, false}, + {`!a`, `a:10`, false}, + {`!b`, `a:10`, true}, + {`!b`, ``, true}, + {`c`, `a:10 b:3 c:7`, true}, + {`a|b`, `a:10 b:3 c:7`, true}, + {`c|d`, `a:10 b:3 c:7`, true}, + {`d|e`, `a:10 b:3 c:7`, false}, + {`a&b`, `a:10 b:3 c:7`, true}, + {`c&d`, `a:10 b:3 c:7`, false}, + {`d&e`, `a:10 b:3 c:7`, false}, + {`a&!b`, `a:10 b:3 c:7`, false}, + {`c&!d`, `a:10 b:3 c:7`, true}, + {`d&!e`, `a:10 b:3 c:7`, false}, + {`!d&!e`, `a:10 b:3 c:7`, true}, + + // Tests for prefix matching. + {`a:*`, `ar:10`, true}, + {`ar:*`, `ar:10`, true}, + {`arg:*`, `ar:10`, false}, + {`b:*`, `ar:10`, false}, + {`a:* & ar:*`, `ar:10`, true}, + {`a:* & ar:* & arg:*`, `ar:10`, false}, + + // Tests for followed-by. + {`a <-> b`, `a:1 b:2`, true}, + {`a <-> b`, `a:2 b:1`, false}, + {`a <-> b`, `a:1 b:3`, false}, + {`a <-> b <-> c`, `a:1 b:2 c:3`, true}, + {`(a <-> b) <-> c`, `a:1 b:2 c:3`, true}, + {`a <2> b`, `a:1 b:3`, true}, + {`a:* <0> ab:*`, `abba:1`, true}, + {`a:* <0> ab:*`, `a:1`, false}, + + // Negations. + {`a <-> !b`, `a:1 b:2`, false}, + {`a <-> !b`, `a:1 c:2`, true}, + {`a <-> !b`, `a:1,3 c:2 b:4`, true}, + {`a <-> !b`, `a:1,3 b:2,4`, false}, + {`a <-> !b`, `b:2 a:3`, true}, + {`!a <-> b`, `a:1 b:2`, false}, + {`!a <-> b`, `a:1 c:2`, false}, + {`!a <-> b`, `a:1,3 c:2 b:4,5`, true}, + {`!a <-> b`, `a:1,3 b:2,4`, false}, + {`!a <-> b`, `b:2 a:3`, true}, + {`!a <-> !b`, `b:2 a:3`, true}, + {`!a <-> !b`, `a:3 b:4`, true}, + {`!a <-> !b`, `a:3`, true}, + + //{`!a <-> !b`, ``, true}, + // Or on the RHS of a follows-by. + {`a <-> (b|c)`, `a:1 b:2`, true}, + {`a <-> (b|c)`, `a:1 c:2`, true}, + {`a <-> (b|c)`, `a:1 d:2`, false}, + {`a <-> (!b|c)`, `a:1 b:2`, false}, + {`a <-> (!b|c)`, `a:1 c:2`, true}, + {`a <-> (!b|c)`, `a:1 d:2`, true}, + {`a <-> (b|!c)`, `a:1 b:2`, true}, + {`a <-> (b|!c)`, `a:1 c:2`, false}, + {`a <-> (b|!c)`, `a:1 d:2`, true}, + {`a <-> (!b|!c)`, `a:1 b:2`, true}, + {`a <-> (!b|!c)`, `a:1 c:2`, true}, + {`a <-> (!b|!c)`, `a:1 d:2`, true}, + // And on the RHS of a follows-by. + {`a <-> (b&c)`, `a:1 b:2`, false}, + {`a <-> (b&c)`, `a:1 c:2`, false}, + {`a <-> (b&c)`, `a:1 d:2`, false}, + {`a <-> (!b&c)`, `a:1 b:2`, false}, + {`a <-> (!b&c)`, `a:1 c:2`, true}, + {`a <-> (!b&c)`, `a:1 d:2`, false}, + {`a <-> (b&!c)`, `a:1 b:2`, true}, + {`a <-> (b&!c)`, `a:1 c:2`, false}, + {`a <-> (b&!c)`, `a:1 d:2`, false}, + {`a <-> (!b&!c)`, `a:1 b:2`, false}, + {`a <-> (!b&!c)`, `a:1 c:2`, false}, + {`a <-> (!b&!c)`, `a:1 d:2`, true}, + } + for _, tc := range tcs { + t.Log(tc) + q, err := ParseTSQuery(tc.query) + require.NoError(t, err) + v, err := ParseTSVector(tc.vector) + require.NoError(t, err) + eval, err := EvalTSQuery(q, v) + require.NoError(t, err) + + assert.Equal(t, tc.expected, eval) + } + + // This subtest runs all the test cases against PG to ensure the behavior is + // the same. + t.Run("ComparePG", func(t *testing.T) { + skip.IgnoreLint(t, "need to manually enable") + conn, err := pgx.Connect(context.Background(), "postgresql://jordan@localhost:5432") + require.NoError(t, err) + for _, tc := range tcs { + t.Log(tc) + + var actual bool + row := conn.QueryRow(context.Background(), "SELECT $1::TSQuery @@ $2::TSVector", + tc.query, tc.vector, + ) + require.NoError(t, row.Scan(&actual)) + assert.Equal(t, tc.expected, actual) + } + }) +} From c637e6467f062d34a78f2938cd3c5724ab8cd5ec Mon Sep 17 00:00:00 2001 From: Kai Sun Date: Mon, 10 Oct 2022 09:04:25 -0400 Subject: [PATCH 4/5] allocatorimpl: Prioritize non-voters in voter additions Previously, when adding a voter via the allocator, we would not prioritize stores with non-voters when selecting candidate stores. This was inadequate because selecting a store without a non-voter would require sending a snapshot over the WAN in order to add a voter onto that store. On the other hand, selecting a store with a non-voter would only require promoting that non-voter to a voter, no snapshot needs to be sent over the WAN. To address this, this patch determines if candidate stores have a non-voter replica and prioritize this status when sorting candidate stores (priority lower than balance score, but higher than range count). By prioritizing selecting a store with a non-voter, we can reduce the number of snapshots that need to be sent over the WAN. Release note (ops change): We prioritized non-voters in voter additions, meaning that when selecting a store to add a voter on (in the allocator), we would prioritize candidate stores that contain a non-voter replica higher. This allows us to reduce the number of snapshots that need to be sent over the WAN. --- .../allocator/allocatorimpl/allocator.go | 2 + .../allocatorimpl/allocator_scorer.go | 56 ++++-- .../allocator/allocatorimpl/allocator_test.go | 118 +++++++++++ pkg/kv/kvserver/replicate_queue_test.go | 190 +++++++++++++++++- 4 files changed, 347 insertions(+), 19 deletions(-) diff --git a/pkg/kv/kvserver/allocator/allocatorimpl/allocator.go b/pkg/kv/kvserver/allocator/allocatorimpl/allocator.go index 454ba69f2d10..10b73e73eb39 100644 --- a/pkg/kv/kvserver/allocator/allocatorimpl/allocator.go +++ b/pkg/kv/kvserver/allocator/allocatorimpl/allocator.go @@ -1028,10 +1028,12 @@ func (a *Allocator) AllocateTargetFromList( candidateStores, constraintsChecker, existingReplicaSet, + existingNonVoters, a.StorePool.GetLocalitiesByStore(existingReplicaSet), a.StorePool.IsStoreReadyForRoutineReplicaTransfer, allowMultipleReplsPerNode, options, + targetType, ) log.KvDistribution.VEventf(ctx, 3, "allocate %s: %s", targetType, candidates) diff --git a/pkg/kv/kvserver/allocator/allocatorimpl/allocator_scorer.go b/pkg/kv/kvserver/allocator/allocatorimpl/allocator_scorer.go index 8528b20b0c75..930368d20324 100644 --- a/pkg/kv/kvserver/allocator/allocatorimpl/allocator_scorer.go +++ b/pkg/kv/kvserver/allocator/allocatorimpl/allocator_scorer.go @@ -584,15 +584,16 @@ type candidate struct { l0SubLevels int convergesScore int balanceScore balanceStatus + hasNonVoter bool rangeCount int details string } func (c candidate) String() string { str := fmt.Sprintf("s%d, valid:%t, fulldisk:%t, necessary:%t, diversity:%.2f, highReadAmp: %t, l0SubLevels: %d, converges:%d, "+ - "balance:%d, rangeCount:%d, queriesPerSecond:%.2f", + "balance:%d, hasNonVoter:%t, rangeCount:%d, queriesPerSecond:%.2f", c.store.StoreID, c.valid, c.fullDisk, c.necessary, c.diversityScore, c.highReadAmp, c.l0SubLevels, c.convergesScore, - c.balanceScore, c.rangeCount, c.store.Capacity.QueriesPerSecond) + c.balanceScore, c.hasNonVoter, c.rangeCount, c.store.Capacity.QueriesPerSecond) if c.details != "" { return fmt.Sprintf("%s, details:(%s)", str, c.details) } @@ -640,54 +641,60 @@ func (c candidate) less(o candidate) bool { // candidate is. func (c candidate) compare(o candidate) float64 { if !o.valid { - return 60 + return 600 } if !c.valid { - return -60 + return -600 } if o.fullDisk { - return 50 + return 500 } if c.fullDisk { - return -50 + return -500 } if c.necessary != o.necessary { if c.necessary { - return 40 + return 400 } - return -40 + return -400 } if !scoresAlmostEqual(c.diversityScore, o.diversityScore) { if c.diversityScore > o.diversityScore { - return 30 + return 300 } - return -30 + return -300 } // If both o and c have high read amplification, then we prefer the // canidate with lower read amp. if o.highReadAmp && c.highReadAmp { if o.l0SubLevels > c.l0SubLevels { - return 25 + return 250 } } if c.highReadAmp { - return -25 + return -250 } if o.highReadAmp { - return 25 + return 250 } if c.convergesScore != o.convergesScore { if c.convergesScore > o.convergesScore { - return 2 + float64(c.convergesScore-o.convergesScore)/10.0 + return 200 + float64(c.convergesScore-o.convergesScore)/10.0 } - return -(2 + float64(o.convergesScore-c.convergesScore)/10.0) + return -(200 + float64(o.convergesScore-c.convergesScore)/10.0) } if c.balanceScore != o.balanceScore { if c.balanceScore > o.balanceScore { - return 1 + (float64(c.balanceScore-o.balanceScore))/10.0 + return 150 + (float64(c.balanceScore-o.balanceScore))/10.0 } - return -(1 + (float64(o.balanceScore-c.balanceScore))/10.0) + return -(150 + (float64(o.balanceScore-c.balanceScore))/10.0) + } + if c.hasNonVoter != o.hasNonVoter { + if c.hasNonVoter { + return 100 + } + return -100 } // Sometimes we compare partially-filled in candidates, e.g. those with // diversity scores filled in but not balance scores or range counts. This @@ -736,6 +743,7 @@ func (c byScoreAndID) Less(i, j int) bool { if scoresAlmostEqual(c[i].diversityScore, c[j].diversityScore) && c[i].convergesScore == c[j].convergesScore && c[i].balanceScore == c[j].balanceScore && + c[i].hasNonVoter == c[j].hasNonVoter && c[i].rangeCount == c[j].rangeCount && c[i].necessary == c[j].necessary && c[i].fullDisk == c[j].fullDisk && @@ -770,7 +778,8 @@ func (cl candidateList) best() candidateList { if cl[i].necessary == cl[0].necessary && scoresAlmostEqual(cl[i].diversityScore, cl[0].diversityScore) && cl[i].convergesScore == cl[0].convergesScore && - cl[i].balanceScore == cl[0].balanceScore { + cl[i].balanceScore == cl[0].balanceScore && + cl[i].hasNonVoter == cl[0].hasNonVoter { continue } return cl[:i] @@ -938,13 +947,16 @@ func rankedCandidateListForAllocation( candidateStores storepool.StoreList, constraintsCheck constraintsCheckFn, existingReplicas []roachpb.ReplicaDescriptor, + nonVoterReplicas []roachpb.ReplicaDescriptor, existingStoreLocalities map[roachpb.StoreID]roachpb.Locality, isStoreValidForRoutineReplicaTransfer func(context.Context, roachpb.StoreID) bool, allowMultipleReplsPerNode bool, options ScorerOptions, + targetType TargetReplicaType, ) candidateList { var candidates candidateList existingReplTargets := roachpb.MakeReplicaSet(existingReplicas).ReplicationTargets() + var nonVoterReplTargets []roachpb.ReplicationTarget for _, s := range candidateStores.Stores { // Disregard all the stores that already have replicas. if StoreHasReplica(s.StoreID, existingReplTargets) { @@ -1009,6 +1021,13 @@ func rankedCandidateListForAllocation( }).balanceScore(candidateStores, s.Capacity) } } + var hasNonVoter bool + if targetType == VoterTarget { + if nonVoterReplTargets == nil { + nonVoterReplTargets = roachpb.MakeReplicaSet(nonVoterReplicas).ReplicationTargets() + } + hasNonVoter = StoreHasReplica(s.StoreID, nonVoterReplTargets) + } candidates = append(candidates, candidate{ store: s, valid: constraintsOK, @@ -1016,6 +1035,7 @@ func rankedCandidateListForAllocation( diversityScore: diversityScore, convergesScore: convergesScore, balanceScore: balanceScore, + hasNonVoter: hasNonVoter, rangeCount: int(s.Capacity.RangeCount), }) } diff --git a/pkg/kv/kvserver/allocator/allocatorimpl/allocator_test.go b/pkg/kv/kvserver/allocator/allocatorimpl/allocator_test.go index 5531ef96c355..34c4d50859ce 100644 --- a/pkg/kv/kvserver/allocator/allocatorimpl/allocator_test.go +++ b/pkg/kv/kvserver/allocator/allocatorimpl/allocator_test.go @@ -3409,10 +3409,12 @@ func TestAllocateCandidatesExcludeNonReadyNodes(t *testing.T) { sl, allocationConstraintsChecker, existingRepls, + nil, a.StorePool.GetLocalitiesByStore(existingRepls), a.StorePool.IsStoreReadyForRoutineReplicaTransfer, false, /* allowMultipleReplsPerNode */ a.ScorerOptions(ctx), + VoterTarget, ) if !expectedStoreIDsMatch(tc.expected, candidates) { @@ -3751,10 +3753,12 @@ func TestAllocateCandidatesNumReplicasConstraints(t *testing.T) { sl, checkFn, existingRepls, + nil, a.StorePool.GetLocalitiesByStore(existingRepls), func(context.Context, roachpb.StoreID) bool { return true }, false, /* allowMultipleReplsPerNode */ a.ScorerOptions(ctx), + VoterTarget, ) best := candidates.best() match := true @@ -8178,3 +8182,117 @@ func initTestStores(testStores []testStore, firstRangeSize int64, firstStoreQPS // Initialize the cluster with a single range. testStores[0].add(firstRangeSize, firstStoreQPS) } + +func TestNonVoterPrioritizationInVoterAdditions(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + stores := []*roachpb.StoreDescriptor{ + { + StoreID: 1, + Node: roachpb.NodeDescriptor{NodeID: 1}, + Capacity: roachpb.StoreCapacity{Capacity: 200, Available: 100, RangeCount: 514}, + }, + { + StoreID: 2, + Node: roachpb.NodeDescriptor{NodeID: 2}, + Capacity: roachpb.StoreCapacity{Capacity: 200, Available: 100, RangeCount: 520}, + }, + { + StoreID: 3, + Node: roachpb.NodeDescriptor{NodeID: 3}, + Capacity: roachpb.StoreCapacity{Capacity: 200, Available: 100, RangeCount: 470}, + }, + { + StoreID: 4, + Node: roachpb.NodeDescriptor{ + NodeID: 4, + Attrs: roachpb.Attributes{Attrs: []string{"a"}}, + }, + Capacity: roachpb.StoreCapacity{Capacity: 200, Available: 100, RangeCount: 480}, + }, + { + StoreID: 5, + Node: roachpb.NodeDescriptor{NodeID: 5}, + Capacity: roachpb.StoreCapacity{Capacity: 200, Available: 100, RangeCount: 516}, + }, + { + StoreID: 6, + Node: roachpb.NodeDescriptor{NodeID: 6}, + Capacity: roachpb.StoreCapacity{Capacity: 200, Available: 100, RangeCount: 500}, + }, + } + + ctx := context.Background() + stopper, g, _, a, _ := CreateTestAllocator(ctx, 10, false /* deterministic */) + defer stopper.Stop(ctx) + gossiputil.NewStoreGossiper(g).GossipStores(stores, t) + + testCases := []struct { + existingVoters []roachpb.ReplicaDescriptor + existingNonVoters []roachpb.ReplicaDescriptor + spanConfig roachpb.SpanConfig + expectedTargetAllocate roachpb.ReplicationTarget + }{ + // NB: Store 5 has a non-voter and range count 516 and store 4 can add a + // voter and has range count 480. Allocator should select store 5 with the + // non-voter despite having higher range count than store 4 with the voter. + { + existingVoters: []roachpb.ReplicaDescriptor{ + {NodeID: 1, StoreID: 1}, + {NodeID: 2, StoreID: 2}, + {NodeID: 3, StoreID: 3}, + }, + existingNonVoters: []roachpb.ReplicaDescriptor{ + {NodeID: 5, StoreID: 5}, + }, + spanConfig: emptySpanConfig(), + expectedTargetAllocate: roachpb.ReplicationTarget{NodeID: 5, StoreID: 5}, + }, + // NB: Store 3 can add a voter and has range count 470 (underfull), and + // stores 4 and 5 have non-voters and range counts > 475 (not underfull). + // Allocator should select store 3 with the lower balance score despite + // not having a non-voter like stores 4 and 5. + { + existingVoters: []roachpb.ReplicaDescriptor{ + {NodeID: 1, StoreID: 1}, + {NodeID: 2, StoreID: 2}, + }, + existingNonVoters: []roachpb.ReplicaDescriptor{ + {NodeID: 4, StoreID: 4}, + {NodeID: 5, StoreID: 5}, + }, + spanConfig: emptySpanConfig(), + expectedTargetAllocate: roachpb.ReplicationTarget{NodeID: 3, StoreID: 3}, + }, + // NB: Store 4 can add a voter and satisfies the voter constraints, and + // store 5 has a non-voter but does not satisfy the voter constraints. + // Allocator should select store 4 that satisfies the voter constraints + // despite not having a non-voter like store 5. + { + existingVoters: []roachpb.ReplicaDescriptor{ + {NodeID: 1, StoreID: 1}, + {NodeID: 2, StoreID: 2}, + {NodeID: 3, StoreID: 3}, + }, + existingNonVoters: []roachpb.ReplicaDescriptor{ + {NodeID: 5, StoreID: 5}, + }, + spanConfig: roachpb.SpanConfig{ + VoterConstraints: []roachpb.ConstraintsConjunction{ + { + Constraints: []roachpb.Constraint{ + {Value: "a", Type: roachpb.Constraint_REQUIRED}, + }, + }, + }, + }, + expectedTargetAllocate: roachpb.ReplicationTarget{NodeID: 4, StoreID: 4}, + }, + } + + for i, tc := range testCases { + result, _, _ := a.AllocateVoter(ctx, tc.spanConfig, tc.existingVoters, tc.existingNonVoters, Alive) + assert.Equal(t, tc.expectedTargetAllocate, result, "Unexpected replication target returned by allocate voter in test %d", i) + } +} diff --git a/pkg/kv/kvserver/replicate_queue_test.go b/pkg/kv/kvserver/replicate_queue_test.go index a9df99fa0dbb..6a7d9bd77753 100644 --- a/pkg/kv/kvserver/replicate_queue_test.go +++ b/pkg/kv/kvserver/replicate_queue_test.go @@ -1517,7 +1517,7 @@ func filterRangeLog( eventType kvserverpb.RangeLogEventType, reason kvserverpb.RangeLogEventReason, ) ([]kvserverpb.RangeLogEvent_Info, error) { - return queryRangeLog(conn, `SELECT info FROM system.rangelog WHERE "rangeID" = $1 AND "eventType" = $2 AND info LIKE concat('%', $3, '%');`, rangeID, eventType.String(), reason) + return queryRangeLog(conn, `SELECT info FROM system.rangelog WHERE "rangeID" = $1 AND "eventType" = $2 AND info LIKE concat('%', $3, '%') ORDER BY timestamp ASC;`, rangeID, eventType.String(), reason) } func toggleReplicationQueues(tc *testcluster.TestCluster, active bool) { @@ -1964,3 +1964,191 @@ func TestReplicateQueueAcquiresInvalidLeases(t *testing.T) { return nil }) } + +func iterateOverAllStores( + t *testing.T, tc *testcluster.TestCluster, f func(*kvserver.Store) error, +) { + for _, server := range tc.Servers { + require.NoError(t, server.Stores().VisitStores(f)) + } +} + +// TestPromoteNonVoterInAddVoter tests the prioritization of promoting +// non-voters when switching from ZONE to REGION survival i.e. +// +// ZONE survival configuration: +// Region 1: Voter, Voter, Voter +// Region 2: Non-Voter +// Region 3: Non-Voter +// to REGION survival configuration: +// Region 1: Voter, Voter +// Region 2: Voter, Voter +// Region 3: Voter +// +// Here we have 7 stores: 3 in Region 1, 2 in Region 2, and 2 in Region 3. +// +// The expected behaviour is that there should not be any add voter events in +// the range log where the added replica type is a LEARNER. +func TestPromoteNonVoterInAddVoter(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + // This test is slow under stress and can time out when upreplicating / + // rebalancing to ensure all stores have the same range count initially, due + // to slow heartbeats. + skip.UnderStress(t) + + ctx := context.Background() + + // Create 7 stores: 3 in Region 1, 2 in Region 2, and 2 in Region 3. + const numNodes = 7 + serverArgs := make(map[int]base.TestServerArgs) + regions := [numNodes]int{1, 1, 1, 2, 2, 3, 3} + for i := 0; i < numNodes; i++ { + serverArgs[i] = base.TestServerArgs{ + Locality: roachpb.Locality{ + Tiers: []roachpb.Tier{ + { + Key: "region", Value: strconv.Itoa(regions[i]), + }, + }, + }, + } + } + + // Start test cluster. + clusterArgs := base.TestClusterArgs{ + ReplicationMode: base.ReplicationAuto, + ServerArgsPerNode: serverArgs, + } + tc := testcluster.StartTestCluster(t, numNodes, clusterArgs) + defer tc.Stopper().Stop(ctx) + db := tc.ServerConn(0) + + setConstraintFn := func(object string, numReplicas, numVoters int, additionalConstraints string) { + _, err := db.Exec( + fmt.Sprintf("ALTER %s CONFIGURE ZONE USING num_replicas = %d, num_voters = %d%s", + object, numReplicas, numVoters, additionalConstraints)) + require.NoError(t, err) + } + + // Ensure all stores have the same range count initially, to allow for more + // predictable behaviour when the allocator ranks stores using balance score. + setConstraintFn("DATABASE system", 7, 7, "") + setConstraintFn("RANGE system", 7, 7, "") + setConstraintFn("RANGE liveness", 7, 7, "") + setConstraintFn("RANGE meta", 7, 7, "") + setConstraintFn("RANGE default", 7, 7, "") + testutils.SucceedsSoon(t, func() error { + if err := forceScanOnAllReplicationQueues(tc); err != nil { + return err + } + rangeCount := -1 + allEqualRangeCount := true + iterateOverAllStores(t, tc, func(s *kvserver.Store) error { + if rangeCount == -1 { + rangeCount = s.ReplicaCount() + } else if rangeCount != s.ReplicaCount() { + allEqualRangeCount = false + } + return nil + }) + if !allEqualRangeCount { + return errors.New("Range counts are not all equal") + } + return nil + }) + + // Create a new range to simulate switching from ZONE to REGION survival. + _, err := db.Exec("CREATE TABLE t (i INT PRIMARY KEY, s STRING)") + require.NoError(t, err) + + // ZONE survival configuration. + setConstraintFn("TABLE t", 5, 3, + ", constraints = '{\"+region=2\": 1, \"+region=3\": 1}', voter_constraints = '{\"+region=1\": 3}'") + + // computeNumberOfReplicas is used to find the number of voters and + // non-voters to check if we are meeting our zone configuration. + computeNumberOfReplicas := func( + t *testing.T, + tc *testcluster.TestCluster, + db *gosql.DB, + ) (numVoters, numNonVoters int, err error) { + if err := forceScanOnAllReplicationQueues(tc); err != nil { + return 0, 0, err + } + + var rangeID roachpb.RangeID + if err := db.QueryRow("select range_id from [show ranges from table t] limit 1").Scan(&rangeID); err != nil { + return 0, 0, err + } + iterateOverAllStores(t, tc, func(s *kvserver.Store) error { + if replica, err := s.GetReplica(rangeID); err == nil && replica.OwnsValidLease(ctx, replica.Clock().NowAsClockTimestamp()) { + desc := replica.Desc() + numVoters = len(desc.Replicas().VoterDescriptors()) + numNonVoters = len(desc.Replicas().NonVoterDescriptors()) + } + return nil + }) + return numVoters, numNonVoters, nil + } + + // Ensure we are meeting our ZONE survival configuration. + testutils.SucceedsSoon(t, func() error { + numVoters, numNonVoters, err := computeNumberOfReplicas(t, tc, db) + require.NoError(t, err) + if numVoters != 3 { + return errors.Newf("expected 3 voters; got %d", numVoters) + } + if numNonVoters != 2 { + return errors.Newf("expected 2 non-voters; got %v", numNonVoters) + } + return nil + }) + + // REGION survival configuration. + setConstraintFn("TABLE t", 5, 5, + ", constraints = '{}', voter_constraints = '{\"+region=1\": 2, \"+region=2\": 2, \"+region=3\": 1}'") + require.NoError(t, err) + + // Ensure we are meeting our REGION survival configuration. + testutils.SucceedsSoon(t, func() error { + numVoters, numNonVoters, err := computeNumberOfReplicas(t, tc, db) + require.NoError(t, err) + if numVoters != 5 { + return errors.Newf("expected 5 voters; got %d", numVoters) + } + if numNonVoters != 0 { + return errors.Newf("expected 0 non-voters; got %v", numNonVoters) + } + return nil + }) + + // Retrieve the add voter events from the range log. + var rangeID roachpb.RangeID + err = db.QueryRow("select range_id from [show ranges from table t] limit 1").Scan(&rangeID) + require.NoError(t, err) + addVoterEvents, err := filterRangeLog(tc.Conns[0], rangeID, kvserverpb.RangeLogEventType_add_voter, kvserverpb.ReasonRangeUnderReplicated) + require.NoError(t, err) + + // Check if an add voter event has an added replica of type LEARNER, and if + // it does, it shows that we are adding a new voter rather than promoting an + // existing non-voter, which is unexpected. + for _, addVoterEvent := range addVoterEvents { + switch addVoterEvent.AddedReplica.Type { + case roachpb.LEARNER: + require.Failf( + t, + "Expected to promote non-voter, instead added voter", + "Added voter store ID: %v\nAdd voter events: %v", + addVoterEvent.AddedReplica.StoreID, addVoterEvents) + case roachpb.VOTER_FULL: + default: + require.Failf( + t, + "Unexpected added replica type", + "Replica type: %v\nAdd voter events: %v", + addVoterEvent.AddedReplica.Type, addVoterEvents) + } + } +} From a517c9fe94ed548cad0e973d22d84c18c88f7018 Mon Sep 17 00:00:00 2001 From: Jordan Lewis Date: Wed, 26 Oct 2022 22:36:31 -0400 Subject: [PATCH 5/5] tsearch: add more explanatory comments ... and rename a few structs to be canonical. Release note: None --- pkg/util/tsearch/BUILD.bazel | 1 + pkg/util/tsearch/eval.go | 86 +++++++++++++++++++++-------- pkg/util/tsearch/eval_test.go | 10 +++- pkg/util/tsearch/lex.go | 60 +++++++++++++------- pkg/util/tsearch/tsquery.go | 63 +++++++++++++++++++-- pkg/util/tsearch/tsquery_test.go | 12 ++++ pkg/util/tsearch/tsvector.go | 92 +++++++++++++++++-------------- pkg/util/tsearch/tsvector_test.go | 34 +++++++++++- 8 files changed, 263 insertions(+), 95 deletions(-) diff --git a/pkg/util/tsearch/BUILD.bazel b/pkg/util/tsearch/BUILD.bazel index af90c2c62508..cf5332d80551 100644 --- a/pkg/util/tsearch/BUILD.bazel +++ b/pkg/util/tsearch/BUILD.bazel @@ -25,6 +25,7 @@ go_test( "tsquery_test.go", "tsvector_test.go", ], + args = ["-test.timeout=295s"], embed = [":tsearch"], deps = [ "//pkg/testutils/skip", diff --git a/pkg/util/tsearch/eval.go b/pkg/util/tsearch/eval.go index e2344ecdfc91..0933efea90c2 100644 --- a/pkg/util/tsearch/eval.go +++ b/pkg/util/tsearch/eval.go @@ -37,9 +37,12 @@ func (e *tsEvaluator) eval() (bool, error) { return e.evalNode(e.q.root) } +// evalNode is used to evaluate a query node that's not nested within any +// followed by operators. it returns true if the match was successful. func (e *tsEvaluator) evalNode(node *tsNode) (bool, error) { switch node.op { case invalid: + // If there's no operator we're evaluating a leaf term. prefixMatch := false if len(node.term.positions) > 0 && node.term.positions[0].weight == weightStar { prefixMatch = true @@ -59,12 +62,14 @@ func (e *tsEvaluator) evalNode(node *tsNode) (bool, error) { } return false, nil case and: + // Match if both operands are true. l, err := e.evalNode(node.l) if err != nil || !l { return false, err } return e.evalNode(node.r) case or: + // Match if either operand is true. l, err := e.evalNode(node.l) if err != nil { return false, err @@ -74,38 +79,69 @@ func (e *tsEvaluator) evalNode(node *tsNode) (bool, error) { } return e.evalNode(node.r) case not: + // Match if the operand is false. ret, err := e.evalNode(node.l) return !ret, err case followedby: + // For followed-by queries, we recurse into the special followed-by handler. + // Then, we return true if there is at least one position at which the + // followed-by query matches. positions, err := e.evalWithinFollowedBy(node) return len(positions.positions) > 0, err } return false, errors.AssertionFailedf("invalid operator %d", node.op) } -type tspositionset struct { - positions []tsposition - width int - invert bool +// tsPositionSet keeps track of metadata for a followed-by match. It's used to +// pass information about followed by queries during evaluation of them. +type tsPositionSet struct { + // positions is the list of positions that the match is successful at (or, + // if invert is true, unsuccessful at). + positions []tsPosition + // width is the width of the match. This is important to track to deal with + // chained followed by queries with possibly different widths (<-> vs <2> etc). + // A match of a single term within a followed by has width 0. + width int + // invert, if true, indicates that this match should be inverted. It's used + // to handle followed by matches within not operators. + invert bool } +// emitMode is a bitfield that controls the output of followed by matches. type emitMode int const ( + // emitMatches causes evalFollowedBy to emit matches - positions at which + // the left argument is found separated from the right argument by the right + // width. emitMatches emitMode = 1 << iota + // emitLeftUnmatched causes evalFollowedBy to emit places at which the left + // arm doesn't match. emitLeftUnmatched + // emitRightUnmatched causes evalFollowedBy to emit places at which the right + // arm doesn't match. emitRightUnmatched ) +// evalFollowedBy handles evaluating a followed by operator. It needs +// information about the positions at which the left and right arms of the +// followed by operator matches, as well as the offsets for each of the arms: +// the number of lexemes apart each of the matches were. +// the emitMode controls the output - see the comments on each of the emitMode +// values for details. +// This function is a little bit confusing, because it's operating on two +// input position sets, and not directly on search terms. Its job is to do set +// operations on the input sets, depending on emitMode - an intersection or +// difference depending on the desired outcome by evalWithinFollowedBy. func (e *tsEvaluator) evalFollowedBy( - lPositions, rPositions tspositionset, lOffset, rOffset int, emitMode emitMode, -) (tspositionset, error) { + lPositions, rPositions tsPositionSet, lOffset, rOffset int, emitMode emitMode, +) (tsPositionSet, error) { // Followed by makes sure that two terms are separated by exactly n words. // First, find all slots that match for the left expression. // Find the offsetted intersection of 2 sorted integer lists, using the // followedN as the offset. - var ret tspositionset + var ret tsPositionSet var lIdx, rIdx int // Loop through the two sorted position lists, until the position on the // right is as least as large as the position on the left. @@ -139,18 +175,18 @@ func (e *tsEvaluator) evalFollowedBy( if lPos < rPos { if emitMode&emitLeftUnmatched > 0 { - ret.positions = append(ret.positions, tsposition{position: lPos}) + ret.positions = append(ret.positions, tsPosition{position: lPos}) } lIdx++ } else if lPos == rPos { if emitMode&emitMatches > 0 { - ret.positions = append(ret.positions, tsposition{position: rPos}) + ret.positions = append(ret.positions, tsPosition{position: rPos}) } lIdx++ rIdx++ } else { if emitMode&emitRightUnmatched > 0 { - ret.positions = append(ret.positions, tsposition{position: rPos}) + ret.positions = append(ret.positions, tsPosition{position: rPos}) } rIdx++ } @@ -162,9 +198,10 @@ func (e *tsEvaluator) evalFollowedBy( // operator. Instead of just returning true or false, and possibly short // circuiting on boolean ops, we need to return all of the tspositions at which // each arm of the followed by expression matches. -func (e *tsEvaluator) evalWithinFollowedBy(node *tsNode) (tspositionset, error) { +func (e *tsEvaluator) evalWithinFollowedBy(node *tsNode) (tsPositionSet, error) { switch node.op { case invalid: + // We're evaluating a leaf (a term). prefixMatch := false if len(node.term.positions) > 0 && node.term.positions[0].weight == weightStar { prefixMatch = true @@ -177,9 +214,9 @@ func (e *tsEvaluator) evalWithinFollowedBy(node *tsNode) (tspositionset, error) }) if i >= len(e.v) { // No match. - return tspositionset{}, nil + return tsPositionSet{}, nil } - var ret []tsposition + var ret []tsPosition if prefixMatch { for j := i; j < len(e.v); j++ { t := e.v[j] @@ -189,22 +226,23 @@ func (e *tsEvaluator) evalWithinFollowedBy(node *tsNode) (tspositionset, error) ret = append(ret, t.positions...) } ret = sortAndUniqTSPositions(ret) - return tspositionset{positions: ret}, nil + return tsPositionSet{positions: ret}, nil } else if e.v[i].lexeme != target { // No match. - return tspositionset{}, nil + return tsPositionSet{}, nil } - return tspositionset{positions: e.v[i].positions}, nil + // Return all of the positions at which the term is present. + return tsPositionSet{positions: e.v[i].positions}, nil case or: var lOffset, rOffset, width int lPositions, err := e.evalWithinFollowedBy(node.l) if err != nil { - return tspositionset{}, err + return tsPositionSet{}, err } rPositions, err := e.evalWithinFollowedBy(node.r) if err != nil { - return tspositionset{}, err + return tsPositionSet{}, err } width = lPositions.width @@ -236,7 +274,7 @@ func (e *tsEvaluator) evalWithinFollowedBy(node *tsNode) (tspositionset, error) case not: ret, err := e.evalWithinFollowedBy(node.l) if err != nil { - return tspositionset{}, err + return tsPositionSet{}, err } ret.invert = !ret.invert return ret, nil @@ -248,11 +286,11 @@ func (e *tsEvaluator) evalWithinFollowedBy(node *tsNode) (tspositionset, error) lPositions, err := e.evalWithinFollowedBy(node.l) if err != nil { - return tspositionset{}, err + return tsPositionSet{}, err } rPositions, err := e.evalWithinFollowedBy(node.r) if err != nil { - return tspositionset{}, err + return tsPositionSet{}, err } if node.op == followedby { lOffset = node.followedN + rPositions.width @@ -284,12 +322,12 @@ func (e *tsEvaluator) evalWithinFollowedBy(node *tsNode) (tspositionset, error) ret.width = width return ret, err } - return tspositionset{}, errors.AssertionFailedf("invalid operator %d", node.op) + return tsPositionSet{}, errors.AssertionFailedf("invalid operator %d", node.op) } -// sortAndUniqTSPositions sorts and uniquifies the input tsposition list by +// sortAndUniqTSPositions sorts and uniquifies the input tsPosition list by // their position attributes. -func sortAndUniqTSPositions(pos []tsposition) []tsposition { +func sortAndUniqTSPositions(pos []tsPosition) []tsPosition { if len(pos) <= 1 { return pos } diff --git a/pkg/util/tsearch/eval_test.go b/pkg/util/tsearch/eval_test.go index 2c87e4219a86..deccbaff2160 100644 --- a/pkg/util/tsearch/eval_test.go +++ b/pkg/util/tsearch/eval_test.go @@ -61,6 +61,7 @@ func TestEval(t *testing.T) { {`a <2> b`, `a:1 b:3`, true}, {`a:* <0> ab:*`, `abba:1`, true}, {`a:* <0> ab:*`, `a:1`, false}, + {`a:* <1> abc:*`, `a:1 abd:2`, false}, // Negations. {`a <-> !b`, `a:1 b:2`, false}, @@ -77,7 +78,6 @@ func TestEval(t *testing.T) { {`!a <-> !b`, `a:3 b:4`, true}, {`!a <-> !b`, `a:3`, true}, - //{`!a <-> !b`, ``, true}, // Or on the RHS of a follows-by. {`a <-> (b|c)`, `a:1 b:2`, true}, {`a <-> (b|c)`, `a:1 c:2`, true}, @@ -91,6 +91,8 @@ func TestEval(t *testing.T) { {`a <-> (!b|!c)`, `a:1 b:2`, true}, {`a <-> (!b|!c)`, `a:1 c:2`, true}, {`a <-> (!b|!c)`, `a:1 d:2`, true}, + {`a <-> ((b <-> c) | d)`, `a:1 b:2 c:3 d:4`, true}, + {`a <-> (b | (c <-> d))`, `a:1 b:2 c:3 d:4`, true}, // And on the RHS of a follows-by. {`a <-> (b&c)`, `a:1 b:2`, false}, {`a <-> (b&c)`, `a:1 c:2`, false}, @@ -104,6 +106,8 @@ func TestEval(t *testing.T) { {`a <-> (!b&!c)`, `a:1 b:2`, false}, {`a <-> (!b&!c)`, `a:1 c:2`, false}, {`a <-> (!b&!c)`, `a:1 d:2`, true}, + {`a <-> ((b <-> c) & d)`, `a:1 b:2 c:3 d:4`, false}, + {`a <-> (b & (c <-> d))`, `a:1 b:2 c:3 d:4`, false}, } for _, tc := range tcs { t.Log(tc) @@ -120,6 +124,10 @@ func TestEval(t *testing.T) { // This subtest runs all the test cases against PG to ensure the behavior is // the same. t.Run("ComparePG", func(t *testing.T) { + // This test can be manually run by pointing it to a local Postgres. There + // are no requirements for the contents of the local Postgres - it just runs + // expressions. The test validates that all of the test cases in this test + // file work the same in Postgres as they do this package. skip.IgnoreLint(t, "need to manually enable") conn, err := pgx.Connect(context.Background(), "postgresql://jordan@localhost:5432") require.NoError(t, err) diff --git a/pkg/util/tsearch/lex.go b/pkg/util/tsearch/lex.go index ee0ae85c6713..10f5c462d2c3 100644 --- a/pkg/util/tsearch/lex.go +++ b/pkg/util/tsearch/lex.go @@ -37,6 +37,8 @@ const ( expectingPosDelimiter ) +// tsVectorLexer is a lexing state machine for the TSVector and TSQuery input +// formats. See the comment above lex() for more details. type tsVectorLexer struct { input string lastLen int @@ -44,7 +46,7 @@ type tsVectorLexer struct { state tsVectorParseState // If true, we're in "TSQuery lexing mode" - tsquery bool + tsQuery bool } func (p *tsVectorLexer) back() { @@ -59,7 +61,8 @@ func (p *tsVectorLexer) advance() rune { return r } -// lex lexes the input in the receiver according to the TSVector "grammar". +// lex lexes the input in the receiver according to the TSVector "grammar", or +// according the TSQuery "grammar" if tsQuery is set to true. // // A simple TSVector input could look like this: // @@ -73,18 +76,29 @@ func (p *tsVectorLexer) advance() rune { // without any whitespace in between (if there is no position list on the word). // In a single-quote wrapped word, the word must terminate with a single quote. // All other characters are treated as literals. Backlashes can be used to -// escape single quotes, and are otherwise no-ops. +// escape single quotes, and are otherwise skipped, allowing the following +// character to be included as a literal (such as the backslash character itself). // // If a word is not single-quote wrapped, the next term will begin if there is // whitespace after the word. Whitespace and colons may be entered by escaping -// them with backslashes. All other uses of backslashes are no-ops. +// them with backslashes. All other uses of backslashes are skipped, allowing +// the following character to be included as a literal. // // A word is delimited from its position list with a colon. // // A position list is made up of a comma-delimited list of numbers, each // of which may have an optional "strength" which is a letter from A-D. // -// See examples in tsvector_test.go. +// In TSQuery mode, there are a few differences: +// - Terms must be separated with tsOperators (!, <->, |, &), not just spaces. +// - Terms may be surrounded by the ( ) grouping tokens. +// - Terms cannot include multiple positions. +// - Terms can include more than one "strength", as well as the * prefix search +// operator. For example, foo:3AC* +// +// See examples in tsvector_test.go and tsquery_test.go, and see the +// documentation in tsvector.go for more information and a link to the Postgres +// documentation that is the spec for all of this behavior. func (p tsVectorLexer) lex() (TSVector, error) { // termBuf will be reused as a temporary buffer to assemble each term before // copying into the vector. @@ -104,7 +118,7 @@ func (p tsVectorLexer) lex() (TSVector, error) { continue } - if p.tsquery { + if p.tsQuery { // Check for &, |, !, and <-> (or ) switch r { case '&': @@ -171,7 +185,7 @@ func (p tsVectorLexer) lex() (TSVector, error) { p.state = expectingTerm } else if r == ':' { lastTerm := &ret[len(ret)-1] - lastTerm.positions = append(lastTerm.positions, tsposition{}) + lastTerm.positions = append(lastTerm.positions, tsPosition{}) p.state = expectingPosList } else { p.state = expectingTerm @@ -185,7 +199,7 @@ func (p tsVectorLexer) lex() (TSVector, error) { continue } - if p.tsquery { + if p.tsQuery { switch r { case '&', '!', '|', '<', '(', ')': // These are all "operators" in the TSQuery language. End the current @@ -205,7 +219,7 @@ func (p tsVectorLexer) lex() (TSVector, error) { // Copy the termBuf into the vector, resize the termBuf, continue on. term := tsTerm{lexeme: string(termBuf)} if r == ':' { - term.positions = append(term.positions, tsposition{}) + term.positions = append(term.positions, tsPosition{}) } ret = append(ret, term) termBuf = termBuf[:0] @@ -216,13 +230,13 @@ func (p tsVectorLexer) lex() (TSVector, error) { } continue } - if p.tsquery && r == ':' { + if p.tsQuery && r == ':' { return p.syntaxError() } termBuf = append(termBuf, r) case expectingPosList: var pos int - if !p.tsquery { + if !p.tsQuery { // If we have nothing in our termBuf, we need to see at least one number. if unicode.IsNumber(r) { termBuf = append(termBuf, r) @@ -251,37 +265,41 @@ func (p tsVectorLexer) lex() (TSVector, error) { } switch r { case ',': - if p.tsquery { + if p.tsQuery { // Not valid! No , allowed in position lists in tsqueries. return ret, pgerror.Newf(pgcode.Syntax, "syntax error in TSVector: %s", p.input) } - lastTerm.positions = append(lastTerm.positions, tsposition{}) + lastTerm.positions = append(lastTerm.positions, tsPosition{}) // Expecting another number next. continue case '*': - if p.tsquery { + if p.tsQuery { lastTerm.positions[lastTermPos].weight |= weightStar } else { p.state = expectingPosDelimiter lastTerm.positions[lastTermPos].weight |= weightA } case 'a', 'A': - if !p.tsquery { + if !p.tsQuery { p.state = expectingPosDelimiter } lastTerm.positions[lastTermPos].weight |= weightA case 'b', 'B': - if !p.tsquery { + if !p.tsQuery { p.state = expectingPosDelimiter } lastTerm.positions[lastTermPos].weight |= weightB case 'c', 'C': - if !p.tsquery { + if !p.tsQuery { p.state = expectingPosDelimiter } lastTerm.positions[lastTermPos].weight |= weightC case 'd', 'D': - if p.tsquery { + // Weight D is handled differently in TSQuery parsing than TSVector. In + // TSVector parsing, the default is already D - so we don't record any + // weight at all. This matches Postgres behavior - a default D weight is + // not printed or stored. In TSQuery, we have to record it explicitly. + if p.tsQuery { lastTerm.positions[lastTermPos].weight |= weightD } else { p.state = expectingPosDelimiter @@ -293,7 +311,7 @@ func (p tsVectorLexer) lex() (TSVector, error) { if r == ',' { p.state = expectingPosList lastTerm := &ret[len(ret)-1] - lastTerm.positions = append(lastTerm.positions, tsposition{}) + lastTerm.positions = append(lastTerm.positions, tsPosition{}) } else if unicode.IsSpace(r) { p.state = expectingTerm } else { @@ -313,7 +331,7 @@ func (p tsVectorLexer) lex() (TSVector, error) { ret = append(ret, tsTerm{lexeme: string(termBuf)}) case expectingPosList: // Finish number. - if !p.tsquery { + if !p.tsQuery { if len(termBuf) == 0 { return p.syntaxError() } @@ -344,7 +362,7 @@ func (p tsVectorLexer) lex() (TSVector, error) { func (p *tsVectorLexer) syntaxError() (TSVector, error) { typ := "TSVector" - if p.tsquery { + if p.tsQuery { typ = "TSQuery" } return TSVector{}, pgerror.Newf(pgcode.Syntax, "syntax error in %s: %s", typ, p.input) diff --git a/pkg/util/tsearch/tsquery.go b/pkg/util/tsearch/tsquery.go index 32e67cf75c8b..1f101dcf9200 100644 --- a/pkg/util/tsearch/tsquery.go +++ b/pkg/util/tsearch/tsquery.go @@ -16,15 +16,62 @@ import ( "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" + "github.com/cockroachdb/errors" ) +// tsOperator is an enum that represents the different operators within a +// TSQuery. +type tsOperator int + +const ( + // Parentheses can be used to control nesting of the TSQuery operators. + // Without parentheses, | binds least tightly, + // then &, then <->, and ! most tightly. + + invalid tsOperator = iota + // and is the & operator, which requires both of its operands to exist in + // the searched document. + and + // or is the | operator, which requires one or more of its operands to exist + // in the searched document. + or + // not is the ! operator, which requires that its single operand doesn't exist + // in the searched document. + not + // followedby is the <-> operator. It can also be specified with a number like + // <1> or <2> or <3>. It requires that the left operand is followed by the right + // operand. The <-> and <1> forms mean that they should be directly followed + // by each other. A number indicates how many terms away the operands should be. + followedby + // lparen and rparen are grouping operators. They're just used in parsing and + // don't appear in the TSQuery tree. + lparen + rparen +) + +// precedence returns the parsing precedence of the receiver. A higher +// precedence means that the operator binds more tightly. +func (o tsOperator) precedence() int { + switch o { + case not: + return 4 + case followedby: + return 3 + case and: + return 2 + case or: + return 1 + } + panic(errors.AssertionFailedf("no precedence for operator %d", o)) +} + // tsNode represents a single AST node within the tree of a TSQuery. type tsNode struct { - // Only one of term, op, or paren will be set. + // Only one of term or op will be set. // If term is set, this is a leaf node containing a lexeme. term tsTerm // If op is set, this is an operator node: either not, and, or, or followedby. - op tsoperator + op tsOperator // set only when op is followedby. Indicates the number n within the // operator, which means the number of terms separating the left and the right // argument. @@ -79,8 +126,8 @@ func (n tsNode) UnambiguousString() string { return fmt.Sprintf("[%s%s%s]", n.l.UnambiguousString(), tsTerm{operator: n.op, followedN: n.followedN}, n.r.UnambiguousString()) } -// TSQuery represents a tsquery AST root. A tsquery is a tree of text search -// operators that can be run against a tsvector to produce a predicate of +// TSQuery represents a tsNode AST root. A TSQuery is a tree of text search +// operators that can be run against a TSVector to produce a predicate of // whether the query matched. type TSQuery struct { root *tsNode @@ -97,7 +144,7 @@ func lexTSQuery(input string) (TSVector, error) { parser := tsVectorLexer{ input: input, state: expectingTerm, - tsquery: true, + tsQuery: true, } return parser.lex() @@ -115,6 +162,8 @@ func ParseTSQuery(input string) (TSQuery, error) { return queryParser.parse() } +// tsQueryParser is a parser that operates on a set of lexed tokens, represented +// as the tsTerms in a TSVector. type tsQueryParser struct { input string terms TSVector @@ -148,6 +197,10 @@ func (p *tsQueryParser) parse() (TSQuery, error) { return TSQuery{root: expr}, nil } +// parseTSExpr is a "Pratt parser" which constructs a query tree out of the +// lexed tsTerms, respecting the precedence of the tsOperators. +// See this nice article about Pratt parsing, which this parser was adapted from: +// https://matklad.github.io/2020/04/13/simple-but-powerful-pratt-parsing.html func (p *tsQueryParser) parseTSExpr(minBindingPower int) (*tsNode, error) { t, ok := p.nextTerm() if !ok { diff --git a/pkg/util/tsearch/tsquery_test.go b/pkg/util/tsearch/tsquery_test.go index cce8b046a115..c16e7b06dab0 100644 --- a/pkg/util/tsearch/tsquery_test.go +++ b/pkg/util/tsearch/tsquery_test.go @@ -12,6 +12,7 @@ package tsearch import ( "context" + "strings" "testing" "github.com/cockroachdb/cockroach/pkg/testutils/skip" @@ -47,6 +48,7 @@ bar `, `'foo' 'bar'`}, {`'\:'`, `':'`}, {`'\ '`, `' '`}, {`\ `, `' '`}, + {`\\`, `'\'`}, {`blah'blah`, `'blah''blah'`}, {`blah'`, `'blah'''`}, @@ -145,6 +147,7 @@ func TestParseTSQuery(t *testing.T) { {`'\:'`, `':'`}, {`'\ '`, `' '`}, {`\ `, `' '`}, + {`\\`, `'\'`}, {`blah'blah`, `'blah''blah'`}, {`blah'`, `'blah'''`}, @@ -195,6 +198,10 @@ func TestParseTSQuery(t *testing.T) { } t.Run("ComparePG", func(t *testing.T) { + // This test can be manually run by pointing it to a local Postgres. There + // are no requirements for the contents of the local Postgres - it just runs + // expressions. The test validates that all of the test cases in this test + // file work the same in Postgres as they do this package. skip.IgnoreLint(t, "need to manually enable") conn, err := pgx.Connect(context.Background(), "postgresql://jordan@localhost:5432") require.NoError(t, err) @@ -204,6 +211,11 @@ func TestParseTSQuery(t *testing.T) { var actual string row := conn.QueryRow(context.Background(), "SELECT $1::TSQuery", tc.input) require.NoError(t, row.Scan(&actual)) + // To make sure that our tests behave as expected, we need to remove + // doubled backslashes from Postgres's output. Confusingly, even in the + // protocol that pgx uses, the actual text returned for a string + // containing a backslash doubles all backslash literals. + actual = strings.ReplaceAll(actual, `\\`, `\`) assert.Equal(t, tc.expectedStr, actual) } }) diff --git a/pkg/util/tsearch/tsvector.go b/pkg/util/tsearch/tsvector.go index 46746365c625..e0271c8e637c 100644 --- a/pkg/util/tsearch/tsvector.go +++ b/pkg/util/tsearch/tsvector.go @@ -17,18 +17,55 @@ import ( "strings" ) -// tsweight is a string A, B, C, or D. -type tsweight int +// This file defines the TSVector data structure, which is used to implement +// Postgres's tsvector text search mechanism. +// See https://www.postgresql.org/docs/current/datatype-textsearch.html for +// context on what each of the pieces do. +// +// TSVector is ultimately used to represent a document as a posting list - a +// list of lexemes in the doc (typically without stop words like common or short +// words, and typically stemmed using a stemming algorithm like snowball +// stemming (https://snowballstem.org/)), along with an associated set of +// positions that those lexemes occur within the document. +// +// Typically, this posting list is then stored in an inverted index within the +// database to accelerate searches of terms within the document. +// +// The key structures are: +// - tsTerm is a document term (also referred to as lexeme) along with a +// position list, which contains the positions within a document that a term +// appeared, along with an optional weight, which controls matching. +// - tsTerm is also used during parsing of both TSQueries and TSVectors, so a +// tsTerm also can represent an TSQuery operator. +// - tsWeight represents the weight of a given lexeme. It's also used for +// queries, when a "star" weight is available that matches any weight. +// - TSVector is a list of tsTerms, ordered by their lexeme. + +// tsWeight is a bitfield that represents the weight of a given term. When +// stored in a TSVector, only 1 of the bits will be set. The default weight is +// D - as a result, we store 0 for the weight of terms with weight D or no +// specified weight. The weightStar value is never set in a TSVector weight. +// +// tsWeight is also used inside of TSQueries, to specify the weight to search. +// Within TSQueries, the absence of a weight is the default, and indicates that +// the search term should match any matching term, regardless of its weight. If +// one or more of the weights are set in a search term, it indicates that the +// query should match only terms with the given weights. +type tsWeight int const ( - weightD tsweight = 1 << iota + // These enum values are a bitfield and must be kept in order. + weightD tsWeight = 1 << iota weightC weightB weightA + // weightStar is a special "weight" that can be specified only in a search + // term. It indicates prefix matching, which will allow the term to match any + // document term that begins with the search term. weightStar ) -func (w tsweight) String() string { +func (w tsWeight) String() string { var ret strings.Builder if w&weightStar != 0 { ret.WriteByte('*') @@ -48,47 +85,20 @@ func (w tsweight) String() string { return ret.String() } -type tsposition struct { +// tsPosition is a position within a document, along with an optional weight. +type tsPosition struct { position int - weight tsweight -} - -type tsoperator int - -const ( - // Parentheses can be used to control nesting of the TSQuery operators. - // Without parentheses, | binds least tightly, - // then &, then <->, and ! most tightly. - - invalid tsoperator = iota - and // binary - or // binary - not // unary prefix - followedby // binary - lparen // grouping - rparen // grouping -) - -func (o tsoperator) precedence() int { - switch o { - case not: - return 4 - case followedby: - return 3 - case and: - return 2 - case or: - return 1 - } - panic(fmt.Sprintf("no precdence for operator %d", o)) + weight tsWeight } +// tsTerm is either a lexeme and position list, or an operator (when parsing a +// a TSQuery). type tsTerm struct { - // A term is either a lexeme and position list, or an operator. lexeme string - positions []tsposition + positions []tsPosition - operator tsoperator + // The operator and followedN fields are only used when parsing a TSQuery. + operator tsOperator // Set only when operator = followedby followedN int } @@ -194,8 +204,8 @@ func ParseTSVector(input string) (TSVector, error) { if len(ret) >= 1 { // Make sure to sort and uniq the position list even if there's only 1 // entry. - latsIdx := len(ret) - 1 - ret[latsIdx].positions = sortAndUniqTSPositions(ret[latsIdx].positions) + lastIdx := len(ret) - 1 + ret[lastIdx].positions = sortAndUniqTSPositions(ret[lastIdx].positions) } return ret, nil } diff --git a/pkg/util/tsearch/tsvector_test.go b/pkg/util/tsearch/tsvector_test.go index f8bd7f373947..18a351dd7412 100644 --- a/pkg/util/tsearch/tsvector_test.go +++ b/pkg/util/tsearch/tsvector_test.go @@ -11,15 +11,19 @@ package tsearch import ( + "context" + "strings" "testing" + "github.com/cockroachdb/cockroach/pkg/testutils/skip" "github.com/cockroachdb/cockroach/pkg/util/randutil" + "github.com/jackc/pgx/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestParseTSVector(t *testing.T) { - for _, tc := range []struct { + tcs := []struct { input string expectedStr string }{ @@ -36,7 +40,6 @@ bar `, `'bar' 'foo'`}, {`foo:3a`, `'foo':3A`}, {`foo:3A`, `'foo':3A`}, {`foo:3D`, `'foo':3`}, - {`foo:3D`, `'foo':3`}, {`foo:30A`, `'foo':30A`}, {`foo:20,30b,1a`, `'foo':1A,20,30B`}, {`foo:3 bar`, `'bar' 'foo':3`}, @@ -48,6 +51,7 @@ bar `, `'bar' 'foo'`}, {`'\:'`, `':'`}, {`'\ '`, `' '`}, {`\ `, `' '`}, + {`\\`, `'\'`}, {`:3`, `':3'`}, {`::3`, `':':3`}, {`:3:3`, `':3':3`}, @@ -77,13 +81,37 @@ bar `, `'bar' 'foo'`}, {`foo:3,1`, `'foo':1,3`}, {`foo:3,2,1 foo:1,2,3`, `'foo':1,2,3`}, {`a:3 b:2 a:1`, `'a':1,3 'b':2`}, - } { + } + for _, tc := range tcs { t.Log(tc.input) vec, err := ParseTSVector(tc.input) require.NoError(t, err) actual := vec.String() assert.Equal(t, tc.expectedStr, actual) } + + t.Run("ComparePG", func(t *testing.T) { + // This test can be manually run by pointing it to a local Postgres. There + // are no requirements for the contents of the local Postgres - it just runs + // expressions. The test validates that all of the test cases in this test + // file work the same in Postgres as they do this package. + skip.IgnoreLint(t, "need to manually enable") + conn, err := pgx.Connect(context.Background(), "postgresql://jordan@localhost:5432") + require.NoError(t, err) + for _, tc := range tcs { + t.Log(tc) + + var actual string + row := conn.QueryRow(context.Background(), "SELECT $1::TSVector", tc.input) + require.NoError(t, row.Scan(&actual)) + // To make sure that our tests behave as expected, we need to remove + // doubled backslashes from Postgres's output. Confusingly, even in the + // protocol that pgx uses, the actual text returned for a string + // containing a backslash doubles all backslash literals. + actual = strings.ReplaceAll(actual, `\\`, `\`) + assert.Equal(t, tc.expectedStr, actual) + } + }) } func TestParseTSVectorError(t *testing.T) {