Skip to content

Commit

Permalink
Add required default case to generated Map_XXX() and `Switch_XXX(…
Browse files Browse the repository at this point in the history
…)` functions for one-of fields.
  • Loading branch information
jmalloc committed Jul 11, 2024
1 parent 294ef09 commit e78e786
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 123 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ The format is based on [Keep a Changelog], and this project adheres to
[Keep a Changelog]: https://keepachangelog.com/en/1.0.0/
[Semantic Versioning]: https://semver.org/spec/v2.0.0.html

## [0.3.0] - 2024-07-11

### Changed

- **[BC]** Added a mandatory `default_` parameter to the `Map_XXX()` and
`Switch_XXX()` functions generated for one-of fields.

## [0.2.4] - 2024-07-11

### Fixed
Expand Down Expand Up @@ -105,6 +112,7 @@ The format is based on [Keep a Changelog], and this project adheres to
[0.2.2]: https://github.com/dogmatiq/primo/releases/tag/v0.2.2
[0.2.3]: https://github.com/dogmatiq/primo/releases/tag/v0.2.3
[0.2.4]: https://github.com/dogmatiq/primo/releases/tag/v0.2.4
[0.3.0]: https://github.com/dogmatiq/primo/releases/tag/v0.3.0

<!-- version template
## [0.0.1] - YYYY-MM-DD
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/dogmatiq/primo

go 1.21
go 1.22

require (
github.com/dave/jennifer v1.7.0
Expand Down
47 changes: 25 additions & 22 deletions internal/generator/exhaustiveswitch/oneof.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package exhaustiveswitch

import (
"fmt"

"github.com/dave/jennifer/jen"
"github.com/dogmatiq/primo/internal/generator/internal/scope"
)
Expand Down Expand Up @@ -31,7 +29,7 @@ func generateOneOfSwitch(code *jen.File, g *scope.OneOfGroup) {
)
code.Comment("")
code.Commentf(
"It panics if x.%s is nil.",
"It calls the function associated with the default case if x.%s is nil.",
g.GoFieldName,
)

Expand All @@ -56,6 +54,12 @@ func generateOneOfSwitch(code *jen.File, g *scope.OneOfGroup) {
)
}

code.
Line().
Id("default_").
Func().
Params()

code.Line()
},
).
Expand Down Expand Up @@ -90,15 +94,8 @@ func generateOneOfSwitch(code *jen.File, g *scope.OneOfGroup) {

code.
Default().
Panic(
jen.Lit(
fmt.Sprintf(
"%s: x.%s is nil",
funcName,
g.GoFieldName,
),
),
)
Id("default_").
Call()
},
),
)
Expand All @@ -119,8 +116,11 @@ func generateOneOfMap(code *jen.File, g *scope.OneOfGroup) {
"It invokes the function that corresponds to the value of x.%s,",
g.GoFieldName,
)
code.Comment(
"and returns that function's result. It calls the function associated with",
)
code.Commentf(
"and returns that function's result. It panics if x.%s is nil.",
"the default case if x.%s is nil.",
g.GoFieldName,
)

Expand Down Expand Up @@ -151,6 +151,15 @@ func generateOneOfMap(code *jen.File, g *scope.OneOfGroup) {
)
}

code.
Line().
Id("default_").
Func().
Params().
Params(
jen.Id("T"),
)

code.Line()
},
).
Expand Down Expand Up @@ -189,15 +198,9 @@ func generateOneOfMap(code *jen.File, g *scope.OneOfGroup) {

code.
Default().
Panic(
jen.Lit(
fmt.Sprintf(
"%s: x.%s is nil",
funcName,
g.GoFieldName,
),
),
)
Return().
Id("default_").
Call()
},
),
)
Expand Down
173 changes: 73 additions & 100 deletions internal/test/exhaustiveswitch/oneof_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ func TestOneOfSwitch(t *testing.T) {
func(v int32) { called = true },
func(m string) { panic("unexpected log operation") },
func(*Record_NamingCollision) { panic("unexpected NamingCollision operation") },
func() { panic("unexpected default case") },
)

if !called {
Expand All @@ -37,60 +38,46 @@ func TestOneOfSwitch(t *testing.T) {
)

t.Run(
"it panics if the field is nil",
"it calls the function associated with the default case if the one-of field is nil",
func(t *testing.T) {
t.Parallel()

rec := &Record{}

defer func() {
got := recover()
want := "Switch_Record_Operation: x.Operation is nil"

if got != want {
t.Fatalf(
"unexpected panic message: got %q, want %q",
got,
want,
)
}
}()

Switch_Record_Operation(
rec,
func(int32) { panic("unexpected increment operation") },
func(int32) { panic("unexpected decrement operation") },
func(string) { panic("unexpected log operation") },
func(*Record_NamingCollision) { panic("unexpected NamingCollision operation") },
)
},
)

t.Run(
"it panics if the message is nil",
func(t *testing.T) {
t.Parallel()

defer func() {
got := recover()
want := "Switch_Record_Operation: x.Operation is nil"

if got != want {
t.Fatalf(
"unexpected panic message: got %q, want %q",
got,
want,
)
}
}()
cases := []struct {
name string
record *Record
}{
{
"nil message",
nil,
},
{
"nil field",
&Record{},
},
}

Switch_Record_Operation(
nil,
func(int32) { panic("unexpected increment operation") },
func(int32) { panic("unexpected decrement operation") },
func(string) { panic("unexpected log operation") },
func(*Record_NamingCollision) { panic("unexpected NamingCollision operation") },
)
for _, c := range cases {
t.Run(
c.name,
func(t *testing.T) {
t.Parallel()

called := false

Switch_Record_Operation(
c.record,
func(int32) { panic("unexpected increment operation") },
func(int32) { panic("unexpected decrement operation") },
func(string) { panic("unexpected log operation") },
func(*Record_NamingCollision) { panic("unexpected NamingCollision operation") },
func() { called = true },
)

if !called {
t.Fatalf("expected case function to be called")
}
})
}
},
)
}
Expand All @@ -116,6 +103,7 @@ func TestOneOfMap(t *testing.T) {
func(v int32) int32 { return v - 1 },
func(m string) int32 { panic("unexpected log operation") },
func(*Record_NamingCollision) int32 { panic("unexpected NamingCollision operation") },
func() int32 { panic("unexpected default case") },
)

if got != want {
Expand All @@ -125,60 +113,45 @@ func TestOneOfMap(t *testing.T) {
)

t.Run(
"it panics if the field is nil",
"it calls the function associated with the default case if the one-of field is nil",
func(t *testing.T) {
t.Parallel()

rec := &Record{}

defer func() {
got := recover()
want := "Map_Record_Operation: x.Operation is nil"

if got != want {
t.Fatalf(
"unexpected panic message: got %q, want %q",
got,
want,
)
}
}()

Map_Record_Operation(
rec,
func(int32) error { panic("unexpected increment operation") },
func(int32) error { panic("unexpected decrement operation") },
func(string) error { panic("unexpected log operation") },
func(*Record_NamingCollision) error { panic("unexpected NamingCollision operation") },
)
},
)

t.Run(
"it panics if the message is nil",
func(t *testing.T) {
t.Parallel()
cases := []struct {
name string
record *Record
}{
{
"nil message",
nil,
},
{
"nil field",
&Record{},
},
}

defer func() {
got := recover()
want := "Map_Record_Operation: x.Operation is nil"

if got != want {
t.Fatalf(
"unexpected panic message: got %q, want %q",
got,
want,
)
}
}()

Map_Record_Operation(
nil,
func(int32) error { panic("unexpected increment operation") },
func(int32) error { panic("unexpected decrement operation") },
func(string) error { panic("unexpected log operation") },
func(*Record_NamingCollision) error { panic("unexpected NamingCollision operation") },
)
for _, c := range cases {
t.Run(
c.name,
func(t *testing.T) {
t.Parallel()

want := int32(123)
got := Map_Record_Operation(
c.record,
func(int32) int32 { panic("unexpected increment operation") },
func(int32) int32 { panic("unexpected decrement operation") },
func(string) int32 { panic("unexpected log operation") },
func(*Record_NamingCollision) int32 { panic("unexpected NamingCollision operation") },
func() int32 { return 123 },
)

if got != want {
t.Fatalf("unexpected result: got %q, want %q", got, want)
}
})
}
},
)
}

0 comments on commit e78e786

Please sign in to comment.