Skip to content

Commit

Permalink
cmd/shfmt: implement --from-json and improve -tojson
Browse files Browse the repository at this point in the history
For the sake of consistency, -tojson is now also available as --to-json.

The following changes are made to --to-json:

1) JSON object keys are no longer sorted alphabetically.
   The new order is: the derived keys (Type, Pos, End),
   and then the Node's struct fields in their original order.
   This helps consistency across nodes and with the Go documentation.

2) File.Name is empty by default, rather than `<standard input>`.
   It did not make sense as a default for emitting JSON,
   as the flag always required the input to be stdin.

3) Empty values are no longer emitted, to help with verbosity.
   This includes `false`, `""`, `null`, `[]`, and zero positions.
   Positions offsets are exempt, as 0 is a valid byte offset.

4) All position fields in Node structs are now emitted.
   Some positions are redundant given the derived Pos and End keys,
   but some others are entirely separate, like IfClause.ThenPos.

As part of point 1 above, JSON encoding no longer uses Go maps.
It now uses reflect.StructOf, which further leans into Go reflection.

Any of these four changes could potentially break users,
as they slightly change the way we encode syntax trees as JSON.
We are working under the assumption that the changes are reasonable,
and that any breakage is unlikely and easy to fix.
If those assumptions turn out to be false once this change is released,
we can always swap the -tojson flag to be exactly the old behavior,
and --to-json can then add the new behavior in a safer way.

We also had other ideas to further improve the JSON format,
such as renaming the Type and Pos/End JSON keys,
but we leave those as v4 TODOs as they will surely break most users.

The new --from-json flag does the reverse; it decodes the typed JSON,
and fills a *syntax.File with all the information.
The derived Type field is used to select syntax.Node types,
and the derived Pos and End fields are ignored entirely.

It's worth noting that neither --to-json nor --from-json are optimized.
The decoding side first decodes into an empty interface, for example,
which leaves plenty of room for improvement.
Once we're happy with the functionality, we can look at faster
implementations and even dropping the need for reflection.

While here, I noticed that the godoc for Pos.IsValid was slightly wrong.

As a proof of concept, the following commands all produce the same result:

	shfmt <input.sh
	shfmt --to-json <input.sh | shfmt --from-json
	shfmt --to-json <input.sh | jq | shfmt --from-json

Fixes mvdan#35 again, as we never implemented the "read JSON" side.
  • Loading branch information
mvdan committed Jul 16, 2022
1 parent ab95a7f commit c016564
Show file tree
Hide file tree
Showing 9 changed files with 4,391 additions and 180 deletions.
259 changes: 225 additions & 34 deletions cmd/shfmt/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package main

import (
"encoding/json"
"go/ast"
"fmt"
"io"
"reflect"

Expand All @@ -14,68 +14,259 @@ import (

func writeJSON(w io.Writer, node syntax.Node, pretty bool) error {
val := reflect.ValueOf(node)
v, _ := encode(val)
encVal, _ := encode(val)
enc := json.NewEncoder(w)
if pretty {
enc.SetIndent("", "\t")
}
return enc.Encode(v)
return enc.Encode(encVal.Interface())
}

func encode(val reflect.Value) (interface{}, string) {
func encode(val reflect.Value) (reflect.Value, string) {
switch val.Kind() {
case reflect.Ptr:
elem := val.Elem()
if !elem.IsValid() {
return nil, ""
break
}
return encode(elem)
case reflect.Interface:
if val.IsNil() {
return nil, ""
break
}
v, tname := encode(val.Elem())
m := v.(map[string]interface{})
m["Type"] = tname
return m, ""
enc, tname := encode(val.Elem())
if tname != "" {
enc.Elem().Field(0).SetString(tname)
}
return enc, ""
case reflect.Struct:
m := make(map[string]interface{}, val.NumField()+1)
// Construct a new struct with an optional Type, Pos and End,
// and then all the visible fields which aren't positions.
typ := val.Type()
for i := 0; i < val.NumField(); i++ {
ftyp := typ.Field(i)
if ftyp.Type.Name() == "Pos" {
continue
}
if !ast.IsExported(ftyp.Name) {
continue
fields := []reflect.StructField{typeField, posField, endField}
for _, field := range reflect.VisibleFields(typ) {
typ := anyType
if field.Type == posType {
typ = exportedPosType
}
fval := val.Field(i)
v, _ := encode(fval)
m[ftyp.Name] = v
fields = append(fields, reflect.StructField{
Name: field.Name,
Type: typ,
Tag: `json:",omitempty"`,
})
}
encTyp := reflect.StructOf(fields)
enc := reflect.New(encTyp).Elem()

// Pos methods are defined on struct pointer receivers.
for _, name := range [...]string{"Pos", "End"} {
for i, name := range [...]string{"Pos", "End"} {
if fn := val.Addr().MethodByName(name); fn.IsValid() {
m[name] = translatePos(fn.Call(nil)[0])
encodePos(enc.Field(1+i), fn.Call(nil)[0])
}
}
// Do the rest of the fields.
for i := 3; i < encTyp.NumField(); i++ {
ftyp := encTyp.Field(i)
fval := val.FieldByName(ftyp.Name)
if ftyp.Type == exportedPosType {
encodePos(enc.Field(i), fval)
} else {
encElem, _ := encode(fval)
if encElem.IsValid() {
enc.Field(i).Set(encElem)
}
}
}
return m, typ.Name()

// Addr helps prevent an allocation as we use interface{} fields.
return enc.Addr(), typ.Name()
case reflect.Slice:
l := make([]interface{}, val.Len())
for i := 0; i < val.Len(); i++ {
n := val.Len()
if n == 0 {
break
}
enc := reflect.MakeSlice(anySliceType, n, n)
for i := 0; i < n; i++ {
elem := val.Index(i)
l[i], _ = encode(elem)
encElem, _ := encode(elem)
enc.Index(i).Set(encElem)
}
return enc, ""
case reflect.Bool:
if val.Bool() {
return val, ""
}
case reflect.String:
if val.String() != "" {
return val, ""
}
case reflect.Uint32:
if val.Uint() != 0 {
return val, ""
}
return l, ""
default:
return val.Interface(), ""
panic(val.Kind().String())
}
return noValue, ""
}

var (
noValue reflect.Value

anyType = reflect.TypeOf((*interface{})(nil)).Elem() // interface{}
anySliceType = reflect.SliceOf(anyType) // []interface{}
posType = reflect.TypeOf((*syntax.Pos)(nil)).Elem() // syntax.Pos
exportedPosType = reflect.TypeOf((*exportedPos)(nil)) // *exportedPos

// TODO(v4): derived fields like Type, Pos, and End should have clearly
// different names to prevent confusion. For example: _type, _pos, _end.
typeField = reflect.StructField{
Name: "Type",
Type: reflect.TypeOf((*string)(nil)).Elem(),
Tag: `json:",omitempty"`,
}
posField = reflect.StructField{
Name: "Pos",
Type: exportedPosType,
Tag: `json:",omitempty"`,
}
endField = reflect.StructField{
Name: "End",
Type: exportedPosType,
Tag: `json:",omitempty"`,
}
)

type exportedPos struct {
Offset, Line, Col uint
}

func encodePos(encPtr, val reflect.Value) {
if !val.MethodByName("IsValid").Call(nil)[0].Bool() {
return
}
enc := reflect.New(exportedPosType.Elem())
encPtr.Set(enc)
enc = enc.Elem()

enc.Field(0).Set(val.MethodByName("Offset").Call(nil)[0])
enc.Field(1).Set(val.MethodByName("Line").Call(nil)[0])
enc.Field(2).Set(val.MethodByName("Col").Call(nil)[0])
}

func decodePos(val reflect.Value, enc map[string]interface{}) {
offset := uint(enc["Offset"].(float64))
line := uint(enc["Line"].(float64))
column := uint(enc["Col"].(float64))
val.Set(reflect.ValueOf(syntax.NewPos(offset, line, column)))
}

func translatePos(val reflect.Value) map[string]interface{} {
return map[string]interface{}{
"Offset": val.MethodByName("Offset").Call(nil)[0].Uint(),
"Line": val.MethodByName("Line").Call(nil)[0].Uint(),
"Col": val.MethodByName("Col").Call(nil)[0].Uint(),
func readJSON(r io.Reader) (syntax.Node, error) {
var enc interface{}
if err := json.NewDecoder(r).Decode(&enc); err != nil {
return nil, err
}
node := &syntax.File{}
if err := decode(reflect.ValueOf(node), enc); err != nil {
return nil, err
}
return node, nil
}

var nodeByName = map[string]reflect.Type{
"Word": reflect.TypeOf((*syntax.Word)(nil)).Elem(),

"Lit": reflect.TypeOf((*syntax.Lit)(nil)).Elem(),
"SglQuoted": reflect.TypeOf((*syntax.SglQuoted)(nil)).Elem(),
"DblQuoted": reflect.TypeOf((*syntax.DblQuoted)(nil)).Elem(),
"ParamExp": reflect.TypeOf((*syntax.ParamExp)(nil)).Elem(),
"CmdSubst": reflect.TypeOf((*syntax.CmdSubst)(nil)).Elem(),
"CallExpr": reflect.TypeOf((*syntax.CallExpr)(nil)).Elem(),
"ArithmExp": reflect.TypeOf((*syntax.ArithmExp)(nil)).Elem(),
"ProcSubst": reflect.TypeOf((*syntax.ProcSubst)(nil)).Elem(),
"ExtGlob": reflect.TypeOf((*syntax.ExtGlob)(nil)).Elem(),
"BraceExp": reflect.TypeOf((*syntax.BraceExp)(nil)).Elem(),

"ArithmCmd": reflect.TypeOf((*syntax.ArithmCmd)(nil)).Elem(),
"BinaryCmd": reflect.TypeOf((*syntax.BinaryCmd)(nil)).Elem(),
"IfClause": reflect.TypeOf((*syntax.IfClause)(nil)).Elem(),
"ForClause": reflect.TypeOf((*syntax.ForClause)(nil)).Elem(),
"WhileClause": reflect.TypeOf((*syntax.WhileClause)(nil)).Elem(),
"CaseClause": reflect.TypeOf((*syntax.CaseClause)(nil)).Elem(),
"Block": reflect.TypeOf((*syntax.Block)(nil)).Elem(),
"Subshell": reflect.TypeOf((*syntax.Subshell)(nil)).Elem(),
"FuncDecl": reflect.TypeOf((*syntax.FuncDecl)(nil)).Elem(),
"TestClause": reflect.TypeOf((*syntax.TestClause)(nil)).Elem(),
"DeclClause": reflect.TypeOf((*syntax.DeclClause)(nil)).Elem(),
"LetClause": reflect.TypeOf((*syntax.LetClause)(nil)).Elem(),
"TimeClause": reflect.TypeOf((*syntax.TimeClause)(nil)).Elem(),
"CoprocClause": reflect.TypeOf((*syntax.CoprocClause)(nil)).Elem(),
"TestDecl": reflect.TypeOf((*syntax.TestDecl)(nil)).Elem(),

"UnaryArithm": reflect.TypeOf((*syntax.UnaryArithm)(nil)).Elem(),
"BinaryArithm": reflect.TypeOf((*syntax.BinaryArithm)(nil)).Elem(),
"ParenArithm": reflect.TypeOf((*syntax.ParenArithm)(nil)).Elem(),

"UnaryTest": reflect.TypeOf((*syntax.UnaryTest)(nil)).Elem(),
"BinaryTest": reflect.TypeOf((*syntax.BinaryTest)(nil)).Elem(),
"ParenTest": reflect.TypeOf((*syntax.ParenTest)(nil)).Elem(),

"WordIter": reflect.TypeOf((*syntax.WordIter)(nil)).Elem(),
"CStyleLoop": reflect.TypeOf((*syntax.CStyleLoop)(nil)).Elem(),
}

func decode(val reflect.Value, enc interface{}) error {
switch enc := enc.(type) {
case map[string]interface{}:
if val.Kind() == reflect.Ptr && val.IsNil() {
val.Set(reflect.New(val.Type().Elem()))
}
if typeName, _ := enc["Type"].(string); typeName != "" {
typ := nodeByName[typeName]
if typ == nil {
return fmt.Errorf("unknown type: %q", typeName)
}
val.Set(reflect.New(typ))
}
for val.Kind() == reflect.Ptr || val.Kind() == reflect.Interface {
val = val.Elem()
}
for name, fv := range enc {
fval := val.FieldByName(name)
switch name {
case "Type", "Pos", "End":
// Type is already used above. Pos and End came from method calls.
continue
}
if !fval.IsValid() {
return fmt.Errorf("unknown field for %s: %q", val.Type(), name)
}
if fval.Type() == posType {
// TODO: don't panic on bad input
decodePos(fval, fv.(map[string]interface{}))
continue
}
if err := decode(fval, fv); err != nil {
return err
}
}
case []interface{}:
for _, encElem := range enc {
elem := reflect.New(val.Type().Elem()).Elem()
if err := decode(elem, encElem); err != nil {
return err
}
val.Set(reflect.Append(val, elem))
}
case float64:
// Tokens and thus operators are uint32, but encoding/json defaults to float64.
// TODO: reject invalid operators.
u := uint64(enc)
val.SetUint(u)
default:
if enc != nil {
val.Set(reflect.ValueOf(enc))
}
}
return nil
}
70 changes: 70 additions & 0 deletions cmd/shfmt/json_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright (c) 2017, Daniel Martí <[email protected]>
// See LICENSE for licensing information

package main

import (
"bytes"
"os"
"strings"
"testing"

qt "github.com/frankban/quicktest"

"mvdan.cc/sh/v3/syntax"
)

func TestRoundtripJSON(t *testing.T) {
t.Parallel()

// Read testdata files.
inputShell, err := os.ReadFile("testdata/json.sh")
qt.Assert(t, err, qt.IsNil)
inputJSON, err := os.ReadFile("testdata/json.json")
if !*update { // allow it to not exist
qt.Assert(t, err, qt.IsNil)
}
sb := new(strings.Builder)

// Parse the shell source and check that it is well formatted.
parser := syntax.NewParser(syntax.KeepComments(true))
node, err := parser.Parse(bytes.NewReader(inputShell), "")
qt.Assert(t, err, qt.IsNil)

printer := syntax.NewPrinter()
sb.Reset()
err = printer.Print(sb, node)
qt.Assert(t, err, qt.IsNil)
qt.Assert(t, sb.String(), qt.Equals, string(inputShell))

// Validate writing the pretty JSON.
sb.Reset()
err = writeJSON(sb, node, true)
qt.Assert(t, err, qt.IsNil)
got := sb.String()
if *update {
err := os.WriteFile("testdata/json.json", []byte(got), 0o666)
qt.Assert(t, err, qt.IsNil)
} else {
qt.Assert(t, got, qt.Equals, string(inputJSON))
}

// Ensure we don't use the originally parsed node again.
node = nil

// Validate reading the pretty JSON and check that it formats the same.
node2, err := readJSON(bytes.NewReader(inputJSON))
qt.Assert(t, err, qt.IsNil)

sb.Reset()
err = printer.Print(sb, node2)
qt.Assert(t, err, qt.IsNil)
qt.Assert(t, sb.String(), qt.Equals, string(inputShell))

// Validate that emitting the JSON again produces the same result.
sb.Reset()
err = writeJSON(sb, node2, true)
qt.Assert(t, err, qt.IsNil)
got = sb.String()
qt.Assert(t, got, qt.Equals, string(inputJSON))
}
Loading

0 comments on commit c016564

Please sign in to comment.