23
23
from numpyro .infer .util import transform_fn
24
24
from numpyro .util import fori_collect
25
25
26
- SteinVIState = namedtuple ("SteinVIState" , ["optim_state" , "rng_key" ])
26
+ SteinVIState = namedtuple (
27
+ "SteinVIState" ,
28
+ ["optim_state" , "rng_key" , "loss_temperature" , "repulsion_temperature" ],
29
+ )
27
30
SteinVIRunResult = namedtuple ("SteinRunResult" , ["params" , "state" , "losses" ])
28
31
29
32
@@ -231,7 +234,15 @@ def local_trace(key):
231
234
232
235
return vmap (local_trace )(random .split (particle_seed , self .num_stein_particles ))
233
236
234
- def _svgd_loss_and_grads (self , rng_key , unconstr_params , * args , ** kwargs ):
237
+ def _svgd_loss_and_grads (
238
+ self ,
239
+ rng_key ,
240
+ unconstr_params ,
241
+ loss_temperature ,
242
+ repulsion_temperature ,
243
+ * args ,
244
+ ** kwargs ,
245
+ ):
235
246
# Separate model and guide parameters, since only guide parameters are updated using Stein
236
247
# Split parameters into model and guide components - only unflagged guide parameters are
237
248
# optimized via Stein forces.
@@ -262,7 +273,7 @@ def particle_transform_fn(particle):
262
273
ctparticle , _ = ravel_pytree (ctparams )
263
274
return ctparticle
264
275
265
- model = handlers .scale (self ._inference_model , self . loss_temperature )
276
+ model = handlers .scale (self ._inference_model , loss_temperature )
266
277
267
278
def stein_loss_fn (key , particle , particle_idx ):
268
279
return self .stein_loss .particle_loss (
@@ -312,10 +323,9 @@ def body(attr_force, state, y):
312
323
# Third term of eq. 9 from https://arxiv.org/pdf/2410.22948.
313
324
repulsive_force = vmap (
314
325
lambda y : jnp .mean (
315
- vmap (
316
- lambda x : self .repulsion_temperature
317
- * self ._kernel_grad (kernel , x , y )
318
- )(stein_particles ),
326
+ vmap (lambda x : repulsion_temperature * self ._kernel_grad (kernel , x , y ))(
327
+ stein_particles
328
+ ),
319
329
axis = 0 ,
320
330
)
321
331
)(stein_particles )
@@ -420,7 +430,12 @@ def init(self, rng_key, *args, **kwargs):
420
430
)
421
431
422
432
self .kernel_fn .init (kernel_seed , stein_particles .shape )
423
- return SteinVIState (self .optim .init (params ), rng_key )
433
+ return SteinVIState (
434
+ self .optim .init (params ),
435
+ rng_key ,
436
+ self .loss_temperature ,
437
+ self .repulsion_temperature ,
438
+ )
424
439
425
440
def get_params (self , state : SteinVIState ):
426
441
"""Gets values at `param` sites of the `model` and `guide`.
@@ -444,16 +459,28 @@ def update(self, state: SteinVIState, *args, **kwargs) -> SteinVIState:
444
459
params = self .optim .get_params (state .optim_state )
445
460
optim_state = state .optim_state
446
461
loss_val , grads = self ._svgd_loss_and_grads (
447
- rng_key_step , params , * args , ** kwargs , ** self .static_kwargs
462
+ rng_key_step ,
463
+ params ,
464
+ state .loss_temperature ,
465
+ state .repulsion_temperature ,
466
+ * args ,
467
+ ** kwargs ,
468
+ ** self .static_kwargs ,
448
469
)
449
470
optim_state = self .optim .update (grads , optim_state )
450
- return SteinVIState (optim_state , rng_key ), loss_val
471
+ return SteinVIState (
472
+ optim_state , rng_key , state .loss_temperature , state .repulsion_temperature
473
+ ), loss_val
451
474
452
475
def setup_run (self , rng_key , num_steps , args , init_state , kwargs ):
453
476
if init_state is None :
454
477
state = self .init (rng_key , * args , ** kwargs )
455
478
else :
479
+ assert isinstance (init_state , ASVGDState ), (
480
+ "The init_state much be an instance of ASVGDState"
481
+ )
456
482
state = init_state
483
+
457
484
loss = self .evaluate (state , * args , ** kwargs )
458
485
459
486
info_init = (state , loss )
@@ -526,7 +553,13 @@ def evaluate(self, state: SteinVIState, *args, **kwargs):
526
553
_ , _ , rng_key_eval = random .split (state .rng_key , num = 3 )
527
554
params = self .optim .get_params (state .optim_state )
528
555
normed_stein_force , _ = self ._svgd_loss_and_grads (
529
- rng_key_eval , params , * args , ** kwargs , ** self .static_kwargs
556
+ rng_key_eval ,
557
+ params ,
558
+ state .loss_temperature ,
559
+ state .repulsion_temperature ,
560
+ * args ,
561
+ ** kwargs ,
562
+ ** self .static_kwargs ,
530
563
)
531
564
return normed_stein_force
532
565
@@ -620,6 +653,12 @@ def __init__(
620
653
)
621
654
622
655
656
+ ASVGDState = namedtuple (
657
+ "ASVGDState" ,
658
+ ["step_count" , "num_steps" , "num_cycles" , "transition_speed" , "steinvi_state" ],
659
+ )
660
+
661
+
623
662
class ASVGD (SVGD ):
624
663
"""Annealing Stein variational gradient descent [1].
625
664
@@ -693,12 +732,17 @@ def __init__(
693
732
kernel_fn ,
694
733
num_stein_particles = 10 ,
695
734
num_cycles = 10 ,
696
- trans_speed = 10 ,
735
+ transition_speed = 10 ,
697
736
guide_kwargs = {},
698
737
** static_kwargs ,
699
738
):
739
+ assert num_cycles > 0 , f"The number of cycles must be >0. Got { num_cycles } ."
740
+ assert transition_speed > 0 , (
741
+ f"The transtion speed must be >0. Got { transition_speed } ."
742
+ )
743
+
700
744
self .num_cycles = num_cycles
701
- self .trans_speed = trans_speed
745
+ self .transition_speed = transition_speed
702
746
703
747
super ().__init__ (
704
748
model ,
@@ -710,63 +754,155 @@ def __init__(
710
754
)
711
755
712
756
@staticmethod
713
- def _cyclical_annealing (num_steps : int , num_cycles : int , trans_speed : int ):
757
+ def _cyclical_annealing (
758
+ step_count : int , num_steps : int , num_cycles : int , trans_speed : int
759
+ ):
714
760
"""Cyclical annealing schedule as in eq. 4 of [1].
715
761
716
762
**References** (MLA)
717
763
718
764
1. D'Angelo, Francesco, and Vincent Fortuin. "Annealed Stein Variational Gradient Descent."
719
765
Third Symposium on Advances in Approximate Bayesian Inference, 2021.
720
766
767
+ :param step_count: The current number of steps taken.
721
768
:param num_steps: The total number of steps. Corresponds to $T$ in eq. 4 of [1].
722
769
:param num_cycles: The total number of cycles. Corresponds to $C$ in eq. 4 of [1].
723
770
:param trans_speed: Speed of transition between two phases. Corresponds to $p$ in eq. 4 of [1].
724
771
"""
725
- norm = float (num_steps + 1 ) / float ( num_cycles )
772
+ norm = (num_steps + 1 ) / num_cycles
726
773
cycle_len = num_steps // num_cycles
727
- last_start = (num_cycles - 1 ) * cycle_len
728
774
729
- def cycle_fn (t ):
730
- last_cycle = t // last_start
731
- return (1 - last_cycle ) * (
732
- ((t % cycle_len ) + 1 ) / norm
733
- ) ** trans_speed + last_cycle
775
+ # Safegaurd against num_cycles=1, which would cause division by zero
776
+ last_start = jnp .maximum ((num_cycles - 1 ) * cycle_len , 1.0 )
777
+ last_cycle = step_count // last_start
734
778
735
- return cycle_fn
779
+ return (1 - last_cycle ) * (
780
+ ((step_count % cycle_len ) + 1 ) / norm
781
+ ) ** trans_speed + last_cycle
736
782
737
- def setup_run (self , rng_key , num_steps , args , init_state , kwargs ):
738
- cyc_fn = ASVGD ._cyclical_annealing (num_steps , self .num_cycles , self .trans_speed )
739
-
740
- (
741
- istep ,
742
- idiag ,
743
- icol ,
744
- iext ,
745
- iinit ,
746
- ) = super ().setup_run (
747
- rng_key ,
748
- num_steps ,
749
- args ,
750
- init_state ,
751
- kwargs ,
783
+ def init (self , rng_key , num_steps , * args , ** kwargs ):
784
+ """Register random variable transformations, constraints and determine initialize positions of the particles.
785
+
786
+ :param jax.random.PRNGKey rng_key: Random number generator seed.
787
+ :param args: Positional arguments to the model and guide.
788
+ :param num_steps: Totat number of steps in the optimization.
789
+ :param kwargs: Keyword arguments to the model and guide.
790
+ :return: Initial :data:`ASVGDState`.
791
+ """
792
+ # Sets initial loss temperature to 1, the temperature is adjusted by calls to `self.update``.
793
+ steinvi_state = super ().init (rng_key , * args , ** kwargs )
794
+ return ASVGDState (
795
+ step_count = 0.0 ,
796
+ num_steps = float (num_steps ),
797
+ num_cycles = float (self .num_cycles ),
798
+ transition_speed = float (self .transition_speed ),
799
+ steinvi_state = steinvi_state ,
752
800
)
753
801
802
+ def get_params (self , state : ASVGDState ):
803
+ return super ().get_params (state .steinvi_state )
804
+
805
+ def update (self , state : ASVGDState , * args , ** kwargs ) -> ASVGDState :
806
+ step_count , num_steps , num_cycles , transition_speed , steinvi_state = state
807
+
808
+ # Compute the loss temperature
809
+ loss_temperature = ASVGD ._cyclical_annealing (
810
+ step_count , num_steps , num_cycles , transition_speed
811
+ )
812
+
813
+ steinvi_state = SteinVIState (
814
+ rng_key = steinvi_state .rng_key ,
815
+ optim_state = steinvi_state .optim_state ,
816
+ loss_temperature = loss_temperature ,
817
+ repulsion_temperature = steinvi_state .repulsion_temperature ,
818
+ )
819
+ new_steinvi_state , loss_val = super ().update (steinvi_state , * args , ** kwargs )
820
+
821
+ new_asvgd_state = ASVGDState (
822
+ step_count + 1 ,
823
+ state .num_steps ,
824
+ state .num_cycles ,
825
+ state .transition_speed ,
826
+ steinvi_state = new_steinvi_state ,
827
+ )
828
+ return new_asvgd_state , loss_val
829
+
830
+ def evaluate (self , state : ASVGDState , * args , ** kwargs ):
831
+ """Take a single step of Stein (possibly on a batch / minibatch of data).
832
+
833
+ :param ASVGDState state: Current state of inference.
834
+ :param args: Positional arguments to the model and guide.
835
+ :param kwargs: Keyword arguments to the model and guide.
836
+ :return: Normed Stein force.
837
+ """
838
+
839
+ return super ().evaluate (state .steinvi_state , * args , ** kwargs )
840
+
841
+ def setup_run (self , rng_key , num_steps , args , init_state , kwargs ):
842
+ if init_state is None :
843
+ state = self .init (rng_key , num_steps , * args , ** kwargs )
844
+ else :
845
+ assert isinstance (init_state , ASVGDState ), (
846
+ "The init_state much be an instance of ASVGDState"
847
+ )
848
+ state = init_state
849
+
850
+ loss = self .evaluate (state , * args , ** kwargs )
851
+
852
+ info_init = (state , loss )
853
+
754
854
def step (info ):
755
- t , iinfo = info [0 ], info [- 1 ]
756
- self .loss_temperature = cyc_fn (t ) / float (self .num_stein_particles )
757
- return (t + 1 , istep (iinfo ))
855
+ state , loss = info
856
+ return self .update (state , * args , ** kwargs ) # uses closure!
857
+
858
+ def collect (info ):
859
+ _ , loss = info
860
+ return loss
861
+
862
+ def extract (info ):
863
+ state , _ = info
864
+ return state
758
865
759
866
def diagnostic (info ):
760
- _ , iinfo = info
761
- return idiag ( iinfo )
867
+ _ , loss = info
868
+ return f"Stein force { loss :.2f } ."
762
869
763
- def collect (info ):
764
- _ , iinfo = info
765
- return icol (iinfo )
870
+ return step , diagnostic , collect , extract , info_init
766
871
767
- def extract_state (info ):
768
- _ , iinfo = info
769
- return iext (iinfo )
872
+ def run (
873
+ self ,
874
+ rng_key ,
875
+ num_steps ,
876
+ * args ,
877
+ progress_bar = True ,
878
+ init_state = None ,
879
+ ** kwargs ,
880
+ ):
881
+ """Run ASVGD inference.
770
882
771
- info_init = (0 , iinit )
772
- return step , diagnostic , collect , extract_state , info_init
883
+ :param jax.random.PRNGKey rng_key: Random number generator seed.
884
+ :param int num_steps: Number of steps to optimize.
885
+ :param *args: Positional arguments to the model and guide.
886
+ :param bool progress_bar: Use a progress bar. Default is `True`.
887
+ Inference is faster with `False`.
888
+ :param SteinVIState init_state: Initial state of inference.
889
+ Default is ``None``, which will initialize using init before running inference.
890
+ :param **kwargs: Keyword arguments to the model and guide.
891
+ """
892
+ step , diagnostic , collect , extract , init_info = self .setup_run (
893
+ rng_key , num_steps , args , init_state , kwargs
894
+ )
895
+
896
+ auxiliaries , last_res = fori_collect (
897
+ 0 ,
898
+ num_steps ,
899
+ step ,
900
+ init_info ,
901
+ progbar = progress_bar ,
902
+ transform = collect ,
903
+ return_last_val = True ,
904
+ diagnostics_fn = diagnostic if progress_bar else None ,
905
+ )
906
+
907
+ state = extract (last_res )
908
+ return SteinVIRunResult (self .get_params (state ), state , auxiliaries )
0 commit comments