diff --git a/base/inference.jl b/base/inference.jl index c2f25b7ad78fb..c02a7a713a854 100644 --- a/base/inference.jl +++ b/base/inference.jl @@ -633,6 +633,13 @@ const limit_tuple_type_n = function (t::Tuple, lim::Int) return t end +function func_for_method(m::Method, tt) + if !m.isstaged + return m.func.code + end + (ccall(:jl_instantiate_staged,Any,(Any,Any),m,tt)).code +end + function abstract_call_gf(f, fargs, argtypes, e) if length(argtypes)>1 && (argtypes[1] <: Tuple) && argtypes[2]===Int # allow tuple indexing functions to take advantage of constant @@ -704,7 +711,13 @@ function abstract_call_gf(f, fargs, argtypes, e) end end for (m::Tuple) in x - linfo = m[3].func.code + local linfo + try + linfo = func_for_method(m[3],argtypes) + catch + rettype = Any + break + end sig = m[1] lsig = length(m[3].sig) # limit argument type tuple based on size of definition signature. @@ -749,7 +762,12 @@ function invoke_tfunc(f, types, argtypes) return Any end for (m::Tuple) in applicable - linfo = m[3].func.code + local linfo + try + linfo = func_for_method(m[3],types) + catch + return Any + end if typeseq(m[1],types) tvars = m[2][1:2:end] (ti, env) = ccall(:jl_match_method, Any, (Any,Any,Any), @@ -2072,7 +2090,13 @@ function inlineable(f, e::Expr, atypes, sv, enclosing_ast) return NF end meth = meth[1]::Tuple - linfo = meth[3].func.code + + local linfo + try + linfo = func_for_method(meth[3],atypes) + catch + return NF + end ## This code tries to limit the argument list length only when it is ## growing due to recursion. @@ -3036,7 +3060,7 @@ end function code_typed(f::Callable, types::(Type...)) asts = {} for x in _methods(f,types,-1) - linfo = x[3].func.code + linfo = func_for_method(x[3],types) (tree, ty) = typeinf(linfo, x[1], x[2]) if !isa(tree,Expr) push!(asts, ccall(:jl_uncompress_ast, Any, (Any,Any), linfo, tree)) @@ -3050,7 +3074,7 @@ end function return_types(f::Callable, types) rt = {} for x in _methods(f,types,-1) - linfo = x[3].func.code + linfo = func_for_method(x[3],types) (tree, ty) = typeinf(linfo, x[1], x[2]) push!(rt, ty) end diff --git a/contrib/julia.xml b/contrib/julia.xml index 78286dcd70110..4f97dfa2199c5 100644 --- a/contrib/julia.xml +++ b/contrib/julia.xml @@ -37,6 +37,7 @@ do for function + stagedfunction if immutable let diff --git a/src/alloc.c b/src/alloc.c index 6a747e9a247f2..4db66e169ffde 100644 --- a/src/alloc.c +++ b/src/alloc.c @@ -92,6 +92,7 @@ jl_sym_t *global_sym; jl_sym_t *tuple_sym; jl_sym_t *dot_sym; jl_sym_t *newvar_sym; jl_sym_t *boundscheck_sym; jl_sym_t *copyast_sym; jl_sym_t *simdloop_sym; jl_sym_t *meta_sym; +jl_sym_t *arrow_sym; jl_sym_t *ldots_sym; typedef struct { int64_t a; diff --git a/src/codegen.cpp b/src/codegen.cpp index 4a5e5cfb65580..f3a035a388eaa 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -2800,9 +2800,9 @@ static Value *emit_expr(jl_value_t *expr, jl_codectx_t *ctx, bool isboxed, make_gcroot(a1, ctx); Value *a2 = boxed(emit_expr(args[2], ctx),ctx); make_gcroot(a2, ctx); - Value *mdargs[5] = { name, bp, literal_pointer_val(bnd), a1, a2 }; + Value *mdargs[6] = { name, bp, literal_pointer_val(bnd), a1, a2, literal_pointer_val(args[3]) }; ctx->argDepth = last_depth; - return builder.CreateCall(prepare_call(jlmethod_func), ArrayRef(&mdargs[0], 5)); + return builder.CreateCall(prepare_call(jlmethod_func), ArrayRef(&mdargs[0], 6)); } else if (head == const_sym) { jl_sym_t *sym = (jl_sym_t*)args[0]; @@ -4368,6 +4368,7 @@ static void init_julia_llvm_env(Module *m) mdargs.push_back(jl_pvalue_llvmt); mdargs.push_back(jl_pvalue_llvmt); mdargs.push_back(jl_pvalue_llvmt); + mdargs.push_back(jl_pvalue_llvmt); jlmethod_func = Function::Create(FunctionType::get(jl_pvalue_llvmt, mdargs, false), Function::ExternalLinkage, diff --git a/src/gf.c b/src/gf.c index 38c3b95acf780..1a264f45b41da 100644 --- a/src/gf.c +++ b/src/gf.c @@ -274,8 +274,7 @@ static jl_function_t *jl_method_table_assoc_exact(jl_methtable_t *mt, mt_assoc_lkup: while (ml != JL_NULL) { size_t lensig = jl_tuple_len(ml->sig); - if ((lensig == n || ml->va) && - !(lensig > n && n != lensig-1)) { + if (lensig == n || (ml->va && lensig <= n+1)) { if (cache_match(args, n, (jl_tuple_t*)ml->sig, ml->va, lensig)) { return ml->func; } @@ -323,7 +322,7 @@ jl_function_t *jl_reinstantiate_method(jl_function_t *f, jl_lambda_info_t *li) static jl_methlist_t *jl_method_list_insert(jl_methlist_t **pml, jl_tuple_t *type, jl_function_t *method, jl_tuple_t *tvars, - int check_amb); + int check_amb, int8_t isstaged); static jl_function_t *jl_method_cache_insert(jl_methtable_t *mt, jl_tuple_t *type, @@ -356,7 +355,7 @@ jl_function_t *jl_method_cache_insert(jl_methtable_t *mt, jl_tuple_t *type, } } ml_do_insert: - return jl_method_list_insert(pml, type, method, jl_null, 0)->func; + return jl_method_list_insert(pml, type, method, jl_null, 0, 0)->func; } extern jl_function_t *jl_typeinf_func; @@ -454,7 +453,7 @@ static jl_value_t *ml_matches(jl_methlist_t *ml, jl_value_t *type, static jl_function_t *cache_method(jl_methtable_t *mt, jl_tuple_t *type, jl_function_t *method, jl_tuple_t *decl, - jl_tuple_t *sparams) + jl_tuple_t *sparams, int isstaged) { size_t i; int need_guard_entries = 0; @@ -705,7 +704,7 @@ static jl_function_t *cache_method(jl_methtable_t *mt, jl_tuple_t *type, // in general, here we want to find the biggest type that's not a // supertype of any other method signatures. so far we are conservative // and the types we find should be bigger. - if (jl_tuple_len(type) > mt->max_args && + if (!isstaged && jl_tuple_len(type) > mt->max_args && jl_is_vararg_type(jl_tupleref(decl,jl_tuple_len(decl)-1))) { size_t nspec = mt->max_args + 2; jl_tuple_t *limited = jl_alloc_tuple(nspec); @@ -930,6 +929,43 @@ static jl_value_t *lookup_match(jl_value_t *a, jl_value_t *b, jl_tuple_t **penv, return ti; } +DLLEXPORT jl_function_t *jl_instantiate_staged(jl_methlist_t *m, jl_tuple_t *tt) +{ + jl_lambda_info_t *newlinfo = NULL; + jl_value_t *code = NULL; + jl_expr_t *ex = NULL; + jl_expr_t *oldast = NULL; + jl_function_t *func = NULL; + JL_GC_PUSH4(&code, &newlinfo, &ex, &oldast); + if (jl_is_expr(m->func->linfo->ast)) + oldast = (jl_expr_t*)m->func->linfo; + else + oldast = (jl_expr_t*)jl_uncompress_ast(m->func->linfo, m->func->linfo->ast); + assert(oldast->head == lambda_sym); + ex = jl_exprn(arrow_sym, 2); + jl_array_t *oldargnames = (jl_array_t*)jl_cellref(oldast->args,0); + jl_expr_t *argnames = jl_exprn(tuple_sym, jl_array_len(oldargnames)); + jl_cellset(ex->args, 0, argnames); + for (size_t i = 0; i < jl_array_len(oldargnames); ++i) { + jl_value_t *arg = jl_cellref(oldargnames,i); + if (jl_is_expr(arg)) { + assert(((jl_expr_t*)arg)->head == colons_sym); + arg = jl_cellref(((jl_expr_t*)arg)->args,0); + assert(jl_is_symbol(arg)); + jl_expr_t *dd_expr = jl_exprn(ldots_sym,1); + jl_cellset(dd_expr->args,0,arg); + jl_cellset(argnames->args,i,dd_expr); + } else { + assert(jl_is_symbol(arg)); + jl_cellset(argnames->args,i,arg); + } + } + jl_cellset(ex->args, 1, jl_apply(m->func, tt->data, jl_tuple_len(tt))); + func = (jl_function_t*)jl_toplevel_eval(jl_expand((jl_value_t*)ex)); + JL_GC_POP(); + return func; +} + static jl_function_t *jl_mt_assoc_by_type(jl_methtable_t *mt, jl_tuple_t *tt, int cache, int inexact) { jl_methlist_t *m = mt->defs; @@ -972,17 +1008,27 @@ static jl_function_t *jl_mt_assoc_by_type(jl_methtable_t *mt, jl_tuple_t *tt, in m = m->next; } + jl_function_t *func = NULL; if (ti == (jl_value_t*)jl_bottom_type) { - JL_GC_POP(); if (m != JL_NULL) { + func = m->func; + if (m->isstaged) + func = jl_instantiate_staged(m,tt); + JL_GC_POP(); if (!cache) - return m->func; - return cache_method(mt, tt, m->func, (jl_tuple_t*)m->sig, jl_null); + return func; + return cache_method(mt, tt, func, (jl_tuple_t*)m->sig, jl_null, m->isstaged); } + JL_GC_POP(); return jl_bottom_func; } assert(jl_is_tuple(env)); + func = m->func; + + if (m->isstaged) + func = jl_instantiate_staged(m,tt); + // don't bother computing this if no arguments are tuples for(i=0; i < jl_tuple_len(tt); i++) { if (jl_is_tuple(jl_tupleref(tt,i))) @@ -999,9 +1045,9 @@ static jl_function_t *jl_mt_assoc_by_type(jl_methtable_t *mt, jl_tuple_t *tt, in assert(jl_is_tuple(newsig)); jl_function_t *nf; if (!cache) - nf = m->func; + nf = func; else - nf = cache_method(mt, tt, m->func, newsig, env); + nf = cache_method(mt, tt, func, newsig, env, 0); JL_GC_POP(); return nf; } @@ -1138,7 +1184,7 @@ static int has_unions(jl_tuple_t *type) static jl_methlist_t *jl_method_list_insert(jl_methlist_t **pml, jl_tuple_t *type, jl_function_t *method, jl_tuple_t *tvars, - int check_amb) + int check_amb, int8_t isstaged) { jl_methlist_t *l, **pl; @@ -1170,6 +1216,7 @@ jl_methlist_t *jl_method_list_insert(jl_methlist_t **pml, jl_tuple_t *type, l->va = (jl_tuple_len(type) > 0 && jl_is_vararg_type(jl_tupleref(type,jl_tuple_len(type)-1))) ? 1 : 0; + l->isstaged = isstaged; l->invokes = (struct _jl_methtable_t *)JL_NULL; l->func = method; JL_SIGATOMIC_END(); @@ -1197,6 +1244,7 @@ jl_methlist_t *jl_method_list_insert(jl_methlist_t **pml, jl_tuple_t *type, newrec->va = (jl_tuple_len(type) > 0 && jl_is_vararg_type(jl_tupleref(type,jl_tuple_len(type)-1))) ? 1 : 0; + newrec->isstaged = isstaged; newrec->func = method; newrec->invokes = (struct _jl_methtable_t*)JL_NULL; newrec->next = l; @@ -1249,12 +1297,13 @@ static void remove_conflicting(jl_methlist_t **pl, jl_value_t *type) } jl_methlist_t *jl_method_table_insert(jl_methtable_t *mt, jl_tuple_t *type, - jl_function_t *method, jl_tuple_t *tvars) + jl_function_t *method, jl_tuple_t *tvars, + int8_t isstaged) { if (jl_tuple_len(tvars) == 1) tvars = (jl_tuple_t*)jl_t0(tvars); JL_SIGATOMIC_BEGIN(); - jl_methlist_t *ml = jl_method_list_insert(&mt->defs,type,method,tvars,1); + jl_methlist_t *ml = jl_method_list_insert(&mt->defs,type,method,tvars,1,isstaged); // invalidate cached methods that overlap this definition remove_conflicting(&mt->cache, (jl_value_t*)type); if (mt->cache_arg1 != JL_NULL) { @@ -1660,7 +1709,7 @@ jl_value_t *jl_gf_invoke(jl_function_t *gf, jl_tuple_t *types, if (m->invokes == JL_NULL) { m->invokes = new_method_table(mt->name); // this private method table has just this one definition - jl_method_list_insert(&m->invokes->defs,m->sig,m->func,m->tvars,0); + jl_method_list_insert(&m->invokes->defs,m->sig,m->func,m->tvars,0,0); } tt = arg_type_tuple(args, nargs); @@ -1684,7 +1733,7 @@ jl_value_t *jl_gf_invoke(jl_function_t *gf, jl_tuple_t *types, jl_tuple_len(tpenv)/2); } } - mfunc = cache_method(m->invokes, tt, m->func, newsig, tpenv); + mfunc = cache_method(m->invokes, tt, m->func, newsig, tpenv, 0); JL_GC_POP(); } @@ -1721,7 +1770,7 @@ DLLEXPORT jl_function_t *jl_new_gf_internal(jl_value_t *env) } void jl_add_method(jl_function_t *gf, jl_tuple_t *types, jl_function_t *meth, - jl_tuple_t *tvars) + jl_tuple_t *tvars, int8_t isstaged) { assert(jl_is_function(gf)); assert(jl_is_tuple(types)); @@ -1729,7 +1778,7 @@ void jl_add_method(jl_function_t *gf, jl_tuple_t *types, jl_function_t *meth, assert(jl_is_mtable(jl_gf_mtable(gf))); if (meth->linfo != NULL) meth->linfo->name = jl_gf_name(gf); - (void)jl_method_table_insert(jl_gf_mtable(gf), types, meth, tvars); + (void)jl_method_table_insert(jl_gf_mtable(gf), types, meth, tvars, isstaged); } DLLEXPORT jl_tuple_t *jl_match_method(jl_value_t *type, jl_value_t *sig, diff --git a/src/interpreter.c b/src/interpreter.c index f62446357e589..0f412890e27e9 100644 --- a/src/interpreter.c +++ b/src/interpreter.c @@ -285,7 +285,7 @@ static jl_value_t *eval(jl_value_t *e, jl_value_t **locals, size_t nl) jl_check_static_parameter_conflicts((jl_lambda_info_t*)args[2], (jl_tuple_t*)jl_t1(atypes), fname); } meth = eval(args[2], locals, nl); - jl_method_def(fname, bp, b, (jl_tuple_t*)atypes, (jl_function_t*)meth); + jl_method_def(fname, bp, b, (jl_tuple_t*)atypes, (jl_function_t*)meth, args[3]); JL_GC_POP(); return *bp; } diff --git a/src/jltypes.c b/src/jltypes.c index 70875ea48a72d..3f8b72f245263 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -3001,10 +3001,10 @@ void jl_init_types(void) jl_method_type = jl_new_datatype(jl_symbol("Method"), jl_any_type, jl_null, - jl_tuple(6, jl_symbol("sig"), jl_symbol("va"), + jl_tuple(7, jl_symbol("sig"), jl_symbol("va"), jl_symbol("isstaged"), jl_symbol("tvars"), jl_symbol("func"), jl_symbol("invokes"), jl_symbol("next")), - jl_tuple(6, jl_tuple_type, jl_bool_type, + jl_tuple(7, jl_tuple_type, jl_bool_type, jl_bool_type, jl_tuple_type, jl_any_type, jl_any_type, jl_any_type), 0, 1); @@ -3262,6 +3262,8 @@ void jl_init_types(void) copyast_sym = jl_symbol("copyast"); simdloop_sym = jl_symbol("simdloop"); meta_sym = jl_symbol("meta"); + arrow_sym = jl_symbol("->"); + ldots_sym = jl_symbol("..."); } #ifdef __cplusplus diff --git a/src/julia-parser.scm b/src/julia-parser.scm index 32999aacc3d91..5b9459aef8edf 100644 --- a/src/julia-parser.scm +++ b/src/julia-parser.scm @@ -90,7 +90,7 @@ (define operator? (Set operators)) (define reserved-words '(begin while if for try return break continue - function macro quote let local global const + stagedfunction function macro quote let local global const abstract typealias type bitstype immutable ccall do module baremodule using import export importall)) @@ -1074,7 +1074,7 @@ (if const `(const ,expr) expr))) - ((function macro) + ((stagedfunction function macro) (let* ((paren (eqv? (require-token s) #\()) (sig (parse-call s)) (def (if (or (symbol? sig) diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index 44201134d8228..35e8c351fa34e 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -437,7 +437,7 @@ (pair? (caddr e)) (eq? (car (caddr e)) 'quote) (symbol? (cadr (caddr e)))))) -(define (method-def-expr- name sparams argl body) +(define (method-def-expr- name sparams argl body isstaged) (receive (names bounds) (sparam-name-bounds sparams '() '()) (begin @@ -455,12 +455,12 @@ (let* ((types (llist-types argl)) (body (method-lambda-expr argl body))) (if (null? sparams) - `(method ,name (tuple (tuple ,@types) (tuple)) ,body) + `(method ,name (tuple (tuple ,@types) (tuple)) ,body ,isstaged) `(method ,name (call (lambda ,names (tuple (tuple ,@types) (tuple ,@names))) ,@(symbols->typevars names bounds #t)) - ,body)))))) + ,body ,isstaged)))))) (define (vararg? x) (and (pair? x) (eq? (car x) '...))) (define (trans? x) (and (pair? x) (eq? (car x) '|.'|))) @@ -469,7 +469,7 @@ (define (const-default? x) (or (number? x) (string? x) (char? x) (and (pair? x) (eq? (car x) 'quote)))) -(define (keywords-method-def-expr name sparams argl body) +(define (keywords-method-def-expr name sparams argl body isstaged) (let* ((kargl (cdar argl)) ;; keyword expressions (= k v) (pargl (cdr argl)) ;; positional args (body (if (and (pair? body) (eq? (car body) 'block)) @@ -533,7 +533,7 @@ `(block ,@(if (null? lno) '() (list (append (car lno) (list (undot-name name))))) - ,@stmts)) + ,@stmts) isstaged) ;; call with no keyword args ,(method-def-expr- @@ -551,7 +551,7 @@ ,@(if (null? restkw) '() '((cell1d))) ,@(map arg-name pargl) ,@(if (null? vararg) '() - (list `(... ,(arg-name (car vararg))))))))) + (list `(... ,(arg-name (car vararg)))))))) isstaged) ;; call with unsorted keyword args. this sorts and re-dispatches. ,(method-def-expr- @@ -628,11 +628,12 @@ ,@(if (null? restkw) '() (list rkw)) ,@(map arg-name pargl) ,@(if (null? vararg) '() - (list `(... ,(arg-name (car vararg))))))))) + (list `(... ,(arg-name (car vararg)))))))) + isstaged) ;; return primary function ,name)))) -(define (optional-positional-defs name sparams req opt dfl body overall-argl . kw) +(define (optional-positional-defs name sparams req opt dfl body isstaged overall-argl . kw) (let ((lno (if (and (pair? body) (pair? (cdr body)) (pair? (cadr body)) (eq? (caadr body) 'line)) (list (cadr body)) @@ -668,11 +669,11 @@ `(block ,@lno (call ,name ,@kw ,@(map arg-name passed) ,@vals))))) - (method-def-expr name sp (append kw passed) body))) + (method-def-expr name sp (append kw passed) body isstaged))) (iota (length opt))) - ,(method-def-expr name sparams overall-argl body)))) + ,(method-def-expr name sparams overall-argl body isstaged)))) -(define (method-def-expr name sparams argl body) +(define (method-def-expr name sparams argl body isstaged) (if (any kwarg? argl) ;; has optional positional args (begin @@ -697,20 +698,20 @@ (check-kw-args (cdr kw)) (receive (vararg req) (separate vararg? argl) - (optional-positional-defs name sparams req opt dfl body + (optional-positional-defs name sparams req opt dfl body isstaged (cons kw (append req opt vararg)) `(parameters (... ,(gensy)))))) ;; optional positional only (receive (vararg req) (separate vararg? argl) - (optional-positional-defs name sparams req opt dfl body + (optional-positional-defs name sparams req opt dfl body isstaged (append req opt vararg))))))) (if (has-parameters? argl) ;; keywords only (begin (check-kw-args (cdar argl)) - (keywords-method-def-expr name sparams argl body)) + (keywords-method-def-expr name sparams argl body isstaged)) ;; neither - (method-def-expr- name sparams argl body)))) + (method-def-expr- name sparams argl body isstaged)))) (define (struct-def-expr name params super fields mut) (receive @@ -796,9 +797,15 @@ (pattern-lambda (function (call (curly name . p) . sig) body) `(function (call (curly ,(if (eq? name Tname) iname name) ,@p) ,@sig) ,(ctor-body body))) + (pattern-lambda (stagedfunction (call (curly name . p) . sig) body) + `(stagedfunction (call (curly ,(if (eq? name Tname) iname name) ,@p) ,@sig) + ,(ctor-body body))) (pattern-lambda (function (call name . sig) body) `(function (call ,(if (eq? name Tname) iname name) ,@sig) ,(ctor-body body))) + (pattern-lambda (stagedfunction (call name . sig) body) + `(stagedfunction (call ,(if (eq? name Tname) iname name) ,@sig) + ,(ctor-body body))) (pattern-lambda (= (call (curly name . p) . sig) body) `(= (call (curly ,(if (eq? name Tname) iname name) ,@p) ,@sig) ,(ctor-body body))) @@ -985,6 +992,7 @@ (cond ((or (atom? e) (quoted? e)) e) ((or (eq? (car e) 'lambda) (eq? (car e) 'function) + (eq? (car e) 'stagedfunction) (eq? (car e) '->)) e) ((eq? (car e) 'return) `(block ,@(if ret `((= ,ret true)) '()) @@ -998,7 +1006,7 @@ ((quoted? e) e) (else (case (car e) - ((function) + ((function stagedfunction) (let ((name (cadr e))) (if (pair? name) (if (eq? (car name) 'call) @@ -1008,11 +1016,11 @@ (method-def-expr (cadr (cadr name)) (cddr (cadr name)) (fix-arglist (cddr name)) - (caddr e)) + (caddr e) (eq? (car e) 'stagedfunction)) (method-def-expr (cadr name) '() (fix-arglist (cddr name)) - (caddr e)))) + (caddr e) (eq? (car e) 'stagedfunction)))) (if (eq? (car name) 'tuple) (expand-binding-forms `(-> ,name ,(caddr e))) @@ -2848,6 +2856,7 @@ So far only the second case can actually occur. (define (free-vars e) (table.keys (free-vars- e (table) *free-vars-secret-value*))) +(define (caddddr x) (car (cdr (cdr (cdr (cdr x)))))) ; convert each lambda's (locals ...) to ; ((localvars...) var-info-lst captured-var-infos) ; where var-info-lst is a list of var-info records @@ -2951,7 +2960,8 @@ So far only the second case can actually occur. (vinfo:set-iasg! vi #t))))) `(method ,(cadr e) ,(analyze-vars (caddr e) env captvars) - ,(analyze-vars (cadddr e) env captvars))) + ,(analyze-vars (cadddr e) env captvars) + ,(caddddr e))) (else (cons (car e) (map (lambda (x) (analyze-vars x env captvars)) (cdr e))))))) diff --git a/src/julia.h b/src/julia.h index 13a04464728a5..32c603c5f9623 100644 --- a/src/julia.h +++ b/src/julia.h @@ -261,6 +261,7 @@ typedef struct _jl_methlist_t { JL_DATA_TYPE jl_tuple_t *sig; int8_t va; + int8_t isstaged; jl_tuple_t *tvars; jl_function_t *func; // cache of specializations of this method for invoke(), i.e. @@ -422,7 +423,7 @@ extern jl_sym_t *compositetype_sym; extern jl_sym_t *type_goto_sym; extern jl_sym_t *global_sym; extern jl_sym_t *tuple_sym; extern jl_sym_t *boundscheck_sym; extern jl_sym_t *copyast_sym; extern jl_sym_t *simdloop_sym; extern jl_sym_t *meta_sym; - +extern jl_sym_t *arrow_sym; extern jl_sym_t *ldots_sym; // object accessors ----------------------------------------------------------- @@ -672,9 +673,9 @@ DLLEXPORT jl_sym_t *jl_get_root_symbol(void); jl_expr_t *jl_exprn(jl_sym_t *head, size_t n); jl_function_t *jl_new_generic_function(jl_sym_t *name); void jl_add_method(jl_function_t *gf, jl_tuple_t *types, jl_function_t *meth, - jl_tuple_t *tvars); + jl_tuple_t *tvars, int8_t isstaged); DLLEXPORT jl_value_t *jl_method_def(jl_sym_t *name, jl_value_t **bp, jl_binding_t *bnd, - jl_tuple_t *argtypes, jl_function_t *f); + jl_tuple_t *argtypes, jl_function_t *f, jl_value_t *isstaged); DLLEXPORT jl_value_t *jl_box_bool(int8_t x); DLLEXPORT jl_value_t *jl_box_int8(int32_t x); DLLEXPORT jl_value_t *jl_box_uint8(uint32_t x); diff --git a/src/toplevel.c b/src/toplevel.c index 4662a78837fba..600609d2b1a39 100644 --- a/src/toplevel.c +++ b/src/toplevel.c @@ -642,7 +642,7 @@ static int type_contains(jl_value_t *ty, jl_value_t *x) void print_func_loc(JL_STREAM *s, jl_lambda_info_t *li); DLLEXPORT jl_value_t *jl_method_def(jl_sym_t *name, jl_value_t **bp, jl_binding_t *bnd, - jl_tuple_t *argtypes, jl_function_t *f) + jl_tuple_t *argtypes, jl_function_t *f, jl_value_t *isstaged) { // argtypes is a tuple ((types...), (typevars...)) jl_tuple_t *t = (jl_tuple_t*)jl_t1(argtypes); @@ -702,7 +702,7 @@ DLLEXPORT jl_value_t *jl_method_def(jl_sym_t *name, jl_value_t **bp, jl_binding_ assert(jl_is_tuple(argtypes)); assert(jl_is_tuple(t)); - jl_add_method((jl_function_t*)gf, argtypes, f, t); + jl_add_method((jl_function_t*)gf, argtypes, f, t, isstaged == jl_true); if (jl_boot_file_loaded && f->linfo && f->linfo->ast && jl_is_expr(f->linfo->ast)) { jl_lambda_info_t *li = f->linfo; diff --git a/test/runtests.jl b/test/runtests.jl index 960232a961fed..dc4d747e6ed0c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,7 +9,7 @@ testnames = [ "floatapprox", "readdlm", "reflection", "regex", "float16", "combinatorics", "sysinfo", "rounding", "ranges", "mod2pi", "euler", "show", "lineedit", "replcompletions", "repl", "test", "examples", "goto", - "llvmcall", "grisu", "nullable", "meta" + "llvmcall", "grisu", "nullable", "meta", "staged" ] @unix_only push!(testnames, "unicode") diff --git a/test/staged.jl b/test/staged.jl new file mode 100644 index 0000000000000..85e8c4c049f7d --- /dev/null +++ b/test/staged.jl @@ -0,0 +1,78 @@ +stagedfunction staged_t1(a,b) + if a == Int + return :(a+b) + else + return :(a*b) + end +end + +@test staged_t1(1,2) == 3 +@test staged_t1(1.0,0.5) == 0.5 +@test staged_t1(1,0.5) == 1.5 + +tinline(a,b) = staged_t1(a,b) + +@test !isa(tinline(1,2),Expr) +@test tinline(1,0.5) == 1.5 + +stagedfunction splat(a,b...) + :( ($a,$b,a,b) ) +end + +@test splat(1,2,3) == (Int,(Int,Int),1,(2,3)) + +stagediobuf = IOBuffer() +stagedfunction splat2(a...) + print(stagediobuf, a) + :(nothing) +end + +const intstr = @sprintf("%s", Int) +splat2(1) +@test takebuf_string(stagediobuf) == "($intstr,)" +splat2(1,3) +@test takebuf_string(stagediobuf) == "($intstr,$intstr)" +splat2(5,2) +@test takebuf_string(stagediobuf) == "" +splat2(1:3,5.2) +@test takebuf_string(stagediobuf) == "(UnitRange{$intstr},Float64)" +splat2(3,5:2:7) +@test takebuf_string(stagediobuf) == "($intstr,StepRange{$intstr,$intstr})" +splat2(1,2,3,4) +@test takebuf_string(stagediobuf) == "($intstr,$intstr,$intstr,$intstr)" +splat2(1,2,3) +@test takebuf_string(stagediobuf) == "($intstr,$intstr,$intstr)" +splat2(1:5, 3, 3:3) +@test takebuf_string(stagediobuf) == "(UnitRange{$intstr},$intstr,UnitRange{$intstr})" +splat2(1:5, 3, 3:3) +@test takebuf_string(stagediobuf) == "" +splat2(1:5, 3:3, 3) +@test takebuf_string(stagediobuf) == "(UnitRange{$intstr},UnitRange{$intstr},$intstr)" +splat2(1:5, 3:3) +@test takebuf_string(stagediobuf) == "(UnitRange{$intstr},UnitRange{$intstr})" +splat2(3, 3:5) +@test takebuf_string(stagediobuf) == "($intstr,UnitRange{$intstr})" + + +A = rand(5,5,3); +B = slice(A, 1:3, 2, 1:3); +stagedfunction mygetindex(S::SubArray, indexes::Real...) + T, N, A, I = S.parameters + if N != length(indexes) + error("Wrong number of indexes supplied") + end + NP = length(I) + indexexprs = Array(Expr, NP) + j = 1 + for i = 1:NP + if I[i] == Int + indexexprs[i] = :(S.indexes[$i]) + else + indexexprs[i] = :(S.indexes[$i][indexes[$j]]) + j += 1 + end + end + ex = :(S.parent[$(indexexprs...)]) + ex +end +@test mygetindex(B,2,2) == A[2,2,2]