File tree Expand file tree Collapse file tree 1 file changed +11
-3
lines changed
autoPyTorch/pipeline/components/setup/network_embedding Expand file tree Collapse file tree 1 file changed +11
-3
lines changed Original file line number Diff line number Diff line change @@ -108,16 +108,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
108
108
concat_seq = []
109
109
110
110
layer_pointer = 0
111
+ # Time series tasks need to add targets to the embeddings. However, the target information is not recorded
112
+ # by autoPyTorch's embeddings. Therefore, we need to add the targets parts to `concat_seq` manually, which is
113
+ # the last few dimensions of the input x
114
+ # we assign x_pointer to 0 beforehand to avoid the case that self.embed_features has 0 length
115
+ x_pointer = 0
111
116
for x_pointer , embed in enumerate (self .embed_features ):
112
- current_feature_slice = x [:, x_pointer ]
113
117
if not embed :
114
- concat_seq .append (current_feature_slice .view (- 1 , 1 ))
118
+ current_feature_slice = x [..., [x_pointer ]]
119
+ concat_seq .append (current_feature_slice )
115
120
continue
121
+ current_feature_slice = x [..., x_pointer ]
116
122
current_feature_slice = current_feature_slice .to (torch .int )
117
123
concat_seq .append (self .ee_layers [layer_pointer ](current_feature_slice ))
118
124
layer_pointer += 1
119
125
120
- return torch .cat (concat_seq , dim = 1 )
126
+ concat_seq .append (x [..., x_pointer :])
127
+
128
+ return torch .cat (concat_seq , dim = - 1 )
121
129
122
130
def _create_ee_layers (self ) -> nn .ModuleList :
123
131
# entity embeding layers are Linear Layers
You can’t perform that action at this time.
0 commit comments