Skip to content

* Adding README for the examples. #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# KFAC-JAX Examples

This folder contains code with common functionality used in all examples, and
the examples subfolders as well.
Each example follows the following structure:
* `experiment.py` has the model definition, loss definition, and pipeline
experiment class.
* `pipeline.py` has the hyper-parameter configuration.


To run the examples you will need to install additional dependencies via:

```shell
$ pip install -r examples/requirements.txt
```

To run an example simply do:

```shell
$ python example_name/pipeline.py
```

## Autoencoder on MNIST

This example uses the K-FAC optimizer to perform deterministic (i.e. full batch)
training of a deep autoencoder on MNIST.
The default configuration uses the automatic learning rate, momentum, and
damping adaptation techniques from the original K-FAC paper.

## Classifier on MNIST

This example uses the K-FAC optimizer to perform deterministic (i.e. full batch)
training of a very small convolutional network for MNIST classification.
The default configuration uses the automatic learning rate, momentum, and
damping adaptation techniques from the original K-FAC paper.

## Resnet50 on ImageNet

This example uses the K-FAC optimizer to perform stochastic training (with
fixed batch size) of a Resnet50 network for ImageNet classification.
The default configuration uses the automatic damping adaptation technique from
the original K-FAC paper.
The momentum is fixed at `0.9` and the learning rate follows an ad-hoc schedule.


## Resnet101 with TAT on ImageNet

This example uses the K-FAC optimizer to perform stochastic training (with
fixed batch size) of a Resnet101 network for ImageNet classification,
with no residual connections or normalization layers as in the
[TAT paper].
The default configuration uses a fixed damping of `0.001`.
The momentum is fixed at `0.9` and the learning rate follows a cosine decay
schedule.

[TAT paper]: https://arxiv.org/abs/2203.08120
1 change: 1 addition & 0 deletions examples/autoencoder_mnist/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def get_config() -> config_dict.ConfigDict:
config.checkpoint_dir = "/tmp/kfac_jax_jaxline/"
config.train_checkpoint_all_hosts = False

# Experiment config.
config.experiment_kwargs = config_dict.ConfigDict(
dict(
config=dict(
Expand Down
2 changes: 1 addition & 1 deletion examples/lrelunet101_imagenet/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def get_config() -> config_dict.ConfigDict:
use_adaptive_momentum=False,
use_adaptive_damping=False,
learning_rate_schedule=dict(
initial_learning_rate=0.1,
initial_learning_rate=3e-4,
warmup_epochs=5,
name="cosine",
),
Expand Down
2 changes: 1 addition & 1 deletion examples/resnet50_imagenet/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_config() -> config_dict.ConfigDict:
config.experiment_kwargs = config_dict.ConfigDict(
dict(
config=dict(
l2_reg=0.0,
l2_reg=1e-5,
training=dict(
steps=200_000,
epochs=None,
Expand Down