Skip to content

Commit fab2928

Browse files
authored
Fix callback issue in can (#265)
1 parent f700f6c commit fab2928

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

skada/deep/_divergence.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ def CAN(
255255
class_threshold=3,
256256
sigmas=None,
257257
base_criterion=None,
258+
callbacks=None,
258259
**kwargs,
259260
):
260261
"""Contrastive Adaptation Network (CAN) domain adaptation method.
@@ -281,6 +282,8 @@ def CAN(
281282
base_criterion : torch criterion (class)
282283
The base criterion used to compute the loss with source
283284
labels. If None, the default is `torch.nn.CrossEntropyLoss`.
285+
callbacks : list, optional
286+
List of callbacks to be used during training.
284287
285288
References
286289
----------
@@ -292,6 +295,14 @@ def CAN(
292295
if base_criterion is None:
293296
base_criterion = torch.nn.CrossEntropyLoss()
294297

298+
if callbacks is None:
299+
callbacks = [ComputeSourceCentroids()]
300+
else:
301+
if isinstance(callbacks, list):
302+
callbacks.append(ComputeSourceCentroids())
303+
else:
304+
callbacks = [callbacks, ComputeSourceCentroids()]
305+
295306
net = DomainAwareNet(
296307
module=DomainAwareModule,
297308
module__base_module=module,
@@ -305,7 +316,7 @@ def CAN(
305316
class_threshold=class_threshold,
306317
sigmas=sigmas,
307318
),
308-
callbacks=[ComputeSourceCentroids()],
319+
callbacks=callbacks,
309320
**kwargs,
310321
)
311322
return net

skada/deep/tests/test_deep_divergence.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#
55
# License: BSD 3-Clause
66
import pytest
7+
from skorch.callbacks import EpochScoring
78

89
pytest.importorskip("torch")
910

@@ -149,3 +150,48 @@ def test_can(sigmas, distance_threshold, class_threshold):
149150

150151
history = method.history_
151152
assert history[0]["train_loss"] > history[-1]["train_loss"]
153+
154+
155+
def test_can_with_custom_callbacks():
156+
module = ToyModule2D()
157+
module.eval()
158+
159+
n_samples = 10
160+
dataset = make_shifted_datasets(
161+
n_samples_source=n_samples,
162+
n_samples_target=n_samples,
163+
shift="concept_drift",
164+
noise=0.1,
165+
random_state=42,
166+
return_dataset=True,
167+
)
168+
169+
# Create a custom callback
170+
custom_callback = EpochScoring(scoring="accuracy", lower_is_better=False)
171+
172+
method = CAN(
173+
ToyModule2D(),
174+
reg=0.01,
175+
layer_name="dropout",
176+
batch_size=10,
177+
max_epochs=10,
178+
train_split=None,
179+
callbacks=[custom_callback], # Pass the custom callback
180+
)
181+
182+
X, y, sample_domain = dataset.pack_train(as_sources=["s"], as_targets=["t"])
183+
method.fit(X.astype(np.float32), y, sample_domain)
184+
185+
X_test, y_test, sample_domain_test = dataset.pack_test(as_targets=["t"])
186+
187+
y_pred = method.predict(X_test.astype(np.float32), sample_domain_test)
188+
189+
assert y_pred.shape[0] == X_test.shape[0]
190+
191+
history = method.history_
192+
assert history[0]["train_loss"] > history[-1]["train_loss"]
193+
194+
# Check if both custom callback and ComputeSourceCentroids are present
195+
callback_classes = [cb.__class__.__name__ for cb in method.callbacks]
196+
assert "EpochScoring" in callback_classes
197+
assert "ComputeSourceCentroids" in callback_classes

0 commit comments

Comments
 (0)