Skip to content

Commit

Permalink
Convert to package, JSON null funcs to generics (#16973)
Browse files Browse the repository at this point in the history
NullValue now takes a generic type parameter instead of an interface arg
to determine the type of null sentinel value to create.
IsNullValue infers its generic parameter to determine the type of null
sentinel value to look for.
At present, there is no way to express a 'nillable' generic type
constraint so the funcs simply take a [T any] which should be fine as
they typically take/return pointer-to-types.
The 'to' package has been reduced to two funcs.
  • Loading branch information
jhendrixMSFT authored Feb 4, 2022
1 parent e0eaa3f commit 072a1b3
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 303 deletions.
22 changes: 10 additions & 12 deletions sdk/azcore/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,16 @@ type TokenCredential = shared.TokenCredential
// holds sentinel values used to send nulls
var nullables map[reflect.Type]interface{} = map[reflect.Type]interface{}{}

func typeOfT[T any]() reflect.Type {
// you can't, at present, obtain the type of
// a type parameter, so this is the trick
return reflect.TypeOf((*T)(nil)).Elem()
}

// NullValue is used to send an explicit 'null' within a request.
// This is typically used in JSON-MERGE-PATCH operations to delete a value.
func NullValue(v interface{}) interface{} {
t := reflect.TypeOf(v)
if k := t.Kind(); k != reflect.Ptr && k != reflect.Slice && k != reflect.Map {
// t is not of pointer type, make it be of pointer type
t = reflect.PtrTo(t)
}
func NullValue[T any]() T {
t := typeOfT[T]()
v, found := nullables[t]
if !found {
var o reflect.Value
Expand All @@ -48,18 +50,14 @@ func NullValue(v interface{}) interface{} {
nullables[t] = v
}
// return the sentinel object
return v
return v.(T)
}

// IsNullValue returns true if the field contains a null sentinel value.
// This is used by custom marshallers to properly encode a null value.
func IsNullValue(v interface{}) bool {
func IsNullValue[T any](v T) bool {
// see if our map has a sentinel object for this *T
t := reflect.TypeOf(v)
if k := t.Kind(); k != reflect.Ptr && k != reflect.Slice && k != reflect.Map {
// v isn't a pointer type so it can never be a null
return false
}
if o, found := nullables[t]; found {
o1 := reflect.ValueOf(o)
v1 := reflect.ValueOf(v)
Expand Down
42 changes: 10 additions & 32 deletions sdk/azcore/core_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,11 @@ import (
)

func TestNullValue(t *testing.T) {
v := NullValue("")
if _, ok := v.(*string); !ok {
t.Fatalf("unexpected type %T", v)
}
vv := NullValue((*string)(nil))
if _, ok := vv.(*string); !ok {
t.Fatalf("unexpected type %T", vv)
}
v := NullValue[*string]()
vv := NullValue[*string]()
if v != vv {
t.Fatal("null values should match for the same types")
}
i := NullValue(1)
if _, ok := i.(*int); !ok {
t.Fatalf("unexpected type %T", v)
}
if v == i {
t.Fatal("null values for string and int should not match")
}
}

func TestIsNullValue(t *testing.T) {
Expand All @@ -44,7 +31,7 @@ func TestIsNullValue(t *testing.T) {
if IsNullValue(i) {
t.Fatal("i isn't a null value")
}
i = NullValue(0).(*int)
i = NullValue[*int]()
if !IsNullValue(i) {
t.Fatal("expected null value for i")
}
Expand All @@ -56,21 +43,12 @@ func TestIsNullValue(t *testing.T) {
}

func TestNullValueMapSlice(t *testing.T) {
v := NullValue([]string{})
if _, ok := v.([]string); !ok {
t.Fatalf("unexpected type %T", v)
}
vv := NullValue(([]string)(nil))
if _, ok := vv.([]string); !ok {
t.Fatalf("unexpected type %T", vv)
}
v := NullValue[[]string]()
vv := NullValue[[]string]()
if reflect.TypeOf(v) != reflect.TypeOf(vv) {
t.Fatal("null values should match for the same types")
}
m := NullValue(map[string]int{})
if _, ok := m.(map[string]int); !ok {
t.Fatalf("unexpected type %T", m)
}
m := NullValue[map[string]int]()
if reflect.TypeOf(v) == reflect.TypeOf(m) {
t.Fatal("null values for string and int should not match")
}
Expand All @@ -83,11 +61,11 @@ func TestIsNullValueMapSlice(t *testing.T) {
if IsNullValue(map[int]string{}) {
t.Fatal("map literal can't be a null value")
}
s := NullValue([]int{}).([]int)
s := NullValue[[]int]()
if !IsNullValue(s) {
t.Fatal("expected null value for s")
}
m := NullValue(map[string]interface{}{}).(map[string]interface{})
m := NullValue[map[string]interface{}]()
if !IsNullValue(m) {
t.Fatal("expected null value for s")
}
Expand All @@ -114,8 +92,8 @@ func TestIsNullValueMapSlice(t *testing.T) {
t.Fatal("unexpected null slice")
}

nf.Map = NullValue(map[string]int{}).(map[string]int)
nf.Slice = NullValue([]string{}).([]string)
nf.Map = NullValue[map[string]int]()
nf.Slice = NullValue[[]string]()
if !IsNullValue(nf.Map) {
t.Fatal("expected null map")
}
Expand Down
2 changes: 1 addition & 1 deletion sdk/azcore/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (w Widget) MarshalJSON() ([]byte, error) {

func ExampleNullValue() {
w := Widget{
Count: azcore.NullValue(0).(*int),
Count: azcore.NullValue[*int](),
}
b, _ := json.Marshal(w)
fmt.Println(string(b))
Expand Down
104 changes: 9 additions & 95 deletions sdk/azcore/to/to.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,102 +6,16 @@

package to

import "time"

// BoolPtr returns a pointer to the provided bool.
func BoolPtr(b bool) *bool {
return &b
}

// Float32Ptr returns a pointer to the provided float32.
func Float32Ptr(i float32) *float32 {
return &i
}

// Float64Ptr returns a pointer to the provided float64.
func Float64Ptr(i float64) *float64 {
return &i
}

// Int32Ptr returns a pointer to the provided int32.
func Int32Ptr(i int32) *int32 {
return &i
}

// Int64Ptr returns a pointer to the provided int64.
func Int64Ptr(i int64) *int64 {
return &i
}

// StringPtr returns a pointer to the provided string.
func StringPtr(s string) *string {
return &s
}

// TimePtr returns a pointer to the provided time.Time.
func TimePtr(t time.Time) *time.Time {
return &t
}

// Int32PtrArray returns an array of *int32 from the specified values.
func Int32PtrArray(vals ...int32) []*int32 {
arr := make([]*int32, len(vals))
for i := range vals {
arr[i] = Int32Ptr(vals[i])
}
return arr
}

// Int64PtrArray returns an array of *int64 from the specified values.
func Int64PtrArray(vals ...int64) []*int64 {
arr := make([]*int64, len(vals))
for i := range vals {
arr[i] = Int64Ptr(vals[i])
}
return arr
}

// Float32PtrArray returns an array of *float32 from the specified values.
func Float32PtrArray(vals ...float32) []*float32 {
arr := make([]*float32, len(vals))
for i := range vals {
arr[i] = Float32Ptr(vals[i])
}
return arr
}

// Float64PtrArray returns an array of *float64 from the specified values.
func Float64PtrArray(vals ...float64) []*float64 {
arr := make([]*float64, len(vals))
for i := range vals {
arr[i] = Float64Ptr(vals[i])
}
return arr
}

// BoolPtrArray returns an array of *bool from the specified values.
func BoolPtrArray(vals ...bool) []*bool {
arr := make([]*bool, len(vals))
for i := range vals {
arr[i] = BoolPtr(vals[i])
}
return arr
}

// StringPtrArray returns an array of *string from the specified values.
func StringPtrArray(vals ...string) []*string {
arr := make([]*string, len(vals))
for i := range vals {
arr[i] = StringPtr(vals[i])
}
return arr
// Ptr returns a pointer to the provided value.
func Ptr[T any](v T) *T {
return &v
}

// TimePtrArray returns an array of *time.Time from the specified values.
func TimePtrArray(vals ...time.Time) []*time.Time {
arr := make([]*time.Time, len(vals))
for i := range vals {
arr[i] = TimePtr(vals[i])
// SliceOfPtrs returns a slice of *T from the specified values.
func SliceOfPtrs[T any](vv ...T) []*T {
slc := make([]*T, len(vv))
for i := range vv {
slc[i] = Ptr(vv[i])
}
return arr
return slc
}
Loading

0 comments on commit 072a1b3

Please sign in to comment.