Skip to content

Commit 401ffec

Browse files
committed
first test loss
1 parent 1f3655c commit 401ffec

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

test/test_loss.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
"""
4+
Created on Thu Oct 12 11:24:47 2017
5+
6+
@author: Rachid Riad
7+
"""
8+
9+
import pytest
10+
from abnet3.loss import coscos2, cosmargin
11+
import torch
12+
from torch.autograd import Variable
13+
import numpy as np
14+
15+
losses = {
16+
'coscos2': coscos2,
17+
'cosmargin': cosmargin
18+
}
19+
20+
params = [a for a in losses ]
21+
22+
@pytest.mark.parametrize('loss_func,', params)
23+
def test_forward(loss_func):
24+
N_batch = 16
25+
x1 = Variable(torch.randn(N_batch, 10))
26+
x2 = Variable(torch.randn(N_batch, 10))
27+
y = Variable(torch.from_numpy(np.random.choice([1,-1],N_batch)))
28+
loss = losses[loss_func]()
29+
res = loss.forward(x1,x2,y)
30+
assert res.dim() == 1, 'fail for {}'.format(loss_func)
31+

0 commit comments

Comments
 (0)