Skip to content

Commit a8bc71a

Browse files
committed
add if @generated ... else ... end inside functions to provide optional optimizers
use meta nodes instead of `stagedfunction` expression head
1 parent 5fa5c7e commit a8bc71a

23 files changed

+398
-243
lines changed

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ New language features
1919
* The macro call syntax `@macroname[args]` is now available and is parsed
2020
as `@macroname([args])` ([#23519]).
2121

22+
* The construct `if @generated ...; else ...; end` can be used to provide both
23+
`@generated` and normal implementations of part of a function. Surrounding code
24+
will be common to both versions ([#23168]).
25+
2226
Language changes
2327
----------------
2428

base/boot.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,4 +432,32 @@ show(@nospecialize a) = show(STDOUT, a)
432432
print(@nospecialize a...) = print(STDOUT, a...)
433433
println(@nospecialize a...) = println(STDOUT, a...)
434434

435+
struct GeneratedFunctionStub
436+
gen
437+
argnames
438+
spnames
439+
line
440+
file
441+
end
442+
443+
# invoke and wrap the results of @generated
444+
function (g::GeneratedFunctionStub)(args...)
445+
body = g.gen(args...)
446+
if body isa CodeInfo
447+
return body
448+
end
449+
lam = Expr(:lambda, g.argnames,
450+
Expr(Symbol("scope-block"),
451+
Expr(:block,
452+
LineNumberNode(g.line, g.file),
453+
Expr(:meta, :push_loc, g.file, Symbol("@generated body")),
454+
Expr(:return, body),
455+
Expr(:meta, :pop_loc))))
456+
if g.spnames === nothing
457+
return lam
458+
else
459+
return Expr(Symbol("with-static-parameters"), lam, g.spnames...)
460+
end
461+
end
462+
435463
ccall(:jl_set_istopmod, Void, (Any, Bool), Core, true)

base/docs/Docs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,7 @@ finddoc(λ, def) = false
642642

643643
# Predicates and helpers for `docm` expression selection:
644644

645-
const FUNC_HEADS = [:function, :stagedfunction, :macro, :(=)]
645+
const FUNC_HEADS = [:function, :macro, :(=)]
646646
const BINDING_HEADS = [:typealias, :const, :global, :(=)] # deprecation: remove `typealias` post-0.6
647647
# For the special `:@mac` / `:(Base.@mac)` syntax for documenting a macro after definition.
648648
isquotedmacrocall(x) =

base/expr.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -332,10 +332,23 @@ function remove_linenums!(ex::Expr)
332332
return ex
333333
end
334334

335+
macro generated()
336+
return Expr(:generated)
337+
end
338+
335339
macro generated(f)
336-
if isa(f, Expr) && (f.head === :function || is_short_function_def(f))
337-
f.head = :stagedfunction
338-
return Expr(:escape, f)
340+
if isa(f, Expr) && (f.head === :function || is_short_function_def(f))
341+
body = f.args[2]
342+
lno = body.args[1]
343+
return Expr(:escape,
344+
Expr(f.head, f.args[1],
345+
Expr(:block,
346+
lno,
347+
Expr(:if, Expr(:generated),
348+
body,
349+
Expr(:block,
350+
Expr(:meta, :generated_only),
351+
Expr(:return, nothing))))))
339352
else
340353
error("invalid syntax; @generated must be used with a function definition")
341354
end

base/linalg/bidiag.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -573,12 +573,18 @@ _valuefields(::Type{<:AbstractTriangular}) = [:data]
573573

574574
const SpecialArrays = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal,AbstractTriangular}
575575

576-
@generated function fillslots!(A::SpecialArrays, x)
577-
ex = :(xT = convert(eltype(A), x))
578-
for field in _valuefields(A)
579-
ex = :($ex; fill!(A.$field, xT))
576+
function fillslots!(A::SpecialArrays, x)
577+
xT = convert(eltype(A), x)
578+
if @generated
579+
quote
580+
$([ :(fill!(A.$field, xT)) for field in _valuefields(A) ]...)
581+
end
582+
else
583+
for field in _valuefields(A)
584+
fill!(getfield(A, field), xT)
585+
end
580586
end
581-
:($ex;return A)
587+
return A
582588
end
583589

584590
# for historical reasons:

base/methodshow.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,15 @@ function argtype_decl(env, n, sig::DataType, i::Int, nargs, isva::Bool) # -> (ar
4242
return s, string_with_env(env, t)
4343
end
4444

45+
function method_argnames(m::Method)
46+
if !isdefined(m, :source) && isdefined(m, :generator)
47+
return m.generator.argnames
48+
end
49+
argnames = Vector{Any}(m.nargs)
50+
ccall(:jl_fill_argnames, Void, (Any, Any), m.source, argnames)
51+
return argnames
52+
end
53+
4554
function arg_decl_parts(m::Method)
4655
tv = Any[]
4756
sig = m.sig
@@ -52,8 +61,7 @@ function arg_decl_parts(m::Method)
5261
file = m.file
5362
line = m.line
5463
if isdefined(m, :source) || isdefined(m, :generator)
55-
argnames = Vector{Any}(m.nargs)
56-
ccall(:jl_fill_argnames, Void, (Any, Any), isdefined(m, :source) ? m.source : m.generator.inferred, argnames)
64+
argnames = method_argnames(m)
5765
show_env = ImmutableDict{Symbol, Any}()
5866
for t in tv
5967
show_env = ImmutableDict(show_env, :unionall_env => t)

base/multidimensional.jl

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -549,14 +549,11 @@ end
549549
@noinline throw_checksize_error(A, sz) = throw(DimensionMismatch("output array is the wrong size; expected $sz, got $(size(A))"))
550550

551551
## setindex! ##
552-
@generated function _setindex!(l::IndexStyle, A::AbstractArray, x, I::Union{Real, AbstractArray}...)
553-
N = length(I)
554-
quote
555-
@_inline_meta
556-
@boundscheck checkbounds(A, I...)
557-
_unsafe_setindex!(l, _maybe_reshape(l, A, I...), x, I...)
558-
A
559-
end
552+
function _setindex!(l::IndexStyle, A::AbstractArray, x, I::Union{Real, AbstractArray}...)
553+
@_inline_meta
554+
@boundscheck checkbounds(A, I...)
555+
_unsafe_setindex!(l, _maybe_reshape(l, A, I...), x, I...)
556+
A
560557
end
561558

562559
_iterable(v::AbstractArray) = v
@@ -916,28 +913,29 @@ function copy!(dest::AbstractArray{T,N}, src::AbstractArray{T,N}) where {T,N}
916913
dest
917914
end
918915

919-
@generated function copy!(dest::AbstractArray{T1,N},
920-
Rdest::CartesianRange{N},
921-
src::AbstractArray{T2,N},
922-
Rsrc::CartesianRange{N}) where {T1,T2,N}
923-
quote
924-
isempty(Rdest) && return dest
925-
if size(Rdest) != size(Rsrc)
926-
throw(ArgumentError("source and destination must have same size (got $(size(Rsrc)) and $(size(Rdest)))"))
916+
function copy!(dest::AbstractArray{T1,N}, Rdest::CartesianRange{N},
917+
src::AbstractArray{T2,N}, Rsrc::CartesianRange{N}) where {T1,T2,N}
918+
isempty(Rdest) && return dest
919+
if size(Rdest) != size(Rsrc)
920+
throw(ArgumentError("source and destination must have same size (got $(size(Rsrc)) and $(size(Rdest)))"))
921+
end
922+
@boundscheck checkbounds(dest, first(Rdest))
923+
@boundscheck checkbounds(dest, last(Rdest))
924+
@boundscheck checkbounds(src, first(Rsrc))
925+
@boundscheck checkbounds(src, last(Rsrc))
926+
ΔI = first(Rdest) - first(Rsrc)
927+
if @generated
928+
quote
929+
@nloops $N i (n->Rsrc.indices[n]) begin
930+
@inbounds @nref($N,dest,n->i_n+ΔI[n]) = @nref($N,src,i)
931+
end
927932
end
928-
@boundscheck checkbounds(dest, first(Rdest))
929-
@boundscheck checkbounds(dest, last(Rdest))
930-
@boundscheck checkbounds(src, first(Rsrc))
931-
@boundscheck checkbounds(src, last(Rsrc))
932-
ΔI = first(Rdest) - first(Rsrc)
933-
# TODO: restore when #9080 is fixed
934-
# for I in Rsrc
935-
# @inbounds dest[I+ΔI] = src[I]
936-
@nloops $N i (n->Rsrc.indices[n]) begin
937-
@inbounds @nref($N,dest,n->i_n+ΔI[n]) = @nref($N,src,i)
933+
else
934+
for I in Rsrc
935+
@inbounds dest[I + ΔI] = src[I]
938936
end
939-
dest
940937
end
938+
dest
941939
end
942940

943941
"""

base/reflection.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,8 @@ function length(mt::MethodTable)
738738
end
739739
isempty(mt::MethodTable) = (mt.defs === nothing)
740740

741-
uncompressed_ast(m::Method) = uncompressed_ast(m, isdefined(m, :source) ? m.source : m.generator.inferred)
741+
uncompressed_ast(m::Method) = isdefined(m,:source) ? uncompressed_ast(m, m.source) :
742+
error("Method is @generated; try `code_lowered` instead.")
742743
uncompressed_ast(m::Method, s::CodeInfo) = s
743744
uncompressed_ast(m::Method, s::Array{UInt8,1}) = ccall(:jl_uncompress_ast, Any, (Any, Any), m, s)::CodeInfo
744745
uncompressed_ast(m::Core.MethodInstance) = uncompressed_ast(m.def)
@@ -852,7 +853,7 @@ code_native(::IO, ::Any, ::Symbol) = error("illegal code_native call") # resolve
852853

853854
# give a decent error message if we try to instantiate a staged function on non-leaf types
854855
function func_for_method_checked(m::Method, @nospecialize types)
855-
if isdefined(m,:generator) && !isdefined(m,:source) && !_isleaftype(types)
856+
if isdefined(m,:generator) && !_isleaftype(types)
856857
error("cannot call @generated function `", m, "` ",
857858
"with abstract argument types: ", types)
858859
end

base/sysimg.jl

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -236,21 +236,27 @@ include("broadcast.jl")
236236
using .Broadcast
237237

238238
# define the real ntuple functions
239-
@generated function ntuple(f::F, ::Val{N}) where {F,N}
240-
Core.typeassert(N, Int)
241-
(N >= 0) || return :(throw($(ArgumentError(string("tuple length should be ≥0, got ", N)))))
242-
return quote
243-
$(Expr(:meta, :inline))
244-
@nexprs $N i -> t_i = f(i)
245-
@ncall $N tuple t
239+
@inline function ntuple(f::F, ::Val{N}) where {F,N}
240+
N::Int
241+
(N >= 0) || throw(ArgumentError(string("tuple length should be ≥0, got ", N)))
242+
if @generated
243+
quote
244+
@nexprs $N i -> t_i = f(i)
245+
@ncall $N tuple t
246+
end
247+
else
248+
Tuple(f(i) for i = 1:N)
246249
end
247250
end
248-
@generated function fill_to_length(t::Tuple, val, ::Val{N}) where {N}
249-
M = length(t.parameters)
250-
M > N && return :(throw($(ArgumentError("input tuple of length $M, requested $N"))))
251-
return quote
252-
$(Expr(:meta, :inline))
253-
(t..., $(Any[ :val for i = (M + 1):N ]...))
251+
@inline function fill_to_length(t::Tuple, val, ::Val{N}) where {N}
252+
M = length(t)
253+
M > N && throw(ArgumentError("input tuple of length $M, requested $N"))
254+
if @generated
255+
quote
256+
(t..., $(fill(:val, N-length(t.parameters))...))
257+
end
258+
else
259+
(t..., fill(val, N-M)...)
254260
end
255261
end
256262

doc/src/manual/metaprogramming.md

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,17 +1012,16 @@ syntax tree.
10121012
A very special macro is `@generated`, which allows you to define so-called *generated functions*.
10131013
These have the capability to generate specialized code depending on the types of their arguments
10141014
with more flexibility and/or less code than what can be achieved with multiple dispatch. While
1015-
macros work with expressions at parsing-time and cannot access the types of their inputs, a generated
1015+
macros work with expressions at parse time and cannot access the types of their inputs, a generated
10161016
function gets expanded at a time when the types of the arguments are known, but the function is
10171017
not yet compiled.
10181018

10191019
Instead of performing some calculation or action, a generated function declaration returns a quoted
10201020
expression which then forms the body for the method corresponding to the types of the arguments.
1021-
When called, the body expression is first evaluated and compiled, then the returned expression
1022-
is compiled and run. To make this efficient, the result is often cached. And to make this inferable,
1023-
only a limited subset of the language is usable. Thus, generated functions provide a flexible
1024-
framework to move work from run-time to compile-time, at the expense of greater restrictions on
1025-
the allowable constructs.
1021+
When a generated function is called, the expression it returns is compiled and then run.
1022+
To make this efficient, the result is usually cached. And to make this inferable, only a limited
1023+
subset of the language is usable. Thus, generated functions provide a flexible way to move work from
1024+
run time to compile time, at the expense of greater restrictions on allowed constructs.
10261025

10271026
When defining generated functions, there are four main differences to ordinary functions:
10281027

@@ -1038,7 +1037,7 @@ When defining generated functions, there are four main differences to ordinary f
10381037
This means they can only read global constants, and cannot have any side effects.
10391038
In other words, they must be completely pure.
10401039
Due to an implementation limitation, this also means that they currently cannot define a closure
1041-
or untyped generator.
1040+
or generator.
10421041

10431042
It's easiest to illustrate this with an example. We can declare a generated function `foo` as
10441043

@@ -1053,9 +1052,8 @@ foo (generic function with 1 method)
10531052
Note that the body returns a quoted expression, namely `:(x * x)`, rather than just the value
10541053
of `x * x`.
10551054

1056-
From the caller's perspective, they are very similar to regular functions; in fact, you don't
1057-
have to know if you're calling a regular or generated function - the syntax and result of the
1058-
call is just the same. Let's see how `foo` behaves:
1055+
From the caller's perspective, this is identical to a regular function; in fact, you don't
1056+
have to know whether you're calling a regular or generated function. Let's see how `foo` behaves:
10591057

10601058
```jldoctest generated
10611059
julia> x = foo(2); # note: output is from println() statement in the body
@@ -1199,7 +1197,7 @@ end and at the call site; however, *don't copy them*, for the following reasons:
11991197
when, how often or how many times these side-effects will occur
12001198
* the `bar` function solves a problem that is better solved with multiple dispatch - defining `bar(x) = x`
12011199
and `bar(x::Integer) = x ^ 2` will do the same thing, but it is both simpler and faster.
1202-
* the `baz` function is pathologically insane
1200+
* the `baz` function is pathological
12031201

12041202
Note that the set of operations that should not be attempted in a generated function is unbounded,
12051203
and the runtime system can currently only detect a subset of the invalid operations. There are
@@ -1317,3 +1315,50 @@ the two tuples, multiplication and addition/subtraction. All the looping is perf
13171315
and we avoid looping during execution entirely. Thus, we only loop *once per type*, in this case
13181316
once per `N` (except in edge cases where the function is generated more than once - see disclaimer
13191317
above).
1318+
1319+
### Optionally-generated functions
1320+
1321+
Generated functions can achieve high efficiency at run time, but come with a compile time cost:
1322+
a new function body must be generated for every combination of concrete argument types.
1323+
Typically, Julia is able to compile "generic" versions of functions that will work for any
1324+
arguments, but with generated functions this is impossible.
1325+
This means that programs making heavy use of generated functions might be impossible to
1326+
statically compile.
1327+
1328+
To solve this problem, the language provides syntax for writing normal, non-generated
1329+
alternative implementations of generated functions.
1330+
Applied to the `sub2ind` example above, it would look like this:
1331+
1332+
```julia
1333+
function sub2ind_gen(dims::NTuple{N}, I::Integer...) where N
1334+
if N != length(I)
1335+
throw(ArgumentError("Number of dimensions must match number of indices."))
1336+
end
1337+
if @generated
1338+
ex = :(I[$N] - 1)
1339+
for i = (N - 1):-1:1
1340+
ex = :(I[$i] - 1 + dims[$i] * $ex)
1341+
end
1342+
return :($ex + 1)
1343+
else
1344+
ind = I[N] - 1
1345+
for i = (N - 1):-1:1
1346+
ind = I[i] - 1 + dims[i]*ind
1347+
end
1348+
return ind + 1
1349+
end
1350+
end
1351+
```
1352+
1353+
Internally, this code creates two implementations of the function: a generated one where
1354+
the first block in `if @generated` is used, and a normal one where the `else` block is used.
1355+
Notice that we added an error check to the top of the function.
1356+
This code will be common to both versions, and is run-time code in both versions
1357+
(in other words, it will be quoted and returned as an expression from the generated version).
1358+
Inside the `then` part of the `if @generated` block, code has the same semantics as other
1359+
generated functions: argument names refer to types, and the code should return an expression.
1360+
1361+
In this style of definition, the code generation feature is essentially an optional
1362+
optimization.
1363+
The compiler will use it if convenient, but otherwise may choose to use the normal
1364+
implementation instead.

src/ast.c

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ jl_sym_t *meta_sym; jl_sym_t *compiler_temp_sym;
5555
jl_sym_t *inert_sym; jl_sym_t *vararg_sym;
5656
jl_sym_t *unused_sym; jl_sym_t *static_parameter_sym;
5757
jl_sym_t *polly_sym; jl_sym_t *inline_sym;
58-
jl_sym_t *propagate_inbounds_sym;
58+
jl_sym_t *propagate_inbounds_sym; jl_sym_t *generated_sym;
59+
jl_sym_t *generated_only_sym;
5960
jl_sym_t *isdefined_sym; jl_sym_t *nospecialize_sym;
6061
jl_sym_t *macrocall_sym;
6162
jl_sym_t *hygienicscope_sym;
@@ -343,6 +344,8 @@ void jl_init_frontend(void)
343344
hygienicscope_sym = jl_symbol("hygienic-scope");
344345
gc_preserve_begin_sym = jl_symbol("gc_preserve_begin");
345346
gc_preserve_end_sym = jl_symbol("gc_preserve_end");
347+
generated_sym = jl_symbol("generated");
348+
generated_only_sym = jl_symbol("generated_only");
346349
}
347350

348351
JL_DLLEXPORT void jl_lisp_prompt(void)

0 commit comments

Comments
 (0)