Skip to content

Commit

Permalink
Merge pull request #2143 from huggingface/fix_asymm_set_grad_enable
Browse files Browse the repository at this point in the history
Fix #2132, remove use of _C.set_grad_enable. Line endings were messed up too
  • Loading branch information
rwightman authored Apr 9, 2024
2 parents 9531eb7 + 5c5ae8d commit f5ea076
Showing 1 changed file with 97 additions and 97 deletions.
194 changes: 97 additions & 97 deletions timm/loss/asymmetric_loss.py
Original file line number Diff line number Diff line change
@@ -1,97 +1,97 @@
import torch
import torch.nn as nn


class AsymmetricLossMultiLabel(nn.Module):
def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False):
super(AsymmetricLossMultiLabel, self).__init__()

self.gamma_neg = gamma_neg
self.gamma_pos = gamma_pos
self.clip = clip
self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
self.eps = eps

def forward(self, x, y):
""""
Parameters
----------
x: input logits
y: targets (multi-label binarized vector)
"""

# Calculating Probabilities
x_sigmoid = torch.sigmoid(x)
xs_pos = x_sigmoid
xs_neg = 1 - x_sigmoid

# Asymmetric Clipping
if self.clip is not None and self.clip > 0:
xs_neg = (xs_neg + self.clip).clamp(max=1)

# Basic CE calculation
los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
loss = los_pos + los_neg

# Asymmetric Focusing
if self.gamma_neg > 0 or self.gamma_pos > 0:
if self.disable_torch_grad_focal_loss:
torch._C.set_grad_enabled(False)
pt0 = xs_pos * y
pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
pt = pt0 + pt1
one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
one_sided_w = torch.pow(1 - pt, one_sided_gamma)
if self.disable_torch_grad_focal_loss:
torch._C.set_grad_enabled(True)
loss *= one_sided_w

return -loss.sum()


class AsymmetricLossSingleLabel(nn.Module):
def __init__(self, gamma_pos=1, gamma_neg=4, eps: float = 0.1, reduction='mean'):
super(AsymmetricLossSingleLabel, self).__init__()

self.eps = eps
self.logsoftmax = nn.LogSoftmax(dim=-1)
self.targets_classes = [] # prevent gpu repeated memory allocation
self.gamma_pos = gamma_pos
self.gamma_neg = gamma_neg
self.reduction = reduction

def forward(self, inputs, target, reduction=None):
""""
Parameters
----------
x: input logits
y: targets (1-hot vector)
"""

num_classes = inputs.size()[-1]
log_preds = self.logsoftmax(inputs)
self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1)

# ASL weights
targets = self.targets_classes
anti_targets = 1 - targets
xs_pos = torch.exp(log_preds)
xs_neg = 1 - xs_pos
xs_pos = xs_pos * targets
xs_neg = xs_neg * anti_targets
asymmetric_w = torch.pow(1 - xs_pos - xs_neg,
self.gamma_pos * targets + self.gamma_neg * anti_targets)
log_preds = log_preds * asymmetric_w

if self.eps > 0: # label smoothing
self.targets_classes = self.targets_classes.mul(1 - self.eps).add(self.eps / num_classes)

# loss calculation
loss = - self.targets_classes.mul(log_preds)

loss = loss.sum(dim=-1)
if self.reduction == 'mean':
loss = loss.mean()

return loss
import torch
import torch.nn as nn


class AsymmetricLossMultiLabel(nn.Module):
def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False):
super(AsymmetricLossMultiLabel, self).__init__()

self.gamma_neg = gamma_neg
self.gamma_pos = gamma_pos
self.clip = clip
self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
self.eps = eps

def forward(self, x, y):
""""
Parameters
----------
x: input logits
y: targets (multi-label binarized vector)
"""

# Calculating Probabilities
x_sigmoid = torch.sigmoid(x)
xs_pos = x_sigmoid
xs_neg = 1 - x_sigmoid

# Asymmetric Clipping
if self.clip is not None and self.clip > 0:
xs_neg = (xs_neg + self.clip).clamp(max=1)

# Basic CE calculation
los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
loss = los_pos + los_neg

# Asymmetric Focusing
if self.gamma_neg > 0 or self.gamma_pos > 0:
if self.disable_torch_grad_focal_loss:
torch.set_grad_enabled(False)
pt0 = xs_pos * y
pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
pt = pt0 + pt1
one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
one_sided_w = torch.pow(1 - pt, one_sided_gamma)
if self.disable_torch_grad_focal_loss:
torch.set_grad_enabled(True)
loss *= one_sided_w

return -loss.sum()


class AsymmetricLossSingleLabel(nn.Module):
def __init__(self, gamma_pos=1, gamma_neg=4, eps: float = 0.1, reduction='mean'):
super(AsymmetricLossSingleLabel, self).__init__()

self.eps = eps
self.logsoftmax = nn.LogSoftmax(dim=-1)
self.targets_classes = [] # prevent gpu repeated memory allocation
self.gamma_pos = gamma_pos
self.gamma_neg = gamma_neg
self.reduction = reduction

def forward(self, inputs, target, reduction=None):
""""
Parameters
----------
x: input logits
y: targets (1-hot vector)
"""

num_classes = inputs.size()[-1]
log_preds = self.logsoftmax(inputs)
self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1)

# ASL weights
targets = self.targets_classes
anti_targets = 1 - targets
xs_pos = torch.exp(log_preds)
xs_neg = 1 - xs_pos
xs_pos = xs_pos * targets
xs_neg = xs_neg * anti_targets
asymmetric_w = torch.pow(1 - xs_pos - xs_neg,
self.gamma_pos * targets + self.gamma_neg * anti_targets)
log_preds = log_preds * asymmetric_w

if self.eps > 0: # label smoothing
self.targets_classes = self.targets_classes.mul(1 - self.eps).add(self.eps / num_classes)

# loss calculation
loss = - self.targets_classes.mul(log_preds)

loss = loss.sum(dim=-1)
if self.reduction == 'mean':
loss = loss.mean()

return loss

0 comments on commit f5ea076

Please sign in to comment.