Skip to content

Commit b358c2f

Browse files
committed
test: fix error in test
1 parent 4d9744d commit b358c2f

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/vi.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,17 +89,17 @@ def test_vi_convergence(self):
8989
dim = 2
9090

9191
text_full = []
92-
for i in range(100):
92+
for i in range(10):
9393
text_full += text
9494

9595
e0 = Embedding(vocabulary=vocabulary, dimensionality=dim)
9696
init_mean = False
9797
init_std = 0.2
98-
q_mu0, q_std_e0 = mean_field_vi(e, text, model="sgns", evaluate=False, ws=ws, batch_size=batch_size, init_mean=init_mean, init_std=init_std, epochs=5)
98+
q_mu0, q_std_e0 = mean_field_vi(e0, text, model="sgns", evaluate=False, ws=ws, batch_size=batch_size, init_mean=init_mean, init_std=init_std, epochs=5)
9999

100100
e = Embedding(vocabulary=vocabulary, dimensionality=dim)
101101
q_mu, q_std_e = mean_field_vi(e, text_full, model="sgns", evaluate=False, ws=ws, batch_size=batch_size, init_mean=init_mean, init_std=init_std, epochs=5)
102-
self.assertGreater(np.mean(q_std_e.theta), np.mean(q_std_e0.theta))
102+
self.assertGreater(np.mean(q_std_e0.theta), np.mean(q_std_e.theta))
103103

104104

105105
if __name__ == '__main__':

0 commit comments

Comments
 (0)