Skip to content

Commit 31c2469

Browse files
vtjnashquinnj
authored andcommitted
simple pass at giving union fields an optimized layout
unlike codegen, only bitstypes (!isptr) fields are permitted in the union and the offset count starts from 0 instead of 1 but otherwise the tindex counter is compatible
1 parent 5851a26 commit 31c2469

File tree

7 files changed

+203
-59
lines changed

7 files changed

+203
-59
lines changed

src/cgutils.cpp

Lines changed: 72 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,6 +1131,33 @@ static Value *emit_bounds_check(jl_codectx_t &ctx, const jl_cgval_t &ainfo, jl_v
11311131

11321132
// --- loading and storing ---
11331133

1134+
static Value *compute_box_tindex(Value *datatype, jl_value_t *supertype, jl_value_t *ut, jl_codectx_t *ctx)
1135+
{
1136+
Value *tindex = ConstantInt::get(T_int8, 0);
1137+
unsigned counter = 0;
1138+
for_each_uniontype_small(
1139+
[&](unsigned idx, jl_datatype_t *jt) {
1140+
if (jl_subtype((jl_value_t*)jt, supertype)) {
1141+
Value *cmp = builder.CreateICmpEQ(literal_pointer_val((jl_value_t*)jt), datatype);
1142+
tindex = builder.CreateSelect(cmp, ConstantInt::get(T_int8, idx), tindex);
1143+
}
1144+
},
1145+
ut,
1146+
counter);
1147+
return tindex;
1148+
}
1149+
1150+
// get the runtime tindex value
1151+
static Value *compute_tindex_unboxed(const jl_cgval_t &val, jl_value_t *typ, jl_codectx_t *ctx)
1152+
{
1153+
if (val.constant)
1154+
return ConstantInt::get(T_int8, get_box_tindex((jl_datatype_t*)jl_typeof(val.constant), typ));
1155+
if (val.isboxed)
1156+
return compute_box_tindex(emit_typeof_boxed(val, ctx), val.typ, typ, ctx);
1157+
assert(val.TIndex);
1158+
return builder.CreateAnd(val.TIndex, ConstantInt::get(T_int8, 0x7f));
1159+
}
1160+
11341161
// If given alignment is 0 and LLVM's assumed alignment for a load/store via ptr
11351162
// might be stricter than the Julia alignment for jltype, return the alignment of jltype.
11361163
// Otherwise return the given alignment.
@@ -1436,6 +1463,9 @@ static jl_cgval_t emit_getfield_knownidx(jl_codectx_t &ctx, const jl_cgval_t &st
14361463
addr = ctx.builder.CreateStructGEP(lt, ptr, idx);
14371464
}
14381465
}
1466+
int align = jl_field_offset(jt, idx);
1467+
align |= 16;
1468+
align &= -align;
14391469
if (jl_field_isptr(jt, idx)) {
14401470
bool maybe_null = idx >= (unsigned)jt->ninitialized;
14411471
Instruction *Load = maybe_mark_load_dereferenceable(
@@ -1447,6 +1477,29 @@ static jl_cgval_t emit_getfield_knownidx(jl_codectx_t &ctx, const jl_cgval_t &st
14471477
null_pointer_check(ctx, fldv);
14481478
return mark_julia_type(ctx, fldv, true, jfty, strct.gcroot || !strct.isimmutable);
14491479
}
1480+
else if (jl_is_uniontype(jfty)) {
1481+
int fsz = jl_field_size(jt, idx);
1482+
Value *ptindex = builder.CreateGEP(LLVM37_param(T_int8) emit_bitcast(addr, T_pint8), ConstantInt::get(T_size, fsz - 1));
1483+
Value *tindex = builder.CreateNUWAdd(ConstantInt::get(T_int8, 1), builder.CreateLoad(ptindex));
1484+
bool isimmutable = strct.isimmutable;
1485+
Value *gcroot = strct.gcroot;
1486+
if (jt->mutabl) {
1487+
// move value to an immutable stack slot
1488+
Type *AT = ArrayType::get(IntegerType::get(jl_LLVMContext, 8 * align), (fsz + align - 2) / align);
1489+
AllocaInst *lv = emit_static_alloca(AT, ctx);
1490+
if (align > 1)
1491+
lv->setAlignment(align);
1492+
Value *nbytes = ConstantInt::get(T_size, fsz - 1);
1493+
builder.CreateMemCpy(lv, addr, nbytes, align);
1494+
addr = lv;
1495+
isimmutable = true;
1496+
gcroot = NULL;
1497+
}
1498+
jl_cgval_t fieldval = mark_julia_slot(addr, jfty, tindex, strct.tbaa);
1499+
fieldval.isimmutable = isimmutable;
1500+
fieldval.gcroot = gcroot;
1501+
return fieldval;
1502+
}
14501503
else if (!jt->mutabl) {
14511504
// just compute the pointer and let user load it when necessary
14521505
jl_cgval_t fieldval = mark_julia_slot(addr, jfty, NULL, strct.tbaa);
@@ -2065,7 +2118,7 @@ static void emit_unionmove(jl_codectx_t &ctx, Value *dest, const jl_cgval_t &src
20652118
jl_value_t *typ = src.constant ? jl_typeof(src.constant) : src.typ;
20662119
Type *store_ty = julia_type_to_llvm(typ);
20672120
assert(skip || jl_isbits(typ));
2068-
if (jl_isbits(typ)) {
2121+
if (jl_isbits(typ) && jl_datatype_size(typ) > 0) {
20692122
if (!src.ispointer() || src.constant) {
20702123
emit_unbox(ctx, store_ty, src, typ, dest, isVolatile);
20712124
}
@@ -2236,11 +2289,24 @@ static void emit_setfield(jl_codectx_t &ctx,
22362289
emit_checked_write_barrier(ctx, boxed(ctx, strct), r);
22372290
}
22382291
else {
2239-
int align = jl_field_offset(sty, idx0);
2240-
align |= 16;
2241-
align &= -align;
2242-
typed_store(ctx, addr, ConstantInt::get(T_size, 0), rhs, jfty,
2243-
strct.tbaa, data_pointer(ctx, strct, T_pjlvalue), align);
2292+
if (jl_is_uniontype(jfty)) {
2293+
int fsz = jl_field_size(sty, idx0);
2294+
// compute tindex from rhs
2295+
jl_cgval_t rhs_union = convert_julia_type(rhs, jfty, ctx);
2296+
Value *ptindex = builder.CreateGEP(LLVM37_param(T_int8) emit_bitcast(addr, T_pint8), ConstantInt::get(T_size, fsz - 1));
2297+
Value *tindex = compute_tindex_unboxed(rhs_union, jfty, ctx);
2298+
tindex = builder.CreateNUWSub(tindex, ConstantInt::get(T_int8, 1));
2299+
builder.CreateStore(tindex, ptindex);
2300+
// copy data
2301+
emit_unionmove(addr, rhs, NULL, false, NULL, ctx);
2302+
}
2303+
else {
2304+
int align = jl_field_offset(sty, idx0);
2305+
align |= 16;
2306+
align &= -align;
2307+
typed_store(addr, ConstantInt::get(T_size, 0), rhs, jfty, ctx,
2308+
strct.tbaa, data_pointer(ctx, strct, T_pjlvalue), align);
2309+
}
22442310
}
22452311
}
22462312
else {

src/codegen.cpp

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,8 @@ static inline jl_cgval_t update_julia_type(jl_codectx_t &ctx, const jl_cgval_t &
743743
return jl_cgval_t(v, typ, NULL);
744744
}
745745

746+
static jl_cgval_t convert_julia_type(const jl_cgval_t &v, jl_value_t *typ, jl_codectx_t *ctx, bool needsroot = true);
747+
746748
// --- allocating local variables ---
747749

748750
static jl_sym_t *slot_symbol(jl_codectx_t &ctx, int s)
@@ -823,7 +825,7 @@ static void jl_rethrow_with_add(const char *fmt, ...)
823825
}
824826

825827
// given a value marked with type `v.typ`, compute the mapping and/or boxing to return a value of type `typ`
826-
static jl_cgval_t convert_julia_type(jl_codectx_t &ctx, const jl_cgval_t &v, jl_value_t *typ, bool needsroot = true)
828+
static jl_cgval_t convert_julia_type(jl_codectx_t &ctx, const jl_cgval_t &v, jl_value_t *typ, bool needsroot)
827829
{
828830
if (typ == (jl_value_t*)jl_typeofbottom_type)
829831
return ghostValue(typ); // normalize TypeofBottom to Type{Union{}}
@@ -3544,33 +3546,6 @@ static Value *try_emit_union_alloca(jl_codectx_t &ctx, jl_uniontype_t *ut, bool
35443546
return NULL;
35453547
}
35463548

3547-
static Value *compute_box_tindex(jl_codectx_t &ctx, Value *datatype, jl_value_t *supertype, jl_value_t *ut)
3548-
{
3549-
Value *tindex = ConstantInt::get(T_int8, 0);
3550-
unsigned counter = 0;
3551-
for_each_uniontype_small(
3552-
[&](unsigned idx, jl_datatype_t *jt) {
3553-
if (jl_subtype((jl_value_t*)jt, supertype)) {
3554-
Value *cmp = ctx.builder.CreateICmpEQ(maybe_decay_untracked(literal_pointer_val(ctx, (jl_value_t*)jt)), datatype);
3555-
tindex = ctx.builder.CreateSelect(cmp, ConstantInt::get(T_int8, idx), tindex);
3556-
}
3557-
},
3558-
ut,
3559-
counter);
3560-
return tindex;
3561-
}
3562-
3563-
// get the runtime tindex value
3564-
static Value *compute_tindex_unboxed(jl_codectx_t &ctx, const jl_cgval_t &val, jl_value_t *typ)
3565-
{
3566-
if (val.constant)
3567-
return ConstantInt::get(T_int8, get_box_tindex((jl_datatype_t*)jl_typeof(val.constant), typ));
3568-
if (val.isboxed)
3569-
return compute_box_tindex(ctx, emit_typeof_boxed(ctx, val), val.typ, typ);
3570-
assert(val.TIndex);
3571-
return ctx.builder.CreateAnd(val.TIndex, ConstantInt::get(T_int8, 0x7f));
3572-
}
3573-
35743549
static void emit_assignment(jl_codectx_t &ctx, jl_value_t *l, jl_value_t *r)
35753550
{
35763551
if (jl_is_ssavalue(l)) {

src/datatype.c

Lines changed: 81 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,38 @@ unsigned jl_special_vector_alignment(size_t nfields, jl_value_t *t)
216216
return alignment;
217217
}
218218

219+
static int jl_layout_isbits(jl_value_t *ty)
220+
{
221+
if (jl_isbits(ty) && jl_is_leaf_type(ty)) {
222+
if (((jl_datatype_t*)ty)->layout) // layout check handles possible layout recursion
223+
return 1;
224+
}
225+
return 0;
226+
}
227+
228+
static unsigned jl_union_isbits(jl_value_t *ty, size_t *nbytes, size_t *align)
229+
{
230+
if (jl_is_uniontype(ty)) {
231+
unsigned na = jl_union_isbits(((jl_uniontype_t*)ty)->a, nbytes, align);
232+
if (na == 0)
233+
return 0;
234+
unsigned nb = jl_union_isbits(((jl_uniontype_t*)ty)->b, nbytes, align);
235+
if (nb == 0)
236+
return 0;
237+
return na + nb;
238+
}
239+
if (jl_layout_isbits(ty)) {
240+
size_t sz = jl_datatype_size(ty);
241+
size_t al = ((jl_datatype_t*)ty)->layout->alignment;
242+
if (*nbytes < sz)
243+
*nbytes = sz;
244+
if (*align < al)
245+
*align = al;
246+
return 1;
247+
}
248+
return 0;
249+
}
250+
219251
void jl_compute_field_offsets(jl_datatype_t *st)
220252
{
221253
size_t sz = 0, alignm = 1;
@@ -277,16 +309,22 @@ void jl_compute_field_offsets(jl_datatype_t *st)
277309

278310
for (size_t i = 0; i < nfields; i++) {
279311
jl_value_t *ty = jl_field_type(st, i);
280-
size_t fsz, al;
281-
if (jl_isbits(ty) && jl_is_leaf_type(ty) && ((jl_datatype_t*)ty)->layout) {
282-
fsz = jl_datatype_size(ty);
312+
size_t fsz = 0, al = 0;
313+
unsigned countbits = jl_union_isbits(ty, &fsz, &al);
314+
if (countbits > 0 && countbits < 127) {
283315
// Should never happen
284316
if (__unlikely(fsz > max_size))
285317
goto throw_ovf;
286318
al = jl_datatype_align(ty);
287319
desc[i].isptr = 0;
288-
if (((jl_datatype_t*)ty)->layout->haspadding)
320+
if (jl_is_uniontype(ty)) {
289321
haspadding = 1;
322+
fsz += 1; // selector byte
323+
}
324+
else { // isbits struct
325+
if (((jl_datatype_t*)ty)->layout->haspadding)
326+
haspadding = 1;
327+
}
290328
}
291329
else {
292330
fsz = sizeof(void*);
@@ -312,7 +350,7 @@ void jl_compute_field_offsets(jl_datatype_t *st)
312350
goto throw_ovf;
313351
sz += fsz;
314352
}
315-
if (homogeneous && lastty!=NULL && jl_is_tuple_type(st)) {
353+
if (homogeneous && lastty != NULL && jl_is_tuple_type(st)) {
316354
// Some tuples become LLVM vectors with stronger alignment than what was calculated above.
317355
unsigned al = jl_special_vector_alignment(nfields, lastty);
318356
assert(al % alignm == 0);
@@ -326,10 +364,12 @@ void jl_compute_field_offsets(jl_datatype_t *st)
326364
if (st->size > sz)
327365
haspadding = 1;
328366
st->layout = jl_get_layout(nfields, alignm, haspadding, desc);
329-
if (descsz >= jl_page_size) free(desc);
367+
if (descsz >= jl_page_size)
368+
free(desc);
330369
return;
331370
throw_ovf:
332-
if (descsz >= jl_page_size) free(desc);
371+
if (descsz >= jl_page_size)
372+
free(desc);
333373
jl_throw(jl_overflow_exception);
334374
}
335375

@@ -704,46 +744,69 @@ JL_DLLEXPORT jl_value_t *jl_get_nth_field(jl_value_t *v, size_t i)
704744
{
705745
jl_datatype_t *st = (jl_datatype_t*)jl_typeof(v);
706746
assert(i < jl_datatype_nfields(st));
707-
size_t offs = jl_field_offset(st,i);
708-
if (jl_field_isptr(st,i)) {
747+
size_t offs = jl_field_offset(st, i);
748+
if (jl_field_isptr(st, i)) {
709749
return *(jl_value_t**)((char*)v + offs);
710750
}
711-
return jl_new_bits(jl_field_type(st,i), (char*)v + offs);
751+
jl_value_t *ty = jl_field_type(st, i);
752+
if (jl_is_uniontype(ty)) {
753+
uint8_t sel = ((uint8_t*)v)[offs + jl_field_size(st, i) - 1];
754+
ty = jl_nth_union_component(ty, sel);
755+
}
756+
return jl_new_bits(ty, (char*)v + offs);
712757
}
713758

714759
JL_DLLEXPORT jl_value_t *jl_get_nth_field_checked(jl_value_t *v, size_t i)
715760
{
716761
jl_datatype_t *st = (jl_datatype_t*)jl_typeof(v);
717762
if (i >= jl_datatype_nfields(st))
718-
jl_bounds_error_int(v, i+1);
719-
size_t offs = jl_field_offset(st,i);
720-
if (jl_field_isptr(st,i)) {
763+
jl_bounds_error_int(v, i + 1);
764+
size_t offs = jl_field_offset(st, i);
765+
if (jl_field_isptr(st, i)) {
721766
jl_value_t *fval = *(jl_value_t**)((char*)v + offs);
722767
if (fval == NULL)
723768
jl_throw(jl_undefref_exception);
724769
return fval;
725770
}
726-
return jl_new_bits(jl_field_type(st,i), (char*)v + offs);
771+
jl_value_t *ty = jl_field_type(st, i);
772+
if (jl_is_uniontype(ty)) {
773+
size_t fsz = jl_field_size(st, i);
774+
uint8_t sel = ((uint8_t*)v)[offs + fsz - 1];
775+
ty = jl_nth_union_component(ty, sel);
776+
if (jl_is_datatype_singleton((jl_datatype_t*)ty))
777+
return ((jl_datatype_t*)ty)->instance;
778+
}
779+
return jl_new_bits(ty, (char*)v + offs);
727780
}
728781

729782
JL_DLLEXPORT void jl_set_nth_field(jl_value_t *v, size_t i, jl_value_t *rhs)
730783
{
731784
jl_datatype_t *st = (jl_datatype_t*)jl_typeof(v);
732-
size_t offs = jl_field_offset(st,i);
733-
if (jl_field_isptr(st,i)) {
785+
size_t offs = jl_field_offset(st, i);
786+
if (jl_field_isptr(st, i)) {
734787
*(jl_value_t**)((char*)v + offs) = rhs;
735788
if (rhs != NULL) jl_gc_wb(v, rhs);
736789
}
737790
else {
791+
jl_value_t *ty = jl_field_type(st, i);
792+
if (jl_is_uniontype(ty)) {
793+
uint8_t *psel = &((uint8_t*)v)[offs + jl_field_size(st, i) - 1];
794+
unsigned nth = 0;
795+
if (!jl_find_union_component(ty, jl_typeof(rhs), &nth))
796+
assert(0 && "invalid field assignment to isbits union");
797+
*psel = nth;
798+
if (jl_is_datatype_singleton((jl_datatype_t*)ty))
799+
return;
800+
}
738801
jl_assign_bits((char*)v + offs, rhs);
739802
}
740803
}
741804

742805
JL_DLLEXPORT int jl_field_isdefined(jl_value_t *v, size_t i)
743806
{
744807
jl_datatype_t *st = (jl_datatype_t*)jl_typeof(v);
745-
size_t offs = jl_field_offset(st,i);
746-
if (jl_field_isptr(st,i)) {
808+
size_t offs = jl_field_offset(st, i);
809+
if (jl_field_isptr(st, i)) {
747810
return *(jl_value_t**)((char*)v + offs) != NULL;
748811
}
749812
return 1;

0 commit comments

Comments
 (0)