The thumbnail image is created by chatGPT-4o

This content explains how to set up GPU-enabled JAX environment for HPC settings, espeically ICDS at the Pennstate University.



1. loading relevent modules

module load anaconda
module load cuda/12.6.0


2. create anaconda environment and activate the environment

conda create --name jax python=3.9
conda activate jax


3. install gpu-enabled jax

(https://docs.jax.dev/en/latest/installation.html)

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html


3. test

python
import jax
jax.devices()