diff --git a/lib/xdrgen/generators/go.rb b/lib/xdrgen/generators/go.rb index 1509ddec7..58186c0a3 100644 --- a/lib/xdrgen/generators/go.rb +++ b/lib/xdrgen/generators/go.rb @@ -353,22 +353,26 @@ def render_struct_encode_to_interface(out, struct) def render_union_encode_to_interface(out, union) name = name(union) out.puts "// EncodeTo encodes this value using the Encoder." - out.puts "func (s #{name}) EncodeTo(e *xdr.Encoder) error {" - out.puts " _, err := e.EncodeInt(int32(s.#{name(union.discriminant)}))" - out.puts " if err != nil {" - out.puts " return err" - out.puts " }" - switch_for(out, union, "s.#{name(union.discriminant)}") do |arm, kase| + out.puts "func (u #{name}) EncodeTo(e *xdr.Encoder) error {" + out.puts " var err error" + render_encode_to_body(out, "u.#{name(union.discriminant)}", union.discriminant.type, self_encode: false) + switch_for(out, union, "u.#{name(union.discriminant)}") do |arm, kase| out2 = StringIO.new if arm.void? - "// Void" + out2.puts "// Void" else mn = name(arm) - render_encode_to_body(out2, "(*s.#{mn})", arm.type, self_encode: false) - out2.string + render_encode_to_body(out2, "(*u.#{mn})", arm.type, self_encode: false) end + out2.puts "return nil" + out2.string end - out.puts " return err" + + # when the default arm is not present, we must render the failure case + unless union.default_arm.present? + out.puts " return fmt.Errorf(\"#{name(union.discriminant)} (#{reference union.discriminant.type}) switch value '%d' is not valid for union #{name}\", u.#{name(union.discriminant)})" + end + out.puts "}" out.break end @@ -376,13 +380,16 @@ def render_union_encode_to_interface(out, union) def render_enum_encode_to_interface(out, typedef) name = name(typedef) type = typedef - out.puts "// EncodeTo encodes this value using the Encoder." - out.puts "func (s #{name}) EncodeTo(e *xdr.Encoder) error {" - out.puts " var err error" - render_encode_to_body(out, "s", type, self_encode: true) - out.puts " return nil" - out.puts "}" - out.break + out.puts <<-EOS.strip_heredoc + // EncodeTo encodes this value using the Encoder. + func (e #{name}) EncodeTo(enc *xdr.Encoder) error { + if _, ok := #{private_name type}Map[int32(e)]; !ok { + return fmt.Errorf("'%d' is not a valid #{name} enum value", e) + } + _, err := enc.EncodeInt(int32(e)) + return err + } + EOS end def render_typedef_encode_to_interface(out, typedef) @@ -432,41 +439,42 @@ def render_binary_interface(out, name) # xdr.Encoder, and a variable defined by `name` that is the value to # encode. def render_encode_to_body(out, var, type, self_encode:) + def check_error(str) + <<-EOS.strip_heredoc + if #{str}; err != nil { + return err + } + EOS + end optional = type.sub_type == :optional if optional - out.puts " _, err = e.EncodeBool(#{var} != nil)" - out.puts " if err != nil {" - out.puts " return err" - out.puts " }" + out.puts check_error "_, err = e.EncodeBool(#{var} != nil)" out.puts " if #{var} != nil {" var = "(*#{var})" end case type when AST::Typespecs::UnsignedHyper - out.puts " _, err = e.EncodeUhyper(uint64(#{var}))" + out.puts check_error " _, err = e.EncodeUhyper(uint64(#{var}))" when AST::Typespecs::Hyper - out.puts " _, err = e.EncodeHyper(int64(#{var}))" + out.puts check_error "_, err = e.EncodeHyper(int64(#{var}))" when AST::Typespecs::UnsignedInt - out.puts " _, err = e.EncodeUint(uint32(#{var}))" - when AST::Typespecs::Int, AST::Definitions::Enum - out.puts " _, err = e.EncodeInt(int32(#{var}))" + out.puts check_error "_, err = e.EncodeUint(uint32(#{var}))" + when AST::Typespecs::Int + out.puts (check_error "_, err = e.EncodeInt(int32(#{var}))") when AST::Typespecs::String - out.puts " _, err = e.EncodeString(string(#{var}))" + out.puts check_error "_, err = e.EncodeString(string(#{var}))" when AST::Typespecs::Opaque if type.fixed? - out.puts " _, err = e.EncodeFixedOpaque(#{var}[:])" + out.puts check_error "_, err = e.EncodeFixedOpaque(#{var}[:])" else - out.puts " _, err = e.EncodeOpaque(#{var}[:])" + out.puts check_error "_, err = e.EncodeOpaque(#{var}[:])" end when AST::Typespecs::Simple case type.sub_type when :simple, :optional optional_within = type.is_a?(AST::Identifier) && type.resolved_type.sub_type == :optional if optional_within - out.puts " _, err = e.EncodeBool(#{var} != nil)" - out.puts " if err != nil {" - out.puts " return err" - out.puts " }" + out.puts check_error "_, err = e.EncodeBool(#{var} != nil)" out.puts " if #{var} != nil {" var = "(*#{var})" end @@ -481,7 +489,7 @@ def render_encode_to_body(out, var, type, self_encode:) end var = newvar end - out.puts " err = #{var}.EncodeTo(e)" + out.puts check_error " err = #{var}.EncodeTo(e)" if optional_within out.puts " }" end @@ -490,41 +498,26 @@ def render_encode_to_body(out, var, type, self_encode:) element_var = "#{var}[i]" optional_within = type.is_a?(AST::Identifier) && type.resolved_type.sub_type == :optional if optional_within - out.puts " _, err = e.EncodeBool(#{element_var} != nil)" - out.puts " if err != nil {" - out.puts " return err" - out.puts " }" + out.puts check_error "_, err = e.EncodeBool(#{element_var} != nil)" out.puts " if #{element_var} != nil {" var = "(*#{element_var})" end - out.puts " err = #{element_var}.EncodeTo(e)" - out.puts " if err != nil {" - out.puts " return err" - out.puts " }" + out.puts check_error "err = #{element_var}.EncodeTo(e)" if optional_within out.puts " }" end out.puts " }" when :var_array - out.puts " _, err = e.EncodeUint(uint32(len(#{var})))" - out.puts " if err != nil {" - out.puts " return err" - out.puts " }" + out.puts check_error "_, err = e.EncodeUint(uint32(len(#{var})))" out.puts " for i := 0; i < len(#{var}); i++ {" element_var = "#{var}[i]" optional_within = type.is_a?(AST::Identifier) && type.resolved_type.sub_type == :optional if optional_within - out.puts " _, err = e.EncodeBool(#{element_var} != nil)" - out.puts " if err != nil {" - out.puts " return err" - out.puts " }" + out.puts check_error "_, err = e.EncodeBool(#{element_var} != nil)" out.puts " if #{element_var} != nil {" var = "(*#{element_var})" end - out.puts " err = #{element_var}.EncodeTo(e)" - out.puts " if err != nil {" - out.puts " return err" - out.puts " }" + out.puts check_error "err = #{element_var}.EncodeTo(e)" if optional_within out.puts " }" end @@ -534,19 +527,16 @@ def render_encode_to_body(out, var, type, self_encode:) end when AST::Definitions::Base if self_encode - out.puts " err = #{name type}(#{var}).EncodeTo(e)" + out.puts check_error "err = #{name type}(#{var}).EncodeTo(e)" else - out.puts " err = #{var}.EncodeTo(e)" + out.puts check_error "err = #{var}.EncodeTo(e)" end else - out.puts " _, err = e.Encode(#{var})" + out.puts check_error "_, err = e.Encode(#{var})" end if optional out.puts " }" end - out.puts " if err != nil {" - out.puts " return err" - out.puts " }" end def render_xdr_type_interface(out, name)