Skip to content

Commit 85f0f4b

Browse files
yashk2810Flax Authors
authored andcommitted
Use .format in place of .layout.
JAX is undergoing a rename of the contents of jax.experimental.layouts in preparation for its graduation from experimental: "Formats" are replacing "layouts", and specifically Layout -> Format "Layouts" are replacing "device local layouts", and specifically DeviceLocalLayout -> Layout This is an incremental update carrying out #1. PiperOrigin-RevId: 773116112
1 parent 21e64ec commit 85f0f4b

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

flax/nnx/transforms/compilation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ def output_shardings(self): # PyTree[sharding.Sharding]
599599

600600
@property
601601
def input_layouts(self):
602-
return self.compiled.input_layouts
602+
return self.compiled.input_formats
603603

604604

605605
@dataclasses.dataclass(frozen=True, slots=True)

0 commit comments

Comments
 (0)