Replies: 1 comment
-
Through some trial and error, I may have found a solution. My New Flax code: import jax
import jax.numpy as jnp
from flax import linen as nn
from einops import rearrange
def make_initializer(out_channels, in_channels, kernel_size, groups):
# https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
k = groups / (in_channels * jnp.prod(jnp.array(kernel_size)))
scale = jnp.sqrt(k)
def init_fn(key, shape, dtype):
return jax.random.uniform(key, shape, minval=-scale, maxval=scale, dtype=dtype)
return init_fn
class CustomConv1d(nn.Conv):
@nn.compact
def __call__(self, x):
# note: we just ignore whatever self.kernel_init is
kernel_init = make_initializer(
self.features, x.shape[-1], self.kernel_size, self.feature_group_count
)
if self.use_bias:
# note: we just ignore whatever self.bias_init is
bias_init = make_initializer(
self.features, x.shape[-1], self.kernel_size, self.feature_group_count
)
else:
bias_init = None
return nn.Conv(
features=self.features,
kernel_size=self.kernel_size,
strides=self.strides,
padding=self.padding,
input_dilation=self.input_dilation,
kernel_dilation=self.kernel_dilation,
feature_group_count=self.feature_group_count,
use_bias=self.use_bias,
mask=self.mask,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
kernel_init=kernel_init,
bias_init=bias_init
)(x)
class LeakyReLU(nn.Module):
negative_slope: float = .01
@nn.compact
def __call__(self, x):
return nn.leaky_relu(x, negative_slope=self.negative_slope)
def WNConv2d(scale_init, *args, **kwargs):
conv = nn.WeightNorm(CustomConv1d(*args, **kwargs), scale_init=scale_init)
return conv
class MPD(nn.Module):
period: int
def pad_to_period(self, x):
t = x.shape[-1]
x = jnp.pad(x, pad_width=((0, 0), (0, 0), (0, self.period - t % self.period)), mode='reflect')
return x
@nn.compact
def __call__(self, x):
convs = [
WNConv2d(nn.initializers.constant(1/jnp.sqrt(3)), features=32, kernel_size=(5, 1), strides=(3, 1), padding=((2, 2), (0, 0))),
WNConv2d(nn.initializers.constant(1/jnp.sqrt(3)), features=128, kernel_size=(5, 1), strides=(3, 1), padding=((2, 2), (0, 0))),
WNConv2d(nn.initializers.constant(1/jnp.sqrt(3)), features=512, kernel_size=(5, 1), strides=(3, 1), padding=((2, 2), (0, 0))),
WNConv2d(nn.initializers.constant(1/jnp.sqrt(3)), features=1024, kernel_size=(5, 1), strides=(3, 1), padding=((2, 2), (0, 0))),
WNConv2d(nn.initializers.constant(1/jnp.sqrt(3)), features=1024, kernel_size=(5, 1), strides=(1, 1), padding=((2, 2), (0, 0))),
WNConv2d(nn.initializers.constant(1/jnp.sqrt(3)), features=1, kernel_size=(3, 1), strides=(1, 1), padding=((1, 1), (0, 0))),
]
fmap = []
x = self.pad_to_period(x)
x = rearrange(x, "b c (l p) -> b l p c", p=self.period)
for i, layer in enumerate(convs):
x = layer(x)
if i != (len(convs) - 1):
x = LeakyReLU(negative_slope=0.1)(x)
fmap.append(x)
return fmap
def summary_stats(name, x):
print(f'Stats for {name}:')
print(f'shape:', list(x.shape))
print(f'mean: { jnp.mean(x):,.5f} min: { jnp.min(x):,.5f} max: {jnp.max(x):,.5f} std: {jnp.std(x):,.5f}')
key = jax.random.PRNGKey(1)
B, C, T = 1, 1, 44100
x = jnp.zeros((B, C, T))
period = 2
model = MPD(period)
fmaps, variables = model.init_with_output({"params": key}, x)
# Print summary stats for each feature map
for i, fmap in enumerate(fmaps):
summary_stats(f"fmap {i}", fmap)
print()
params = variables["params"]
for i in range(6):
params[f"WeightNorm_{i}"][f"CustomConv1d_{i}/Conv_0/kernel/scale"]
params[f"CustomConv1d_{i}"]["Conv_0"]["bias"]
params[f"CustomConv1d_{i}"]["Conv_0"]["kernel"]
print(model.tabulate({"params": key}, x, console_kwargs={"width": 400})) New output:
And another randomly sampled PyTorch output:
|
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I'm trying to port some PyTorch code to Flax. The model involves Conv2d layers wrapped with weight_norm. There are also LeakyReLU activations except on the last layer. I've confirmed that the parameter counts and input/output shapes are the same between PyTorch and Flax, and yet the mean/min/max/std of the outputs seem off. So can someone help me identify what went wrong in the porting of the code? I think the issue is related to weight initializations (see #4091)
Here's the PyTorch code:
and PyTorch output:
Here's the Flax code:
and the Flax output:
To me, the most glaring differences in the outputs are the
max:
values, even when changing JAX seeds. Again, here's the PyTorch output:and Flax output:
Beta Was this translation helpful? Give feedback.
All reactions