jax-js
Numerical / GPU foundations for the web
Have you ever wanted to use NumPy or PyTorch in your browser?
npm install @jax-js/jax jax-js is a new machine learning framework that brings JAX-style, high-performance CPU and GPU kernels to JavaScript. Run neural
networks, image algorithms, simulations, and many kinds of numerical
applications without leaving the frontend.
jax-js is likely the most portable ML framework, since it runs
anywhere a browser can run (Chrome, Firefox, Safari, iOS, and Android). It's
also simple but optimized, including a lightweight ML compiler and GPU
kernel scheduler inspired by tinygrad.
Performance graph of flops: cpu, wasm, webgl, webgpu
The ML compiler achieves best-in-class performance on the web, as it takes a different approach from most runtimes: JIT translating array programs into WebAssembly, WebGL, and WebGPU kernels. It is up to 5x faster than other browser ML frameworks like TensorFlow.js (Google) due to flexible operator fusion.
XXX Performance benchmarks coming soon, though see here for some initial results.
Libraries like NumPy, PyTorch, and JAX revolutionized numerical computing in
Python. jax-js aspires to provide a similar foundation for JavaScript, the
world's largest software platform.
It's still in development, but you can try it in the jax-js REPL.