File tree 1 file changed +10
-7
lines changed
1 file changed +10
-7
lines changed Original file line number Diff line number Diff line change @@ -341,16 +341,19 @@ in skorch. Here is an example:
341
341
342
342
.. code :: python
343
343
344
- class MyNet ( NeuralNetClassifier ):
345
- def check_data (self , X , y ):
346
- super ().check_data(X, y )
344
+ class InputShapeSetter ( skorch . callbacks . Callback ):
345
+ def on_train_begin (self , net , X , y ):
346
+ net.set_params( module__input_dim = X.shape[ 1 ] )
347
347
348
- if self .module_.input_units != X.shape[1 ]:
349
- self .set_params(module__input_units = X.shape[1 ])
350
- self .initialize()
348
+
349
+ net = skorch.NeuralNetClassifier(
350
+ ClassifierModule,
351
+ callbacks = [InputShapeSetter()],
352
+ )
351
353
352
354
This assumes that your module accepts an argument called
353
355
``input_units ``, which determines the number of units of the input
354
356
layer, and that the number of features can be determined by
355
357
``X.shape[1] ``. If those assumptions are not true for your case,
356
- adjust the code accordingly.
358
+ adjust the code accordingly. A fully working example can be found
359
+ on `stackoverflow <https://stackoverflow.com/a/60170023/1643939 >`_.
You can’t perform that action at this time.
0 commit comments