Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tests/unit/test_kernels.py::TestPeriodicKernel::test_gradients[True-True-divergence_x_grad_y] fails with large numerical errors on Jax 0.5.0 #1003

Open
rg936672 opened this issue Mar 21, 2025 · 0 comments
Labels
bug Something isn't working

Comments

@rg936672
Copy link
Contributor

What's the problem?

As the title says, the test fails on Jax 0.5.0 with large numerical errors.

How can we reproduce the issue?

  1. Remove the conditional skip on the test (introduced in Re-enable Jax 0.5.x #999).
  2. Run pytest .\tests\unit\test_kernels.py::TestPeriodicKernel::test_gradients[True-True-divergence_x_grad_y].
  3. Observe the failure below.

Python version

3.13

Package version

0.4.0

Operating system

Windows

Other packages

None, this is with a uv sync from the lockfile as of c11cd38e0559c4e4d042e75ce7a704adda5c0f6c.

Relevant log output

============================= test session starts =============================
platform win32 -- Python 3.13.2, pytest-8.3.5, pluggy-1.5.0
rootdir: C:\Users\rg936672\dev\coreax
configfile: pyproject.toml
plugins: anyio-4.9.0, jaxtyping-0.2.38, cov-6.0.0, rerunfailures-15.0, xdist-3.6.1
collected 1 item

tests\unit\test_kernels.py F                                             [100%]

================================== FAILURES ===================================
______ TestPeriodicKernel.test_gradients[True-True-divergence_x_grad_y] _______

self = <tests.unit.test_kernels.TestPeriodicKernel object at 0x0000022FD1CE7F20>
gradient_problem = (array([[0.1271346 , 0.82849865],
       [0.1090742 , 0.67314784],
       [0.02745512, 0.416371  ],
       [0.89221347...2806, 0.42603548],
       [0.34778401, 0.92303771],
       [0.01197228, 0.48235223],
       [0.90783867, 0.47335769]]))
kernel = PeriodicKernel(
  length_scale=0.33313825726509094,
  output_scale=0.23848214745521545,
  periodicity=0.5968074202537537
)
mode = 'divergence_x_grad_y', elementwise = True, auto_diff = True

    @pytest.mark.parametrize("mode", ["grad_x", "grad_y", "divergence_x_grad_y"])
    @pytest.mark.parametrize("elementwise", [False, True])
    @pytest.mark.parametrize("auto_diff", [False, True])
    def test_gradients(
        self,
        gradient_problem: tuple[Array, Array],
        kernel: _ScalarValuedKernel,
        mode: Literal["grad_x", "grad_y", "divergence_x_grad_y"],
        elementwise: bool,
        auto_diff: bool,
    ):
        """Test computation of the kernel gradients."""
    
        x, y = gradient_problem
        test_mode = mode
        reference_mode = "expected_" + mode
        if elementwise:
            test_mode += "_elementwise"
            x, y = x[:, 0], y[:, 0]
        expected_output = getattr(self, reference_mode)(x, y, kernel)
        if elementwise:
            expected_output = expected_output.squeeze()
        if auto_diff:
            if isinstance(kernel, (AdditiveKernel, ProductKernel, PowerKernel)):
                pytest.skip(
                    "Autodiff of Additive and Product kernels is tested implicitly."
                )
            # Access overridden parent methods that use auto-differentiation
            autodiff_kernel = super(type(kernel), kernel)
            output = getattr(autodiff_kernel, test_mode)(x, y)
        else:
            output = getattr(kernel, test_mode)(x, y)
>       np.testing.assert_allclose(output, expected_output, atol=1e-3, rtol=1e-4)

tests\unit\test_kernels.py:221: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

args = (<function assert_allclose.<locals>.compare at 0x0000022FD2597F60>, array(-102.96228, dtype=float32), array(-17.63162294))
kwds = {'equal_nan': True, 'err_msg': '', 'header': 'Not equal to tolerance rtol=0.0001, atol=0.001', 'strict': False, ...}

    @wraps(func)
    def inner(*args, **kwds):
        with self._recreate_cm():
>           return func(*args, **kwds)
E           AssertionError: 
E           Not equal to tolerance rtol=0.0001, atol=0.001
E           
E           Mismatched elements: 1 / 1 (100%)
E           Max absolute difference among violations: 85.33065733
E           Max relative difference among violations: 4.83963715
E            ACTUAL: array(-102.96228, dtype=float32)
E            DESIRED: array(-17.631623)

..\..\AppData\Roaming\uv\python\cpython-3.13.2-windows-x86_64-none\Lib\contextlib.py:85: AssertionError
=========================== short test summary info ===========================
FAILED tests/unit/test_kernels.py::TestPeriodicKernel::test_gradients[True-True-divergence_x_grad_y]
============================== 1 failed in 2.43s ==============================
@rg936672 rg936672 added the bug Something isn't working label Mar 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant