@@ -241,44 +241,6 @@ def forward(self, input: dict, **kwargs: Any) -> torch.Tensor:
241
241
return input [self .output_key ] if self .output_key is not None else input
242
242
243
243
244
- class LinearInputEncoder (nn .Module ):
245
- """A simple linear input encoder."""
246
-
247
- def __init__ (
248
- self ,
249
- num_features : int ,
250
- emsize : int ,
251
- replace_nan_by_zero : bool = False ,
252
- bias : bool = True ,
253
- ):
254
- """Initialize the LinearInputEncoder.
255
-
256
- Args:
257
- num_features: The number of input features.
258
- emsize: The embedding size, i.e. the number of output features.
259
- replace_nan_by_zero: Whether to replace NaN values in the input by zero.
260
- bias: Whether to use a bias term in the linear layer.
261
- """
262
- super ().__init__ ()
263
- self .layer = nn .Linear (num_features , emsize , bias = bias )
264
- self .replace_nan_by_zero = replace_nan_by_zero
265
-
266
- def forward (self , * x : torch .Tensor , ** kwargs : Any ) -> tuple [torch .Tensor ]:
267
- """Apply the linear transformation to the input.
268
-
269
- Args:
270
- *x: The input tensors to concatenate and transform.
271
- **kwargs: Unused keyword arguments.
272
-
273
- Returns:
274
- A tuple containing the transformed tensor.
275
- """
276
- x = torch .cat (x , dim = - 1 ) # type: ignore
277
- if self .replace_nan_by_zero :
278
- x = torch .nan_to_num (x , nan = 0.0 ) # type: ignore
279
- return (self .layer (x ),)
280
-
281
-
282
244
class SeqEncStep (nn .Module ):
283
245
"""Abstract base class for sequential encoder steps.
284
246
0 commit comments