diff --git a/sdk/azcore/core.go b/sdk/azcore/core.go index b8bf33df5b33..ec313526c0ba 100644 --- a/sdk/azcore/core.go +++ b/sdk/azcore/core.go @@ -23,14 +23,16 @@ type TokenCredential = shared.TokenCredential // holds sentinel values used to send nulls var nullables map[reflect.Type]interface{} = map[reflect.Type]interface{}{} +func typeOfT[T any]() reflect.Type { + // you can't, at present, obtain the type of + // a type parameter, so this is the trick + return reflect.TypeOf((*T)(nil)).Elem() +} + // NullValue is used to send an explicit 'null' within a request. // This is typically used in JSON-MERGE-PATCH operations to delete a value. -func NullValue(v interface{}) interface{} { - t := reflect.TypeOf(v) - if k := t.Kind(); k != reflect.Ptr && k != reflect.Slice && k != reflect.Map { - // t is not of pointer type, make it be of pointer type - t = reflect.PtrTo(t) - } +func NullValue[T any]() T { + t := typeOfT[T]() v, found := nullables[t] if !found { var o reflect.Value @@ -48,18 +50,14 @@ func NullValue(v interface{}) interface{} { nullables[t] = v } // return the sentinel object - return v + return v.(T) } // IsNullValue returns true if the field contains a null sentinel value. // This is used by custom marshallers to properly encode a null value. -func IsNullValue(v interface{}) bool { +func IsNullValue[T any](v T) bool { // see if our map has a sentinel object for this *T t := reflect.TypeOf(v) - if k := t.Kind(); k != reflect.Ptr && k != reflect.Slice && k != reflect.Map { - // v isn't a pointer type so it can never be a null - return false - } if o, found := nullables[t]; found { o1 := reflect.ValueOf(o) v1 := reflect.ValueOf(v) diff --git a/sdk/azcore/core_test.go b/sdk/azcore/core_test.go index 2cbd1d893de2..56b8b5504542 100644 --- a/sdk/azcore/core_test.go +++ b/sdk/azcore/core_test.go @@ -12,24 +12,11 @@ import ( ) func TestNullValue(t *testing.T) { - v := NullValue("") - if _, ok := v.(*string); !ok { - t.Fatalf("unexpected type %T", v) - } - vv := NullValue((*string)(nil)) - if _, ok := vv.(*string); !ok { - t.Fatalf("unexpected type %T", vv) - } + v := NullValue[*string]() + vv := NullValue[*string]() if v != vv { t.Fatal("null values should match for the same types") } - i := NullValue(1) - if _, ok := i.(*int); !ok { - t.Fatalf("unexpected type %T", v) - } - if v == i { - t.Fatal("null values for string and int should not match") - } } func TestIsNullValue(t *testing.T) { @@ -44,7 +31,7 @@ func TestIsNullValue(t *testing.T) { if IsNullValue(i) { t.Fatal("i isn't a null value") } - i = NullValue(0).(*int) + i = NullValue[*int]() if !IsNullValue(i) { t.Fatal("expected null value for i") } @@ -56,21 +43,12 @@ func TestIsNullValue(t *testing.T) { } func TestNullValueMapSlice(t *testing.T) { - v := NullValue([]string{}) - if _, ok := v.([]string); !ok { - t.Fatalf("unexpected type %T", v) - } - vv := NullValue(([]string)(nil)) - if _, ok := vv.([]string); !ok { - t.Fatalf("unexpected type %T", vv) - } + v := NullValue[[]string]() + vv := NullValue[[]string]() if reflect.TypeOf(v) != reflect.TypeOf(vv) { t.Fatal("null values should match for the same types") } - m := NullValue(map[string]int{}) - if _, ok := m.(map[string]int); !ok { - t.Fatalf("unexpected type %T", m) - } + m := NullValue[map[string]int]() if reflect.TypeOf(v) == reflect.TypeOf(m) { t.Fatal("null values for string and int should not match") } @@ -83,11 +61,11 @@ func TestIsNullValueMapSlice(t *testing.T) { if IsNullValue(map[int]string{}) { t.Fatal("map literal can't be a null value") } - s := NullValue([]int{}).([]int) + s := NullValue[[]int]() if !IsNullValue(s) { t.Fatal("expected null value for s") } - m := NullValue(map[string]interface{}{}).(map[string]interface{}) + m := NullValue[map[string]interface{}]() if !IsNullValue(m) { t.Fatal("expected null value for s") } @@ -114,8 +92,8 @@ func TestIsNullValueMapSlice(t *testing.T) { t.Fatal("unexpected null slice") } - nf.Map = NullValue(map[string]int{}).(map[string]int) - nf.Slice = NullValue([]string{}).([]string) + nf.Map = NullValue[map[string]int]() + nf.Slice = NullValue[[]string]() if !IsNullValue(nf.Map) { t.Fatal("expected null map") } diff --git a/sdk/azcore/example_test.go b/sdk/azcore/example_test.go index 434051bcefe9..955f7777fd40 100644 --- a/sdk/azcore/example_test.go +++ b/sdk/azcore/example_test.go @@ -56,7 +56,7 @@ func (w Widget) MarshalJSON() ([]byte, error) { func ExampleNullValue() { w := Widget{ - Count: azcore.NullValue(0).(*int), + Count: azcore.NullValue[*int](), } b, _ := json.Marshal(w) fmt.Println(string(b)) diff --git a/sdk/azcore/to/to.go b/sdk/azcore/to/to.go index 57a8d10ecc3f..e0e4817b90d1 100644 --- a/sdk/azcore/to/to.go +++ b/sdk/azcore/to/to.go @@ -6,102 +6,16 @@ package to -import "time" - -// BoolPtr returns a pointer to the provided bool. -func BoolPtr(b bool) *bool { - return &b -} - -// Float32Ptr returns a pointer to the provided float32. -func Float32Ptr(i float32) *float32 { - return &i -} - -// Float64Ptr returns a pointer to the provided float64. -func Float64Ptr(i float64) *float64 { - return &i -} - -// Int32Ptr returns a pointer to the provided int32. -func Int32Ptr(i int32) *int32 { - return &i -} - -// Int64Ptr returns a pointer to the provided int64. -func Int64Ptr(i int64) *int64 { - return &i -} - -// StringPtr returns a pointer to the provided string. -func StringPtr(s string) *string { - return &s -} - -// TimePtr returns a pointer to the provided time.Time. -func TimePtr(t time.Time) *time.Time { - return &t -} - -// Int32PtrArray returns an array of *int32 from the specified values. -func Int32PtrArray(vals ...int32) []*int32 { - arr := make([]*int32, len(vals)) - for i := range vals { - arr[i] = Int32Ptr(vals[i]) - } - return arr -} - -// Int64PtrArray returns an array of *int64 from the specified values. -func Int64PtrArray(vals ...int64) []*int64 { - arr := make([]*int64, len(vals)) - for i := range vals { - arr[i] = Int64Ptr(vals[i]) - } - return arr -} - -// Float32PtrArray returns an array of *float32 from the specified values. -func Float32PtrArray(vals ...float32) []*float32 { - arr := make([]*float32, len(vals)) - for i := range vals { - arr[i] = Float32Ptr(vals[i]) - } - return arr -} - -// Float64PtrArray returns an array of *float64 from the specified values. -func Float64PtrArray(vals ...float64) []*float64 { - arr := make([]*float64, len(vals)) - for i := range vals { - arr[i] = Float64Ptr(vals[i]) - } - return arr -} - -// BoolPtrArray returns an array of *bool from the specified values. -func BoolPtrArray(vals ...bool) []*bool { - arr := make([]*bool, len(vals)) - for i := range vals { - arr[i] = BoolPtr(vals[i]) - } - return arr -} - -// StringPtrArray returns an array of *string from the specified values. -func StringPtrArray(vals ...string) []*string { - arr := make([]*string, len(vals)) - for i := range vals { - arr[i] = StringPtr(vals[i]) - } - return arr +// Ptr returns a pointer to the provided value. +func Ptr[T any](v T) *T { + return &v } -// TimePtrArray returns an array of *time.Time from the specified values. -func TimePtrArray(vals ...time.Time) []*time.Time { - arr := make([]*time.Time, len(vals)) - for i := range vals { - arr[i] = TimePtr(vals[i]) +// SliceOfPtrs returns a slice of *T from the specified values. +func SliceOfPtrs[T any](vv ...T) []*T { + slc := make([]*T, len(vv)) + for i := range vv { + slc[i] = Ptr(vv[i]) } - return arr + return slc } diff --git a/sdk/azcore/to/to_test.go b/sdk/azcore/to/to_test.go index 177e9a48a3e4..175f52a31aea 100644 --- a/sdk/azcore/to/to_test.go +++ b/sdk/azcore/to/to_test.go @@ -7,16 +7,12 @@ package to import ( - "fmt" - "reflect" - "strconv" "testing" - "time" ) -func TestBoolPtr(t *testing.T) { +func TestPtr(t *testing.T) { b := true - pb := BoolPtr(b) + pb := Ptr(b) if pb == nil { t.Fatal("unexpected nil conversion") } @@ -25,168 +21,15 @@ func TestBoolPtr(t *testing.T) { } } -func TestFloat32Ptr(t *testing.T) { - f32 := float32(3.1415926) - pf32 := Float32Ptr(f32) - if pf32 == nil { - t.Fatal("unexpected nil conversion") - } - if *pf32 != f32 { - t.Fatalf("got %v, want %v", *pf32, f32) - } -} - -func TestFloat64Ptr(t *testing.T) { - f64 := float64(2.71828182845904) - pf64 := Float64Ptr(f64) - if pf64 == nil { - t.Fatal("unexpected nil conversion") - } - if *pf64 != f64 { - t.Fatalf("got %v, want %v", *pf64, f64) - } -} - -func TestInt32Ptr(t *testing.T) { - i32 := int32(123456789) - pi32 := Int32Ptr(i32) - if pi32 == nil { - t.Fatal("unexpected nil conversion") - } - if *pi32 != i32 { - t.Fatalf("got %v, want %v", *pi32, i32) - } -} - -func TestInt64Ptr(t *testing.T) { - i64 := int64(9876543210) - pi64 := Int64Ptr(i64) - if pi64 == nil { - t.Fatal("unexpected nil conversion") - } - if *pi64 != i64 { - t.Fatalf("got %v, want %v", *pi64, i64) - } -} - -func TestStringPtr(t *testing.T) { - s := "the string" - ps := StringPtr(s) - if ps == nil { - t.Fatal("unexpected nil conversion") - } - if *ps != s { - t.Fatalf("got %v, want %v", *ps, s) - } -} - -func TestTimePtr(t *testing.T) { - tt := time.Now() - pt := TimePtr(tt) - if pt == nil { - t.Fatal("unexpected nil conversion") - } - if *pt != tt { - t.Fatalf("got %v, want %v", *pt, tt) - } -} - -func TestInt32PtrArray(t *testing.T) { - arr := Int32PtrArray() - if len(arr) != 0 { - t.Fatal("expected zero length") - } - arr = Int32PtrArray(1, 2, 3, 4, 5) - for i, v := range arr { - if *v != int32(i+1) { - t.Fatal("values don't match") - } - } -} - -func TestInt64PtrArray(t *testing.T) { - arr := Int64PtrArray() +func TestSliceOfPtrs(t *testing.T) { + arr := SliceOfPtrs[int]() if len(arr) != 0 { t.Fatal("expected zero length") } - arr = Int64PtrArray(1, 2, 3, 4, 5) + arr = SliceOfPtrs(1, 2, 3, 4, 5) for i, v := range arr { - if *v != int64(i+1) { + if *v != i+1 { t.Fatal("values don't match") } } } - -func TestFloat32PtrArray(t *testing.T) { - arr := Float32PtrArray() - if len(arr) != 0 { - t.Fatal("expected zero length") - } - arr = Float32PtrArray(1.1, 2.2, 3.3, 4.4, 5.5) - for i, v := range arr { - f, err := strconv.ParseFloat(fmt.Sprintf("%d.%d", i+1, i+1), 32) - if err != nil { - t.Fatal(err) - } - if *v != float32(f) { - t.Fatal("values don't match") - } - } -} - -func TestFloat64PtrArray(t *testing.T) { - arr := Float64PtrArray() - if len(arr) != 0 { - t.Fatal("expected zero length") - } - arr = Float64PtrArray(1.1, 2.2, 3.3, 4.4, 5.5) - for i, v := range arr { - f, err := strconv.ParseFloat(fmt.Sprintf("%d.%d", i+1, i+1), 64) - if err != nil { - t.Fatal(err) - } - if *v != f { - t.Fatal("values don't match") - } - } -} - -func TestBoolPtrArray(t *testing.T) { - arr := BoolPtrArray() - if len(arr) != 0 { - t.Fatal("expected zero length") - } - arr = BoolPtrArray(true, false, true) - curr := true - for _, v := range arr { - if *v != curr { - t.Fatal("values don'p match") - } - curr = !curr - } -} - -func TestStringPtrArray(t *testing.T) { - arr := StringPtrArray() - if len(arr) != 0 { - t.Fatal("expected zero length") - } - arr = StringPtrArray("one", "", "three") - if !reflect.DeepEqual(arr, []*string{StringPtr("one"), StringPtr(""), StringPtr("three")}) { - t.Fatal("values don't match") - } -} - -func TestTimePtrArray(t *testing.T) { - arr := TimePtrArray() - if len(arr) != 0 { - t.Fatal("expected zero length") - } - t1 := time.Now() - t2 := time.Time{} - t3 := t1.Add(24 * time.Hour) - arr = TimePtrArray(t1, t2, t3) - if !reflect.DeepEqual(arr, []*time.Time{&t1, &t2, &t3}) { - t.Fatal("values don't match") - } -}