Skip to content

Commit 374cd89

Browse files
authored
Fix Memleak in ASVGD (#2003)
* added ASVGDState which tracks the annealing for ASVGD; added loss_temperature and repulsion_temperature to the SteinVIState; renamed _svgd_loss_and_grads => _loss_and_grads; added loss and repulsion temperature to signature of _loss_and_grads * fixed minor issues; reproduces results from previous version * added division by zero guard for ASVGD with num_cycles=1 and removed leak pytest.mark * removed ASVGD memleak test from CI
1 parent 109632f commit 374cd89

File tree

3 files changed

+194
-56
lines changed

3 files changed

+194
-56
lines changed

.github/workflows/ci.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ jobs:
8585
env:
8686
JAX_CHECK_TRACER_LEAKS: 1
8787
run: |
88-
pytest -vs test/contrib/einstein/test_steinvi.py::test_run_smoke -k ASVGD
8988
pytest -vs test/contrib/test_infer_discrete.py::test_scan_hmm_smoke
9089
pytest -vs test/infer/test_mcmc.py::test_chain_inside_jit
9190
pytest -vs test/infer/test_mcmc.py::test_chain_jit_args_smoke

numpyro/contrib/einstein/steinvi.py

Lines changed: 186 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@
2323
from numpyro.infer.util import transform_fn
2424
from numpyro.util import fori_collect
2525

26-
SteinVIState = namedtuple("SteinVIState", ["optim_state", "rng_key"])
26+
SteinVIState = namedtuple(
27+
"SteinVIState",
28+
["optim_state", "rng_key", "loss_temperature", "repulsion_temperature"],
29+
)
2730
SteinVIRunResult = namedtuple("SteinRunResult", ["params", "state", "losses"])
2831

2932

@@ -231,7 +234,15 @@ def local_trace(key):
231234

232235
return vmap(local_trace)(random.split(particle_seed, self.num_stein_particles))
233236

234-
def _svgd_loss_and_grads(self, rng_key, unconstr_params, *args, **kwargs):
237+
def _svgd_loss_and_grads(
238+
self,
239+
rng_key,
240+
unconstr_params,
241+
loss_temperature,
242+
repulsion_temperature,
243+
*args,
244+
**kwargs,
245+
):
235246
# Separate model and guide parameters, since only guide parameters are updated using Stein
236247
# Split parameters into model and guide components - only unflagged guide parameters are
237248
# optimized via Stein forces.
@@ -262,7 +273,7 @@ def particle_transform_fn(particle):
262273
ctparticle, _ = ravel_pytree(ctparams)
263274
return ctparticle
264275

265-
model = handlers.scale(self._inference_model, self.loss_temperature)
276+
model = handlers.scale(self._inference_model, loss_temperature)
266277

267278
def stein_loss_fn(key, particle, particle_idx):
268279
return self.stein_loss.particle_loss(
@@ -312,10 +323,9 @@ def body(attr_force, state, y):
312323
# Third term of eq. 9 from https://arxiv.org/pdf/2410.22948.
313324
repulsive_force = vmap(
314325
lambda y: jnp.mean(
315-
vmap(
316-
lambda x: self.repulsion_temperature
317-
* self._kernel_grad(kernel, x, y)
318-
)(stein_particles),
326+
vmap(lambda x: repulsion_temperature * self._kernel_grad(kernel, x, y))(
327+
stein_particles
328+
),
319329
axis=0,
320330
)
321331
)(stein_particles)
@@ -420,7 +430,12 @@ def init(self, rng_key, *args, **kwargs):
420430
)
421431

422432
self.kernel_fn.init(kernel_seed, stein_particles.shape)
423-
return SteinVIState(self.optim.init(params), rng_key)
433+
return SteinVIState(
434+
self.optim.init(params),
435+
rng_key,
436+
self.loss_temperature,
437+
self.repulsion_temperature,
438+
)
424439

425440
def get_params(self, state: SteinVIState):
426441
"""Gets values at `param` sites of the `model` and `guide`.
@@ -444,16 +459,28 @@ def update(self, state: SteinVIState, *args, **kwargs) -> SteinVIState:
444459
params = self.optim.get_params(state.optim_state)
445460
optim_state = state.optim_state
446461
loss_val, grads = self._svgd_loss_and_grads(
447-
rng_key_step, params, *args, **kwargs, **self.static_kwargs
462+
rng_key_step,
463+
params,
464+
state.loss_temperature,
465+
state.repulsion_temperature,
466+
*args,
467+
**kwargs,
468+
**self.static_kwargs,
448469
)
449470
optim_state = self.optim.update(grads, optim_state)
450-
return SteinVIState(optim_state, rng_key), loss_val
471+
return SteinVIState(
472+
optim_state, rng_key, state.loss_temperature, state.repulsion_temperature
473+
), loss_val
451474

452475
def setup_run(self, rng_key, num_steps, args, init_state, kwargs):
453476
if init_state is None:
454477
state = self.init(rng_key, *args, **kwargs)
455478
else:
479+
assert isinstance(init_state, ASVGDState), (
480+
"The init_state much be an instance of ASVGDState"
481+
)
456482
state = init_state
483+
457484
loss = self.evaluate(state, *args, **kwargs)
458485

459486
info_init = (state, loss)
@@ -526,7 +553,13 @@ def evaluate(self, state: SteinVIState, *args, **kwargs):
526553
_, _, rng_key_eval = random.split(state.rng_key, num=3)
527554
params = self.optim.get_params(state.optim_state)
528555
normed_stein_force, _ = self._svgd_loss_and_grads(
529-
rng_key_eval, params, *args, **kwargs, **self.static_kwargs
556+
rng_key_eval,
557+
params,
558+
state.loss_temperature,
559+
state.repulsion_temperature,
560+
*args,
561+
**kwargs,
562+
**self.static_kwargs,
530563
)
531564
return normed_stein_force
532565

@@ -620,6 +653,12 @@ def __init__(
620653
)
621654

622655

656+
ASVGDState = namedtuple(
657+
"ASVGDState",
658+
["step_count", "num_steps", "num_cycles", "transition_speed", "steinvi_state"],
659+
)
660+
661+
623662
class ASVGD(SVGD):
624663
"""Annealing Stein variational gradient descent [1].
625664
@@ -693,12 +732,17 @@ def __init__(
693732
kernel_fn,
694733
num_stein_particles=10,
695734
num_cycles=10,
696-
trans_speed=10,
735+
transition_speed=10,
697736
guide_kwargs={},
698737
**static_kwargs,
699738
):
739+
assert num_cycles > 0, f"The number of cycles must be >0. Got {num_cycles}."
740+
assert transition_speed > 0, (
741+
f"The transtion speed must be >0. Got {transition_speed}."
742+
)
743+
700744
self.num_cycles = num_cycles
701-
self.trans_speed = trans_speed
745+
self.transition_speed = transition_speed
702746

703747
super().__init__(
704748
model,
@@ -710,63 +754,155 @@ def __init__(
710754
)
711755

712756
@staticmethod
713-
def _cyclical_annealing(num_steps: int, num_cycles: int, trans_speed: int):
757+
def _cyclical_annealing(
758+
step_count: int, num_steps: int, num_cycles: int, trans_speed: int
759+
):
714760
"""Cyclical annealing schedule as in eq. 4 of [1].
715761
716762
**References** (MLA)
717763
718764
1. D'Angelo, Francesco, and Vincent Fortuin. "Annealed Stein Variational Gradient Descent."
719765
Third Symposium on Advances in Approximate Bayesian Inference, 2021.
720766
767+
:param step_count: The current number of steps taken.
721768
:param num_steps: The total number of steps. Corresponds to $T$ in eq. 4 of [1].
722769
:param num_cycles: The total number of cycles. Corresponds to $C$ in eq. 4 of [1].
723770
:param trans_speed: Speed of transition between two phases. Corresponds to $p$ in eq. 4 of [1].
724771
"""
725-
norm = float(num_steps + 1) / float(num_cycles)
772+
norm = (num_steps + 1) / num_cycles
726773
cycle_len = num_steps // num_cycles
727-
last_start = (num_cycles - 1) * cycle_len
728774

729-
def cycle_fn(t):
730-
last_cycle = t // last_start
731-
return (1 - last_cycle) * (
732-
((t % cycle_len) + 1) / norm
733-
) ** trans_speed + last_cycle
775+
# Safegaurd against num_cycles=1, which would cause division by zero
776+
last_start = jnp.maximum((num_cycles - 1) * cycle_len, 1.0)
777+
last_cycle = step_count // last_start
734778

735-
return cycle_fn
779+
return (1 - last_cycle) * (
780+
((step_count % cycle_len) + 1) / norm
781+
) ** trans_speed + last_cycle
736782

737-
def setup_run(self, rng_key, num_steps, args, init_state, kwargs):
738-
cyc_fn = ASVGD._cyclical_annealing(num_steps, self.num_cycles, self.trans_speed)
739-
740-
(
741-
istep,
742-
idiag,
743-
icol,
744-
iext,
745-
iinit,
746-
) = super().setup_run(
747-
rng_key,
748-
num_steps,
749-
args,
750-
init_state,
751-
kwargs,
783+
def init(self, rng_key, num_steps, *args, **kwargs):
784+
"""Register random variable transformations, constraints and determine initialize positions of the particles.
785+
786+
:param jax.random.PRNGKey rng_key: Random number generator seed.
787+
:param args: Positional arguments to the model and guide.
788+
:param num_steps: Totat number of steps in the optimization.
789+
:param kwargs: Keyword arguments to the model and guide.
790+
:return: Initial :data:`ASVGDState`.
791+
"""
792+
# Sets initial loss temperature to 1, the temperature is adjusted by calls to `self.update``.
793+
steinvi_state = super().init(rng_key, *args, **kwargs)
794+
return ASVGDState(
795+
step_count=0.0,
796+
num_steps=float(num_steps),
797+
num_cycles=float(self.num_cycles),
798+
transition_speed=float(self.transition_speed),
799+
steinvi_state=steinvi_state,
752800
)
753801

802+
def get_params(self, state: ASVGDState):
803+
return super().get_params(state.steinvi_state)
804+
805+
def update(self, state: ASVGDState, *args, **kwargs) -> ASVGDState:
806+
step_count, num_steps, num_cycles, transition_speed, steinvi_state = state
807+
808+
# Compute the loss temperature
809+
loss_temperature = ASVGD._cyclical_annealing(
810+
step_count, num_steps, num_cycles, transition_speed
811+
)
812+
813+
steinvi_state = SteinVIState(
814+
rng_key=steinvi_state.rng_key,
815+
optim_state=steinvi_state.optim_state,
816+
loss_temperature=loss_temperature,
817+
repulsion_temperature=steinvi_state.repulsion_temperature,
818+
)
819+
new_steinvi_state, loss_val = super().update(steinvi_state, *args, **kwargs)
820+
821+
new_asvgd_state = ASVGDState(
822+
step_count + 1,
823+
state.num_steps,
824+
state.num_cycles,
825+
state.transition_speed,
826+
steinvi_state=new_steinvi_state,
827+
)
828+
return new_asvgd_state, loss_val
829+
830+
def evaluate(self, state: ASVGDState, *args, **kwargs):
831+
"""Take a single step of Stein (possibly on a batch / minibatch of data).
832+
833+
:param ASVGDState state: Current state of inference.
834+
:param args: Positional arguments to the model and guide.
835+
:param kwargs: Keyword arguments to the model and guide.
836+
:return: Normed Stein force.
837+
"""
838+
839+
return super().evaluate(state.steinvi_state, *args, **kwargs)
840+
841+
def setup_run(self, rng_key, num_steps, args, init_state, kwargs):
842+
if init_state is None:
843+
state = self.init(rng_key, num_steps, *args, **kwargs)
844+
else:
845+
assert isinstance(init_state, ASVGDState), (
846+
"The init_state much be an instance of ASVGDState"
847+
)
848+
state = init_state
849+
850+
loss = self.evaluate(state, *args, **kwargs)
851+
852+
info_init = (state, loss)
853+
754854
def step(info):
755-
t, iinfo = info[0], info[-1]
756-
self.loss_temperature = cyc_fn(t) / float(self.num_stein_particles)
757-
return (t + 1, istep(iinfo))
855+
state, loss = info
856+
return self.update(state, *args, **kwargs) # uses closure!
857+
858+
def collect(info):
859+
_, loss = info
860+
return loss
861+
862+
def extract(info):
863+
state, _ = info
864+
return state
758865

759866
def diagnostic(info):
760-
_, iinfo = info
761-
return idiag(iinfo)
867+
_, loss = info
868+
return f"Stein force {loss:.2f}."
762869

763-
def collect(info):
764-
_, iinfo = info
765-
return icol(iinfo)
870+
return step, diagnostic, collect, extract, info_init
766871

767-
def extract_state(info):
768-
_, iinfo = info
769-
return iext(iinfo)
872+
def run(
873+
self,
874+
rng_key,
875+
num_steps,
876+
*args,
877+
progress_bar=True,
878+
init_state=None,
879+
**kwargs,
880+
):
881+
"""Run ASVGD inference.
770882
771-
info_init = (0, iinit)
772-
return step, diagnostic, collect, extract_state, info_init
883+
:param jax.random.PRNGKey rng_key: Random number generator seed.
884+
:param int num_steps: Number of steps to optimize.
885+
:param *args: Positional arguments to the model and guide.
886+
:param bool progress_bar: Use a progress bar. Default is `True`.
887+
Inference is faster with `False`.
888+
:param SteinVIState init_state: Initial state of inference.
889+
Default is ``None``, which will initialize using init before running inference.
890+
:param **kwargs: Keyword arguments to the model and guide.
891+
"""
892+
step, diagnostic, collect, extract, init_info = self.setup_run(
893+
rng_key, num_steps, args, init_state, kwargs
894+
)
895+
896+
auxiliaries, last_res = fori_collect(
897+
0,
898+
num_steps,
899+
step,
900+
init_info,
901+
progbar=progress_bar,
902+
transform=collect,
903+
return_last_val=True,
904+
diagnostics_fn=diagnostic if progress_bar else None,
905+
)
906+
907+
state = extract(last_res)
908+
return SteinVIRunResult(self.get_params(state), state, auxiliaries)

test/contrib/einstein/test_steinvi.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from collections import namedtuple
55
from functools import partial
6-
import os
76
import string
87

98
import numpy as np
@@ -120,13 +119,17 @@ def model(features, labels):
120119
@pytest.mark.parametrize("kernel", KERNELS)
121120
@pytest.mark.parametrize("problem", (uniform_normal, regression))
122121
@pytest.mark.parametrize("method", ("ASVGD", "SVGD", "SteinVI"))
123-
@pytest.mark.xfail(
124-
os.getenv("JAX_CHECK_TRACER_LEAKS") == "1", reason="Expected tracer leak"
125-
)
126122
def test_run_smoke(kernel, problem, method):
127123
true_coefs, data, model = problem()
128124
if method == "ASVGD":
129-
stein = ASVGD(model, Adam(1e-1), kernel, num_stein_particles=1)
125+
stein = ASVGD(
126+
model,
127+
Adam(1e-1),
128+
kernel,
129+
num_stein_particles=1,
130+
num_cycles=1,
131+
transition_speed=1,
132+
)
130133
if method == "SVGD":
131134
stein = SVGD(model, Adam(1e-1), kernel, num_stein_particles=1)
132135
if method == "SteinVI":

0 commit comments

Comments
 (0)