@@ -81,6 +81,13 @@ enum HTLCInterceptedAction {
81
81
ForwardPayment ( ForwardPaymentAction ) ,
82
82
}
83
83
84
+ /// Possible actions that need to be taken when a payment is forwarded.
85
+ #[ derive( Debug , PartialEq ) ]
86
+ enum PaymentForwardedAction {
87
+ ForwardPayment ( ForwardPaymentAction ) ,
88
+ ForwardHTLCs ( ForwardHTLCsAction ) ,
89
+ }
90
+
84
91
/// The forwarding of a payment while skimming the JIT channel opening fee.
85
92
#[ derive( Debug , PartialEq ) ]
86
93
struct ForwardPaymentAction ( ChannelId , FeePayment ) ;
@@ -318,17 +325,36 @@ impl OutboundJITChannelState {
318
325
}
319
326
320
327
fn payment_forwarded (
321
- & mut self ,
322
- ) -> Result < ( Self , Option < ForwardHTLCsAction > ) , ChannelStateError > {
328
+ & mut self , skimmed_fee_msat : Option < u64 > ,
329
+ ) -> Result < ( Self , Option < PaymentForwardedAction > ) , ChannelStateError > {
323
330
match self {
324
331
OutboundJITChannelState :: PendingPaymentForward {
325
- payment_queue, channel_id, ..
332
+ payment_queue,
333
+ channel_id,
334
+ opening_fee_msat,
326
335
} => {
327
336
let mut payment_queue_lock = payment_queue. lock ( ) . unwrap ( ) ;
328
- let payment_forwarded =
329
- OutboundJITChannelState :: PaymentForwarded { channel_id : * channel_id } ;
330
- let forward_htlcs = ForwardHTLCsAction ( * channel_id, payment_queue_lock. clear ( ) ) ;
331
- Ok ( ( payment_forwarded, Some ( forward_htlcs) ) )
337
+
338
+ let skimmed_fee_msat = skimmed_fee_msat. unwrap_or ( 0 ) ;
339
+ let remaining_fee = opening_fee_msat. saturating_sub ( skimmed_fee_msat) ;
340
+
341
+ if remaining_fee > 0 {
342
+ let ( state, payment_action) = try_get_payment (
343
+ Arc :: clone ( payment_queue) ,
344
+ payment_queue_lock,
345
+ * channel_id,
346
+ remaining_fee,
347
+ ) ;
348
+ Ok ( ( state, payment_action. map ( |pa| PaymentForwardedAction :: ForwardPayment ( pa) ) ) )
349
+ } else {
350
+ let payment_forwarded =
351
+ OutboundJITChannelState :: PaymentForwarded { channel_id : * channel_id } ;
352
+ let forward_htlcs = ForwardHTLCsAction ( * channel_id, payment_queue_lock. clear ( ) ) ;
353
+ Ok ( (
354
+ payment_forwarded,
355
+ Some ( PaymentForwardedAction :: ForwardHTLCs ( forward_htlcs) ) ,
356
+ ) )
357
+ }
332
358
} ,
333
359
OutboundJITChannelState :: PaymentForwarded { channel_id } => {
334
360
let payment_forwarded =
@@ -362,6 +388,10 @@ impl OutboundJITChannel {
362
388
}
363
389
}
364
390
391
+ pub fn has_paid_fee ( & self ) -> bool {
392
+ matches ! ( self . state, OutboundJITChannelState :: PaymentForwarded { .. } )
393
+ }
394
+
365
395
fn htlc_intercepted (
366
396
& mut self , htlc : InterceptedHTLC ,
367
397
) -> Result < Option < HTLCInterceptedAction > , LightningError > {
@@ -385,8 +415,10 @@ impl OutboundJITChannel {
385
415
Ok ( action)
386
416
}
387
417
388
- fn payment_forwarded ( & mut self ) -> Result < Option < ForwardHTLCsAction > , LightningError > {
389
- let ( new_state, action) = self . state . payment_forwarded ( ) ?;
418
+ fn payment_forwarded (
419
+ & mut self , skimmed_fee_msat : Option < u64 > ,
420
+ ) -> Result < Option < PaymentForwardedAction > , LightningError > {
421
+ let ( new_state, action) = self . state . payment_forwarded ( skimmed_fee_msat) ?;
390
422
self . state = new_state;
391
423
Ok ( action)
392
424
}
@@ -812,7 +844,9 @@ where
812
844
/// greater or equal to 0.0.107.
813
845
///
814
846
/// [`Event::PaymentForwarded`]: lightning::events::Event::PaymentForwarded
815
- pub fn payment_forwarded ( & self , next_channel_id : ChannelId ) -> Result < ( ) , APIError > {
847
+ pub fn payment_forwarded (
848
+ & self , next_channel_id : ChannelId , skimmed_fee_msat : Option < u64 > ,
849
+ ) -> Result < bool , APIError > {
816
850
if let Some ( counterparty_node_id) =
817
851
self . peer_by_channel_id . read ( ) . unwrap ( ) . get ( & next_channel_id)
818
852
{
@@ -826,8 +860,10 @@ where
826
860
if let Some ( jit_channel) =
827
861
peer_state. outbound_channels_by_intercept_scid . get_mut ( & intercept_scid)
828
862
{
829
- match jit_channel. payment_forwarded ( ) {
830
- Ok ( Some ( ForwardHTLCsAction ( channel_id, htlcs) ) ) => {
863
+ match jit_channel. payment_forwarded ( skimmed_fee_msat) {
864
+ Ok ( Some ( PaymentForwardedAction :: ForwardHTLCs (
865
+ ForwardHTLCsAction ( channel_id, htlcs) ,
866
+ ) ) ) => {
831
867
for htlc in htlcs {
832
868
self . channel_manager . get_cm ( ) . forward_intercepted_htlc (
833
869
htlc. intercept_id ,
@@ -837,6 +873,29 @@ where
837
873
) ?;
838
874
}
839
875
} ,
876
+ Ok ( Some ( PaymentForwardedAction :: ForwardPayment (
877
+ ForwardPaymentAction (
878
+ channel_id,
879
+ FeePayment { htlcs, opening_fee_msat } ,
880
+ ) ,
881
+ ) ) ) => {
882
+ let amounts_to_forward_msat =
883
+ calculate_amount_to_forward_per_htlc (
884
+ & htlcs,
885
+ opening_fee_msat,
886
+ ) ;
887
+
888
+ for ( intercept_id, amount_to_forward_msat) in
889
+ amounts_to_forward_msat
890
+ {
891
+ self . channel_manager . get_cm ( ) . forward_intercepted_htlc (
892
+ intercept_id,
893
+ & channel_id,
894
+ * counterparty_node_id,
895
+ amount_to_forward_msat,
896
+ ) ?;
897
+ }
898
+ } ,
840
899
Ok ( None ) => { } ,
841
900
Err ( e) => {
842
901
return Err ( APIError :: APIMisuseError {
@@ -847,6 +906,7 @@ where
847
906
} )
848
907
} ,
849
908
}
909
+ return Ok ( jit_channel. has_paid_fee ( ) ) ;
850
910
}
851
911
} else {
852
912
return Err ( APIError :: APIMisuseError {
@@ -862,7 +922,7 @@ where
862
922
}
863
923
}
864
924
865
- Ok ( ( ) )
925
+ Ok ( false )
866
926
}
867
927
868
928
/// Forward [`Event::ChannelReady`] event parameters into this function.
@@ -1409,12 +1469,18 @@ mod tests {
1409
1469
}
1410
1470
state = new_state;
1411
1471
}
1472
+
1473
+ // TODO: how do I get the expected skimmed amount here
1474
+
1412
1475
// Payment completes, queued payments get forwarded.
1413
1476
{
1414
- let ( new_state, action) = state. payment_forwarded ( ) . unwrap ( ) ;
1477
+ let ( new_state, action) = state. payment_forwarded ( Some ( 100_000_000 ) ) . unwrap ( ) ;
1415
1478
assert ! ( matches!( new_state, OutboundJITChannelState :: PaymentForwarded { .. } ) ) ;
1416
1479
match action {
1417
- Some ( ForwardHTLCsAction ( channel_id, htlcs) ) => {
1480
+ Some ( PaymentForwardedAction :: ForwardHTLCs ( ForwardHTLCsAction (
1481
+ channel_id,
1482
+ htlcs,
1483
+ ) ) ) => {
1418
1484
assert_eq ! ( channel_id, ChannelId ( [ 200 ; 32 ] ) ) ;
1419
1485
assert_eq ! (
1420
1486
htlcs,
@@ -1550,12 +1616,18 @@ mod tests {
1550
1616
}
1551
1617
state = new_state;
1552
1618
}
1619
+
1620
+ // TODO: how do I grab the expected skimmed fee amount here.
1621
+
1553
1622
// Payment completes, queued payments get forwarded.
1554
1623
{
1555
- let ( new_state, action) = state. payment_forwarded ( ) . unwrap ( ) ;
1624
+ let ( new_state, action) = state. payment_forwarded ( Some ( 100_000_000 ) ) . unwrap ( ) ;
1556
1625
assert ! ( matches!( new_state, OutboundJITChannelState :: PaymentForwarded { .. } ) ) ;
1557
1626
match action {
1558
- Some ( ForwardHTLCsAction ( channel_id, htlcs) ) => {
1627
+ Some ( PaymentForwardedAction :: ForwardHTLCs ( ForwardHTLCsAction (
1628
+ channel_id,
1629
+ htlcs,
1630
+ ) ) ) => {
1559
1631
assert_eq ! ( channel_id, ChannelId ( [ 200 ; 32 ] ) ) ;
1560
1632
assert_eq ! (
1561
1633
htlcs,
0 commit comments