Skip to content

Commit 7bd762e

Browse files
committed
Bugfixing
1 parent ca0d13a commit 7bd762e

File tree

9 files changed

+50
-70
lines changed

9 files changed

+50
-70
lines changed

fflows/filters/spectral.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,9 @@ def __init__(self, d, k, FFT, hidden, flip=False, RNN=False):
5050

5151
self.d, self.k = d, k
5252

53-
if FFT:
54-
55-
self.out_size = self.d - self.k + 1
56-
self.pz_size = self.d + 1
57-
self.in_size = self.k
58-
59-
else:
60-
61-
self.out_size = self.d - self.k
62-
self.pz_size = self.d
63-
self.in_size = self.k
53+
self.out_size = self.d - self.k
54+
self.pz_size = self.d
55+
self.in_size = self.k
6456

6557
if flip:
6658

fflows/fourier/transforms.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77
Used as the "Fourier Transform" step in a Fourier Flow, can be preceded by other torch modules.
88
99
"""
10-
11-
from __future__ import absolute_import, division, print_function
12-
1310
import sys
1411
import warnings
1512

@@ -60,10 +57,10 @@ def reconstruct_DFT(x, component="real"):
6057
"""
6158

6259
if component == "real":
63-
x_rec = torch.cat([x[0, :], flip(x[0, :], dim=0)[1:]], dim=0)
60+
x_rec = torch.cat([x[0, :], flip(x[0, :], dim=0)], dim=0)
6461

6562
elif component == "imag":
66-
x_rec = torch.cat([x[1, :], -1 * flip(x[1, :], dim=0)[1:]], dim=0)
63+
x_rec = torch.cat([x[1, :], -1 * flip(x[1, :], dim=0)], dim=0)
6764

6865
return x_rec
6966

@@ -93,7 +90,7 @@ def __init__(self, N_fft=100):
9390
super(DFT, self).__init__()
9491

9592
self.N_fft = N_fft
96-
self.crop_size = int(self.N_fft / 2) + 1
93+
self.crop_size = int(np.ceil(self.N_fft / 2))
9794
base_mu, base_cov = torch.zeros(self.crop_size * 2), torch.eye(
9895
self.crop_size * 2
9996
)

fflows/sequential_flows.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self, hidden, n_flows, FFT=True, flip=True, normalize=False):
2626

2727
self.FFT = FFT
2828
self.normalize = normalize
29-
self.n_flows =n_flows
29+
self.n_flows = n_flows
3030
self.hidden = hidden
3131

3232
if flip:
@@ -42,7 +42,7 @@ def forward(self, x):
4242
if self.normalize:
4343
x = (x - self.fft_mean) / (self.fft_std + 1e-8)
4444

45-
x = x.view(-1, self.d + 1)
45+
x = x.view(-1, self.d)
4646

4747
log_jacobs = []
4848

@@ -55,17 +55,12 @@ def forward(self, x):
5555
return x, log_pz, sum(log_jacobs)
5656

5757
def inverse(self, z):
58-
5958
for bijector, f in zip(reversed(self.bijectors), reversed(self.flips)):
60-
6159
z = bijector.inverse(z, flip=f)
6260

6361
if self.FFT:
64-
6562
if self.normalize:
66-
z = z * self.fft_std.view(-1, self.d + 1) + self.fft_mean.view(
67-
-1, self.d + 1
68-
)
63+
z = z * self.fft_std.view(-1, self.d) + self.fft_mean.view(-1, self.d)
6964

7065
z = self.FourierTransform.inverse(z)
7166

@@ -74,10 +69,18 @@ def inverse(self, z):
7469
def fit(self, X, epochs=500, batch_size=128, learning_rate=1e-3, display_step=100):
7570
X_train = torch.from_numpy(np.array(X)).float()
7671

72+
self.carry_flag = False
73+
if np.prod(X_train.shape[1:]) % 2 == 1:
74+
repeat_last = X_train[:, :, -1:]
75+
X_train = torch.concat([X_train, repeat_last], dim=2)
76+
self.carry_flag = True
77+
7778
self.individual_shape = X_train.shape[1:]
7879

7980
self.d = np.prod(self.individual_shape)
80-
self.k = int(np.floor(self.d / 2))
81+
self.k = int(np.ceil(self.d / 2))
82+
83+
assert self.d % 2 == 0
8184

8285
# Prepare models
8386
self.bijectors = nn.ModuleList(
@@ -95,6 +98,8 @@ def fit(self, X, epochs=500, batch_size=128, learning_rate=1e-3, display_step=10
9598

9699
# for normalizing the spectral transforms
97100
X_train_spectral = self.FourierTransform(X_train)[0]
101+
assert X_train_spectral.shape[-1] == self.k
102+
98103
self.fft_mean = torch.mean(X_train_spectral, dim=0)
99104
self.fft_std = torch.std(X_train_spectral, dim=0)
100105
optim = torch.optim.Adam(self.parameters(), lr=learning_rate)
@@ -132,21 +137,18 @@ def fit(self, X, epochs=500, batch_size=128, learning_rate=1e-3, display_step=10
132137
return losses
133138

134139
def sample(self, n_samples):
135-
136-
if self.FFT:
137-
138-
mu, cov = torch.zeros(self.d + 1), torch.eye(self.d + 1)
139-
140-
else:
141-
142-
mu, cov = torch.zeros(self.d), torch.eye(self.d)
140+
mu, cov = torch.zeros(self.d), torch.eye(self.d)
143141

144142
p_Z = MultivariateNormal(mu, cov)
145143
z = p_Z.rsample(sample_shape=(n_samples,))
146144

147145
X_sample = self.inverse(z)
146+
X_sample = X_sample.reshape(-1, *self.individual_shape)
148147

149-
return X_sample.reshape(-1, *self.individual_shape)
148+
if self.carry_flag:
149+
X_sample = X_sample[:, :, :-1]
150+
151+
return X_sample
150152

151153

152154
class RealNVP(nn.Module):
@@ -173,7 +175,7 @@ def forward(self, x):
173175
if self.normalize:
174176
x = (x - self.fft_mean) / (self.fft_std + 1e-8)
175177

176-
x = x.view(-1, self.d + 1)
178+
x = x.view(-1, self.d)
177179

178180
log_jacobs = []
179181

@@ -194,9 +196,7 @@ def inverse(self, z):
194196
if self.FFT:
195197

196198
if self.normalize:
197-
z = z * self.fft_std.view(-1, self.d + 1) + self.fft_mean.view(
198-
-1, self.d + 1
199-
)
199+
z = z * self.fft_std.view(-1, self.d) + self.fft_mean.view(-1, self.d1)
200200

201201
z = self.FourierTransform.inverse(z)
202202

@@ -259,7 +259,7 @@ def sample(self, n_samples):
259259

260260
if self.FFT:
261261

262-
mu, cov = torch.zeros(self.d + 1), torch.eye(self.d + 1)
262+
mu, cov = torch.zeros(self.d), torch.eye(self.d)
263263

264264
else:
265265

@@ -309,7 +309,7 @@ def forward(self, x):
309309
if self.normalize:
310310
x = (x - self.fft_mean) / self.fft_std
311311

312-
x = x.view(-1, self.d + 1)
312+
x = x.view(-1, self.d)
313313

314314
log_jacobs = []
315315

@@ -330,9 +330,7 @@ def inverse(self, z):
330330
if self.FFT:
331331

332332
if self.normalize:
333-
z = z * self.fft_std.view(-1, self.d + 1) + self.fft_mean.view(
334-
-1, self.d + 1
335-
)
333+
z = z * self.fft_std.view(-1, self.d) + self.fft_mean.view(-1, self.d)
336334

337335
z = self.FourierTransform.inverse(z)
338336

@@ -384,7 +382,7 @@ def sample(self, n_samples):
384382

385383
if self.FFT:
386384

387-
mu, cov = torch.zeros(self.d + 1), torch.eye(self.d + 1)
385+
mu, cov = torch.zeros(self.d), torch.eye(self.d)
388386

389387
else:
390388

fflows/utils/data_padding.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
# Copyright (c) 2020, Ahmed M. Alaa
22
# Licensed under the BSD 3-clause license (see LICENSE.txt)
3-
4-
from __future__ import absolute_import, division, print_function
5-
63
import sys
74

85
import numpy as np

fflows/utils/make_data.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
# Copyright (c) 2020, Ahmed M. Alaa
22
# Licensed under the BSD 3-clause license (see LICENSE.txt)
3-
4-
from __future__ import absolute_import, division, print_function
5-
63
import sys
74

85
import numpy as np
@@ -38,7 +35,7 @@ def create_autoregressive_data(
3835
# Create the input features
3936

4037
X = [np.random.normal(X_m, X_v, (seq_len, n_features)) for k in range(n_samples)]
41-
w = np.array([memory_factor ** k for k in range(seq_len)])
38+
w = np.array([memory_factor**k for k in range(seq_len)])
4239

4340
if mode == "noise-sweep":
4441

fflows/utils/spectral.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
# Copyright (c) 2021, Ahmed M. Alaa
22
# Licensed under the BSD 3-clause license (see LICENSE.txt)
3-
4-
5-
from __future__ import absolute_import, division, print_function
6-
73
import numpy as np
84

95

fflows/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.0.2"
1+
__version__ = "0.0.3"

tests/test_fourier_flows.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,21 @@
44
from data import sine_data_generation
55

66

7-
@pytest.mark.parametrize("normalize", [True, False])
7+
@pytest.mark.parametrize("normalize", [False, True])
88
@pytest.mark.parametrize("FFT", [True, False])
99
@pytest.mark.parametrize("flip", [True, False])
10-
def test_training(normalize: bool, FFT: bool, flip: bool) -> None:
11-
T = 11
10+
@pytest.mark.parametrize("T", [12, 11])
11+
@pytest.mark.parametrize("dims", [4, 6, 3])
12+
@pytest.mark.parametrize("n_flows", [5, 10, 15])
13+
def test_training(
14+
normalize: bool, FFT: bool, flip: bool, T: int, dims: int, n_flows: int
15+
) -> None:
1216
n_samples = 100
13-
dims = 3
14-
X = sine_data_generation(no=n_samples, seq_len=T, dim=3)
17+
X = sine_data_generation(no=n_samples, seq_len=T, dim=dims)
1518

1619
ff_params = {
1720
"hidden": 11,
18-
"n_flows": 11,
21+
"n_flows": n_flows,
1922
"normalize": normalize,
2023
"FFT": FFT,
2124
"flip": flip,
@@ -31,10 +34,10 @@ def test_training(normalize: bool, FFT: bool, flip: bool) -> None:
3134
_ = model.fit(X, **train_params)
3235

3336

34-
def test_generation() -> None:
35-
T = 11
37+
@pytest.mark.parametrize("T", [12, 11])
38+
@pytest.mark.parametrize("dims", [4, 3])
39+
def test_generation(T: int, dims: int) -> None:
3640
n_samples = 100
37-
dims = 3
3841
X = sine_data_generation(no=n_samples, seq_len=T, dim=dims)
3942

4043
ff_params = {"hidden": 11, "n_flows": 11, "normalize": False}

tests/test_real_nvp.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55

66

77
@pytest.mark.parametrize("normalize", [True, False])
8-
def test_training(normalize: bool) -> None:
9-
T = 11
8+
@pytest.mark.parametrize("T", [11, 12])
9+
@pytest.mark.parametrize("dims", [3, 4, 6])
10+
def test_training(normalize: bool, T: int, dims: int) -> None:
1011
n_samples = 100
11-
dims = 3
12-
X = sine_data_generation(no=n_samples, seq_len=T, dim=3)
12+
X = sine_data_generation(no=n_samples, seq_len=T, dim=dims)
1313

1414
model_params = {"hidden": 11, "n_flows": 11, "normalize": normalize}
1515
train_params = {

0 commit comments

Comments
 (0)