Skip to content

Commit ef8995e

Browse files
committed
Fix JAX GPU tests.
Using XLA_PYTHON_CLIENT_MEM_FRACTION.
1 parent e4bca84 commit ef8995e

File tree

3 files changed

+23
-0
lines changed

3 files changed

+23
-0
lines changed

.kokoro/github/ubuntu/gpu/build.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ python3 --version
1818
# Check cuda
1919
nvidia-smi
2020
nvcc --version
21+
echo "LD_LIBRARY_PATH before ${LD_LIBRARY_PATH}"
2122

2223
cd "src/github/keras"
2324
pip install -U pip setuptools
@@ -50,6 +51,8 @@ then
5051
# Raise error if GPU is not detected.
5152
python3 -c 'import jax;assert jax.default_backend().lower() == "gpu"'
5253

54+
echo "LD_LIBRARY_PATH after ${LD_LIBRARY_PATH}"
55+
5356
# TODO: keras/layers/merging/merging_test.py::MergingLayersTest::test_sparse_dot_2d Fatal Python error: Aborted
5457
# TODO: keras/trainers/data_adapters/py_dataset_adapter_test.py::PyDatasetAdapterTest::test_basic_flow0 Fatal Python error: Aborted
5558
# keras/backend/jax/distribution_lib_test.py is configured for CPU test for now.

.kokoro/github/ubuntu/gpu/jax/continuous.cfg

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,15 @@ env_vars: {
1212
value: "jax"
1313
}
1414

15+
env_vars: {
16+
key: "XLA_PYTHON_CLIENT_MEM_FRACTION"
17+
value: ".5"
18+
}
19+
20+
env_vars: {
21+
key: "JAX_TRACEBACK_FILTERING"
22+
value: "off"
23+
}
24+
1525
# Set timeout to 60 mins from default 180 mins
1626
timeout_mins: 60

.kokoro/github/ubuntu/gpu/jax/presubmit.cfg

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,15 @@ env_vars: {
1212
value: "jax"
1313
}
1414

15+
env_vars: {
16+
key: "XLA_PYTHON_CLIENT_MEM_FRACTION"
17+
value: ".5"
18+
}
19+
20+
env_vars: {
21+
key: "JAX_TRACEBACK_FILTERING"
22+
value: "off"
23+
}
24+
1525
# Set timeout to 60 mins from default 180 mins
1626
timeout_mins: 60

0 commit comments

Comments
 (0)