20
20
use std:: collections:: { HashMap , HashSet } ;
21
21
use std:: fmt:: { self , Debug , Display , Formatter } ;
22
22
use std:: hash:: { Hash , Hasher } ;
23
- use std:: sync:: Arc ;
23
+ use std:: sync:: { Arc , OnceLock } ;
24
24
25
25
use super :: dml:: CopyTo ;
26
26
use super :: DdlStatement ;
@@ -45,7 +45,8 @@ use crate::{
45
45
46
46
use arrow:: datatypes:: { DataType , Field , Schema , SchemaRef } ;
47
47
use datafusion_common:: tree_node:: {
48
- Transformed , TransformedResult , TreeNode , TreeNodeRecursion , TreeNodeVisitor ,
48
+ Transformed , TransformedIterator , TransformedResult , TreeNode , TreeNodeRecursion ,
49
+ TreeNodeVisitor ,
49
50
} ;
50
51
use datafusion_common:: {
51
52
aggregate_functional_dependencies, internal_err, plan_err, Column , Constraints ,
@@ -1131,6 +1132,202 @@ impl LogicalPlan {
1131
1132
} ) ?;
1132
1133
Ok ( ( ) )
1133
1134
}
1135
+ }
1136
+
1137
+ // TODO put this somewhere better than here
1138
+
1139
+ /// A temporary node that is left in place while rewriting the children of a
1140
+ /// [`LogicalPlan`]. This is necessary to ensure that the `LogicalPlan` is
1141
+ /// always in a valid state (from the Rust perspective)
1142
+ static PLACEHOLDER : OnceLock < Arc < LogicalPlan > > = OnceLock :: new ( ) ;
1143
+
1144
+ /// its inputs, so this code would not be needed. However, for now we try and
1145
+ /// unwrap the `Arc` which avoids `clone`ing in most cases.
1146
+ ///
1147
+ /// On error, node be left with a placeholder logical plan
1148
+ fn rewrite_arc < F > (
1149
+ node : & mut Arc < LogicalPlan > ,
1150
+ mut f : F ,
1151
+ ) -> Result < Transformed < & mut Arc < LogicalPlan > > >
1152
+ where
1153
+ F : FnMut ( LogicalPlan ) -> Result < Transformed < LogicalPlan > > ,
1154
+ {
1155
+ // We need to leave a valid node in the Arc, while we rewrite the existing
1156
+ // one, so use a single global static placeholder node
1157
+ let mut new_node = PLACEHOLDER
1158
+ . get_or_init ( || {
1159
+ Arc :: new ( LogicalPlan :: EmptyRelation ( EmptyRelation {
1160
+ produce_one_row : false ,
1161
+ schema : DFSchemaRef :: new ( DFSchema :: empty ( ) ) ,
1162
+ } ) )
1163
+ } )
1164
+ . clone ( ) ;
1165
+
1166
+ // take the old value out of the Arc
1167
+ std:: mem:: swap ( node, & mut new_node) ;
1168
+
1169
+ // try to update existing node, if it isn't shared with others
1170
+ let new_node = Arc :: try_unwrap ( new_node)
1171
+ // if None is returned, there is another reference to this
1172
+ // LogicalPlan, so we must clone instead
1173
+ . unwrap_or_else ( |node| node. as_ref ( ) . clone ( ) ) ;
1174
+
1175
+ // apply the actual transform
1176
+ let result = f ( new_node) ?;
1177
+
1178
+ // put the new value back into the Arc
1179
+ let mut new_node = Arc :: new ( result. data ) ;
1180
+ std:: mem:: swap ( node, & mut new_node) ;
1181
+
1182
+ // return the `node` back
1183
+ Ok ( Transformed :: new ( node, result. transformed , result. tnr ) )
1184
+ }
1185
+
1186
+ /// Rewrite the arc and discard the contents of Transformed
1187
+ fn rewrite_arc_no_data < F > ( node : & mut Arc < LogicalPlan > , f : F ) -> Result < Transformed < ( ) > >
1188
+ where
1189
+ F : FnMut ( LogicalPlan ) -> Result < Transformed < LogicalPlan > > ,
1190
+ {
1191
+ rewrite_arc ( node, f) . map ( |res| res. discard_data ( ) )
1192
+ }
1193
+
1194
+ /// Rewrites all inputs for an Extension node "in place"
1195
+ /// (it currently has to copy values because there are no APIs for in place modification)
1196
+ ///
1197
+ /// Should be removed when we have an API for in place modifications of the
1198
+ /// extension to avoid these copies
1199
+ fn rewrite_extension_inputs < F > (
1200
+ node : & mut Arc < dyn UserDefinedLogicalNode > ,
1201
+ f : F ,
1202
+ ) -> Result < Transformed < ( ) > >
1203
+ where
1204
+ F : FnMut ( LogicalPlan ) -> Result < Transformed < LogicalPlan > > ,
1205
+ {
1206
+ let Transformed {
1207
+ data : new_inputs,
1208
+ transformed,
1209
+ tnr,
1210
+ } = node
1211
+ . inputs ( )
1212
+ . into_iter ( )
1213
+ . cloned ( )
1214
+ . map_until_stop_and_collect ( f) ?;
1215
+
1216
+ let exprs = node. expressions ( ) ;
1217
+ let mut new_node = node. from_template ( & exprs, & new_inputs) ;
1218
+ std:: mem:: swap ( node, & mut new_node) ;
1219
+ Ok ( Transformed {
1220
+ data : ( ) ,
1221
+ transformed,
1222
+ tnr,
1223
+ } )
1224
+ }
1225
+
1226
+ impl LogicalPlan {
1227
+ /// applies `f` to each input of this plan node, rewriting them *in place.*
1228
+ ///
1229
+ /// # Notes
1230
+ /// Inputs include both direct children as well as any embedded subquery
1231
+ /// `LogicalPlan`s, for example such as are in [`Expr::Exists`].
1232
+ ///
1233
+ /// If `f` returns an `Err`, that Err is returned, and the inputs are left
1234
+ /// in a partially modified state
1235
+ pub ( crate ) fn rewrite_children < F > ( & mut self , mut f : F ) -> Result < Transformed < ( ) > >
1236
+ where
1237
+ F : FnMut ( Self ) -> Result < Transformed < Self > > ,
1238
+ {
1239
+ let children_result = match self {
1240
+ LogicalPlan :: Projection ( Projection { input, .. } ) => {
1241
+ rewrite_arc_no_data ( input, & mut f)
1242
+ }
1243
+ LogicalPlan :: Filter ( Filter { input, .. } ) => {
1244
+ rewrite_arc_no_data ( input, & mut f)
1245
+ }
1246
+ LogicalPlan :: Repartition ( Repartition { input, .. } ) => {
1247
+ rewrite_arc_no_data ( input, & mut f)
1248
+ }
1249
+ LogicalPlan :: Window ( Window { input, .. } ) => {
1250
+ rewrite_arc_no_data ( input, & mut f)
1251
+ }
1252
+ LogicalPlan :: Aggregate ( Aggregate { input, .. } ) => {
1253
+ rewrite_arc_no_data ( input, & mut f)
1254
+ }
1255
+ LogicalPlan :: Sort ( Sort { input, .. } ) => rewrite_arc_no_data ( input, & mut f) ,
1256
+ LogicalPlan :: Join ( Join { left, right, .. } ) => {
1257
+ let results = [ left, right]
1258
+ . into_iter ( )
1259
+ . map_until_stop_and_collect ( |input| rewrite_arc ( input, & mut f) ) ?;
1260
+ Ok ( results. discard_data ( ) )
1261
+ }
1262
+ LogicalPlan :: CrossJoin ( CrossJoin { left, right, .. } ) => {
1263
+ let results = [ left, right]
1264
+ . into_iter ( )
1265
+ . map_until_stop_and_collect ( |input| rewrite_arc ( input, & mut f) ) ?;
1266
+ Ok ( results. discard_data ( ) )
1267
+ }
1268
+ LogicalPlan :: Limit ( Limit { input, .. } ) => rewrite_arc_no_data ( input, & mut f) ,
1269
+ LogicalPlan :: Subquery ( Subquery { subquery, .. } ) => {
1270
+ rewrite_arc_no_data ( subquery, & mut f)
1271
+ }
1272
+ LogicalPlan :: SubqueryAlias ( SubqueryAlias { input, .. } ) => {
1273
+ rewrite_arc_no_data ( input, & mut f)
1274
+ }
1275
+ LogicalPlan :: Extension ( extension) => {
1276
+ rewrite_extension_inputs ( & mut extension. node , & mut f)
1277
+ }
1278
+ LogicalPlan :: Union ( Union { inputs, .. } ) => {
1279
+ let results = inputs
1280
+ . iter_mut ( )
1281
+ . map_until_stop_and_collect ( |input| rewrite_arc ( input, & mut f) ) ?;
1282
+ Ok ( results. discard_data ( ) )
1283
+ }
1284
+ LogicalPlan :: Distinct (
1285
+ Distinct :: All ( input) | Distinct :: On ( DistinctOn { input, .. } ) ,
1286
+ ) => rewrite_arc_no_data ( input, & mut f) ,
1287
+ LogicalPlan :: Explain ( explain) => {
1288
+ rewrite_arc_no_data ( & mut explain. plan , & mut f)
1289
+ }
1290
+ LogicalPlan :: Analyze ( analyze) => {
1291
+ rewrite_arc_no_data ( & mut analyze. input , & mut f)
1292
+ }
1293
+ LogicalPlan :: Dml ( write) => rewrite_arc_no_data ( & mut write. input , & mut f) ,
1294
+ LogicalPlan :: Copy ( copy) => rewrite_arc_no_data ( & mut copy. input , & mut f) ,
1295
+ LogicalPlan :: Ddl ( ddl) => {
1296
+ if let Some ( input) = ddl. input_mut ( ) {
1297
+ rewrite_arc_no_data ( input, & mut f)
1298
+ } else {
1299
+ Ok ( Transformed :: no ( ( ) ) )
1300
+ }
1301
+ }
1302
+ LogicalPlan :: Unnest ( Unnest { input, .. } ) => {
1303
+ rewrite_arc_no_data ( input, & mut f)
1304
+ }
1305
+ LogicalPlan :: Prepare ( Prepare { input, .. } ) => {
1306
+ rewrite_arc_no_data ( input, & mut f)
1307
+ }
1308
+ LogicalPlan :: RecursiveQuery ( RecursiveQuery {
1309
+ static_term,
1310
+ recursive_term,
1311
+ ..
1312
+ } ) => {
1313
+ let results = [ static_term, recursive_term]
1314
+ . into_iter ( )
1315
+ . map_until_stop_and_collect ( |input| rewrite_arc ( input, & mut f) ) ?;
1316
+ Ok ( results. discard_data ( ) )
1317
+ }
1318
+ // plans without inputs
1319
+ LogicalPlan :: TableScan { .. }
1320
+ | LogicalPlan :: Statement { .. }
1321
+ | LogicalPlan :: EmptyRelation { .. }
1322
+ | LogicalPlan :: Values { .. }
1323
+ | LogicalPlan :: DescribeTable ( _) => Ok ( Transformed :: no ( ( ) ) ) ,
1324
+ } ?;
1325
+
1326
+ // after visiting the actual children we we need to visit any subqueries
1327
+ // that are inside the expressions
1328
+ // children_result.and_then(|| self.rewrite_subqueries(&mut f))
1329
+ Ok ( children_result)
1330
+ }
1134
1331
1135
1332
/// Return a `LogicalPlan` with all placeholders (e.g $1 $2,
1136
1333
/// ...) replaced with corresponding values provided in
0 commit comments