Check out an even simpler quickstart-repo, including this quickstart-notebook
Check out our demo colab
Check out our talk at the ASAP online seminar!
What is this repo? A way to parallelize "inherently sequential" processes like nonlinear recurrent neural networks!
This repository contains code for the paper "Towards Scalable and Stable Parallelization of Nonlinear RNNs," published in NeurIPS 2024. The paper is available on ArXiv here.
This line of work focuses on techniques to parallelize "inherently sequential" processes (like nonlinear RNNs) over the sequence length. This line of work is important because running an inherently sequential process over a long sequence length has very poor GPU utilization. The supposedly "inherently sequential" nature of nonlinear RNNs is one of the main reasons they were passed over in favor of transformers (embarassingly parallel) or linear RNNs (such as mamba or gated delta net, which can be parallelized over the seqeunce length). However, these "ungulate" (DEER and ELK) algorithms also allow nonlinear RNNs to be parallelized over the sequence length, improving their GPU utilization and unlocking their usage for longer sequence modeling.
The primary contributions of our paper are the quasi-DEER and ELK algorithms for parallelizing RNNs over the sequence length. Quasi-DEER is a scalable algorithm. ELK, which stands for "Evaluating Levenberg-Marquardt with Kalman, is a stable algorithm. We also contribute quasi-ELK, which is both scalable and stable.
Our paper and codebase build on the work of YH Lim, et al., in "Parallelizing non-linear sequential models over the sequence length", (paper, codebase) published in ICLR 2024. This work from Machine Discovery Ltd is licensed under the BSD 3-Clause License. The particular files in our codebase from DEER are deer.py and the folder fig3; and much of the set-up in qdeer_profile_exps_figs_2_5_6 comes from the analogous experiments in DEER, though we contribute the memory profiling code.
We recommend using a virtual environment. Use python 3.12.1
Within that virtual environment, first install JAX with
pip install --upgrade pip
pip install -U "jax[cuda12]"
After installing JAX, pip install the package with
pip install --upgrade -e .[cr]
We originally wrote the paper in Python 3.9, but this required running in a singularity container and using old verisons of JAX. Using the code is much easier in Python 3.12.1 and the provided setup.py.
Google Colab runs with Python 3.10
To install into a Google colab, git clone this repo, and then cd into elk. First, import jax, and then run pip install --upgrade -e .[flex]
This is a living repo which will change as we develop more applications. See the tag v1.0.0 for the version corresponding to our NeurIPS paper
Please also star this repo if you find the code interesting or useful!
@inproceedings{gonzalez2024scalable,
title={Towards Scalable and Stable Parallelization of Nonlinear RNNs},
author={Xavier Gonzalez and Andrew Warrington and Jimmy T. H. Smith and Scott W. Linderman},
booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
year={2024},
url={https://doi.org/10.48550/arXiv.2407.19115},
}
