Skip to content

Dice loss backward function is wrong? #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ljpadam opened this issue Jul 15, 2017 · 10 comments
Open

Dice loss backward function is wrong? #2

ljpadam opened this issue Jul 15, 2017 · 10 comments

Comments

@ljpadam
Copy link

ljpadam commented Jul 15, 2017

Hi,
In the backward function of dice loss class, is this statement wrong?
grad_input = torch.cat((torch.mul(dDice, -grad_output[0]),
torch.mul(dDice, grad_output[0])), 0)

I think it should expand the dDice from one dimension to two dimension, and then concat them in the second dimension.

@anewlearner
Copy link

@ljpadam
Hi, Adam LI.
I run into a bug here. Can you show the your version of backward for dice loss.
Thank you.

@JDwangmo
Copy link

the dice loss may be wrong, I think the statement should be change to: grad_input = torch.cat((torch.mul(dDice, grad_output[0]),torch.mul(dDice, -grad_output[0])), dim=1) ?

@hfarhidzadeh
Copy link

@ljpadam @anewlearner @JDwangmo Did you guys find a proper implementation that works fine?

@hfarhidzadeh
Copy link

hfarhidzadeh commented Jan 15, 2018

Found the right implementation. Just use these two lines in backward function
dDice = torch.add(torch.mul(gt, 2), torch.mul(pred, -4))
grad_input = torch.cat((torch.mul(torch.unsqueeze(dDice,1), grad_output[0]),
torch.mul(torch.unsqueeze(dDice,1), -grad_output[0])), dim = 1)

@ghost
Copy link

ghost commented Jan 16, 2018

@hamidfarhid Thanks, it works

@xubaoquan33
Copy link

xubaoquan33 commented May 15, 2018

@Victor-2015 @hfarhidzadeh Could you show your dice loss code? The above code can't run because I meet another problem.
`/data1/xbqfile/vnet.pytorch-master/torchbiomed/loss.py in backward(self, grad_output)
55 IoU2 = intersect/(union*union)
56 pred = torch.mul(input[:, 1], IoU2)
---> 57 dDice = torch.add(torch.mul(gt, 2), torch.mul(pred, -4))
58 #grad_input = torch.cat((torch.mul(dDice, -grad_output[0]),
59 #torch.mul(dDice, grad_output[0])), 0)

RuntimeError: arguments are located on different GPUs at /pytorch/torch/lib/THC/generated/../generic/THCTensorMathPointwise.cu:269`

do you know how to solve it?

@xubaoquan33
Copy link

xubaoquan33 commented May 15, 2018

I find the solution of my problem.
previous GPU: gpu_ids=[2,3] , I change it to :gpu_ids=[0,1]. The bug disappear.
it maybe the bug of pytorch. It seems that this error only happens when device_ids[0] is not 0. like this
I have change all xx.cuda() to xx.cuda(device=gpu_ids[0]) , but still get error .So I change gpu_ids ,it work .

@JasonLeeSJTU
Copy link

grad_input = torch.cat((torch.mul(torch.unsqueeze(dDice,1), grad_output[0]),
torch.mul(torch.unsqueeze(dDice,1), -grad_output[0])), dim = 1)

@hfarhidzadeh Hi, why we need to cat the two part? Why not just return dDice * grad_output ?
Thx.

@hfarhidzadeh
Copy link

hfarhidzadeh commented Jun 15, 2019

grad_input = torch.cat((torch.mul(torch.unsqueeze(dDice,1), grad_output[0]),
torch.mul(torch.unsqueeze(dDice,1), -grad_output[0])), dim = 1)

@hfarhidzadeh Hi, why we need to cat the two part? Why not just return dDice * grad_output ?
Thx.

I think the way they implemented, in the end we have one matrix with two columns, not a vector. It is on my top head and don't remember the details. :)

@JasonLeeSJTU
Copy link

grad_input = torch.cat((torch.mul(torch.unsqueeze(dDice,1), grad_output[0]),
torch.mul(torch.unsqueeze(dDice,1), -grad_output[0])), dim = 1)

@hfarhidzadeh Hi, why we need to cat the two part? Why not just return dDice * grad_output ?
Thx.

I think the way they implemented, in the end we have one matrix with two columns, not a vector. It is on my top head and don't remember the details. :)

Thanks. 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants