Skip to content

Commit 8404b16

Browse files
yashk2810Flax Authors
authored andcommitted
Use .input_formats and .output_formats in place of .input_layouts and .output_layouts respectively.
JAX is undergoing a rename of the contents of jax.experimental.layouts in preparation for its graduation from experimental: "Formats" are replacing "layouts", and specifically Layout -> Format "Layouts" are replacing "device local layouts", and specifically DeviceLocalLayout -> Layout This is an incremental update carrying out #1. PiperOrigin-RevId: 773109310
1 parent 21e64ec commit 8404b16

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

flax/nnx/transforms/compilation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ def output_shardings(self): # PyTree[sharding.Sharding]
599599

600600
@property
601601
def input_layouts(self):
602-
return self.compiled.input_layouts
602+
return self.compiled.input_formats
603603

604604

605605
@dataclasses.dataclass(frozen=True, slots=True)

0 commit comments

Comments
 (0)