diff --git a/message/v1/message.go b/message/v1/message.go index e8b458a4..c3b726cf 100644 --- a/message/v1/message.go +++ b/message/v1/message.go @@ -9,6 +9,7 @@ import ( blocks "github.com/ipfs/go-block-format" "github.com/ipfs/go-cid" "github.com/ipld/go-ipld-prime/datamodel" + "github.com/ipld/go-ipld-prime/node/basicnode" pool "github.com/libp2p/go-buffer-pool" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" @@ -245,7 +246,6 @@ func (mh *MessageHandler) newMessageFromProto(p peer.ID, pbm *pb.Message_V1_0_0) return message.NewMessage(requests, responses, blks), nil } -// TODO: is this a breaking protocol change? force all extension data into dag-cbor? func toEncodedExtensions(part MessagePartWithExtensions) (map[string][]byte, error) { names := part.ExtensionNames() out := make(map[string][]byte, len(names)) @@ -264,19 +264,22 @@ func toEncodedExtensions(part MessagePartWithExtensions) (map[string][]byte, err return out, nil } -// TODO: is this a breaking protocol change? force all extension data into dag-cbor? func fromEncodedExtensions(in map[string][]byte) ([]graphsync.ExtensionData, error) { if in == nil { return []graphsync.ExtensionData{}, nil } out := make([]graphsync.ExtensionData, 0, len(in)) - for name, data := range in { - if len(data) == 0 { + for name, byts := range in { + if len(byts) == 0 { out = append(out, graphsync.ExtensionData{Name: graphsync.ExtensionName(name), Data: nil}) } else { - data, err := ipldutil.DecodeNode(data) + data, err := ipldutil.DecodeNode(byts) if err != nil { - return nil, err + // Backward-compatibility for extensions that may not be encoding data as + // DAG-CBOR: if we don't find valid DAG-CBOR, we just turn the whole thing + // into a Bytes node and let the extension handler deal with that, they + // can decode it however they want. + data = basicnode.NewBytes(byts) } out = append(out, graphsync.ExtensionData{Name: graphsync.ExtensionName(name), Data: data}) } diff --git a/message/v1/message_test.go b/message/v1/message_test.go index 1655d415..e71be996 100644 --- a/message/v1/message_test.go +++ b/message/v1/message_test.go @@ -445,3 +445,63 @@ func TestKnownFuzzIssues(t *testing.T) { require.Equal(t, msg1, msg2) } } + +func TestRequestExtensionPlainBytes(t *testing.T) { + originalDataNode := basicnode.NewString(string([]byte{0xca, 0xfe, 0xbe, 0xef})) + expectedDataNode := basicnode.NewBytes([]byte{0xbe, 0xef, 0xca, 0xfe, 0x00}) + + id := graphsync.NewRequestID() + extensionName := graphsync.ExtensionName("graphsync/plainbytes") + extension := graphsync.ExtensionData{ + Name: extensionName, + Data: originalDataNode, + } + + builder := message.NewBuilder() + builder.AddRequest(message.NewUpdateRequest(id, extension)) + gsm, err := builder.Build() + require.NoError(t, err) + + requests := gsm.Requests() + require.Len(t, requests, 1, "did not add cancel request") + request := requests[0] + require.Equal(t, id, request.ID()) + require.True(t, request.IsUpdate()) + require.False(t, request.IsCancel()) + extensionData, found := request.Extension(extensionName) + require.True(t, found) + require.Equal(t, extension.Data, extensionData) + + mh := NewMessageHandler() + + buf := new(bytes.Buffer) + err = mh.ToNet(peer.ID("foo"), gsm, buf) + require.NoError(t, err, "did not serialize protobuf message") + + // We've now captured a legitimate message with nicely encoded extension data + // in proper DAG-CBOR, as per the (new) expectation that extension data will + // always be DAG-CBOR. + // But since we want to also handle the possibility of arbirary bytes for + // extensions that may already exist that are using arbirary bytes and their + // own encoding formats, we'll strip out the CBOR tag and rearrange the bytes + // a little and confirm that it comes back out as a Bytes datamodel.Node + // (i.e. leaving it up to the extension author to extract their bytes out + // of that and still use their encoding format). + newmsg := bytes.Replace(buf.Bytes(), []byte{0x64, 0xca, 0xfe, 0xbe, 0xef}, []byte{0xbe, 0xef, 0xca, 0xfe, 0x00}, 1) + + deserialized, err := mh.FromNet(peer.ID("foo"), bytes.NewReader(newmsg)) + require.NoError(t, err, "did not deserialize protobuf message") + + deserializedRequests := deserialized.Requests() + require.Len(t, deserializedRequests, 1, "did not add request to deserialized message") + deserializedRequest := deserializedRequests[0] + extensionData, found = deserializedRequest.Extension(extensionName) + require.Equal(t, request.ID(), deserializedRequest.ID()) + require.Equal(t, request.IsCancel(), deserializedRequest.IsCancel()) + require.Equal(t, request.IsUpdate(), deserializedRequest.IsUpdate()) + require.Equal(t, request.Priority(), deserializedRequest.Priority()) + require.Equal(t, request.Root().String(), deserializedRequest.Root().String()) + require.Equal(t, request.Selector(), deserializedRequest.Selector()) + require.True(t, found) + require.Equal(t, expectedDataNode, extensionData) +}