Welcome to jax_toolkit¶ A collection of jax functions to help with common machine/deep learning related functionality. Installation¶ pip install jax_toolkit Or for additional loss function utils: pip install jax_toolkit[losses_utils]