Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert to package, JSON null funcs to generics #16973

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions sdk/azcore/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,16 @@ type TokenCredential = shared.TokenCredential
// holds sentinel values used to send nulls
var nullables map[reflect.Type]interface{} = map[reflect.Type]interface{}{}

func typeOfT[T any]() reflect.Type {
// you can't, at present, obtain the type of
// a type parameter, so this is the trick
return reflect.TypeOf((*T)(nil)).Elem()
}

// NullValue is used to send an explicit 'null' within a request.
// This is typically used in JSON-MERGE-PATCH operations to delete a value.
func NullValue(v interface{}) interface{} {
t := reflect.TypeOf(v)
if k := t.Kind(); k != reflect.Ptr && k != reflect.Slice && k != reflect.Map {
// t is not of pointer type, make it be of pointer type
t = reflect.PtrTo(t)
}
func NullValue[T any]() T {
t := typeOfT[T]()
v, found := nullables[t]
if !found {
var o reflect.Value
Expand All @@ -48,18 +50,14 @@ func NullValue(v interface{}) interface{} {
nullables[t] = v
}
// return the sentinel object
return v
return v.(T)
}

// IsNullValue returns true if the field contains a null sentinel value.
// This is used by custom marshallers to properly encode a null value.
func IsNullValue(v interface{}) bool {
func IsNullValue[T any](v T) bool {
// see if our map has a sentinel object for this *T
t := reflect.TypeOf(v)
if k := t.Kind(); k != reflect.Ptr && k != reflect.Slice && k != reflect.Map {
// v isn't a pointer type so it can never be a null
return false
}
if o, found := nullables[t]; found {
o1 := reflect.ValueOf(o)
v1 := reflect.ValueOf(v)
Expand Down
42 changes: 10 additions & 32 deletions sdk/azcore/core_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,11 @@ import (
)

func TestNullValue(t *testing.T) {
v := NullValue("")
if _, ok := v.(*string); !ok {
t.Fatalf("unexpected type %T", v)
}
vv := NullValue((*string)(nil))
if _, ok := vv.(*string); !ok {
t.Fatalf("unexpected type %T", vv)
}
v := NullValue[*string]()
vv := NullValue[*string]()
if v != vv {
t.Fatal("null values should match for the same types")
}
i := NullValue(1)
if _, ok := i.(*int); !ok {
t.Fatalf("unexpected type %T", v)
}
if v == i {
t.Fatal("null values for string and int should not match")
}
}

func TestIsNullValue(t *testing.T) {
Expand All @@ -44,7 +31,7 @@ func TestIsNullValue(t *testing.T) {
if IsNullValue(i) {
t.Fatal("i isn't a null value")
}
i = NullValue(0).(*int)
i = NullValue[*int]()
if !IsNullValue(i) {
t.Fatal("expected null value for i")
}
Expand All @@ -56,21 +43,12 @@ func TestIsNullValue(t *testing.T) {
}

func TestNullValueMapSlice(t *testing.T) {
v := NullValue([]string{})
if _, ok := v.([]string); !ok {
t.Fatalf("unexpected type %T", v)
}
vv := NullValue(([]string)(nil))
if _, ok := vv.([]string); !ok {
t.Fatalf("unexpected type %T", vv)
}
v := NullValue[[]string]()
vv := NullValue[[]string]()
if reflect.TypeOf(v) != reflect.TypeOf(vv) {
t.Fatal("null values should match for the same types")
}
m := NullValue(map[string]int{})
if _, ok := m.(map[string]int); !ok {
t.Fatalf("unexpected type %T", m)
}
m := NullValue[map[string]int]()
if reflect.TypeOf(v) == reflect.TypeOf(m) {
t.Fatal("null values for string and int should not match")
}
Expand All @@ -83,11 +61,11 @@ func TestIsNullValueMapSlice(t *testing.T) {
if IsNullValue(map[int]string{}) {
t.Fatal("map literal can't be a null value")
}
s := NullValue([]int{}).([]int)
s := NullValue[[]int]()
if !IsNullValue(s) {
t.Fatal("expected null value for s")
}
m := NullValue(map[string]interface{}{}).(map[string]interface{})
m := NullValue[map[string]interface{}]()
if !IsNullValue(m) {
t.Fatal("expected null value for s")
}
Expand All @@ -114,8 +92,8 @@ func TestIsNullValueMapSlice(t *testing.T) {
t.Fatal("unexpected null slice")
}

nf.Map = NullValue(map[string]int{}).(map[string]int)
nf.Slice = NullValue([]string{}).([]string)
nf.Map = NullValue[map[string]int]()
nf.Slice = NullValue[[]string]()
if !IsNullValue(nf.Map) {
t.Fatal("expected null map")
}
Expand Down
2 changes: 1 addition & 1 deletion sdk/azcore/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (w Widget) MarshalJSON() ([]byte, error) {

func ExampleNullValue() {
w := Widget{
Count: azcore.NullValue(0).(*int),
Count: azcore.NullValue[*int](),
}
b, _ := json.Marshal(w)
fmt.Println(string(b))
Expand Down
104 changes: 9 additions & 95 deletions sdk/azcore/to/to.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,102 +6,16 @@

package to

import "time"

// BoolPtr returns a pointer to the provided bool.
func BoolPtr(b bool) *bool {
return &b
}

// Float32Ptr returns a pointer to the provided float32.
func Float32Ptr(i float32) *float32 {
return &i
}

// Float64Ptr returns a pointer to the provided float64.
func Float64Ptr(i float64) *float64 {
return &i
}

// Int32Ptr returns a pointer to the provided int32.
func Int32Ptr(i int32) *int32 {
return &i
}

// Int64Ptr returns a pointer to the provided int64.
func Int64Ptr(i int64) *int64 {
return &i
}

// StringPtr returns a pointer to the provided string.
func StringPtr(s string) *string {
return &s
}

// TimePtr returns a pointer to the provided time.Time.
func TimePtr(t time.Time) *time.Time {
return &t
}

// Int32PtrArray returns an array of *int32 from the specified values.
func Int32PtrArray(vals ...int32) []*int32 {
arr := make([]*int32, len(vals))
for i := range vals {
arr[i] = Int32Ptr(vals[i])
}
return arr
}

// Int64PtrArray returns an array of *int64 from the specified values.
func Int64PtrArray(vals ...int64) []*int64 {
arr := make([]*int64, len(vals))
for i := range vals {
arr[i] = Int64Ptr(vals[i])
}
return arr
}

// Float32PtrArray returns an array of *float32 from the specified values.
func Float32PtrArray(vals ...float32) []*float32 {
arr := make([]*float32, len(vals))
for i := range vals {
arr[i] = Float32Ptr(vals[i])
}
return arr
}

// Float64PtrArray returns an array of *float64 from the specified values.
func Float64PtrArray(vals ...float64) []*float64 {
arr := make([]*float64, len(vals))
for i := range vals {
arr[i] = Float64Ptr(vals[i])
}
return arr
}

// BoolPtrArray returns an array of *bool from the specified values.
func BoolPtrArray(vals ...bool) []*bool {
arr := make([]*bool, len(vals))
for i := range vals {
arr[i] = BoolPtr(vals[i])
}
return arr
}

// StringPtrArray returns an array of *string from the specified values.
func StringPtrArray(vals ...string) []*string {
arr := make([]*string, len(vals))
for i := range vals {
arr[i] = StringPtr(vals[i])
}
return arr
// Ptr returns a pointer to the provided value.
func Ptr[T any](v T) *T {
return &v
}

// TimePtrArray returns an array of *time.Time from the specified values.
func TimePtrArray(vals ...time.Time) []*time.Time {
arr := make([]*time.Time, len(vals))
for i := range vals {
arr[i] = TimePtr(vals[i])
// SliceOfPtrs returns a slice of *T from the specified values.
func SliceOfPtrs[T any](vv ...T) []*T {
slc := make([]*T, len(vv))
for i := range vv {
slc[i] = Ptr(vv[i])
}
return arr
return slc
}
Loading