Skip to content
This repository was archived by the owner on Jan 6, 2025. It is now read-only.

Commit ec6f314

Browse files
check if fee has been collected
1 parent c5d65d4 commit ec6f314

File tree

1 file changed

+89
-17
lines changed

1 file changed

+89
-17
lines changed

src/lsps2/service.rs

Lines changed: 89 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,13 @@ enum HTLCInterceptedAction {
8181
ForwardPayment(ForwardPaymentAction),
8282
}
8383

84+
/// Possible actions that need to be taken when a payment is forwarded.
85+
#[derive(Debug, PartialEq)]
86+
enum PaymentForwardedAction {
87+
ForwardPayment(ForwardPaymentAction),
88+
ForwardHTLCs(ForwardHTLCsAction),
89+
}
90+
8491
/// The forwarding of a payment while skimming the JIT channel opening fee.
8592
#[derive(Debug, PartialEq)]
8693
struct ForwardPaymentAction(ChannelId, FeePayment);
@@ -318,17 +325,36 @@ impl OutboundJITChannelState {
318325
}
319326

320327
fn payment_forwarded(
321-
&mut self,
322-
) -> Result<(Self, Option<ForwardHTLCsAction>), ChannelStateError> {
328+
&mut self, skimmed_fee_msat: Option<u64>,
329+
) -> Result<(Self, Option<PaymentForwardedAction>), ChannelStateError> {
323330
match self {
324331
OutboundJITChannelState::PendingPaymentForward {
325-
payment_queue, channel_id, ..
332+
payment_queue,
333+
channel_id,
334+
opening_fee_msat,
326335
} => {
327336
let mut payment_queue_lock = payment_queue.lock().unwrap();
328-
let payment_forwarded =
329-
OutboundJITChannelState::PaymentForwarded { channel_id: *channel_id };
330-
let forward_htlcs = ForwardHTLCsAction(*channel_id, payment_queue_lock.clear());
331-
Ok((payment_forwarded, Some(forward_htlcs)))
337+
338+
let skimmed_fee_msat = skimmed_fee_msat.unwrap_or(0);
339+
let remaining_fee = opening_fee_msat.saturating_sub(skimmed_fee_msat);
340+
341+
if remaining_fee > 0 {
342+
let (state, payment_action) = try_get_payment(
343+
Arc::clone(payment_queue),
344+
payment_queue_lock,
345+
*channel_id,
346+
remaining_fee,
347+
);
348+
Ok((state, payment_action.map(|pa| PaymentForwardedAction::ForwardPayment(pa))))
349+
} else {
350+
let payment_forwarded =
351+
OutboundJITChannelState::PaymentForwarded { channel_id: *channel_id };
352+
let forward_htlcs = ForwardHTLCsAction(*channel_id, payment_queue_lock.clear());
353+
Ok((
354+
payment_forwarded,
355+
Some(PaymentForwardedAction::ForwardHTLCs(forward_htlcs)),
356+
))
357+
}
332358
},
333359
OutboundJITChannelState::PaymentForwarded { channel_id } => {
334360
let payment_forwarded =
@@ -362,6 +388,10 @@ impl OutboundJITChannel {
362388
}
363389
}
364390

391+
pub fn has_paid_fee(&self) -> bool {
392+
matches!(self.state, OutboundJITChannelState::PaymentForwarded { .. })
393+
}
394+
365395
fn htlc_intercepted(
366396
&mut self, htlc: InterceptedHTLC,
367397
) -> Result<Option<HTLCInterceptedAction>, LightningError> {
@@ -385,8 +415,10 @@ impl OutboundJITChannel {
385415
Ok(action)
386416
}
387417

388-
fn payment_forwarded(&mut self) -> Result<Option<ForwardHTLCsAction>, LightningError> {
389-
let (new_state, action) = self.state.payment_forwarded()?;
418+
fn payment_forwarded(
419+
&mut self, skimmed_fee_msat: Option<u64>,
420+
) -> Result<Option<PaymentForwardedAction>, LightningError> {
421+
let (new_state, action) = self.state.payment_forwarded(skimmed_fee_msat)?;
390422
self.state = new_state;
391423
Ok(action)
392424
}
@@ -812,7 +844,9 @@ where
812844
/// greater or equal to 0.0.107.
813845
///
814846
/// [`Event::PaymentForwarded`]: lightning::events::Event::PaymentForwarded
815-
pub fn payment_forwarded(&self, next_channel_id: ChannelId) -> Result<(), APIError> {
847+
pub fn payment_forwarded(
848+
&self, next_channel_id: ChannelId, skimmed_fee_msat: Option<u64>,
849+
) -> Result<bool, APIError> {
816850
if let Some(counterparty_node_id) =
817851
self.peer_by_channel_id.read().unwrap().get(&next_channel_id)
818852
{
@@ -826,8 +860,10 @@ where
826860
if let Some(jit_channel) =
827861
peer_state.outbound_channels_by_intercept_scid.get_mut(&intercept_scid)
828862
{
829-
match jit_channel.payment_forwarded() {
830-
Ok(Some(ForwardHTLCsAction(channel_id, htlcs))) => {
863+
match jit_channel.payment_forwarded(skimmed_fee_msat) {
864+
Ok(Some(PaymentForwardedAction::ForwardHTLCs(
865+
ForwardHTLCsAction(channel_id, htlcs),
866+
))) => {
831867
for htlc in htlcs {
832868
self.channel_manager.get_cm().forward_intercepted_htlc(
833869
htlc.intercept_id,
@@ -837,6 +873,29 @@ where
837873
)?;
838874
}
839875
},
876+
Ok(Some(PaymentForwardedAction::ForwardPayment(
877+
ForwardPaymentAction(
878+
channel_id,
879+
FeePayment { htlcs, opening_fee_msat },
880+
),
881+
))) => {
882+
let amounts_to_forward_msat =
883+
calculate_amount_to_forward_per_htlc(
884+
&htlcs,
885+
opening_fee_msat,
886+
);
887+
888+
for (intercept_id, amount_to_forward_msat) in
889+
amounts_to_forward_msat
890+
{
891+
self.channel_manager.get_cm().forward_intercepted_htlc(
892+
intercept_id,
893+
&channel_id,
894+
*counterparty_node_id,
895+
amount_to_forward_msat,
896+
)?;
897+
}
898+
},
840899
Ok(None) => {},
841900
Err(e) => {
842901
return Err(APIError::APIMisuseError {
@@ -847,6 +906,7 @@ where
847906
})
848907
},
849908
}
909+
return Ok(jit_channel.has_paid_fee());
850910
}
851911
} else {
852912
return Err(APIError::APIMisuseError {
@@ -862,7 +922,7 @@ where
862922
}
863923
}
864924

865-
Ok(())
925+
Ok(false)
866926
}
867927

868928
/// Forward [`Event::ChannelReady`] event parameters into this function.
@@ -1409,12 +1469,18 @@ mod tests {
14091469
}
14101470
state = new_state;
14111471
}
1472+
1473+
// TODO: how do I get the expected skimmed amount here
1474+
14121475
// Payment completes, queued payments get forwarded.
14131476
{
1414-
let (new_state, action) = state.payment_forwarded().unwrap();
1477+
let (new_state, action) = state.payment_forwarded(Some(100_000_000)).unwrap();
14151478
assert!(matches!(new_state, OutboundJITChannelState::PaymentForwarded { .. }));
14161479
match action {
1417-
Some(ForwardHTLCsAction(channel_id, htlcs)) => {
1480+
Some(PaymentForwardedAction::ForwardHTLCs(ForwardHTLCsAction(
1481+
channel_id,
1482+
htlcs,
1483+
))) => {
14181484
assert_eq!(channel_id, ChannelId([200; 32]));
14191485
assert_eq!(
14201486
htlcs,
@@ -1550,12 +1616,18 @@ mod tests {
15501616
}
15511617
state = new_state;
15521618
}
1619+
1620+
// TODO: how do I grab the expected skimmed fee amount here.
1621+
15531622
// Payment completes, queued payments get forwarded.
15541623
{
1555-
let (new_state, action) = state.payment_forwarded().unwrap();
1624+
let (new_state, action) = state.payment_forwarded(Some(100_000_000)).unwrap();
15561625
assert!(matches!(new_state, OutboundJITChannelState::PaymentForwarded { .. }));
15571626
match action {
1558-
Some(ForwardHTLCsAction(channel_id, htlcs)) => {
1627+
Some(PaymentForwardedAction::ForwardHTLCs(ForwardHTLCsAction(
1628+
channel_id,
1629+
htlcs,
1630+
))) => {
15591631
assert_eq!(channel_id, ChannelId([200; 32]));
15601632
assert_eq!(
15611633
htlcs,

0 commit comments

Comments
 (0)