Skip to content

Commit 8a5bef5

Browse files
committed
Fix JAX GPU tests.
Using XLA_PYTHON_CLIENT_MEM_FRACTION.
1 parent e4bca84 commit 8a5bef5

File tree

3 files changed

+25
-2
lines changed

3 files changed

+25
-2
lines changed

.kokoro/github/ubuntu/gpu/build.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ python3 --version
1818
# Check cuda
1919
nvidia-smi
2020
nvcc --version
21+
echo "LD_LIBRARY_PATH before ${LD_LIBRARY_PATH}"
2122

2223
cd "src/github/keras"
2324
pip install -U pip setuptools
@@ -50,6 +51,8 @@ then
5051
# Raise error if GPU is not detected.
5152
python3 -c 'import jax;assert jax.default_backend().lower() == "gpu"'
5253

54+
echo "LD_LIBRARY_PATH after ${LD_LIBRARY_PATH}"
55+
5356
# TODO: keras/layers/merging/merging_test.py::MergingLayersTest::test_sparse_dot_2d Fatal Python error: Aborted
5457
# TODO: keras/trainers/data_adapters/py_dataset_adapter_test.py::PyDatasetAdapterTest::test_basic_flow0 Fatal Python error: Aborted
5558
# keras/backend/jax/distribution_lib_test.py is configured for CPU test for now.

.kokoro/github/ubuntu/gpu/jax/continuous.cfg

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,20 @@ action {
77
}
88
}
99

10-
env_vars: {
10+
env_vars {
1111
key: "KERAS_BACKEND"
1212
value: "jax"
1313
}
1414

15+
env_vars {
16+
key: "XLA_PYTHON_CLIENT_MEM_FRACTION"
17+
value: "0.5"
18+
}
19+
20+
env_vars {
21+
key: "JAX_TRACEBACK_FILTERING"
22+
value: "off"
23+
}
24+
1525
# Set timeout to 60 mins from default 180 mins
1626
timeout_mins: 60

.kokoro/github/ubuntu/gpu/jax/presubmit.cfg

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,20 @@ action {
77
}
88
}
99

10-
env_vars: {
10+
env_vars {
1111
key: "KERAS_BACKEND"
1212
value: "jax"
1313
}
1414

15+
env_vars {
16+
key: "XLA_PYTHON_CLIENT_MEM_FRACTION"
17+
value: "0.5"
18+
}
19+
20+
env_vars {
21+
key: "JAX_TRACEBACK_FILTERING"
22+
value: "off"
23+
}
24+
1525
# Set timeout to 60 mins from default 180 mins
1626
timeout_mins: 60

0 commit comments

Comments
 (0)