@@ -313,6 +313,7 @@ Instruction(is::InstructionStream) = Instruction(is, add_new_idx!(is))
313
313
fldarray = getfield (getfield (node, :data ), fld)
314
314
fldidx = getfield (node, :idx )
315
315
(fld === :line ) && return (fldarray[3 fldidx- 2 ], fldarray[3 fldidx- 1 ], fldarray[3 fldidx- 0 ])
316
+ (1 ≤ fldidx ≤ length (fldarray)) || throw (InvalidIRError ())
316
317
return fldarray[fldidx]
317
318
end
318
319
@inline function setindex! (node:: Instruction , @nospecialize (val), fld:: Symbol )
@@ -481,11 +482,16 @@ function block_for_inst(ir::IRCode, inst::Int)
481
482
end
482
483
483
484
function getindex (ir:: IRCode , s:: SSAValue )
485
+ id = s. id
486
+ (id ≥ 1 ) || throw (InvalidIRError ())
484
487
nstmts = length (ir. stmts)
485
- if s . id <= nstmts
486
- return ir. stmts[s . id]
488
+ if id <= nstmts
489
+ return ir. stmts[id]
487
490
else
488
- return ir. new_nodes. stmts[s. id - nstmts]
491
+ id -= nstmts
492
+ stmts = ir. new_nodes. stmts
493
+ (id ≤ length (stmts)) || throw (InvalidIRError ())
494
+ return stmts[id]
489
495
end
490
496
end
491
497
@@ -801,12 +807,13 @@ end
801
807
types (ir:: Union{IRCode, IncrementalCompact} ) = TypesView (ir)
802
808
803
809
function getindex (compact:: IncrementalCompact , ssa:: SSAValue )
804
- @assert ssa. id < compact. result_idx
810
+ ( 1 ≤ ssa. id ≤ compact. result_idx) || throw ( InvalidIRError ())
805
811
return compact. result[ssa. id]
806
812
end
807
813
808
814
function getindex (compact:: IncrementalCompact , ssa:: OldSSAValue )
809
815
id = ssa. id
816
+ (id ≥ 1 ) || throw (InvalidIRError ())
810
817
if id < compact. idx
811
818
new_idx = compact. ssa_rename[id]:: Int
812
819
return compact. result[new_idx]
@@ -818,12 +825,15 @@ function getindex(compact::IncrementalCompact, ssa::OldSSAValue)
818
825
return compact. ir. new_nodes. stmts[id]
819
826
end
820
827
id -= length (compact. ir. new_nodes)
828
+ (id ≤ length (compact. pending_nodes. stmts)) || throw (InvalidIRError ())
821
829
return compact. pending_nodes. stmts[id]
822
830
end
823
831
824
832
function getindex (compact:: IncrementalCompact , ssa:: NewSSAValue )
825
833
if ssa. id < 0
826
- return compact. new_new_nodes. stmts[- ssa. id]
834
+ stmts = compact. new_new_nodes. stmts
835
+ (- ssa. id ≤ length (stmts)) || throw (InvalidIRError ())
836
+ return stmts[- ssa. id]
827
837
else
828
838
return compact[SSAValue (ssa. id)]
829
839
end
@@ -1069,6 +1079,7 @@ function getindex(view::TypesView, v::OldSSAValue)
1069
1079
id = v. id
1070
1080
ir = view. ir. ir
1071
1081
stmts = ir. stmts
1082
+ (id ≥ 1 ) || throw (InvalidIRError ())
1072
1083
if id <= length (stmts)
1073
1084
return stmts[id][:type ]
1074
1085
end
@@ -1077,7 +1088,9 @@ function getindex(view::TypesView, v::OldSSAValue)
1077
1088
return ir. new_nodes. stmts[id][:type ]
1078
1089
end
1079
1090
id -= length (ir. new_nodes)
1080
- return view. ir. pending_nodes. stmts[id][:type ]
1091
+ stmts = view. ir. pending_nodes. stmts
1092
+ (id ≤ length (stmts)) || throw (InvalidIRError ())
1093
+ return stmts[id][:type ]
1081
1094
end
1082
1095
1083
1096
function kill_current_use! (compact:: IncrementalCompact , @nospecialize (val))
@@ -1204,20 +1217,27 @@ end
1204
1217
1205
1218
getindex (view:: TypesView , idx:: SSAValue ) = getindex (view, idx. id)
1206
1219
function getindex (view:: TypesView , idx:: Int )
1220
+ (idx ≥ 1 ) || throw (InvalidIRError ())
1207
1221
if isa (view. ir, IncrementalCompact) && idx < view. ir. result_idx
1208
1222
return view. ir. result[idx][:type ]
1209
1223
elseif isa (view. ir, IncrementalCompact) && view. ir. renamed_new_nodes
1210
1224
if idx <= length (view. ir. result)
1211
1225
return view. ir. result[idx][:type ]
1212
1226
else
1213
- return view. ir. new_new_nodes. stmts[idx - length (view. ir. result)][:type ]
1227
+ idx -= length (view. ir. result)
1228
+ stmts = view. ir. new_new_nodes. stmts
1229
+ (idx ≤ length (stmts)) || throw (InvalidIRError ())
1230
+ return stmts[idx][:type ]
1214
1231
end
1215
1232
else
1216
1233
ir = isa (view. ir, IncrementalCompact) ? view. ir. ir : view. ir
1217
1234
if idx <= length (ir. stmts)
1218
1235
return ir. stmts[idx][:type ]
1219
1236
else
1220
- return ir. new_nodes. stmts[idx - length (ir. stmts)][:type ]
1237
+ idx -= length (ir. stmts)
1238
+ stmts = ir. new_nodes. stmts
1239
+ (idx ≤ length (stmts)) || throw (InvalidIRError ())
1240
+ return stmts[idx][:type ]
1221
1241
end
1222
1242
end
1223
1243
end
0 commit comments