I am trying to install JAX with GPU support on a powerful, dedicated Linux server, but I am stuck in what feels like a Catch-22 where every official installation method fails in a different way, always resulting in JAX falling back to the CPU.
I am looking for a definitive, foolproof set of commands to get a working GPU installation.
System Specifications:
- OS: Ubuntu 18.04 LTS
- GPU: 8x NVIDIA Quadro RTX 8000
- NVIDIA Driver: 550.144.03
- CUDA Version (reported by driver): 12.4
- Python: 3.10 (managed by Conda)
What I Have Tried
I have meticulously created fresh conda environments for each attempt to ensure there are no conflicts.
Attempt #1: The Standard Recommended Method
This is the official recommended command.
conda create -n jax_test python=3.10 -y
conda activate jax_test
pip install "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html --no-cache-dir
- Expected Result: A large, multi-gigabyte
jaxlibwheel with CUDA libraries included should be downloaded and installed. - Actual Result:
pipconsistently ignores the[cuda12_pip]directive, downloads the small CPU version ofjaxlib(89.9 MB), and gives a warning. The verification command confirms this failure:
WARNING: jax 0.6.2 does not provide the extra 'cuda12-pip'
Downloading jaxlib-0.6.2-cp310-cp310-manylinux2014_x86_64.whl (89.9 MB)
...
$ python -c "import jax; print(jax.devices())"
WARNING: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[CpuDevice(id=0)]
Attempt #2: The Direct URL Method
This is the expert workaround to force installation of a specific GPU wheel.
# In a clean environment...
pip install "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.23+cuda12.cudnn88-cp310-cp310-manylinux2014_x86_64.whl"
pip install jax==0.4.23 "numpy<2.0"
- Expected Result: The specific, all-in-one GPU
jaxlibis installed. - Actual Result: The installation fails because the URL is dead. Google has apparently removed older files from their storage.
ERROR: HTTP error 404 while getting https://.../jaxlib-0.4.23...whl
This shows that relying on specific, old URLs is not a stable solution.
Attempt #3: The Staged Plugin Method
This involves installing the CUDA plugin first, then JAX.
# In a clean environment...
pip install --upgrade "jax-cuda12-plugin"
pip install jax
- Expected Result: JAX should install and then use the
jaxlibprovided by the plugin. - Actual Result: This leads to a dependency conflict. Installing
jaxre-installs the wrong CPU version ofjaxlibover the plugin's libraries, leading back to the same "fallback to CPU" problem as Attempt #1.
My Question
I am completely stuck.
- The standard installer fails to select the GPU package.
- Direct URL installation fails because the files are gone.
- The plugin-based method creates a dependency conflict that breaks the installation.
Given my server specifications (Ubuntu 18.04, CUDA 12.4 compatible driver), what is the current, definitive, and guaranteed-to-work set of commands to install a version of JAX that successfully uses the GPU?