Skip to content

Commit 43299a9

Browse files
committed
Dataset size reduction fixed, updated TargetValidator to match signatures (#1250)
* Moved to new splitter, moved to util file * flake8'd * Fixed errors, added test specifically for CustomStratifiedShuffleSplit * flake8'd * Updated docstring * Updated types in docstring * reduce_dataset_size_if_too_large supports more types * flake8'd * flake8'd * Updated docstring * Seperated out the data subsampling into individual functions * Improved typing from Automl.fit to reduce_dataset_size_if_too_large * flak8'd * subsample tested * Finished testing and flake8'd * Cleaned up transform function that was touched * ^ * Removed double typing * Cleaned up typing of convert_if_sparse * Cleaned up splitters and added size test * Cleanup doc in data * rogue line added was removed * Test fix * flake8'd * Typo fix * Fixed ordering of things * Fixed typing and tests of target_validator fit, transform, inv_transform * Updated doc * Updated Type return * Removed elif gaurd * removed extraneuous overload * Updated return type of feature validator * Type fixes for target validator fit * flake8'd * Moved to new splitter, moved to util file * flake8'd * Fixed errors, added test specifically for CustomStratifiedShuffleSplit * flake8'd * Updated docstring * Updated types in docstring * reduce_dataset_size_if_too_large supports more types * flake8'd * flake8'd * Updated docstring * Seperated out the data subsampling into individual functions * Improved typing from Automl.fit to reduce_dataset_size_if_too_large * flak8'd * subsample tested * Finished testing and flake8'd * Cleaned up transform function that was touched * ^ * Removed double typing * Cleaned up typing of convert_if_sparse * Cleaned up splitters and added size test * Cleanup doc in data * rogue line added was removed * Test fix * flake8'd * Typo fix * Fixed ordering of things * Fixed typing and tests of target_validator fit, transform, inv_transform * Updated doc * Updated Type return * Removed elif gaurd * removed extraneuous overload * Updated return type of feature validator * Type fixes for target validator fit * flake8'd * Fixed err message str and automl sparse y tests * Flak8'd * Fix sort indices * list type to List * Remove uneeded comment * Updated comment to make it more clear * Comment update * Fixed warning message for reduce_dataset_if_too_large * Fix test * Added check for error message in tests * Test Updates * Fix error msg * reinclude csr y to test * Reintroduced explicit subsample values test * flaked * Missed an uncomment * Update the comment for test of splitters * Updated warning message in CustomSplitter * Update comment in test * Update tests * Removed overloads * Narrowed type of subsample * Removed overload import * Fix `todense` giving np.matrix, using `toarray` * Made subsampling a little less aggresive * Changed multiplier back to 10 * Allow argument to specfiy how auto-sklearn handles compressing dataset size (#1341) * Added dataset_compression parameter and validation * Fix docstring * Updated docstring for `resampling_strategy` * Updated param def and memory_allocation can now be absolute * insert newline * Fix params into one line * fix indentation in docs * fix import breaks * Allow absolute memory_allocation * Tests * Update test on for precision omitted from methods * Update test for akslearn2 with same args * Update to use TypedDict for better Mypy parsing * Added arg to asklearn2 * Updated tests to remove some warnings * flaked * Fix broken link? * Remove TypedDict as it's not supported in Python3.7 * Missing import * Review changes * Fix magic mock for python < 3.9 * Fixed bad merge
1 parent bd8d521 commit 43299a9

File tree

13 files changed

+1881
-616
lines changed

13 files changed

+1881
-616
lines changed

autosklearn/automl.py

+88-158
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import os
1010
import sys
1111
import time
12-
from typing import Any, Dict, Optional, List, Tuple, Union
12+
from typing import Any, Dict, Mapping, Optional, List, Tuple, Union, cast
1313
import uuid
1414
import unittest.mock
1515
import tempfile
@@ -52,6 +52,13 @@
5252
from autosklearn.evaluation.abstract_evaluator import _fit_and_suppress_warnings
5353
from autosklearn.evaluation.train_evaluator import TrainEvaluator, _fit_with_budget
5454
from autosklearn.metrics import calculate_metric
55+
from autosklearn.util.data import (
56+
reduce_dataset_size_if_too_large,
57+
supported_precision_reductions,
58+
validate_dataset_compression_arg,
59+
default_dataset_compression_arg,
60+
DatasetCompressionSpec,
61+
)
5562
from autosklearn.util.stopwatch import StopWatch
5663
from autosklearn.util.logging_ import (
5764
setup_logger,
@@ -118,7 +125,7 @@ def _model_predict(
118125
The predictions produced by the model
119126
"""
120127
# Copy the array and ensure is has the attr 'shape'
121-
X_ = np.asarray(X) if isinstance(X, list) else X.copy()
128+
X_ = np.asarray(X) if isinstance(X, List) else X.copy()
122129

123130
assert X_.shape[0] >= 1, f"X must have more than 1 sample but has {X_.shape[0]}"
124131

@@ -160,34 +167,36 @@ def _model_predict(
160167

161168
class AutoML(BaseEstimator):
162169

163-
def __init__(self,
164-
time_left_for_this_task,
165-
per_run_time_limit,
166-
temporary_directory: Optional[str] = None,
167-
delete_tmp_folder_after_terminate: bool = True,
168-
initial_configurations_via_metalearning=25,
169-
ensemble_size=1,
170-
ensemble_nbest=1,
171-
max_models_on_disc=1,
172-
seed=1,
173-
memory_limit=3072,
174-
metadata_directory=None,
175-
debug_mode=False,
176-
include: Optional[Dict[str, List[str]]] = None,
177-
exclude: Optional[Dict[str, List[str]]] = None,
178-
resampling_strategy='holdout-iterative-fit',
179-
resampling_strategy_arguments=None,
180-
n_jobs=None,
181-
dask_client: Optional[dask.distributed.Client] = None,
182-
precision=32,
183-
disable_evaluator_output=False,
184-
get_smac_object_callback=None,
185-
smac_scenario_args=None,
186-
logging_config=None,
187-
metric=None,
188-
scoring_functions=None,
189-
get_trials_callback=None
190-
):
170+
def __init__(
171+
self,
172+
time_left_for_this_task,
173+
per_run_time_limit,
174+
temporary_directory: Optional[str] = None,
175+
delete_tmp_folder_after_terminate: bool = True,
176+
initial_configurations_via_metalearning=25,
177+
ensemble_size=1,
178+
ensemble_nbest=1,
179+
max_models_on_disc=1,
180+
seed=1,
181+
memory_limit=3072,
182+
metadata_directory=None,
183+
debug_mode=False,
184+
include=None,
185+
exclude=None,
186+
resampling_strategy='holdout-iterative-fit',
187+
resampling_strategy_arguments=None,
188+
n_jobs=None,
189+
dask_client: Optional[dask.distributed.Client] = None,
190+
precision=32,
191+
disable_evaluator_output=False,
192+
get_smac_object_callback=None,
193+
smac_scenario_args=None,
194+
logging_config=None,
195+
metric=None,
196+
scoring_functions=None,
197+
get_trials_callback=None,
198+
dataset_compression: Union[bool, Mapping[str, Any]] = True
199+
):
191200
super(AutoML, self).__init__()
192201
self.configuration_space = None
193202
self._backend: Optional[Backend] = None
@@ -217,10 +226,10 @@ def __init__(self,
217226
self.precision = precision
218227
self._disable_evaluator_output = disable_evaluator_output
219228
# Check arguments prior to doing anything!
220-
if not isinstance(self._disable_evaluator_output, (bool, list)):
229+
if not isinstance(self._disable_evaluator_output, (bool, List)):
221230
raise ValueError('disable_evaluator_output must be of type bool '
222231
'or list.')
223-
if isinstance(self._disable_evaluator_output, list):
232+
if isinstance(self._disable_evaluator_output, List):
224233
allowed_elements = ['model', 'cv_model', 'y_optimization', 'y_test', 'y_valid']
225234
for element in self._disable_evaluator_output:
226235
if element not in allowed_elements:
@@ -232,6 +241,18 @@ def __init__(self,
232241
self._smac_scenario_args = smac_scenario_args
233242
self.logging_config = logging_config
234243

244+
# Validate dataset_compression and set its values
245+
self._dataset_compression: Optional[DatasetCompressionSpec]
246+
if isinstance(dataset_compression, bool):
247+
if dataset_compression is True:
248+
self._dataset_compression = default_dataset_compression_arg
249+
else:
250+
self._dataset_compression = None
251+
else:
252+
self._dataset_compression = validate_dataset_compression_arg(
253+
dataset_compression, memory_limit=self._memory_limit
254+
)
255+
235256
self._datamanager = None
236257
self._dataset_name = None
237258
self._feat_type = None
@@ -490,10 +511,10 @@ def _supports_task_type(cls, task_type: str) -> bool:
490511
def fit(
491512
self,
492513
X: SUPPORTED_FEAT_TYPES,
493-
y: Union[SUPPORTED_TARGET_TYPES, spmatrix],
514+
y: SUPPORTED_TARGET_TYPES,
494515
task: Optional[int] = None,
495516
X_test: Optional[SUPPORTED_FEAT_TYPES] = None,
496-
y_test: Optional[Union[SUPPORTED_TARGET_TYPES, spmatrix]] = None,
517+
y_test: Optional[SUPPORTED_TARGET_TYPES] = None,
497518
feat_type: Optional[List[str]] = None,
498519
dataset_name: Optional[str] = None,
499520
only_return_configuration_space: bool = False,
@@ -509,8 +530,8 @@ def fit(
509530
#
510531
# `task: Optional[int]` and `is_classification`
511532
#
512-
# `AutoML` tries to identify the task itself with
513-
# `sklearn.type_of_target`, leaving little for the subclasses to do.
533+
# `AutoML` tries to identify the task itself with `sklearn.type_of_target`,
534+
# leaving little for the subclasses to do.
514535
# Except this failes when type_of_target(y) == "multiclass".
515536
#
516537
# "multiclass" be mean either REGRESSION or MULTICLASS_CLASSIFICATION,
@@ -588,6 +609,8 @@ def fit(
588609
self
589610
590611
"""
612+
if (X_test is not None) ^ (y_test is not None):
613+
raise ValueError("Must provide both X_test and y_test together")
591614

592615
# AutoSklearn does not handle sparse y for now
593616
y = convert_if_sparse(y)
@@ -639,17 +662,35 @@ def fit(
639662
self.InputValidator.fit(X_train=X, y_train=y, X_test=X_test, y_test=y_test)
640663
X, y = self.InputValidator.transform(X, y)
641664

642-
if X_test is not None:
665+
if X_test is not None and y_test is not None:
643666
X_test, y_test = self.InputValidator.transform(X_test, y_test)
644667

645-
X, y = self.subsample_if_too_large(
646-
X=X,
647-
y=y,
648-
logger=self._logger,
649-
seed=self._seed,
650-
memory_limit=self._memory_limit,
651-
task=self._task,
652-
)
668+
# We don't support size reduction on pandas type object yet
669+
if (
670+
self._dataset_compression is not None
671+
and not isinstance(X, pd.DataFrame)
672+
and not (isinstance(y, pd.Series) or isinstance(y, pd.DataFrame))
673+
):
674+
methods = self._dataset_compression["methods"]
675+
memory_allocation = self._dataset_compression["memory_allocation"]
676+
677+
# Remove precision reduction if we can't perform it
678+
if (
679+
X.dtype not in supported_precision_reductions
680+
and "precision" in cast(List[str], methods) # Removable with TypedDict
681+
):
682+
methods = [method for method in methods if method != "precision"]
683+
684+
with warnings_to(self._logger):
685+
X, y = reduce_dataset_size_if_too_large(
686+
X=X,
687+
y=y,
688+
memory_limit=self._memory_limit,
689+
is_classification=is_classification,
690+
random_state=self._seed,
691+
operations=methods,
692+
memory_allocation=memory_allocation
693+
)
653694

654695
# Check the re-sampling strategy
655696
try:
@@ -1042,117 +1083,6 @@ def _check_resampling_strategy(
10421083

10431084
return
10441085

1045-
@staticmethod
1046-
def subsample_if_too_large(
1047-
X: SUPPORTED_FEAT_TYPES,
1048-
y: SUPPORTED_TARGET_TYPES,
1049-
logger,
1050-
seed: int,
1051-
memory_limit: int,
1052-
task: int,
1053-
):
1054-
if memory_limit and isinstance(X, np.ndarray):
1055-
1056-
if X.dtype == np.float32:
1057-
multiplier = 4
1058-
elif X.dtype in (np.float64, float):
1059-
multiplier = 8
1060-
elif (
1061-
# In spite of the names, np.float96 and np.float128
1062-
# provide only as much precision as np.longdouble,
1063-
# that is, 80 bits on most x86 machines and 64 bits
1064-
# in standard Windows builds.
1065-
(hasattr(np, 'float128') and X.dtype == np.float128)
1066-
or (hasattr(np, 'float96') and X.dtype == np.float96)
1067-
):
1068-
multiplier = 16
1069-
else:
1070-
# Just assuming some value - very unlikely
1071-
multiplier = 8
1072-
logger.warning('Unknown dtype for X: %s, assuming it takes 8 bit/number',
1073-
str(X.dtype))
1074-
1075-
megabytes = X.shape[0] * X.shape[1] * multiplier / 1024 / 1024
1076-
if memory_limit <= megabytes * 10 and X.dtype != np.float32:
1077-
cast_to = {
1078-
8: np.float32,
1079-
16: np.float64,
1080-
}.get(multiplier, np.float32)
1081-
logger.warning(
1082-
'Dataset too large for memory limit %dMB, reducing the precision from %s to %s',
1083-
memory_limit,
1084-
X.dtype,
1085-
cast_to,
1086-
)
1087-
X = X.astype(cast_to)
1088-
1089-
megabytes = X.shape[0] * X.shape[1] * multiplier / 1024 / 1024
1090-
if memory_limit <= megabytes * 10:
1091-
new_num_samples = int(
1092-
memory_limit / (10 * X.shape[1] * multiplier / 1024 / 1024)
1093-
)
1094-
logger.warning(
1095-
'Dataset too large for memory limit %dMB, reducing number of samples from '
1096-
'%d to %d.',
1097-
memory_limit,
1098-
X.shape[0],
1099-
new_num_samples,
1100-
)
1101-
if task in CLASSIFICATION_TASKS:
1102-
# Identify if it has unique labels and allow for
1103-
# stratification, with unique labels in training set
1104-
values, idxs, counts = np.unique(y, axis=0,
1105-
return_index=True,
1106-
return_counts=True)
1107-
unique_labels = {
1108-
idx: value
1109-
for value, idx, count in zip(values, idxs, counts)
1110-
if count == 1
1111-
}
1112-
1113-
# If there are unique labeled elements, remove them and
1114-
# place them back in later
1115-
if len(unique_labels) > 0:
1116-
idxs_of_unique = np.asarray(list(unique_labels.keys()))
1117-
unique_X = X[idxs_of_unique]
1118-
unique_y = y[idxs_of_unique]
1119-
1120-
# NOTE optimization
1121-
# If this ever turns out to be slow, this actually
1122-
# copies the entire array. There might be a better
1123-
# solution but it will probably require a lot more
1124-
# manual work in how splitting is done.
1125-
X = np.delete(X, idxs_of_unique, axis=0)
1126-
y = np.delete(y, idxs_of_unique, axis=0)
1127-
1128-
X, _, y, _ = sklearn.model_selection.train_test_split(
1129-
X, y,
1130-
train_size=new_num_samples - len(unique_y),
1131-
random_state=seed,
1132-
stratify=y,
1133-
)
1134-
1135-
X = np.append(X, unique_X, axis=0)
1136-
y = np.append(y, unique_y, axis=0)
1137-
1138-
# Otherwise we should be able to stratify as normal
1139-
else:
1140-
X, _, y, _ = sklearn.model_selection.train_test_split(
1141-
X, y,
1142-
train_size=new_num_samples,
1143-
random_state=seed,
1144-
stratify=y,
1145-
)
1146-
elif task in REGRESSION_TASKS:
1147-
X, _, y, _ = sklearn.model_selection.train_test_split(
1148-
X, y,
1149-
train_size=new_num_samples,
1150-
random_state=seed,
1151-
)
1152-
else:
1153-
raise ValueError(task)
1154-
return X, y
1155-
11561086
def refit(self, X, y):
11571087
# AutoSklearn does not handle sparse y for now
11581088
y = convert_if_sparse(y)
@@ -1247,7 +1177,7 @@ def fit_pipeline(
12471177
is_classification: bool
12481178
Whether the task is for classification or regression. This affects
12491179
how the targets are treated
1250-
feat_type : list, optional (default=None)
1180+
feat_type : List, optional (default=None)
12511181
List of str of `len(X.shape[1])` describing the attribute type.
12521182
Possible types are `Categorical` and `Numerical`. `Categorical`
12531183
attributes will be automatically One-Hot encoded. The values
@@ -1536,7 +1466,7 @@ def _load_models(self):
15361466
raise ValueError('No models fitted!')
15371467

15381468
elif self._disable_evaluator_output is False or \
1539-
(isinstance(self._disable_evaluator_output, list) and
1469+
(isinstance(self._disable_evaluator_output, List) and
15401470
'model' not in self._disable_evaluator_output):
15411471
model_names = self._backend.list_all_models(self._seed)
15421472

0 commit comments

Comments
 (0)