Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add let statement #52

Merged
merged 6 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ documentation is representative of the current pql api.
- [`as`](https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/as-operator)
- [`count`](https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/count-operator)
- [`join`](https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/join-operator)
- [`let` statements](https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/let-statement),
but only scalar expressions are supported.
- [`project`](https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/project-operator)
- [`extend`](https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/extend-operator)
- [`sort`/`order`](https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/sort-operator)
Expand Down
16 changes: 15 additions & 1 deletion cmd/pql/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func run(ctx context.Context, output io.Writer, input io.Reader, logError func(e
}

var finalError error
letStatements := new(strings.Builder)
for scanner.Scan() {
sb.Write(scanner.Bytes())
sb.WriteByte('\n')
Expand All @@ -80,7 +81,20 @@ func run(ctx context.Context, output io.Writer, input io.Reader, logError func(e
}

for _, stmt := range statements[:len(statements)-1] {
sql, err := pql.Compile(stmt)
// Valid let statements are prepended to an ongoing prelude.
tokens := parser.Scan(stmt)
if len(tokens) > 0 && tokens[0].Kind == parser.TokenIdentifier && tokens[0].Value == "let" {
if _, err := pql.Compile(letStatements.String() + stmt + ";X"); err != nil {
logError(err)
finalError = errors.New("one or more statements could not be compiled")
} else {
letStatements.WriteString(stmt)
letStatements.WriteString(";\n")
}
continue
}

sql, err := pql.Compile(letStatements.String() + stmt)
if err != nil {
logError(err)
finalError = errors.New("one or more statements could not be compiled")
Expand Down
36 changes: 36 additions & 0 deletions parser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,20 @@ func (id *QualifiedIdent) Span() Span {

func (id *QualifiedIdent) expression() {}

type Statement interface {
Node
statement()
}

// TabularExpr is a query expression that produces a table.
// It implements [Statement].
type TabularExpr struct {
Source TabularDataSource
Operators []TabularOperator
}

func (x *TabularExpr) statement() {}

func (x *TabularExpr) Span() Span {
if x == nil {
return nullSpan()
Expand Down Expand Up @@ -548,6 +556,29 @@ func (idx *IndexExpr) Span() Span {

func (idx *IndexExpr) expression() {}

// A LetStatement node represents a let statement,
// assigning an expression to a name.
// It implements [Statement].
type LetStatement struct {
Keyword Span
Name *Ident
Assign Span
X Expr
}

func (stmt *LetStatement) statement() {}

func (stmt *LetStatement) Span() Span {
if stmt == nil {
return nullSpan()
}
xSpan := nullSpan()
if stmt.X != nil {
xSpan = stmt.X.Span()
}
return unionSpans(stmt.Keyword, stmt.Name.Span(), stmt.Assign, xSpan)
}

// Walk traverses an AST in depth-first order.
// If the visit function returns true for a node,
// the visit function will be called for its children.
Expand Down Expand Up @@ -685,6 +716,11 @@ func Walk(n Node, visit func(n Node) bool) {
stack = append(stack, n.Index)
stack = append(stack, n.X)
}
case *LetStatement:
if visit(n) {
stack = append(stack, n.X)
stack = append(stack, n.Name)
}
default:
panic(fmt.Errorf("unknown Node type %T", n))
}
Expand Down
223 changes: 156 additions & 67 deletions parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,40 +21,119 @@ type parser struct {
splitKind TokenKind
}

// Parse converts a Pipeline Query Language tabular expression
// Parse converts a Pipeline Query Language query
// into an Abstract Syntax Tree (AST).
func Parse(query string) (*TabularExpr, error) {
func Parse(query string) ([]Statement, error) {
p := &parser{
source: query,
tokens: Scan(query),
}
expr, err := p.tabularExpr()
if p.pos < len(p.tokens) {
trailingToken := p.tokens[p.pos]
if trailingToken.Kind == TokenError {
err = joinErrors(err, &parseError{
source: p.source,
span: trailingToken.Span,
err: errors.New(trailingToken.Value),
})
var result []Statement
var resultError error
for {
stmtParser := p.splitSemi()

stmt, err := firstParse(
func() (Statement, error) {
stmt, err := stmtParser.letStatement()
if stmt == nil {
// Prevent returning a non-nil interface.
return nil, err
}
return stmt, err
},
func() (Statement, error) {
expr, err := stmtParser.tabularExpr()
if expr == nil {
// Prevent returning a non-nil interface.
return nil, err
}
return expr, err
},
)

if isNotFound(err) {
// We're okay with empty statements, we just ignore them.
if stmtParser.pos < len(stmtParser.tokens) {
trailingToken := stmtParser.tokens[stmtParser.pos]
if trailingToken.Kind == TokenError {
resultError = joinErrors(err, &parseError{
source: p.source,
span: trailingToken.Span,
err: errors.New(trailingToken.Value),
})
} else {
resultError = joinErrors(err, &parseError{
source: p.source,
span: trailingToken.Span,
err: errors.New("unrecognized token"),
})
}
}
} else {
err = joinErrors(err, &parseError{
source: p.source,
span: trailingToken.Span,
err: errors.New("unrecognized token"),
})
if stmt != nil {
result = append(result, stmt)
}
resultError = joinErrors(resultError, makeErrorOpaque(err))
resultError = joinErrors(resultError, stmtParser.endSplit())
}

// Next token, if present, guaranteed to be a semicolon.
if _, ok := p.next(); !ok {
break
}
} else if isNotFound(err) {
err = &parseError{
}

if resultError != nil {
return result, fmt.Errorf("parse pipeline query language: %w", resultError)
}
return result, nil
}

func firstParse[T any](productions ...func() (T, error)) (T, error) {
for _, p := range productions[:len(productions)-1] {
x, err := p()
if !isNotFound(err) {
return x, err
}
}
return productions[len(productions)-1]()
}

func (p *parser) letStatement() (*LetStatement, error) {
keyword, _ := p.next()
if keyword.Kind != TokenIdentifier || keyword.Value != "let" {
p.prev()
return nil, &parseError{
source: p.source,
span: keyword.Span,
err: notFoundError{fmt.Errorf("expected 'let', got %s", formatToken(p.source, keyword))},
}
}

stmt := &LetStatement{
Keyword: keyword.Span,
Assign: nullSpan(),
}
var err error
stmt.Name, err = p.ident()
if err != nil {
return stmt, makeErrorOpaque(err)
}
assign, _ := p.next()
if assign.Kind != TokenAssign {
return stmt, &parseError{
source: p.source,
span: indexSpan(len(query)),
err: errors.New("empty query"),
span: assign.Span,
err: fmt.Errorf("expected '=', got %s", formatToken(p.source, assign)),
}
}
stmt.Assign = assign.Span
stmt.X, err = p.expr()
if err != nil {
return expr, fmt.Errorf("parse pipeline query language: %w", err)
return stmt, makeErrorOpaque(err)
}
return expr, nil
return stmt, nil
}

func (p *parser) tabularExpr() (*TabularExpr, error) {
Expand Down Expand Up @@ -294,27 +373,10 @@ func (p *parser) takeOperator(pipe, keyword Token) (*TakeOperator, error) {
Pipe: pipe.Span,
Keyword: keyword.Span,
}

tok, _ := p.next()
if tok.Kind != TokenNumber {
return op, &parseError{
source: p.source,
span: tok.Span,
err: fmt.Errorf("expected integer, got %s", formatToken(p.source, tok)),
}
}
rowCount := &BasicLit{
Kind: tok.Kind,
Value: tok.Value,
ValueSpan: tok.Span,
}
op.RowCount = rowCount
if !rowCount.IsInteger() {
return op, &parseError{
source: p.source,
span: tok.Span,
err: fmt.Errorf("expected integer, got %s", formatToken(p.source, tok)),
}
var err error
op.RowCount, err = p.rowCount()
if err != nil {
return op, makeErrorOpaque(err)
}
return op, nil
}
Expand All @@ -326,30 +388,13 @@ func (p *parser) topOperator(pipe, keyword Token) (*TopOperator, error) {
By: nullSpan(),
}

tok, _ := p.next()
if tok.Kind != TokenNumber {
p.prev()
return op, &parseError{
source: p.source,
span: tok.Span,
err: fmt.Errorf("expected integer, got %s", formatToken(p.source, tok)),
}
}
rowCount := &BasicLit{
Kind: tok.Kind,
Value: tok.Value,
ValueSpan: tok.Span,
}
op.RowCount = rowCount
if !rowCount.IsInteger() {
return op, &parseError{
source: p.source,
span: tok.Span,
err: fmt.Errorf("expected integer, got %s", formatToken(p.source, tok)),
}
var err error
op.RowCount, err = p.rowCount()
if err != nil {
return op, makeErrorOpaque(err)
}

tok, _ = p.next()
tok, _ := p.next()
if tok.Kind != TokenBy {
p.prev()
return op, &parseError{
Expand All @@ -360,11 +405,28 @@ func (p *parser) topOperator(pipe, keyword Token) (*TopOperator, error) {
}
op.By = tok.Span

var err error
op.Col, err = p.sortTerm()
return op, makeErrorOpaque(err)
}

func (p *parser) rowCount() (Expr, error) {
x, err := p.expr()
if err != nil {
return x, err
}
if lit, ok := x.(*BasicLit); ok {
// Do basic check for common case of literals.
if !lit.IsInteger() {
return x, fmt.Errorf("expected integer, got %s", formatToken(p.source, Token{
Kind: lit.Kind,
Span: lit.ValueSpan,
Value: lit.Value,
}))
}
}
return x, nil
}

func (p *parser) projectOperator(pipe, keyword Token) (*ProjectOperator, error) {
op := &ProjectOperator{
Pipe: pipe.Span,
Expand Down Expand Up @@ -1034,7 +1096,9 @@ func (p *parser) qualifiedIdent() (*QualifiedIdent, error) {
// split advances the parser to right before the next token of the given kind,
// and returns a new parser that reads the tokens that were skipped over.
// It ignores tokens that are in parenthetical groups after the initial parse position.
// If no such token is found, skipTo advances to EOF.
// If no such token is found, split advances to EOF.
//
// For splitting by semicolon, see [*parser.splitSemi].
func (p *parser) split(search TokenKind) *parser {
// stack is the list of expected closing parentheses/brackets.
// When a closing parenthesis/bracket is encountered,
Expand Down Expand Up @@ -1095,6 +1159,31 @@ loop:
}
}

// splitSemi advances the parser to right before the next semicolon,
// and returns a new parser that reads the tokens that were skipped over.
// If no semicolon is found, splitSemi advances to EOF.
func (p *parser) splitSemi() *parser {
start := p.pos
for {
tok, ok := p.next()
if !ok {
return &parser{
source: p.source,
tokens: p.tokens[start:],
splitKind: TokenSemi,
}
}
if tok.Kind == TokenSemi {
p.prev()
return &parser{
source: p.source,
tokens: p.tokens[start:p.pos],
splitKind: TokenSemi,
}
}
}
}

func (p *parser) endSplit() error {
if p.splitKind == 0 {
// This is a bug, but treating as an error instead of panicing.
Expand Down
Loading