Fix CUDA inference

This commit is contained in:
Ibai
2022-04-08 12:00:32 +09:00
parent 494d12893b
commit 3396d2fd06
2 changed files with 17 additions and 10 deletions
+1 -2
View File
@@ -18,7 +18,6 @@ class PositionEncodingSine(nn.Module):
We will remove the buggy impl after re-training all variants of our released models.
"""
super().__init__()
pe = torch.zeros((d_model, *max_shape))
y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
@@ -39,4 +38,4 @@ class PositionEncodingSine(nn.Module):
Args:
x: [N, C, H, W]
"""
return x + self.pe[:, :, :x.size(2), :x.size(3)]
return x + self.pe[:, :, :x.size(2), :x.size(3)].to(x.device)