We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 1f3655c commit 401ffecCopy full SHA for 401ffec
test/test_loss.py
@@ -0,0 +1,31 @@
1
+#!/usr/bin/env python3
2
+# -*- coding: utf-8 -*-
3
+"""
4
+Created on Thu Oct 12 11:24:47 2017
5
+
6
+@author: Rachid Riad
7
8
9
+import pytest
10
+from abnet3.loss import coscos2, cosmargin
11
+import torch
12
+from torch.autograd import Variable
13
+import numpy as np
14
15
+losses = {
16
+ 'coscos2': coscos2,
17
+ 'cosmargin': cosmargin
18
+ }
19
20
+params = [a for a in losses ]
21
22
+@pytest.mark.parametrize('loss_func,', params)
23
+def test_forward(loss_func):
24
+ N_batch = 16
25
+ x1 = Variable(torch.randn(N_batch, 10))
26
+ x2 = Variable(torch.randn(N_batch, 10))
27
+ y = Variable(torch.from_numpy(np.random.choice([1,-1],N_batch)))
28
+ loss = losses[loss_func]()
29
+ res = loss.forward(x1,x2,y)
30
+ assert res.dim() == 1, 'fail for {}'.format(loss_func)
31
0 commit comments