Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

record+routing/route: add AMP record #3957

Merged
merged 3 commits into from
Feb 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions record/amp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package record

import (
"fmt"
"io"

"github.com/lightningnetwork/lnd/tlv"
)

// AMPOnionType is the type used in the onion to reference the AMP fields:
// root_share, set_id, and child_index.
const AMPOnionType tlv.Type = 10

// AMP is a record that encodes the fields necessary for atomic multi-path
// payments.
type AMP struct {
rootShare [32]byte
setID [32]byte
childIndex uint16
}

// NewAMP generate a new AMP record with the given root_share, set_id, and
// child_index.
func NewAMP(rootShare, setID [32]byte, childIndex uint16) *AMP {
return &AMP{
rootShare: rootShare,
setID: setID,
childIndex: childIndex,
}
}

// RootShare returns the root share contained in the AMP record.
func (a *AMP) RootShare() [32]byte {
return a.rootShare
}

// SetID returns the set id contained in the AMP record.
func (a *AMP) SetID() [32]byte {
return a.setID
}

// ChildIndex returns the child index contained in the AMP record.
func (a *AMP) ChildIndex() uint16 {
return a.childIndex
}

// AMPEncoder writes the AMP record to the provided io.Writer.
func AMPEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*AMP); ok {
if err := tlv.EBytes32(w, &v.rootShare, buf); err != nil {
return err
}

if err := tlv.EBytes32(w, &v.setID, buf); err != nil {
return err
}

return tlv.ETUint16T(w, v.childIndex, buf)
}
return tlv.NewTypeForEncodingErr(val, "AMP")
}

const (
// minAMPLength is the minimum length of a serialized AMP TLV record,
// which occurs when the truncated encoding of child_index takes 0
// bytes, leaving only the root_share and set_id.
minAMPLength = 64

// maxAMPLength is the maximum legnth of a serialized AMP TLV record,
// which occurs when the truncated endoing of a child_index takes 2
// bytes.
maxAMPLength = 66
)

// AMPDecoder reads the AMP record from the provided io.Reader.
func AMPDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if v, ok := val.(*AMP); ok && minAMPLength <= l && l <= maxAMPLength {
if err := tlv.DBytes32(r, &v.rootShare, buf, 32); err != nil {
return err
}

if err := tlv.DBytes32(r, &v.setID, buf, 32); err != nil {
return err
}

return tlv.DTUint16(r, &v.childIndex, buf, l-64)
}
return tlv.NewTypeForDecodingErr(val, "AMP", l, maxAMPLength)
}

// Record returns a tlv.Record that can be used to encode or decode this record.
func (a *AMP) Record() tlv.Record {
return tlv.MakeDynamicRecord(
AMPOnionType, a, a.PayloadSize, AMPEncoder, AMPDecoder,
)
}

// PayloadSize returns the size this record takes up in encoded form.
func (a *AMP) PayloadSize() uint64 {
return 32 + 32 + tlv.SizeTUint16(a.childIndex)
}

// String returns a human-readble description of the amp payload fields.
func (a *AMP) String() string {
return fmt.Sprintf("root_share=%x set_id=%x child_index=%d",
cfromknecht marked this conversation as resolved.
Show resolved Hide resolved
a.rootShare, a.setID, a.childIndex)
}
30 changes: 28 additions & 2 deletions record/record_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ type recordEncDecTest struct {
}

var (
testTotal = lnwire.MilliSatoshi(45)
testAddr = [32]byte{0x01, 0x02}
testTotal = lnwire.MilliSatoshi(45)
testAddr = [32]byte{0x01, 0x02}
testShare = [32]byte{0x03, 0x04}
testSetID = [32]byte{0x05, 0x06}
testChildIndex = uint16(17)
)

var recordEncDecTests = []recordEncDecTest{
Expand All @@ -40,6 +43,29 @@ var recordEncDecTests = []recordEncDecTest{
}
},
},
{
name: "amp",
encRecord: func() tlv.RecordProducer {
return record.NewAMP(
testShare, testSetID, testChildIndex,
)
},
decRecord: func() tlv.RecordProducer {
return new(record.AMP)
},
assert: func(t *testing.T, r interface{}) {
amp := r.(*record.AMP)
if amp.RootShare() != testShare {
t.Fatal("incorrect root share")
}
if amp.SetID() != testSetID {
t.Fatal("incorrect set id")
}
if amp.ChildIndex() != testChildIndex {
t.Fatal("incorrect child index")
}
},
},
}

// TestRecordEncodeDecode is a generic test framework for custom TLV records. It
Expand Down
25 changes: 25 additions & 0 deletions routing/route/route.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ var (
// record to an intermediate hop, only final hops can receive MPP
// records.
ErrIntermediateMPPHop = errors.New("cannot send MPP to intermediate")

// ErrAMPMissingMPP is returned when the caller tries to attach an AMP
// record but no MPP record is presented for the final hop.
ErrAMPMissingMPP = errors.New("cannot send AMP without MPP record")
)

// Vertex is a simple alias for the serialization of a compressed Bitcoin
Expand Down Expand Up @@ -111,6 +115,10 @@ type Hop struct {
// only be set for the final hop.
MPP *record.MPP

// AMP encapsulates the data required for option_amp. This field should
// only be set for the final hop.
AMP *record.AMP

// CustomRecords if non-nil are a set of additional TLV records that
// should be included in the forwarding instructions for this node.
CustomRecords record.CustomSet
Expand Down Expand Up @@ -168,6 +176,18 @@ func (h *Hop) PackHopPayload(w io.Writer, nextChanID uint64) error {
}
}

// If an AMP record is destined for this hop, ensure that we only ever
cfromknecht marked this conversation as resolved.
Show resolved Hide resolved
// attach it if we also have an MPP record. We can infer that this is
// already a final hop if MPP is non-nil otherwise we would have exited
// above.
if h.AMP != nil {
if h.MPP != nil {
records = append(records, h.AMP.Record())
} else {
return ErrAMPMissingMPP
}
}

// Append any custom types destined for this hop.
tlvRecords := tlv.MapToRecords(h.CustomRecords)
records = append(records, tlvRecords...)
Expand Down Expand Up @@ -217,6 +237,11 @@ func (h *Hop) PayloadSize(nextChanID uint64) uint64 {
addRecord(record.MPPOnionType, h.MPP.PayloadSize())
}

// Add amp if present.
if h.AMP != nil {
addRecord(record.AMPOnionType, h.AMP.PayloadSize())
}

// Add custom records.
for k, v := range h.CustomRecords {
addRecord(tlv.Type(k), uint64(len(v)))
Expand Down
46 changes: 44 additions & 2 deletions routing/route/route_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ var (
testAddr = [32]byte{0x01, 0x02}
)

// TestMPPHop asserts that a Hop will encode a non-nil to final nodes, and fail
// when trying to send to intermediaries.
// TestMPPHop asserts that a Hop will encode a non-nil MPP to final nodes, and
// fail when trying to send to intermediaries.
func TestMPPHop(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -101,6 +101,47 @@ func TestMPPHop(t *testing.T) {
}
}

// TestAMPHop asserts that a Hop will encode a non-nil AMP to final nodes of an
// MPP record is also present, and fail otherwise.
func TestAMPHop(t *testing.T) {
t.Parallel()

hop := Hop{
ChannelID: 1,
OutgoingTimeLock: 44,
AmtToForward: testAmt,
LegacyPayload: false,
AMP: record.NewAMP([32]byte{}, [32]byte{}, 3),
}

// Encoding an AMP record to an intermediate hop w/o an MPP record
// should result in a failure.
var b bytes.Buffer
err := hop.PackHopPayload(&b, 2)
if err != ErrAMPMissingMPP {
t.Fatalf("expected err: %v, got: %v",
ErrAMPMissingMPP, err)
}

// Encoding an AMP record to a final hop w/o an MPP record should result
// in a failure.
b.Reset()
err = hop.PackHopPayload(&b, 0)
if err != ErrAMPMissingMPP {
t.Fatalf("expected err: %v, got: %v",
ErrAMPMissingMPP, err)
}

// Encoding an AMP record to a final hop w/ an MPP record should be
// successful.
hop.MPP = record.NewMPP(testAmt, testAddr)
b.Reset()
err = hop.PackHopPayload(&b, 0)
if err != nil {
t.Fatalf("expected err: %v, got: %v", nil, err)
}
}

// TestPayloadSize tests the payload size calculation that is provided by Hop
// structs.
func TestPayloadSize(t *testing.T) {
Expand All @@ -123,6 +164,7 @@ func TestPayloadSize(t *testing.T) {
AmtToForward: 1200,
OutgoingTimeLock: 700000,
MPP: record.NewMPP(500, [32]byte{}),
AMP: record.NewAMP([32]byte{}, [32]byte{}, 8),
CustomRecords: map[uint64][]byte{
100000: {1, 2, 3},
1000000: {4, 5},
Expand Down