diff --git a/.github/workflows/gpu_tests.yml b/.github/workflows/gpu_tests.yml new file mode 100644 index 0000000..3fcfc36 --- /dev/null +++ b/.github/workflows/gpu_tests.yml @@ -0,0 +1,48 @@ +name: GPU jobs + +on: [ push, pull_request ] + +permissions: + contents: read # to fetch code (actions/checkout) + +env: + CCACHE_DIR: "${{ github.workspace }}/.ccache" + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + cupy_tests: + name: CuPy GPU + runs-on: ghcr.io/cirruslabs/ubuntu-runner-amd64-gpu:22.04 + steps: + - name: Checkout xsf repo + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 + with: + submodules: recursive + + - name: Setup compiler cache + uses: cirruslabs/cache@v4 #caa3ad0624c6c2acd8ba50ad452d1f44bba078bb # v4 + with: + path: ${{ env.CCACHE_DIR }} + # Make primary key unique by using `run_id`, this ensures the cache + # is always saved at the end. + key: ${{ runner.os }}-gpu-ccache-${{ github.run_id }} + restore-keys: | + ${{ runner.os }}-gpu-ccache + + - name: run nvidia-smi + run: nvidia-smi + + - name: run nvidia-smi --query + run: nvidia-smi --query + + - uses: prefix-dev/setup-pixi@ba3bb36eb2066252b2363392b7739741bb777659 # v0.8.1 + with: + pixi-version: v0.39.2 + manifest-path: pixi.toml + cache: false + + - name: Run CuPy tests + run: pixi run test-cupy diff --git a/include/xsf/config.h b/include/xsf/config.h index 2fed489..ec7f0bb 100644 --- a/include/xsf/config.h +++ b/include/xsf/config.h @@ -108,7 +108,7 @@ XSF_HOST_DEVICE inline bool signbit(double x) { return cuda::std::signbit(x); } XSF_HOST_DEVICE inline double hypot(double x, double y) { return cuda::std::hypot(x, y); } // Fallback to global namespace for functions unsupported on NVRTC -#ifndef _LIBCUDACXX_COMPILER_NVRTC +#ifndef __CUDACC_RTC__ XSF_HOST_DEVICE inline double ceil(double x) { return cuda::std::ceil(x); } XSF_HOST_DEVICE inline double floor(double x) { return cuda::std::floor(x); } XSF_HOST_DEVICE inline double round(double x) { return cuda::std::round(x); } @@ -210,8 +210,13 @@ using enable_if = cuda::std::enable_if; template using decay = cuda::std::decay; -template -using invoke_result = cuda::std::invoke_result; +template +struct invoke_result { + using type = decltype(cuda::std::declval()()); +}; + +template +using invoke_result_t = typename invoke_result::type; template using pair = cuda::std::pair; diff --git a/pixi.lock b/pixi.lock index 45d702d..9dc4928 100644 --- a/pixi.lock +++ b/pixi.lock @@ -1,5 +1,70 @@ version: 6 environments: + cupy-tests: + channels: + - url: https://prefix.dev/conda-forge/ + packages: + linux-64: + - conda: https://prefix.dev/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 + - conda: https://prefix.dev/conda-forge/linux-64/_openmp_mutex-4.5-2_gnu.tar.bz2 + - conda: https://prefix.dev/conda-forge/linux-64/bzip2-1.0.8-h4bc722e_7.conda + - conda: https://prefix.dev/conda-forge/linux-64/ca-certificates-2025.1.31-hbcca054_0.conda + - conda: https://prefix.dev/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/cuda-cccl_linux-64-12.9.27-ha770c72_0.conda + - conda: https://prefix.dev/conda-forge/noarch/cuda-cudart-dev_linux-64-12.9.79-h3f2d84a_0.conda + - conda: https://prefix.dev/conda-forge/noarch/cuda-cudart-static_linux-64-12.9.79-h3f2d84a_0.conda + - conda: https://prefix.dev/conda-forge/noarch/cuda-cudart_linux-64-12.9.79-h3f2d84a_0.conda + - conda: https://prefix.dev/conda-forge/linux-64/cuda-nvrtc-12.9.86-h5888daf_0.conda + - conda: https://prefix.dev/conda-forge/noarch/cuda-version-12.9-h4f385c5_3.conda + - conda: https://prefix.dev/conda-forge/linux-64/cupy-13.4.1-py312h78400a1_1.conda + - conda: https://prefix.dev/conda-forge/linux-64/cupy-core-13.4.1-py312h007fbcc_1.conda + - conda: https://prefix.dev/conda-forge/noarch/exceptiongroup-1.3.0-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/linux-64/fastrlock-0.8.3-py312h6edf5ed_1.conda + - conda: https://prefix.dev/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/linux-64/ld_impl_linux-64-2.43-h712a8e2_4.conda + - conda: https://prefix.dev/conda-forge/linux-64/libblas-3.9.0-31_h59b9bed_openblas.conda + - conda: https://prefix.dev/conda-forge/linux-64/libcblas-3.9.0-31_he106b2a_openblas.conda + - conda: https://prefix.dev/conda-forge/linux-64/libcublas-12.9.1.4-h9ab20c4_0.conda + - conda: https://prefix.dev/conda-forge/linux-64/libcufft-11.4.1.4-h5888daf_0.conda + - conda: https://prefix.dev/conda-forge/linux-64/libcurand-10.3.10.19-h9ab20c4_0.conda + - conda: https://prefix.dev/conda-forge/linux-64/libcusolver-11.7.5.82-h9ab20c4_0.conda + - conda: https://prefix.dev/conda-forge/linux-64/libcusparse-12.5.10.65-h5888daf_0.conda + - conda: https://prefix.dev/conda-forge/linux-64/libexpat-2.7.0-h5888daf_0.conda + - conda: https://prefix.dev/conda-forge/linux-64/libffi-3.4.6-h2dba641_1.conda + - conda: https://prefix.dev/conda-forge/linux-64/libgcc-14.2.0-h767d61c_2.conda + - conda: https://prefix.dev/conda-forge/linux-64/libgcc-ng-14.2.0-h69a702a_2.conda + - conda: https://prefix.dev/conda-forge/linux-64/libgfortran-14.2.0-h69a702a_2.conda + - conda: https://prefix.dev/conda-forge/linux-64/libgfortran5-14.2.0-hf1ad2bd_2.conda + - conda: https://prefix.dev/conda-forge/linux-64/libgomp-14.2.0-h767d61c_2.conda + - conda: https://prefix.dev/conda-forge/linux-64/liblapack-3.9.0-31_h7ac8fdf_openblas.conda + - conda: https://prefix.dev/conda-forge/linux-64/liblzma-5.8.1-hb9d3cd8_0.conda + - conda: https://prefix.dev/conda-forge/linux-64/libnsl-2.0.1-hd590300_0.conda + - conda: https://prefix.dev/conda-forge/linux-64/libnvjitlink-12.9.86-h5888daf_0.conda + - conda: https://prefix.dev/conda-forge/linux-64/libopenblas-0.3.29-pthreads_h94d23a6_0.conda + - conda: https://prefix.dev/conda-forge/linux-64/libsqlite-3.50.1-hee588c1_0.conda + - conda: https://prefix.dev/conda-forge/linux-64/libstdcxx-14.2.0-h8f9b012_2.conda + - conda: https://prefix.dev/conda-forge/linux-64/libuuid-2.38.1-h0b41bf4_0.conda + - conda: https://prefix.dev/conda-forge/linux-64/libxcrypt-4.4.36-hd590300_1.conda + - conda: https://prefix.dev/conda-forge/linux-64/libzlib-1.3.1-hb9d3cd8_2.conda + - conda: https://prefix.dev/conda-forge/linux-64/ncurses-6.5-h2d0b736_3.conda + - conda: https://prefix.dev/conda-forge/linux-64/numpy-2.2.6-py312h72c5963_0.conda + - conda: https://prefix.dev/conda-forge/linux-64/openssl-3.5.0-h7b32b05_0.conda + - conda: https://prefix.dev/conda-forge/noarch/packaging-25.0-pyh29332c3_1.conda + - conda: https://prefix.dev/conda-forge/noarch/pip-25.1.1-pyh8b19718_0.conda + - conda: https://prefix.dev/conda-forge/noarch/pluggy-1.6.0-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/py-1.11.0-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/pygments-2.19.1-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/pytest-8.4.0-pyhd8ed1ab_0.conda + - conda: https://prefix.dev/conda-forge/noarch/pytest-forked-1.6.0-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/linux-64/python-3.12.11-h9e4cc4f_0_cpython.conda + - conda: https://prefix.dev/conda-forge/noarch/python_abi-3.12-7_cp312.conda + - conda: https://prefix.dev/conda-forge/linux-64/readline-8.2-h8c095d6_2.conda + - conda: https://prefix.dev/conda-forge/noarch/setuptools-80.9.0-pyhff2d567_0.conda + - conda: https://prefix.dev/conda-forge/linux-64/tk-8.6.13-noxft_hd72426e_102.conda + - conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda + - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.14.0-pyhe01879c_0.conda + - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025b-h78e105d_0.conda + - conda: https://prefix.dev/conda-forge/noarch/wheel-0.45.1-pyhd8ed1ab_1.conda default: channels: - url: https://prefix.dev/conda-forge/ @@ -2432,6 +2497,15 @@ packages: license_family: BSD size: 14271971 timestamp: 1740468587344 +- conda: https://prefix.dev/conda-forge/noarch/colorama-0.4.6-pyhd8ed1ab_1.conda + sha256: ab29d57dc70786c1269633ba3dff20288b81664d3ff8d21af995742e2bb03287 + md5: 962b9857ee8e7018c22f2776ffa0b2d7 + depends: + - python >=3.9 + license: BSD-3-Clause + license_family: BSD + size: 27011 + timestamp: 1733218222191 - conda: https://prefix.dev/conda-forge/osx-64/compiler-rt-18.1.8-h1020d70_1.conda sha256: 30bd259ad8909c02ee9da8b13bf7c9f6dc0f4d6fa3c5d1cd82213180ca5f9c03 md5: bc1714a1e73be18e411cff30dc1fe011 @@ -2480,6 +2554,109 @@ packages: license_family: APACHE size: 10583287 timestamp: 1725258124186 +- conda: https://prefix.dev/conda-forge/noarch/cuda-cccl_linux-64-12.9.27-ha770c72_0.conda + sha256: 2ee3b9564ca326226e5cda41d11b251482df8e7c757e333d28ec75213c75d126 + md5: 87ff6381e33b76e5b9b179a2cdd005ec + depends: + - cuda-version >=12.9,<12.10.0a0 + license: LicenseRef-NVIDIA-End-User-License-Agreement + size: 1150650 + timestamp: 1746189825236 +- conda: https://prefix.dev/conda-forge/noarch/cuda-cudart-dev_linux-64-12.9.79-h3f2d84a_0.conda + sha256: ffe86ed0144315b276f18020d836c8ef05bf971054cf7c3eb167af92494080d5 + md5: 86e40eb67d83f1a58bdafdd44e5a77c6 + depends: + - cuda-cccl_linux-64 + - cuda-cudart-static_linux-64 + - cuda-cudart_linux-64 + - cuda-version >=12.9,<12.10.0a0 + license: LicenseRef-NVIDIA-End-User-License-Agreement + size: 389140 + timestamp: 1749218427266 +- conda: https://prefix.dev/conda-forge/noarch/cuda-cudart-static_linux-64-12.9.79-h3f2d84a_0.conda + sha256: d435f8a19b59b52ce460ee3a6bfd877288a0d1d645119a6ba60f1c3627dc5032 + md5: b87bf315d81218dd63eb46cc1eaef775 + depends: + - cuda-version >=12.9,<12.10.0a0 + license: LicenseRef-NVIDIA-End-User-License-Agreement + size: 1148889 + timestamp: 1749218381225 +- conda: https://prefix.dev/conda-forge/noarch/cuda-cudart_linux-64-12.9.79-h3f2d84a_0.conda + sha256: 6cde0ace2b995b49d0db2eefb7bc30bf00ffc06bb98ef7113632dec8f8907475 + md5: 64508631775fbbf9eca83c84b1df0cae + depends: + - cuda-version >=12.9,<12.10.0a0 + license: LicenseRef-NVIDIA-End-User-License-Agreement + size: 197249 + timestamp: 1749218394213 +- conda: https://prefix.dev/conda-forge/linux-64/cuda-nvrtc-12.9.86-h5888daf_0.conda + sha256: 4d339c411c23d40ff3a8671284e476a31b31273b1a4d29c680c01940a559bd95 + md5: 9c52e4389e54d4f5800b23512e479479 + depends: + - __glibc >=2.17,<3.0.a0 + - cuda-version >=12.9,<12.10.0a0 + - libgcc >=13 + - libstdcxx >=13 + license: LicenseRef-NVIDIA-End-User-License-Agreement + size: 67183992 + timestamp: 1749221543691 +- conda: https://prefix.dev/conda-forge/noarch/cuda-version-12.9-h4f385c5_3.conda + sha256: 5f5f428031933f117ff9f7fcc650e6ea1b3fef5936cf84aa24af79167513b656 + md5: b6d5d7f1c171cbd228ea06b556cfa859 + constrains: + - cudatoolkit 12.9|12.9.* + - __cuda >=12 + license: LicenseRef-NVIDIA-End-User-License-Agreement + size: 21578 + timestamp: 1746134436166 +- conda: https://prefix.dev/conda-forge/linux-64/cupy-13.4.1-py312h78400a1_1.conda + sha256: a80d9c747675e417a1bdbd997784b1b6939a5b586b67f27e63f4d480d0d2a0bc + md5: 23a0a0c65bb4edf97dd318158db6b033 + depends: + - cuda-cudart-dev_linux-64 + - cuda-nvrtc + - cuda-version >=12,<13.0a0 + - cupy-core 13.4.1 py312h007fbcc_1 + - libcublas + - libcufft + - libcurand + - libcusolver + - libcusparse + - python >=3.12,<3.13.0a0 + - python_abi 3.12.* *_cp312 + license: MIT + license_family: MIT + size: 357438 + timestamp: 1749506001576 +- conda: https://prefix.dev/conda-forge/linux-64/cupy-core-13.4.1-py312h007fbcc_1.conda + sha256: b38b71ef4ba305a7ac6802b0199b5a2db6b76548026fb04a81da86c824a25074 + md5: 1783b7098b5151908dd4af4116347b5a + depends: + - __glibc >=2.17,<3.0.a0 + - fastrlock >=0.8.3,<0.9.0a0 + - libgcc >=13 + - libstdcxx >=13 + - numpy >=1.22,<2.3 + - python >=3.12,<3.13.0a0 + - python_abi 3.12.* *_cp312 + constrains: + - cuda-version >=12,<13.0a0 + - cupy >=13.4.1,<13.5.0a0 + - scipy >=1.7,<1.17 + - libcurand >=10,<11.0a0 + - __cuda >=12.0 + - cutensor >=2.2.0.0,<3.0a0 + - libcublas >=12,<13.0a0 + - libcusparse >=12,<13.0a0 + - optuna ~=3.0 + - libcufft >=11,<12.0a0 + - cuda-nvrtc >=12,<13.0a0 + - nccl >=2.27.3.1,<3.0a0 + - libcusolver >=11,<12.0a0 + license: MIT + license_family: MIT + size: 49560329 + timestamp: 1749505893183 - conda: https://prefix.dev/conda-forge/linux-64/cxx-compiler-1.9.0-h1a2810e_0.conda sha256: 5efc51b8e7d87fc5380f00ace9f9c758142eade520a63d3631d2616d1c1b25f9 md5: 1ce8b218d359d9ed0ab481f2a3f3c512 @@ -2520,6 +2697,29 @@ packages: license_family: BSD size: 6528 timestamp: 1736437098756 +- conda: https://prefix.dev/conda-forge/noarch/exceptiongroup-1.3.0-pyhd8ed1ab_0.conda + sha256: ce61f4f99401a4bd455b89909153b40b9c823276aefcbb06f2044618696009ca + md5: 72e42d28960d875c7654614f8b50939a + depends: + - python >=3.9 + - typing_extensions >=4.6.0 + license: MIT and PSF-2.0 + size: 21284 + timestamp: 1746947398083 +- conda: https://prefix.dev/conda-forge/linux-64/fastrlock-0.8.3-py312h6edf5ed_1.conda + sha256: 260589d271cfdd4bf04d084084123be3e49e9017da159f27bea5dc8617eaada6 + md5: 2e401040f77cf54d8d5e1f0417dcf0b2 + depends: + - python + - __glibc >=2.17,<3.0.a0 + - libgcc >=13 + - libstdcxx >=13 + - libgcc >=13 + - python_abi 3.12.* *_cp312 + license: MIT + license_family: MIT + size: 41705 + timestamp: 1734873425804 - conda: https://prefix.dev/conda-forge/linux-64/gcc-13.3.0-h9576a4e_2.conda sha256: 300f077029e7626d69cc250a69acd6018c1fced3f5bf76adf37854f3370d2c45 md5: d92e51bf4b6bdbfe45e5884fb0755afe @@ -2714,6 +2914,15 @@ packages: license_family: MIT size: 11857802 timestamp: 1720853997952 +- conda: https://prefix.dev/conda-forge/noarch/iniconfig-2.0.0-pyhd8ed1ab_1.conda + sha256: 0ec8f4d02053cd03b0f3e63168316530949484f80e16f5e2fb199a1d117a89ca + md5: 6837f3eff7dcea42ecd714ce1ac2b108 + depends: + - python >=3.9 + license: MIT + license_family: MIT + size: 11474 + timestamp: 1733223232820 - conda: https://prefix.dev/conda-forge/noarch/kernel-headers_linux-64-3.10.0-he073ed8_18.conda sha256: a922841ad80bd7b222502e65c07ecb67e4176c4fa5b03678a005f39fcc98be4b md5: ad8527bf134a90e1c9ed35fa0b64318c @@ -3557,6 +3766,23 @@ packages: license_family: APACHE size: 372641 timestamp: 1744943697025 +- conda: https://prefix.dev/conda-forge/linux-64/libblas-3.9.0-31_h59b9bed_openblas.conda + build_number: 31 + sha256: 9839fc4ac0cbb0aa3b9eea520adfb57311838959222654804e58f6f2d1771db5 + md5: 728dbebd0f7a20337218beacffd37916 + depends: + - libopenblas >=0.3.29,<0.3.30.0a0 + - libopenblas >=0.3.29,<1.0a0 + constrains: + - liblapacke =3.9.0=31*_openblas + - liblapack =3.9.0=31*_openblas + - blas =2.131=openblas + - mkl <2025 + - libcblas =3.9.0=31*_openblas + license: BSD-3-Clause + license_family: BSD + size: 16859 + timestamp: 1740087969120 - conda: https://prefix.dev/conda-forge/linux-64/libbrotlicommon-1.1.0-hb9d3cd8_2.conda sha256: d9db2de60ea917298e658143354a530e9ca5f9c63471c65cf47ab39fd2f429e3 md5: 41b599ed2b02abcfdd84302bff174b23 @@ -3682,6 +3908,20 @@ packages: license_family: MIT size: 245929 timestamp: 1725268238259 +- conda: https://prefix.dev/conda-forge/linux-64/libcblas-3.9.0-31_he106b2a_openblas.conda + build_number: 31 + sha256: ede8545011f5b208b151fe3e883eb4e31d495ab925ab7b9ce394edca846e0c0d + md5: abb32c727da370c481a1c206f5159ce9 + depends: + - libblas 3.9.0 31_h59b9bed_openblas + constrains: + - liblapacke =3.9.0=31*_openblas + - liblapack =3.9.0=31*_openblas + - blas =2.131=openblas + license: BSD-3-Clause + license_family: BSD + size: 16796 + timestamp: 1740087984429 - conda: https://prefix.dev/conda-forge/osx-64/libclang-cpp18.1-18.1.8-default_h3571c67_9.conda sha256: a3453cf08393f4a369a70795036d60dd8ea0de1efbf683594cbcaba49d8e3e74 md5: ef1a444913775b76f3391431967090a9 @@ -3776,6 +4016,40 @@ packages: license_family: BSD size: 25694 timestamp: 1633684287072 +- conda: https://prefix.dev/conda-forge/linux-64/libcublas-12.9.1.4-h9ab20c4_0.conda + sha256: 38bc99de89687ec391750dc603203364bdedfb92c600dcb2916dd3cd8558f5f5 + md5: 605f995d88cdb64714bd9979aadc7cd4 + depends: + - __glibc >=2.28,<3.0.a0 + - cuda-nvrtc + - cuda-version >=12.9,<12.10.0a0 + - libgcc >=13 + - libstdcxx >=13 + license: LicenseRef-NVIDIA-End-User-License-Agreement + size: 467680700 + timestamp: 1749227622432 +- conda: https://prefix.dev/conda-forge/linux-64/libcufft-11.4.1.4-h5888daf_0.conda + sha256: fb4d2b0c23104d2c42400a3f69f311f087a3b71ab9c9c36bb249919e599b7e8d + md5: 2da1a83a3b1951e7e8d1c9c3d1340c41 + depends: + - __glibc >=2.17,<3.0.a0 + - cuda-version >=12.9,<12.10.0a0 + - libgcc >=13 + - libstdcxx >=13 + license: LicenseRef-NVIDIA-End-User-License-Agreement + size: 162077229 + timestamp: 1749221627451 +- conda: https://prefix.dev/conda-forge/linux-64/libcurand-10.3.10.19-h9ab20c4_0.conda + sha256: c4576976b8b5ceb060b32d24fc08db5253606256c3c99b42ace343e9be2229db + md5: c745bc0dd1f066e6752c8b2909216b62 + depends: + - __glibc >=2.28,<3.0.a0 + - cuda-version >=12.9,<12.10.0a0 + - libgcc >=13 + - libstdcxx >=13 + license: LicenseRef-NVIDIA-End-User-License-Agreement + size: 46161381 + timestamp: 1746193213392 - conda: https://prefix.dev/conda-forge/linux-64/libcurl-8.13.0-h332b0f4_0.conda sha256: 38e528acfaa0276b7052f4de44271ff9293fdb84579650601a8c49dac171482a md5: cbdc92ac0d93fe3c796e36ad65c7905c @@ -3836,6 +4110,32 @@ packages: license_family: MIT size: 357142 timestamp: 1743602240803 +- conda: https://prefix.dev/conda-forge/linux-64/libcusolver-11.7.5.82-h9ab20c4_0.conda + sha256: fadacf0aacead8bb6264c4bce4051f4ef7830c218a4e867a67c02d3c4b28bd08 + md5: ecaa51e8bc0039aab1ac44c1270c70b8 + depends: + - __glibc >=2.28,<3.0.a0 + - cuda-version >=12.9,<12.10.0a0 + - libcublas >=12.9.1.4,<12.10.0a0 + - libcusparse >=12.5.10.65,<12.6.0a0 + - libgcc >=13 + - libnvjitlink >=12.9.86,<12.10.0a0 + - libstdcxx >=13 + license: LicenseRef-NVIDIA-End-User-License-Agreement + size: 205162082 + timestamp: 1749232252911 +- conda: https://prefix.dev/conda-forge/linux-64/libcusparse-12.5.10.65-h5888daf_0.conda + sha256: 2e69a61c10633651c80dee982d7e46ed5aef6c06ee47622188403d6b9f99b889 + md5: 662ed6e77f131380286d772f6a364ac2 + depends: + - __glibc >=2.17,<3.0.a0 + - cuda-version >=12.9,<12.10.0a0 + - libgcc >=13 + - libnvjitlink >=12.9.86,<12.10.0a0 + - libstdcxx >=13 + license: LicenseRef-NVIDIA-End-User-License-Agreement + size: 208848587 + timestamp: 1749224709022 - conda: https://prefix.dev/conda-forge/osx-64/libcxx-20.1.3-hf95d169_0.conda sha256: a4b493e0f76b20ff14e0f1f93c92882663c4f23c4488d8de3f6bbf1311b9c41e md5: 022f109787a9624301ddbeb39519ff13 @@ -4016,6 +4316,16 @@ packages: license_family: MIT size: 140896 timestamp: 1743432122520 +- conda: https://prefix.dev/conda-forge/linux-64/libffi-3.4.6-h2dba641_1.conda + sha256: 764432d32db45466e87f10621db5b74363a9f847d2b8b1f9743746cd160f06ab + md5: ede4673863426c0883c0063d853bbd85 + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=13 + license: MIT + license_family: MIT + size: 57433 + timestamp: 1743434498161 - conda: https://prefix.dev/conda-forge/linux-64/libgcc-14.2.0-h767d61c_2.conda sha256: 3a572d031cb86deb541d15c1875aaa097baefc0c580b54dc61f5edab99215792 md5: ef504d1acbd74b7cc6849ef8af47dd03 @@ -4460,6 +4770,20 @@ packages: license: LGPL-2.1-or-later size: 78921 timestamp: 1739039271409 +- conda: https://prefix.dev/conda-forge/linux-64/liblapack-3.9.0-31_h7ac8fdf_openblas.conda + build_number: 31 + sha256: f583661921456e798aba10972a8abbd9d33571c655c1f66eff450edc9cbefcf3 + md5: 452b98eafe050ecff932f0ec832dd03f + depends: + - libblas 3.9.0 31_h59b9bed_openblas + constrains: + - libcblas =3.9.0=31*_openblas + - liblapacke =3.9.0=31*_openblas + - blas =2.131=openblas + license: BSD-3-Clause + license_family: BSD + size: 16790 + timestamp: 1740087997375 - conda: https://prefix.dev/conda-forge/linux-64/libllvm18-18.1.8-ha7bfdaf_3.conda sha256: de23835ab90e90b4dec9960f69c56a629189bb266d0d9aabac3bac26f1a4a836 md5: de2f6ca3a6e411376ccc56398550f7e0 @@ -4621,6 +4945,40 @@ packages: license_family: MIT size: 566719 timestamp: 1729572385640 +- conda: https://prefix.dev/conda-forge/linux-64/libnsl-2.0.1-hd590300_0.conda + sha256: 26d77a3bb4dceeedc2a41bd688564fe71bf2d149fdcf117049970bc02ff1add6 + md5: 30fd6e37fe21f86f4bd26d6ee73eeec7 + depends: + - libgcc-ng >=12 + license: LGPL-2.1-only + license_family: GPL + size: 33408 + timestamp: 1697359010159 +- conda: https://prefix.dev/conda-forge/linux-64/libnvjitlink-12.9.86-h5888daf_0.conda + sha256: 2df595ff4cd599446ed7ca01cdfaccc6bc8de89de45b834dd8d5b044ef1d0aea + md5: 7bc06365942b9e4a037746c182feff4d + depends: + - __glibc >=2.17,<3.0.a0 + - cuda-version >=12,<12.10.0a0 + - libgcc >=13 + - libstdcxx >=13 + license: LicenseRef-NVIDIA-End-User-License-Agreement + size: 30525691 + timestamp: 1749219248901 +- conda: https://prefix.dev/conda-forge/linux-64/libopenblas-0.3.29-pthreads_h94d23a6_0.conda + sha256: cc5389ea254f111ef17a53df75e8e5209ef2ea6117e3f8aced88b5a8e51f11c4 + md5: 0a4d0252248ef9a0f88f2ba8b8a08e12 + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=14 + - libgfortran + - libgfortran5 >=14.2.0 + constrains: + - openblas >=0.3.29,<0.3.30.0a0 + license: BSD-3-Clause + license_family: BSD + size: 5919288 + timestamp: 1739825731827 - conda: https://prefix.dev/conda-forge/linux-64/libopentelemetry-cpp-1.20.0-hd1b1c89_0.conda sha256: 11ba93b440f3332499801b8f9580cea3dc19c3aa440c4deb30fd8be302a71c7f md5: e1185384cc23e3bbf85486987835df94 @@ -4880,6 +5238,16 @@ packages: license_family: GPL size: 4155341 timestamp: 1740240344242 +- conda: https://prefix.dev/conda-forge/linux-64/libsqlite-3.50.1-hee588c1_0.conda + sha256: cd15ab1b9f0d53507e7ad7a01e52f6756ab3080bf623ab0e438973b6e4dba3c0 + md5: 96a7e36bff29f1d0ddf5b771e0da373a + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=13 + - libzlib >=1.3.1,<2.0a0 + license: Unlicense + size: 919819 + timestamp: 1749232795476 - conda: https://prefix.dev/conda-forge/linux-64/libssh2-1.11.1-hf672d98_0.conda sha256: 0407ac9fda2bb67e11e357066eff144c845801d00b5f664efbc48813af1e7bb9 md5: be2de152d8073ef1c01b7728475f2fe7 @@ -5047,6 +5415,15 @@ packages: license_family: MIT size: 85371 timestamp: 1737244781933 +- conda: https://prefix.dev/conda-forge/linux-64/libuuid-2.38.1-h0b41bf4_0.conda + sha256: 787eb542f055a2b3de553614b25f09eefb0a0931b0c87dbcce6efdfd92f04f18 + md5: 40b61aab5c7ba9ff276c41cfffe6b80b + depends: + - libgcc-ng >=12 + license: BSD-3-Clause + license_family: BSD + size: 33601 + timestamp: 1680112270483 - conda: https://prefix.dev/conda-forge/linux-64/libuv-1.50.0-hb9d3cd8_0.conda sha256: b4a8890023902aef9f1f33e3e35603ad9c2f16c21fdb58e968fa6c1bd3e94c0b md5: 771ee65e13bc599b0b62af5359d80169 @@ -5407,6 +5784,24 @@ packages: license_family: MIT size: 136487 timestamp: 1744445244122 +- conda: https://prefix.dev/conda-forge/linux-64/numpy-2.2.6-py312h72c5963_0.conda + sha256: c3b3ff686c86ed3ec7a2cc38053fd6234260b64286c2bd573e436156f39d14a7 + md5: 17fac9db62daa5c810091c2882b28f45 + depends: + - __glibc >=2.17,<3.0.a0 + - libblas >=3.9.0,<4.0a0 + - libcblas >=3.9.0,<4.0a0 + - libgcc >=13 + - liblapack >=3.9.0,<4.0a0 + - libstdcxx >=13 + - python >=3.12,<3.13.0a0 + - python_abi 3.12.* *_cp312 + constrains: + - numpy-base <0a0 + license: BSD-3-Clause + license_family: BSD + size: 8490501 + timestamp: 1747545073507 - conda: https://prefix.dev/conda-forge/linux-64/openssl-3.5.0-h7b32b05_0.conda sha256: 38285d280f84f1755b7c54baf17eccf2e3e696287954ce0adca16546b85ee62c md5: bb539841f2a3fde210f387d00ed4bb9d @@ -5516,6 +5911,16 @@ packages: license_family: Apache size: 1103840 timestamp: 1741889978401 +- conda: https://prefix.dev/conda-forge/noarch/packaging-25.0-pyh29332c3_1.conda + sha256: 289861ed0c13a15d7bbb408796af4de72c2fe67e2bcb0de98f4c3fce259d7991 + md5: 58335b26c38bf4a20f399384c33cbcf9 + depends: + - python >=3.8 + - python + license: Apache-2.0 + license_family: APACHE + size: 62477 + timestamp: 1745345660407 - conda: https://prefix.dev/conda-forge/linux-64/pcre2-10.44-hba22ea6_2.conda sha256: 1087716b399dab91cc9511d6499036ccdc53eb29a288bebcb19cf465c51d7c0d md5: df359c09c41cd186fffb93a2d87aa6f5 @@ -5581,6 +5986,26 @@ packages: license: GPL-1.0-or-later OR Artistic-1.0-Perl size: 28889712 timestamp: 1703310809518 +- conda: https://prefix.dev/conda-forge/noarch/pip-25.1.1-pyh8b19718_0.conda + sha256: ebfa591d39092b111b9ebb3210eb42251be6da89e26c823ee03e5e838655a43e + md5: 32d0781ace05105cc99af55d36cbec7c + depends: + - python >=3.9,<3.13.0a0 + - setuptools + - wheel + license: MIT + license_family: MIT + size: 1242995 + timestamp: 1746249983238 +- conda: https://prefix.dev/conda-forge/noarch/pluggy-1.6.0-pyhd8ed1ab_0.conda + sha256: a8eb555eef5063bbb7ba06a379fa7ea714f57d9741fe0efdb9442dbbc2cccbcc + md5: 7da7ccd349dbf6487a7778579d2bb971 + depends: + - python >=3.9 + license: MIT + license_family: MIT + size: 24246 + timestamp: 1747339794916 - conda: https://prefix.dev/conda-forge/linux-64/prometheus-cpp-1.3.0-ha5d0236_0.conda sha256: 013669433eb447548f21c3c6b16b2ed64356f726b5f77c1b39d5ba17a8a4b8bc md5: a83f6a2fdc079e643237887a37460668 @@ -5621,6 +6046,89 @@ packages: license_family: MIT size: 173220 timestamp: 1730769371051 +- conda: https://prefix.dev/conda-forge/noarch/py-1.11.0-pyhd8ed1ab_1.conda + sha256: f2660eb121032dcbe1f3f5d53a120625698ca6602f32a2aba131bb1023286722 + md5: 9eb1496f8aa577322f293ee0c72983fd + depends: + - python >=3.9 + license: MIT + license_family: MIT + size: 80791 + timestamp: 1734003519402 +- conda: https://prefix.dev/conda-forge/noarch/pygments-2.19.1-pyhd8ed1ab_0.conda + sha256: 28a3e3161390a9d23bc02b4419448f8d27679d9e2c250e29849e37749c8de86b + md5: 232fb4577b6687b2d503ef8e254270c9 + depends: + - python >=3.9 + license: BSD-2-Clause + license_family: BSD + size: 888600 + timestamp: 1736243563082 +- conda: https://prefix.dev/conda-forge/noarch/pytest-8.4.0-pyhd8ed1ab_0.conda + sha256: f8c5a65ff4216f7c0a9be1708be1ee1446ad678da5a01eeb2437551156e32a06 + md5: 516d31f063ce7e49ced17f105b63a1f1 + depends: + - colorama >=0.4 + - exceptiongroup >=1 + - iniconfig >=1 + - packaging >=20 + - pluggy >=1.5,<2 + - pygments >=2.7.2 + - python >=3.9 + - tomli >=1 + constrains: + - pytest-faulthandler >=2 + license: MIT + license_family: MIT + size: 275014 + timestamp: 1748907618871 +- conda: https://prefix.dev/conda-forge/noarch/pytest-forked-1.6.0-pyhd8ed1ab_1.conda + sha256: df40eeea0e36ecd95fc9edf41e55dbcea28d92fcdb6aa53bcec51aaf5bdf4ae6 + md5: 2222c712a38755af8870642c17beabc1 + depends: + - py + - pytest >=3.10 + - python >=3.9 + license: MIT + license_family: MIT + size: 10953 + timestamp: 1734551290813 +- conda: https://prefix.dev/conda-forge/linux-64/python-3.12.11-h9e4cc4f_0_cpython.conda + sha256: 6cca004806ceceea9585d4d655059e951152fc774a471593d4f5138e6a54c81d + md5: 94206474a5608243a10c92cefbe0908f + depends: + - __glibc >=2.17,<3.0.a0 + - bzip2 >=1.0.8,<2.0a0 + - ld_impl_linux-64 >=2.36.1 + - libexpat >=2.7.0,<3.0a0 + - libffi >=3.4.6,<3.5.0a0 + - libgcc >=13 + - liblzma >=5.8.1,<6.0a0 + - libnsl >=2.0.1,<2.1.0a0 + - libsqlite >=3.50.0,<4.0a0 + - libuuid >=2.38.1,<3.0a0 + - libxcrypt >=4.4.36 + - libzlib >=1.3.1,<2.0a0 + - ncurses >=6.5,<7.0a0 + - openssl >=3.5.0,<4.0a0 + - readline >=8.2,<9.0a0 + - tk >=8.6.13,<8.7.0a0 + - tzdata + constrains: + - python_abi 3.12.* *_cp312 + license: Python-2.0 + size: 31445023 + timestamp: 1749050216615 +- conda: https://prefix.dev/conda-forge/noarch/python_abi-3.12-7_cp312.conda + build_number: 7 + sha256: a1bbced35e0df66cc713105344263570e835625c28d1bdee8f748f482b2d7793 + md5: 0dfcdc155cf23812a0c9deada86fb723 + constrains: + - python 3.12.* *_cpython + license: BSD-3-Clause + license_family: BSD + size: 6971 + timestamp: 1745258861359 - conda: https://prefix.dev/conda-forge/linux-64/re2-2024.07.02-h9925aae_3.conda sha256: 66d34e3b4881f856486d11914392c585713100ca547ccfc0947f3a4765c2c486 md5: 6f445fb139c356f903746b2b91bbe786 @@ -5657,6 +6165,16 @@ packages: license_family: BSD size: 220297 timestamp: 1741121702233 +- conda: https://prefix.dev/conda-forge/linux-64/readline-8.2-h8c095d6_2.conda + sha256: 2d6d0c026902561ed77cd646b5021aef2d4db22e57a5b0178dfc669231e06d2c + md5: 283b96675859b20a825f8fa30f311446 + depends: + - libgcc >=13 + - ncurses >=6.5,<7.0a0 + license: GPL-3.0-only + license_family: GPL + size: 282480 + timestamp: 1740379431762 - conda: https://prefix.dev/conda-forge/linux-64/rhash-1.4.5-hb9d3cd8_0.conda sha256: 04677caac29ec64a5d41d0cca8dbec5f60fa166d5458ff5a4393e4dc08a4799e md5: 9af0e7981755f09c81421946c4bcea04 @@ -5696,6 +6214,15 @@ packages: license_family: Apache size: 352907 timestamp: 1743805258946 +- conda: https://prefix.dev/conda-forge/noarch/setuptools-80.9.0-pyhff2d567_0.conda + sha256: 972560fcf9657058e3e1f97186cc94389144b46dbdf58c807ce62e83f977e863 + md5: 4de79c071274a53dcaf2a8c749d1499e + depends: + - python >=3.9 + license: MIT + license_family: MIT + size: 748788 + timestamp: 1748804951958 - conda: https://prefix.dev/conda-forge/osx-64/sigtool-0.1.3-h88f4db0_0.tar.bz2 sha256: 46fdeadf8f8d725819c4306838cdfd1099cd8fe3e17bd78862a5dfdcd6de61cf md5: fbfb84b9de9a6939cb165c02c69b1865 @@ -5788,6 +6315,36 @@ packages: license_family: MIT size: 207679 timestamp: 1725491499758 +- conda: https://prefix.dev/conda-forge/linux-64/tk-8.6.13-noxft_hd72426e_102.conda + sha256: a84ff687119e6d8752346d1d408d5cf360dee0badd487a472aa8ddedfdc219e1 + md5: a0116df4f4ed05c303811a837d5b39d8 + depends: + - __glibc >=2.17,<3.0.a0 + - libgcc >=13 + - libzlib >=1.3.1,<2.0a0 + license: TCL + license_family: BSD + size: 3285204 + timestamp: 1748387766691 +- conda: https://prefix.dev/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda + sha256: 18636339a79656962723077df9a56c0ac7b8a864329eb8f847ee3d38495b863e + md5: ac944244f1fed2eb49bae07193ae8215 + depends: + - python >=3.9 + license: MIT + license_family: MIT + size: 19167 + timestamp: 1733256819729 +- conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.14.0-pyhe01879c_0.conda + sha256: 8561db52f278c5716b436da6d4ee5521712a49e8f3c70fcae5350f5ebb4be41c + md5: 2adcd9bb86f656d3d43bf84af59a1faf + depends: + - python >=3.9 + - python + license: PSF-2.0 + license_family: PSF + size: 50978 + timestamp: 1748959427551 - conda: https://prefix.dev/conda-forge/noarch/tzdata-2025b-h78e105d_0.conda sha256: 5aaa366385d716557e365f0a4e9c3fca43ba196872abbbe3d56bb610d131e192 md5: 4222072737ccff51314b5ece9c7d6f5a @@ -5853,6 +6410,15 @@ packages: license_family: MIT size: 219013 timestamp: 1719460515960 +- conda: https://prefix.dev/conda-forge/noarch/wheel-0.45.1-pyhd8ed1ab_1.conda + sha256: 1b34021e815ff89a4d902d879c3bd2040bc1bd6169b32e9427497fa05c55f1ce + md5: 75cb7132eb58d97896e173ef12ac9986 + depends: + - python >=3.9 + license: MIT + license_family: MIT + size: 62931 + timestamp: 1733130309598 - conda: https://prefix.dev/conda-forge/linux-64/zlib-1.3.1-hb9d3cd8_2.conda sha256: 5d7c0e5f0005f74112a34a7425179f4eb6e73c92f5d109e6af4ddeca407c92ab md5: c9f075ab2f33b3bbee9e62d4ad0a6cd8 diff --git a/pixi.toml b/pixi.toml index b72942d..12ddc67 100644 --- a/pixi.toml +++ b/pixi.toml @@ -130,8 +130,43 @@ coverage.cmd = [ coverage.depends-on = ["tests-ci"] coverage.cwd = "." +# CuPy tests +[feature.cupy-tests] +platforms = ["linux-64"] + +[feature.cupy-tests.dependencies] +python = ">=3.12.0,<3.13" +pip = "*" +setuptools = "*" +cupy = "*" +pytest = "*" +pytest-forked = "*" + +[feature.cupy-tests.tasks] +# Since CuPy tests are only available on Linux, we can use bash like +# this to only clone xsref if it isn't already there and checked out +# at the proper tag. +clone-xsref-test-cupy.cmd = """ +bash -c ' +if [ -d xsref ]; then + tag=$(git -C xsref describe --tags --exact-match 2>/dev/null || true) +fi +if [ \"$tag\" != v0.0.0 ]; then + rm -rf xsref + git clone --branch v0.0.0 --depth 1 https://github.com/scipy/xsref.git +fi +' +""" +clone-xsref-test-cupy.cwd = "." +install-xsref-test-cupy.cmd = "pip install ." +install-xsref-test-cupy.cwd = "xsref" +install-xsref-test-cupy.depends-on = ["clone-xsref-test-cupy"] +test-cupy.cmd = "pytest --forked python_tests/test_cupy.py" +test-cupy.cwd = "." +test-cupy.depends-on = ["install-xsref-test-cupy"] [environments] default = { features = ["build", "tests"], solve-group = "default" } tests-ci = { features = ["build", "tests", "tests-ci", "coverage"], solve-group = "default" } lint = { features = ["clang-format"], solve-group = "default" } +cupy-tests = { features = ["cupy-tests"], solve-group = "default" } diff --git a/python_tests/.gitignore b/python_tests/.gitignore new file mode 100644 index 0000000..2cd3f63 --- /dev/null +++ b/python_tests/.gitignore @@ -0,0 +1,195 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock +#poetry.toml + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Abstra +# Abstra is an AI-powered process automation framework. +# Ignore directories containing user credentials, local state, and settings. +# Learn more at https://abstra.io/docs +.abstra/ + +# Visual Studio Code +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore +# and can be added to the global gitignore or merged into this file. However, if you prefer, +# you could uncomment the following to ignore the entire vscode folder +# .vscode/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + +# Cursor +# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to +# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data +# refer to https://docs.cursor.com/context/ignore-files +.cursorignore +.cursorindexingignore diff --git a/python_tests/test_cupy.py b/python_tests/test_cupy.py new file mode 100644 index 0000000..3505645 --- /dev/null +++ b/python_tests/test_cupy.py @@ -0,0 +1,457 @@ +"""Test xsf functions work on GPU through CuPy. + +Beyond cupy and pytest, also requires xsref package to be installed. +Should be run with pytest-isolate or pytest-forked to isolate tests +in separate processes since memory corruption on GPU can cause +failures to occur in unrelated tests. + +TODO: +Set this up to run through pixi, and to run in CI. +""" + +import os +import numpy as np +import polars as pl +import pytest +import shutil +import tempfile + +from glob import glob +from packaging.version import Version +from pathlib import Path + +from xsref.float_tools import extended_relative_error +from xsref.tables import get_input_rows, get_output_rows, get_in_out_types + + +#------------------------------------------------------------------------------ +# Check if a module is present to be used in tests +# +# Copied from +# https://github.com/scipy/scipy/blob/1cbfa1c894557041d9825d6754a2c48fc3bec484/scipy/special/_testutils.py +#------------------------------------------------------------------------------ + +class MissingModule: + def __init__(self, name): + self.name = name + + +def check_version(module, min_ver): + if type(module) is MissingModule: + return pytest.mark.skip(reason=f"{module.name} is not installed") + return pytest.mark.skipif( + Version(module.__version__) < Version(min_ver), + reason=f"{module.__name__} version >= {min_ver} required" + ) + +#------------------------------------------------------------------------------ + +try: + import cupy # type: ignore +except (ImportError, AttributeError): + cupy = MissingModule('cupy') + + +@pytest.fixture(scope="function", autouse=True) +def manage_cupy_cache(): + # Temporarily change cupy kernel cache location so kernel cache will not be polluted + # by these tests. Remove temporary cache in teardown. + temp_cache_dir = tempfile.mkdtemp() + original_cache_dir = os.environ.get('CUPY_CACHE_DIR', None) + os.environ['CUPY_CACHE_DIR'] = temp_cache_dir + + yield + + if original_cache_dir is not None: + os.environ['CUPY_CACHE_DIR'] = original_cache_dir + else: + del os.environ['CUPY_CACHE_DIR'] + shutil.rmtree(temp_cache_dir) + + +def _get_cols_helper(table_path, xp): + # This is more complicated than need be due to an oversight in xsref + # functions for getting tables. The table metadata for input, output, + # and err tables has the input and output types, but does not have + # the column types of the current table, and does not contain the info + # of whether current table is an input, output, or err table. It also + # provides separate functions to get input rows from an input table, + # output rows from an output table, and provides no function to get + # rows from an err table (though this last task is easy to do because + # there are no complex types involved in an err table). + # TODO: Update xsref tables to contain the table's column types + # in the metadata, and provide a single function for getting rows + # from tables as lists of tuples that uses this metadata. + table_name = table_path.name.lower() + if table_name.startswith("in_"): + rows = get_input_rows(table_path) + elif table_name.startswith("out_"): + rows = get_output_rows(table_path) + elif table_name.startswith("err_"): + table = pl.read_parquet(table_path) + rows = table.to_numpy() + result = tuple(xp.asarray(col) for col in zip(*rows)) + if len(result) == 1: + result = result[0] + return result + + +def get_cols_as_cupy(table_path): + return _get_cols_helper(table_path, cupy) + + +def get_cols_as_numpy(table_path): + return _get_cols_helper(table_path, np) + + +HERE = Path(__file__) + + +def get_tables_for_func(func_name): + tables_path = ( + HERE.parent.parent.resolve() / "xsref" / "tables" / "scipy_special_tests" + ) + tables_path /= func_name + input_tables = list(tables_path.glob("In_*.parquet")) + output_tables = [ + path.parent / path.name.replace("In_", "Out_") for path in input_tables + ] + err_tables = [] + for path in input_tables: + types = path.name.removesuffix(".parquet").replace("In_", "") + name = f"Err_{types}_other.parquet" + err_tables.append(path.parent / name) + return list(zip(input_tables, output_tables, err_tables)) + + +def get_preamble(header): + header_path = (HERE.parent.parent / "include" / Path(header)).resolve() + return f'#include "{header_path}"' + + +@pytest.mark.usefixtures("manage_cupy_cache") +@check_version(cupy, "13.0.0") +class TestCuPy: + def _adjust_tol(self, tol, *, wiggle=16): + return wiggle * np.maximum(tol, np.finfo(tol.dtype).eps) + + @pytest.mark.parametrize( + "tables_paths", get_tables_for_func("beta") + ) + def test_beta(self, tables_paths): + input_path, output_path, tol_path = tables_paths + beta = cupy._core.create_ufunc( + "cupyx_scipy_beta", + ("ll->d", "LL->d", "ee->d", "ff->f", "dd->d"), + "out0 = out0_type(xsf::beta(in0, in1));", + preamble=get_preamble("xsf/beta.h"), + ) + + a, b = get_cols_as_cupy(input_path) + out = cupy.asnumpy(beta(a, b)) + + desired = get_cols_as_numpy(output_path) + rtol = get_cols_as_numpy(tol_path) + assert np.all( + extended_relative_error(out, desired) <= self._adjust_tol(rtol) + ) + + @pytest.mark.parametrize( + "tables_paths", get_tables_for_func("binom") + ) + def test_binom(self, tables_paths): + input_path, output_path, tol_path = tables_paths + binom = cupy._core.create_ufunc( + "cupyx_scipy_binom", + ("ff->f", "dd->d"), + "out0 = out0_type(xsf::binom(in0, in1));", + preamble=get_preamble("xsf/binom.h"), + ) + + n, k = get_cols_as_cupy(input_path) + out = cupy.asnumpy(binom(n, k)) + + desired = get_cols_as_numpy(output_path) + rtol = get_cols_as_numpy(tol_path) + assert np.all( + extended_relative_error(out, desired) <= self._adjust_tol(rtol) + ) + + @pytest.mark.parametrize( + "tables_paths", get_tables_for_func("digamma") + ) + def test_digamma(self, tables_paths): + input_path, output_path, tol_path = tables_paths + digamma = cupy._core.create_ufunc( + 'cupyx_scipy_special_digamma', + ( + ('l->d', 'out0 = xsf::digamma(double(in0))'), + ('e->d', 'out0 = xsf::digamma(double(in0))'), + 'f->f', + 'd->d', + 'F->F', + 'D->D', + ), + 'out0 = xsf::digamma(in0)', + preamble=get_preamble("xsf/digamma.h") + ) + + x = get_cols_as_cupy(input_path) + out = cupy.asnumpy(digamma(x)) + + desired = get_cols_as_numpy(output_path) + rtol = get_cols_as_numpy(tol_path) + assert np.all( + extended_relative_error(out, desired) <= self._adjust_tol(rtol) + ) + + + @pytest.mark.parametrize( + "tables_paths", get_tables_for_func("ellipkinc") + ) + def test_ellipkinc(self, tables_paths): + input_path, output_path, tol_path = tables_paths + ellipkinc = cupy._core.create_ufunc( + 'cupyx_scipy_special_ellipkinc', ('ff->f', 'dd->d'), + 'out0 = xsf::cephes::ellik(in0, in1)', + preamble=get_preamble("xsf/cephes/ellik.h"), + ) + + phi, m = get_cols_as_cupy(input_path) + out = cupy.asnumpy(ellipkinc(phi, m)) + + desired = get_cols_as_numpy(output_path) + rtol = get_cols_as_numpy(tol_path) + assert np.all( + extended_relative_error(out, desired) <= self._adjust_tol(rtol) + ) + + @pytest.mark.parametrize( + "tables_paths", get_tables_for_func("ellipeinc") + ) + def test_ellipeinc(self, tables_paths): + input_path, output_path, tol_path = tables_paths + ellipeinc = cupy._core.create_ufunc( + 'cupyx_scipy_special_ellipeinc', ('ff->f', 'dd->d'), + 'out0 = xsf::cephes::ellie(in0, in1)', + preamble=get_preamble("xsf/cephes/ellie.h"), + ) + + phi, m = get_cols_as_cupy(input_path) + out = cupy.asnumpy(ellipeinc(phi, m)) + + desired = get_cols_as_numpy(output_path) + rtol = get_cols_as_numpy(tol_path) + assert np.all( + extended_relative_error(out, desired) <= self._adjust_tol(rtol) + ) + + @pytest.mark.parametrize( + "tables_paths", get_tables_for_func("expn") + ) + def test_expn(self, tables_paths): + input_path, output_path, tol_path = tables_paths + expn = cupy._core.create_ufunc( + 'cupyx_scipy_special_expn', + ('ff->f', 'dd->d'), + 'out0 = xsf::cephes::expn(in0, in1)', + preamble=get_preamble("xsf/cephes/expn.h"), + ) + + x, n = get_cols_as_cupy(input_path) + out = cupy.asnumpy(expn(x, n)) + + desired = get_cols_as_numpy(output_path) + rtol = get_cols_as_numpy(tol_path) + assert np.all( + extended_relative_error(out, desired) <= self._adjust_tol(rtol) + ) + + @pytest.mark.parametrize( + "tables_paths", get_tables_for_func("gdtrib") + ) + @pytest.mark.xfail(reason="Requires cpp_std >= 17") + def test_gdtrib(self, tables_paths): + input_path, output_path, tol_path = tables_paths + gdtrib = cupy._core.create_ufunc( + 'cupyx_scipy_special_gdtrib', + ('fff->f', 'ddd->d'), + 'out0 = xsf::gdtrib(in0, in1, in2)', + preamble=get_preamble("xsf/cdflib.h"), + ) + + a, p, x = get_cols_as_cupy(input_path) + out = cupy.asnumpy(gdtrib(a, p, x)) + + desired = get_cols_as_numpy(output_path) + rtol = get_cols_as_numpy(tol_path) + assert np.all( + extended_relative_error(out, desired) <= self._adjust_tol(rtol) + ) + + @pytest.mark.parametrize( + "tables_paths", get_tables_for_func("hyp2f1") + ) + def test_hyp2f1(self, tables_paths): + input_path, output_path, tol_path = tables_paths + hyp2f1 = cupy._core.create_ufunc( + 'cupyx_scipy_special_hyp2f1', + ('ffff->f', 'dddd->d', 'fffF->F', 'dddD->D'), + 'out0 = xsf::hyp2f1(in0, in1, in2, in3)', + preamble=get_preamble("xsf/hyp2f1.h"), + ) + + a, b, c, z = get_cols_as_cupy(input_path) + if not cupy.iscomplexobj(z): + pytest.xfail( + "Real valued hyp2f1 currently broken on GPU due to use of" + " recursion." + ) + out = cupy.asnumpy(hyp2f1(a, b, c, z)) + + desired = get_cols_as_numpy(output_path) + rtol = get_cols_as_numpy(tol_path) + assert np.all( + extended_relative_error(out, desired) <= self._adjust_tol(rtol) + ) + + @pytest.mark.parametrize( + "tables_paths", get_tables_for_func("lambertw") + ) + def test_lambertw(self, tables_paths): + input_path, output_path, tol_path = tables_paths + _lambertw_scalar = cupy._core.create_ufunc( + "cupyx_scipy_lambertw_scalar", + ("Dld->D", "Fif->f"), + "out0 = xsf::lambertw(in0, in1, in2)", + preamble=get_preamble("xsf/lambertw.h"), + ) + + # A parameter called tol, not to be confused with the rtol for assessing + # accuracy. + z, k, tol = get_cols_as_cupy(input_path) + out = cupy.asnumpy(_lambertw_scalar(z, k, tol)) + + desired = get_cols_as_numpy(output_path) + rtol = get_cols_as_numpy(tol_path) + assert np.all( + extended_relative_error(out, desired) <= self._adjust_tol(rtol) + ) + + @pytest.mark.parametrize( + "tables_paths", get_tables_for_func("sici") + ) + def test_sici(self, tables_paths): + input_path, output_path, tol_path = tables_paths + sici = cupy._core.create_ufunc( + 'cupyx_scipy_special_sici', + ( + ( + 'f->ff', + ''' + float si, ci; + xsf::sici(in0, si, ci); + out0 = si; out1 = ci; + ''', + ), + ( + 'd->dd', + ''' + double si, ci; + xsf::sici(in0, si, ci); + out0 = si; out1 = ci; + ''', + ), + ( + 'F->FF', + ''' + complex si, ci; + xsf::sici(in0, si, ci); + out0 = si; out1 = ci; + ''', + ), + ( + 'D->DD', + ''' + complex si, ci; + xsf::sici(in0, si, ci); + out0 = si; out1 = ci; + ''', + ), + ), + preamble=get_preamble("xsf/sici.h"), + ) + + x = get_cols_as_cupy(input_path) + if cupy.iscomplexobj(x): + pytest.xfail("Known bug, returning nan instead of a complex infinity.") + out0, out1 = map(cupy.asnumpy, sici(x)) + + desired0, desired1 = get_cols_as_numpy(output_path) + rtol0, rtol1 = get_cols_as_numpy(tol_path) + error = extended_relative_error(out0, desired0) + tol = self._adjust_tol(rtol0) + assert np.all( + extended_relative_error(out0, desired0) <= self._adjust_tol(rtol0) + ) + assert np.all( + extended_relative_error(out1, desired1) <= self._adjust_tol(rtol1) + ) + + @pytest.mark.parametrize( + "tables_paths", get_tables_for_func("shichi") + ) + def test_shichi(self, tables_paths): + input_path, output_path, tol_path = tables_paths + shichi = cupy._core.create_ufunc( + 'cupyx_scipy_special_shichi', + ( + ( + 'f->ff', + ''' + float shi, chi; + xsf::shichi(in0, shi, chi); + out0 = shi; out1 = chi; + ''', + ), + ( + 'd->dd', + ''' + double shi, chi; + xsf::shichi(in0, shi, chi); + out0 = shi; out1 = chi; + ''', + ), + ( + 'F->FF', + ''' + complex shi, chi; + xsf::shichi(in0, shi, chi); + out0 = shi; out1 = chi; + ''', + ), + ( + 'D->DD', + ''' + complex shi, chi; + xsf::shichi(in0, shi, chi); + out0 = shi; out1 = chi; + ''', + ), + ), + preamble=get_preamble("xsf/sici.h"), + ) + + x = get_cols_as_cupy(input_path) + if cupy.iscomplexobj(x): + pytest.xfail("Known bug, returning nan instead of a complex infinity.") + out0, out1 = map(cupy.asnumpy, shichi(x)) + + desired0, desired1 = get_cols_as_numpy(output_path) + rtol0, rtol1 = get_cols_as_numpy(tol_path) + assert np.all( + extended_relative_error(out0, desired0) <= self._adjust_tol(rtol0) + ) + assert np.all( + extended_relative_error(out1, desired1) <= self._adjust_tol(rtol1) + )