Skip to content

Commit ace35c5

Browse files
committed
compiler: zero struct padding during map operations
Fixes #3358
1 parent 1f0bf9b commit ace35c5

File tree

1 file changed

+75
-1
lines changed

1 file changed

+75
-1
lines changed

compiler/map.go

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ func (b *builder) createMapLookup(keyType, valueType types.Type, m, key llvm.Val
8989
// growth.
9090
mapKeyAlloca, mapKeyPtr, mapKeySize := b.createTemporaryAlloca(key.Type(), "hashmap.key")
9191
b.CreateStore(key, mapKeyAlloca)
92+
b.zeroUndefBytes(keyType, mapKeyAlloca)
9293
// Fetch the value from the hashmap.
9394
params := []llvm.Value{m, mapKeyPtr, mapValuePtr, mapValueSize}
9495
commaOkValue = b.createRuntimeCall("hashmapBinaryGet", params, "")
@@ -133,6 +134,7 @@ func (b *builder) createMapUpdate(keyType types.Type, m, key, value llvm.Value,
133134
// key can be compared with runtime.memequal
134135
keyAlloca, keyPtr, keySize := b.createTemporaryAlloca(key.Type(), "hashmap.key")
135136
b.CreateStore(key, keyAlloca)
137+
b.zeroUndefBytes(keyType, keyAlloca)
136138
params := []llvm.Value{m, keyPtr, valuePtr}
137139
b.createRuntimeCall("hashmapBinarySet", params, "")
138140
b.emitLifetimeEnd(keyPtr, keySize)
@@ -161,6 +163,7 @@ func (b *builder) createMapDelete(keyType types.Type, m, key llvm.Value, pos tok
161163
} else if hashmapIsBinaryKey(keyType) {
162164
keyAlloca, keyPtr, keySize := b.createTemporaryAlloca(key.Type(), "hashmap.key")
163165
b.CreateStore(key, keyAlloca)
166+
b.zeroUndefBytes(keyType, keyAlloca)
164167
params := []llvm.Value{m, keyPtr}
165168
b.createRuntimeCall("hashmapBinaryDelete", params, "")
166169
b.emitLifetimeEnd(keyPtr, keySize)
@@ -240,7 +243,8 @@ func (b *builder) createMapIteratorNext(rangeVal ssa.Value, llvmRangeVal, it llv
240243
}
241244

242245
// Returns true if this key type does not contain strings, interfaces etc., so
243-
// can be compared with runtime.memequal.
246+
// can be compared with runtime.memequal. Note that padding bytes are undef
247+
// and can alter two "equal" structs being equal when compared with memequal.
244248
func hashmapIsBinaryKey(keyType types.Type) bool {
245249
switch keyType := keyType.(type) {
246250
case *types.Basic:
@@ -263,3 +267,73 @@ func hashmapIsBinaryKey(keyType types.Type) bool {
263267
return false
264268
}
265269
}
270+
271+
func (b *builder) zeroUndefBytes(typ types.Type, ptr llvm.Value) error {
272+
// We know that hashmapIsBinaryKey is true, so we only have to handle those types that can show up there.
273+
// To zero all undefined bytes, we iterate over all the fields in the type. For each element, compute the
274+
// offset of that element. If it's Basic type, there are no internal padding bytes. For compound types, we recurse to ensure
275+
// we handle nested types. Next, we determine if there are any padding bytes before the next
276+
// element and zero those as well.
277+
278+
switch typ := typ.Underlying().(type) {
279+
case *types.Basic:
280+
// no padding bytes
281+
return nil
282+
case *types.Pointer:
283+
// mo padding bytes
284+
return nil
285+
case *types.Named:
286+
// zero underlying type
287+
return b.zeroUndefBytes(typ.Underlying(), ptr)
288+
case *types.Array:
289+
llvmArrayType := b.getLLVMType(typ)
290+
llvmElemType := b.getLLVMType(typ.Elem())
291+
base := ptr
292+
293+
for i := int64(0); i < typ.Len(); i++ {
294+
// for each element, first clear any undef bytes in the element itself
295+
idx := llvm.ConstInt(b.uintptrType, uint64(i), false)
296+
ptr := b.CreateGEP(llvmArrayType, base, []llvm.Value{idx}, "")
297+
298+
// recursively zero any padding bytes in this element
299+
b.zeroUndefBytes(typ.Elem(), ptr)
300+
301+
// check for padding between elements
302+
// TODO(dgryski): typeSizeEqualStoreSize ?
303+
if allocSize, storeSize := b.targetData.TypeAllocSize(llvmElemType), b.targetData.TypeStoreSize(llvmElemType); allocSize != storeSize {
304+
n := llvm.ConstInt(b.uintptrType, allocSize-storeSize, false)
305+
llvmSize := llvm.ConstInt(b.uintptrType, storeSize, false)
306+
paddingStart := b.CreateGEP(b.ctx.Int8Type(), ptr, []llvm.Value{llvmSize}, "")
307+
b.createRuntimeCall("memzero", []llvm.Value{paddingStart, n}, "")
308+
}
309+
}
310+
311+
case *types.Struct:
312+
llvmStructType := b.getLLVMType(typ)
313+
base := ptr
314+
315+
for i := 0; i < typ.NumFields(); i++ {
316+
offset := b.targetData.ElementOffset(llvmStructType, int(i))
317+
llvmOffset := llvm.ConstInt(b.uintptrType, offset, false)
318+
ptr := b.CreateGEP(b.ctx.Int8Type(), base, []llvm.Value{llvmOffset}, "")
319+
320+
// zero any undef bytes in this field
321+
fieldType := typ.Field(i).Type()
322+
b.zeroUndefBytes(fieldType, ptr)
323+
324+
// zero any undef bytes before the next field, if any
325+
if i < typ.NumFields()-1 {
326+
nextOffset := b.targetData.ElementOffset(llvmStructType, i+1)
327+
llvmElemType := b.getLLVMType(fieldType)
328+
if storeSize := b.targetData.TypeStoreSize(llvmElemType); (nextOffset - offset) != storeSize {
329+
n := llvm.ConstInt(b.uintptrType, (nextOffset-offset)-storeSize, false)
330+
llvmSize := llvm.ConstInt(b.uintptrType, storeSize, false)
331+
paddingStart := b.CreateGEP(b.ctx.Int8Type(), ptr, []llvm.Value{llvmSize}, "")
332+
b.createRuntimeCall("memzero", []llvm.Value{paddingStart, n}, "")
333+
}
334+
}
335+
}
336+
}
337+
338+
return nil
339+
}

0 commit comments

Comments
 (0)