diff --git a/make.go b/make.go index 81e1de9..4e764f4 100644 --- a/make.go +++ b/make.go @@ -9,12 +9,13 @@ package rapid import ( "fmt" "reflect" + "strings" ) // Make creates a generator of values of type V, using reflection to infer the required structure. func Make[V any]() *Generator[V] { var zero V - gen := newMakeGen(reflect.TypeOf(zero)) + gen := newMakeGen(reflect.TypeOf(zero), nil) return newGenerator[V](&makeGen[V]{ gen: gen, }) @@ -33,8 +34,8 @@ func (g *makeGen[V]) value(t *T) V { return g.gen.value(t).(V) } -func newMakeGen(typ reflect.Type) *Generator[any] { - gen, mayNeedCast := newMakeKindGen(typ) +func newMakeGen(typ reflect.Type, overrides []*Generator[any]) *Generator[any] { + gen, mayNeedCast := newMakeKindGen(typ, overrides) if !mayNeedCast || typ.String() == typ.Kind().String() { return gen // fast path with less reflect } @@ -55,7 +56,33 @@ func (g *castGen) value(t *T) any { return reflect.ValueOf(v).Convert(g.typ).Interface() } -func newMakeKindGen(typ reflect.Type) (gen *Generator[any], mayNeedCast bool) { +func makeLinkName(t reflect.Type) string { + path := t.PkgPath() + name := t.Name() + if len(path) == 0 && len(name) == 0 { + return "" + } + + if len(path) == 0 { + return name + } + + return fmt.Sprintf("%s.%s", t.PkgPath(), t.Name()) +} + +func newMakeKindGen(typ reflect.Type, overrides []*Generator[any]) (gen *Generator[any], mayNeedCast bool) { + // First, check if any overrides apply + for _, override := range overrides { + // My kingdom for https://github.com/golang/go/issues/54393 + tt := reflect.TypeOf(override.impl).Elem() + + // Types that are parameterized generically expose "link name" (see https://github.com/golang/go/issues/55924) + target := fmt.Sprintf("[%s]", makeLinkName(typ)) + if strings.Contains(tt.Name(), target) { + return override, false // TODO: No idea if false is right here + } + } + switch typ.Kind() { case reflect.Bool: return Bool().AsAny(), true @@ -86,25 +113,25 @@ func newMakeKindGen(typ reflect.Type) (gen *Generator[any], mayNeedCast bool) { case reflect.Float64: return Float64().AsAny(), true case reflect.Array: - return genAnyArray(typ), false + return genAnyArray(typ, overrides), false case reflect.Map: - return genAnyMap(typ), false + return genAnyMap(typ, overrides), false case reflect.Pointer: - return Deferred(func() *Generator[any] { return genAnyPointer(typ) }), false + return Deferred(func() *Generator[any] { return genAnyPointer(typ, overrides) }), false case reflect.Slice: - return genAnySlice(typ), false + return genAnySlice(typ, overrides), false case reflect.String: return String().AsAny(), true case reflect.Struct: - return genAnyStruct(typ), false + return genAnyStruct(typ, overrides), false default: panic(fmt.Sprintf("unsupported type kind for Make: %v", typ.Kind())) } } -func genAnyPointer(typ reflect.Type) *Generator[any] { +func genAnyPointer(typ reflect.Type, overrides []*Generator[any]) *Generator[any] { elem := typ.Elem() - elemGen := newMakeGen(elem) + elemGen := newMakeGen(elem, overrides) const pNonNil = 0.5 return Custom[any](func(t *T) any { @@ -119,9 +146,9 @@ func genAnyPointer(typ reflect.Type) *Generator[any] { }) } -func genAnyArray(typ reflect.Type) *Generator[any] { +func genAnyArray(typ reflect.Type, overrides []*Generator[any]) *Generator[any] { count := typ.Len() - elemGen := newMakeGen(typ.Elem()) + elemGen := newMakeGen(typ.Elem(), overrides) return Custom[any](func(t *T) any { a := reflect.Indirect(reflect.New(typ)) @@ -137,8 +164,8 @@ func genAnyArray(typ reflect.Type) *Generator[any] { }) } -func genAnySlice(typ reflect.Type) *Generator[any] { - elemGen := newMakeGen(typ.Elem()) +func genAnySlice(typ reflect.Type, overrides []*Generator[any]) *Generator[any] { + elemGen := newMakeGen(typ.Elem(), overrides) return Custom[any](func(t *T) any { repeat := newRepeat(-1, -1, -1, elemGen.String()) @@ -151,9 +178,9 @@ func genAnySlice(typ reflect.Type) *Generator[any] { }) } -func genAnyMap(typ reflect.Type) *Generator[any] { - keyGen := newMakeGen(typ.Key()) - valGen := newMakeGen(typ.Elem()) +func genAnyMap(typ reflect.Type, overrides []*Generator[any]) *Generator[any] { + keyGen := newMakeGen(typ.Key(), overrides) + valGen := newMakeGen(typ.Elem(), overrides) return Custom[any](func(t *T) any { label := keyGen.String() + "," + valGen.String() @@ -172,11 +199,11 @@ func genAnyMap(typ reflect.Type) *Generator[any] { }) } -func genAnyStruct(typ reflect.Type) *Generator[any] { +func genAnyStruct(typ reflect.Type, overrides []*Generator[any]) *Generator[any] { numFields := typ.NumField() fieldGens := make([]*Generator[any], numFields) for i := 0; i < numFields; i++ { - fieldGens[i] = newMakeGen(typ.Field(i).Type) + fieldGens[i] = newMakeGen(typ.Field(i).Type, overrides) } return Custom[any](func(t *T) any { diff --git a/make_variant.go b/make_variant.go new file mode 100644 index 0000000..6de8f06 --- /dev/null +++ b/make_variant.go @@ -0,0 +1,22 @@ +// Copyright 2022 Gregory Petrosyan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package rapid + +import ( + "reflect" +) + +// https://stackoverflow.com/questions/73864711/get-type-parameter-from-a-generic-struct-using-reflection + +// MakeVariant creates a generator of values of type V, using reflection to infer the required structure. +func MakeVariant[V any](overrides ...*Generator[any]) *Generator[V] { + var zero V + gen := newMakeGen(reflect.TypeOf(zero), overrides) + return newGenerator[V](&makeGen[V]{ + gen: gen, + }) +} diff --git a/make_variant_test.go b/make_variant_test.go new file mode 100644 index 0000000..2ca54cf --- /dev/null +++ b/make_variant_test.go @@ -0,0 +1,50 @@ +// Copyright 2022 Gregory Petrosyan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package rapid_test + +import ( + "fmt" + "testing" + "time" + + "pgregory.net/rapid" +) + +type S struct { + F1 string + Int int + T time.Time + TPtr *time.Time +} + +func (s S) String() string { + return fmt.Sprintf("%s, %d, %s", s.F1, s.Int, s.T.String()) +} + +func Test(t *testing.T) { + now := time.Now() + + strGen := rapid.Just("Hello") + intGen := rapid.IntRange(0, 100) + timeGen := rapid.Just(now) + sGen := rapid.MakeVariant[S](strGen.AsAny(), intGen.AsAny(), timeGen.AsAny()) + s := sGen.Example(1) + + if s.F1 != "Hello" { + t.Errorf("Unexpected string value") + } + if s.Int > 100 || s.Int < 0 { + t.Errorf("Unexpected int value") + } + if !s.T.Equal(now) { + t.Errorf("Unexpected time.Time value") + } + if s.TPtr != nil && !s.TPtr.Equal(now) { + t.Errorf("Unexpected time.Time ptr value") + } + fmt.Println(s.String()) +}