Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

golang: Sanity-check allocations #182

Merged
merged 3 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 40 additions & 17 deletions lib/xdrgen/generators/go.rb
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def check_error(str)
def render_struct_decode_from_interface(out, struct)
name = name(struct)
out.puts "// DecodeFrom decodes this value using the Decoder."
out.puts "func (s *#{name}) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {"
out.puts "func (s *#{name}) DecodeFrom(d *xdr.Decoder, maxDepth uint, maxAllocSize int) (int, error) {"
out.puts " if maxDepth == 0 {"
out.puts " return 0, fmt.Errorf(\"decoding #{name}: %w\", ErrMaxDecodingDepthReached)"
out.puts " }"
Expand All @@ -556,7 +556,7 @@ def render_struct_decode_from_interface(out, struct)
def render_union_decode_from_interface(out, union)
name = name(union)
out.puts "// DecodeFrom decodes this value using the Decoder."
out.puts "func (u *#{name}) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {"
out.puts "func (u *#{name}) DecodeFrom(d *xdr.Decoder, maxDepth uint, maxAllocSize int) (int, error) {"
out.puts " if maxDepth == 0 {"
out.puts " return 0, fmt.Errorf(\"decoding #{name}: %w\", ErrMaxDecodingDepthReached)"
out.puts " }"
Expand Down Expand Up @@ -589,7 +589,7 @@ def render_enum_decode_from_interface(out, typedef)
type = typedef
out.puts <<-EOS.strip_heredoc
// DecodeFrom decodes this value using the Decoder.
func (e *#{name}) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
func (e *#{name}) DecodeFrom(d *xdr.Decoder, maxDepth uint, maxAllocSize int) (int, error) {
if maxDepth == 0 {
return 0, fmt.Errorf("decoding #{name}: %w", ErrMaxDecodingDepthReached)
}
Expand All @@ -611,7 +611,7 @@ def render_typedef_decode_from_interface(out, typedef)
name = name(typedef)
type = typedef.declaration.type
out.puts "// DecodeFrom decodes this value using the Decoder."
out.puts "func (s *#{name}) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {"
out.puts "func (s *#{name}) DecodeFrom(d *xdr.Decoder, maxDepth uint, maxAllocSize int) (int, error) {"
out.puts " if maxDepth == 0 {"
out.puts " return 0, fmt.Errorf(\"decoding #{name}: %w\", ErrMaxDecodingDepthReached)"
out.puts " }"
Expand Down Expand Up @@ -681,17 +681,17 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:)
out.puts " #{var}, nTmp, err = d.DecodeBool()"
out.puts tail
when AST::Typespecs::String
arg = "0"
arg = type.decl.resolved_size unless type.decl.resolved_size.nil?
arg = "maxAllocSize"
arg = "mergeMaxAllocSizeAndMaxSize(maxAllocSize, #{type.decl.resolved_size})" unless type.decl.resolved_size.nil?
out.puts " #{var}, nTmp, err = d.DecodeString(#{arg})"
out.puts tail
when AST::Typespecs::Opaque
if type.fixed?
out.puts " nTmp, err = d.DecodeFixedOpaqueInplace(#{var}[:])"
out.puts tail
else
arg = "0"
arg = type.decl.resolved_size unless type.decl.resolved_size.nil?
arg = "maxAllocSize"
arg = "mergeMaxAllocSizeAndMaxSize(maxAllocSize, #{type.decl.resolved_size})" unless type.decl.resolved_size.nil?
out.puts " #{var}, nTmp, err = d.DecodeOpaque(#{arg})"
out.puts tail
end
Expand All @@ -708,7 +708,7 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:)
out.puts " #{var} = new(#{name type.resolved_type.declaration.type})"
end
var = "(*#{name type})(#{var})" if self_encode
out.puts " nTmp, err = #{var}.DecodeFrom(d, maxDepth)"
out.puts " nTmp, err = #{var}.DecodeFrom(d, maxDepth, maxAllocSize)"
out.puts tail
if optional_within
out.puts " }"
Expand All @@ -725,7 +725,7 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:)
out.puts " if eb {"
var = "(*#{element_var})"
end
out.puts " nTmp, err = #{element_var}.DecodeFrom(d, maxDepth)"
out.puts " nTmp, err = #{element_var}.DecodeFrom(d, maxDepth, maxAllocSize)"
out.puts tail
if optional_within
out.puts " }"
Expand All @@ -740,6 +740,13 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:)
out.puts " return n, fmt.Errorf(\"decoding #{name type}: data size (%d) exceeds size limit (#{type.decl.resolved_size})\", l)"
out.puts " }"
end
out.puts " if maxAllocSize > 0 {"
out.puts " var phony #{name type}"
out.puts " allocSize := unsafe.Sizeof(phony) * uintptr(l)"
out.puts " if uintptr(maxAllocSize) < allocSize {"
out.puts " return n, fmt.Errorf(\"decoding #{name type}: allocation size (%d) exceeds limit (%d)\", allocSize, maxAllocSize)"
out.puts " }"
out.puts " }"
out.puts " #{var} = nil"
out.puts " if l > 0 {"
out.puts " #{var} = make([]#{name type}, l)"
Expand All @@ -755,7 +762,7 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:)
out.puts " #{element_var} = new(#{name type.resolved_type.declaration.type})"
var = "(*#{element_var})"
end
out.puts " nTmp, err = #{element_var}.DecodeFrom(d, maxDepth)"
out.puts " nTmp, err = #{element_var}.DecodeFrom(d, maxDepth, maxAllocSize)"
out.puts tail
if optional_within
out.puts " }"
Expand All @@ -767,13 +774,13 @@ def render_decode_from_body(out, var, type, declared_variables:, self_encode:)
end
when AST::Definitions::Base
if self_encode
out.puts " nTmp, err = #{name type}(#{var}).DecodeFrom(d, maxDepth)"
out.puts " nTmp, err = #{name type}(#{var}).DecodeFrom(d, maxDepth, maxAllocSize)"
else
out.puts " nTmp, err = #{var}.DecodeFrom(d, maxDepth)"
out.puts " nTmp, err = #{var}.DecodeFrom(d, maxDepth, maxAllocSize)"
end
out.puts tail
else
out.puts " nTmp, err = d.DecodeWithMaxDepth(&#{var}, maxDepth)"
out.puts " nTmp, err = d.DecodeWithMaxDepth(&#{var}, maxDepth, maxAllocSize)"
out.puts tail
end
if optional
Expand All @@ -794,7 +801,7 @@ def render_binary_interface(out, name)
out.puts "func (s *#{name}) UnmarshalBinary(inp []byte) error {"
out.puts " r := bytes.NewReader(inp)"
out.puts " d := xdr.NewDecoder(r)"
out.puts " _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth)"
out.puts " _, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth, 0)"
out.puts " return err"
out.puts "}"
out.break
Expand Down Expand Up @@ -836,12 +843,15 @@ def render_top_matter(out)
"errors"
"io"
"fmt"
"unsafe"

"github.com/stellar/go-xdr/xdr3"
)
EOS
out.break
out.puts <<-EOS.strip_heredoc
// Needed since unsafe is not used in all cases
var _ = unsafe.Sizeof(0)
// XdrFilesSHA256 is the SHA256 hashes of source files.
var XdrFilesSHA256 = map[string]string{
#{@output.relative_source_path_sha256_hashes.map(){ |path, hash| %{"#{path}": "#{hash}",} }.join("\n")}
Expand All @@ -856,19 +866,32 @@ def render_top_matter(out)
}

type decoderFrom interface {
DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error)
DecodeFrom(d *xdr.Decoder, maxDepth uint, maxAllocSize int) (int, error)
}

// Unmarshal reads an xdr element from `r` into `v`.
func Unmarshal(r io.Reader, v interface{}) (int, error) {
return UnmarshalWithMaxAllocSize(r, v, 0)
}

// Unmarshal reads an xdr element from `r` into `v`.
func UnmarshalWithMaxAllocSize(r io.Reader, v interface{}, maxAllocSize int) (int, error) {
if decodable, ok := v.(decoderFrom); ok {
d := xdr.NewDecoder(r)
return decodable.DecodeFrom(d, xdr.DecodeDefaultMaxDepth)
return decodable.DecodeFrom(d, xdr.DecodeDefaultMaxDepth, maxAllocSize)
}
// delegate to xdr package's Unmarshal
return xdr.Unmarshal(r, v)
}

func mergeMaxAllocSizeAndMaxSize(maxAllocSize int, maxSize int) int {
if maxAllocSize > 0 || maxAllocSize < maxSize {
return maxAllocSize
}
return maxSize
}


// Marshal writes an xdr element `v` into `w`.
func Marshal(w io.Writer, v interface{}) (int, error) {
if _, ok := v.(xdrType); ok {
Expand Down
86 changes: 51 additions & 35 deletions spec/output/generator_spec_go/block_comments.x/MyXDR_generated.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,49 +14,65 @@ import (
"errors"
"io"
"fmt"
"unsafe"

"github.com/stellar/go-xdr/xdr3"
)

// Needed since unsafe is not used in all cases
var _ = unsafe.Sizeof(0)
// XdrFilesSHA256 is the SHA256 hashes of source files.
var XdrFilesSHA256 = map[string]string{
"spec/fixtures/generator/block_comments.x": "e13131bc4134f38da17b9d5e9f67d2695a69ef98e3ef272833f4c18d0cc88a30",
}

var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached")

type xdrType interface {
xdrType()
}

type decoderFrom interface {
DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error)
}

// Unmarshal reads an xdr element from `r` into `v`.
func Unmarshal(r io.Reader, v interface{}) (int, error) {
if decodable, ok := v.(decoderFrom); ok {
d := xdr.NewDecoder(r)
return decodable.DecodeFrom(d, xdr.DecodeDefaultMaxDepth)
}
// delegate to xdr package's Unmarshal
return xdr.Unmarshal(r, v)
}

// Marshal writes an xdr element `v` into `w`.
func Marshal(w io.Writer, v interface{}) (int, error) {
if _, ok := v.(xdrType); ok {
if bm, ok := v.(encoding.BinaryMarshaler); ok {
b, err := bm.MarshalBinary()
if err != nil {
return 0, err
}
return w.Write(b)
}
}
// delegate to xdr package's Marshal
return xdr.Marshal(w, v)
var ErrMaxDecodingDepthReached = errors.New("maximum decoding depth reached")

type xdrType interface {
xdrType()
}

type decoderFrom interface {
DecodeFrom(d *xdr.Decoder, maxDepth uint, maxAllocSize int) (int, error)
}

// Unmarshal reads an xdr element from `r` into `v`.
func Unmarshal(r io.Reader, v interface{}) (int, error) {
return UnmarshalWithMaxAllocSize(r, v, 0)
}

// Unmarshal reads an xdr element from `r` into `v`.
2opremio marked this conversation as resolved.
Show resolved Hide resolved
func UnmarshalWithMaxAllocSize(r io.Reader, v interface{}, maxAllocSize int) (int, error) {
if decodable, ok := v.(decoderFrom); ok {
d := xdr.NewDecoder(r)
return decodable.DecodeFrom(d, xdr.DecodeDefaultMaxDepth, maxAllocSize)
}
// delegate to xdr package's Unmarshal
return xdr.Unmarshal(r, v)
}

func mergeMaxAllocSizeAndMaxSize(maxAllocSize int, maxSize int) int {
if maxAllocSize > 0 || maxAllocSize < maxSize {
return maxAllocSize
}
return maxSize
}


// Marshal writes an xdr element `v` into `w`.
func Marshal(w io.Writer, v interface{}) (int, error) {
if _, ok := v.(xdrType); ok {
if bm, ok := v.(encoding.BinaryMarshaler); ok {
b, err := bm.MarshalBinary()
if err != nil {
return 0, err
}
return w.Write(b)
}
}
// delegate to xdr package's Marshal
return xdr.Marshal(w, v)
}

// AccountFlags is an XDR Enum defines as:
//
Expand Down Expand Up @@ -95,7 +111,7 @@ func (e AccountFlags) EncodeTo(enc *xdr.Encoder) error {
}
var _ decoderFrom = (*AccountFlags)(nil)
// DecodeFrom decodes this value using the Decoder.
func (e *AccountFlags) DecodeFrom(d *xdr.Decoder, maxDepth uint) (int, error) {
func (e *AccountFlags) DecodeFrom(d *xdr.Decoder, maxDepth uint, maxAllocSize int) (int, error) {
if maxDepth == 0 {
return 0, fmt.Errorf("decoding AccountFlags: %w", ErrMaxDecodingDepthReached)
}
Expand All @@ -122,7 +138,7 @@ func (s AccountFlags) MarshalBinary() ([]byte, error) {
func (s *AccountFlags) UnmarshalBinary(inp []byte) error {
r := bytes.NewReader(inp)
d := xdr.NewDecoder(r)
_, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth)
_, err := s.DecodeFrom(d, xdr.DecodeDefaultMaxDepth, 0)
return err
}

Expand Down
Loading
Loading