From d55304782ef5d990c45b4ec62e0f45ba781ffb9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Wei=C3=9Fe?= <66256922+daniel-weisse@users.noreply.github.com> Date: Sat, 17 Aug 2024 21:44:21 +0200 Subject: [PATCH] Allow int, uint, and floats as map keys (#958) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Daniel Weiße --- marshaler.go | 12 ++++ marshaler_test.go | 61 +++++++++++++++- unmarshaler.go | 28 ++++++++ unmarshaler_test.go | 172 +++++++++++++++++++++++++++++++++++++++----- 4 files changed, 251 insertions(+), 22 deletions(-) diff --git a/marshaler.go b/marshaler.go index fe67eccd..f9e6d09d 100644 --- a/marshaler.go +++ b/marshaler.go @@ -631,6 +631,18 @@ func (enc *Encoder) keyToString(k reflect.Value) (string, error) { return "", fmt.Errorf("toml: error marshalling key %v from text: %w", k, err) } return string(keyB), nil + + case keyType.Kind() == reflect.Int || keyType.Kind() == reflect.Int8 || keyType.Kind() == reflect.Int16 || keyType.Kind() == reflect.Int32 || keyType.Kind() == reflect.Int64: + return strconv.FormatInt(k.Int(), 10), nil + + case keyType.Kind() == reflect.Uint || keyType.Kind() == reflect.Uint8 || keyType.Kind() == reflect.Uint16 || keyType.Kind() == reflect.Uint32 || keyType.Kind() == reflect.Uint64: + return strconv.FormatUint(k.Uint(), 10), nil + + case keyType.Kind() == reflect.Float32: + return strconv.FormatFloat(k.Float(), 'f', -1, 32), nil + + case keyType.Kind() == reflect.Float64: + return strconv.FormatFloat(k.Float(), 'f', -1, 64), nil } return "", fmt.Errorf("toml: type %s is not supported as a map key", keyType.Kind()) } diff --git a/marshaler_test.go b/marshaler_test.go index b3473e38..e5d1f11c 100644 --- a/marshaler_test.go +++ b/marshaler_test.go @@ -587,13 +587,69 @@ foo = 42 `, }, { - desc: "invalid map key", + desc: "int map key", v: map[int]interface{}{1: "a"}, + expected: `1 = 'a' +`, + }, + { + desc: "int8 map key", + v: map[int8]interface{}{1: "a"}, + expected: `1 = 'a' +`, + }, + { + desc: "int64 map key", + v: map[int64]interface{}{1: "a"}, + expected: `1 = 'a' +`, + }, + { + desc: "uint map key", + v: map[uint]interface{}{1: "a"}, + expected: `1 = 'a' +`, + }, + { + desc: "uint8 map key", + v: map[uint8]interface{}{1: "a"}, + expected: `1 = 'a' +`, + }, + { + desc: "uint64 map key", + v: map[uint64]interface{}{1: "a"}, + expected: `1 = 'a' +`, + }, + { + desc: "float32 map key", + v: map[float32]interface{}{ + 1.1: "a", + 1.0020: "b", + }, + expected: `'1.002' = 'b' +'1.1' = 'a' +`, + }, + { + desc: "float64 map key", + v: map[float64]interface{}{ + 1.1: "a", + 1.0020: "b", + }, + expected: `'1.002' = 'b' +'1.1' = 'a' +`, + }, + { + desc: "invalid map key", + v: map[struct{ int }]interface{}{{1}: "a"}, err: true, }, { desc: "invalid map key but empty", - v: map[int]interface{}{}, + v: map[struct{ int }]interface{}{}, expected: "", }, { @@ -1565,7 +1621,6 @@ func ExampleMarshal() { // configuration file that has commented out sections (example from // go-graphite/graphite-clickhouse). func ExampleMarshal_commented() { - type Common struct { Listen string `toml:"listen" comment:"general listener"` PprofListen string `toml:"pprof-listen" comment:"listener to serve /debug/pprof requests. '-pprof' argument overrides it"` diff --git a/unmarshaler.go b/unmarshaler.go index 94f1113e..c3df8bee 100644 --- a/unmarshaler.go +++ b/unmarshaler.go @@ -7,6 +7,7 @@ import ( "io" "math" "reflect" + "strconv" "strings" "sync/atomic" "time" @@ -1079,6 +1080,33 @@ func (d *decoder) keyFromData(keyType reflect.Type, data []byte) (reflect.Value, return reflect.Value{}, fmt.Errorf("toml: error unmarshalling key type %s from text: %w", stringType, err) } return mk.Elem(), nil + + case keyType.Kind() == reflect.Int || keyType.Kind() == reflect.Int8 || keyType.Kind() == reflect.Int16 || keyType.Kind() == reflect.Int32 || keyType.Kind() == reflect.Int64: + key, err := strconv.ParseInt(string(data), 10, 64) + if err != nil { + return reflect.Value{}, fmt.Errorf("toml: error parsing key of type %s from integer: %w", stringType, err) + } + return reflect.ValueOf(key).Convert(keyType), nil + case keyType.Kind() == reflect.Uint || keyType.Kind() == reflect.Uint8 || keyType.Kind() == reflect.Uint16 || keyType.Kind() == reflect.Uint32 || keyType.Kind() == reflect.Uint64: + key, err := strconv.ParseUint(string(data), 10, 64) + if err != nil { + return reflect.Value{}, fmt.Errorf("toml: error parsing key of type %s from unsigned integer: %w", stringType, err) + } + return reflect.ValueOf(key).Convert(keyType), nil + + case keyType.Kind() == reflect.Float32: + key, err := strconv.ParseFloat(string(data), 32) + if err != nil { + return reflect.Value{}, fmt.Errorf("toml: error parsing key of type %s from float: %w", stringType, err) + } + return reflect.ValueOf(float32(key)), nil + + case keyType.Kind() == reflect.Float64: + key, err := strconv.ParseFloat(string(data), 64) + if err != nil { + return reflect.Value{}, fmt.Errorf("toml: error parsing key of type %s from float: %w", stringType, err) + } + return reflect.ValueOf(float64(key)), nil } return reflect.Value{}, fmt.Errorf("toml: cannot convert map key of type %s to expected type %s", stringType, keyType) } diff --git a/unmarshaler_test.go b/unmarshaler_test.go index c1833fb0..3cbd81d1 100644 --- a/unmarshaler_test.go +++ b/unmarshaler_test.go @@ -205,7 +205,6 @@ func TestUnmarshal_Floats(t *testing.T) { testFn func(t *testing.T, v float64) err bool }{ - { desc: "float pi", input: `3.1415`, @@ -840,8 +839,10 @@ huey = 'dewey' return test{ target: &doc{}, - expected: &doc{A: []interface{}{"0", "1", "2", "3", "4", "5", "6", - "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17"}}, + expected: &doc{A: []interface{}{ + "0", "1", "2", "3", "4", "5", "6", + "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", + }}, } }, }, @@ -1696,16 +1697,6 @@ B = "data"`, } }, }, - { - desc: "empty map into map with invalid key type", - input: ``, - gen: func() test { - return test{ - target: &map[int]string{}, - expected: &map[int]string{}, - } - }, - }, { desc: "into map with convertible key type", input: `A = "hello"`, @@ -1942,6 +1933,150 @@ B = "data"`, } }, }, + { + desc: "into map of int to string", + input: `1 = "a"`, + gen: func() test { + return test{ + target: &map[int]string{}, + expected: &map[int]string{1: "a"}, + assert: func(t *testing.T, test test) { + assert.Equal(t, test.expected, test.target) + }, + } + }, + }, + { + desc: "into map of int8 to string", + input: `1 = "a"`, + gen: func() test { + return test{ + target: &map[int8]string{}, + expected: &map[int8]string{1: "a"}, + assert: func(t *testing.T, test test) { + assert.Equal(t, test.expected, test.target) + }, + } + }, + }, + { + desc: "into map of int64 to string", + input: `1 = "a"`, + gen: func() test { + return test{ + target: &map[int64]string{}, + expected: &map[int64]string{1: "a"}, + assert: func(t *testing.T, test test) { + assert.Equal(t, test.expected, test.target) + }, + } + }, + }, + { + desc: "into map of uint to string", + input: `1 = "a"`, + gen: func() test { + return test{ + target: &map[uint]string{}, + expected: &map[uint]string{1: "a"}, + assert: func(t *testing.T, test test) { + assert.Equal(t, test.expected, test.target) + }, + } + }, + }, + { + desc: "into map of uint8 to string", + input: `1 = "a"`, + gen: func() test { + return test{ + target: &map[uint8]string{}, + expected: &map[uint8]string{1: "a"}, + assert: func(t *testing.T, test test) { + assert.Equal(t, test.expected, test.target) + }, + } + }, + }, + { + desc: "into map of uint64 to string", + input: `1 = "a"`, + gen: func() test { + return test{ + target: &map[uint64]string{}, + expected: &map[uint64]string{1: "a"}, + assert: func(t *testing.T, test test) { + assert.Equal(t, test.expected, test.target) + }, + } + }, + }, + { + desc: "into map of uint with invalid key", + input: `-1 = "a"`, + gen: func() test { + return test{ + target: &map[uint]string{}, + err: true, + } + }, + }, + { + desc: "into map of float64 to string", + input: `'1.01' = "a"`, + gen: func() test { + return test{ + target: &map[float64]string{}, + expected: &map[float64]string{1.01: "a"}, + assert: func(t *testing.T, test test) { + assert.Equal(t, test.expected, test.target) + }, + } + }, + }, + { + desc: "into map of float64 with invalid key", + input: `key = "a"`, + gen: func() test { + return test{ + target: &map[float64]string{}, + err: true, + } + }, + }, + { + desc: "into map of float32 to string", + input: `'1.01' = "a"`, + gen: func() test { + return test{ + target: &map[float32]string{}, + expected: &map[float32]string{1.01: "a"}, + assert: func(t *testing.T, test test) { + assert.Equal(t, test.expected, test.target) + }, + } + }, + }, + { + desc: "into map of float32 with invalid key", + input: `key = "a"`, + gen: func() test { + return test{ + target: &map[float32]string{}, + err: true, + } + }, + }, + { + desc: "invalid map key type", + input: `1 = "a"`, + gen: func() test { + return test{ + target: &map[struct{ int }]string{}, + err: true, + } + }, + }, } for _, e := range examples { @@ -2653,7 +2788,7 @@ func TestIssue772(t *testing.T) { FileHandling `toml:"filehandling"` } - var defaultConfigFile = []byte(` + defaultConfigFile := []byte(` [filehandling] pattern = "reach-masterdev-"`) @@ -2750,7 +2885,7 @@ func TestIssue866(t *testing.T) { PipelineMapping map[string]*Pipeline `toml:"pipelines"` } - var badToml = ` + badToml := ` [pipelines.register] mapping.inst.req = [ ["param1", "value1"], @@ -2768,7 +2903,7 @@ mapping.inst.res = [ t.Fatal("unmarshal failed with mismatch value") } - var goodTooToml = ` + goodTooToml := ` [pipelines.register] mapping.inst.req = [ ["param1", "value1"], @@ -2783,7 +2918,7 @@ mapping.inst.req = [ t.Fatal("unmarshal failed with mismatch value") } - var goodToml = ` + goodToml := ` [pipelines.register.mapping.inst] req = [ ["param1", "value1"], @@ -3362,7 +3497,7 @@ func TestOmitEmpty(t *testing.T) { X []elem `toml:",inline"` } - d := doc{X: []elem{elem{ + d := doc{X: []elem{{ Foo: "test", Inner: inner{ V: "alue", @@ -3785,7 +3920,6 @@ func (k *CustomUnmarshalerKey) UnmarshalTOML(value *unstable.Node) error { } k.A = item return nil - } func TestUnmarshal_CustomUnmarshaler(t *testing.T) {