Fix CUDA inference
This commit is contained in:
@@ -18,7 +18,6 @@ class PositionEncodingSine(nn.Module):
|
||||
We will remove the buggy impl after re-training all variants of our released models.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
pe = torch.zeros((d_model, *max_shape))
|
||||
y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
|
||||
x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
|
||||
@@ -39,4 +38,4 @@ class PositionEncodingSine(nn.Module):
|
||||
Args:
|
||||
x: [N, C, H, W]
|
||||
"""
|
||||
return x + self.pe[:, :, :x.size(2), :x.size(3)]
|
||||
return x + self.pe[:, :, :x.size(2), :x.size(3)].to(x.device)
|
||||
Reference in New Issue
Block a user