Skip to content
This repository has been archived by the owner on Aug 2, 2021. It is now read-only.

Commit

Permalink
p2p/protocols, p2p/testing; conditional propagagation of context
Browse files Browse the repository at this point in the history
  • Loading branch information
zelig committed Aug 14, 2019
1 parent 3bbdca3 commit b8b6b9a
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 117 deletions.
79 changes: 79 additions & 0 deletions p2p/protocols/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package protocols

import (
"bufio"
"bytes"
"context"
"io/ioutil"

"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethersphere/swarm/spancontext"
opentracing "github.com/opentracing/opentracing-go"
)

// msgWithContext is used to propagate marshalled context alongside message payloads
type msgWithContext struct {
Context []byte
Msg []byte
}

func encodeWithContext(ctx context.Context, msg interface{}) (interface{}, int, error) {
var b bytes.Buffer
writer := bufio.NewWriter(&b)
tracer := opentracing.GlobalTracer()
sctx := spancontext.FromContext(ctx)
if sctx != nil {
err := tracer.Inject(
sctx,
opentracing.Binary,
writer)
if err != nil {
return nil, 0, err
}
}
writer.Flush()
msgBytes, err := rlp.EncodeToBytes(msg)
if err != nil {
return nil, 0, err
}

return &msgWithContext{
Context: b.Bytes(),
Msg: msgBytes,
}, len(msgBytes), nil
}

func decodeWithContext(msg p2p.Msg) (context.Context, []byte, error) {
var wmsg msgWithContext
err := msg.Decode(&wmsg)
if err != nil {
return nil, nil, err
}

ctx := context.Background()

if len(wmsg.Context) == 0 {
return ctx, wmsg.Msg, nil
}

tracer := opentracing.GlobalTracer()
sctx, err := tracer.Extract(opentracing.Binary, bytes.NewReader(wmsg.Context))
if err != nil {
return nil, nil, err
}
ctx = spancontext.WithContext(ctx, sctx)
return ctx, wmsg.Msg, nil
}

func encodeWithoutContext(ctx context.Context, msg interface{}) (interface{}, int, error) {
return msg, 0, nil
}

func decodeWithoutContext(msg p2p.Msg) (context.Context, []byte, error) {
b, err := ioutil.ReadAll(msg.Payload)
if err != nil {
return nil, nil, err
}
return context.Background(), b, nil
}
118 changes: 40 additions & 78 deletions p2p/protocols/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ devp2p subprotocols by abstracting away code standardly shared by protocols.
package protocols

import (
"bufio"
"bytes"
"context"
"fmt"
"io"
Expand All @@ -42,9 +40,7 @@ import (
"github.com/ethereum/go-ethereum/metrics"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethersphere/swarm/spancontext"
"github.com/ethersphere/swarm/tracing"
opentracing "github.com/opentracing/opentracing-go"
)

// error codes used by this protocol scheme
Expand Down Expand Up @@ -115,13 +111,6 @@ func errorf(code int, format string, params ...interface{}) *Error {
}
}

// WrappedMsg is used to propagate marshalled context alongside message payloads
type WrappedMsg struct {
Context []byte
Size uint32
Payload []byte
}

//For accounting, the design is to allow the Spec to describe which and how its messages are priced
//To access this functionality, we provide a Hook interface which will call accounting methods
//NOTE: there could be more such (horizontal) hooks in the future
Expand Down Expand Up @@ -157,6 +146,10 @@ type Spec struct {
initOnce sync.Once
codes map[reflect.Type]uint64
types map[uint64]reflect.Type

// if the protocol allows for extending the p2p msg to propagate context
// even if set to true context will propagate only if the remote peer supports it
DisableContext bool
}

func (s *Spec) init() {
Expand Down Expand Up @@ -208,17 +201,27 @@ type Peer struct {
*p2p.Peer // the p2p.Peer object representing the remote
rw p2p.MsgReadWriter // p2p.MsgReadWriter to send messages to and read messages from
spec *Spec
encode func(context.Context, interface{}) (interface{}, int, error)
decode func(p2p.Msg) (context.Context, []byte, error)
}

// NewPeer constructs a new peer
// this constructor is called by the p2p.Protocol#Run function
// the first two arguments are the arguments passed to p2p.Protocol.Run function
// the third argument is the Spec describing the protocol
func NewPeer(p *p2p.Peer, rw p2p.MsgReadWriter, spec *Spec) *Peer {
func NewPeer(peer *p2p.Peer, rw p2p.MsgReadWriter, spec *Spec) *Peer {
encode := encodeWithContext
decode := decodeWithContext
if spec.DisableContext || !tracing.Enabled {
encode = encodeWithoutContext
decode = decodeWithoutContext
}
return &Peer{
Peer: p,
rw: rw,
spec: spec,
Peer: peer,
rw: rw,
spec: spec,
encode: encode,
decode: decode,
}
}

Expand All @@ -234,7 +237,6 @@ func (p *Peer) Run(handler func(ctx context.Context, msg interface{}) error) err
metrics.GetOrRegisterCounter("peer.handleincoming.error", nil).Inc(1)
log.Error("peer.handleIncoming", "err", err)
}

return err
}
}
Expand All @@ -256,51 +258,32 @@ func (p *Peer) Send(ctx context.Context, msg interface{}) error {
metrics.GetOrRegisterCounter("peer.send", nil).Inc(1)
metrics.GetOrRegisterCounter(fmt.Sprintf("peer.send.%T", msg), nil).Inc(1)

var b bytes.Buffer
if tracing.Enabled {
writer := bufio.NewWriter(&b)

tracer := opentracing.GlobalTracer()

sctx := spancontext.FromContext(ctx)

if sctx != nil {
err := tracer.Inject(
sctx,
opentracing.Binary,
writer)
if err != nil {
return err
}
}

writer.Flush()
code, found := p.spec.GetCode(msg)
if !found {
return errorf(ErrInvalidMsgType, "%v", code)
}

r, err := rlp.EncodeToBytes(msg)
wmsg, size, err := p.encode(ctx, msg)
if err != nil {
return err
}

wmsg := WrappedMsg{
Context: b.Bytes(),
Size: uint32(len(r)),
Payload: r,
// if size is not set by the wrapper, need to serialise
if size == 0 {
r, err := rlp.EncodeToBytes(msg)
if err != nil {
return err
}
size = len(r)
}

//if the accounting hook is set, call it
// if the accounting hook is set, call it
if p.spec.Hook != nil {
err := p.spec.Hook.Send(p, wmsg.Size, msg)
err = p.spec.Hook.Send(p, uint32(size), msg)
if err != nil {
p.Drop()
return err
}
}

code, found := p.spec.GetCode(msg)
if !found {
return errorf(ErrInvalidMsgType, "%v", code)
}
return p2p.Send(p.rw, code, wmsg)
}

Expand All @@ -324,44 +307,23 @@ func (p *Peer) handleIncoming(handle func(ctx context.Context, msg interface{})
return errorf(ErrMsgTooLong, "%v > %v", msg.Size, p.spec.MaxMsgSize)
}

// unmarshal wrapped msg, which might contain context
var wmsg WrappedMsg
err = msg.Decode(&wmsg)
if err != nil {
log.Error(err.Error())
return err
}

ctx := context.Background()

// if tracing is enabled and the context coming within the request is
// not empty, try to unmarshal it
if tracing.Enabled && len(wmsg.Context) > 0 {
var sctx opentracing.SpanContext

tracer := opentracing.GlobalTracer()
sctx, err = tracer.Extract(
opentracing.Binary,
bytes.NewReader(wmsg.Context))
if err != nil {
log.Error(err.Error())
return err
}

ctx = spancontext.WithContext(ctx, sctx)
}

val, ok := p.spec.NewMsg(msg.Code)
if !ok {
return errorf(ErrInvalidMsgCode, "%v", msg.Code)
}
if err := rlp.DecodeBytes(wmsg.Payload, val); err != nil {

ctx, msgBytes, err := p.decode(msg)
if err != nil {
return errorf(ErrDecode, "%v err=%v", msg.Code, err)
}

if err := rlp.DecodeBytes(msgBytes, val); err != nil {
return errorf(ErrDecode, "<= %v: %v", msg, err)
}

//if the accounting hook is set, call it
// if the accounting hook is set, call it
if p.spec.Hook != nil {
err := p.spec.Hook.Receive(p, wmsg.Size, val)
err := p.spec.Hook.Receive(p, uint32(len(msgBytes)), val)
if err != nil {
return err
}
Expand Down
33 changes: 12 additions & 21 deletions p2p/protocols/protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/ethereum/go-ethereum/p2p/enode"
"github.com/ethereum/go-ethereum/p2p/simulations/adapters"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethersphere/swarm/log"
p2ptest "github.com/ethersphere/swarm/p2p/testing"
)

Expand Down Expand Up @@ -249,9 +250,7 @@ func TestProtocolHook(t *testing.T) {
runFunc := func(p *p2p.Peer, rw p2p.MsgReadWriter) error {
peer := NewPeer(p, rw, spec)
ctx := context.TODO()
err := peer.Send(ctx, &dummyMsg{
Content: "handshake"})

err := peer.Send(ctx, &dummyMsg{Content: "handshake"})
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -281,7 +280,9 @@ func TestProtocolHook(t *testing.T) {
if err != nil {
t.Fatal(err)
}

testHook.mu.Lock()
log.Warn("SENT msg:", "msg", testHook.msg)
if testHook.msg == nil || testHook.msg.(*dummyMsg).Content != "handshake" {
t.Fatal("Expected msg to be set, but it is not")
}
Expand All @@ -291,8 +292,8 @@ func TestProtocolHook(t *testing.T) {
if testHook.peer == nil {
t.Fatal("Expected peer to be set, is nil")
}
if peerId := testHook.peer.ID(); peerId != tester.Nodes[0].ID() && peerId != tester.Nodes[1].ID() {
t.Fatalf("Expected peer ID to be set correctly, but it is not (got %v, exp %v or %v", peerId, tester.Nodes[0].ID(), tester.Nodes[1].ID())
if peerID := testHook.peer.ID(); peerID != tester.Nodes[0].ID() && peerID != tester.Nodes[1].ID() {
t.Fatalf("Expected peer ID to be set correctly, but it is not (got %v, exp %v or %v", peerID, tester.Nodes[0].ID(), tester.Nodes[1].ID())
}
if testHook.size != 11 { //11 is the length of the encoded message
t.Fatalf("Expected size to be %d, but it is %d ", 1, testHook.size)
Expand All @@ -309,11 +310,10 @@ func TestProtocolHook(t *testing.T) {
},
})

<-testHook.waitC

if err != nil {
t.Fatal(err)
}
<-testHook.waitC

testHook.mu.Lock()
if testHook.msg == nil || testHook.msg.(*dummyMsg).Content != "response" {
Expand Down Expand Up @@ -600,24 +600,15 @@ func (d *dummyRW) WriteMsg(msg p2p.Msg) error {
}

func (d *dummyRW) ReadMsg() (p2p.Msg, error) {
enc := bytes.NewReader(d.getDummyMsg())
r, err := rlp.EncodeToBytes(d.msg)
if err != nil {
return p2p.Msg{}, err
}
enc := bytes.NewReader(r)
return p2p.Msg{
Code: d.code,
Size: d.size,
Payload: enc,
ReceivedAt: time.Now(),
}, nil
}

func (d *dummyRW) getDummyMsg() []byte {
r, _ := rlp.EncodeToBytes(d.msg)
var b bytes.Buffer
wmsg := WrappedMsg{
Context: b.Bytes(),
Size: uint32(len(r)),
Payload: r,
}
rr, _ := rlp.EncodeToBytes(wmsg)
d.size = uint32(len(rr))
return rr
}
Loading

0 comments on commit b8b6b9a

Please sign in to comment.