Skip to content

Commit 8c7d277

Browse files
committed
Fix nimarb#23: Improve calc_loss's numerical stability using cross entropy.
1 parent 66c9a9e commit 8c7d277

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

Diff for: pytorch_influence_functions/influence_function.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,10 @@ def calc_loss(y, t):
6767
# if dim == [0, 1, 3] then dim=0; else dim=1
6868
####################
6969
# y = torch.nn.functional.log_softmax(y, dim=0)
70-
y = torch.nn.functional.log_softmax(y)
71-
loss = torch.nn.functional.nll_loss(
72-
y, t, weight=None, reduction='mean')
70+
# y = torch.nn.functional.log_softmax(y)
71+
# loss = torch.nn.functional.nll_loss(
72+
# y, t, weight=None, reduction='mean')
73+
loss = torch.nn.functional.cross_entropy(y, t, weight=None, reduction="mean")
7374
return loss
7475

7576

0 commit comments

Comments
 (0)