8
8
from pytensor .tensor .math import sum as at_sum
9
9
from pytensor .tensor .random .op import RandomVariable
10
10
from pytensor .tensor .random .utils import broadcast_params
11
- from pytensor .tensor .shape import Shape , Shape_i
11
+ from pytensor .tensor .shape import Shape , Shape_i , shape_padleft
12
12
from pytensor .tensor .subtensor import (
13
13
AdvancedSubtensor ,
14
14
AdvancedSubtensor1 ,
@@ -115,23 +115,10 @@ def local_dimshuffle_rv_lift(fgraph, node):
115
115
116
116
For example, ``normal(mu, std).T == normal(mu.T, std.T)``.
117
117
118
- The basic idea behind this rewrite is that we need to separate the
119
- ``DimShuffle``-ing into distinct ``DimShuffle``s that each occur in two
120
- distinct sub-spaces: the (set of independent) parameters and ``size``
121
- (i.e. replications) sub-spaces.
122
-
123
- If a ``DimShuffle`` exchanges dimensions across those two sub-spaces, then we
124
- don't do anything.
125
-
126
- Otherwise, if the ``DimShuffle`` only exchanges dimensions within each of
127
- those sub-spaces, we can break it apart and apply the parameter-space
128
- ``DimShuffle`` to the distribution parameters, and then apply the
129
- replications-space ``DimShuffle`` to the ``size`` tuple. The latter is a
130
- particularly simple rearranging of a tuple, but the former requires a
131
- little more work.
132
-
133
- TODO: Currently, multivariate support for this rewrite is disabled.
118
+ This rewrite is only applicable when the Dimshuffle operation does
119
+ not affect support dimensions.
134
120
121
+ TODO: Support dimension dropping
135
122
"""
136
123
137
124
ds_op = node .op
@@ -142,128 +129,67 @@ def local_dimshuffle_rv_lift(fgraph, node):
142
129
base_rv = node .inputs [0 ]
143
130
rv_node = base_rv .owner
144
131
145
- if not (
146
- rv_node and isinstance (rv_node .op , RandomVariable ) and rv_node .op .ndim_supp == 0
147
- ):
132
+ if not (rv_node and isinstance (rv_node .op , RandomVariable )):
148
133
return False
149
134
150
- # If no one else is using the underlying `RandomVariable`, then we can
151
- # do this; otherwise, the graph would be internally inconsistent.
152
- if is_rv_used_in_graph (base_rv , node , fgraph ):
135
+ # Dimshuffle which drop dimensions not supported yet
136
+ if ds_op .drop :
153
137
return False
154
138
155
139
rv_op = rv_node .op
156
140
rng , size , dtype , * dist_params = rv_node .inputs
141
+ rv = rv_node .default_output ()
157
142
158
- # We need to know the dimensions that were *not* added by the `size`
159
- # parameter (i.e. the dimensions corresponding to independent variates with
160
- # different parameter values)
161
- num_ind_dims = None
162
- if len (dist_params ) == 1 :
163
- num_ind_dims = dist_params [0 ].ndim
164
- else :
165
- # When there is more than one distribution parameter, assume that all
166
- # of them will broadcast to the maximum number of dimensions
167
- num_ind_dims = max (d .ndim for d in dist_params )
168
-
169
- # If the indices in `ds_new_order` are entirely within the replication
170
- # indices group or the independent variates indices group, then we can apply
171
- # this rewrite.
172
-
173
- ds_new_order = ds_op .new_order
174
- # Create a map from old index order to new/`DimShuffled` index order
175
- dim_orders = [(n , d ) for n , d in enumerate (ds_new_order ) if isinstance (d , int )]
176
-
177
- # Find the index at which the replications/independents split occurs
178
- reps_ind_split_idx = len (dim_orders ) - (num_ind_dims + rv_op .ndim_supp )
179
-
180
- ds_reps_new_dims = dim_orders [:reps_ind_split_idx ]
181
- ds_ind_new_dims = dim_orders [reps_ind_split_idx :]
182
- ds_in_ind_space = ds_ind_new_dims and all (
183
- d >= reps_ind_split_idx for n , d in ds_ind_new_dims
184
- )
143
+ # Check that Dimshuffle does not affect support dims
144
+ supp_dims = set (range (rv .ndim - rv_op .ndim_supp , rv .ndim ))
145
+ shuffled_dims = {dim for i , dim in enumerate (ds_op .shuffle ) if dim != i }
146
+ augmented_dims = set (d - rv_op .ndim_supp for d in ds_op .augment )
147
+ if (shuffled_dims | augmented_dims ) & supp_dims :
148
+ return False
185
149
186
- if ds_in_ind_space or (not ds_ind_new_dims and not ds_reps_new_dims ):
150
+ # If no one else is using the underlying RandomVariable, then we can
151
+ # do this; otherwise, the graph would be internally inconsistent.
152
+ if is_rv_used_in_graph (base_rv , node , fgraph ):
153
+ return False
187
154
188
- # Update the `size` array to reflect the `DimShuffle`d dimensions,
189
- # since the trailing dimensions in `size` represent the independent
190
- # variates dimensions (for univariate distributions, at least)
191
- has_size = get_vector_length (size ) > 0
192
- new_size = (
193
- [constant (1 , dtype = "int64" ) if o == "x" else size [o ] for o in ds_new_order ]
194
- if has_size
195
- else size
155
+ batched_dims = rv .ndim - rv_op .ndim_supp
156
+ batched_dims_ds_order = tuple (o for o in ds_op .new_order if o not in supp_dims )
157
+
158
+ # Make size explicit
159
+ missing_size_dims = batched_dims - get_vector_length (size )
160
+ if missing_size_dims > 0 :
161
+ full_size = tuple (broadcast_params (dist_params , rv_op .ndims_params )[0 ].shape )
162
+ size = full_size [:missing_size_dims ] + tuple (size )
163
+
164
+ # Update the size to reflect the DimShuffled dimensions
165
+ new_size = [
166
+ constant (1 , dtype = "int64" ) if o == "x" else size [o ]
167
+ for o in batched_dims_ds_order
168
+ ]
169
+
170
+ # Updates the params to reflect the Dimshuffled dimensions
171
+ new_dist_params = []
172
+ for param , param_ndim_supp in zip (dist_params , rv_op .ndims_params ):
173
+ # Add broadcastable dimensions to the parameters that would have been expanded by the size
174
+ padleft = batched_dims - (param .ndim - param_ndim_supp )
175
+ if padleft > 0 :
176
+ param = shape_padleft (param , padleft )
177
+
178
+ # Add the parameter support dimension indexes to the batched dimensions Dimshuffle
179
+ param_new_order = batched_dims_ds_order + tuple (
180
+ range (batched_dims , batched_dims + param_ndim_supp )
196
181
)
182
+ new_dist_params .append (param .dimshuffle (param_new_order ))
197
183
198
- # Compute the new axes parameter(s) for the `DimShuffle` that will be
199
- # applied to the `RandomVariable` parameters (they need to be offset)
200
- if ds_ind_new_dims :
201
- rv_params_new_order = [
202
- d - reps_ind_split_idx if isinstance (d , int ) else d
203
- for d in ds_new_order [ds_ind_new_dims [0 ][0 ] :]
204
- ]
205
-
206
- if not has_size and len (ds_new_order [: ds_ind_new_dims [0 ][0 ]]) > 0 :
207
- # Additional broadcast dimensions need to be added to the
208
- # independent dimensions (i.e. parameters), since there's no
209
- # `size` to which they can be added
210
- rv_params_new_order = (
211
- list (ds_new_order [: ds_ind_new_dims [0 ][0 ]]) + rv_params_new_order
212
- )
213
- else :
214
- # This case is reached when, for example, `ds_new_order` only
215
- # consists of new broadcastable dimensions (i.e. `"x"`s)
216
- rv_params_new_order = ds_new_order
217
-
218
- # Lift the `DimShuffle`s into the parameters
219
- # NOTE: The parameters might not be broadcasted against each other, so
220
- # we can only apply the parts of the `DimShuffle` that are relevant.
221
- new_dist_params = []
222
- for d in dist_params :
223
- if d .ndim < len (ds_ind_new_dims ):
224
- _rv_params_new_order = [
225
- o
226
- for o in rv_params_new_order
227
- if (isinstance (o , int ) and o < d .ndim ) or o == "x"
228
- ]
229
- else :
230
- _rv_params_new_order = rv_params_new_order
231
-
232
- new_dist_params .append (
233
- type (ds_op )(d .type .broadcastable , _rv_params_new_order )(d )
234
- )
235
- new_node = rv_op .make_node (rng , new_size , dtype , * new_dist_params )
236
-
237
- if config .compute_test_value != "off" :
238
- compute_test_value (new_node )
239
-
240
- out = new_node .outputs [1 ]
241
- if base_rv .name :
242
- out .name = f"{ base_rv .name } _lifted"
243
- return [out ]
184
+ new_node = rv_op .make_node (rng , new_size , dtype , * new_dist_params )
244
185
245
- ds_in_reps_space = ds_reps_new_dims and all (
246
- d < reps_ind_split_idx for n , d in ds_reps_new_dims
247
- )
248
-
249
- if ds_in_reps_space :
250
- # Update the `size` array to reflect the `DimShuffle`d dimensions.
251
- # There should be no need to `DimShuffle` now.
252
- new_size = [
253
- constant (1 , dtype = "int64" ) if o == "x" else size [o ] for o in ds_new_order
254
- ]
255
-
256
- new_node = rv_op .make_node (rng , new_size , dtype , * dist_params )
257
-
258
- if config .compute_test_value != "off" :
259
- compute_test_value (new_node )
260
-
261
- out = new_node .outputs [1 ]
262
- if base_rv .name :
263
- out .name = f"{ base_rv .name } _lifted"
264
- return [out ]
186
+ if config .compute_test_value != "off" :
187
+ compute_test_value (new_node )
265
188
266
- return False
189
+ out = new_node .outputs [1 ]
190
+ if base_rv .name :
191
+ out .name = f"{ base_rv .name } _lifted"
192
+ return [out ]
267
193
268
194
269
195
@node_rewriter ([Subtensor , AdvancedSubtensor1 , AdvancedSubtensor ])
0 commit comments