-
Notifications
You must be signed in to change notification settings - Fork 2.2k
AMP support for SendPaymentV2 #5159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7795353
a9f19b1
6474b25
41ae353
06f045f
e1399fb
c1e82e5
2d397b1
5531b81
6104d12
f07c9d0
56a2c65
0b9137c
8f57dcf
c4fc72d
a2a61a1
13c0012
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,165 @@ | ||
| package amp | ||
|
|
||
| import ( | ||
| "crypto/rand" | ||
| "encoding/binary" | ||
| "fmt" | ||
| "sync" | ||
|
|
||
| "github.com/lightningnetwork/lnd/lntypes" | ||
| "github.com/lightningnetwork/lnd/lnwire" | ||
| "github.com/lightningnetwork/lnd/record" | ||
| "github.com/lightningnetwork/lnd/routing/shards" | ||
| ) | ||
|
|
||
| // Shard is an implementation of the shards.PaymentShards interface specific | ||
| // to AMP payments. | ||
| type Shard struct { | ||
| child *Child | ||
| mpp *record.MPP | ||
| amp *record.AMP | ||
| } | ||
|
|
||
| // A compile time check to ensure Shard implements the shards.PaymentShard | ||
| // interface. | ||
| var _ shards.PaymentShard = (*Shard)(nil) | ||
|
|
||
| // Hash returns the hash used for the HTLC representing this AMP shard. | ||
| func (s *Shard) Hash() lntypes.Hash { | ||
| return s.child.Hash | ||
| } | ||
|
|
||
| // MPP returns any extra MPP records that should be set for the final hop on | ||
| // the route used by this shard. | ||
| func (s *Shard) MPP() *record.MPP { | ||
| return s.mpp | ||
| } | ||
|
|
||
| // AMP returns any extra AMP records that should be set for the final hop on | ||
| // the route used by this shard. | ||
| func (s *Shard) AMP() *record.AMP { | ||
| return s.amp | ||
| } | ||
|
|
||
| // ShardTracker is an implementation of the shards.ShardTracker interface | ||
| // that is able to generate payment shards according to the AMP splitting | ||
| // algorithm. It can be used to generate new hashes to use for HTLCs, and also | ||
| // cancel shares used for failed payment shards. | ||
| type ShardTracker struct { | ||
| setID [32]byte | ||
| paymentAddr [32]byte | ||
| totalAmt lnwire.MilliSatoshi | ||
|
|
||
| sharer Sharer | ||
|
|
||
| shards map[uint64]*Child | ||
| sync.Mutex | ||
| } | ||
|
|
||
| // A compile time check to ensure ShardTracker implements the | ||
| // shards.ShardTracker interface. | ||
| var _ shards.ShardTracker = (*ShardTracker)(nil) | ||
|
|
||
| // NewShardTracker creates a new shard tracker to use for AMP payments. The | ||
| // root shard, setID, payment address and total amount must be correctly set in | ||
| // order for the TLV options to include with each shard to be created | ||
| // correctly. | ||
| func NewShardTracker(root, setID, payAddr [32]byte, | ||
| totalAmt lnwire.MilliSatoshi) *ShardTracker { | ||
|
|
||
| // Create a new seed sharer from this root. | ||
| rootShare := Share(root) | ||
| rootSharer := SeedSharerFromRoot(&rootShare) | ||
|
|
||
| return &ShardTracker{ | ||
| setID: setID, | ||
| paymentAddr: payAddr, | ||
| totalAmt: totalAmt, | ||
| sharer: rootSharer, | ||
| shards: make(map[uint64]*Child), | ||
| } | ||
| } | ||
|
|
||
| // NewShard registers a new attempt with the ShardTracker and returns a | ||
| // new shard representing this attempt. This attempt's shard should be canceled | ||
| // if it ends up not being used by the overall payment, i.e. if the attempt | ||
| // fails. | ||
| func (s *ShardTracker) NewShard(pid uint64, last bool) (shards.PaymentShard, | ||
| error) { | ||
|
|
||
| s.Lock() | ||
| defer s.Unlock() | ||
|
|
||
| // Use a random child index. | ||
| var childIndex [4]byte | ||
|
||
| if _, err := rand.Read(childIndex[:]); err != nil { | ||
| return nil, err | ||
| } | ||
| idx := binary.BigEndian.Uint32(childIndex[:]) | ||
|
|
||
| // Depending on whether we are requesting the last shard or not, either | ||
| // split the current share into two, or get a Child directly from the | ||
| // current sharer. | ||
| var child *Child | ||
| if last { | ||
Roasbeef marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| child = s.sharer.Child(idx) | ||
|
|
||
| // If this was the last shard, set the current share to the | ||
| // zero share to indicate we cannot split it further. | ||
| s.sharer = s.sharer.Zero() | ||
| } else { | ||
Roasbeef marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| left, sharer, err := s.sharer.Split() | ||
|
||
| if err != nil { | ||
| return nil, err | ||
| } | ||
|
|
||
| s.sharer = sharer | ||
| child = left.Child(idx) | ||
| } | ||
|
|
||
| // Track the new child and return the shard. | ||
| s.shards[pid] = child | ||
|
|
||
| mpp := record.NewMPP(s.totalAmt, s.paymentAddr) | ||
| amp := record.NewAMP( | ||
|
||
| child.ChildDesc.Share, s.setID, child.ChildDesc.Index, | ||
| ) | ||
|
|
||
| return &Shard{ | ||
| child: child, | ||
| mpp: mpp, | ||
| amp: amp, | ||
| }, nil | ||
| } | ||
|
|
||
| // CancelShard cancel's the shard corresponding to the given attempt ID. | ||
| func (s *ShardTracker) CancelShard(pid uint64) error { | ||
| s.Lock() | ||
| defer s.Unlock() | ||
|
|
||
| c, ok := s.shards[pid] | ||
| if !ok { | ||
| return fmt.Errorf("pid not found") | ||
| } | ||
| delete(s.shards, pid) | ||
|
|
||
| // Now that we are canceling this shard, we XOR the share back into our | ||
| // current share. | ||
| s.sharer = s.sharer.Merge(c) | ||
| return nil | ||
| } | ||
|
|
||
| // GetHash retrieves the hash used by the shard of the given attempt ID. This | ||
| // will return an error if the attempt ID is unknown. | ||
| func (s *ShardTracker) GetHash(pid uint64) (lntypes.Hash, error) { | ||
| s.Lock() | ||
| defer s.Unlock() | ||
|
|
||
| c, ok := s.shards[pid] | ||
| if !ok { | ||
| return lntypes.Hash{}, fmt.Errorf("AMP shard for attempt %v "+ | ||
| "not found", pid) | ||
| } | ||
|
|
||
| return c.Hash, nil | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,95 @@ | ||
| package amp_test | ||
|
|
||
| import ( | ||
| "crypto/rand" | ||
| "testing" | ||
|
|
||
| "github.com/lightningnetwork/lnd/amp" | ||
| "github.com/lightningnetwork/lnd/lnwire" | ||
| "github.com/lightningnetwork/lnd/routing/shards" | ||
| "github.com/stretchr/testify/require" | ||
| ) | ||
|
|
||
| // TestAMPShardTracker tests that we can derive and cancel shards at will using | ||
| // the AMP shard tracker. | ||
| func TestAMPShardTracker(t *testing.T) { | ||
|
||
| var root, setID, payAddr [32]byte | ||
| _, err := rand.Read(root[:]) | ||
| require.NoError(t, err) | ||
|
|
||
| _, err = rand.Read(setID[:]) | ||
| require.NoError(t, err) | ||
|
|
||
| _, err = rand.Read(payAddr[:]) | ||
| require.NoError(t, err) | ||
|
|
||
| var totalAmt lnwire.MilliSatoshi = 1000 | ||
|
|
||
| // Create an AMP shard tracker using the random data we just generated. | ||
| tracker := amp.NewShardTracker(root, setID, payAddr, totalAmt) | ||
|
|
||
| // Trying to retrieve a hash for id 0 should result in an error. | ||
| _, err = tracker.GetHash(0) | ||
| require.Error(t, err) | ||
|
|
||
| // We start by creating 20 shards. | ||
| const numShards = 20 | ||
|
|
||
| var shards []shards.PaymentShard | ||
| for i := uint64(0); i < numShards; i++ { | ||
| s, err := tracker.NewShard(i, i == numShards-1) | ||
| require.NoError(t, err) | ||
|
|
||
| // Check that the shards have their payloads set as expected. | ||
| require.Equal(t, setID, s.AMP().SetID()) | ||
| require.Equal(t, totalAmt, s.MPP().TotalMsat()) | ||
| require.Equal(t, payAddr, s.MPP().PaymentAddr()) | ||
|
|
||
| shards = append(shards, s) | ||
| } | ||
|
|
||
| // Make sure we can retrieve the hash for all of them. | ||
| for i := uint64(0); i < numShards; i++ { | ||
| hash, err := tracker.GetHash(i) | ||
| require.NoError(t, err) | ||
| require.Equal(t, shards[i].Hash(), hash) | ||
| } | ||
|
|
||
| // Now cancel half of the shards. | ||
| j := 0 | ||
| for i := uint64(0); i < numShards; i++ { | ||
| if i%2 == 0 { | ||
| err := tracker.CancelShard(i) | ||
| require.NoError(t, err) | ||
| continue | ||
| } | ||
|
|
||
| // Keep shard. | ||
| shards[j] = shards[i] | ||
| j++ | ||
| } | ||
| shards = shards[:j] | ||
|
|
||
| // Get a new last shard. | ||
| s, err := tracker.NewShard(numShards, true) | ||
| require.NoError(t, err) | ||
| shards = append(shards, s) | ||
|
|
||
| // Finally make sure these shards together can be used to reconstruct | ||
| // the children. | ||
| childDescs := make([]amp.ChildDesc, len(shards)) | ||
| for i, s := range shards { | ||
| childDescs[i] = amp.ChildDesc{ | ||
| Share: s.AMP().RootShare(), | ||
| Index: s.AMP().ChildIndex(), | ||
| } | ||
| } | ||
|
|
||
| // Using the child descriptors, reconstruct the children. | ||
| children := amp.ReconstructChildren(childDescs...) | ||
|
|
||
| // Validate that the derived child preimages match the hash of each shard. | ||
| for i, child := range children { | ||
| require.Equal(t, shards[i].Hash(), child.Hash) | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.