Skip to content

Commit d429721

Browse files
committed
learned embeddings for multi-dimensional tensors
1 parent 49d49c2 commit d429721

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

autoPyTorch/pipeline/components/setup/network_embedding/LearnedEntityEmbedding.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,16 +108,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
108108
concat_seq = []
109109

110110
layer_pointer = 0
111+
# Time series tasks need to add targets to the embeddings. However, the target information is not recorded
112+
# by autoPyTorch's embeddings. Therefore, we need to add the targets parts to `concat_seq` manually, which is
113+
# the last few dimensions of the input x
114+
# we assign x_pointer to 0 beforehand to avoid the case that self.embed_features has 0 length
115+
x_pointer = 0
111116
for x_pointer, embed in enumerate(self.embed_features):
112-
current_feature_slice = x[:, x_pointer]
113117
if not embed:
114-
concat_seq.append(current_feature_slice.view(-1, 1))
118+
current_feature_slice = x[..., [x_pointer]]
119+
concat_seq.append(current_feature_slice)
115120
continue
121+
current_feature_slice = x[..., x_pointer]
116122
current_feature_slice = current_feature_slice.to(torch.int)
117123
concat_seq.append(self.ee_layers[layer_pointer](current_feature_slice))
118124
layer_pointer += 1
119125

120-
return torch.cat(concat_seq, dim=1)
126+
concat_seq.append(x[..., x_pointer:])
127+
128+
return torch.cat(concat_seq, dim=-1)
121129

122130
def _create_ee_layers(self) -> nn.ModuleList:
123131
# entity embeding layers are Linear Layers

0 commit comments

Comments
 (0)