Skip to content

Commit fd01d98

Browse files
remove unused LinearInputEncoder and fix casing in development instructions (#320)
1 parent 49394b0 commit fd01d98

File tree

2 files changed

+1
-39
lines changed

2 files changed

+1
-39
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ Not effective:
265265
python -m venv venv
266266
source venv/bin/activate # On Windows: venv\Scripts\activate
267267
git clone https://github.com/PriorLabs/TabPFN.git
268-
cd tabpfn
268+
cd TabPFN
269269
pip install -e ".[dev]"
270270
pre-commit install
271271
```

src/tabpfn/model/encoders.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -241,44 +241,6 @@ def forward(self, input: dict, **kwargs: Any) -> torch.Tensor:
241241
return input[self.output_key] if self.output_key is not None else input
242242

243243

244-
class LinearInputEncoder(nn.Module):
245-
"""A simple linear input encoder."""
246-
247-
def __init__(
248-
self,
249-
num_features: int,
250-
emsize: int,
251-
replace_nan_by_zero: bool = False,
252-
bias: bool = True,
253-
):
254-
"""Initialize the LinearInputEncoder.
255-
256-
Args:
257-
num_features: The number of input features.
258-
emsize: The embedding size, i.e. the number of output features.
259-
replace_nan_by_zero: Whether to replace NaN values in the input by zero.
260-
bias: Whether to use a bias term in the linear layer.
261-
"""
262-
super().__init__()
263-
self.layer = nn.Linear(num_features, emsize, bias=bias)
264-
self.replace_nan_by_zero = replace_nan_by_zero
265-
266-
def forward(self, *x: torch.Tensor, **kwargs: Any) -> tuple[torch.Tensor]:
267-
"""Apply the linear transformation to the input.
268-
269-
Args:
270-
*x: The input tensors to concatenate and transform.
271-
**kwargs: Unused keyword arguments.
272-
273-
Returns:
274-
A tuple containing the transformed tensor.
275-
"""
276-
x = torch.cat(x, dim=-1) # type: ignore
277-
if self.replace_nan_by_zero:
278-
x = torch.nan_to_num(x, nan=0.0) # type: ignore
279-
return (self.layer(x),)
280-
281-
282244
class SeqEncStep(nn.Module):
283245
"""Abstract base class for sequential encoder steps.
284246

0 commit comments

Comments
 (0)