-
Notifications
You must be signed in to change notification settings - Fork 691
[spike] evaluate + prototype interaction of unified memory abstraction with custom_ops #1556
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
# Revised get_paged in functional.py
def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE):
num_bytes = dtype.itemsize * prod(shape)
# Use C++ extension instead of direct lib calls
tensor = cpp_extension.get_managed_tensor(
num_bytes,
list(shape),
dtype
)
tensor.is_paged = True
tensor.page_deviceid = device.index
return tensor // Enhanced C++ implementation in pythonInterface.cpp
torch::Tensor get_managed_tensor(
size_t nb_bytes,
c10::IntArrayRef sizes,
c10::ScalarType dtype
) {
void* cuda_ptr;
CUDA_CHECK(cudaMallocManaged(&cuda_ptr, nb_bytes, cudaMemAttachHost));
auto options = torch::TensorOptions()
.device(torch::kCUDA) // Critical device specification
.dtype(dtype)
.requires_grad(false);
return torch::from_blob(
cuda_ptr,
sizes,
[](void* ptr) { CUDA_CHECK(cudaFree(ptr)); },
options
);
}
# The tensor will report device=cuda
t = get_paged(1024, dtype=torch.float32)
print(t.device) # Output: cuda:0
# Operations will dispatch to CUDA kernels
# Unified memory handles page migration automatically
y = t @ t.T # Dispatches to CUDA GEMM
This approach satisfies all requirements while maintaining compatibility with existing optimizer infrastructure. The key innovation is creating CUDA device tensors that transparently use unified memory, enabling proper dispatch while retaining paged memory benefits. |
We closed this line of investigation because of the proposed solution:
having a dependency on torch lib which we don't want to depend on and the fact that no other implementation than AMD actually wants to use unified memory so far and AMD is dispatched under the dispatch key "cuda" anyways. Therefore, it doesn't make sense to investigate this further. |
Unified memory isn't supported in PyTorch and was considered a potential blocker for the custom ops refactor.
We found a workaround at the time, with a simple viability proof.
It's however not clear how this fits together with the current open PR #1544 and RFC #1545 and this needs to be fleshed out.
Questions:
The text was updated successfully, but these errors were encountered: