jax-js

Numerical / GPU foundations for the web

Have you ever wanted to use NumPy or PyTorch in your browser?

npm install @jax-js/jax

jax-js is a new machine learning framework that brings JAX-style, high-performance CPU and GPU kernels to JavaScript. Run neural networks, image algorithms, simulations, and many kinds of numerical applications without leaving the frontend.

julia set
mnist
llama chatbot
bayesian regression

jax-js is likely the most portable ML framework, since it runs anywhere a browser can run (Chrome, Firefox, Safari, iOS, and Android). It's also simple but optimized, including a lightweight ML compiler and GPU kernel scheduler inspired by tinygrad.

Performance graph of flops: cpu, wasm, webgl, webgpu

The ML compiler achieves best-in-class performance on the web, as it takes a different approach from most runtimes: JIT translating array programs into WebAssembly, WebGL, and WebGPU kernels. It is up to 5x faster than other browser ML frameworks like TensorFlow.js (Google) due to flexible operator fusion.

XXX Performance benchmarks coming soon, though see here for some initial results.

Libraries like NumPy, PyTorch, and JAX revolutionized numerical computing in Python. jax-js aspires to provide a similar foundation for JavaScript, the world's largest software platform.

It's still in development, but you can try it in the jax-js REPL.