diff --git a/lib/xdrgen/generators/go.rb b/lib/xdrgen/generators/go.rb index 6b5f7e13c..1498488aa 100644 --- a/lib/xdrgen/generators/go.rb +++ b/lib/xdrgen/generators/go.rb @@ -551,11 +551,9 @@ def render_union_decode_from_interface(out, union) name = name(union) out.puts "// DecodeFrom decodes this value using the Decoder." out.puts "func (u *#{name}) DecodeFrom(d *xdr.Decoder) (int, error) {" - out.puts " disc, n, err := d.DecodeInt()" - out.puts " if err != nil {" - out.puts " return 0, err" - out.puts " }" - out.puts " u.#{name(union.discriminant)} = #{reference union.discriminant.type}(disc)" + out.puts " var err error" + out.puts " var n, nTmp int" + render_decode_from_body(out, "u.#{name(union.discriminant)}", union.discriminant.type, declared_variables: [], self_encode: false) switch_for(out, union, "u.#{name(union.discriminant)}") do |arm, kase| out2 = StringIO.new if arm.void? @@ -580,16 +578,20 @@ def render_union_decode_from_interface(out, union) def render_enum_decode_from_interface(out, typedef) name = name(typedef) type = typedef - out.puts "// DecodeFrom decodes this value using the Decoder." - out.puts "func (e *#{name}) DecodeFrom(d *xdr.Decoder) (int, error) {" - out.puts " var err error" - out.puts " var n, nTmp int" - out.puts " var i int32" - render_decode_from_body(out, "i", type, declared_variables: [], self_encode: true) - out.puts " *e = #{name}(i)" - out.puts " return n, nil" - out.puts "}" - out.break + out.puts <<-EOS.strip_heredoc + // DecodeFrom decodes this value using the Decoder. + func (e *#{name}) DecodeFrom(d *xdr.Decoder) (int, error) { + v, n, err := d.DecodeInt() + if err != nil { + return n, err + } + if _, ok := #{private_name type}Map[v]; !ok { + return n, fmt.Errorf("'%d' is not a valid #{name} enum value", v) + } + *e = #{name}(v) + return n, nil + } + EOS end def render_typedef_decode_from_interface(out, typedef) @@ -655,7 +657,7 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:) when AST::Typespecs::UnsignedInt out.puts " #{var}, nTmp, err = d.DecodeUint()" out.puts tail - when AST::Typespecs::Int, AST::Definitions::Enum + when AST::Typespecs::Int out.puts " #{var}, nTmp, err = d.DecodeInt()" out.puts tail when AST::Typespecs::String