FF-jax[Paper]
Unofficial implementation of Forward-Forward algorithm by jax.
# download script from git
git clone https://github.com/dslisleedh/FF-jax.git
cd FF-jax
# create environment
conda create --name <env> --file requirements.txt
conda activate <env>
# if this not working, install below packages manually
# jax, jaxlib (https://github.com/google/jax#installation)
# einops, tensorflow, tensorflow_datasets, tqdm, hydra-core, hydra-colorlog, omegaconf, gin-config
# run !
python train.py
You can easily change train setting under ./config/hparams.gin # config.yaml is for hydra that create and set working directory
Hyperparameters
- Losses
- mse_loss
- softplus_loss (used in Original Paper)
- probabilistic_loss
- symba_loss
- swish_symba_loss
- Optimizers
- SGD
- MomentumSGD
- NesterovMomentumSGD
- AdaGrad
- RMSProp
- Adam
- AdaBelief
- Initializers
- jax.nn.initializers.lecun_normal
- jax.nn.initializers.glorot_normal
- jax.nn.initializers.he_normal
- jax.nn.initializers.variance_scaling
- and others like n_layers, n_units, ...
- Add Local Conv With Peer Normalization
- add online training model?