From c4d4c66cb1f4aaefdddd2b67bf0ebf155880a4e7 Mon Sep 17 00:00:00 2001 From: Arushi-Gupta13 Date: Thu, 6 Feb 2025 22:17:00 +0530 Subject: [PATCH] Added GPU Support --- src/discretize.jl | 1 - src/eltype_matching.jl | 13 +++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/discretize.jl b/src/discretize.jl index 43653ba7c..ffd73770c 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -321,7 +321,6 @@ function get_numeric_integral(pinnrep::PINNRepresentation) return integration_arr end end - """ prob = symbolic_discretize(pde_system::PDESystem, discretization::AbstractPINN) diff --git a/src/eltype_matching.jl b/src/eltype_matching.jl index ca5aeecd7..d178488db 100644 --- a/src/eltype_matching.jl +++ b/src/eltype_matching.jl @@ -1,8 +1,17 @@ struct EltypeAdaptor{T} end -(l::EltypeAdaptor)(x) = fmap(Adapt.adapt(l), x) +function ensure_same_device(x, device) + if (typeof(x) != device) && !(x isa Number) + error("Device mismatch detected. Ensure all data is on the same device.") + end + return x +end + + +(l::EltypeAdaptor)(x) = fmap(y -> ensure_same_device(y, l), x) + function (l::EltypeAdaptor)(x::AbstractArray{T}) where {T} - return (isbitstype(T) || T <: Number) ? Adapt.adapt(l, x) : map(l, x) + return (isbitstype(T) || T <: Number) ? x : map(y -> ensure_same_device(y, l), x) end function Adapt.adapt_storage(::EltypeAdaptor{T}, x::AbstractArray) where {T}