Skip to content

Refactor some std code that works with pointer offstes #100823

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 24, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 62 additions & 62 deletions library/core/src/fmt/num.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ macro_rules! impl_Display {
fn $name(mut n: $u, is_nonnegative: bool, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// 2^128 is about 3*10^38, so 39 gives an extra byte of space
let mut buf = [MaybeUninit::<u8>::uninit(); 39];
let mut curr = buf.len() as isize;
let mut curr = buf.len();
let buf_ptr = MaybeUninit::slice_as_mut_ptr(&mut buf);
let lut_ptr = DEC_DIGITS_LUT.as_ptr();

Expand All @@ -228,7 +228,7 @@ macro_rules! impl_Display {

// eagerly decode 4 characters at a time
while n >= 10000 {
let rem = (n % 10000) as isize;
let rem = (n % 10000) as usize;
n /= 10000;

let d1 = (rem / 100) << 1;
Expand All @@ -238,37 +238,37 @@ macro_rules! impl_Display {
// We are allowed to copy to `buf_ptr[curr..curr + 3]` here since
// otherwise `curr < 0`. But then `n` was originally at least `10000^10`
// which is `10^40 > 2^128 > n`.
ptr::copy_nonoverlapping(lut_ptr.offset(d1), buf_ptr.offset(curr), 2);
ptr::copy_nonoverlapping(lut_ptr.offset(d2), buf_ptr.offset(curr + 2), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d1), buf_ptr.add(curr), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d2), buf_ptr.add(curr + 2), 2);
}

// if we reach here numbers are <= 9999, so at most 4 chars long
let mut n = n as isize; // possibly reduce 64bit math
let mut n = n as usize; // possibly reduce 64bit math

// decode 2 more chars, if > 2 chars
if n >= 100 {
let d1 = (n % 100) << 1;
n /= 100;
curr -= 2;
ptr::copy_nonoverlapping(lut_ptr.offset(d1), buf_ptr.offset(curr), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d1), buf_ptr.add(curr), 2);
}

// decode last 1 or 2 chars
if n < 10 {
curr -= 1;
*buf_ptr.offset(curr) = (n as u8) + b'0';
*buf_ptr.add(curr) = (n as u8) + b'0';
} else {
let d1 = n << 1;
curr -= 2;
ptr::copy_nonoverlapping(lut_ptr.offset(d1), buf_ptr.offset(curr), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d1), buf_ptr.add(curr), 2);
}
}

// SAFETY: `curr` > 0 (since we made `buf` large enough), and all the chars are valid
// UTF-8 since `DEC_DIGITS_LUT` is
let buf_slice = unsafe {
str::from_utf8_unchecked(
slice::from_raw_parts(buf_ptr.offset(curr), buf.len() - curr as usize))
slice::from_raw_parts(buf_ptr.add(curr), buf.len() - curr))
};
f.pad_integral(is_nonnegative, "", buf_slice)
}
Expand Down Expand Up @@ -339,18 +339,18 @@ macro_rules! impl_Exp {
// Since `curr` always decreases by the number of digits copied, this means
// that `curr >= 0`.
let mut buf = [MaybeUninit::<u8>::uninit(); 40];
let mut curr = buf.len() as isize; //index for buf
let mut curr = buf.len(); //index for buf
let buf_ptr = MaybeUninit::slice_as_mut_ptr(&mut buf);
let lut_ptr = DEC_DIGITS_LUT.as_ptr();

// decode 2 chars at a time
while n >= 100 {
let d1 = ((n % 100) as isize) << 1;
let d1 = ((n % 100) << 1) as usize;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For safety, can you switch this to be more similar to the old one?

Suggested change
let d1 = ((n % 100) << 1) as usize;
let d1 = ((n % 100) as usize) << 1;

The macro defining this function is instantiated with i8 (

impl_Exp!(
i8, u8, i16, u16, i32, u32, i64, u64, usize, isize
as u64 via to_u64 named exp_u64
);
), meaning that n could be i8, so if n was 99_i8 then the shift would overflow if it's done before the cast.

(Can that ever happen? No idea. Up to you if you want to try to write a test to trigger it somehow. But I think it's worth changing the code to obviously not be a problem just in case.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, wait, no, we only get here if n >= 100, but (127_i8 % 100) << 1, so I guess it'd be fine.

That's still a bit too subtle for my taste, though, so I'd rather just shift in the bigger type.

curr -= 2;
// SAFETY: `d1 <= 198`, so we can copy from `lut_ptr[d1..d1 + 2]` since
// `DEC_DIGITS_LUT` has a length of 200.
unsafe {
ptr::copy_nonoverlapping(lut_ptr.offset(d1), buf_ptr.offset(curr), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d1), buf_ptr.add(curr), 2);
}
n /= 100;
exponent += 2;
Expand All @@ -362,7 +362,7 @@ macro_rules! impl_Exp {
curr -= 1;
// SAFETY: Safe since `40 > curr >= 0` (see comment)
unsafe {
*buf_ptr.offset(curr) = (n as u8 % 10_u8) + b'0';
*buf_ptr.add(curr) = (n as u8 % 10_u8) + b'0';
}
n /= 10;
exponent += 1;
Expand All @@ -372,18 +372,18 @@ macro_rules! impl_Exp {
curr -= 1;
// SAFETY: Safe since `40 > curr >= 0`
unsafe {
*buf_ptr.offset(curr) = b'.';
*buf_ptr.add(curr) = b'.';
}
}

// SAFETY: Safe since `40 > curr >= 0`
let buf_slice = unsafe {
// decode last character
curr -= 1;
*buf_ptr.offset(curr) = (n as u8) + b'0';
*buf_ptr.add(curr) = (n as u8) + b'0';

let len = buf.len() - curr as usize;
slice::from_raw_parts(buf_ptr.offset(curr), len)
slice::from_raw_parts(buf_ptr.add(curr), len)
};

// stores 'e' (or 'E') and the up to 2-digit exponent
Expand All @@ -392,13 +392,13 @@ macro_rules! impl_Exp {
// SAFETY: In either case, `exp_buf` is written within bounds and `exp_ptr[..len]`
// is contained within `exp_buf` since `len <= 3`.
let exp_slice = unsafe {
*exp_ptr.offset(0) = if upper { b'E' } else { b'e' };
*exp_ptr.add(0) = if upper { b'E' } else { b'e' };
let len = if exponent < 10 {
*exp_ptr.offset(1) = (exponent as u8) + b'0';
*exp_ptr.add(1) = (exponent as u8) + b'0';
2
} else {
let off = exponent << 1;
ptr::copy_nonoverlapping(lut_ptr.offset(off), exp_ptr.offset(1), 2);
ptr::copy_nonoverlapping(lut_ptr.add(off), exp_ptr.add(1), 2);
3
};
slice::from_raw_parts(exp_ptr, len)
Expand Down Expand Up @@ -479,7 +479,7 @@ mod imp {
impl_Exp!(i128, u128 as u128 via to_u128 named exp_u128);

/// Helper function for writing a u64 into `buf` going from last to first, with `curr`.
fn parse_u64_into<const N: usize>(mut n: u64, buf: &mut [MaybeUninit<u8>; N], curr: &mut isize) {
fn parse_u64_into<const N: usize>(mut n: u64, buf: &mut [MaybeUninit<u8>; N], curr: &mut usize) {
let buf_ptr = MaybeUninit::slice_as_mut_ptr(buf);
let lut_ptr = DEC_DIGITS_LUT.as_ptr();
assert!(*curr > 19);
Expand All @@ -494,73 +494,73 @@ fn parse_u64_into<const N: usize>(mut n: u64, buf: &mut [MaybeUninit<u8>; N], cu
n /= 1e16 as u64;

// Some of these are nops but it looks more elegant this way.
let d1 = ((to_parse / 1e14 as u64) % 100) << 1;
let d2 = ((to_parse / 1e12 as u64) % 100) << 1;
let d3 = ((to_parse / 1e10 as u64) % 100) << 1;
let d4 = ((to_parse / 1e8 as u64) % 100) << 1;
let d5 = ((to_parse / 1e6 as u64) % 100) << 1;
let d6 = ((to_parse / 1e4 as u64) % 100) << 1;
let d7 = ((to_parse / 1e2 as u64) % 100) << 1;
let d8 = ((to_parse / 1e0 as u64) % 100) << 1;
let d1 = (((to_parse / 1e14 as u64) % 100) << 1) as usize;
let d2 = (((to_parse / 1e12 as u64) % 100) << 1) as usize;
let d3 = (((to_parse / 1e10 as u64) % 100) << 1) as usize;
let d4 = (((to_parse / 1e8 as u64) % 100) << 1) as usize;
let d5 = (((to_parse / 1e6 as u64) % 100) << 1) as usize;
let d6 = (((to_parse / 1e4 as u64) % 100) << 1) as usize;
let d7 = (((to_parse / 1e2 as u64) % 100) << 1) as usize;
let d8 = (((to_parse / 1e0 as u64) % 100) << 1) as usize;

*curr -= 16;

ptr::copy_nonoverlapping(lut_ptr.offset(d1 as isize), buf_ptr.offset(*curr + 0), 2);
ptr::copy_nonoverlapping(lut_ptr.offset(d2 as isize), buf_ptr.offset(*curr + 2), 2);
ptr::copy_nonoverlapping(lut_ptr.offset(d3 as isize), buf_ptr.offset(*curr + 4), 2);
ptr::copy_nonoverlapping(lut_ptr.offset(d4 as isize), buf_ptr.offset(*curr + 6), 2);
ptr::copy_nonoverlapping(lut_ptr.offset(d5 as isize), buf_ptr.offset(*curr + 8), 2);
ptr::copy_nonoverlapping(lut_ptr.offset(d6 as isize), buf_ptr.offset(*curr + 10), 2);
ptr::copy_nonoverlapping(lut_ptr.offset(d7 as isize), buf_ptr.offset(*curr + 12), 2);
ptr::copy_nonoverlapping(lut_ptr.offset(d8 as isize), buf_ptr.offset(*curr + 14), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d1), buf_ptr.add(*curr + 0), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d2), buf_ptr.add(*curr + 2), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d3), buf_ptr.add(*curr + 4), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d4), buf_ptr.add(*curr + 6), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d5), buf_ptr.add(*curr + 8), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d6), buf_ptr.add(*curr + 10), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d7), buf_ptr.add(*curr + 12), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d8), buf_ptr.add(*curr + 14), 2);
}
if n >= 1e8 as u64 {
let to_parse = n % 1e8 as u64;
n /= 1e8 as u64;

// Some of these are nops but it looks more elegant this way.
let d1 = ((to_parse / 1e6 as u64) % 100) << 1;
let d2 = ((to_parse / 1e4 as u64) % 100) << 1;
let d3 = ((to_parse / 1e2 as u64) % 100) << 1;
let d4 = ((to_parse / 1e0 as u64) % 100) << 1;
let d1 = (((to_parse / 1e6 as u64) % 100) << 1) as usize;
let d2 = (((to_parse / 1e4 as u64) % 100) << 1) as usize;
let d3 = (((to_parse / 1e2 as u64) % 100) << 1) as usize;
let d4 = (((to_parse / 1e0 as u64) % 100) << 1) as usize;
*curr -= 8;

ptr::copy_nonoverlapping(lut_ptr.offset(d1 as isize), buf_ptr.offset(*curr + 0), 2);
ptr::copy_nonoverlapping(lut_ptr.offset(d2 as isize), buf_ptr.offset(*curr + 2), 2);
ptr::copy_nonoverlapping(lut_ptr.offset(d3 as isize), buf_ptr.offset(*curr + 4), 2);
ptr::copy_nonoverlapping(lut_ptr.offset(d4 as isize), buf_ptr.offset(*curr + 6), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d1), buf_ptr.add(*curr + 0), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d2), buf_ptr.add(*curr + 2), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d3), buf_ptr.add(*curr + 4), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d4), buf_ptr.add(*curr + 6), 2);
}
// `n` < 1e8 < (1 << 32)
let mut n = n as u32;
if n >= 1e4 as u32 {
let to_parse = n % 1e4 as u32;
n /= 1e4 as u32;

let d1 = (to_parse / 100) << 1;
let d2 = (to_parse % 100) << 1;
let d1 = ((to_parse / 100) << 1) as usize;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pondering: I'm not sure if moving the casts up to the dN variables helps. Based on the structure hwere, where these various blocks are working in particular bitwidths based on the range of n, I think it's kinda nice to have then dN variables still be in those types. Which would then imply keeping the cast in the pointer operations, just as .add(d1 as usize) instead of offset+isize. How do you feel about that? I'm not sure how strongly I feel here, so absolutely feel free to argue either way 🙃

(The curr being usize is definitely a nice change, though.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually agree, dN are only used once, so it doesn't really matter where the cast is, and with the cast in add it looks nicer.

let d2 = ((to_parse % 100) << 1) as usize;
*curr -= 4;

ptr::copy_nonoverlapping(lut_ptr.offset(d1 as isize), buf_ptr.offset(*curr + 0), 2);
ptr::copy_nonoverlapping(lut_ptr.offset(d2 as isize), buf_ptr.offset(*curr + 2), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d1), buf_ptr.add(*curr + 0), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d2), buf_ptr.add(*curr + 2), 2);
}

// `n` < 1e4 < (1 << 16)
let mut n = n as u16;
if n >= 100 {
let d1 = (n % 100) << 1;
let d1 = ((n % 100) << 1) as usize;
n /= 100;
*curr -= 2;
ptr::copy_nonoverlapping(lut_ptr.offset(d1 as isize), buf_ptr.offset(*curr), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d1), buf_ptr.add(*curr), 2);
}

// decode last 1 or 2 chars
if n < 10 {
*curr -= 1;
*buf_ptr.offset(*curr) = (n as u8) + b'0';
*buf_ptr.add(*curr) = (n as u8) + b'0';
} else {
let d1 = n << 1;
let d1 = (n << 1) as usize;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aside: I guess this one is also overflowing when n is 99_i8, even before your change, but it just overflows into the sign bit which I guess doesn't matter in the end after the casts?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n is cast to u16 on line 548, so no overflows here :)

Overflow into a sign bit would be pretty bad:

[src/main.rs:2] 99i8 << 1 = -58
[src/main.rs:2] (99i8 << 1) as usize = 18446744073709551558

*curr -= 2;
ptr::copy_nonoverlapping(lut_ptr.offset(d1 as isize), buf_ptr.offset(*curr), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d1), buf_ptr.add(*curr), 2);
}
}
}
Expand Down Expand Up @@ -593,21 +593,21 @@ impl fmt::Display for i128 {
fn fmt_u128(n: u128, is_nonnegative: bool, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// 2^128 is about 3*10^38, so 39 gives an extra byte of space
let mut buf = [MaybeUninit::<u8>::uninit(); 39];
let mut curr = buf.len() as isize;
let mut curr = buf.len();

let (n, rem) = udiv_1e19(n);
parse_u64_into(rem, &mut buf, &mut curr);

if n != 0 {
// 0 pad up to point
let target = (buf.len() - 19) as isize;
let target = buf.len() - 19;
// SAFETY: Guaranteed that we wrote at most 19 bytes, and there must be space
// remaining since it has length 39
unsafe {
ptr::write_bytes(
MaybeUninit::slice_as_mut_ptr(&mut buf).offset(target),
MaybeUninit::slice_as_mut_ptr(&mut buf).add(target),
b'0',
(curr - target) as usize,
curr - target,
);
}
curr = target;
Expand All @@ -616,16 +616,16 @@ fn fmt_u128(n: u128, is_nonnegative: bool, f: &mut fmt::Formatter<'_>) -> fmt::R
parse_u64_into(rem, &mut buf, &mut curr);
// Should this following branch be annotated with unlikely?
if n != 0 {
let target = (buf.len() - 38) as isize;
let target = buf.len() - 38;
// The raw `buf_ptr` pointer is only valid until `buf` is used the next time,
// buf `buf` is not used in this scope so we are good.
let buf_ptr = MaybeUninit::slice_as_mut_ptr(&mut buf);
// SAFETY: At this point we wrote at most 38 bytes, pad up to that point,
// There can only be at most 1 digit remaining.
unsafe {
ptr::write_bytes(buf_ptr.offset(target), b'0', (curr - target) as usize);
ptr::write_bytes(buf_ptr.add(target), b'0', curr - target);
curr = target - 1;
*buf_ptr.offset(curr) = (n as u8) + b'0';
*buf_ptr.add(curr) = (n as u8) + b'0';
}
}
}
Expand All @@ -634,8 +634,8 @@ fn fmt_u128(n: u128, is_nonnegative: bool, f: &mut fmt::Formatter<'_>) -> fmt::R
// UTF-8 since `DEC_DIGITS_LUT` is
let buf_slice = unsafe {
str::from_utf8_unchecked(slice::from_raw_parts(
MaybeUninit::slice_as_mut_ptr(&mut buf).offset(curr),
buf.len() - curr as usize,
MaybeUninit::slice_as_mut_ptr(&mut buf).add(curr),
buf.len() - curr,
))
};
f.pad_integral(is_nonnegative, "", buf_slice)
Expand Down
4 changes: 2 additions & 2 deletions library/core/src/slice/memchr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ pub fn memrchr(x: u8, text: &[u8]) -> Option<usize> {
// SAFETY: offset starts at len - suffix.len(), as long as it is greater than
// min_aligned_offset (prefix.len()) the remaining distance is at least 2 * chunk_bytes.
unsafe {
let u = *(ptr.offset(offset as isize - 2 * chunk_bytes as isize) as *const Chunk);
let v = *(ptr.offset(offset as isize - chunk_bytes as isize) as *const Chunk);
let u = *(ptr.add(offset - 2 * chunk_bytes) as *const Chunk);
let v = *(ptr.add(offset - chunk_bytes) as *const Chunk);

// Break if there is a matching byte.
let zu = contains_zero_byte(u ^ repeated_x);
Expand Down
28 changes: 14 additions & 14 deletions library/std/src/sys/sgx/abi/usercalls/alloc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,9 @@ where
// | small1 | Chunk smaller than 8 bytes
// +--------+
fn region_as_aligned_chunks(ptr: *const u8, len: usize) -> (usize, usize, usize) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pondering (not this PR): this looks an awful lot like .align_to(8).

let small0_size = if ptr as usize % 8 == 0 { 0 } else { 8 - ptr as usize % 8 };
let small1_size = (len - small0_size as usize) % 8;
let big_size = len - small0_size as usize - small1_size as usize;
let small0_size = if ptr.is_aligned_to(8) { 0 } else { 8 - ptr.addr() % 8 };
let small1_size = (len - small0_size) % 8;
let big_size = len - small0_size - small1_size;

(small0_size, big_size, small1_size)
}
Expand Down Expand Up @@ -364,8 +364,8 @@ pub(crate) unsafe fn copy_to_userspace(src: *const u8, dst: *mut u8, len: usize)
mfence
lfence
",
val = in(reg_byte) *src.offset(off as isize),
dst = in(reg) dst.offset(off as isize),
val = in(reg_byte) *src.add(off),
dst = in(reg) dst.add(off),
seg_sel = in(reg) &mut seg_sel,
options(nostack, att_syntax)
);
Expand All @@ -378,8 +378,8 @@ pub(crate) unsafe fn copy_to_userspace(src: *const u8, dst: *mut u8, len: usize)
assert!(is_enclave_range(src, len));
assert!(is_user_range(dst, len));
assert!(len < isize::MAX as usize);
assert!(!(src as usize).overflowing_add(len).1);
assert!(!(dst as usize).overflowing_add(len).1);
assert!(!src.addr().overflowing_add(len).1);
assert!(!dst.addr().overflowing_add(len).1);

if len < 8 {
// Can't align on 8 byte boundary: copy safely byte per byte
Expand All @@ -404,17 +404,17 @@ pub(crate) unsafe fn copy_to_userspace(src: *const u8, dst: *mut u8, len: usize)

unsafe {
// Copy small0
copy_bytewise_to_userspace(src, dst, small0_size as _);
copy_bytewise_to_userspace(src, dst, small0_size);

// Copy big
let big_src = src.offset(small0_size as _);
let big_dst = dst.offset(small0_size as _);
copy_quadwords(big_src as _, big_dst, big_size);
let big_src = src.add(small0_size);
let big_dst = dst.add(small0_size);
copy_quadwords(big_src, big_dst, big_size);

// Copy small1
let small1_src = src.offset(big_size as isize + small0_size as isize);
let small1_dst = dst.offset(big_size as isize + small0_size as isize);
copy_bytewise_to_userspace(small1_src, small1_dst, small1_size as _);
let small1_src = src.add(big_size + small0_size);
let small1_dst = dst.add(big_size + small0_size);
copy_bytewise_to_userspace(small1_src, small1_dst, small1_size);
}
}
}
Expand Down