Skip to content

Commit f406c60

Browse files
committed
refactor: use TF2.16+ optimizer structure
1 parent 9cc5491 commit f406c60

File tree

2 files changed

+25
-14
lines changed

2 files changed

+25
-14
lines changed

probabilistic_word_embeddings/estimation.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ def map_estimate(embedding, data=None, ns_data=None, data_generator=None, N=None
6868
ns_data = data
6969

7070
opt = tf.keras.optimizers.Adam(learning_rate=0.001)
71+
opt_theta = opt.add_variable_from_reference(embedding.theta, "theta")#, initial_value=embedding.theta)
72+
opt.build([opt_theta])
73+
opt_theta.assign(embedding.theta)
74+
7175
e = embedding
7276
if valid_data is not None:
7377
if not isinstance(valid_data, tf.Tensor):
@@ -107,13 +111,17 @@ def map_estimate(embedding, data=None, ns_data=None, data_generator=None, N=None
107111
i,j,x = next(data_generator)
108112
else:
109113
i,j,x = generate_batch(data, model=model, ws=ws, ns=ns, batch_size=batch_size, start_ix=start_ix, ns_data=ns_data)
110-
if model == "sgns":
111-
objective = lambda: - tf.reduce_sum(sgns_likelihood(e, i, j, x=x)) - e.log_prob(batch_size, N)
112-
elif model == "cbow":
113-
objective = lambda: - tf.reduce_sum(cbow_likelihood(e, i, j, x=x)) - e.log_prob(batch_size, N)
114-
_ = opt.minimize(objective, [embedding.theta])
114+
with tf.GradientTape() as tape:
115+
if model == "sgns":
116+
objective = - tf.reduce_sum(sgns_likelihood(e, i, j, x=x)) - e.log_prob(batch_size, N)
117+
elif model == "cbow":
118+
objective = - tf.reduce_sum(cbow_likelihood(e, i, j, x=x)) - e.log_prob(batch_size, N)
119+
d_l_d_theta = - tape.gradient(objective, e.theta)
120+
121+
opt.update_step(d_l_d_theta, opt_theta, 0.001)
122+
embedding.theta.assign(opt_theta)
115123
if training_loss:
116-
epoch_training_loss.append(objective() / len(i))
124+
epoch_training_loss.append(objective / len(i))
117125
batch_no = len(epoch_training_loss)
118126
if batch_no % 250 == 0:
119127
logger.log(logging.TRAIN, f"Epoch {epoch} mean training loss after {batch_no} batches: {np.mean(epoch_training_loss)}")
@@ -155,7 +163,7 @@ def mean_field_vi(embedding, data=None, data_generator=None, N=None, model="cbow
155163
if model not in ["sgns", "cbow"]:
156164
raise ValueError("model must be 'sgns' or 'cbow'")
157165

158-
optimizer = tf.keras.optimizers.experimental.Adam(learning_rate=0.001)
166+
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
159167
e = embedding
160168

161169
if words_to_fix_rotation:
@@ -177,10 +185,13 @@ def mean_field_vi(embedding, data=None, data_generator=None, N=None, model="cbow
177185
logger.info(f"Init std: {init_std}")
178186
q_std_log = tf.Variable(init_std)
179187

180-
opt_mean_var = optimizer.add_variable_from_reference(q_mean, "q_mean", initial_value=q_mean)
181-
opt_std_var = optimizer.add_variable_from_reference(q_std_log, "q_std_log", initial_value=q_std_log)
188+
opt_mean_var = optimizer.add_variable_from_reference(q_mean, "q_mean")
189+
opt_std_var = optimizer.add_variable_from_reference(q_std_log, "q_std_log")
182190
optimizer.build([opt_mean_var, opt_std_var])
183191

192+
opt_mean_var.assign(q_mean)
193+
opt_std_var.assign(q_std_log)
194+
184195
elbos = []
185196
for epoch in range(epochs):
186197
logger.log(logging.TRAIN, f"Epoch {epoch}")
@@ -216,8 +227,8 @@ def mean_field_vi(embedding, data=None, data_generator=None, N=None, model="cbow
216227
# Add the entropy term
217228
d_l_q_std_log = d_l_q_std_log - tf.ones(d_l_q_std_log.shape, dtype=tf.float64)
218229

219-
optimizer.update_step(d_l_d_q_mean, opt_mean_var)
220-
optimizer.update_step(d_l_q_std_log, opt_std_var)
230+
optimizer.update_step(d_l_d_q_mean, opt_mean_var, 0.001)
231+
optimizer.update_step(d_l_q_std_log, opt_std_var, 0.001)
221232

222233

223234
std_numerical_stability_constant = 10.0

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "probabilistic-word-embeddings"
3-
version = "1.17.1"
3+
version = "2.0.0rc1"
44
description = "Probabilistic Word Embeddings for Python"
55
authors = ["Your Name <[email protected]>"]
66
license = "MIT"
@@ -9,8 +9,8 @@ documentation = "https://ninpnin.github.io/probabilistic-word-embeddings/"
99

1010
[tool.poetry.dependencies]
1111
python = "^3.7"
12-
tensorflow = "<= 2.15.1"
13-
tensorflow-probability = "<= 0.23"
12+
tensorflow = "^2.16.1"
13+
tensorflow-probability = "*"
1414
progressbar2 = "*"
1515
networkx = "*"
1616
pandas = "*"

0 commit comments

Comments
 (0)