Skip to content

Commit 62e9764

Browse files
[ADD] fit pipeline honoring API constraints with tests (#348)
* Add fit pipeline with tests * Add documentation for get dataset * update documentation * fix tests * remove permutation importance from visualisation example * change disable_file_output * add * fix flake * fix test and examples * change type of disable_file_output * Address comments from eddie * fix docstring in api * fix tests for base api * fix tests for base api * fix tests after rebase * reduce dataset size in example * remove optional from doc string * Handle unsuccessful fitting of pipeline better * fix flake in tests * change to default configuration for documentation * add warning for no ensemble created when y_optimization in disable_file_output * reduce budget for single configuration * address comments from eddie * address comments from shuhei * Add autoPyTorchEnum * fix flake in tests * address comments from shuhei * Apply suggestions from code review Co-authored-by: nabenabe0928 <[email protected]> * fix flake * use **dataset_kwargs * fix flake * change to enforce keyword args Co-authored-by: nabenabe0928 <[email protected]>
1 parent aa927a3 commit 62e9764

18 files changed

+1164
-205
lines changed

autoPyTorch/api/base_task.py

Lines changed: 397 additions & 50 deletions
Large diffs are not rendered by default.

autoPyTorch/api/tabular_classification.py

Lines changed: 128 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import os
2-
import uuid
3-
from typing import Any, Callable, Dict, List, Optional, Union
1+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
42

53
import numpy as np
64

@@ -13,11 +11,13 @@
1311
TASK_TYPES_TO_STRING,
1412
)
1513
from autoPyTorch.data.tabular_validator import TabularInputValidator
14+
from autoPyTorch.datasets.base_dataset import BaseDatasetPropertiesType
1615
from autoPyTorch.datasets.resampling_strategy import (
1716
CrossValTypes,
1817
HoldoutValTypes,
1918
)
2019
from autoPyTorch.datasets.tabular_dataset import TabularDataset
20+
from autoPyTorch.evaluation.utils import DisableFileOutputParameters
2121
from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline
2222
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
2323

@@ -54,13 +54,16 @@ class TabularClassificationTask(BaseTask):
5454
delete_tmp_folder_after_terminate (bool):
5555
Determines whether to delete the temporary directory,
5656
when finished
57-
include_components (Optional[Dict]):
58-
If None, all possible components are used.
59-
Otherwise specifies set of components to use.
60-
exclude_components (Optional[Dict]):
61-
If None, all possible components are used.
62-
Otherwise specifies set of components not to use.
63-
Incompatible with include components.
57+
include_components (Optional[Dict[str, Any]]):
58+
Dictionary containing components to include. Key is the node
59+
name and Value is an Iterable of the names of the components
60+
to include. Only these components will be present in the
61+
search space.
62+
exclude_components (Optional[Dict[str, Any]]):
63+
Dictionary containing components to exclude. Key is the node
64+
name and Value is an Iterable of the names of the components
65+
to exclude. All except these components will be present in
66+
the search space.
6467
search_space_updates (Optional[HyperparameterSearchSpaceUpdates]):
6568
search space updates that can be used to modify the search
6669
space of particular components or choice modules of the pipeline
@@ -78,8 +81,8 @@ def __init__(
7881
output_directory: Optional[str] = None,
7982
delete_tmp_folder_after_terminate: bool = True,
8083
delete_output_folder_after_terminate: bool = True,
81-
include_components: Optional[Dict] = None,
82-
exclude_components: Optional[Dict] = None,
84+
include_components: Optional[Dict[str, Any]] = None,
85+
exclude_components: Optional[Dict[str, Any]] = None,
8386
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
8487
resampling_strategy_args: Optional[Dict[str, Any]] = None,
8588
backend: Optional[Backend] = None,
@@ -106,18 +109,109 @@ def __init__(
106109
task_type=TASK_TYPES_TO_STRING[TABULAR_CLASSIFICATION],
107110
)
108111

109-
def build_pipeline(self, dataset_properties: Dict[str, Any]) -> TabularClassificationPipeline:
112+
def build_pipeline(
113+
self,
114+
dataset_properties: Dict[str, BaseDatasetPropertiesType],
115+
include_components: Optional[Dict[str, Any]] = None,
116+
exclude_components: Optional[Dict[str, Any]] = None,
117+
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
118+
) -> TabularClassificationPipeline:
110119
"""
111-
Build pipeline according to current task and for the passed dataset properties
120+
Build pipeline according to current task
121+
and for the passed dataset properties
112122
113123
Args:
114-
dataset_properties (Dict[str,Any])
124+
dataset_properties (Dict[str, Any]):
125+
Characteristics of the dataset to guide the pipeline
126+
choices of components
127+
include_components (Optional[Dict[str, Any]]):
128+
Dictionary containing components to include. Key is the node
129+
name and Value is an Iterable of the names of the components
130+
to include. Only these components will be present in the
131+
search space.
132+
exclude_components (Optional[Dict[str, Any]]):
133+
Dictionary containing components to exclude. Key is the node
134+
name and Value is an Iterable of the names of the components
135+
to exclude. All except these components will be present in
136+
the search space.
137+
search_space_updates (Optional[HyperparameterSearchSpaceUpdates]):
138+
Search space updates that can be used to modify the search
139+
space of particular components or choice modules of the pipeline
115140
116141
Returns:
117-
TabularClassificationPipeline:
118-
Pipeline compatible with the given dataset properties.
142+
TabularClassificationPipeline
143+
144+
"""
145+
return TabularClassificationPipeline(dataset_properties=dataset_properties,
146+
include=include_components,
147+
exclude=exclude_components,
148+
search_space_updates=search_space_updates)
149+
150+
def _get_dataset_input_validator(
151+
self,
152+
X_train: Union[List, pd.DataFrame, np.ndarray],
153+
y_train: Union[List, pd.DataFrame, np.ndarray],
154+
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
155+
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
156+
resampling_strategy: Optional[Union[CrossValTypes, HoldoutValTypes]] = None,
157+
resampling_strategy_args: Optional[Dict[str, Any]] = None,
158+
dataset_name: Optional[str] = None,
159+
) -> Tuple[TabularDataset, TabularInputValidator]:
119160
"""
120-
return TabularClassificationPipeline(dataset_properties=dataset_properties)
161+
Returns an object of `TabularDataset` and an object of
162+
`TabularInputValidator` according to the current task.
163+
164+
Args:
165+
X_train (Union[List, pd.DataFrame, np.ndarray]):
166+
Training feature set.
167+
y_train (Union[List, pd.DataFrame, np.ndarray]):
168+
Training target set.
169+
X_test (Optional[Union[List, pd.DataFrame, np.ndarray]]):
170+
Testing feature set
171+
y_test (Optional[Union[List, pd.DataFrame, np.ndarray]]):
172+
Testing target set
173+
resampling_strategy (Optional[Union[CrossValTypes, HoldoutValTypes]]):
174+
Strategy to split the training data. if None, uses
175+
HoldoutValTypes.holdout_validation.
176+
resampling_strategy_args (Optional[Dict[str, Any]]):
177+
arguments required for the chosen resampling strategy. If None, uses
178+
the default values provided in DEFAULT_RESAMPLING_PARAMETERS
179+
in ```datasets/resampling_strategy.py```.
180+
dataset_name (Optional[str]):
181+
name of the dataset, used as experiment name.
182+
Returns:
183+
TabularDataset:
184+
the dataset object.
185+
TabularInputValidator:
186+
the input validator fitted on the data.
187+
"""
188+
189+
resampling_strategy = resampling_strategy if resampling_strategy is not None else self.resampling_strategy
190+
resampling_strategy_args = resampling_strategy_args if resampling_strategy_args is not None else \
191+
self.resampling_strategy_args
192+
193+
# Create a validator object to make sure that the data provided by
194+
# the user matches the autopytorch requirements
195+
InputValidator = TabularInputValidator(
196+
is_classification=True,
197+
logger_port=self._logger_port,
198+
)
199+
200+
# Fit a input validator to check the provided data
201+
# Also, an encoder is fit to both train and test data,
202+
# to prevent unseen categories during inference
203+
InputValidator.fit(X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test)
204+
205+
dataset = TabularDataset(
206+
X=X_train, Y=y_train,
207+
X_test=X_test, Y_test=y_test,
208+
validator=InputValidator,
209+
resampling_strategy=resampling_strategy,
210+
resampling_strategy_args=resampling_strategy_args,
211+
dataset_name=dataset_name
212+
)
213+
214+
return dataset, InputValidator
121215

122216
def search(
123217
self,
@@ -138,7 +232,7 @@ def search(
138232
get_smac_object_callback: Optional[Callable] = None,
139233
all_supported_metrics: bool = True,
140234
precision: int = 32,
141-
disable_file_output: List = [],
235+
disable_file_output: Optional[List[Union[str, DisableFileOutputParameters]]] = None,
142236
load_models: bool = True,
143237
portfolio_selection: Optional[str] = None,
144238
) -> 'BaseTask':
@@ -237,10 +331,10 @@ def search(
237331
precision (int: default=32):
238332
Numeric precision used when loading ensemble data.
239333
Can be either '16', '32' or '64'.
240-
disable_file_output (Union[bool, List]):
241-
If True, disable model and prediction output.
242-
Can also be used as a list to pass more fine-grained
243-
information on what to save. Allowed elements in the list are:
334+
disable_file_output (Optional[List[Union[str, DisableFileOutputParameters]]]):
335+
Used as a list to pass more fine-grained
336+
information on what to save. Must be a member of `DisableFileOutputParameters`.
337+
Allowed elements in the list are:
244338
245339
+ `y_optimization`:
246340
do not save the predictions for the optimization set,
@@ -253,6 +347,9 @@ def search(
253347
pipelines fit on each fold.
254348
+ `y_test`:
255349
do not save the predictions for the test set.
350+
+ `all`:
351+
do not save any of the above.
352+
For more information check `autoPyTorch.evaluation.utils.DisableFileOutputParameters`.
256353
load_models (bool: default=True):
257354
Whether to load the models after fitting AutoPyTorch.
258355
portfolio_selection (Optional[str]):
@@ -269,32 +366,15 @@ def search(
269366
self
270367
271368
"""
272-
if dataset_name is None:
273-
dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))
274369

275-
# we have to create a logger for at this point for the validator
276-
self._logger = self._get_logger(dataset_name)
277-
278-
# Create a validator object to make sure that the data provided by
279-
# the user matches the autopytorch requirements
280-
self.InputValidator = TabularInputValidator(
281-
is_classification=True,
282-
logger_port=self._logger_port,
283-
)
284-
285-
# Fit a input validator to check the provided data
286-
# Also, an encoder is fit to both train and test data,
287-
# to prevent unseen categories during inference
288-
self.InputValidator.fit(X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test)
289-
290-
self.dataset = TabularDataset(
291-
X=X_train, Y=y_train,
292-
X_test=X_test, Y_test=y_test,
293-
validator=self.InputValidator,
294-
dataset_name=dataset_name,
370+
self.dataset, self.InputValidator = self._get_dataset_input_validator(
371+
X_train=X_train,
372+
y_train=y_train,
373+
X_test=X_test,
374+
y_test=y_test,
295375
resampling_strategy=self.resampling_strategy,
296376
resampling_strategy_args=self.resampling_strategy_args,
297-
)
377+
dataset_name=dataset_name)
298378

299379
return self._search(
300380
dataset=self.dataset,
@@ -333,7 +413,7 @@ def predict(
333413
"""
334414
if self.InputValidator is None or not self.InputValidator._is_fitted:
335415
raise ValueError("predict() is only supported after calling search. Kindly call first "
336-
"the estimator fit() method.")
416+
"the estimator search() method.")
337417

338418
X_test = self.InputValidator.feature_validator.transform(X_test)
339419
predicted_probabilities = super().predict(X_test, batch_size=batch_size,
@@ -353,6 +433,6 @@ def predict_proba(self,
353433
batch_size: Optional[int] = None, n_jobs: int = 1) -> np.ndarray:
354434
if self.InputValidator is None or not self.InputValidator._is_fitted:
355435
raise ValueError("predict() is only supported after calling search. Kindly call first "
356-
"the estimator fit() method.")
436+
"the estimator search() method.")
357437
X_test = self.InputValidator.feature_validator.transform(X_test)
358438
return super().predict(X_test, batch_size=batch_size, n_jobs=n_jobs)

0 commit comments

Comments
 (0)