diff --git a/marshal.go b/marshal.go index 42fb3cfb..1dbec2dc 100644 --- a/marshal.go +++ b/marshal.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "reflect" + "strconv" "strings" "time" ) @@ -12,6 +13,7 @@ import ( type tomlOpts struct { name string comment *string + commented bool include bool omitempty bool } @@ -148,7 +150,7 @@ func valueToTree(mtype reflect.Type, mval reflect.Value) (*Tree, error) { if err != nil { return nil, err } - tval.Set(opts.name, opts.comment, val) + tval.Set(opts.name, opts.comment, opts.commented, val) } } case reflect.Map: @@ -158,7 +160,7 @@ func valueToTree(mtype reflect.Type, mval reflect.Value) (*Tree, error) { if err != nil { return nil, err } - tval.Set(key.String(), nil, val) + tval.Set(key.String(), nil, false, val) } } return tval, nil @@ -453,7 +455,8 @@ func tomlOptions(vf reflect.StructField) tomlOpts { if c := vf.Tag.Get("comment"); c != "" { comment = &c } - result := tomlOpts{vf.Name, comment, true, false} + commented, _ := strconv.ParseBool(vf.Tag.Get("commented")) + result := tomlOpts{name: vf.Name, comment: comment, commented: commented, include: true, omitempty: false} if parse[0] != "" { if parse[0] == "-" && len(parse) == 1 { result.include = false diff --git a/marshal_test.go b/marshal_test.go index 1dd7be65..a44490e3 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -600,8 +600,9 @@ func TestNestedCustomMarshaler(t *testing.T) { } var commentTestToml = []byte(` -# postgres it's a comment on type +# it's a comment on type [postgres] + # isCommented = "dvalue" noComment = "cvalue" # A comment on AttrB @@ -613,7 +614,6 @@ var commentTestToml = []byte(` # a comment on My [[postgres.My]] - # a comment on My [[postgres.My]] `) @@ -625,6 +625,7 @@ func TestMarshalComment(t *testing.T) { AttrA string `toml:"user" comment:"A comment on AttrA"` AttrB string `toml:"password" comment:"A comment on AttrB"` AttrC string `toml:"noComment"` + AttrD string `toml:"isCommented" commented:"true"` My []TypeC `comment:"a comment on My"` } type TypeA struct { @@ -632,7 +633,7 @@ func TestMarshalComment(t *testing.T) { } ta := []TypeC{{my: "Foo"}, {my: "Baar"}} - config := TypeA{TypeB{AttrA: "avalue", AttrB: "bvalue", AttrC: "cvalue", My: ta}} + config := TypeA{TypeB{AttrA: "avalue", AttrB: "bvalue", AttrC: "cvalue", AttrD: "dvalue", My: ta}} result, err := Marshal(config) if err != nil { t.Fatal(err) diff --git a/parser.go b/parser.go index 5ca65dff..85feafeb 100644 --- a/parser.go +++ b/parser.go @@ -110,7 +110,7 @@ func (p *tomlParser) parseGroupArray() tomlParserStateFn { newTree := newTree() newTree.position = startToken.Position array = append(array, newTree) - p.tree.SetPath(p.currentTable, nil, array) + p.tree.SetPath(p.currentTable, nil, false, array) // remove all keys that were children of this table array prefix := key.val + "." @@ -299,7 +299,7 @@ Loop: key := p.getToken() p.assume(tokenEqual) value := p.parseRvalue() - tree.Set(key.val, nil, value) + tree.Set(key.val, nil, false, value) case tokenComma: if previous == nil { p.raiseError(follow, "inline table cannot start with a comma") diff --git a/parser_test.go b/parser_test.go index 7a6c21cd..81cbb21a 100644 --- a/parser_test.go +++ b/parser_test.go @@ -46,7 +46,7 @@ func assertTree(t *testing.T, tree *Tree, err error, ref map[string]interface{}) func TestCreateSubTree(t *testing.T) { tree := newTree() tree.createSubTree([]string{"a", "b", "c"}, Position{}) - tree.Set("a.b.c", nil, 42) + tree.Set("a.b.c", nil, false, 42) if tree.Get("a.b.c") != 42 { t.Fail() } diff --git a/toml.go b/toml.go index 4da92478..75df92e9 100644 --- a/toml.go +++ b/toml.go @@ -11,16 +11,18 @@ import ( ) type tomlValue struct { - value interface{} // string, int64, uint64, float64, bool, time.Time, [] of any of this list - comment *string - position Position + value interface{} // string, int64, uint64, float64, bool, time.Time, [] of any of this list + comment *string + commented bool + position Position } // Tree is the result of the parsing of a TOML file. type Tree struct { - values map[string]interface{} // string -> *tomlValue, *Tree, []*Tree - comment *string - position Position + values map[string]interface{} // string -> *tomlValue, *Tree, []*Tree + comment *string + commented bool + position Position } func newTree() *Tree { @@ -179,14 +181,14 @@ func (t *Tree) GetDefault(key string, def interface{}) interface{} { // Set an element in the tree. // Key is a dot-separated path (e.g. a.b.c). // Creates all necessary intermediate trees, if needed. -func (t *Tree) Set(key string, comment *string, value interface{}) { - t.SetPath(strings.Split(key, "."), comment, value) +func (t *Tree) Set(key string, comment *string, commented bool, value interface{}) { + t.SetPath(strings.Split(key, "."), comment, commented, value) } // SetPath sets an element in the tree. // Keys is an array of path elements (e.g. {"a","b","c"}). // Creates all necessary intermediate trees, if needed. -func (t *Tree) SetPath(keys []string, comment *string, value interface{}) { +func (t *Tree) SetPath(keys []string, comment *string, commented bool, value interface{}) { subtree := t for _, intermediateKey := range keys[:len(keys)-1] { nextTree, exists := subtree.values[intermediateKey] @@ -213,13 +215,15 @@ func (t *Tree) SetPath(keys []string, comment *string, value interface{}) { case *Tree: toInsert = value subtree.comment = comment + subtree.commented = commented case []*Tree: toInsert = value subtree.comment = comment + subtree.commented = commented case *tomlValue: toInsert = value default: - toInsert = &tomlValue{value: value, comment: comment} + toInsert = &tomlValue{value: value, comment: comment, commented: commented} } subtree.values[keys[len(keys)-1]] = toInsert diff --git a/tomltree_write.go b/tomltree_write.go index 233b1c55..85568907 100644 --- a/tomltree_write.go +++ b/tomltree_write.go @@ -119,14 +119,19 @@ func (t *Tree) writeTo(w io.Writer, indent, keyspace string, bytesCount int64) ( } if v.comment != nil { - writtenBytesCountComment, errc := writeStrings(w, "\n", indent, "# ", *v.comment, "\n") + comment := strings.Replace(*v.comment, "\n", "\n"+indent, -1) + writtenBytesCountComment, errc := writeStrings(w, "\n", indent, "# ", comment, "\n") bytesCount += int64(writtenBytesCountComment) if errc != nil { return bytesCount, errc } } - writtenBytesCount, err := writeStrings(w, indent, k, " = ", repr, "\n") + var commented string + if v.commented { + commented = "# " + } + writtenBytesCount, err := writeStrings(w, indent, commented, k, " = ", repr, "\n") bytesCount += int64(writtenBytesCount) if err != nil { return bytesCount, err @@ -140,17 +145,24 @@ func (t *Tree) writeTo(w io.Writer, indent, keyspace string, bytesCount int64) ( if keyspace != "" { combinedKey = keyspace + "." + combinedKey } + var commented string + if t.commented { + commented = "# " + } + + if t.comment != nil { + comment := strings.Replace(*t.comment, "\n", "\n"+indent, -1) + writtenBytesCountComment, errc := writeStrings(w, "\n", indent, "# ", comment) + bytesCount += int64(writtenBytesCountComment) + if errc != nil { + return bytesCount, errc + } + } + switch node := v.(type) { // node has to be of those two types given how keys are sorted above case *Tree: - if t.comment != nil { - writtenBytesCountComment, errc := writeStrings(w, "\n", indent, "# ", combinedKey, " ", *t.comment) - bytesCount += int64(writtenBytesCountComment) - if errc != nil { - return bytesCount, errc - } - } - writtenBytesCount, err := writeStrings(w, "\n", indent, "[", combinedKey, "]\n") + writtenBytesCount, err := writeStrings(w, "\n", indent, commented, "[", combinedKey, "]\n") bytesCount += int64(writtenBytesCount) if err != nil { return bytesCount, err @@ -161,14 +173,7 @@ func (t *Tree) writeTo(w io.Writer, indent, keyspace string, bytesCount int64) ( } case []*Tree: for _, subTree := range node { - if t.comment != nil { - writtenBytesCountComment, errc := writeStrings(w, "\n", indent, "# ", *t.comment) - bytesCount += int64(writtenBytesCountComment) - if errc != nil { - return bytesCount, errc - } - } - writtenBytesCount, err := writeStrings(w, "\n", indent, "[[", combinedKey, "]]\n") + writtenBytesCount, err := writeStrings(w, "\n", indent, commented, "[[", combinedKey, "]]\n") bytesCount += int64(writtenBytesCount) if err != nil { return bytesCount, err