diff --git a/pkg/sql/sem/tree/datum.go b/pkg/sql/sem/tree/datum.go index baf6e2ef2438..166b5f530aa4 100644 --- a/pkg/sql/sem/tree/datum.go +++ b/pkg/sql/sem/tree/datum.go @@ -4059,7 +4059,7 @@ func (d *DTSVector) Min(_ CompareContext) (Datum, bool) { // Size implements the Datum interface. func (d *DTSVector) Size() uintptr { - return uintptr(len(d.TSVector.String())) + return uintptr(d.TSVector.StringSize()) } // AsDTSVector attempts to retrieve a DTSVector from an Expr, returning a diff --git a/pkg/util/tsearch/tsvector.go b/pkg/util/tsearch/tsvector.go index e75d15b83b2a..62e1bb018085 100644 --- a/pkg/util/tsearch/tsvector.go +++ b/pkg/util/tsearch/tsvector.go @@ -11,10 +11,12 @@ package tsearch import ( + "math/bits" "sort" "strconv" "strings" "unicode" + "unicode/utf8" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgcode" "github.com/cockroachdb/cockroach/pkg/sql/pgwire/pgerror" @@ -72,6 +74,7 @@ const ( weightAny = weightA | weightB | weightC | weightD ) +// NB: must be kept in sync with stringSize(). func (w tsWeight) writeString(buf *strings.Builder) { if w&weightStar != 0 { buf.WriteByte('*') @@ -90,6 +93,14 @@ func (w tsWeight) writeString(buf *strings.Builder) { } } +// stringSize returns the length of the string that corresponds to this +// tsWeight. +// NB: must be kept in sync with writeString(). +func (w tsWeight) stringSize() int { + // Count the number of bits set in the lowest 5 bits. + return bits.OnesCount8(uint8(w & 31)) +} + // TSVectorPGEncoding returns the PG-compatible wire protocol encoding for a // given weight. Note that this is only allowable for TSVector tsweights, which // can't have more than one weight set at the same time. In a TSQuery, you might @@ -164,6 +175,7 @@ func newLexemeTerm(lexeme string) (tsTerm, error) { return tsTerm{lexeme: lexeme}, nil } +// NB: must be kept in sync with stringSize(). func (t tsTerm) writeString(buf *strings.Builder) { if t.operator != 0 { switch t.operator { @@ -217,6 +229,46 @@ func (t tsTerm) writeString(buf *strings.Builder) { } } +// stringSize returns the length of the string representation of this tsTerm. +// NB: must be kept in sync with writeString(). +func (t tsTerm) stringSize() int { + if t.operator != 0 { + switch t.operator { + case and, or, not, lparen, rparen: + return 1 + case followedby: + if t.followedN == 1 { + return 3 // '<->' + } + return 2 + len(strconv.Itoa(int(t.followedN))) // fmt.Sprintf("<%d>", t.followedN) + } + } + size := 1 // '\'' + for _, r := range t.lexeme { + if r == '\'' { + // Single quotes are escaped as double single quotes inside of a + // TSVector. + size += 2 + } else { + // Compare as uint32 to correctly handle negative runes. + if uint32(r) < utf8.RuneSelf { + size++ + } else { + size += utf8.RuneLen(r) + } + } + } + size++ // '\'' + size += len(t.positions) // ':' or ',' for each position + for _, pos := range t.positions { + if pos.position > 0 { + size += len(strconv.Itoa(int(pos.position))) + } + size += pos.weight.stringSize() + } + return size +} + func (t tsTerm) matchesWeight(targetWeight tsWeight) bool { if targetWeight == weightAny { return true @@ -249,6 +301,19 @@ func (t TSVector) String() string { return buf.String() } +// StringSize returns the length of the string that would have been returned on +// String() call, without actually constructing that string. +func (t TSVector) StringSize() int { + var size int + if len(t) > 0 { + size = len(t) - 1 // space + } + for _, term := range t { + size += term.stringSize() + } + return size +} + // ParseTSVector produces a TSVector from an input string. The input will be // sorted by lexeme, but will not be automatically stemmed or stop-worded. func ParseTSVector(input string) (TSVector, error) { diff --git a/pkg/util/tsearch/tsvector_test.go b/pkg/util/tsearch/tsvector_test.go index 6bf57fdb0a9a..07b08777e28e 100644 --- a/pkg/util/tsearch/tsvector_test.go +++ b/pkg/util/tsearch/tsvector_test.go @@ -188,6 +188,14 @@ func TestParseTSRandom(t *testing.T) { } } +func TestTSVectorStringSize(t *testing.T) { + r, _ := randutil.NewTestRand() + for i := 0; i < 1000; i++ { + v := RandomTSVector(r) + require.Equal(t, len(v.String()), v.StringSize()) + } +} + func BenchmarkTSVector(b *testing.B) { r, _ := randutil.NewTestRand() tsVectors := make([]TSVector, 10000) @@ -205,7 +213,7 @@ func BenchmarkTSVector(b *testing.B) { b.Run("StringSize", func(b *testing.B) { for i := 0; i < b.N; i++ { for _, v := range tsVectors { - _ = len(v.String()) + _ = v.StringSize() } } })