diff --git a/pkg/kv/kvclient/kvcoord/txn_interceptor_committer.go b/pkg/kv/kvclient/kvcoord/txn_interceptor_committer.go index efbf7f87d1ea..e725324d6dd2 100644 --- a/pkg/kv/kvclient/kvcoord/txn_interceptor_committer.go +++ b/pkg/kv/kvclient/kvcoord/txn_interceptor_committer.go @@ -163,7 +163,7 @@ func (tc *txnCommitter) SendLocked( // Make a copy of the EndTxn, since we're going to change it below to // disable the parallel commit. etCpy := *et - ba.Requests[len(ba.Requests)-1].SetInner(&etCpy) + ba.Requests[len(ba.Requests)-1].MustSetInner(&etCpy) et = &etCpy } } diff --git a/pkg/kv/kvclient/kvcoord/txn_interceptor_span_refresher.go b/pkg/kv/kvclient/kvcoord/txn_interceptor_span_refresher.go index 4009a4373f4c..6043188d86f8 100644 --- a/pkg/kv/kvclient/kvcoord/txn_interceptor_span_refresher.go +++ b/pkg/kv/kvclient/kvcoord/txn_interceptor_span_refresher.go @@ -185,7 +185,7 @@ func (sr *txnSpanRefresher) SendLocked( isReissue := et.DeprecatedCanCommitAtHigherTimestamp if isReissue { etCpy := *et - ba.Requests[len(ba.Requests)-1].SetInner(&etCpy) + ba.Requests[len(ba.Requests)-1].MustSetInner(&etCpy) et = &etCpy } et.DeprecatedCanCommitAtHigherTimestamp = ba.CanForwardReadTimestamp diff --git a/pkg/roachpb/api.go b/pkg/roachpb/api.go index aa796141ab51..931f0c1596cb 100644 --- a/pkg/roachpb/api.go +++ b/pkg/roachpb/api.go @@ -597,26 +597,6 @@ func (sr *ReverseScanResponse) Verify(req Request) error { return nil } -// MustSetInner sets the Request contained in the union. It panics if the -// request is not recognized by the union type. The RequestUnion is reset -// before being repopulated. -func (ru *RequestUnion) MustSetInner(args Request) { - ru.Reset() - if !ru.SetInner(args) { - panic(errors.AssertionFailedf("%T excludes %T", ru, args)) - } -} - -// MustSetInner sets the Response contained in the union. It panics if the -// response is not recognized by the union type. The ResponseUnion is reset -// before being repopulated. -func (ru *ResponseUnion) MustSetInner(reply Response) { - ru.Reset() - if !ru.SetInner(reply) { - panic(errors.AssertionFailedf("%T excludes %T", ru, reply)) - } -} - // Method implements the Request interface. func (*GetRequest) Method() Method { return Get } diff --git a/pkg/roachpb/api_test.go b/pkg/roachpb/api_test.go index d74be872d8e0..6c26f6d638ac 100644 --- a/pkg/roachpb/api_test.go +++ b/pkg/roachpb/api_test.go @@ -224,7 +224,7 @@ func TestMustSetInner(t *testing.T) { req := RequestUnion{} res := ResponseUnion{} - // GetRequest is checked first in the generated code for SetInner. + // GetRequest is checked first in the generated code for MustSetInner. req.MustSetInner(&GetRequest{}) res.MustSetInner(&GetResponse{}) req.MustSetInner(&EndTxnRequest{}) diff --git a/pkg/roachpb/batch_generated.go b/pkg/roachpb/batch_generated.go index db4de754ddb7..5f0b865c6e3c 100644 --- a/pkg/roachpb/batch_generated.go +++ b/pkg/roachpb/batch_generated.go @@ -263,8 +263,9 @@ func (ru ResponseUnion) GetInner() Response { } } -// SetInner sets the error in the union. -func (ru *ErrorDetail) SetInner(r error) bool { +// MustSetInner sets the error in the union. +func (ru *ErrorDetail) MustSetInner(r error) { + ru.Reset() var union isErrorDetail_Value switch t := r.(type) { case *NotLeaseHolderError: @@ -324,14 +325,14 @@ func (ru *ErrorDetail) SetInner(r error) bool { case *IndeterminateCommitError: union = &ErrorDetail_IndeterminateCommit{t} default: - return false + panic(fmt.Sprintf("unsupported type %T for %T", r, ru)) } ru.Value = union - return true } -// SetInner sets the Request in the union. -func (ru *RequestUnion) SetInner(r Request) bool { +// MustSetInner sets the Request in the union. +func (ru *RequestUnion) MustSetInner(r Request) { + ru.Reset() var union isRequestUnion_Value switch t := r.(type) { case *GetRequest: @@ -423,14 +424,14 @@ func (ru *RequestUnion) SetInner(r Request) bool { case *AdminVerifyProtectedTimestampRequest: union = &RequestUnion_AdminVerifyProtectedTimestamp{t} default: - return false + panic(fmt.Sprintf("unsupported type %T for %T", r, ru)) } ru.Value = union - return true } -// SetInner sets the Response in the union. -func (ru *ResponseUnion) SetInner(r Response) bool { +// MustSetInner sets the Response in the union. +func (ru *ResponseUnion) MustSetInner(r Response) { + ru.Reset() var union isResponseUnion_Value switch t := r.(type) { case *GetResponse: @@ -520,10 +521,9 @@ func (ru *ResponseUnion) SetInner(r Response) bool { case *AdminVerifyProtectedTimestampResponse: union = &ResponseUnion_AdminVerifyProtectedTimestamp{t} default: - return false + panic(fmt.Sprintf("unsupported type %T for %T", r, ru)) } ru.Value = union - return true } type reqCounts [44]int32 diff --git a/pkg/roachpb/errors.go b/pkg/roachpb/errors.go index 6f1b10fd43a0..6a0a04994974 100644 --- a/pkg/roachpb/errors.go +++ b/pkg/roachpb/errors.go @@ -219,6 +219,9 @@ func (e *internalError) Error() string { } // ErrorDetailInterface is an interface for each error detail. +// These must not be implemented by anything other than our protobuf-backed error details +// as we rely on a 1:1 correspondence between the interface and what can be stored via +// `Error.SetDetail`. type ErrorDetailInterface interface { error protoutil.Message @@ -307,12 +310,7 @@ func (e *Error) SetDetail(detail ErrorDetailInterface) { } else { e.TransactionRestart = TransactionRestart_NONE } - // If the specific error type exists in the detail union, set it. - if !e.Detail.SetInner(detail) { - if e.TransactionRestart != TransactionRestart_NONE { - panic(errors.AssertionFailedf("transactionRestartError %T must be an ErrorDetail", detail)) - } - } + e.Detail.MustSetInner(detail) e.checkTxnStatusValid() } diff --git a/pkg/roachpb/gen_batch.go b/pkg/roachpb/gen_batch.go index 649eea337069..483b7f25013f 100644 --- a/pkg/roachpb/gen_batch.go +++ b/pkg/roachpb/gen_batch.go @@ -108,10 +108,11 @@ func (ru %[1]s) GetInner() %[2]s { `) } -func genSetInner(w io.Writer, unionName, variantName string, variants []variantInfo) { +func genMustSetInner(w io.Writer, unionName, variantName string, variants []variantInfo) { fmt.Fprintf(w, ` -// SetInner sets the %[2]s in the union. -func (ru *%[1]s) SetInner(r %[2]s) bool { +// MustSetInner sets the %[2]s in the union. +func (ru *%[1]s) MustSetInner(r %[2]s) { + ru.Reset() var union is%[1]s_Value switch t := r.(type) { `, unionName, variantName) @@ -123,10 +124,9 @@ func (ru *%[1]s) SetInner(r %[2]s) bool { } fmt.Fprint(w, ` default: - return false + panic(fmt.Sprintf("unsupported type %T for %T", r, ru)) } ru.Value = union - return true } `) } @@ -160,10 +160,10 @@ import ( genGetInner(f, "RequestUnion", "Request", reqVariants) genGetInner(f, "ResponseUnion", "Response", resVariants) - // Generate SetInner methods. - genSetInner(f, "ErrorDetail", "error", errVariants) - genSetInner(f, "RequestUnion", "Request", reqVariants) - genSetInner(f, "ResponseUnion", "Response", resVariants) + // Generate MustSetInner methods. + genMustSetInner(f, "ErrorDetail", "error", errVariants) + genMustSetInner(f, "RequestUnion", "Request", reqVariants) + genMustSetInner(f, "ResponseUnion", "Response", resVariants) fmt.Fprintf(f, ` type reqCounts [%d]int32