Skip to content

Commit 560e721

Browse files
authored
Update dynamic input FAQ with callback solution (#592)
Update FAQ with callback solution From https://stackoverflow.com/a/60170023/1643939
1 parent dc70fb4 commit 560e721

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

docs/user/FAQ.rst

+10-7
Original file line numberDiff line numberDiff line change
@@ -341,16 +341,19 @@ in skorch. Here is an example:
341341

342342
.. code:: python
343343
344-
class MyNet(NeuralNetClassifier):
345-
def check_data(self, X, y):
346-
super().check_data(X, y)
344+
class InputShapeSetter(skorch.callbacks.Callback):
345+
def on_train_begin(self, net, X, y):
346+
net.set_params(module__input_dim=X.shape[1])
347347
348-
if self.module_.input_units != X.shape[1]:
349-
self.set_params(module__input_units=X.shape[1])
350-
self.initialize()
348+
349+
net = skorch.NeuralNetClassifier(
350+
ClassifierModule,
351+
callbacks=[InputShapeSetter()],
352+
)
351353
352354
This assumes that your module accepts an argument called
353355
``input_units``, which determines the number of units of the input
354356
layer, and that the number of features can be determined by
355357
``X.shape[1]``. If those assumptions are not true for your case,
356-
adjust the code accordingly.
358+
adjust the code accordingly. A fully working example can be found
359+
on `stackoverflow <https://stackoverflow.com/a/60170023/1643939>`_.

0 commit comments

Comments
 (0)