1

I am trying to install JAX with GPU support on a powerful, dedicated Linux server, but I am stuck in what feels like a Catch-22 where every official installation method fails in a different way, always resulting in JAX falling back to the CPU.

I am looking for a definitive, foolproof set of commands to get a working GPU installation.

System Specifications:

  • OS: Ubuntu 18.04 LTS
  • GPU: 8x NVIDIA Quadro RTX 8000
  • NVIDIA Driver: 550.144.03
  • CUDA Version (reported by driver): 12.4
  • Python: 3.10 (managed by Conda)

What I Have Tried

I have meticulously created fresh conda environments for each attempt to ensure there are no conflicts.

Attempt #1: The Standard Recommended Method

This is the official recommended command.

conda create -n jax_test python=3.10 -y
conda activate jax_test
pip install "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html --no-cache-dir
  • Expected Result: A large, multi-gigabyte jaxlib wheel with CUDA libraries included should be downloaded and installed.
  • Actual Result: pip consistently ignores the [cuda12_pip] directive, downloads the small CPU version of jaxlib (89.9 MB), and gives a warning. The verification command confirms this failure:
WARNING: jax 0.6.2 does not provide the extra 'cuda12-pip'
Downloading jaxlib-0.6.2-cp310-cp310-manylinux2014_x86_64.whl (89.9 MB)
...
$ python -c "import jax; print(jax.devices())"
WARNING: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[CpuDevice(id=0)]

Attempt #2: The Direct URL Method

This is the expert workaround to force installation of a specific GPU wheel.

# In a clean environment...
pip install "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.23+cuda12.cudnn88-cp310-cp310-manylinux2014_x86_64.whl"
pip install jax==0.4.23 "numpy<2.0"
  • Expected Result: The specific, all-in-one GPU jaxlib is installed.
  • Actual Result: The installation fails because the URL is dead. Google has apparently removed older files from their storage.
ERROR: HTTP error 404 while getting https://.../jaxlib-0.4.23...whl

This shows that relying on specific, old URLs is not a stable solution.

Attempt #3: The Staged Plugin Method

This involves installing the CUDA plugin first, then JAX.

# In a clean environment...
pip install --upgrade "jax-cuda12-plugin"
pip install jax
  • Expected Result: JAX should install and then use the jaxlib provided by the plugin.
  • Actual Result: This leads to a dependency conflict. Installing jax re-installs the wrong CPU version of jaxlib over the plugin's libraries, leading back to the same "fallback to CPU" problem as Attempt #1.

My Question

I am completely stuck.

  • The standard installer fails to select the GPU package.
  • Direct URL installation fails because the files are gone.
  • The plugin-based method creates a dependency conflict that breaks the installation.

Given my server specifications (Ubuntu 18.04, CUDA 12.4 compatible driver), what is the current, definitive, and guaranteed-to-work set of commands to install a version of JAX that successfully uses the GPU?

1 Answer 1

2

Given my server specifications (Ubuntu 18.04, CUDA 12.4 compatible driver), what is the current, definitive, and guaranteed-to-work set of commands to install a version of JAX that successfully uses the GPU?

I would stick with the standard installation method – in your case, this is the relevant warning:

WARNING: jax 0.6.2 does not provide the extra 'cuda12-pip'

You're using an old installation command: the cuda12-pip extra was removed in JAX v0.6.0. JAX's current installation instructions recommend this:

pip install -U "jax[cuda12]"

If you use this command, you should get the correct GPU-specific installation of the latest JAX version compatible with your platform and the Python version you're using (JAX v0.6.2 in the case of Linux / Python 3.10).


I noticed in some of your examples you're pinning old JAX versions (v0.4.23). If you're interested in installing jaxlib wheels for older JAX versions, there are some tips at https://docs.jax.dev/en/latest/installation.html#installing-older-jaxlib-wheels. If you're specifically having issues related to installing older jaxlib versions, I'd suggest posting another question on that topic, being very clear about which JAX version you're trying to install.

Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.