diff --git a/chain/exchange/cbor_gen.go b/chain/exchange/cbor_gen.go index e66b6d798c4..71c75869dba 100644 --- a/chain/exchange/cbor_gen.go +++ b/chain/exchange/cbor_gen.go @@ -306,9 +306,9 @@ func (t *Response) UnmarshalCBOR(r io.Reader) (err error) { return nil } -var lengthBufCompactedMessages = []byte{132} +var lengthBufCompactedMessagesCBOR = []byte{132} -func (t *CompactedMessages) MarshalCBOR(w io.Writer) error { +func (t *CompactedMessagesCBOR) MarshalCBOR(w io.Writer) error { if t == nil { _, err := w.Write(cbg.CborNull) return err @@ -316,12 +316,12 @@ func (t *CompactedMessages) MarshalCBOR(w io.Writer) error { cw := cbg.NewCborWriter(w) - if _, err := cw.Write(lengthBufCompactedMessages); err != nil { + if _, err := cw.Write(lengthBufCompactedMessagesCBOR); err != nil { return err } // t.Bls ([]*types.Message) (slice) - if len(t.Bls) > cbg.MaxLength { + if len(t.Bls) > 150000 { return xerrors.Errorf("Slice value in field t.Bls was too long") } @@ -334,7 +334,7 @@ func (t *CompactedMessages) MarshalCBOR(w io.Writer) error { } } - // t.BlsIncludes ([][]uint64) (slice) + // t.BlsIncludes ([]exchange.messageIndices) (slice) if len(t.BlsIncludes) > cbg.MaxLength { return xerrors.Errorf("Slice value in field t.BlsIncludes was too long") } @@ -343,24 +343,13 @@ func (t *CompactedMessages) MarshalCBOR(w io.Writer) error { return err } for _, v := range t.BlsIncludes { - if len(v) > cbg.MaxLength { - return xerrors.Errorf("Slice value in field v was too long") - } - - if err := cw.WriteMajorTypeHeader(cbg.MajArray, uint64(len(v))); err != nil { + if err := v.MarshalCBOR(cw); err != nil { return err } - for _, v := range v { - - if err := cw.WriteMajorTypeHeader(cbg.MajUnsignedInt, uint64(v)); err != nil { - return err - } - - } } // t.Secpk ([]*types.SignedMessage) (slice) - if len(t.Secpk) > cbg.MaxLength { + if len(t.Secpk) > 150000 { return xerrors.Errorf("Slice value in field t.Secpk was too long") } @@ -373,7 +362,7 @@ func (t *CompactedMessages) MarshalCBOR(w io.Writer) error { } } - // t.SecpkIncludes ([][]uint64) (slice) + // t.SecpkIncludes ([]exchange.messageIndices) (slice) if len(t.SecpkIncludes) > cbg.MaxLength { return xerrors.Errorf("Slice value in field t.SecpkIncludes was too long") } @@ -382,26 +371,15 @@ func (t *CompactedMessages) MarshalCBOR(w io.Writer) error { return err } for _, v := range t.SecpkIncludes { - if len(v) > cbg.MaxLength { - return xerrors.Errorf("Slice value in field v was too long") - } - - if err := cw.WriteMajorTypeHeader(cbg.MajArray, uint64(len(v))); err != nil { + if err := v.MarshalCBOR(cw); err != nil { return err } - for _, v := range v { - - if err := cw.WriteMajorTypeHeader(cbg.MajUnsignedInt, uint64(v)); err != nil { - return err - } - - } } return nil } -func (t *CompactedMessages) UnmarshalCBOR(r io.Reader) (err error) { - *t = CompactedMessages{} +func (t *CompactedMessagesCBOR) UnmarshalCBOR(r io.Reader) (err error) { + *t = CompactedMessagesCBOR{} cr := cbg.NewCborReader(r) @@ -430,7 +408,7 @@ func (t *CompactedMessages) UnmarshalCBOR(r io.Reader) (err error) { return err } - if extra > cbg.MaxLength { + if extra > 150000 { return fmt.Errorf("t.Bls: array too large (%d)", extra) } @@ -471,7 +449,7 @@ func (t *CompactedMessages) UnmarshalCBOR(r io.Reader) (err error) { } } - // t.BlsIncludes ([][]uint64) (slice) + // t.BlsIncludes ([]exchange.messageIndices) (slice) maj, extra, err = cr.ReadHeader() if err != nil { @@ -487,7 +465,7 @@ func (t *CompactedMessages) UnmarshalCBOR(r io.Reader) (err error) { } if extra > 0 { - t.BlsIncludes = make([][]uint64, extra) + t.BlsIncludes = make([]messageIndices, extra) } for i := 0; i < int(extra); i++ { @@ -499,47 +477,13 @@ func (t *CompactedMessages) UnmarshalCBOR(r io.Reader) (err error) { _ = extra _ = err - maj, extra, err = cr.ReadHeader() - if err != nil { - return err - } - - if extra > cbg.MaxLength { - return fmt.Errorf("t.BlsIncludes[i]: array too large (%d)", extra) - } - - if maj != cbg.MajArray { - return fmt.Errorf("expected cbor array") - } - - if extra > 0 { - t.BlsIncludes[i] = make([]uint64, extra) - } - - for j := 0; j < int(extra); j++ { - { - var maj byte - var extra uint64 - var err error - _ = maj - _ = extra - _ = err - - { - - maj, extra, err = cr.ReadHeader() - if err != nil { - return err - } - if maj != cbg.MajUnsignedInt { - return fmt.Errorf("wrong type for uint64 field") - } - t.BlsIncludes[i][j] = uint64(extra) + { - } + if err := t.BlsIncludes[i].UnmarshalCBOR(cr); err != nil { + return xerrors.Errorf("unmarshaling t.BlsIncludes[i]: %w", err) } - } + } } } @@ -550,7 +494,7 @@ func (t *CompactedMessages) UnmarshalCBOR(r io.Reader) (err error) { return err } - if extra > cbg.MaxLength { + if extra > 150000 { return fmt.Errorf("t.Secpk: array too large (%d)", extra) } @@ -591,7 +535,7 @@ func (t *CompactedMessages) UnmarshalCBOR(r io.Reader) (err error) { } } - // t.SecpkIncludes ([][]uint64) (slice) + // t.SecpkIncludes ([]exchange.messageIndices) (slice) maj, extra, err = cr.ReadHeader() if err != nil { @@ -607,7 +551,7 @@ func (t *CompactedMessages) UnmarshalCBOR(r io.Reader) (err error) { } if extra > 0 { - t.SecpkIncludes = make([][]uint64, extra) + t.SecpkIncludes = make([]messageIndices, extra) } for i := 0; i < int(extra); i++ { @@ -619,47 +563,13 @@ func (t *CompactedMessages) UnmarshalCBOR(r io.Reader) (err error) { _ = extra _ = err - maj, extra, err = cr.ReadHeader() - if err != nil { - return err - } - - if extra > cbg.MaxLength { - return fmt.Errorf("t.SecpkIncludes[i]: array too large (%d)", extra) - } - - if maj != cbg.MajArray { - return fmt.Errorf("expected cbor array") - } - - if extra > 0 { - t.SecpkIncludes[i] = make([]uint64, extra) - } - - for j := 0; j < int(extra); j++ { - { - var maj byte - var extra uint64 - var err error - _ = maj - _ = extra - _ = err - - { - - maj, extra, err = cr.ReadHeader() - if err != nil { - return err - } - if maj != cbg.MajUnsignedInt { - return fmt.Errorf("wrong type for uint64 field") - } - t.SecpkIncludes[i][j] = uint64(extra) + { - } + if err := t.SecpkIncludes[i].UnmarshalCBOR(cr); err != nil { + return xerrors.Errorf("unmarshaling t.SecpkIncludes[i]: %w", err) } - } + } } } diff --git a/chain/exchange/client.go b/chain/exchange/client.go index db39628be69..fca8249cef9 100644 --- a/chain/exchange/client.go +++ b/chain/exchange/client.go @@ -4,6 +4,7 @@ import ( "bufio" "context" "fmt" + "io" "math/rand" "time" @@ -23,6 +24,10 @@ import ( "github.com/filecoin-project/lotus/lib/peermgr" ) +// Set the max exchange message size to 120MiB. Purely based on gas numbers, we can include ~8MiB of +// messages per block, so I've set this to 120MiB to be _very_ safe. +const maxExchangeMessageSize = (15 * 8) << 20 + // client implements exchange.Client, using the libp2p ChainExchange protocol // as the fetching mechanism. type client struct { @@ -434,10 +439,11 @@ func (c *client) sendRequestToPeer(ctx context.Context, peer peer.ID, req *Reque log.Warnw("CloseWrite err", "error", err) } - // Read response. + // Read response, limiting the size of the response to maxExchangeMessageSize as we allow a + // lot of messages (10k+) but they'll mostly be quite small. var res Response err = cborutil.ReadCborRPC( - bufio.NewReader(incrt.New(stream, ReadResMinSpeed, ReadResDeadline)), + bufio.NewReader(io.LimitReader(incrt.New(stream, ReadResMinSpeed, ReadResDeadline), maxExchangeMessageSize)), &res) if err != nil { c.peerTracker.logFailure(peer, build.Clock.Since(connectionStart), req.Length) diff --git a/chain/exchange/protocol.go b/chain/exchange/protocol.go index 5e12d31cc29..cd25f4a4350 100644 --- a/chain/exchange/protocol.go +++ b/chain/exchange/protocol.go @@ -154,6 +154,8 @@ type BSTipSet struct { // FIXME: The logic to decompress this structure should belong // // to itself, not to the consumer. +// +// NOTE: Max messages is: BlockMessageLimit (10k) * MaxTipsetSize (15) = 150k type CompactedMessages struct { Bls []*types.Message BlsIncludes [][]uint64 diff --git a/chain/exchange/protocol_encoding.go b/chain/exchange/protocol_encoding.go new file mode 100644 index 00000000000..7df00a639f9 --- /dev/null +++ b/chain/exchange/protocol_encoding.go @@ -0,0 +1,125 @@ +package exchange + +import ( + "fmt" + "io" + + cbg "github.com/whyrusleeping/cbor-gen" + xerrors "golang.org/x/xerrors" + + "github.com/filecoin-project/lotus/build" + types "github.com/filecoin-project/lotus/chain/types" +) + +// Type used for encoding/decoding compacted messages. This is a ustom type as we need custom limits. +// - Max messages is 150,000 as that's 15 times the max block size (in messages). It needs to be +// large enough to cover a full tipset full of full blocks. +type CompactedMessagesCBOR struct { + Bls []*types.Message `cborgen:"maxlen=150000"` + BlsIncludes []messageIndices + + Secpk []*types.SignedMessage `cborgen:"maxlen=150000"` + SecpkIncludes []messageIndices +} + +// Unmarshal into the "decoding" struct, then copy into the actual struct. +func (t *CompactedMessages) UnmarshalCBOR(r io.Reader) (err error) { + var c CompactedMessagesCBOR + if err := c.UnmarshalCBOR(r); err != nil { + return err + } + t.Bls = c.Bls + t.BlsIncludes = make([][]uint64, len(c.BlsIncludes)) + for i, v := range c.BlsIncludes { + t.BlsIncludes[i] = v.v + } + t.Secpk = c.Secpk + t.SecpkIncludes = make([][]uint64, len(c.SecpkIncludes)) + for i, v := range c.SecpkIncludes { + t.SecpkIncludes[i] = v.v + } + return nil +} + +// Copy into the encoding struct, then marshal. +func (t *CompactedMessages) MarshalCBOR(w io.Writer) error { + if t == nil { + _, err := w.Write(cbg.CborNull) + return err + } + + var c CompactedMessagesCBOR + c.Bls = t.Bls + c.BlsIncludes = make([]messageIndices, len(t.BlsIncludes)) + for i, v := range t.BlsIncludes { + c.BlsIncludes[i].v = v + } + c.Secpk = t.Secpk + c.SecpkIncludes = make([]messageIndices, len(t.SecpkIncludes)) + for i, v := range t.SecpkIncludes { + c.SecpkIncludes[i].v = v + } + return c.MarshalCBOR(w) +} + +// this needs to be a struct or cborgen will peak into it and ignore the Unmarshal/Marshal functions +type messageIndices struct { + v []uint64 +} + +func (t *messageIndices) UnmarshalCBOR(r io.Reader) (err error) { + cr := cbg.NewCborReader(r) + + maj, extra, err := cr.ReadHeader() + if err != nil { + return err + } + + if maj != cbg.MajArray { + return fmt.Errorf("cbor input should be of type array") + } + + if extra > build.BlockMessageLimit { + return fmt.Errorf("cbor input had wrong number of fields") + } + + if extra > 0 { + t.v = make([]uint64, extra) + } + + for i := 0; i < int(extra); i++ { + maj, extra, err := cr.ReadHeader() + if err != nil { + return err + } + if maj != cbg.MajUnsignedInt { + return fmt.Errorf("wrong type for uint64 field") + } + t.v[i] = uint64(extra) + + } + return nil +} + +func (t *messageIndices) MarshalCBOR(w io.Writer) error { + if t == nil { + _, err := w.Write(cbg.CborNull) + return err + } + + cw := cbg.NewCborWriter(w) + + if len(t.v) > build.BlockMessageLimit { + return xerrors.Errorf("Slice value in field v was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajArray, uint64(len(t.v))); err != nil { + return err + } + for _, v := range t.v { + if err := cw.WriteMajorTypeHeader(cbg.MajUnsignedInt, uint64(v)); err != nil { + return err + } + } + return nil +} diff --git a/gen/main.go b/gen/main.go index 0cd3999c38a..942b3ac2c05 100644 --- a/gen/main.go +++ b/gen/main.go @@ -92,7 +92,7 @@ func main() { err = gen.WriteTupleEncodersToFile("./chain/exchange/cbor_gen.go", "exchange", exchange.Request{}, exchange.Response{}, - exchange.CompactedMessages{}, + exchange.CompactedMessagesCBOR{}, exchange.BSTipSet{}, ) if err != nil {