change a bunch of stuff, add wip lightning implementation
This commit is contained in:
@@ -86,8 +86,15 @@ class LocalFeatureTransformer(nn.Module):
|
||||
"""
|
||||
assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
|
||||
|
||||
for layer, name in zip(self.layers, self.layer_names):
|
||||
|
||||
# NOTE Workaround for non statically determinable zip
|
||||
# for layer, name in zip(self.layers, self.layer_names):
|
||||
# layer_zip = ((layer, self.layer_names[i]) for i, layer in enumerate(self.layers))
|
||||
# layer_zip = []
|
||||
# for i, layer in enumerate(self.layers):
|
||||
# layer_zip.append((layer, self.layer_names[i]))
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
name = self.layer_names[i]
|
||||
if name == 'self':
|
||||
feat0 = layer(feat0, feat0, mask0, mask0)
|
||||
feat1 = layer(feat1, feat1, mask1, mask1)
|
||||
@@ -97,4 +104,4 @@ class LocalFeatureTransformer(nn.Module):
|
||||
else:
|
||||
raise KeyError
|
||||
|
||||
return feat0, feat1
|
||||
return feat0, feat1
|
||||
|
||||
Reference in New Issue
Block a user