Skip to content

Commit 252a92c

Browse files
fix phylstm
1 parent a76f61e commit 252a92c

File tree

1 file changed

+48
-7
lines changed

1 file changed

+48
-7
lines changed

ppsci/arch/phylstm.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,21 +145,39 @@ def forward(self, x):
145145
return result_dict
146146

147147
def _forward_type_2(self, x):
148-
output = self.lstm_model(x["ag"])
148+
# output = self.lstm_model(x["ag"])
149+
output = x["ag"]
150+
for layer in self.lstm_model:
151+
output = layer(output)
152+
if isinstance(output, tuple):
153+
output = output[0]
154+
149155
eta_pred = output[:, :, 0 : self.output_size]
150156
eta_dot_pred = output[:, :, self.output_size : 2 * self.output_size]
151157
g_pred = output[:, :, 2 * self.output_size :]
152158

153159
# for ag_c
154-
output_c = self.lstm_model(x["ag_c"])
160+
# output_c = self.lstm_model(x["ag_c"])
161+
output_c = x["ag_c"]
162+
for layer in self.lstm_model:
163+
output_c = layer(output_c)
164+
if isinstance(output_c, tuple):
165+
output_c = output_c[0]
166+
155167
eta_pred_c = output_c[:, :, 0 : self.output_size]
156168
eta_dot_pred_c = output_c[:, :, self.output_size : 2 * self.output_size]
157169
g_pred_c = output_c[:, :, 2 * self.output_size :]
158170
eta_t_pred_c = paddle.matmul(x["phi"], eta_pred_c)
159171
eta_tt_pred_c = paddle.matmul(x["phi"], eta_dot_pred_c)
160172
eta_dot1_pred_c = eta_dot_pred_c[:, :, 0:1]
161173
tmp = paddle.concat([eta_pred_c, eta_dot1_pred_c, g_pred_c], 2)
162-
f = self.lstm_model_f(tmp)
174+
# f = self.lstm_model_f(tmp)
175+
f = tmp
176+
for layer in self.lstm_model_f:
177+
f = layer(f)
178+
if isinstance(f, tuple):
179+
f = f[0]
180+
163181
lift_pred_c = eta_tt_pred_c + f
164182

165183
return {
@@ -173,12 +191,24 @@ def _forward_type_2(self, x):
173191

174192
def _forward_type_3(self, x):
175193
# physics informed neural networks
176-
output = self.lstm_model(x["ag"])
194+
# output = self.lstm_model(x["ag"])
195+
output = x["ag"]
196+
for layer in self.lstm_model:
197+
output = layer(output)
198+
if isinstance(output, tuple):
199+
output = output[0]
200+
177201
eta_pred = output[:, :, 0 : self.output_size]
178202
eta_dot_pred = output[:, :, self.output_size : 2 * self.output_size]
179203
g_pred = output[:, :, 2 * self.output_size :]
180204

181-
output_c = self.lstm_model(x["ag_c"])
205+
# output_c = self.lstm_model(x["ag_c"])
206+
output_c = x["ag_c"]
207+
for layer in self.lstm_model:
208+
output_c = layer(output_c)
209+
if isinstance(output_c, tuple):
210+
output_c = output_c[0]
211+
182212
eta_pred_c = output_c[:, :, 0 : self.output_size]
183213
eta_dot_pred_c = output_c[:, :, self.output_size : 2 * self.output_size]
184214
g_pred_c = output_c[:, :, 2 * self.output_size :]
@@ -187,11 +217,22 @@ def _forward_type_3(self, x):
187217
eta_tt_pred_c = paddle.matmul(x["phi"], eta_dot_pred_c)
188218
g_t_pred_c = paddle.matmul(x["phi"], g_pred_c)
189219

190-
f = self.lstm_model_f(paddle.concat([eta_pred_c, eta_dot_pred_c, g_pred_c], 2))
220+
# f = self.lstm_model_f(paddle.concat([eta_pred_c, eta_dot_pred_c, g_pred_c], 2))
221+
f = paddle.concat([eta_pred_c, eta_dot_pred_c, g_pred_c], 2)
222+
for layer in self.lstm_model_f:
223+
f = layer(f)
224+
if isinstance(f, tuple):
225+
f = f[0]
226+
191227
lift_pred_c = eta_tt_pred_c + f
192228

193229
eta_dot1_pred_c = eta_dot_pred_c[:, :, 0:1]
194-
g_dot_pred_c = self.lstm_model_g(paddle.concat([eta_dot1_pred_c, g_pred_c], 2))
230+
# g_dot_pred_c = self.lstm_model_g(paddle.concat([eta_dot1_pred_c, g_pred_c], 2))
231+
g_dot_pred_c = paddle.concat([eta_dot1_pred_c, g_pred_c], 2)
232+
for layer in self.lstm_model_g:
233+
g_dot_pred_c = layer(g_dot_pred_c)
234+
if isinstance(g_dot_pred_c, tuple):
235+
g_dot_pred_c = g_dot_pred_c[0]
195236

196237
return {
197238
"eta_pred": eta_pred,

0 commit comments

Comments
 (0)