Skip to content

Commit 2ebfbf1

Browse files
committed
Expand and simplify local_dimshuffle_rv_lift
* The rewrite no longer bails out when dimshuffle affects both unique param dimensions and repeated param dimensions from the size argument. This requires: 1) Adding broadcastable dimensions to the parameters, which should be "cost-free" and would need to be done in the `perform` method anyway. 2) Extend size to incorporate implicit batch dimensions coming from the parameters. This requires computing the shape resulting from broadcasting the parameters. It's unclear whether this is less performant, because the `perform` method can now simply broadcast each parameter to the size, instead of having to broadcast the parameters together. * The rewrite now works with Multivariate RVs * The rewrite bails out when dimensions are dropped by the Dimshuffle. This case was not correctly handled by the previous rewrite
1 parent 8e61224 commit 2ebfbf1

File tree

2 files changed

+112
-129
lines changed

2 files changed

+112
-129
lines changed

pytensor/tensor/random/rewriting.py

+52-126
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pytensor.tensor.math import sum as at_sum
99
from pytensor.tensor.random.op import RandomVariable
1010
from pytensor.tensor.random.utils import broadcast_params
11-
from pytensor.tensor.shape import Shape, Shape_i
11+
from pytensor.tensor.shape import Shape, Shape_i, shape_padleft
1212
from pytensor.tensor.subtensor import (
1313
AdvancedSubtensor,
1414
AdvancedSubtensor1,
@@ -115,23 +115,10 @@ def local_dimshuffle_rv_lift(fgraph, node):
115115
116116
For example, ``normal(mu, std).T == normal(mu.T, std.T)``.
117117
118-
The basic idea behind this rewrite is that we need to separate the
119-
``DimShuffle``-ing into distinct ``DimShuffle``s that each occur in two
120-
distinct sub-spaces: the (set of independent) parameters and ``size``
121-
(i.e. replications) sub-spaces.
122-
123-
If a ``DimShuffle`` exchanges dimensions across those two sub-spaces, then we
124-
don't do anything.
125-
126-
Otherwise, if the ``DimShuffle`` only exchanges dimensions within each of
127-
those sub-spaces, we can break it apart and apply the parameter-space
128-
``DimShuffle`` to the distribution parameters, and then apply the
129-
replications-space ``DimShuffle`` to the ``size`` tuple. The latter is a
130-
particularly simple rearranging of a tuple, but the former requires a
131-
little more work.
132-
133-
TODO: Currently, multivariate support for this rewrite is disabled.
118+
This rewrite is only applicable when the Dimshuffle operation does
119+
not affect support dimensions.
134120
121+
TODO: Support dimension dropping
135122
"""
136123

137124
ds_op = node.op
@@ -142,128 +129,67 @@ def local_dimshuffle_rv_lift(fgraph, node):
142129
base_rv = node.inputs[0]
143130
rv_node = base_rv.owner
144131

145-
if not (
146-
rv_node and isinstance(rv_node.op, RandomVariable) and rv_node.op.ndim_supp == 0
147-
):
132+
if not (rv_node and isinstance(rv_node.op, RandomVariable)):
148133
return False
149134

150-
# If no one else is using the underlying `RandomVariable`, then we can
151-
# do this; otherwise, the graph would be internally inconsistent.
152-
if is_rv_used_in_graph(base_rv, node, fgraph):
135+
# Dimshuffle which drop dimensions not supported yet
136+
if ds_op.drop:
153137
return False
154138

155139
rv_op = rv_node.op
156140
rng, size, dtype, *dist_params = rv_node.inputs
141+
rv = rv_node.default_output()
157142

158-
# We need to know the dimensions that were *not* added by the `size`
159-
# parameter (i.e. the dimensions corresponding to independent variates with
160-
# different parameter values)
161-
num_ind_dims = None
162-
if len(dist_params) == 1:
163-
num_ind_dims = dist_params[0].ndim
164-
else:
165-
# When there is more than one distribution parameter, assume that all
166-
# of them will broadcast to the maximum number of dimensions
167-
num_ind_dims = max(d.ndim for d in dist_params)
168-
169-
# If the indices in `ds_new_order` are entirely within the replication
170-
# indices group or the independent variates indices group, then we can apply
171-
# this rewrite.
172-
173-
ds_new_order = ds_op.new_order
174-
# Create a map from old index order to new/`DimShuffled` index order
175-
dim_orders = [(n, d) for n, d in enumerate(ds_new_order) if isinstance(d, int)]
176-
177-
# Find the index at which the replications/independents split occurs
178-
reps_ind_split_idx = len(dim_orders) - (num_ind_dims + rv_op.ndim_supp)
179-
180-
ds_reps_new_dims = dim_orders[:reps_ind_split_idx]
181-
ds_ind_new_dims = dim_orders[reps_ind_split_idx:]
182-
ds_in_ind_space = ds_ind_new_dims and all(
183-
d >= reps_ind_split_idx for n, d in ds_ind_new_dims
184-
)
143+
# Check that Dimshuffle does not affect support dims
144+
supp_dims = set(range(rv.ndim - rv_op.ndim_supp, rv.ndim))
145+
shuffled_dims = {dim for i, dim in enumerate(ds_op.shuffle) if dim != i}
146+
augmented_dims = set(d - rv_op.ndim_supp for d in ds_op.augment)
147+
if (shuffled_dims | augmented_dims) & supp_dims:
148+
return False
185149

186-
if ds_in_ind_space or (not ds_ind_new_dims and not ds_reps_new_dims):
150+
# If no one else is using the underlying RandomVariable, then we can
151+
# do this; otherwise, the graph would be internally inconsistent.
152+
if is_rv_used_in_graph(base_rv, node, fgraph):
153+
return False
187154

188-
# Update the `size` array to reflect the `DimShuffle`d dimensions,
189-
# since the trailing dimensions in `size` represent the independent
190-
# variates dimensions (for univariate distributions, at least)
191-
has_size = get_vector_length(size) > 0
192-
new_size = (
193-
[constant(1, dtype="int64") if o == "x" else size[o] for o in ds_new_order]
194-
if has_size
195-
else size
155+
batched_dims = rv.ndim - rv_op.ndim_supp
156+
batched_dims_ds_order = tuple(o for o in ds_op.new_order if o not in supp_dims)
157+
158+
# Make size explicit
159+
missing_size_dims = batched_dims - get_vector_length(size)
160+
if missing_size_dims > 0:
161+
full_size = tuple(broadcast_params(dist_params, rv_op.ndims_params)[0].shape)
162+
size = full_size[:missing_size_dims] + tuple(size)
163+
164+
# Update the size to reflect the DimShuffled dimensions
165+
new_size = [
166+
constant(1, dtype="int64") if o == "x" else size[o]
167+
for o in batched_dims_ds_order
168+
]
169+
170+
# Updates the params to reflect the Dimshuffled dimensions
171+
new_dist_params = []
172+
for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params):
173+
# Add broadcastable dimensions to the parameters that would have been expanded by the size
174+
padleft = batched_dims - (param.ndim - param_ndim_supp)
175+
if padleft > 0:
176+
param = shape_padleft(param, padleft)
177+
178+
# Add the parameter support dimension indexes to the batched dimensions Dimshuffle
179+
param_new_order = batched_dims_ds_order + tuple(
180+
range(batched_dims, batched_dims + param_ndim_supp)
196181
)
182+
new_dist_params.append(param.dimshuffle(param_new_order))
197183

198-
# Compute the new axes parameter(s) for the `DimShuffle` that will be
199-
# applied to the `RandomVariable` parameters (they need to be offset)
200-
if ds_ind_new_dims:
201-
rv_params_new_order = [
202-
d - reps_ind_split_idx if isinstance(d, int) else d
203-
for d in ds_new_order[ds_ind_new_dims[0][0] :]
204-
]
205-
206-
if not has_size and len(ds_new_order[: ds_ind_new_dims[0][0]]) > 0:
207-
# Additional broadcast dimensions need to be added to the
208-
# independent dimensions (i.e. parameters), since there's no
209-
# `size` to which they can be added
210-
rv_params_new_order = (
211-
list(ds_new_order[: ds_ind_new_dims[0][0]]) + rv_params_new_order
212-
)
213-
else:
214-
# This case is reached when, for example, `ds_new_order` only
215-
# consists of new broadcastable dimensions (i.e. `"x"`s)
216-
rv_params_new_order = ds_new_order
217-
218-
# Lift the `DimShuffle`s into the parameters
219-
# NOTE: The parameters might not be broadcasted against each other, so
220-
# we can only apply the parts of the `DimShuffle` that are relevant.
221-
new_dist_params = []
222-
for d in dist_params:
223-
if d.ndim < len(ds_ind_new_dims):
224-
_rv_params_new_order = [
225-
o
226-
for o in rv_params_new_order
227-
if (isinstance(o, int) and o < d.ndim) or o == "x"
228-
]
229-
else:
230-
_rv_params_new_order = rv_params_new_order
231-
232-
new_dist_params.append(
233-
type(ds_op)(d.type.broadcastable, _rv_params_new_order)(d)
234-
)
235-
new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)
236-
237-
if config.compute_test_value != "off":
238-
compute_test_value(new_node)
239-
240-
out = new_node.outputs[1]
241-
if base_rv.name:
242-
out.name = f"{base_rv.name}_lifted"
243-
return [out]
184+
new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)
244185

245-
ds_in_reps_space = ds_reps_new_dims and all(
246-
d < reps_ind_split_idx for n, d in ds_reps_new_dims
247-
)
248-
249-
if ds_in_reps_space:
250-
# Update the `size` array to reflect the `DimShuffle`d dimensions.
251-
# There should be no need to `DimShuffle` now.
252-
new_size = [
253-
constant(1, dtype="int64") if o == "x" else size[o] for o in ds_new_order
254-
]
255-
256-
new_node = rv_op.make_node(rng, new_size, dtype, *dist_params)
257-
258-
if config.compute_test_value != "off":
259-
compute_test_value(new_node)
260-
261-
out = new_node.outputs[1]
262-
if base_rv.name:
263-
out.name = f"{base_rv.name}_lifted"
264-
return [out]
186+
if config.compute_test_value != "off":
187+
compute_test_value(new_node)
265188

266-
return False
189+
out = new_node.outputs[1]
190+
if base_rv.name:
191+
out.name = f"{base_rv.name}_lifted"
192+
return [out]
267193

268194

269195
@node_rewriter([Subtensor, AdvancedSubtensor1, AdvancedSubtensor])

tests/tensor/random/test_rewriting.py

+60-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pytensor.graph.fg import FunctionGraph
1010
from pytensor.graph.rewriting.basic import EquilibriumGraphRewriter
1111
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
12+
from pytensor.tensor import constant
1213
from pytensor.tensor.elemwise import DimShuffle
1314
from pytensor.tensor.random.basic import (
1415
dirichlet,
@@ -42,7 +43,11 @@ def apply_local_rewrite_to_rv(
4243

4344
size_at = []
4445
for s in size:
45-
s_at = iscalar()
46+
# To test DimShuffle with dropping dims we need that size dimension to be constant
47+
if s == 1:
48+
s_at = constant(np.array(1, dtype="int32"))
49+
else:
50+
s_at = iscalar()
4651
s_at.tag.test_value = s
4752
size_at.append(s_at)
4853

@@ -314,7 +319,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
314319
),
315320
(
316321
("x", 1, 0, 2, "x"),
317-
False,
322+
True,
318323
normal,
319324
(
320325
np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
@@ -332,7 +337,30 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
332337
(3, 2, 2),
333338
1,
334339
),
335-
# A multi-dimensional case
340+
# Supported multi-dimensional cases
341+
(
342+
(1, 0, 2),
343+
True,
344+
multivariate_normal,
345+
(
346+
np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
347+
np.eye(2).astype(config.floatX) * 1e-6,
348+
),
349+
(3, 2),
350+
1e-3,
351+
),
352+
(
353+
(1, 0, "x", 2),
354+
True,
355+
multivariate_normal,
356+
(
357+
np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
358+
np.eye(2).astype(config.floatX) * 1e-6,
359+
),
360+
(3, 2),
361+
1e-3,
362+
),
363+
# Not supported multi-dimensional cases where dimshuffle affects the support dimensionality
336364
(
337365
(0, 2, 1),
338366
False,
@@ -344,6 +372,35 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
344372
(3, 2),
345373
1e-3,
346374
),
375+
(
376+
(0, 1, 2, "x"),
377+
False,
378+
multivariate_normal,
379+
(
380+
np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
381+
np.eye(2).astype(config.floatX) * 1e-6,
382+
),
383+
(3, 2),
384+
1e-3,
385+
),
386+
pytest.param(
387+
(1,),
388+
True,
389+
normal,
390+
(0, 1),
391+
(1, 2),
392+
1e-3,
393+
marks=pytest.mark.xfail(reason="Dropping dimensions not supported yet"),
394+
),
395+
pytest.param(
396+
(1,),
397+
True,
398+
normal,
399+
([[0, 0]], 1),
400+
(1, 2),
401+
1e-3,
402+
marks=pytest.mark.xfail(reason="Dropping dimensions not supported yet"),
403+
),
347404
],
348405
)
349406
@config.change_flags(compute_test_value_opt="raise", compute_test_value="raise")

0 commit comments

Comments
 (0)