From 9d51f6a8fb7bfa632b3f12eee64336456f338fb3 Mon Sep 17 00:00:00 2001 From: Marcus Gartner Date: Mon, 11 Sep 2023 15:59:32 -0400 Subject: [PATCH] util/pretty: mitigate stack overflows of Pretty This commit reduces the chance of a stack overflow from recursive calls of `*beExec.be`. The `Pretty` function will now return an internal error if the recursive depth of `be` surpasses 100,000. While still an internal error, this is preferable to a stack overflow which will crash the process. Informs #91197 Release note: None --- pkg/cli/sqlfmt.go | 6 +- pkg/cmd/reduce/reduce/reducesql/reducesql.go | 6 +- pkg/internal/sqlsmith/setup_test.go | 5 +- pkg/internal/sqlsmith/sqlsmith.go | 6 +- pkg/sql/explain_bundle.go | 22 +++- pkg/sql/logictest/logic.go | 6 +- pkg/sql/opt/optgen/cmd/optfmt/main.go | 5 +- pkg/sql/sem/builtins/builtins.go | 6 +- pkg/sql/sem/tree/pretty.go | 4 +- pkg/sql/sem/tree/pretty_test.go | 43 ++++++- pkg/sql/show_create_clauses.go | 7 +- pkg/testutils/sqlutils/pretty.go | 5 +- pkg/util/pretty/BUILD.bazel | 1 + pkg/util/pretty/pretty.go | 112 +++++++++++++++---- 14 files changed, 190 insertions(+), 44 deletions(-) diff --git a/pkg/cli/sqlfmt.go b/pkg/cli/sqlfmt.go index bf8e7e7f14df..b5aa31425f81 100644 --- a/pkg/cli/sqlfmt.go +++ b/pkg/cli/sqlfmt.go @@ -76,7 +76,11 @@ func runSQLFmt(cmd *cobra.Command, args []string) error { } for i := range sl { - fmt.Print(cfg.Pretty(sl[i].AST)) + p, err := cfg.Pretty(sl[i].AST) + if err != nil { + return err + } + fmt.Print(p) if len(sl) > 1 { fmt.Print(";") } diff --git a/pkg/cmd/reduce/reduce/reducesql/reducesql.go b/pkg/cmd/reduce/reduce/reducesql/reducesql.go index 690a9d713b79..e9f9896a8eef 100644 --- a/pkg/cmd/reduce/reduce/reducesql/reducesql.go +++ b/pkg/cmd/reduce/reduce/reducesql/reducesql.go @@ -438,7 +438,11 @@ func joinASTs(stmts []tree.NodeFormatter) string { UseTabs: false, Simplify: true, } - sb.WriteString(cfg.Pretty(stmt)) + p, err := cfg.Pretty(stmt) + if err != nil { + panic(err) + } + sb.WriteString(p) sb.WriteString(";") } return sb.String() diff --git a/pkg/internal/sqlsmith/setup_test.go b/pkg/internal/sqlsmith/setup_test.go index 76cf1ca266fb..47dc3bdfe708 100644 --- a/pkg/internal/sqlsmith/setup_test.go +++ b/pkg/internal/sqlsmith/setup_test.go @@ -123,7 +123,10 @@ func TestGenerateParse(t *testing.T) { if err != nil { t.Fatalf("%v: %v", stmt, err) } - stmt = sqlsmith.TestingPrettyCfg.Pretty(parsed.AST) + stmt, err = sqlsmith.TestingPrettyCfg.Pretty(parsed.AST) + if err != nil { + t.Fatal(err) + } fmt.Print("STMT: ", i, "\n", stmt, ";\n\n") if *flagExec { db.Exec(t, `SET statement_timeout = '9s'`) diff --git a/pkg/internal/sqlsmith/sqlsmith.go b/pkg/internal/sqlsmith/sqlsmith.go index ec12fed83a92..47271c086aef 100644 --- a/pkg/internal/sqlsmith/sqlsmith.go +++ b/pkg/internal/sqlsmith/sqlsmith.go @@ -190,7 +190,11 @@ func (s *Smither) Generate() string { continue } i = 0 - return prettyCfg.Pretty(stmt) + p, err := prettyCfg.Pretty(stmt) + if err != nil { + panic(err) + } + return p } } diff --git a/pkg/sql/explain_bundle.go b/pkg/sql/explain_bundle.go index 301413552a7e..b9edcfdafe49 100644 --- a/pkg/sql/explain_bundle.go +++ b/pkg/sql/explain_bundle.go @@ -148,7 +148,10 @@ func buildStatementBundle( if plan == nil { return diagnosticsBundle{collectionErr: errors.AssertionFailedf("execution terminated early")} } - b := makeStmtBundleBuilder(explainFlags, db, ie, stmtRawSQL, plan, trace, placeholders, sv) + b, err := makeStmtBundleBuilder(explainFlags, db, ie, stmtRawSQL, plan, trace, placeholders, sv) + if err != nil { + return diagnosticsBundle{collectionErr: err, errorStrings: b.errorStrings} + } b.addStatement() b.addOptPlans(ctx) @@ -226,18 +229,21 @@ func makeStmtBundleBuilder( trace tracingpb.Recording, placeholders *tree.PlaceholderInfo, sv *settings.Values, -) stmtBundleBuilder { +) (stmtBundleBuilder, error) { b := stmtBundleBuilder{ flags: flags, db: db, ie: ie, plan: plan, trace: trace, placeholders: placeholders, sv: sv, } - b.buildPrettyStatement(stmtRawSQL) + err := b.buildPrettyStatement(stmtRawSQL) + if err != nil { + return stmtBundleBuilder{}, err + } b.z.Init() - return b + return b, nil } // buildPrettyStatement saves the pretty-printed statement (without any // placeholder arguments). -func (b *stmtBundleBuilder) buildPrettyStatement(stmtRawSQL string) { +func (b *stmtBundleBuilder) buildPrettyStatement(stmtRawSQL string) (err error) { // If we hit an early error, stmt or stmt.AST might not be initialized yet. In // this case use the original raw SQL. if b.plan.stmt == nil || b.plan.stmt.AST == nil { @@ -255,7 +261,10 @@ func (b *stmtBundleBuilder) buildPrettyStatement(stmtRawSQL string) { cfg.Align = tree.PrettyNoAlign cfg.JSONFmt = true cfg.ValueRedaction = b.flags.RedactValues - b.stmt = cfg.Pretty(b.plan.stmt.AST) + b.stmt, err = cfg.Pretty(b.plan.stmt.AST) + if err != nil { + return err + } // If we had ValueRedaction set, Pretty surrounded all constants with // redaction markers. We must call Redact to fully redact them. @@ -266,6 +275,7 @@ func (b *stmtBundleBuilder) buildPrettyStatement(stmtRawSQL string) { if b.stmt == "" { b.stmt = "-- no statement" } + return nil } // addStatement adds the pretty-printed statement in b.stmt as file diff --git a/pkg/sql/logictest/logic.go b/pkg/sql/logictest/logic.go index 09431143c972..7c7d04701e65 100644 --- a/pkg/sql/logictest/logic.go +++ b/pkg/sql/logictest/logic.go @@ -728,7 +728,11 @@ func (ls *logicStatement) readSQL( if i > 0 { fmt.Fprintln(&newSyntax, ";") } - fmt.Fprint(&newSyntax, pcfg.Pretty(stmtList[i].AST)) + p, err := pcfg.Pretty(stmtList[i].AST) + if err != nil { + return "", errors.Wrapf(err, "error while pretty printing") + } + fmt.Fprint(&newSyntax, p) } return newSyntax.String(), nil }(ls.sql) diff --git a/pkg/sql/opt/optgen/cmd/optfmt/main.go b/pkg/sql/opt/optgen/cmd/optfmt/main.go index ad4bfbcdf7ea..bd52d2b7a7c6 100644 --- a/pkg/sql/opt/optgen/cmd/optfmt/main.go +++ b/pkg/sql/opt/optgen/cmd/optfmt/main.go @@ -173,7 +173,10 @@ func prettyify(r io.Reader, n int, exprgen bool) (string, error) { exprs = parser.Exprs() } d := p.toDoc(exprs) - s := pretty.Pretty(d, n, false, 4, nil) + s, err := pretty.Pretty(d, n, false, 4, nil) + if err != nil { + return "", err + } // Remove any whitespace at EOL. This can happen in define rules where // we always insert a blank line above comments which are nested with diff --git a/pkg/sql/sem/builtins/builtins.go b/pkg/sql/sem/builtins/builtins.go index 2a2bfead9b51..f9565239cbcc 100644 --- a/pkg/sql/sem/builtins/builtins.go +++ b/pkg/sql/sem/builtins/builtins.go @@ -11157,7 +11157,11 @@ func prettyStatement(p tree.PrettyCfg, stmt string) (string, error) { } var formattedStmt strings.Builder for idx := range stmts { - formattedStmt.WriteString(p.Pretty(stmts[idx].AST)) + p, err := p.Pretty(stmts[idx].AST) + if err != nil { + return "", err + } + formattedStmt.WriteString(p) if len(stmts) > 1 { formattedStmt.WriteString(";") } diff --git a/pkg/sql/sem/tree/pretty.go b/pkg/sql/sem/tree/pretty.go index 21e210d22e19..7d1586b8c1c7 100644 --- a/pkg/sql/sem/tree/pretty.go +++ b/pkg/sql/sem/tree/pretty.go @@ -157,13 +157,13 @@ func (p *PrettyCfg) bracketKeyword( } // Pretty pretty prints stmt with default options. -func Pretty(stmt NodeFormatter) string { +func Pretty(stmt NodeFormatter) (string, error) { cfg := DefaultPrettyCfg() return cfg.Pretty(stmt) } // Pretty pretty prints stmt with specified options. -func (p *PrettyCfg) Pretty(stmt NodeFormatter) string { +func (p *PrettyCfg) Pretty(stmt NodeFormatter) (string, error) { doc := p.Doc(stmt) return pretty.Pretty(doc, p.LineWidth, p.UseTabs, p.TabWidth, p.Case) } diff --git a/pkg/sql/sem/tree/pretty_test.go b/pkg/sql/sem/tree/pretty_test.go index e51c353d7c1f..8ac3deb4a807 100644 --- a/pkg/sql/sem/tree/pretty_test.go +++ b/pkg/sql/sem/tree/pretty_test.go @@ -18,6 +18,7 @@ import ( "os" "path/filepath" "runtime" + "strconv" "strings" "testing" @@ -30,6 +31,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/leaktest" "github.com/cockroachdb/cockroach/pkg/util/log" "github.com/cockroachdb/cockroach/pkg/util/pretty" + "github.com/stretchr/testify/assert" "golang.org/x/sync/errgroup" ) @@ -112,7 +114,10 @@ func runTestPrettyData( for p := range work { thisCfg := cfg thisCfg.LineWidth = p.numCols - res[p.idx] = thisCfg.Pretty(stmt.AST) + res[p.idx], err = thisCfg.Pretty(stmt.AST) + if err != nil { + t.Fatal(err) + } } return nil } @@ -178,7 +183,10 @@ func TestPrettyVerify(t *testing.T) { if err != nil { t.Fatal(err) } - got := tree.Pretty(stmt.AST) + got, err := tree.Pretty(stmt.AST) + if err != nil { + t.Fatal(err) + } if pretty != got { t.Fatalf("got: %s\nexpected: %s", got, pretty) } @@ -186,6 +194,32 @@ func TestPrettyVerify(t *testing.T) { } } +func TestPrettyBigStatement(t *testing.T) { + defer leaktest.AfterTest(t)() + defer log.Scope(t).Close(t) + + // Create a SELECT statement with a 1 million item IN expression. Without + // mitigation, this can cause stack overflows - see #91197. + var sb strings.Builder + sb.WriteString("SELECT * FROM foo WHERE id IN (") + for i := 0; i < 1_000_000; i++ { + if i != 0 { + sb.WriteByte(',') + } + sb.WriteString(strconv.Itoa(i)) + } + sb.WriteString(");") + + stmt, err := parser.ParseOne(sb.String()) + if err != nil { + t.Fatal(err) + } + + cfg := tree.DefaultPrettyCfg() + _, err = cfg.Pretty(stmt.AST) + assert.Errorf(t, err, "max call stack depth of be exceeded") +} + func BenchmarkPrettyData(b *testing.B) { matches, err := filepath.Glob(datapathutils.TestDataPath(b, "pretty", "*.sql")) if err != nil { @@ -226,7 +260,10 @@ func TestPrettyExprs(t *testing.T) { } for expr, pretty := range tests { - got := tree.Pretty(expr) + got, err := tree.Pretty(expr) + if err != nil { + t.Fatal(err) + } if pretty != got { t.Fatalf("got: %s\nexpected: %s", got, pretty) } diff --git a/pkg/sql/show_create_clauses.go b/pkg/sql/show_create_clauses.go index 9b742de89780..4a3eee2bf762 100644 --- a/pkg/sql/show_create_clauses.go +++ b/pkg/sql/show_create_clauses.go @@ -169,7 +169,12 @@ func formatViewQueryForDisplay( desc.GetName(), desc.GetID(), err) return } - query = cfg.Pretty(parsed.AST) + query, err = cfg.Pretty(parsed.AST) + if err != nil { + log.Warningf(ctx, "error printing query for view %s (%v): %+v", + desc.GetName(), desc.GetID(), err) + return + } }() typeReplacedViewQuery, err := formatViewQueryTypesForDisplay(ctx, semaCtx, sessionData, desc) diff --git a/pkg/testutils/sqlutils/pretty.go b/pkg/testutils/sqlutils/pretty.go index 38351fa3330a..c6e075c795f5 100644 --- a/pkg/testutils/sqlutils/pretty.go +++ b/pkg/testutils/sqlutils/pretty.go @@ -66,7 +66,10 @@ func VerifyStatementPrettyRoundtrip(t *testing.T, sql string) { // origStmt := stmts[i].AST // Be careful to not simplify otherwise the tests won't round trip. - prettyStmt := cfg.Pretty(origStmt) + prettyStmt, err := cfg.Pretty(origStmt) + if err != nil { + t.Fatalf("%s: %s", err, prettyStmt) + } parsedPretty, err := parser.ParseOne(prettyStmt) if err != nil { t.Fatalf("%s: %s", err, prettyStmt) diff --git a/pkg/util/pretty/BUILD.bazel b/pkg/util/pretty/BUILD.bazel index a37f969901e1..ac7f8f2b6e03 100644 --- a/pkg/util/pretty/BUILD.bazel +++ b/pkg/util/pretty/BUILD.bazel @@ -9,6 +9,7 @@ go_library( ], importpath = "github.com/cockroachdb/cockroach/pkg/util/pretty", visibility = ["//visibility:public"], + deps = ["@com_github_cockroachdb_errors//:errors"], ) go_test( diff --git a/pkg/util/pretty/pretty.go b/pkg/util/pretty/pretty.go index 210223e3592b..4767dab59154 100644 --- a/pkg/util/pretty/pretty.go +++ b/pkg/util/pretty/pretty.go @@ -13,6 +13,8 @@ package pretty import ( "fmt" "strings" + + "github.com/cockroachdb/errors" ) // See the referenced paper in the package documentation for explanations @@ -45,7 +47,9 @@ const ( // if not nil. keywordTransform must not change the visible length of its // argument. It can, for example, add invisible characters like control codes // (colors, etc.). -func Pretty(d Doc, n int, useTabs bool, tabWidth int, keywordTransform func(string) string) string { +func Pretty( + d Doc, n int, useTabs bool, tabWidth int, keywordTransform func(string) string, +) (string, error) { var sb strings.Builder b := beExec{ w: int16(n), @@ -54,13 +58,16 @@ func Pretty(d Doc, n int, useTabs bool, tabWidth int, keywordTransform func(stri memoiDoc: make(map[iDoc]*iDoc), keywordTransform: keywordTransform, } - ldoc := b.best(d) + ldoc, err := b.best(d) + if err != nil { + return "", err + } b.layout(&sb, useTabs, ldoc) - return sb.String() + return sb.String(), nil } // w is the max line width. -func (b *beExec) best(x Doc) *docBest { +func (b *beExec) best(x Doc) (*docBest, error) { return b.be(docPos{0, 0}, b.iDoc(docPos{0, 0}, x, nil)) } @@ -103,18 +110,30 @@ type beExec struct { // keywordTransform filters keywords if not nil. keywordTransform func(string) string + + // beDepth is the depth of recursive calls of be. It is used to detect deep + // call stacks before a stack overflow occurs. + beDepth int } -func (b *beExec) be(k docPos, xlist *iDoc) *docBest { +const maxBeDepth = 10_000 + +func (b *beExec) be(k docPos, xlist *iDoc) (_ *docBest, err error) { + b.beDepth++ + defer func() { b.beDepth-- }() + if b.beDepth > maxBeDepth { + return nil, errors.AssertionFailedf("max call stack depth of be exceeded") + } + // Shortcut: be k [] = Nil if xlist == nil { - return nil + return nil, nil } // If we've computed this result before, short cut here too. memoKey := beArgs{k: k, d: xlist} if cached, ok := b.memoBe[memoKey]; ok { - return cached + return cached, nil } // General case. @@ -127,54 +146,99 @@ func (b *beExec) be(k docPos, xlist *iDoc) *docBest { switch t := d.d.(type) { case nilDoc: - res = b.be(k, z) + res, err = b.be(k, z) + if err != nil { + return nil, err + } case *concat: - res = b.be(k, b.iDoc(d.i, t.a, b.iDoc(d.i, t.b, z))) + res, err = b.be(k, b.iDoc(d.i, t.a, b.iDoc(d.i, t.b, z))) + if err != nil { + return nil, err + } case nests: - res = b.be(k, b.iDoc(docPos{d.i.tabs, d.i.spaces + t.n}, t.d, z)) + res, err = b.be(k, b.iDoc(docPos{d.i.tabs, d.i.spaces + t.n}, t.d, z)) + if err != nil { + return nil, err + } case nestt: - res = b.be(k, b.iDoc(docPos{d.i.tabs + 1 + d.i.spaces/b.tabWidth, 0}, t.d, z)) + res, err = b.be(k, b.iDoc(docPos{d.i.tabs + 1 + d.i.spaces/b.tabWidth, 0}, t.d, z)) + if err != nil { + return nil, err + } case text: + d, err := b.be(docPos{k.tabs, k.spaces + int16(len(t))}, z) + if err != nil { + return nil, err + } res = b.newDocBest(docBest{ tag: textB, s: string(t), - d: b.be(docPos{k.tabs, k.spaces + int16(len(t))}, z), + d: d, }) case keyword: + d, err := b.be(docPos{k.tabs, k.spaces + int16(len(t))}, z) + if err != nil { + return nil, err + } res = b.newDocBest(docBest{ tag: keywordB, s: string(t), - d: b.be(docPos{k.tabs, k.spaces + int16(len(t))}, z), + d: d, }) case line, softbreak: + d, err := b.be(d.i, z) + if err != nil { + return nil, err + } res = b.newDocBest(docBest{ tag: lineB, i: d.i, - d: b.be(d.i, z), + d: d, }) case hardline: + d, err := b.be(d.i, z) + if err != nil { + return nil, err + } res = b.newDocBest(docBest{ tag: hardlineB, i: d.i, - d: b.be(d.i, z), + d: d, }) case *union: - res = b.better(k, - b.be(k, b.iDoc(d.i, t.x, z)), + d, err := b.be(k, b.iDoc(d.i, t.x, z)) + if err != nil { + return nil, err + } + res, err = b.better(k, + d, // We eta-lift the second argument to avoid eager evaluation. - func() *docBest { + func() (*docBest, error) { return b.be(k, b.iDoc(d.i, t.y, z)) }, ) + if err != nil { + return nil, err + } case *scolumn: - res = b.be(k, b.iDoc(d.i, t.f(k.spaces), z)) + res, err = b.be(k, b.iDoc(d.i, t.f(k.spaces), z)) + if err != nil { + return nil, err + } case *snesting: - res = b.be(k, b.iDoc(d.i, t.f(d.i.spaces), z)) + res, err = b.be(k, b.iDoc(d.i, t.f(d.i.spaces), z)) + if err != nil { + return nil, err + } case pad: + d, err := b.be(docPos{k.tabs, k.spaces + t.n}, z) + if err != nil { + return nil, err + } res = b.newDocBest(docBest{ tag: spacesB, i: docPos{spaces: t.n}, - d: b.be(docPos{k.tabs, k.spaces + t.n}, z), + d: d, }) default: panic(fmt.Errorf("unknown type: %T", d.d)) @@ -183,7 +247,7 @@ func (b *beExec) be(k docPos, xlist *iDoc) *docBest { // Memoize so we don't compute the same result twice. b.memoBe[memoKey] = res - return res + return res, nil } // newDocBest makes a new docBest on the heap. Allocations @@ -234,10 +298,10 @@ type beArgs struct { k docPos } -func (b *beExec) better(k docPos, x *docBest, y func() *docBest) *docBest { +func (b *beExec) better(k docPos, x *docBest, y func() (*docBest, error)) (*docBest, error) { remainder := b.w - k.spaces - k.tabs*b.tabWidth if fits(remainder, x) { - return x + return x, nil } return y() }