Skip to content

Commit

Permalink
util/pretty: mitigate stack overflows of Pretty
Browse files Browse the repository at this point in the history
This commit reduces the chance of a stack overflow from recursive calls
of `*beExec.be`. The `Pretty` function will now return an internal error
if the recursive depth of `be` surpasses 10,000.

Informs cockroachdb#91197

Release note: None
  • Loading branch information
mgartner committed Sep 19, 2023
1 parent 9e751ef commit e8d3de2
Show file tree
Hide file tree
Showing 18 changed files with 203 additions and 39 deletions.
6 changes: 5 additions & 1 deletion pkg/cli/sqlfmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,11 @@ func runSQLFmt(cmd *cobra.Command, args []string) error {
}

for i := range sl {
fmt.Print(cfg.Pretty(sl[i].AST))
p, err := cfg.Pretty(sl[i].AST)
if err != nil {
return err
}
fmt.Print(p)
if len(sl) > 1 {
fmt.Print(";")
}
Expand Down
12 changes: 11 additions & 1 deletion pkg/cmd/reduce/reduce/reducesql/reducesql.go
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ func collectASTs(stmts statements.Statements) []tree.NodeFormatter {

func joinASTs(stmts []tree.NodeFormatter) string {
var sb strings.Builder
var fmtCtx *tree.FmtCtx
for i, stmt := range stmts {
if i > 0 {
sb.WriteString("\n\n")
Expand All @@ -438,7 +439,16 @@ func joinASTs(stmts []tree.NodeFormatter) string {
UseTabs: false,
Simplify: true,
}
sb.WriteString(cfg.Pretty(stmt))
p, err := cfg.Pretty(stmt)
if err != nil {
// Use simple printing if pretty-printing fails.
if fmtCtx == nil {
fmtCtx = tree.NewFmtCtx(tree.FmtParsable)
}
stmt.Format(fmtCtx)
p = fmtCtx.CloseAndGetString()
}
sb.WriteString(p)
sb.WriteString(";")
}
return sb.String()
Expand Down
5 changes: 4 additions & 1 deletion pkg/internal/sqlsmith/setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,10 @@ func TestGenerateParse(t *testing.T) {
if err != nil {
t.Fatalf("%v: %v", stmt, err)
}
stmt = sqlsmith.TestingPrettyCfg.Pretty(parsed.AST)
stmt, err = sqlsmith.TestingPrettyCfg.Pretty(parsed.AST)
if err != nil {
t.Fatal(err)
}
fmt.Print("STMT: ", i, "\n", stmt, ";\n\n")
if *flagExec {
db.Exec(t, `SET statement_timeout = '9s'`)
Expand Down
7 changes: 6 additions & 1 deletion pkg/internal/sqlsmith/sqlsmith.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,12 @@ func (s *Smither) Generate() string {
continue
}
i = 0
return prettyCfg.Pretty(stmt)
p, err := prettyCfg.Pretty(stmt)
if err != nil {
// Use simple printing if pretty-printing fails.
p = tree.AsStringWithFlags(stmt, tree.FmtParsable)
}
return p
}
}

Expand Down
13 changes: 6 additions & 7 deletions pkg/kv/kvpb/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -1630,16 +1630,15 @@ func (e *MissingRecordError) SafeFormatError(p errors.Printer) (next error) {

// DescNotFoundError is reported when a descriptor is missing.
type DescNotFoundError struct {
storeID roachpb.StoreID
nodeID roachpb.NodeID
id int32
isStore bool
}

// NewStoreDescNotFoundError initializes a new DescNotFoundError for a missing
// store descriptor.
func NewStoreDescNotFoundError(storeID roachpb.StoreID) *DescNotFoundError {
return &DescNotFoundError{
storeID: storeID,
id: int32(storeID),
isStore: true,
}
}
Expand All @@ -1648,7 +1647,7 @@ func NewStoreDescNotFoundError(storeID roachpb.StoreID) *DescNotFoundError {
// node descriptor.
func NewNodeDescNotFoundError(nodeID roachpb.NodeID) *DescNotFoundError {
return &DescNotFoundError{
nodeID: nodeID,
id: int32(nodeID),
isStore: false,
}
}
Expand All @@ -1658,11 +1657,11 @@ func (e *DescNotFoundError) Error() string {
}

func (e *DescNotFoundError) SafeFormatError(p errors.Printer) (next error) {
s := redact.SafeString("node")
if e.isStore {
p.Printf("store descriptor with store ID %d was not found", e.storeID)
} else {
p.Printf("node descriptor with node ID %d was not found", e.nodeID)
s = "store"
}
p.Printf("%s descriptor with %s ID %d was not found", s, s, e.id)
return nil
}

Expand Down
1 change: 1 addition & 0 deletions pkg/sql/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ go_library(
"//pkg/util/metric",
"//pkg/util/mon",
"//pkg/util/optional",
"//pkg/util/pretty",
"//pkg/util/protoutil",
"//pkg/util/quotapool",
"//pkg/util/randutil",
Expand Down
32 changes: 26 additions & 6 deletions pkg/sql/explain_bundle.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/duration"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/memzipper"
"github.com/cockroachdb/cockroach/pkg/util/pretty"
"github.com/cockroachdb/cockroach/pkg/util/tracing"
"github.com/cockroachdb/cockroach/pkg/util/tracing/tracingpb"
"github.com/cockroachdb/errors"
Expand Down Expand Up @@ -148,7 +149,10 @@ func buildStatementBundle(
if plan == nil {
return diagnosticsBundle{collectionErr: errors.AssertionFailedf("execution terminated early")}
}
b := makeStmtBundleBuilder(explainFlags, db, ie, stmtRawSQL, plan, trace, placeholders, sv)
b, err := makeStmtBundleBuilder(explainFlags, db, ie, stmtRawSQL, plan, trace, placeholders, sv)
if err != nil {
return diagnosticsBundle{collectionErr: err}
}

b.addStatement()
b.addOptPlans(ctx)
Expand Down Expand Up @@ -226,18 +230,21 @@ func makeStmtBundleBuilder(
trace tracingpb.Recording,
placeholders *tree.PlaceholderInfo,
sv *settings.Values,
) stmtBundleBuilder {
) (stmtBundleBuilder, error) {
b := stmtBundleBuilder{
flags: flags, db: db, ie: ie, plan: plan, trace: trace, placeholders: placeholders, sv: sv,
}
b.buildPrettyStatement(stmtRawSQL)
err := b.buildPrettyStatement(stmtRawSQL)
if err != nil {
return stmtBundleBuilder{}, err
}
b.z.Init()
return b
return b, nil
}

// buildPrettyStatement saves the pretty-printed statement (without any
// placeholder arguments).
func (b *stmtBundleBuilder) buildPrettyStatement(stmtRawSQL string) {
func (b *stmtBundleBuilder) buildPrettyStatement(stmtRawSQL string) error {
// If we hit an early error, stmt or stmt.AST might not be initialized yet. In
// this case use the original raw SQL.
if b.plan.stmt == nil || b.plan.stmt.AST == nil {
Expand All @@ -255,7 +262,19 @@ func (b *stmtBundleBuilder) buildPrettyStatement(stmtRawSQL string) {
cfg.Align = tree.PrettyNoAlign
cfg.JSONFmt = true
cfg.ValueRedaction = b.flags.RedactValues
b.stmt = cfg.Pretty(b.plan.stmt.AST)
var err error
b.stmt, err = cfg.Pretty(b.plan.stmt.AST)
if errors.Is(err, pretty.ErrPrettyMaxRecursionDepthExceeded) {
// Use the raw statement string if pretty-printing fails.
b.stmt = stmtRawSQL
// If we're collecting a redacted bundle, redact the raw SQL
// completely.
if b.flags.RedactValues && b.stmt != "" {
b.stmt = string(redact.RedactedMarker())
}
} else if err != nil {
return err
}

// If we had ValueRedaction set, Pretty surrounded all constants with
// redaction markers. We must call Redact to fully redact them.
Expand All @@ -266,6 +285,7 @@ func (b *stmtBundleBuilder) buildPrettyStatement(stmtRawSQL string) {
if b.stmt == "" {
b.stmt = "-- no statement"
}
return nil
}

// addStatement adds the pretty-printed statement in b.stmt as file
Expand Down
6 changes: 5 additions & 1 deletion pkg/sql/logictest/logic.go
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,11 @@ func (ls *logicStatement) readSQL(
if i > 0 {
fmt.Fprintln(&newSyntax, ";")
}
fmt.Fprint(&newSyntax, pcfg.Pretty(stmtList[i].AST))
p, err := pcfg.Pretty(stmtList[i].AST)
if err != nil {
return "", errors.Wrapf(err, "error while pretty printing")
}
fmt.Fprint(&newSyntax, p)
}
return newSyntax.String(), nil
}(ls.sql)
Expand Down
5 changes: 4 additions & 1 deletion pkg/sql/opt/optgen/cmd/optfmt/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,10 @@ func prettyify(r io.Reader, n int, exprgen bool) (string, error) {
exprs = parser.Exprs()
}
d := p.toDoc(exprs)
s := pretty.Pretty(d, n, false, 4, nil)
s, err := pretty.Pretty(d, n, false, 4, nil)
if err != nil {
return "", err
}

// Remove any whitespace at EOL. This can happen in define rules where
// we always insert a blank line above comments which are nested with
Expand Down
1 change: 1 addition & 0 deletions pkg/sql/sem/builtins/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ go_library(
"//pkg/util/json",
"//pkg/util/log",
"//pkg/util/mon",
"//pkg/util/pretty",
"//pkg/util/protoutil",
"//pkg/util/randident",
"//pkg/util/randident/randidentcfg",
Expand Down
11 changes: 10 additions & 1 deletion pkg/sql/sem/builtins/builtins.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/ipaddr"
"github.com/cockroachdb/cockroach/pkg/util/json"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/pretty"
"github.com/cockroachdb/cockroach/pkg/util/protoutil"
"github.com/cockroachdb/cockroach/pkg/util/syncutil"
"github.com/cockroachdb/cockroach/pkg/util/timeofday"
Expand Down Expand Up @@ -11157,7 +11158,15 @@ func prettyStatement(p tree.PrettyCfg, stmt string) (string, error) {
}
var formattedStmt strings.Builder
for idx := range stmts {
formattedStmt.WriteString(p.Pretty(stmts[idx].AST))
p, err := p.Pretty(stmts[idx].AST)
if errors.Is(err, pretty.ErrPrettyMaxRecursionDepthExceeded) {
// If pretty-printing the statement fails, use the original
// statement.
p = stmt
} else if err != nil {
return "", err
}
formattedStmt.WriteString(p)
if len(stmts) > 1 {
formattedStmt.WriteString(";")
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/sem/tree/pretty.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,13 @@ func (p *PrettyCfg) bracketKeyword(
}

// Pretty pretty prints stmt with default options.
func Pretty(stmt NodeFormatter) string {
func Pretty(stmt NodeFormatter) (string, error) {
cfg := DefaultPrettyCfg()
return cfg.Pretty(stmt)
}

// Pretty pretty prints stmt with specified options.
func (p *PrettyCfg) Pretty(stmt NodeFormatter) string {
func (p *PrettyCfg) Pretty(stmt NodeFormatter) (string, error) {
doc := p.Doc(stmt)
return pretty.Pretty(doc, p.LineWidth, p.UseTabs, p.TabWidth, p.Case)
}
Expand Down
48 changes: 44 additions & 4 deletions pkg/sql/sem/tree/pretty_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"testing"

Expand All @@ -30,6 +31,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/pretty"
"github.com/stretchr/testify/assert"
"golang.org/x/sync/errgroup"
)

Expand Down Expand Up @@ -112,7 +114,10 @@ func runTestPrettyData(
for p := range work {
thisCfg := cfg
thisCfg.LineWidth = p.numCols
res[p.idx] = thisCfg.Pretty(stmt.AST)
res[p.idx], err = thisCfg.Pretty(stmt.AST)
if err != nil {
t.Fatal(err)
}
}
return nil
}
Expand Down Expand Up @@ -178,14 +183,43 @@ func TestPrettyVerify(t *testing.T) {
if err != nil {
t.Fatal(err)
}
got := tree.Pretty(stmt.AST)
got, err := tree.Pretty(stmt.AST)
if err != nil {
t.Fatal(err)
}
if pretty != got {
t.Fatalf("got: %s\nexpected: %s", got, pretty)
}
})
}
}

func TestPrettyBigStatement(t *testing.T) {
defer leaktest.AfterTest(t)()
defer log.Scope(t).Close(t)

// Create a SELECT statement with a 1 million item IN expression. Without
// mitigation, this can cause stack overflows - see #91197.
var sb strings.Builder
sb.WriteString("SELECT * FROM foo WHERE id IN (")
for i := 0; i < 1_000_000; i++ {
if i != 0 {
sb.WriteByte(',')
}
sb.WriteString(strconv.Itoa(i))
}
sb.WriteString(");")

stmt, err := parser.ParseOne(sb.String())
if err != nil {
t.Fatal(err)
}

cfg := tree.DefaultPrettyCfg()
_, err = cfg.Pretty(stmt.AST)
assert.Errorf(t, err, "max call stack depth of be exceeded")
}

func BenchmarkPrettyData(b *testing.B) {
matches, err := filepath.Glob(datapathutils.TestDataPath(b, "pretty", "*.sql"))
if err != nil {
Expand All @@ -209,7 +243,10 @@ func BenchmarkPrettyData(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, doc := range docs {
for _, w := range []int{1, 30, 80} {
pretty.Pretty(doc, w, true /*useTabs*/, 4 /*tabWidth*/, nil /* keywordTransform */)
_, err := pretty.Pretty(doc, w, true /*useTabs*/, 4 /*tabWidth*/, nil /* keywordTransform */)
if err != nil {
b.Fatal(err)
}
}
}
}
Expand All @@ -226,7 +263,10 @@ func TestPrettyExprs(t *testing.T) {
}

for expr, pretty := range tests {
got := tree.Pretty(expr)
got, err := tree.Pretty(expr)
if err != nil {
t.Fatal(err)
}
if pretty != got {
t.Fatalf("got: %s\nexpected: %s", got, pretty)
}
Expand Down
Loading

0 comments on commit e8d3de2

Please sign in to comment.