Skip to content

Commit 58825af

Browse files
botevKfacJaxDev
authored and
KfacJaxDev
committed
Adding README for the examples.
PiperOrigin-RevId: 439331640
1 parent f1f58b1 commit 58825af

File tree

4 files changed

+57
-2
lines changed

4 files changed

+57
-2
lines changed

examples/README.md

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# KFAC-JAX Examples
2+
3+
To run the examples you will need to install additional dependencies:
4+
5+
```shell
6+
$ pip install -r examples/requirements.txt
7+
```
8+
9+
This folder contains code with common functionality used in all examples.
10+
Each example follows the following structure:
11+
12+
* `experiment.py` have the example specific code and includes the model
13+
definition, loss definition and pipeline experiment class.
14+
* `pipeline.py` have the example specific hyper-parameter configuration.
15+
16+
To run an example simply do:
17+
18+
```shell
19+
$ python ${example_name}/pipeline.py
20+
```
21+
22+
## Autoencoder on MNIST
23+
24+
The example demonstrates how to use the optimizer on an deterministic
25+
autoencoder on the MNIST dataset.
26+
The default configuration uses the automatic learning rate, momentum and damping
27+
adaptations.
28+
29+
## Classifier on MNIST
30+
31+
The example demonstrates how to use the optimizer on a very small
32+
convolutional classifier on the MNIST dataset.
33+
The default configuration uses the automatic learning rate, momentum and damping
34+
adaptations.
35+
36+
## Resnet50 on ImageNet
37+
38+
This examples demonstrates how to use the optimizer on Resnet50 on the
39+
ImageNet dataset.
40+
Because it is unfeasible to run this problem with very large batch sizes, the
41+
default configuration only adapts the damping.
42+
The momentum is fixed at `0.9` and the learning rate follows an ad-hoc schedule.
43+
44+
45+
## Resnet101 with TAT on ImageNet
46+
47+
This examples demonstrates how to use the optimizer on Resnet101 on the
48+
ImageNet dataset, with no residual connections or normalization layers as in the
49+
[TAT paper].
50+
The damping is fixed at `0.001`, the momentum at `0.9` and we use cosine
51+
learning rate schedule.
52+
53+
54+
[TAT paper]: https://arxiv.org/abs/2203.08120

examples/autoencoder_mnist/pipeline.py

+1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def get_config() -> config_dict.ConfigDict:
3838
config.checkpoint_dir = "/tmp/kfac_jax_jaxline/"
3939
config.train_checkpoint_all_hosts = False
4040

41+
# Experiment config.
4142
config.experiment_kwargs = config_dict.ConfigDict(
4243
dict(
4344
config=dict(

examples/lrelunet101_imagenet/pipeline.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def get_config() -> config_dict.ConfigDict:
7777
use_adaptive_momentum=False,
7878
use_adaptive_damping=False,
7979
learning_rate_schedule=dict(
80-
initial_learning_rate=0.1,
80+
initial_learning_rate=3e-4,
8181
warmup_epochs=5,
8282
name="cosine",
8383
),

examples/resnet50_imagenet/pipeline.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_config() -> config_dict.ConfigDict:
4242
config.experiment_kwargs = config_dict.ConfigDict(
4343
dict(
4444
config=dict(
45-
l2_reg=0.0,
45+
l2_reg=1e-5,
4646
training=dict(
4747
steps=200_000,
4848
epochs=None,

0 commit comments

Comments
 (0)