Skip to content

Commit

Permalink
util/pretty: mitigate stack overflows of Pretty
Browse files Browse the repository at this point in the history
This commit reduces the chance of a stack overflow from recursive calls
of `*beExec.be`. The `Pretty` function will now return an internal error
if the recursive depth of `be` surpasses 10,000.

Informs cockroachdb#91197

Release note: None
  • Loading branch information
mgartner committed Sep 13, 2023
1 parent d497d7a commit 964e2ec
Show file tree
Hide file tree
Showing 15 changed files with 153 additions and 17 deletions.
6 changes: 5 additions & 1 deletion pkg/cli/sqlfmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,11 @@ func runSQLFmt(cmd *cobra.Command, args []string) error {
}

for i := range sl {
fmt.Print(cfg.Pretty(sl[i].AST))
p, err := cfg.Pretty(sl[i].AST)
if err != nil {
return err
}
fmt.Print(p)
if len(sl) > 1 {
fmt.Print(";")
}
Expand Down
12 changes: 11 additions & 1 deletion pkg/cmd/reduce/reduce/reducesql/reducesql.go
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ func collectASTs(stmts statements.Statements) []tree.NodeFormatter {

func joinASTs(stmts []tree.NodeFormatter) string {
var sb strings.Builder
var fmtCtx *tree.FmtCtx
for i, stmt := range stmts {
if i > 0 {
sb.WriteString("\n\n")
Expand All @@ -438,7 +439,16 @@ func joinASTs(stmts []tree.NodeFormatter) string {
UseTabs: false,
Simplify: true,
}
sb.WriteString(cfg.Pretty(stmt))
p, err := cfg.Pretty(stmt)
if err != nil {
// Use simple printing if pretty-printing fails.
if fmtCtx == nil {
fmtCtx = tree.NewFmtCtx(tree.FmtParsable)
}
stmt.Format(fmtCtx)
p = fmtCtx.CloseAndGetString()
}
sb.WriteString(p)
sb.WriteString(";")
}
return sb.String()
Expand Down
5 changes: 4 additions & 1 deletion pkg/internal/sqlsmith/setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,10 @@ func TestGenerateParse(t *testing.T) {
if err != nil {
t.Fatalf("%v: %v", stmt, err)
}
stmt = sqlsmith.TestingPrettyCfg.Pretty(parsed.AST)
stmt, err = sqlsmith.TestingPrettyCfg.Pretty(parsed.AST)
if err != nil {
t.Fatal(err)
}
fmt.Print("STMT: ", i, "\n", stmt, ";\n\n")
if *flagExec {
db.Exec(t, `SET statement_timeout = '9s'`)
Expand Down
7 changes: 6 additions & 1 deletion pkg/internal/sqlsmith/sqlsmith.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,12 @@ func (s *Smither) Generate() string {
continue
}
i = 0
return prettyCfg.Pretty(stmt)
p, err := prettyCfg.Pretty(stmt)
if err != nil {
// Use simple printing if pretty-printing fails.
p = tree.AsStringWithFlags(stmt, tree.FmtParsable)
}
return p
}
}

Expand Down
1 change: 1 addition & 0 deletions pkg/sql/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,7 @@ go_library(
"//pkg/util/metric",
"//pkg/util/mon",
"//pkg/util/optional",
"//pkg/util/pretty",
"//pkg/util/protoutil",
"//pkg/util/quotapool",
"//pkg/util/randutil",
Expand Down
13 changes: 12 additions & 1 deletion pkg/sql/explain_bundle.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/duration"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/memzipper"
"github.com/cockroachdb/cockroach/pkg/util/pretty"
"github.com/cockroachdb/cockroach/pkg/util/tracing"
"github.com/cockroachdb/cockroach/pkg/util/tracing/tracingpb"
"github.com/cockroachdb/errors"
Expand Down Expand Up @@ -255,7 +256,17 @@ func (b *stmtBundleBuilder) buildPrettyStatement(stmtRawSQL string) {
cfg.Align = tree.PrettyNoAlign
cfg.JSONFmt = true
cfg.ValueRedaction = b.flags.RedactValues
b.stmt = cfg.Pretty(b.plan.stmt.AST)
var err error
b.stmt, err = cfg.Pretty(b.plan.stmt.AST)
if errors.Is(err, pretty.ErrPrettyMaxRecursionDepthExceeded) {
// Use the raw statement string if pretty-printing fails.
b.stmt = stmtRawSQL
// If we're collecting a redacted bundle, redact the raw SQL
// completely.
if b.flags.RedactValues && b.stmt != "" {
b.stmt = string(redact.RedactedMarker())
}
}

// If we had ValueRedaction set, Pretty surrounded all constants with
// redaction markers. We must call Redact to fully redact them.
Expand Down
6 changes: 5 additions & 1 deletion pkg/sql/logictest/logic.go
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,11 @@ func (ls *logicStatement) readSQL(
if i > 0 {
fmt.Fprintln(&newSyntax, ";")
}
fmt.Fprint(&newSyntax, pcfg.Pretty(stmtList[i].AST))
p, err := pcfg.Pretty(stmtList[i].AST)
if err != nil {
return "", errors.Wrapf(err, "error while pretty printing")
}
fmt.Fprint(&newSyntax, p)
}
return newSyntax.String(), nil
}(ls.sql)
Expand Down
5 changes: 4 additions & 1 deletion pkg/sql/opt/optgen/cmd/optfmt/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,10 @@ func prettyify(r io.Reader, n int, exprgen bool) (string, error) {
exprs = parser.Exprs()
}
d := p.toDoc(exprs)
s := pretty.Pretty(d, n, false, 4, nil)
s, err := pretty.Pretty(d, n, false, 4, nil)
if err != nil {
return "", err
}

// Remove any whitespace at EOL. This can happen in define rules where
// we always insert a blank line above comments which are nested with
Expand Down
8 changes: 7 additions & 1 deletion pkg/sql/sem/builtins/builtins.go
Original file line number Diff line number Diff line change
Expand Up @@ -11157,7 +11157,13 @@ func prettyStatement(p tree.PrettyCfg, stmt string) (string, error) {
}
var formattedStmt strings.Builder
for idx := range stmts {
formattedStmt.WriteString(p.Pretty(stmts[idx].AST))
p, err := p.Pretty(stmts[idx].AST)
if err != nil {
// If pretty-printing the statement fails, use the original
// statement.
p = stmt
}
formattedStmt.WriteString(p)
if len(stmts) > 1 {
formattedStmt.WriteString(";")
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/sql/sem/tree/pretty.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,13 @@ func (p *PrettyCfg) bracketKeyword(
}

// Pretty pretty prints stmt with default options.
func Pretty(stmt NodeFormatter) string {
func Pretty(stmt NodeFormatter) (string, error) {
cfg := DefaultPrettyCfg()
return cfg.Pretty(stmt)
}

// Pretty pretty prints stmt with specified options.
func (p *PrettyCfg) Pretty(stmt NodeFormatter) string {
func (p *PrettyCfg) Pretty(stmt NodeFormatter) (string, error) {
doc := p.Doc(stmt)
return pretty.Pretty(doc, p.LineWidth, p.UseTabs, p.TabWidth, p.Case)
}
Expand Down
43 changes: 40 additions & 3 deletions pkg/sql/sem/tree/pretty_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"testing"

Expand All @@ -30,6 +31,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/pretty"
"github.com/stretchr/testify/assert"
"golang.org/x/sync/errgroup"
)

Expand Down Expand Up @@ -112,7 +114,10 @@ func runTestPrettyData(
for p := range work {
thisCfg := cfg
thisCfg.LineWidth = p.numCols
res[p.idx] = thisCfg.Pretty(stmt.AST)
res[p.idx], err = thisCfg.Pretty(stmt.AST)
if err != nil {
t.Fatal(err)
}
}
return nil
}
Expand Down Expand Up @@ -178,14 +183,43 @@ func TestPrettyVerify(t *testing.T) {
if err != nil {
t.Fatal(err)
}
got := tree.Pretty(stmt.AST)
got, err := tree.Pretty(stmt.AST)
if err != nil {
t.Fatal(err)
}
if pretty != got {
t.Fatalf("got: %s\nexpected: %s", got, pretty)
}
})
}
}

func TestPrettyBigStatement(t *testing.T) {
defer leaktest.AfterTest(t)()
defer log.Scope(t).Close(t)

// Create a SELECT statement with a 1 million item IN expression. Without
// mitigation, this can cause stack overflows - see #91197.
var sb strings.Builder
sb.WriteString("SELECT * FROM foo WHERE id IN (")
for i := 0; i < 1_000_000; i++ {
if i != 0 {
sb.WriteByte(',')
}
sb.WriteString(strconv.Itoa(i))
}
sb.WriteString(");")

stmt, err := parser.ParseOne(sb.String())
if err != nil {
t.Fatal(err)
}

cfg := tree.DefaultPrettyCfg()
_, err = cfg.Pretty(stmt.AST)
assert.Errorf(t, err, "max call stack depth of be exceeded")
}

func BenchmarkPrettyData(b *testing.B) {
matches, err := filepath.Glob(datapathutils.TestDataPath(b, "pretty", "*.sql"))
if err != nil {
Expand Down Expand Up @@ -226,7 +260,10 @@ func TestPrettyExprs(t *testing.T) {
}

for expr, pretty := range tests {
got := tree.Pretty(expr)
got, err := tree.Pretty(expr)
if err != nil {
t.Fatal(err)
}
if pretty != got {
t.Fatalf("got: %s\nexpected: %s", got, pretty)
}
Expand Down
8 changes: 7 additions & 1 deletion pkg/sql/show_create_clauses.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/sql/sessiondata"
"github.com/cockroachdb/cockroach/pkg/sql/types"
"github.com/cockroachdb/cockroach/pkg/util/log"
"github.com/cockroachdb/cockroach/pkg/util/pretty"
"github.com/cockroachdb/errors"
)

Expand Down Expand Up @@ -169,7 +170,12 @@ func formatViewQueryForDisplay(
desc.GetName(), desc.GetID(), err)
return
}
query = cfg.Pretty(parsed.AST)
query, err = cfg.Pretty(parsed.AST)
if errors.Is(err, pretty.ErrPrettyMaxRecursionDepthExceeded) {
// Use simple printing if pretty-printing fails.
query = tree.AsStringWithFlags(parsed.AST, tree.FmtParsable)
return
}
}()

typeReplacedViewQuery, err := formatViewQueryTypesForDisplay(ctx, semaCtx, sessionData, desc)
Expand Down
5 changes: 4 additions & 1 deletion pkg/testutils/sqlutils/pretty.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ func VerifyStatementPrettyRoundtrip(t *testing.T, sql string) {
//
origStmt := stmts[i].AST
// Be careful to not simplify otherwise the tests won't round trip.
prettyStmt := cfg.Pretty(origStmt)
prettyStmt, err := cfg.Pretty(origStmt)
if err != nil {
t.Fatalf("%s: %s", err, prettyStmt)
}
parsedPretty, err := parser.ParseOne(prettyStmt)
if err != nil {
t.Fatalf("%s: %s", err, prettyStmt)
Expand Down
4 changes: 4 additions & 0 deletions pkg/util/pretty/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ go_library(
],
importpath = "github.com/cockroachdb/cockroach/pkg/util/pretty",
visibility = ["//visibility:public"],
deps = [
"//pkg/util/errorutil",
"@com_github_cockroachdb_errors//:errors",
],
)

go_test(
Expand Down
43 changes: 41 additions & 2 deletions pkg/util/pretty/pretty.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ package pretty
import (
"fmt"
"strings"

"github.com/cockroachdb/cockroach/pkg/util/errorutil"
"github.com/cockroachdb/errors"
)

// See the referenced paper in the package documentation for explanations
Expand Down Expand Up @@ -45,7 +48,25 @@ const (
// if not nil. keywordTransform must not change the visible length of its
// argument. It can, for example, add invisible characters like control codes
// (colors, etc.).
func Pretty(d Doc, n int, useTabs bool, tabWidth int, keywordTransform func(string) string) string {
func Pretty(
d Doc, n int, useTabs bool, tabWidth int, keywordTransform func(string) string,
) (_ string, err error) {
defer func() {
if r := recover(); r != nil {
// This code allows us to propagate internal errors without having
// to add error checks everywhere throughout the code. This is only
// possible because the code does not update shared state and does
// not manipulate locks.
if ok, e := errorutil.ShouldCatch(r); ok {
err = e
} else {
// Other panic objects can't be considered "safe" and thus are
// propagated as panics.
panic(r)
}
}
}()

var sb strings.Builder
b := beExec{
w: int16(n),
Expand All @@ -56,7 +77,7 @@ func Pretty(d Doc, n int, useTabs bool, tabWidth int, keywordTransform func(stri
}
ldoc := b.best(d)
b.layout(&sb, useTabs, ldoc)
return sb.String()
return sb.String(), nil
}

// w is the max line width.
Expand Down Expand Up @@ -103,9 +124,27 @@ type beExec struct {

// keywordTransform filters keywords if not nil.
keywordTransform func(string) string

// beDepth is the depth of recursive calls of be. It is used to detect deep
// call stacks before a stack overflow occurs.
beDepth int
}

// maxBeDepth is the maximum allowed recursive call depth of be. If the depth
// exceeds this value, be will panic.
const maxBeDepth = 10_000

// ErrPrettyMaxRecursionDepthExceeded is returned from Pretty when the maximum
// recursion depth of function invoked by Pretty is exceeded.
var ErrPrettyMaxRecursionDepthExceeded = errors.AssertionFailedf("max recursion depth exceeded")

func (b *beExec) be(k docPos, xlist *iDoc) *docBest {
b.beDepth++
defer func() { b.beDepth-- }()
if b.beDepth > maxBeDepth {
panic(ErrPrettyMaxRecursionDepthExceeded)
}

// Shortcut: be k [] = Nil
if xlist == nil {
return nil
Expand Down

0 comments on commit 964e2ec

Please sign in to comment.