From adc88893038e5f2ae2bb535e965cf21a61c3dda4 Mon Sep 17 00:00:00 2001 From: 1eedaegon Date: Tue, 2 Apr 2024 11:33:10 +0900 Subject: [PATCH] feat: Add Json marshalling, unmarshalling --- hashset.go | 34 ++++++++++++++++++++++++++++++++++ hashset_test.go | 17 ++++++++++++++++- 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/hashset.go b/hashset.go index 2891bb6..df9bb63 100644 --- a/hashset.go +++ b/hashset.go @@ -1,6 +1,8 @@ package hashset import ( + "encoding/json" + "fmt" "reflect" "sync" ) @@ -147,6 +149,38 @@ func (s *Set) ToSlice() []interface{} { return uniTypeSlice } +func (s *Set) MarshalJSON() ([]byte, error) { + stringMap := make(map[string]bool) + // s.mu.RLock() + for k, v := range s.hash { + if reflect.TypeOf(k).Kind() == reflect.Func { + fmt.Printf("[WARN] Skipped function pointer value in set: %v (hashset - MarshalJSON)", k) + continue + } + key := fmt.Sprintf("%v", k) + stringMap[key] = v + } + // s.mu.RUnlock() + jsonByte, err := json.Marshal(stringMap) + if err != nil { + return nil, err + } + return jsonByte, nil +} +func (s *Set) UnmarshalJSON(data []byte) error { + stringMap := make(map[string]bool) + // Here, it is guaranteed that Unmarshal will be appropriate. + s.mu.Lock() + if err := json.Unmarshal(data, &stringMap); err != nil { + return err + } + s.mu.Unlock() + for k := range stringMap { + s.Add(k) + } + return nil +} + // MakeComparable returns pointer(address) not comparable types: slice, map, function func MakeComparable(element interface{}) interface{} { /* diff --git a/hashset_test.go b/hashset_test.go index a872c2d..24d3f7f 100644 --- a/hashset_test.go +++ b/hashset_test.go @@ -1,6 +1,7 @@ package hashset import ( + "encoding/json" "reflect" "strconv" "sync" @@ -127,8 +128,10 @@ func TestConvertToSet(t *testing.T) { func TestConvertToSlice(t *testing.T) { caseSlice := []int{1, 2, 3} caseSliceTwo := []string{"1", "a", "b"} + // Splitting slices s := New(caseSlice, caseSliceTwo) require.Equal(t, 6, s.Len()) + // Converting set to slice arr := s.ToSlice() require.Equal(t, 6, len(arr)) require.True(t, reflect.ValueOf(arr).Kind() == reflect.Slice) @@ -136,7 +139,6 @@ func TestConvertToSlice(t *testing.T) { require.Contains(t, arr, 3) require.Contains(t, arr, "1") require.Contains(t, arr, "b") - } func TestUnion(t *testing.T) { @@ -170,6 +172,19 @@ func TestIntersection(t *testing.T) { func TestDifference(t *testing.T) {} func TestFunctionElement(t *testing.T) {} func TestStructElement(t *testing.T) {} + +func TestMarshalJSON(t *testing.T) { + + caseSlice := []int{1, 2, 3} + caseSliceTwo := []string{"1", "a", "b"} + // Splitting slices + s := New(caseSlice, caseSliceTwo) + ms, err := json.Marshal(s) + require.NoError(t, err) + s2 := New() + err = json.Unmarshal(ms, s2) + require.NoError(t, err) +} func TestConcurrentAddElement10Goroutine100000Loop(t *testing.T) { var wg sync.WaitGroup s := New()