@@ -69,7 +69,7 @@ use rustc_data_structures::small_c_str::SmallCStr;
69
69
use rustc_errors:: { DiagCtxt , FatalError , Level } ;
70
70
use rustc_fs_util:: { link_or_copy, path_to_c_string} ;
71
71
use rustc_middle:: ty:: TyCtxt ;
72
- use rustc_session:: config:: { self , Lto , OutputType , Passes , SplitDwarfKind , SwitchWithOptPath } ;
72
+ use rustc_session:: config:: { self , AutoDiff , Lto , OutputType , Passes , SplitDwarfKind , SwitchWithOptPath } ;
73
73
use rustc_session:: Session ;
74
74
use rustc_span:: symbol:: sym;
75
75
use rustc_span:: InnerSpan ;
@@ -707,7 +707,7 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
707
707
708
708
709
709
unsafe fn create_call < ' a > ( tgt : & ' a Value , src : & ' a Value , rev_mode : bool ,
710
- llmod : & ' a llvm:: Module , llcx : & llvm:: Context , size_positions : & [ usize ] ) {
710
+ llmod : & ' a llvm:: Module , llcx : & llvm:: Context , size_positions : & [ usize ] , ad : & [ AutoDiff ] ) {
711
711
712
712
// first, remove all calls from fnc
713
713
let bb = LLVMGetFirstBasicBlock ( tgt) ;
@@ -729,12 +729,7 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,
729
729
let last_inst = LLVMRustGetLastInstruction ( bb) . unwrap ( ) ;
730
730
LLVMPositionBuilderAtEnd ( builder, bb) ;
731
731
732
- let safety_run_checks;
733
- if std:: env:: var ( "ENZYME_NO_SAFETY_CHECKS" ) . is_ok ( ) {
734
- safety_run_checks = false ;
735
- } else {
736
- safety_run_checks = true ;
737
- }
732
+ let safety_run_checks = !ad. contains ( & AutoDiff :: NoSafetyChecks ) ;
738
733
739
734
if inner_param_num == outer_param_num {
740
735
call_args = outer_args;
@@ -951,6 +946,7 @@ pub(crate) unsafe fn enzyme_ad(
951
946
diag_handler : & DiagCtxt ,
952
947
item : AutoDiffItem ,
953
948
logic_ref : EnzymeLogicRef ,
949
+ ad : & [ AutoDiff ] ,
954
950
) -> Result < ( ) , FatalError > {
955
951
let autodiff_mode = item. attrs . mode ;
956
952
let rust_name = item. source ;
@@ -1010,16 +1006,16 @@ pub(crate) unsafe fn enzyme_ad(
1010
1006
1011
1007
llvm:: set_strict_aliasing ( false ) ;
1012
1008
1013
- if std :: env :: var ( "ENZYME_PRINT_TA" ) . is_ok ( ) {
1009
+ if ad . contains ( & AutoDiff :: PrintTA ) {
1014
1010
llvm:: set_print_type ( true ) ;
1015
1011
}
1016
- if std :: env :: var ( "ENZYME_PRINT_AA" ) . is_ok ( ) {
1017
- llvm:: set_print_activity ( true ) ;
1012
+ if ad . contains ( & AutoDiff :: PrintTA ) {
1013
+ llvm:: set_print_type ( true ) ;
1018
1014
}
1019
- if std :: env :: var ( "ENZYME_PRINT_PERF" ) . is_ok ( ) {
1015
+ if ad . contains ( & AutoDiff :: PrintPerf ) {
1020
1016
llvm:: set_print_perf ( true ) ;
1021
1017
}
1022
- if std :: env :: var ( "ENZYME_PRINT" ) . is_ok ( ) {
1018
+ if ad . contains ( & AutoDiff :: Print ) {
1023
1019
llvm:: set_print ( true ) ;
1024
1020
}
1025
1021
@@ -1062,7 +1058,7 @@ pub(crate) unsafe fn enzyme_ad(
1062
1058
let f_return_type = LLVMGetReturnType ( LLVMGlobalGetValueType ( res) ) ;
1063
1059
1064
1060
let rev_mode = item. attrs . mode == DiffMode :: Reverse ;
1065
- create_call ( target_fnc, res, rev_mode, llmod, llcx, & size_positions) ;
1061
+ create_call ( target_fnc, res, rev_mode, llmod, llcx, & size_positions, ad ) ;
1066
1062
// TODO: implement drop for wrapper type?
1067
1063
FreeTypeAnalysis ( type_analysis) ;
1068
1064
@@ -1087,7 +1083,9 @@ pub(crate) unsafe fn differentiate(
1087
1083
1088
1084
llvm:: set_strict_aliasing ( false ) ;
1089
1085
1090
- if std:: env:: var ( "ENZYME_LOOSE_TYPES" ) . is_ok ( ) {
1086
+ let ad = & config. autodiff ;
1087
+
1088
+ if ad. contains ( & AutoDiff :: LooseTypes ) {
1091
1089
dbg ! ( "Setting loose types to true" ) ;
1092
1090
llvm:: set_loose_types ( true ) ;
1093
1091
}
@@ -1110,41 +1108,42 @@ pub(crate) unsafe fn differentiate(
1110
1108
// trigger the LLVMEnzyme pass to run on them, if we invoke the opt binary.
1111
1109
// This is super helpfull if we want to create a MWE bug reproducer, e.g. to run in
1112
1110
// Enzyme's compiler explorer. TODO: Can we run llvm-extract on the module to remove all other functions?
1113
- if std :: env :: var ( "ENZYME_OPT" ) . is_ok ( ) {
1111
+ if ad . contains ( & AutoDiff :: OPT ) {
1114
1112
dbg ! ( "Enable extra debug helper to debug Enzyme through the opt plugin" ) ;
1115
1113
crate :: builder:: add_opt_dbg_helper ( llmod, llcx, fn_def, item. attrs . clone ( ) , i) ;
1116
1114
}
1117
1115
}
1118
1116
1119
- if std :: env :: var ( "ENZYME_PRINT_MOD_BEFORE" ) . is_ok ( ) || std :: env :: var ( "ENZYME_OPT" ) . is_ok ( ) {
1117
+ if ad . contains ( & AutoDiff :: PrintModBefore ) || ad . contains ( & AutoDiff :: OPT ) {
1120
1118
unsafe {
1121
1119
LLVMDumpModule ( llmod) ;
1122
1120
}
1123
1121
}
1124
1122
1125
- if std :: env :: var ( "ENZYME_INLINE" ) . is_ok ( ) {
1123
+ if ad . contains ( & AutoDiff :: Inline ) {
1126
1124
dbg ! ( "Setting inline to true" ) ;
1127
1125
llvm:: set_inline ( true ) ;
1128
1126
}
1129
1127
1130
- if std:: env:: var ( "ENZYME_TT_DEPTH" ) . is_ok ( ) {
1131
- let depth = std:: env:: var ( "ENZYME_TT_DEPTH" ) . unwrap ( ) ;
1132
- let depth = depth. parse :: < u64 > ( ) . unwrap ( ) ;
1133
- assert ! ( depth >= 1 ) ;
1134
- llvm:: set_max_int_offset ( depth) ;
1135
- }
1136
- if std:: env:: var ( "ENZYME_TT_WIDTH" ) . is_ok ( ) {
1137
- let width = std:: env:: var ( "ENZYME_TT_WIDTH" ) . unwrap ( ) ;
1138
- let width = width. parse :: < u64 > ( ) . unwrap ( ) ;
1139
- assert ! ( width >= 1 ) ;
1140
- llvm:: set_max_type_offset ( width) ;
1141
- }
1142
-
1143
- if std:: env:: var ( "ENZYME_RUNTIME_ACTIVITY" ) . is_ok ( ) {
1128
+ if ad. contains ( & AutoDiff :: RuntimeActivity ) {
1144
1129
dbg ! ( "Setting runtime activity check to true" ) ;
1145
1130
llvm:: set_runtime_activity_check ( true ) ;
1146
1131
}
1147
1132
1133
+ for val in ad {
1134
+ match & val {
1135
+ AutoDiff :: TTDepth ( depth) => {
1136
+ assert ! ( * depth >= 1 ) ;
1137
+ llvm:: set_max_int_offset ( * depth) ;
1138
+ }
1139
+ AutoDiff :: TTWidth ( width) => {
1140
+ assert ! ( * width >= 1 ) ;
1141
+ llvm:: set_max_type_offset ( * width) ;
1142
+ }
1143
+ _ => { } ,
1144
+ }
1145
+ } ;
1146
+
1148
1147
let differentiate = !diff_items. is_empty ( ) ;
1149
1148
let mut first_order_items: Vec < AutoDiffItem > = vec ! [ ] ;
1150
1149
let mut higher_order_items: Vec < AutoDiffItem > = vec ! [ ] ;
@@ -1157,29 +1156,29 @@ pub(crate) unsafe fn differentiate(
1157
1156
}
1158
1157
}
1159
1158
1160
- let mut fnc_opt = false ;
1161
- if std:: env:: var ( "ENZYME_ENABLE_FNC_OPT" ) . is_ok ( ) {
1162
- dbg ! ( "Enable extra optimizations for Enzyme" ) ;
1163
- fnc_opt = true ;
1164
- }
1159
+
1160
+ let fnc_opt = ad. contains ( & AutoDiff :: EnableFncOpt ) ;
1165
1161
1166
1162
// If a function is a base for some higher order ad, always optimize
1167
1163
let fnc_opt_base = true ;
1168
1164
let logic_ref_opt: EnzymeLogicRef = CreateEnzymeLogic ( fnc_opt_base as u8 ) ;
1169
1165
1170
1166
for item in first_order_items {
1171
- let res = enzyme_ad ( llmod, llcx, & diag_handler, item, logic_ref_opt) ;
1167
+ let res = enzyme_ad ( llmod, llcx, & diag_handler, item, logic_ref_opt, ad ) ;
1172
1168
assert ! ( res. is_ok( ) ) ;
1173
1169
}
1174
1170
1175
1171
// For the rest, follow the user choice on debug vs release.
1176
1172
// Reuse the opt one if possible for better compile time (Enzyme internal caching).
1177
1173
let logic_ref = match fnc_opt {
1178
- true => logic_ref_opt,
1174
+ true => {
1175
+ dbg ! ( "Enable extra optimizations for Enzyme" ) ;
1176
+ logic_ref_opt
1177
+ }
1179
1178
false => CreateEnzymeLogic ( fnc_opt as u8 ) ,
1180
1179
} ;
1181
1180
for item in higher_order_items {
1182
- let res = enzyme_ad ( llmod, llcx, & diag_handler, item, logic_ref) ;
1181
+ let res = enzyme_ad ( llmod, llcx, & diag_handler, item, logic_ref, ad ) ;
1183
1182
assert ! ( res. is_ok( ) ) ;
1184
1183
}
1185
1184
@@ -1212,14 +1211,14 @@ pub(crate) unsafe fn differentiate(
1212
1211
break ;
1213
1212
}
1214
1213
}
1215
- if std :: env :: var ( "ENZYME_PRINT_MOD_AFTER_ENZYME" ) . is_ok ( ) {
1214
+ if ad . contains ( & AutoDiff :: PrintModAfterEnzyme ) {
1216
1215
unsafe {
1217
1216
LLVMDumpModule ( llmod) ;
1218
1217
}
1219
1218
}
1220
1219
1221
1220
1222
- if std :: env :: var ( "ENZYME_NO_MOD_OPT_AFTER" ) . is_ok ( ) || !differentiate {
1221
+ if ad . contains ( & AutoDiff :: NoModOptAfter ) || !differentiate {
1223
1222
trace ! ( "Skipping module optimization after automatic differentiation" ) ;
1224
1223
} else {
1225
1224
if let Some ( opt_level) = config. opt_level {
@@ -1231,18 +1230,18 @@ pub(crate) unsafe fn differentiate(
1231
1230
} ;
1232
1231
let mut first_run = false ;
1233
1232
dbg ! ( "Running Module Optimization after differentiation" ) ;
1234
- if std :: env :: var ( "ENZYME_NO_VEC_UNROLL" ) . is_ok ( ) {
1233
+ if ad . contains ( & AutoDiff :: NoVecUnroll ) {
1235
1234
// disables vectorization and loop unrolling
1236
1235
first_run = true ;
1237
1236
}
1238
- if std :: env :: var ( "ENZYME_ALT_PIPELINE" ) . is_ok ( ) {
1237
+ if ad . contains ( & AutoDiff :: AltPipeline ) {
1239
1238
dbg ! ( "Running first postAD optimization" ) ;
1240
1239
first_run = true ;
1241
1240
}
1242
1241
let noop = false ;
1243
1242
llvm_optimize ( cgcx, & diag_handler, module, config, opt_level, opt_stage, first_run, noop) ?;
1244
1243
}
1245
- if std :: env :: var ( "ENZYME_ALT_PIPELINE" ) . is_ok ( ) {
1244
+ if ad . contains ( & AutoDiff :: AltPipeline ) {
1246
1245
dbg ! ( "Running Second postAD optimization" ) ;
1247
1246
if let Some ( opt_level) = config. opt_level {
1248
1247
let opt_stage = match cgcx. lto {
@@ -1253,7 +1252,7 @@ pub(crate) unsafe fn differentiate(
1253
1252
} ;
1254
1253
let mut first_run = false ;
1255
1254
dbg ! ( "Running Module Optimization after differentiation" ) ;
1256
- if std :: env :: var ( "ENZYME_NO_VEC_UNROLL" ) . is_ok ( ) {
1255
+ if ad . contains ( & AutoDiff :: NoVecUnroll ) {
1257
1256
// enables vectorization and loop unrolling
1258
1257
first_run = false ;
1259
1258
}
@@ -1263,7 +1262,7 @@ pub(crate) unsafe fn differentiate(
1263
1262
}
1264
1263
}
1265
1264
1266
- if std :: env :: var ( "ENZYME_PRINT_MOD_AFTER_OPTS" ) . is_ok ( ) {
1265
+ if ad . contains ( & AutoDiff :: PrintModAfterOpts ) {
1267
1266
unsafe {
1268
1267
LLVMDumpModule ( llmod) ;
1269
1268
}
@@ -1341,15 +1340,16 @@ pub(crate) unsafe fn optimize(
1341
1340
_ if cgcx. opts . cg . linker_plugin_lto . enabled ( ) => llvm:: OptStage :: PreLinkThinLTO ,
1342
1341
_ => llvm:: OptStage :: PreLinkNoLTO ,
1343
1342
} ;
1343
+
1344
1344
// Second run only relevant for AD
1345
1345
let first_run = true ;
1346
- let noop;
1347
- if std :: env :: var ( "ENZYME_ALT_PIPELINE" ) . is_ok ( ) {
1348
- noop = true ;
1349
- dbg ! ( "Skipping PreAD optimization" ) ;
1350
- } else {
1351
- noop = false ;
1352
- }
1346
+ let noop = false ;
1347
+ // if ad.contains(&AutoDiff::AltPipeline ) {
1348
+ // noop = true;
1349
+ // dbg!("Skipping PreAD optimization");
1350
+ // } else {
1351
+ // noop = false;
1352
+ // }
1353
1353
return llvm_optimize ( cgcx, dcx, module, config, opt_level, opt_stage, first_run, noop) ;
1354
1354
}
1355
1355
Ok ( ( ) )
0 commit comments