Skip to content

Commit e3da816

Browse files
james-martensKfacJaxDev
authored and
KfacJaxDev
committed
- Adding feature to distribute preconditioner application and computation operations across devices.
- Device syncing of factors in TwoKroneckerFactored and Full blocks now happens after every call to update_curvature_matrix_estimate. This fixes a bug which caused devices to possibly desync when using damping adaptation with the approximate quadratic model. However, device communication will now happen more often, which may degrade performance. - Changing a bunch of methods to be pure functions that don't modify their input state trees. This shouldn't change any behavior and is just a code quality improvement. - Using jax.pure_callback instead of host_callback.call to better support TPUs. - Adding newlines between lines of code to improve readability. PiperOrigin-RevId: 476188081
1 parent 9eb5917 commit e3da816

8 files changed

+761
-80
lines changed

kfac_jax/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from kfac_jax._src import utils
2525

2626

27-
__version__ = "0.0.2"
27+
__version__ = "0.0.3"
2828

2929
# Patches Second Moments
3030
patches_moments = patches_second_moment.patches_moments

0 commit comments

Comments
 (0)