Skip to content

Welcome to jax_toolkit

A collection of jax functions to help with common machine/deep learning related functionality.

Installation

pip install jax_toolkit

Or for additional loss function utils:

pip install jax_toolkit[losses_utils]