Skip to content

Commit 5117d04

Browse files
authored
EA: use is_mutation_free_argtype for the escapability check (#56028)
EA has been using `isbitstype` for type-level escapability checks, but a better criterion (`is_mutation_free`) is available these days, so we would like to use that instead.
1 parent e516e4c commit 5117d04

File tree

2 files changed

+63
-63
lines changed

2 files changed

+63
-63
lines changed

base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,10 @@ using ._TOP_MOD: # Base definitions
2424
isempty, ismutabletype, keys, last, length, max, min, missing, pop!, push!, pushfirst!,
2525
unwrap_unionall, !, !=, !==, &, *, +, -, :, <, <<, =>, >, |, , , , , , , ,
2626
using Core.Compiler: # Core.Compiler specific definitions
27-
Bottom, IRCode, IR_FLAG_NOTHROW, InferenceResult, SimpleInferenceLattice,
27+
AbstractLattice, Bottom, IRCode, IR_FLAG_NOTHROW, InferenceResult, SimpleInferenceLattice,
2828
argextype, fieldcount_noerror, hasintersect, has_flag, intrinsic_nothrow,
29-
is_meta_expr_head, isbitstype, isexpr, println, setfield!_nothrow, singleton_type,
30-
try_compute_field, try_compute_fieldidx, widenconst, , AbstractLattice
29+
is_meta_expr_head, is_mutation_free_argtype, isexpr, println, setfield!_nothrow,
30+
singleton_type, try_compute_field, try_compute_fieldidx, widenconst,
3131

3232
include(x) = _TOP_MOD.include(@__MODULE__, x)
3333
if _TOP_MOD === Core.Compiler
@@ -859,7 +859,7 @@ function add_escape_change!(astate::AnalysisState, @nospecialize(x), xinfo::Esca
859859
xinfo ===&& return nothing # performance optimization
860860
xidx = iridx(x, astate.estate)
861861
if xidx !== nothing
862-
if force || !isbitstype(widenconst(argextype(x, astate.ir)))
862+
if force || !is_mutation_free_argtype(argextype(x, astate.ir))
863863
push!(astate.changes, EscapeChange(xidx, xinfo))
864864
end
865865
end
@@ -869,7 +869,7 @@ end
869869
function add_liveness_change!(astate::AnalysisState, @nospecialize(x), livepc::Int)
870870
xidx = iridx(x, astate.estate)
871871
if xidx !== nothing
872-
if !isbitstype(widenconst(argextype(x, astate.ir)))
872+
if !is_mutation_free_argtype(argextype(x, astate.ir))
873873
push!(astate.changes, LivenessChange(xidx, livepc))
874874
end
875875
end

test/compiler/EscapeAnalysis/EscapeAnalysis.jl

Lines changed: 58 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ end
290290

291291
let # typeassert
292292
result = code_escapes((Any,)) do x
293-
y = x::String
293+
y = x::Base.RefValue{Any}
294294
return y
295295
end
296296
r = only(findall(isreturn, result.ir.stmts.stmt))
@@ -305,11 +305,6 @@ end
305305
r = only(findall(isreturn, result.ir.stmts.stmt))
306306
@test has_return_escape(result.state[Argument(2)], r)
307307
@test !has_all_escape(result.state[Argument(2)])
308-
309-
result = code_escapes((Module,)) do m
310-
isdefined(m, 10) # throws
311-
end
312-
@test has_thrown_escape(result.state[Argument(2)])
313308
end
314309
end
315310

@@ -685,8 +680,8 @@ end
685680
@test has_all_escape(result.state[Argument(2)])
686681
end
687682
let result = @eval EATModule() begin
688-
const Rx = SafeRef{String}("Rx")
689-
$code_escapes((String,)) do s
683+
const Rx = SafeRef{Any}(nothing)
684+
$code_escapes((Base.RefValue{String},)) do s
690685
setfield!(Rx, :x, s)
691686
Core.sizeof(Rx[])
692687
end
@@ -712,7 +707,7 @@ end
712707
# ------------
713708

714709
# field escape should propagate to :new arguments
715-
let result = code_escapes((String,)) do a
710+
let result = code_escapes((Base.RefValue{String},)) do a
716711
o = SafeRef(a)
717712
Core.donotdelete(o)
718713
return o[]
@@ -722,7 +717,7 @@ end
722717
@test has_return_escape(result.state[Argument(2)], r)
723718
@test is_load_forwardable(result.state[SSAValue(i)])
724719
end
725-
let result = code_escapes((String,)) do a
720+
let result = code_escapes((Base.RefValue{String},)) do a
726721
t = SafeRef((a,))
727722
f = t[][1]
728723
return f
@@ -731,9 +726,8 @@ end
731726
r = only(findall(isreturn, result.ir.stmts.stmt))
732727
@test has_return_escape(result.state[Argument(2)], r)
733728
@test is_load_forwardable(result.state[SSAValue(i)])
734-
result.state[SSAValue(i)].AliasInfo
735729
end
736-
let result = code_escapes((String, String)) do a, b
730+
let result = code_escapes((Base.RefValue{String}, Base.RefValue{String})) do a, b
737731
obj = SafeRefs(a, b)
738732
Core.donotdelete(obj)
739733
fld1 = obj[1]
@@ -748,31 +742,31 @@ end
748742
end
749743

750744
# field escape should propagate to `setfield!` argument
751-
let result = code_escapes((String,)) do a
752-
o = SafeRef("foo")
745+
let result = code_escapes((Base.RefValue{String},)) do a
746+
o = SafeRef(Ref("foo"))
753747
Core.donotdelete(o)
754748
o[] = a
755749
return o[]
756750
end
757-
i = only(findall(isnew, result.ir.stmts.stmt))
751+
i = last(findall(isnew, result.ir.stmts.stmt))
758752
r = only(findall(isreturn, result.ir.stmts.stmt))
759753
@test has_return_escape(result.state[Argument(2)], r)
760754
@test is_load_forwardable(result.state[SSAValue(i)])
761755
end
762756
# propagate escape information imposed on return value of `setfield!` call
763-
let result = code_escapes((String,)) do a
764-
obj = SafeRef("foo")
757+
let result = code_escapes((Base.RefValue{String},)) do a
758+
obj = SafeRef(Ref("foo"))
765759
Core.donotdelete(obj)
766760
return (obj[] = a)
767761
end
768-
i = only(findall(isnew, result.ir.stmts.stmt))
762+
i = last(findall(isnew, result.ir.stmts.stmt))
769763
r = only(findall(isreturn, result.ir.stmts.stmt))
770764
@test has_return_escape(result.state[Argument(2)], r)
771765
@test is_load_forwardable(result.state[SSAValue(i)])
772766
end
773767

774768
# nested allocations
775-
let result = code_escapes((String,)) do a
769+
let result = code_escapes((Base.RefValue{String},)) do a
776770
o1 = SafeRef(a)
777771
o2 = SafeRef(o1)
778772
return o2[]
@@ -787,7 +781,7 @@ end
787781
end
788782
end
789783
end
790-
let result = code_escapes((String,)) do a
784+
let result = code_escapes((Base.RefValue{String},)) do a
791785
o1 = (a,)
792786
o2 = (o1,)
793787
return o2[1]
@@ -802,7 +796,7 @@ end
802796
end
803797
end
804798
end
805-
let result = code_escapes((String,)) do a
799+
let result = code_escapes((Base.RefValue{String},)) do a
806800
o1 = SafeRef(a)
807801
o2 = SafeRef(o1)
808802
o1′ = o2[]
@@ -844,7 +838,7 @@ end
844838
@test has_return_escape(result.state[SSAValue(i)], r)
845839
end
846840
end
847-
let result = code_escapes((String,)) do x
841+
let result = code_escapes((Base.RefValue{String},)) do x
848842
o = Ref(x)
849843
Core.donotdelete(o)
850844
broadcast(identity, o)
@@ -892,7 +886,7 @@ end
892886
end
893887
end
894888
# when ϕ-node merges values with different types
895-
let result = code_escapes((Bool,String,String,String)) do cond, x, y, z
889+
let result = code_escapes((Bool,Base.RefValue{String},Base.RefValue{String},Base.RefValue{String})) do cond, x, y, z
896890
local out
897891
if cond
898892
ϕ = SafeRef(x)
@@ -904,7 +898,7 @@ end
904898
end
905899
r = only(findall(isreturn, result.ir.stmts.stmt))
906900
t = only(findall(iscall((result.ir, throw)), result.ir.stmts.stmt))
907-
ϕ = only(findall(==(Union{SafeRef{String},SafeRefs{String,String}}), result.ir.stmts.type))
901+
ϕ = only(findall(==(Union{SafeRef{Base.RefValue{String}},SafeRefs{Base.RefValue{String},Base.RefValue{String}}}), result.ir.stmts.type))
908902
@test has_return_escape(result.state[Argument(3)], r) # x
909903
@test !has_return_escape(result.state[Argument(4)], r) # y
910904
@test has_return_escape(result.state[Argument(5)], r) # z
@@ -1038,7 +1032,7 @@ end
10381032
end
10391033
# alias via typeassert
10401034
let result = code_escapes((Any,)) do a
1041-
r = a::String
1035+
r = a::Base.RefValue{String}
10421036
return r
10431037
end
10441038
r = only(findall(isreturn, result.ir.stmts.stmt))
@@ -1077,11 +1071,11 @@ end
10771071
@test has_all_escape(result.state[Argument(3)]) # a
10781072
end
10791073
# alias via ϕ-node
1080-
let result = code_escapes((Bool,String)) do cond, x
1074+
let result = code_escapes((Bool,Base.RefValue{String})) do cond, x
10811075
if cond
1082-
ϕ2 = ϕ1 = SafeRef("foo")
1076+
ϕ2 = ϕ1 = SafeRef(Ref("foo"))
10831077
else
1084-
ϕ2 = ϕ1 = SafeRef("bar")
1078+
ϕ2 = ϕ1 = SafeRef(Ref("bar"))
10851079
end
10861080
ϕ2[] = x
10871081
return ϕ1[]
@@ -1094,14 +1088,16 @@ end
10941088
@test is_load_forwardable(result.state[SSAValue(i)])
10951089
end
10961090
for i in findall(isnew, result.ir.stmts.stmt)
1097-
@test is_load_forwardable(result.state[SSAValue(i)])
1091+
if result.ir[SSAValue(i)][:type] <: SafeRef
1092+
@test is_load_forwardable(result.state[SSAValue(i)])
1093+
end
10981094
end
10991095
end
1100-
let result = code_escapes((Bool,Bool,String)) do cond1, cond2, x
1096+
let result = code_escapes((Bool,Bool,Base.RefValue{String})) do cond1, cond2, x
11011097
if cond1
1102-
ϕ2 = ϕ1 = SafeRef("foo")
1098+
ϕ2 = ϕ1 = SafeRef(Ref("foo"))
11031099
else
1104-
ϕ2 = ϕ1 = SafeRef("bar")
1100+
ϕ2 = ϕ1 = SafeRef(Ref("bar"))
11051101
end
11061102
cond2 && (ϕ2[] = x)
11071103
return ϕ1[]
@@ -1114,12 +1110,14 @@ end
11141110
@test is_load_forwardable(result.state[SSAValue(i)])
11151111
end
11161112
for i in findall(isnew, result.ir.stmts.stmt)
1117-
@test is_load_forwardable(result.state[SSAValue(i)])
1113+
if result.ir[SSAValue(i)][:type] <: SafeRef
1114+
@test is_load_forwardable(result.state[SSAValue(i)])
1115+
end
11181116
end
11191117
end
11201118
# alias via π-node
11211119
let result = code_escapes((Any,)) do x
1122-
if isa(x, String)
1120+
if isa(x, Base.RefValue{String})
11231121
return x
11241122
end
11251123
throw("error!")
@@ -1213,7 +1211,7 @@ end
12131211

12141212
# conservatively handle unknown field:
12151213
# all fields should be escaped, but the allocation itself doesn't need to be escaped
1216-
let result = code_escapes((String, Symbol)) do a, fld
1214+
let result = code_escapes((Base.RefValue{String}, Symbol)) do a, fld
12171215
obj = SafeRef(a)
12181216
return getfield(obj, fld)
12191217
end
@@ -1222,7 +1220,7 @@ end
12221220
@test has_return_escape(result.state[Argument(2)], r) # a
12231221
@test !is_load_forwardable(result.state[SSAValue(i)]) # obj
12241222
end
1225-
let result = code_escapes((String, String, Symbol)) do a, b, fld
1223+
let result = code_escapes((Base.RefValue{String}, Base.RefValue{String}, Symbol)) do a, b, fld
12261224
obj = SafeRefs(a, b)
12271225
return getfield(obj, fld) # should escape both `a` and `b`
12281226
end
@@ -1232,7 +1230,7 @@ end
12321230
@test has_return_escape(result.state[Argument(3)], r) # b
12331231
@test !is_load_forwardable(result.state[SSAValue(i)]) # obj
12341232
end
1235-
let result = code_escapes((String, String, Int)) do a, b, idx
1233+
let result = code_escapes((Base.RefValue{String}, Base.RefValue{String}, Int)) do a, b, idx
12361234
obj = SafeRefs(a, b)
12371235
return obj[idx] # should escape both `a` and `b`
12381236
end
@@ -1242,33 +1240,33 @@ end
12421240
@test has_return_escape(result.state[Argument(3)], r) # b
12431241
@test !is_load_forwardable(result.state[SSAValue(i)]) # obj
12441242
end
1245-
let result = code_escapes((String, String, Symbol)) do a, b, fld
1246-
obj = SafeRefs("a", "b")
1243+
let result = code_escapes((Base.RefValue{String}, Base.RefValue{String}, Symbol)) do a, b, fld
1244+
obj = SafeRefs(Ref("a"), Ref("b"))
12471245
setfield!(obj, fld, a)
12481246
return obj[2] # should escape `a`
12491247
end
1250-
i = only(findall(isnew, result.ir.stmts.stmt))
1248+
i = last(findall(isnew, result.ir.stmts.stmt))
12511249
r = only(findall(isreturn, result.ir.stmts.stmt))
12521250
@test has_return_escape(result.state[Argument(2)], r) # a
12531251
@test !has_return_escape(result.state[Argument(3)], r) # b
12541252
@test !is_load_forwardable(result.state[SSAValue(i)]) # obj
12551253
end
1256-
let result = code_escapes((String, Symbol)) do a, fld
1257-
obj = SafeRefs("a", "b")
1254+
let result = code_escapes((Base.RefValue{String}, Symbol)) do a, fld
1255+
obj = SafeRefs(Ref("a"), Ref("b"))
12581256
setfield!(obj, fld, a)
12591257
return obj[1] # this should escape `a`
12601258
end
1261-
i = only(findall(isnew, result.ir.stmts.stmt))
1259+
i = last(findall(isnew, result.ir.stmts.stmt))
12621260
r = only(findall(isreturn, result.ir.stmts.stmt))
12631261
@test has_return_escape(result.state[Argument(2)], r) # a
12641262
@test !is_load_forwardable(result.state[SSAValue(i)]) # obj
12651263
end
1266-
let result = code_escapes((String, String, Int)) do a, b, idx
1267-
obj = SafeRefs("a", "b")
1264+
let result = code_escapes((Base.RefValue{String}, Base.RefValue{String}, Int)) do a, b, idx
1265+
obj = SafeRefs(Ref("a"), Ref("b"))
12681266
obj[idx] = a
12691267
return obj[2] # should escape `a`
12701268
end
1271-
i = only(findall(isnew, result.ir.stmts.stmt))
1269+
i = last(findall(isnew, result.ir.stmts.stmt))
12721270
r = only(findall(isreturn, result.ir.stmts.stmt))
12731271
@test has_return_escape(result.state[Argument(2)], r) # a
12741272
@test !has_return_escape(result.state[Argument(3)], r) # b
@@ -1280,7 +1278,7 @@ end
12801278

12811279
let result = @eval EATModule() begin
12821280
@noinline getx(obj) = obj[]
1283-
$code_escapes((String,)) do a
1281+
$code_escapes((Base.RefValue{String},)) do a
12841282
obj = SafeRef(a)
12851283
fld = getx(obj)
12861284
return fld
@@ -1294,8 +1292,8 @@ end
12941292
end
12951293

12961294
# TODO interprocedural alias analysis
1297-
let result = code_escapes((SafeRef{String},)) do s
1298-
s[] = "bar"
1295+
let result = code_escapes((SafeRef{Base.RefValue{String}},)) do s
1296+
s[] = Ref("bar")
12991297
global GV = s[]
13001298
nothing
13011299
end
@@ -1335,7 +1333,7 @@ end
13351333
let result = @eval EATModule() begin
13361334
@noinline mysetindex!(x, a) = x[1] = a
13371335
const Ax = Vector{Any}(undef, 1)
1338-
$code_escapes((String,)) do s
1336+
$code_escapes((Base.RefValue{String},)) do s
13391337
mysetindex!(Ax, s)
13401338
end
13411339
end
@@ -1391,11 +1389,11 @@ end
13911389
end
13921390

13931391
# handle conflicting field information correctly
1394-
let result = code_escapes((Bool,String,String,)) do cnd, baz, qux
1392+
let result = code_escapes((Bool,Base.RefValue{String},Base.RefValue{String},)) do cnd, baz, qux
13951393
if cnd
1396-
o = SafeRef("foo")
1394+
o = SafeRef(Ref("foo"))
13971395
else
1398-
o = SafeRefs("bar", baz)
1396+
o = SafeRefs(Ref("bar"), baz)
13991397
r = getfield(o, 2)
14001398
end
14011399
if cnd
@@ -1409,12 +1407,14 @@ end
14091407
@test has_return_escape(result.state[Argument(3)], r) # baz
14101408
@test has_return_escape(result.state[Argument(4)], r) # qux
14111409
for new in findall(isnew, result.ir.stmts.stmt)
1412-
@test is_load_forwardable(result.state[SSAValue(new)])
1410+
if !(result.ir[SSAValue(new)][:type] <: Base.RefValue)
1411+
@test is_load_forwardable(result.state[SSAValue(new)])
1412+
end
14131413
end
14141414
end
1415-
let result = code_escapes((Bool,String,String,)) do cnd, baz, qux
1415+
let result = code_escapes((Bool,Base.RefValue{String},Base.RefValue{String},)) do cnd, baz, qux
14161416
if cnd
1417-
o = SafeRefs("foo", "bar")
1417+
o = SafeRefs(Ref("foo"), Ref("bar"))
14181418
r = setfield!(o, 2, baz)
14191419
else
14201420
o = SafeRef(qux)
@@ -2141,9 +2141,9 @@ end
21412141
# propagate escapes imposed on call arguments
21422142
@noinline broadcast_noescape2(b) = broadcast(identity, b)
21432143
let result = code_escapes() do
2144-
broadcast_noescape2(Ref("Hi"))
2144+
broadcast_noescape2(Ref(Ref("Hi")))
21452145
end
2146-
i = only(findall(isnew, result.ir.stmts.stmt))
2146+
i = last(findall(isnew, result.ir.stmts.stmt))
21472147
@test_broken !has_return_escape(result.state[SSAValue(i)]) # TODO interprocedural alias analysis
21482148
@test !has_thrown_escape(result.state[SSAValue(i)])
21492149
end

0 commit comments

Comments
 (0)