Artigo Acesso aberto Revisado por pares

Slax: a composable JAX library for rapid and flexible prototyping of spiking neural networks

2025; IOP Publishing; Linguagem: Inglês

10.1088/2634-4386/ada988

ISSN

2634-4386

Autores

Thomas M. Summe, Siddharth Joshi,

Tópico(s)

Neural dynamics and brain function

Resumo

Abstract Spiking Neural Networks (SNNs) offer rich temporal dynamics and unique capabilities, but their training presents challenges. While backpropagation through time (BPTT) with surrogate gradients is the defacto standard for training SNNs, it scales poorly with long time sequences. Alternative learning rules and algorithms could help further develop models and systems across the spectrum of performance, bio-plausibility, and complexity. However, implementing and evaluating these alternatives at scale is cumbersome and error-prone, requiring repeated reimplementations. To address this, we introduce Slax, a JAX-based library designed to accelerate SNN algorithm design and evaluation. Slax is compatible with the broader JAX and Flax ecosystem and provides optimized implementations of diverse training algorithms, enabling direct performance comparisons. Its toolkit includes methods to visualize and debug algorithms through loss landscapes, gradient similarities, and other metrics of model behavior during training. By streamlining the implementation and evaluation of novel SNN learning algorithms, Slax aims to facilitate research and development in this promising field.

Referência(s)