Skip to content

Commit 876cd49

Browse files
fix phylstm (PaddlePaddle#1161)
1 parent 58dafba commit 876cd49

File tree

1 file changed

+41
-7
lines changed

1 file changed

+41
-7
lines changed

ppsci/arch/phylstm.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,21 +145,36 @@ def forward(self, x):
145145
return result_dict
146146

147147
def _forward_type_2(self, x):
148-
output = self.lstm_model(x["ag"])
148+
output = x["ag"]
149+
for layer in self.lstm_model:
150+
output = layer(output)
151+
if isinstance(output, tuple):
152+
output = output[0]
153+
149154
eta_pred = output[:, :, 0 : self.output_size]
150155
eta_dot_pred = output[:, :, self.output_size : 2 * self.output_size]
151156
g_pred = output[:, :, 2 * self.output_size :]
152157

153158
# for ag_c
154-
output_c = self.lstm_model(x["ag_c"])
159+
output_c = x["ag_c"]
160+
for layer in self.lstm_model:
161+
output_c = layer(output_c)
162+
if isinstance(output_c, tuple):
163+
output_c = output_c[0]
164+
155165
eta_pred_c = output_c[:, :, 0 : self.output_size]
156166
eta_dot_pred_c = output_c[:, :, self.output_size : 2 * self.output_size]
157167
g_pred_c = output_c[:, :, 2 * self.output_size :]
158168
eta_t_pred_c = paddle.matmul(x["phi"], eta_pred_c)
159169
eta_tt_pred_c = paddle.matmul(x["phi"], eta_dot_pred_c)
160170
eta_dot1_pred_c = eta_dot_pred_c[:, :, 0:1]
161171
tmp = paddle.concat([eta_pred_c, eta_dot1_pred_c, g_pred_c], 2)
162-
f = self.lstm_model_f(tmp)
172+
f = tmp
173+
for layer in self.lstm_model_f:
174+
f = layer(f)
175+
if isinstance(f, tuple):
176+
f = f[0]
177+
163178
lift_pred_c = eta_tt_pred_c + f
164179

165180
return {
@@ -173,12 +188,22 @@ def _forward_type_2(self, x):
173188

174189
def _forward_type_3(self, x):
175190
# physics informed neural networks
176-
output = self.lstm_model(x["ag"])
191+
output = x["ag"]
192+
for layer in self.lstm_model:
193+
output = layer(output)
194+
if isinstance(output, tuple):
195+
output = output[0]
196+
177197
eta_pred = output[:, :, 0 : self.output_size]
178198
eta_dot_pred = output[:, :, self.output_size : 2 * self.output_size]
179199
g_pred = output[:, :, 2 * self.output_size :]
180200

181-
output_c = self.lstm_model(x["ag_c"])
201+
output_c = x["ag_c"]
202+
for layer in self.lstm_model:
203+
output_c = layer(output_c)
204+
if isinstance(output_c, tuple):
205+
output_c = output_c[0]
206+
182207
eta_pred_c = output_c[:, :, 0 : self.output_size]
183208
eta_dot_pred_c = output_c[:, :, self.output_size : 2 * self.output_size]
184209
g_pred_c = output_c[:, :, 2 * self.output_size :]
@@ -187,11 +212,20 @@ def _forward_type_3(self, x):
187212
eta_tt_pred_c = paddle.matmul(x["phi"], eta_dot_pred_c)
188213
g_t_pred_c = paddle.matmul(x["phi"], g_pred_c)
189214

190-
f = self.lstm_model_f(paddle.concat([eta_pred_c, eta_dot_pred_c, g_pred_c], 2))
215+
f = paddle.concat([eta_pred_c, eta_dot_pred_c, g_pred_c], 2)
216+
for layer in self.lstm_model_f:
217+
f = layer(f)
218+
if isinstance(f, tuple):
219+
f = f[0]
220+
191221
lift_pred_c = eta_tt_pred_c + f
192222

193223
eta_dot1_pred_c = eta_dot_pred_c[:, :, 0:1]
194-
g_dot_pred_c = self.lstm_model_g(paddle.concat([eta_dot1_pred_c, g_pred_c], 2))
224+
g_dot_pred_c = paddle.concat([eta_dot1_pred_c, g_pred_c], 2)
225+
for layer in self.lstm_model_g:
226+
g_dot_pred_c = layer(g_dot_pred_c)
227+
if isinstance(g_dot_pred_c, tuple):
228+
g_dot_pred_c = g_dot_pred_c[0]
195229

196230
return {
197231
"eta_pred": eta_pred,

0 commit comments

Comments
 (0)