Skip to content

POC: Upfront Fees to Mitigate Channel Jamming #7339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions channeldb/invoices.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ const (
htlcAMPType tlv.Type = 19
htlcHashType tlv.Type = 21
htlcPreimageType tlv.Type = 23
upfrontFeeType tlv.Type = 25

// A set of tlv type definitions used to serialize invoice bodiees.
//
Expand Down Expand Up @@ -969,6 +970,13 @@ func serializeHtlcs(w io.Writer,
}
}

if htlc.UpfrontFee != 0 {
fee := uint64(htlc.UpfrontFee)
record := tlv.MakePrimitiveRecord(upfrontFeeType, &fee)

records = append(records, record)
}

// Convert the custom records to tlv.Record types that are ready
// for serialization.
customRecords := tlv.MapToRecords(htlc.CustomRecords)
Expand Down Expand Up @@ -1606,15 +1614,15 @@ func deserializeHtlcs(r io.Reader) (map[models.CircuitKey]*invpkg.InvoiceHTLC,

// Decode the contents into the htlc fields.
var (
htlc invpkg.InvoiceHTLC
key models.CircuitKey
chanID uint64
state uint8
acceptTime, resolveTime uint64
amt, mppTotalAmt uint64
amp = &record.AMP{}
hash32 = &[32]byte{}
preimage32 = &[32]byte{}
htlc invpkg.InvoiceHTLC
key models.CircuitKey
chanID uint64
state uint8
acceptTime, resolveTime uint64
amt, mppTotalAmt, upfrontFee uint64
amp = &record.AMP{}
hash32 = &[32]byte{}
preimage32 = &[32]byte{}
)
tlvStream, err := tlv.NewStream(
tlv.MakePrimitiveRecord(chanIDType, &chanID),
Expand All @@ -1634,6 +1642,7 @@ func deserializeHtlcs(r io.Reader) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
),
tlv.MakePrimitiveRecord(htlcHashType, hash32),
tlv.MakePrimitiveRecord(htlcPreimageType, preimage32),
tlv.MakePrimitiveRecord(upfrontFeeType, &upfrontFee),
)
if err != nil {
return nil, err
Expand Down Expand Up @@ -1665,6 +1674,7 @@ func deserializeHtlcs(r io.Reader) (map[models.CircuitKey]*invpkg.InvoiceHTLC,
htlc.ResolveTime = getNanoTime(resolveTime)
htlc.State = invpkg.HtlcState(state)
htlc.Amt = lnwire.MilliSatoshi(amt)
htlc.UpfrontFee = lnwire.MilliSatoshi(upfrontFee)
htlc.MppTotalAmt = lnwire.MilliSatoshi(mppTotalAmt)
if amp != nil && hash != nil {
htlc.AMP = &invpkg.InvoiceHtlcAMPData{
Expand Down Expand Up @@ -1750,7 +1760,12 @@ func updateHtlcsAmp(invoice *invpkg.Invoice,
}
}

// Update state with the amount paid. We include our upfront fee in
// the amount paid because the sender has pushed this amount to us with
// the HTLC.
ampState.AmtPaid += htlc.Amt
ampState.AmtPaid += htlc.UpfrontFee

ampState.InvoiceKeys[circuitKey] = struct{}{}

// Due to the way maps work, we need to read out the value, update it,
Expand Down Expand Up @@ -1942,6 +1957,7 @@ func (d *DB) updateInvoice(hash *lntypes.Hash, refSetID *invpkg.SetID, invoices,

htlc := &invpkg.InvoiceHTLC{
Amt: htlcUpdate.Amt,
UpfrontFee: htlcUpdate.UpfrontFee,
MppTotalAmt: htlcUpdate.MppTotalAmt,
Expiry: htlcUpdate.Expiry,
AcceptHeight: uint32(htlcUpdate.AcceptHeight),
Expand Down Expand Up @@ -2116,6 +2132,11 @@ func (d *DB) updateInvoice(hash *lntypes.Hash, refSetID *invpkg.SetID, invoices,

amtPaid += htlc.Amt
}

// We do however include the amount pushed to us in the
// htlc's upfront fee, since we have received this
// amount unconditionally on receipt of the htlc.
amtPaid += htlc.UpfrontFee
} else {
// For AMP invoices, since we won't always be reading
// out the total invoice set each time, we'll instead
Expand All @@ -2132,6 +2153,7 @@ func (d *DB) updateInvoice(hash *lntypes.Hash, refSetID *invpkg.SetID, invoices,
invoiceStateReady {

amtPaid += htlc.Amt
amtPaid += htlc.UpfrontFee
}
}
}
Expand Down
7 changes: 7 additions & 0 deletions channeldb/payments.go
Original file line number Diff line number Diff line change
Expand Up @@ -1176,6 +1176,13 @@ func serializeHop(w io.Writer, h *route.Hop) error {
records = append(records, record.NewMetadataRecord(&h.Metadata))
}

upfrontFee, set := h.UpfrontFeeToForward.Value()
if set {
u64Fee := uint64(upfrontFee)
record := record.NewUpfrontFeeToForwardRecord(&u64Fee)
records = append(records, record)
}

// Final sanity check to absolutely rule out custom records that are not
// custom and write into the standard range.
if err := h.CustomRecords.Validate(); err != nil {
Expand Down
4 changes: 3 additions & 1 deletion contractcourt/htlc_incoming_contest_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,9 @@ func (h *htlcIncomingContestResolver) decodePayload() (*hop.Payload,
return nil, nil, err
}

payload, err := iterator.HopPayload()
// TODO - we need to be able to recover the update_add_htlc
// (specifically its extra data) here.
payload, err := iterator.HopPayload(nil)
if err != nil {
return nil, nil, err
}
Expand Down
6 changes: 6 additions & 0 deletions feature/default_sets.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,10 @@ var defaultSetDesc = setDesc{
SetInit: {}, // I
SetNodeAnn: {}, // N
},
lnwire.UpfrontFeeOptional: {
SetInit: {}, // I
SetNodeAnn: {}, // N
SetInvoice: {}, // 9
SetInvoiceAmp: {}, // 9A,
},
}
4 changes: 4 additions & 0 deletions htlcswitch/hop/forwarding_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ type ForwardingInfo struct {
// node should forward to the next hop.
AmountToForward lnwire.MilliSatoshi

// UpfrontFeeToForward is the amount that should be pushed to the
// receiving node to add the incoming htlc to their commitment.
UpfrontFeeToForward *lnwire.UpfrontFee

// OutgoingCTLV is the specified value of the CTLV timelock to be used
// in the outgoing HTLC.
OutgoingCTLV uint32
Expand Down
30 changes: 26 additions & 4 deletions htlcswitch/hop/iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ type Iterator interface {
// information encoded within the returned ForwardingInfo is to be used
// by each hop to authenticate the information given to it by the prior
// hop. The payload will also contain any additional TLV fields provided
// by the sender.
HopPayload() (*Payload, error)
//
// The extra data that was transmitted with the update_add_htlc message
// that provided the payload is passed in to allow validation to
// compare the payload to extra information provided in the add.
HopPayload(*ExtraAddData) (*Payload, error)

// EncodeNextHop encodes the onion packet destined for the next hop
// into the passed io.Writer.
Expand All @@ -35,6 +38,13 @@ type Iterator interface {
lnwire.FailCode)
}

// ExtraAddData contains the additional data that was attached to a
// update_add_htlc that is relevant to validation.
type ExtraAddData struct {
// UpfrontFee is an optional upfront fee set in the update_add_htlc.
UpfrontFee *lnwire.UpfrontFee
}

// sphinxHopIterator is the Sphinx implementation of hop iterator which uses
// onion routing to encode the payment route in such a way so that node might
// see only the next hop in the route..
Expand Down Expand Up @@ -78,7 +88,9 @@ func (r *sphinxHopIterator) EncodeNextHop(w io.Writer) error {
// also contain any additional TLV fields provided by the sender.
//
// NOTE: Part of the HopIterator interface.
func (r *sphinxHopIterator) HopPayload() (*Payload, error) {
func (r *sphinxHopIterator) HopPayload(data *ExtraAddData) (*Payload,
error) {

switch r.processedPacket.Payload.Type {

// If this is the legacy payload, then we'll extract the information
Expand All @@ -90,9 +102,19 @@ func (r *sphinxHopIterator) HopPayload() (*Payload, error) {
// Otherwise, if this is the TLV payload, then we'll make a new stream
// to decode only what we need to make routing decisions.
case sphinx.PayloadTLV:
return NewPayloadFromReader(bytes.NewReader(
payload, parsed, err := NewPayloadFromReader(bytes.NewReader(
r.processedPacket.Payload.Payload,
))
if err != nil {
return nil, err
}

err = ValidateTLVPayload(payload, parsed, data)
if err != nil {
return nil, err
}

return payload, nil

default:
return nil, fmt.Errorf("unknown sphinx payload type: %v",
Expand Down
Loading