Skip to content

Commit 7615ee7

Browse files
botevKfacJaxDev
authored and
KfacJaxDev
committed
* Adding README for the examples.
* Changing some default config settings for examples. PiperOrigin-RevId: 439402725
1 parent f1f58b1 commit 7615ee7

File tree

4 files changed

+59
-2
lines changed

4 files changed

+59
-2
lines changed

examples/README.md

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# KFAC-JAX Examples
2+
3+
This folder contains code with common functionality used in all examples, and
4+
the examples subfolders as well.
5+
Each example follows the following structure:
6+
* `experiment.py` has the model definition, loss definition, and pipeline
7+
experiment class.
8+
* `pipeline.py` has the hyper-parameter configuration.
9+
10+
11+
To run the examples you will need to install additional dependencies via:
12+
13+
```shell
14+
$ pip install -r examples/requirements.txt
15+
```
16+
17+
To run an example simply do:
18+
19+
```shell
20+
$ python example_name/pipeline.py
21+
```
22+
23+
## Autoencoder on MNIST
24+
25+
This example uses the K-FAC optimizer to perform deterministic (i.e. full batch)
26+
training of a deep autoencoder on MNIST.
27+
The default configuration uses the automatic learning rate, momentum, and
28+
damping adaptation techniques from the original K-FAC paper.
29+
30+
## Classifier on MNIST
31+
32+
This example uses the K-FAC optimizer to perform deterministic (i.e. full batch)
33+
training of a very small convolutional network for MNIST classification.
34+
The default configuration uses the automatic learning rate, momentum, and
35+
damping adaptation techniques from the original K-FAC paper.
36+
37+
## Resnet50 on ImageNet
38+
39+
This example uses the K-FAC optimizer to perform stochastic training (with
40+
fixed batch size) of a Resnet50 network for ImageNet classification.
41+
The default configuration uses the automatic damping adaptation technique from
42+
the original K-FAC paper.
43+
The momentum is fixed at `0.9` and the learning rate follows an ad-hoc schedule.
44+
45+
46+
## Resnet101 with TAT on ImageNet
47+
48+
This example uses the K-FAC optimizer to perform stochastic training (with
49+
fixed batch size) of a Resnet101 network for ImageNet classification,
50+
with no residual connections or normalization layers as in the
51+
[TAT paper].
52+
The default configuration uses a fixed damping of `0.001`.
53+
The momentum is fixed at `0.9` and the learning rate follows a cosine decay
54+
schedule.
55+
56+
[TAT paper]: https://arxiv.org/abs/2203.08120

examples/autoencoder_mnist/pipeline.py

+1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def get_config() -> config_dict.ConfigDict:
3838
config.checkpoint_dir = "/tmp/kfac_jax_jaxline/"
3939
config.train_checkpoint_all_hosts = False
4040

41+
# Experiment config.
4142
config.experiment_kwargs = config_dict.ConfigDict(
4243
dict(
4344
config=dict(

examples/lrelunet101_imagenet/pipeline.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def get_config() -> config_dict.ConfigDict:
7777
use_adaptive_momentum=False,
7878
use_adaptive_damping=False,
7979
learning_rate_schedule=dict(
80-
initial_learning_rate=0.1,
80+
initial_learning_rate=3e-4,
8181
warmup_epochs=5,
8282
name="cosine",
8383
),

examples/resnet50_imagenet/pipeline.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_config() -> config_dict.ConfigDict:
4242
config.experiment_kwargs = config_dict.ConfigDict(
4343
dict(
4444
config=dict(
45-
l2_reg=0.0,
45+
l2_reg=1e-5,
4646
training=dict(
4747
steps=200_000,
4848
epochs=None,

0 commit comments

Comments
 (0)