Skip to content

Commit

Permalink
efficientvit (mit) msa attention q/k/v ops need to be in float32 to t…
Browse files Browse the repository at this point in the history
…rain w/o NaN
  • Loading branch information
rwightman committed Aug 20, 2023
1 parent e6aeb91 commit dc18cda
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions timm/models/efficientvit_mit.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,13 @@ def forward(self, x):
k = self.kernel_func(k)
v = F.pad(v, (0, 1), mode="constant", value=1.)

kv = k.transpose(-1, -2) @ v
out = q @ kv
out = out[..., :-1] / (out[..., -1:] + self.eps)
dtype = v.dtype
q, k, v = q.float(), k.float(), v.float()
with torch.amp.autocast(device_type=v.device.type, enabled=False):
kv = k.transpose(-1, -2) @ v
out = q @ kv
out = out[..., :-1] / (out[..., -1:] + self.eps)
out = out.to(dtype)

# final projection
out = out.transpose(-1, -2).reshape(B, -1, H, W)
Expand Down

0 comments on commit dc18cda

Please sign in to comment.