Skip to content

WIP: Staged functions #7474

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Sep 24, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,13 @@ const limit_tuple_type_n = function (t::Tuple, lim::Int)
return t
end

function func_for_method(m::Method, tt)
if !m.isstaged
return m.func.code
end
(ccall(:jl_instantiate_staged,Any,(Any,Any),m,tt)).code
end

function abstract_call_gf(f, fargs, argtypes, e)
if length(argtypes)>1 && (argtypes[1] <: Tuple) && argtypes[2]===Int
# allow tuple indexing functions to take advantage of constant
Expand Down Expand Up @@ -704,7 +711,13 @@ function abstract_call_gf(f, fargs, argtypes, e)
end
end
for (m::Tuple) in x
linfo = m[3].func.code
local linfo
try
linfo = func_for_method(m[3],argtypes)
catch
rettype = Any
break
end
sig = m[1]
lsig = length(m[3].sig)
# limit argument type tuple based on size of definition signature.
Expand Down Expand Up @@ -749,7 +762,12 @@ function invoke_tfunc(f, types, argtypes)
return Any
end
for (m::Tuple) in applicable
linfo = m[3].func.code
local linfo
try
linfo = func_for_method(m[3],types)
catch
return Any
end
if typeseq(m[1],types)
tvars = m[2][1:2:end]
(ti, env) = ccall(:jl_match_method, Any, (Any,Any,Any),
Expand Down Expand Up @@ -2072,7 +2090,13 @@ function inlineable(f, e::Expr, atypes, sv, enclosing_ast)
return NF
end
meth = meth[1]::Tuple
linfo = meth[3].func.code

local linfo
try
linfo = func_for_method(meth[3],atypes)
catch
return NF
end

## This code tries to limit the argument list length only when it is
## growing due to recursion.
Expand Down Expand Up @@ -3036,7 +3060,7 @@ end
function code_typed(f::Callable, types::(Type...))
asts = {}
for x in _methods(f,types,-1)
linfo = x[3].func.code
linfo = func_for_method(x[3],types)
(tree, ty) = typeinf(linfo, x[1], x[2])
if !isa(tree,Expr)
push!(asts, ccall(:jl_uncompress_ast, Any, (Any,Any), linfo, tree))
Expand All @@ -3050,7 +3074,7 @@ end
function return_types(f::Callable, types)
rt = {}
for x in _methods(f,types,-1)
linfo = x[3].func.code
linfo = func_for_method(x[3],types)
(tree, ty) = typeinf(linfo, x[1], x[2])
push!(rt, ty)
end
Expand Down
1 change: 1 addition & 0 deletions contrib/julia.xml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
<item> do </item>
<item> for </item>
<item> function </item>
<item> stagedfunction </item>
<item> if </item>
<item> immutable </item>
<item> let </item>
Expand Down
1 change: 1 addition & 0 deletions src/alloc.c
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ jl_sym_t *global_sym; jl_sym_t *tuple_sym;
jl_sym_t *dot_sym; jl_sym_t *newvar_sym;
jl_sym_t *boundscheck_sym; jl_sym_t *copyast_sym;
jl_sym_t *simdloop_sym; jl_sym_t *meta_sym;
jl_sym_t *arrow_sym; jl_sym_t *ldots_sym;

typedef struct {
int64_t a;
Expand Down
5 changes: 3 additions & 2 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2800,9 +2800,9 @@ static Value *emit_expr(jl_value_t *expr, jl_codectx_t *ctx, bool isboxed,
make_gcroot(a1, ctx);
Value *a2 = boxed(emit_expr(args[2], ctx),ctx);
make_gcroot(a2, ctx);
Value *mdargs[5] = { name, bp, literal_pointer_val(bnd), a1, a2 };
Value *mdargs[6] = { name, bp, literal_pointer_val(bnd), a1, a2, literal_pointer_val(args[3]) };
ctx->argDepth = last_depth;
return builder.CreateCall(prepare_call(jlmethod_func), ArrayRef<Value*>(&mdargs[0], 5));
return builder.CreateCall(prepare_call(jlmethod_func), ArrayRef<Value*>(&mdargs[0], 6));
}
else if (head == const_sym) {
jl_sym_t *sym = (jl_sym_t*)args[0];
Expand Down Expand Up @@ -4368,6 +4368,7 @@ static void init_julia_llvm_env(Module *m)
mdargs.push_back(jl_pvalue_llvmt);
mdargs.push_back(jl_pvalue_llvmt);
mdargs.push_back(jl_pvalue_llvmt);
mdargs.push_back(jl_pvalue_llvmt);
jlmethod_func =
Function::Create(FunctionType::get(jl_pvalue_llvmt, mdargs, false),
Function::ExternalLinkage,
Expand Down
85 changes: 67 additions & 18 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,7 @@ static jl_function_t *jl_method_table_assoc_exact(jl_methtable_t *mt,
mt_assoc_lkup:
while (ml != JL_NULL) {
size_t lensig = jl_tuple_len(ml->sig);
if ((lensig == n || ml->va) &&
!(lensig > n && n != lensig-1)) {
if (lensig == n || (ml->va && lensig <= n+1)) {
if (cache_match(args, n, (jl_tuple_t*)ml->sig, ml->va, lensig)) {
return ml->func;
}
Expand Down Expand Up @@ -323,7 +322,7 @@ jl_function_t *jl_reinstantiate_method(jl_function_t *f, jl_lambda_info_t *li)
static
jl_methlist_t *jl_method_list_insert(jl_methlist_t **pml, jl_tuple_t *type,
jl_function_t *method, jl_tuple_t *tvars,
int check_amb);
int check_amb, int8_t isstaged);

static
jl_function_t *jl_method_cache_insert(jl_methtable_t *mt, jl_tuple_t *type,
Expand Down Expand Up @@ -356,7 +355,7 @@ jl_function_t *jl_method_cache_insert(jl_methtable_t *mt, jl_tuple_t *type,
}
}
ml_do_insert:
return jl_method_list_insert(pml, type, method, jl_null, 0)->func;
return jl_method_list_insert(pml, type, method, jl_null, 0, 0)->func;
}

extern jl_function_t *jl_typeinf_func;
Expand Down Expand Up @@ -454,7 +453,7 @@ static jl_value_t *ml_matches(jl_methlist_t *ml, jl_value_t *type,

static jl_function_t *cache_method(jl_methtable_t *mt, jl_tuple_t *type,
jl_function_t *method, jl_tuple_t *decl,
jl_tuple_t *sparams)
jl_tuple_t *sparams, int isstaged)
{
size_t i;
int need_guard_entries = 0;
Expand Down Expand Up @@ -705,7 +704,7 @@ static jl_function_t *cache_method(jl_methtable_t *mt, jl_tuple_t *type,
// in general, here we want to find the biggest type that's not a
// supertype of any other method signatures. so far we are conservative
// and the types we find should be bigger.
if (jl_tuple_len(type) > mt->max_args &&
if (!isstaged && jl_tuple_len(type) > mt->max_args &&
jl_is_vararg_type(jl_tupleref(decl,jl_tuple_len(decl)-1))) {
size_t nspec = mt->max_args + 2;
jl_tuple_t *limited = jl_alloc_tuple(nspec);
Expand Down Expand Up @@ -930,6 +929,43 @@ static jl_value_t *lookup_match(jl_value_t *a, jl_value_t *b, jl_tuple_t **penv,
return ti;
}

DLLEXPORT jl_function_t *jl_instantiate_staged(jl_methlist_t *m, jl_tuple_t *tt)
{
jl_lambda_info_t *newlinfo = NULL;
jl_value_t *code = NULL;
jl_expr_t *ex = NULL;
jl_expr_t *oldast = NULL;
jl_function_t *func = NULL;
JL_GC_PUSH4(&code, &newlinfo, &ex, &oldast);
if (jl_is_expr(m->func->linfo->ast))
oldast = (jl_expr_t*)m->func->linfo;
else
oldast = (jl_expr_t*)jl_uncompress_ast(m->func->linfo, m->func->linfo->ast);
assert(oldast->head == lambda_sym);
ex = jl_exprn(arrow_sym, 2);
jl_array_t *oldargnames = (jl_array_t*)jl_cellref(oldast->args,0);
jl_expr_t *argnames = jl_exprn(tuple_sym, jl_array_len(oldargnames));
jl_cellset(ex->args, 0, argnames);
for (size_t i = 0; i < jl_array_len(oldargnames); ++i) {
jl_value_t *arg = jl_cellref(oldargnames,i);
if (jl_is_expr(arg)) {
assert(((jl_expr_t*)arg)->head == colons_sym);
arg = jl_cellref(((jl_expr_t*)arg)->args,0);
assert(jl_is_symbol(arg));
jl_expr_t *dd_expr = jl_exprn(ldots_sym,1);
jl_cellset(dd_expr->args,0,arg);
jl_cellset(argnames->args,i,dd_expr);
} else {
assert(jl_is_symbol(arg));
jl_cellset(argnames->args,i,arg);
}
}
jl_cellset(ex->args, 1, jl_apply(m->func, tt->data, jl_tuple_len(tt)));
func = (jl_function_t*)jl_toplevel_eval(jl_expand((jl_value_t*)ex));
JL_GC_POP();
return func;
}

static jl_function_t *jl_mt_assoc_by_type(jl_methtable_t *mt, jl_tuple_t *tt, int cache, int inexact)
{
jl_methlist_t *m = mt->defs;
Expand Down Expand Up @@ -972,17 +1008,27 @@ static jl_function_t *jl_mt_assoc_by_type(jl_methtable_t *mt, jl_tuple_t *tt, in
m = m->next;
}

jl_function_t *func = NULL;
if (ti == (jl_value_t*)jl_bottom_type) {
JL_GC_POP();
if (m != JL_NULL) {
func = m->func;
if (m->isstaged)
func = jl_instantiate_staged(m,tt);
JL_GC_POP();
if (!cache)
return m->func;
return cache_method(mt, tt, m->func, (jl_tuple_t*)m->sig, jl_null);
return func;
return cache_method(mt, tt, func, (jl_tuple_t*)m->sig, jl_null, m->isstaged);
}
JL_GC_POP();
return jl_bottom_func;
}

assert(jl_is_tuple(env));
func = m->func;

if (m->isstaged)
func = jl_instantiate_staged(m,tt);

// don't bother computing this if no arguments are tuples
for(i=0; i < jl_tuple_len(tt); i++) {
if (jl_is_tuple(jl_tupleref(tt,i)))
Expand All @@ -999,9 +1045,9 @@ static jl_function_t *jl_mt_assoc_by_type(jl_methtable_t *mt, jl_tuple_t *tt, in
assert(jl_is_tuple(newsig));
jl_function_t *nf;
if (!cache)
nf = m->func;
nf = func;
else
nf = cache_method(mt, tt, m->func, newsig, env);
nf = cache_method(mt, tt, func, newsig, env, 0);
JL_GC_POP();
return nf;
}
Expand Down Expand Up @@ -1138,7 +1184,7 @@ static int has_unions(jl_tuple_t *type)
static
jl_methlist_t *jl_method_list_insert(jl_methlist_t **pml, jl_tuple_t *type,
jl_function_t *method, jl_tuple_t *tvars,
int check_amb)
int check_amb, int8_t isstaged)
{
jl_methlist_t *l, **pl;

Expand Down Expand Up @@ -1170,6 +1216,7 @@ jl_methlist_t *jl_method_list_insert(jl_methlist_t **pml, jl_tuple_t *type,
l->va = (jl_tuple_len(type) > 0 &&
jl_is_vararg_type(jl_tupleref(type,jl_tuple_len(type)-1))) ?
1 : 0;
l->isstaged = isstaged;
l->invokes = (struct _jl_methtable_t *)JL_NULL;
l->func = method;
JL_SIGATOMIC_END();
Expand Down Expand Up @@ -1197,6 +1244,7 @@ jl_methlist_t *jl_method_list_insert(jl_methlist_t **pml, jl_tuple_t *type,
newrec->va = (jl_tuple_len(type) > 0 &&
jl_is_vararg_type(jl_tupleref(type,jl_tuple_len(type)-1))) ?
1 : 0;
newrec->isstaged = isstaged;
newrec->func = method;
newrec->invokes = (struct _jl_methtable_t*)JL_NULL;
newrec->next = l;
Expand Down Expand Up @@ -1249,12 +1297,13 @@ static void remove_conflicting(jl_methlist_t **pl, jl_value_t *type)
}

jl_methlist_t *jl_method_table_insert(jl_methtable_t *mt, jl_tuple_t *type,
jl_function_t *method, jl_tuple_t *tvars)
jl_function_t *method, jl_tuple_t *tvars,
int8_t isstaged)
{
if (jl_tuple_len(tvars) == 1)
tvars = (jl_tuple_t*)jl_t0(tvars);
JL_SIGATOMIC_BEGIN();
jl_methlist_t *ml = jl_method_list_insert(&mt->defs,type,method,tvars,1);
jl_methlist_t *ml = jl_method_list_insert(&mt->defs,type,method,tvars,1,isstaged);
// invalidate cached methods that overlap this definition
remove_conflicting(&mt->cache, (jl_value_t*)type);
if (mt->cache_arg1 != JL_NULL) {
Expand Down Expand Up @@ -1660,7 +1709,7 @@ jl_value_t *jl_gf_invoke(jl_function_t *gf, jl_tuple_t *types,
if (m->invokes == JL_NULL) {
m->invokes = new_method_table(mt->name);
// this private method table has just this one definition
jl_method_list_insert(&m->invokes->defs,m->sig,m->func,m->tvars,0);
jl_method_list_insert(&m->invokes->defs,m->sig,m->func,m->tvars,0,0);
}

tt = arg_type_tuple(args, nargs);
Expand All @@ -1684,7 +1733,7 @@ jl_value_t *jl_gf_invoke(jl_function_t *gf, jl_tuple_t *types,
jl_tuple_len(tpenv)/2);
}
}
mfunc = cache_method(m->invokes, tt, m->func, newsig, tpenv);
mfunc = cache_method(m->invokes, tt, m->func, newsig, tpenv, 0);
JL_GC_POP();
}

Expand Down Expand Up @@ -1721,15 +1770,15 @@ DLLEXPORT jl_function_t *jl_new_gf_internal(jl_value_t *env)
}

void jl_add_method(jl_function_t *gf, jl_tuple_t *types, jl_function_t *meth,
jl_tuple_t *tvars)
jl_tuple_t *tvars, int8_t isstaged)
{
assert(jl_is_function(gf));
assert(jl_is_tuple(types));
assert(jl_is_func(meth));
assert(jl_is_mtable(jl_gf_mtable(gf)));
if (meth->linfo != NULL)
meth->linfo->name = jl_gf_name(gf);
(void)jl_method_table_insert(jl_gf_mtable(gf), types, meth, tvars);
(void)jl_method_table_insert(jl_gf_mtable(gf), types, meth, tvars, isstaged);
}

DLLEXPORT jl_tuple_t *jl_match_method(jl_value_t *type, jl_value_t *sig,
Expand Down
2 changes: 1 addition & 1 deletion src/interpreter.c
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ static jl_value_t *eval(jl_value_t *e, jl_value_t **locals, size_t nl)
jl_check_static_parameter_conflicts((jl_lambda_info_t*)args[2], (jl_tuple_t*)jl_t1(atypes), fname);
}
meth = eval(args[2], locals, nl);
jl_method_def(fname, bp, b, (jl_tuple_t*)atypes, (jl_function_t*)meth);
jl_method_def(fname, bp, b, (jl_tuple_t*)atypes, (jl_function_t*)meth, args[3]);
JL_GC_POP();
return *bp;
}
Expand Down
6 changes: 4 additions & 2 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -3001,10 +3001,10 @@ void jl_init_types(void)

jl_method_type =
jl_new_datatype(jl_symbol("Method"), jl_any_type, jl_null,
jl_tuple(6, jl_symbol("sig"), jl_symbol("va"),
jl_tuple(7, jl_symbol("sig"), jl_symbol("va"), jl_symbol("isstaged"),
jl_symbol("tvars"), jl_symbol("func"),
jl_symbol("invokes"), jl_symbol("next")),
jl_tuple(6, jl_tuple_type, jl_bool_type,
jl_tuple(7, jl_tuple_type, jl_bool_type, jl_bool_type,
jl_tuple_type, jl_any_type,
jl_any_type, jl_any_type),
0, 1);
Expand Down Expand Up @@ -3262,6 +3262,8 @@ void jl_init_types(void)
copyast_sym = jl_symbol("copyast");
simdloop_sym = jl_symbol("simdloop");
meta_sym = jl_symbol("meta");
arrow_sym = jl_symbol("->");
ldots_sym = jl_symbol("...");
}

#ifdef __cplusplus
Expand Down
4 changes: 2 additions & 2 deletions src/julia-parser.scm
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
(define operator? (Set operators))

(define reserved-words '(begin while if for try return break continue
function macro quote let local global const
stagedfunction function macro quote let local global const
abstract typealias type bitstype immutable ccall do
module baremodule using import export importall))

Expand Down Expand Up @@ -1074,7 +1074,7 @@
(if const
`(const ,expr)
expr)))
((function macro)
((stagedfunction function macro)
(let* ((paren (eqv? (require-token s) #\())
(sig (parse-call s))
(def (if (or (symbol? sig)
Expand Down
Loading