-
Notifications
You must be signed in to change notification settings - Fork 2
/
loss.py
26 lines (25 loc) · 773 Bytes
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, gama=2., size_average=True, weight=None):
super(FocalLoss, self).__init__()
'''
weight: size(C)
'''
self.gama = gama
self.size_average = size_average
self.weight = weight
def forward(self, inputs, targets):
'''
inputs: size(N,C)
targets: size(N)
'''
log_P = -F.cross_entropy(inputs, targets, weight=self.weight, reduction='none')
P = torch.exp(log_P)
batch_loss = -torch.pow(1-P, self.gama).mul(log_P)
if self.size_average:
loss = batch_loss.mean()
else:
loss = batch_loss.sum()
return loss