File tree Expand file tree Collapse file tree 3 files changed +8
-2
lines changed
.kokoro/github/ubuntu/gpu Expand file tree Collapse file tree 3 files changed +8
-2
lines changed Original file line number Diff line number Diff line change @@ -18,6 +18,7 @@ python3 --version
18
18
# Check cuda
19
19
nvidia-smi
20
20
nvcc --version
21
+ echo " LD_LIBRARY_PATH before ${LD_LIBRARY_PATH} "
21
22
22
23
cd " src/github/keras"
23
24
pip install -U pip setuptools
43
44
44
45
if [ " $KERAS_BACKEND " == " jax" ]
45
46
then
47
+ export XLA_PYTHON_CLIENT_MEM_FRACTION=.5
48
+ export JAX_TRACEBACK_FILTERING=off
49
+
46
50
echo " JAX backend detected."
47
51
pip install -r requirements-jax-cuda.txt --progress-bar off --timeout 1000
48
52
pip uninstall -y keras keras-nightly
49
53
python3 -c ' import jax;print(jax.__version__);print(jax.default_backend())'
50
54
# Raise error if GPU is not detected.
51
55
python3 -c ' import jax;assert jax.default_backend().lower() == "gpu"'
52
56
57
+ echo " LD_LIBRARY_PATH after ${LD_LIBRARY_PATH} "
58
+
53
59
# TODO: keras/layers/merging/merging_test.py::MergingLayersTest::test_sparse_dot_2d Fatal Python error: Aborted
54
60
# TODO: keras/trainers/data_adapters/py_dataset_adapter_test.py::PyDatasetAdapterTest::test_basic_flow0 Fatal Python error: Aborted
55
61
# keras/backend/jax/distribution_lib_test.py is configured for CPU test for now.
Original file line number Diff line number Diff line change 7
7
}
8
8
}
9
9
10
- env_vars: {
10
+ env_vars {
11
11
key: " KERAS_BACKEND"
12
12
value: " jax"
13
13
}
Original file line number Diff line number Diff line change 7
7
}
8
8
}
9
9
10
- env_vars: {
10
+ env_vars {
11
11
key: " KERAS_BACKEND"
12
12
value: " jax"
13
13
}
You can’t perform that action at this time.
0 commit comments