Skip to content

Commit 1f83693

Browse files
authored
Rust flags (#159)
* add RUSTFLAGS version of enzyme flags * Remove old env arg checks, use flags now * small fixups
1 parent 0bd1b5d commit 1f83693

File tree

5 files changed

+146
-55
lines changed

5 files changed

+146
-55
lines changed

compiler/rustc_codegen_llvm/src/back/write.rs

+54-54
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ use rustc_data_structures::small_c_str::SmallCStr;
6969
use rustc_errors::{DiagCtxt, FatalError, Level};
7070
use rustc_fs_util::{link_or_copy, path_to_c_string};
7171
use rustc_middle::ty::TyCtxt;
72-
use rustc_session::config::{self, Lto, OutputType, Passes, SplitDwarfKind, SwitchWithOptPath};
72+
use rustc_session::config::{self, AutoDiff, Lto, OutputType, Passes, SplitDwarfKind, SwitchWithOptPath};
7373
use rustc_session::Session;
7474
use rustc_span::symbol::sym;
7575
use rustc_span::InnerSpan;
@@ -707,7 +707,7 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
707707

708708

709709
unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,
710-
llmod: &'a llvm::Module, llcx: &llvm::Context, size_positions: &[usize]) {
710+
llmod: &'a llvm::Module, llcx: &llvm::Context, size_positions: &[usize], ad: &[AutoDiff]) {
711711

712712
// first, remove all calls from fnc
713713
let bb = LLVMGetFirstBasicBlock(tgt);
@@ -729,12 +729,7 @@ unsafe fn create_call<'a>(tgt: &'a Value, src: &'a Value, rev_mode: bool,
729729
let last_inst = LLVMRustGetLastInstruction(bb).unwrap();
730730
LLVMPositionBuilderAtEnd(builder, bb);
731731

732-
let safety_run_checks;
733-
if std::env::var("ENZYME_NO_SAFETY_CHECKS").is_ok() {
734-
safety_run_checks = false;
735-
} else {
736-
safety_run_checks = true;
737-
}
732+
let safety_run_checks = !ad.contains(&AutoDiff::NoSafetyChecks);
738733

739734
if inner_param_num == outer_param_num {
740735
call_args = outer_args;
@@ -951,6 +946,7 @@ pub(crate) unsafe fn enzyme_ad(
951946
diag_handler: &DiagCtxt,
952947
item: AutoDiffItem,
953948
logic_ref: EnzymeLogicRef,
949+
ad: &[AutoDiff],
954950
) -> Result<(), FatalError> {
955951
let autodiff_mode = item.attrs.mode;
956952
let rust_name = item.source;
@@ -1010,16 +1006,16 @@ pub(crate) unsafe fn enzyme_ad(
10101006

10111007
llvm::set_strict_aliasing(false);
10121008

1013-
if std::env::var("ENZYME_PRINT_TA").is_ok() {
1009+
if ad.contains(&AutoDiff::PrintTA) {
10141010
llvm::set_print_type(true);
10151011
}
1016-
if std::env::var("ENZYME_PRINT_AA").is_ok() {
1017-
llvm::set_print_activity(true);
1012+
if ad.contains(&AutoDiff::PrintTA) {
1013+
llvm::set_print_type(true);
10181014
}
1019-
if std::env::var("ENZYME_PRINT_PERF").is_ok() {
1015+
if ad.contains(&AutoDiff::PrintPerf) {
10201016
llvm::set_print_perf(true);
10211017
}
1022-
if std::env::var("ENZYME_PRINT").is_ok() {
1018+
if ad.contains(&AutoDiff::Print) {
10231019
llvm::set_print(true);
10241020
}
10251021

@@ -1062,7 +1058,7 @@ pub(crate) unsafe fn enzyme_ad(
10621058
let f_return_type = LLVMGetReturnType(LLVMGlobalGetValueType(res));
10631059

10641060
let rev_mode = item.attrs.mode == DiffMode::Reverse;
1065-
create_call(target_fnc, res, rev_mode, llmod, llcx, &size_positions);
1061+
create_call(target_fnc, res, rev_mode, llmod, llcx, &size_positions, ad);
10661062
// TODO: implement drop for wrapper type?
10671063
FreeTypeAnalysis(type_analysis);
10681064

@@ -1087,7 +1083,9 @@ pub(crate) unsafe fn differentiate(
10871083

10881084
llvm::set_strict_aliasing(false);
10891085

1090-
if std::env::var("ENZYME_LOOSE_TYPES").is_ok() {
1086+
let ad = &config.autodiff;
1087+
1088+
if ad.contains(&AutoDiff::LooseTypes) {
10911089
dbg!("Setting loose types to true");
10921090
llvm::set_loose_types(true);
10931091
}
@@ -1110,41 +1108,42 @@ pub(crate) unsafe fn differentiate(
11101108
// trigger the LLVMEnzyme pass to run on them, if we invoke the opt binary.
11111109
// This is super helpfull if we want to create a MWE bug reproducer, e.g. to run in
11121110
// Enzyme's compiler explorer. TODO: Can we run llvm-extract on the module to remove all other functions?
1113-
if std::env::var("ENZYME_OPT").is_ok() {
1111+
if ad.contains(&AutoDiff::OPT) {
11141112
dbg!("Enable extra debug helper to debug Enzyme through the opt plugin");
11151113
crate::builder::add_opt_dbg_helper(llmod, llcx, fn_def, item.attrs.clone(), i);
11161114
}
11171115
}
11181116

1119-
if std::env::var("ENZYME_PRINT_MOD_BEFORE").is_ok() || std::env::var("ENZYME_OPT").is_ok(){
1117+
if ad.contains(&AutoDiff::PrintModBefore) || ad.contains(&AutoDiff::OPT) {
11201118
unsafe {
11211119
LLVMDumpModule(llmod);
11221120
}
11231121
}
11241122

1125-
if std::env::var("ENZYME_INLINE").is_ok() {
1123+
if ad.contains(&AutoDiff::Inline) {
11261124
dbg!("Setting inline to true");
11271125
llvm::set_inline(true);
11281126
}
11291127

1130-
if std::env::var("ENZYME_TT_DEPTH").is_ok() {
1131-
let depth = std::env::var("ENZYME_TT_DEPTH").unwrap();
1132-
let depth = depth.parse::<u64>().unwrap();
1133-
assert!(depth >= 1);
1134-
llvm::set_max_int_offset(depth);
1135-
}
1136-
if std::env::var("ENZYME_TT_WIDTH").is_ok() {
1137-
let width = std::env::var("ENZYME_TT_WIDTH").unwrap();
1138-
let width = width.parse::<u64>().unwrap();
1139-
assert!(width >= 1);
1140-
llvm::set_max_type_offset(width);
1141-
}
1142-
1143-
if std::env::var("ENZYME_RUNTIME_ACTIVITY").is_ok() {
1128+
if ad.contains(&AutoDiff::RuntimeActivity) {
11441129
dbg!("Setting runtime activity check to true");
11451130
llvm::set_runtime_activity_check(true);
11461131
}
11471132

1133+
for val in ad {
1134+
match &val {
1135+
AutoDiff::TTDepth(depth) => {
1136+
assert!(*depth >= 1);
1137+
llvm::set_max_int_offset(*depth);
1138+
}
1139+
AutoDiff::TTWidth(width) => {
1140+
assert!(*width >= 1);
1141+
llvm::set_max_type_offset(*width);
1142+
}
1143+
_ => {},
1144+
}
1145+
};
1146+
11481147
let differentiate = !diff_items.is_empty();
11491148
let mut first_order_items: Vec<AutoDiffItem> = vec![];
11501149
let mut higher_order_items: Vec<AutoDiffItem> = vec![];
@@ -1157,29 +1156,29 @@ pub(crate) unsafe fn differentiate(
11571156
}
11581157
}
11591158

1160-
let mut fnc_opt = false;
1161-
if std::env::var("ENZYME_ENABLE_FNC_OPT").is_ok() {
1162-
dbg!("Enable extra optimizations for Enzyme");
1163-
fnc_opt = true;
1164-
}
1159+
1160+
let fnc_opt = ad.contains(&AutoDiff::EnableFncOpt);
11651161

11661162
// If a function is a base for some higher order ad, always optimize
11671163
let fnc_opt_base = true;
11681164
let logic_ref_opt: EnzymeLogicRef = CreateEnzymeLogic(fnc_opt_base as u8);
11691165

11701166
for item in first_order_items {
1171-
let res = enzyme_ad(llmod, llcx, &diag_handler, item, logic_ref_opt);
1167+
let res = enzyme_ad(llmod, llcx, &diag_handler, item, logic_ref_opt, ad);
11721168
assert!(res.is_ok());
11731169
}
11741170

11751171
// For the rest, follow the user choice on debug vs release.
11761172
// Reuse the opt one if possible for better compile time (Enzyme internal caching).
11771173
let logic_ref = match fnc_opt {
1178-
true => logic_ref_opt,
1174+
true => {
1175+
dbg!("Enable extra optimizations for Enzyme");
1176+
logic_ref_opt
1177+
}
11791178
false => CreateEnzymeLogic(fnc_opt as u8),
11801179
};
11811180
for item in higher_order_items {
1182-
let res = enzyme_ad(llmod, llcx, &diag_handler, item, logic_ref);
1181+
let res = enzyme_ad(llmod, llcx, &diag_handler, item, logic_ref, ad);
11831182
assert!(res.is_ok());
11841183
}
11851184

@@ -1212,14 +1211,14 @@ pub(crate) unsafe fn differentiate(
12121211
break;
12131212
}
12141213
}
1215-
if std::env::var("ENZYME_PRINT_MOD_AFTER_ENZYME").is_ok() {
1214+
if ad.contains(&AutoDiff::PrintModAfterEnzyme) {
12161215
unsafe {
12171216
LLVMDumpModule(llmod);
12181217
}
12191218
}
12201219

12211220

1222-
if std::env::var("ENZYME_NO_MOD_OPT_AFTER").is_ok() || !differentiate {
1221+
if ad.contains(&AutoDiff::NoModOptAfter) || !differentiate {
12231222
trace!("Skipping module optimization after automatic differentiation");
12241223
} else {
12251224
if let Some(opt_level) = config.opt_level {
@@ -1231,18 +1230,18 @@ pub(crate) unsafe fn differentiate(
12311230
};
12321231
let mut first_run = false;
12331232
dbg!("Running Module Optimization after differentiation");
1234-
if std::env::var("ENZYME_NO_VEC_UNROLL").is_ok() {
1233+
if ad.contains(&AutoDiff::NoVecUnroll) {
12351234
// disables vectorization and loop unrolling
12361235
first_run = true;
12371236
}
1238-
if std::env::var("ENZYME_ALT_PIPELINE").is_ok() {
1237+
if ad.contains(&AutoDiff::AltPipeline) {
12391238
dbg!("Running first postAD optimization");
12401239
first_run = true;
12411240
}
12421241
let noop = false;
12431242
llvm_optimize(cgcx, &diag_handler, module, config, opt_level, opt_stage, first_run, noop)?;
12441243
}
1245-
if std::env::var("ENZYME_ALT_PIPELINE").is_ok() {
1244+
if ad.contains(&AutoDiff::AltPipeline) {
12461245
dbg!("Running Second postAD optimization");
12471246
if let Some(opt_level) = config.opt_level {
12481247
let opt_stage = match cgcx.lto {
@@ -1253,7 +1252,7 @@ pub(crate) unsafe fn differentiate(
12531252
};
12541253
let mut first_run = false;
12551254
dbg!("Running Module Optimization after differentiation");
1256-
if std::env::var("ENZYME_NO_VEC_UNROLL").is_ok() {
1255+
if ad.contains(&AutoDiff::NoVecUnroll) {
12571256
// enables vectorization and loop unrolling
12581257
first_run = false;
12591258
}
@@ -1263,7 +1262,7 @@ pub(crate) unsafe fn differentiate(
12631262
}
12641263
}
12651264

1266-
if std::env::var("ENZYME_PRINT_MOD_AFTER_OPTS").is_ok() {
1265+
if ad.contains(&AutoDiff::PrintModAfterOpts) {
12671266
unsafe {
12681267
LLVMDumpModule(llmod);
12691268
}
@@ -1341,15 +1340,16 @@ pub(crate) unsafe fn optimize(
13411340
_ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO,
13421341
_ => llvm::OptStage::PreLinkNoLTO,
13431342
};
1343+
13441344
// Second run only relevant for AD
13451345
let first_run = true;
1346-
let noop;
1347-
if std::env::var("ENZYME_ALT_PIPELINE").is_ok() {
1348-
noop = true;
1349-
dbg!("Skipping PreAD optimization");
1350-
} else {
1351-
noop = false;
1352-
}
1346+
let noop = false;
1347+
//if ad.contains(&AutoDiff::AltPipeline) {
1348+
// noop = true;
1349+
// dbg!("Skipping PreAD optimization");
1350+
//} else {
1351+
// noop = false;
1352+
//}
13531353
return llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run, noop);
13541354
}
13551355
Ok(())

compiler/rustc_codegen_ssa/src/back/write.rs

+2
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ pub struct ModuleConfig {
118118
pub inline_threshold: Option<u32>,
119119
pub emit_lifetime_markers: bool,
120120
pub llvm_plugins: Vec<String>,
121+
pub autodiff: Vec<config::AutoDiff>,
121122
}
122123

123124
impl ModuleConfig {
@@ -259,6 +260,7 @@ impl ModuleConfig {
259260
inline_threshold: sess.opts.cg.inline_threshold,
260261
emit_lifetime_markers: sess.emit_lifetime_markers(),
261262
llvm_plugins: if_regular!(sess.opts.unstable_opts.llvm_plugins.clone(), vec![]),
263+
autodiff: if_regular!(sess.opts.unstable_opts.autodiff.clone(), vec![]),
262264
}
263265
}
264266

compiler/rustc_interface/src/tests.rs

+1
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,7 @@ fn test_unstable_options_tracking_hash() {
729729

730730
// Make sure that changing a [TRACKED] option changes the hash.
731731
// tidy-alphabetical-start
732+
tracked!(autodiff, vec![String::from("ad_flags")]);
732733
tracked!(allow_features, Some(vec![String::from("lang_items")]));
733734
tracked!(always_encode_mir, true);
734735
tracked!(asm_comments, true);

compiler/rustc_session/src/config.rs

+50-1
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,53 @@ pub enum InstrumentCoverage {
174174
Off,
175175
}
176176

177+
/// The different settings that the `-Z ad` flag can have.
178+
#[derive(Clone, Copy, PartialEq, Hash, Debug)]
179+
pub enum AutoDiff {
180+
/// Print TypeAnalysis information
181+
PrintTA,
182+
/// Print ActivityAnalysis Information
183+
PrintAA,
184+
/// Print Performance Warnings from Enzyme
185+
PrintPerf,
186+
/// Combines the three print flags above.
187+
Print,
188+
/// Print the whole module, before running opts.
189+
PrintModBefore,
190+
/// Print the whole module just before we pass it to Enzyme.
191+
/// For Debug purpose, prefer the OPT flag below
192+
PrintModAfterOpts,
193+
/// Print the module after Enzyme differentiated everything.
194+
PrintModAfterEnzyme,
195+
196+
/// Enzyme's loose type debug helper (can cause incorrect gradients)
197+
LooseTypes,
198+
/// Output a Module using __enzyme calls to prepare it for opt + enzyme pass usage
199+
OPT,
200+
201+
/// TypeTree options
202+
/// TODO: Figure out how to let users construct these,
203+
/// or whether we want to leave this option in the first place.
204+
TTWidth(u64),
205+
TTDepth(u64),
206+
207+
/// More flags
208+
NoModOptAfter,
209+
/// Tell Enzyme to run LLVM Opts on each function it generated. By default off,
210+
/// since we already optimize the whole module after Enzyme is done.
211+
EnableFncOpt,
212+
NoVecUnroll,
213+
/// Obviously unsafe, disable the length checks that we have for shadow args.
214+
NoSafetyChecks,
215+
RuntimeActivity,
216+
/// Runs Enzyme specific Inlining
217+
Inline,
218+
/// Runs Optimization twice after AD, and zero times after.
219+
/// This is mainly for Benchmarking purpose to show that
220+
/// compiler based AD has a performance benefit. TODO: fix
221+
AltPipeline,
222+
}
223+
177224
/// Settings for `-Z instrument-xray` flag.
178225
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
179226
pub struct InstrumentXRay {
@@ -3229,8 +3276,9 @@ pub(crate) mod dep_tracking {
32293276
LinkerPluginLto, LocationDetail, LtoCli, NextSolverConfig, OomStrategy, OptLevel,
32303277
OutFileName, OutputType, OutputTypes, Polonius, RemapPathScopeComponents, ResolveDocLinks,
32313278
SourceFileHashAlgorithm, SplitDwarfKind, SwitchWithOptPath, SymbolManglingVersion,
3232-
TrimmedDefPaths, WasiExecModel,
3279+
TrimmedDefPaths, WasiExecModel, AutoDiff,
32333280
};
3281+
//use crate::config::AutoDiff;
32343282
use crate::lint;
32353283
use crate::utils::NativeLib;
32363284
use rustc_data_structures::fx::FxIndexMap;
@@ -3285,6 +3333,7 @@ pub(crate) mod dep_tracking {
32853333
}
32863334

32873335
impl_dep_tracking_hash_via_hash!(
3336+
AutoDiff,
32883337
bool,
32893338
usize,
32903339
NonZeroUsize,

0 commit comments

Comments
 (0)