diff --git a/sklearn/ensemble/_gradient_boosting.pyx b/sklearn/ensemble/_gradient_boosting.pyx index 6224dee324a57..82537e56f3738 100644 --- a/sklearn/ensemble/_gradient_boosting.pyx +++ b/sklearn/ensemble/_gradient_boosting.pyx @@ -10,8 +10,8 @@ from scipy.sparse import issparse from sklearn.utils._typedefs cimport float32_t, float64_t, intp_t, int32_t, uint8_t # Note: _tree uses cimport numpy, cnp.import_array, so we need to include # numpy headers in the build configuration of this extension -from sklearn.tree._tree cimport Node from sklearn.tree._tree cimport Tree +from sklearn.tree._utils cimport Node from sklearn.tree._utils cimport safe_realloc diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 7410fd91c89e8..0b02e01e3896e 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -124,6 +124,10 @@ class BaseDecisionTree(MultiOutputMixin, BaseEstimator, metaclass=ABCMeta): "min_impurity_decrease": [Interval(Real, 0.0, None, closed="left")], "ccp_alpha": [Interval(Real, 0.0, None, closed="left")], "monotonic_cst": ["array-like", None], + "categorical_features": [ + "array-like", + None, + ], } @abstractmethod @@ -143,6 +147,7 @@ def __init__( class_weight=None, ccp_alpha=0.0, monotonic_cst=None, + categorical_features=None, ): self.criterion = criterion self.splitter = splitter @@ -157,6 +162,7 @@ def __init__( self.class_weight = class_weight self.ccp_alpha = ccp_alpha self.monotonic_cst = monotonic_cst + self.categorical_features = categorical_features def get_depth(self): """Return the depth of the decision tree. @@ -258,13 +264,18 @@ def _fit( missing_values_in_feature_mask = ( self._compute_missing_values_in_feature_mask(X) ) - if issparse(X): + is_sparse_X = issparse(X) + if is_sparse_X: X.sort_indices() if X.indices.dtype != np.intc or X.indptr.dtype != np.intc: raise ValueError( "No support for np.int64 index based sparse matrices" ) + if is_sparse_X and self.categorical_features is not None: + raise NotImplementedError( + "Categorical features not supported with sparse inputs" + ) if self.criterion == "poisson": if np.any(y < 0): @@ -430,6 +441,26 @@ def _fit( # *positive class*, all signs must be flipped. monotonic_cst *= -1 + self.is_categorical_, n_categories_in_feature = ( + self._check_categorical_features(X, monotonic_cst) + ) + + has_categorical = bool(np.any(self.is_categorical_)) + if has_categorical and self.splitter == "random": + raise ValueError( + "Categorical features are not supported with splitter='random'. " + "Use splitter='best' instead." + ) + if has_categorical and self.n_outputs_ > 1: + raise ValueError( + "Categorical features are not supported with multi-output targets." + ) + if has_categorical and is_classifier(self) and np.any(self.n_classes_ > 2): + raise ValueError( + "Categorical features are only supported for binary classification. " + f"Found {self.n_classes_.max()} classes." + ) + if not isinstance(self.splitter, Splitter): splitter = SPLITTERS[self.splitter]( criterion, @@ -441,13 +472,19 @@ def _fit( ) if is_classifier(self): - self.tree_ = Tree(self.n_features_in_, self.n_classes_, self.n_outputs_) + self.tree_ = Tree( + self.n_features_in_, + self.n_classes_, + self.n_outputs_, + self.is_categorical_, + ) else: self.tree_ = Tree( self.n_features_in_, # TODO: tree shouldn't need this in this case np.array([1] * self.n_outputs_, dtype=np.intp), self.n_outputs_, + self.is_categorical_, ) # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise @@ -471,7 +508,14 @@ def _fit( self.min_impurity_decrease, ) - builder.build(self.tree_, X, y, sample_weight, missing_values_in_feature_mask) + builder.build( + self.tree_, + X, + y, + sample_weight, + missing_values_in_feature_mask, + n_categories_in_feature, + ) if self.n_outputs_ == 1 and is_classifier(self): self.n_classes_ = self.n_classes_[0] @@ -481,6 +525,79 @@ def _fit( return self + def _check_categorical_features(self, X, monotonic_cst): + """Check and validate categorical features in X + + Parameters + ---------- + X : {array-like} of shape (n_samples, n_features) + Input data (after `validate_data` was called) + + Return + ------ + is_categorical : ndarray of shape (n_features,), dtype=bool + Boolean mask indicating whether each feature is categorical. + n_categories_in_feature : ndarray of shape (n_features,), dtype=intp + For categorical features, stores ``max(X[:, idx]) + 1``. For + non-categorical features, stores ``-1``. + """ + n_features = X.shape[1] + categorical_features = np.asarray(self.categorical_features) + + if self.categorical_features is None or categorical_features.size == 0: + is_categorical = np.zeros(n_features, dtype=bool) + elif categorical_features.dtype.kind not in ("i", "b"): + raise ValueError( + "categorical_features must be an array-like of bool or int, " + f"got: {categorical_features.dtype.name}." + ) + elif categorical_features.dtype.kind == "i": + # check for categorical features as indices + if ( + np.max(categorical_features) >= n_features + or np.min(categorical_features) < 0 + ): + raise ValueError( + "categorical_features set as integer " + "indices must be in [0, n_features - 1]" + ) + is_categorical = np.zeros(n_features, dtype=bool) + is_categorical[categorical_features] = True + else: + if categorical_features.shape[0] != n_features: + raise ValueError( + "categorical_features set as a boolean mask " + "must have shape (n_features,), got: " + f"{categorical_features.shape}" + ) + is_categorical = categorical_features + + n_categories_in_feature = np.full(self.n_features_in_, -1, dtype=np.intp) + MAX_NC = 64 # TODO import from somewhere + base_msg = ( + f"Values for categorical features should be integers in [0, {MAX_NC - 1}]." + ) + for idx in np.where(is_categorical)[0]: + if np.isnan(X[:, idx]).any(): + raise ValueError( + "Missing values are not supported in categorical features" + ) + if not np.allclose(X[:, idx].astype(np.intp), X[:, idx]): + raise ValueError(f"{base_msg} Found non-integer values.") + if X[:, idx].min() < 0: + raise ValueError(f"{base_msg} Found negative values.") + X_idx_max = X[:, idx].max() + if X_idx_max >= MAX_NC: + raise ValueError(f"{base_msg} Found {X_idx_max}.") + n_categories_in_feature[idx] = X_idx_max + 1 + if monotonic_cst is not None and monotonic_cst[idx] != 0: + raise ValueError( + "A categorical feature cannot have a non-null monotonic" + " constraint. " + ) + + return is_categorical, n_categories_in_feature + def _validate_X_predict(self, X, check_input): """Validate the training data on predict (probabilities).""" if check_input: @@ -619,13 +736,16 @@ def _prune_tree(self): # build pruned tree if is_classifier(self): n_classes = np.atleast_1d(self.n_classes_) - pruned_tree = Tree(self.n_features_in_, n_classes, self.n_outputs_) + pruned_tree = Tree( + self.n_features_in_, n_classes, self.n_outputs_, self.is_categorical_ + ) else: pruned_tree = Tree( self.n_features_in_, # TODO: the tree shouldn't need this param np.array([1] * self.n_outputs_, dtype=np.intp), self.n_outputs_, + self.is_categorical_, ) _build_pruned_tree_ccp(pruned_tree, self.tree_, self.ccp_alpha) @@ -858,6 +978,19 @@ class DecisionTreeClassifier(ClassifierMixin, BaseDecisionTree): .. versionadded:: 1.4 + categorical_features : array-like of int or bool of shape (n_features,) or + (n_categorical_features,), default=None + Indicates which features are treated as categorical. + + - If array-like of int, the entries are feature indices. + - If array-like of bool, it is a boolean mask over features. + + Categorical features are only supported for dense inputs + and single-output targets. + Values of categorical features must be contiguous integers in ``[0, 63]`` + (missing values are not supported). + Categorical features cannot have non-zero monotonic constraints. + Attributes ---------- classes_ : ndarray of shape (n_classes,) or list of ndarray @@ -975,6 +1108,7 @@ def __init__( class_weight=None, ccp_alpha=0.0, monotonic_cst=None, + categorical_features=None, ): super().__init__( criterion=criterion, @@ -990,6 +1124,7 @@ def __init__( min_impurity_decrease=min_impurity_decrease, monotonic_cst=monotonic_cst, ccp_alpha=ccp_alpha, + categorical_features=categorical_features, ) @_fit_context(prefer_skip_nested_validation=True) @@ -1250,6 +1385,21 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): .. versionadded:: 1.4 + categorical_features : array-like of int or bool of shape (n_features,) or + (n_categorical_features,), default=None + Indicates which features are treated as categorical. + + - If array-like of int, the entries are feature indices. + - If array-like of bool, it is a boolean mask over features. + + Categorical features are only supported for dense inputs + and single-output targets. + Values of categorical features must be contiguous integers in ``[0, 63]`` + (missing values are not supported). + Categorical features cannot have non-zero monotonic constraints. + + When these constraints are not met, ``fit`` will raise an error. + Attributes ---------- feature_importances_ : ndarray of shape (n_features,) @@ -1353,6 +1503,7 @@ def __init__( min_impurity_decrease=0.0, ccp_alpha=0.0, monotonic_cst=None, + categorical_features=None, ): if isinstance(criterion, str) and criterion == "friedman_mse": # TODO(1.11): remove support of "friedman_mse" criterion. @@ -1377,6 +1528,7 @@ def __init__( min_impurity_decrease=min_impurity_decrease, ccp_alpha=ccp_alpha, monotonic_cst=monotonic_cst, + categorical_features=categorical_features, ) @_fit_context(prefer_skip_nested_validation=True) diff --git a/sklearn/tree/_partitioner.pxd b/sklearn/tree/_partitioner.pxd index 27f304261650a..94314af5285cc 100644 --- a/sklearn/tree/_partitioner.pxd +++ b/sklearn/tree/_partitioner.pxd @@ -6,9 +6,10 @@ from cython cimport floating from sklearn.utils._typedefs cimport ( - float32_t, float64_t, int8_t, int32_t, intp_t, uint8_t, uint32_t + float32_t, float64_t, int8_t, int32_t, intp_t, uint8_t, uint32_t, uint64_t ) from sklearn.tree._splitter cimport SplitRecord +from sklearn.tree._utils cimport SplitValue # Mitigate precision differences between 32 bit and 64 bit @@ -31,10 +32,12 @@ cdef const float32_t FEATURE_THRESHOLD = 1e-7 # cdef intp_t end # cdef intp_t n_missing # cdef const uint8_t[::1] missing_values_in_feature_mask +# cdef intp_t n_categories -# cdef void sort_samples_and_feature_values( +# cdef bint sort_samples_and_feature_values( # self, intp_t current_feature # ) noexcept nogil +# cdef void shift_missing_to_the_left(self) noexcept nogil # cdef void init_node_split( # self, # intp_t start, @@ -51,15 +54,19 @@ cdef const float32_t FEATURE_THRESHOLD = 1e-7 # intp_t* p_prev, # intp_t* p # ) noexcept nogil +# cdef inline SplitValue pos_to_threshold( +# self, intp_t p_prev, intp_t p +# ) noexcept nogil # cdef intp_t partition_samples( # self, -# float64_t current_threshold +# float64_t current_threshold, +# bint missing_go_to_left # ) noexcept nogil # cdef void partition_samples_final( # self, -# float64_t best_threshold, +# SplitValue split_value, # intp_t best_feature, -# bint best_missing_go_to_left +# bint best_missing_go_to_left, # ) noexcept nogil @@ -69,16 +76,26 @@ cdef class DensePartitioner: Note that this partitioner is agnostic to the splitting strategy (best vs. random). """ cdef const float32_t[:, :] X + cdef const float64_t[:, :] y + cdef const float64_t[::1] sample_weight cdef intp_t[::1] samples cdef float32_t[::1] feature_values cdef intp_t start cdef intp_t end cdef intp_t n_missing cdef const uint8_t[::1] missing_values_in_feature_mask + cdef const intp_t[::1] n_categories_in_feature cdef bint missing_on_the_left + cdef intp_t n_categories cdef char[::1] swap_buffer - cdef void sort_samples_and_feature_values( + cdef intp_t[::1] counts + cdef float64_t[::1] weighted_counts + cdef float64_t[::1] means + cdef intp_t[::1] sorted_cat + cdef intp_t[::1] offsets + + cdef bint sort_samples_and_feature_values( self, intp_t current_feature ) noexcept nogil cdef void shift_missing_to_the_left(self) noexcept nogil @@ -98,6 +115,9 @@ cdef class DensePartitioner: intp_t* p_prev, intp_t* p ) noexcept nogil + cdef inline SplitValue pos_to_threshold( + self, intp_t p_prev, intp_t p + ) noexcept nogil cdef intp_t partition_samples( self, float64_t current_threshold, @@ -105,10 +125,12 @@ cdef class DensePartitioner: ) noexcept nogil cdef void partition_samples_final( self, - float64_t best_threshold, + SplitValue split_value, intp_t best_feature, bint best_missing_go_to_left, ) noexcept nogil + cdef void _breiman_sort_categories(self, intp_t nc) noexcept nogil + cdef inline uint64_t _split_pos_to_bitset(self, intp_t p, intp_t nc) noexcept nogil cdef class SparsePartitioner: @@ -133,8 +155,9 @@ cdef class SparsePartitioner: cdef intp_t end cdef intp_t n_missing cdef const uint8_t[::1] missing_values_in_feature_mask + cdef intp_t n_categories - cdef void sort_samples_and_feature_values( + cdef bint sort_samples_and_feature_values( self, intp_t current_feature ) noexcept nogil cdef void shift_missing_to_the_left(self) noexcept nogil @@ -154,6 +177,9 @@ cdef class SparsePartitioner: intp_t* p_prev, intp_t* p ) noexcept nogil + cdef inline SplitValue pos_to_threshold( + self, intp_t p_prev, intp_t p + ) noexcept nogil cdef intp_t partition_samples( self, float64_t current_threshold, @@ -161,7 +187,7 @@ cdef class SparsePartitioner: ) noexcept nogil cdef void partition_samples_final( self, - float64_t best_threshold, + SplitValue split_value, intp_t best_feature, bint best_missing_go_to_left, ) noexcept nogil diff --git a/sklearn/tree/_partitioner.pyx b/sklearn/tree/_partitioner.pyx index 51d416cf22f0e..803631e143c08 100644 --- a/sklearn/tree/_partitioner.pyx +++ b/sklearn/tree/_partitioner.pyx @@ -13,7 +13,7 @@ and sparse data stored in a Compressed Sparse Column (CSC) format. from cython cimport final from libc.math cimport isnan, log2 from libc.stdlib cimport qsort -from libc.string cimport memcpy, memmove +from libc.string cimport memcpy, memmove, memset import numpy as np cimport numpy as cnp @@ -27,9 +27,12 @@ from scipy.sparse import issparse # in SparsePartitioner cdef float32_t EXTRACT_NNZ_SWITCH = 0.1 +cdef float64_t INFINITY = np.inf # Allow for 32 bit float comparisons cdef float32_t INFINITY_32t = np.inf +cdef intp_t MAX_N_CAT = 64 + @final cdef class DensePartitioner: @@ -40,31 +43,47 @@ cdef class DensePartitioner: def __init__( self, const float32_t[:, :] X, + const float64_t[:, :] y, + const float64_t[::1] sample_weight, intp_t[::1] samples, float32_t[::1] feature_values, const uint8_t[::1] missing_values_in_feature_mask, + const intp_t[::1] n_categories_in_feature, ): self.X = X + self.y = y + self.sample_weight = sample_weight self.samples = samples self.feature_values = feature_values self.missing_values_in_feature_mask = missing_values_in_feature_mask + self.n_categories_in_feature = n_categories_in_feature + self.missing_on_the_left = False + self.n_categories = 0 buffer_size = samples.size * max(samples.itemsize, feature_values.itemsize) self.swap_buffer = np.empty(buffer_size, dtype=np.uint8) + # for breiman shortcut: + self.counts = np.empty(MAX_N_CAT, dtype=np.intp) + self.weighted_counts = np.empty(MAX_N_CAT, dtype=np.float64) + self.means = np.empty(MAX_N_CAT, dtype=np.float64) + self.sorted_cat = np.empty(MAX_N_CAT, dtype=np.intp) + self.offsets = np.empty(MAX_N_CAT, dtype=np.intp) + cdef inline void init_node_split(self, intp_t start, intp_t end) noexcept nogil: """Initialize splitter at the beginning of node_split.""" self.start = start self.end = end self.n_missing = 0 - cdef inline void sort_samples_and_feature_values( + cdef inline bint sort_samples_and_feature_values( self, intp_t current_feature ) noexcept nogil: """Simultaneously sort based on the feature_values. - Missing values are stored at the end of feature_values. - The number of missing values observed in feature_values is stored - in self.n_missing. + For numerical features, this is a standard sort. For categorical + features, samples are reordered using the Breiman ordering shortcut. + + Returns ``True`` when the feature is constant at the current node. """ cdef: intp_t i, current_end @@ -101,9 +120,91 @@ cdef class DensePartitioner: for i in range(self.start, self.end): feature_values[i] = X[samples[i], current_feature] - sort(&feature_values[self.start], &samples[self.start], self.end - self.start - n_missing) self.missing_on_the_left = False self.n_missing = n_missing + self.n_categories = self.n_categories_in_feature[current_feature] + if n_missing == self.end - self.start: + return True + if self.n_categories <= 0: + # not a categorical feature + sort(&feature_values[self.start], &samples[self.start], self.end - self.start - n_missing) + if n_missing > 0: + return False + return feature_values[self.end - n_missing - 1] <= feature_values[self.start] + FEATURE_THRESHOLD + else: + self._breiman_sort_categories(self.n_categories) + return feature_values[self.start] == feature_values[self.end - 1] + + cdef void _breiman_sort_categories(self, intp_t nc) noexcept nogil: + """ + Order self.sorted_cat by ascending average target value + and order self.features_values & self.samples such that + - self.features_values is ordered according to the order of sorted_cat + - the relation `self.features_values[p] = self.X[self.samples[p], f]` is + preserved + + E.g. sorted_cat is [2 0 1] + features_values is [2 2 2 0 0 1 1 1 1] + + This ordering ensures the optimal split will be among the candidate splits + evaluated by the splitter (this is called the Brieman shortcut). + + Time complexity: O(n + nc log nc) + """ + cdef: + intp_t* counts = &self.counts[0] + float64_t* weighted_counts = &self.weighted_counts[0] + float64_t* means = &self.means[0] + intp_t* sorted_cat = &self.sorted_cat[0] + intp_t* offsets = &self.offsets[0] + float32_t* feature_values = &self.feature_values[0] + intp_t* samples = &self.samples[0] + intp_t c, r, p, new_p + float64_t w = 1. + + memset(means, 0, nc * sizeof(float64_t)) + memset(counts, 0, nc * sizeof(intp_t)) + memset(weighted_counts, 0, nc * sizeof(float64_t)) + + # compute counts, weighted_counts and means + for p in range(self.start, self.end): + c = feature_values[p] + counts[c] += 1 + if self.sample_weight is not None: + w = self.sample_weight[samples[p]] + means[c] += w * self.y[samples[p], 0] + self.weighted_counts[c] += w + + for c in range(nc): + if weighted_counts[c] > 0: + means[c] /= weighted_counts[c] + + # sorted_cat[i] = i-th categories sorted by ascending means + for c in range(nc): + sorted_cat[c] = c + sort(means, sorted_cat, nc) + + # build offsets such that: + # offsets[c] = sum( counts[x] for all x s.t. rank(x) <= rank(c) ) - 1 + cdef intp_t offset = 0 + for r in range(nc): + c = sorted_cat[r] + offset += counts[c] + offsets[c] = offset - 1 + + # sort feature_values & samples in-place such that + # they are ordered by the mean of the category + # while ensuring samples of the same categories are contiguous + p = self.start + while p < self.end: + c = feature_values[p] + new_p = offsets[c] + if new_p > p: + swap(feature_values, samples, p, new_p) + # swap preserves invariant: feature[p] = X[samples[p], f] + offsets[c] -= 1 + else: + p += 1 cdef void shift_missing_to_the_left(self) noexcept nogil: """Moves missing values from the right to the left. @@ -173,12 +274,20 @@ cdef class DensePartitioner: cdef intp_t end_non_missing = ( self.end if self.missing_on_the_left else self.end - self.n_missing) + cdef float32_t c if p[0] == end_non_missing and not self.missing_on_the_left: # skip the missing values up to the end # (which will end the for loop in the best split function) p[0] = self.end p_prev[0] = self.end + elif self.n_categories > 0: + c = self.feature_values[p[0]] + p[0] += 1 + while p[0] < end_non_missing and self.feature_values[p[0]] == c: + p[0] += 1 + + # p_prev is unused in this case else: if self.missing_on_the_left and p[0] == self.start: # skip the missing values up to the first non-missing value: @@ -191,6 +300,39 @@ cdef class DensePartitioner: p[0] += 1 p_prev[0] = p[0] - 1 + cdef inline SplitValue pos_to_threshold( + self, intp_t p_prev, intp_t p + ) noexcept nogil: + """Convert a split position into a concrete split value. + + For numerical features, this returns the usual mid-point threshold. + For categorical features, it converts the split position into a bitset + over categories. + """ + cdef SplitValue split + cdef intp_t end_non_missing = ( + self.end if self.missing_on_the_left + else self.end - self.n_missing) + + if self.n_categories > 0: + split.cat_split = self._split_pos_to_bitset(p, self.n_categories) + return split + + if p == end_non_missing and not self.missing_on_the_left: + # split with the right node being only the missing values + split.threshold = INFINITY + return split + + # split between two non-missing values + # sum of halves is used to avoid infinite value + split.threshold = ( + self.feature_values[p_prev] / 2.0 + self.feature_values[p] / 2.0 + ) + if split.threshold == INFINITY or split.threshold == -INFINITY: + split.threshold = self.feature_values[p_prev] + + return split + cdef inline intp_t partition_samples( self, float64_t threshold, @@ -224,7 +366,7 @@ cdef class DensePartitioner: cdef inline void partition_samples_final( self, - float64_t best_threshold, + SplitValue split_value, intp_t best_feature, bint best_missing_go_to_left ) noexcept nogil: @@ -239,21 +381,30 @@ cdef class DensePartitioner: intp_t partition_end = self.end intp_t* samples = &self.samples[0] float32_t current_value - bint go_to_left + bint is_cat = self.n_categories_in_feature[best_feature] > 0 while partition_start < partition_end: current_value = self.X[samples[partition_start], best_feature] - go_to_left = ( - best_missing_go_to_left if isnan(current_value) - else current_value <= best_threshold - ) - if go_to_left: + if goes_left(split_value, best_missing_go_to_left, is_cat, current_value): partition_start += 1 else: partition_end -= 1 samples[partition_start], samples[partition_end] = ( samples[partition_end], samples[partition_start]) + cdef inline uint64_t _split_pos_to_bitset(self, intp_t p, intp_t nc) noexcept nogil: + """Convert a split position ``p`` into a categorical bitset.""" + cdef uint64_t bitset = 0 + cdef intp_t r, c + cdef intp_t offset = 0 + for r in range(nc): + c = self.sorted_cat[r] + bitset |= (1) << c + offset += self.counts[c] + if offset >= p: + break + return bitset + @final cdef class SparsePartitioner: @@ -292,6 +443,7 @@ cdef class SparsePartitioner: self.index_to_samples[samples[p]] = p self.missing_values_in_feature_mask = missing_values_in_feature_mask + self.n_categories = 0 cdef inline void init_node_split(self, intp_t start, intp_t end) noexcept nogil: """Initialize splitter at the beginning of node_split.""" @@ -300,11 +452,14 @@ cdef class SparsePartitioner: self.is_samples_sorted = 0 self.n_missing = 0 - cdef inline void sort_samples_and_feature_values( + cdef inline bint sort_samples_and_feature_values( self, intp_t current_feature ) noexcept nogil: - """Simultaneously sort based on the feature_values.""" + """Simultaneously sort based on the feature_values. + + Returns ``True`` when the feature is constant at the current node. + """ cdef: float32_t[::1] feature_values = self.feature_values intp_t[::1] index_to_samples = self.index_to_samples @@ -339,6 +494,8 @@ cdef class SparsePartitioner: # number of missing values for current_feature self.n_missing = 0 + return feature_values[self.end - 1] <= feature_values[self.start] + FEATURE_THRESHOLD + cdef void shift_missing_to_the_left(self) noexcept nogil: pass # Missing values are not supported for sparse data. @@ -405,6 +562,21 @@ cdef class SparsePartitioner: p_prev[0] = p[0] p[0] = p_next + cdef inline SplitValue pos_to_threshold( + self, intp_t p_prev, intp_t p + ) noexcept nogil: + """Convert a split position into a numerical threshold.""" + cdef SplitValue split + # split between two non-missing values + # sum of halves is used to avoid infinite value + split.threshold = ( + self.feature_values[p_prev] / 2.0 + self.feature_values[p] / 2.0 + ) + if split.threshold == INFINITY or split.threshold == -INFINITY: + split.threshold = self.feature_values[p_prev] + + return split + cdef inline intp_t partition_samples( self, float64_t current_threshold, @@ -415,13 +587,13 @@ cdef class SparsePartitioner: cdef inline void partition_samples_final( self, - float64_t best_threshold, + SplitValue split_value, intp_t best_feature, bint missing_go_to_left ) noexcept nogil: """Partition samples for X at the best_threshold and best_feature.""" self.extract_nnz(best_feature) - self._partition(best_threshold) + self._partition(split_value.threshold) cdef inline intp_t _partition(self, float64_t threshold) noexcept nogil: """Partition samples[start:end] based on threshold.""" @@ -511,6 +683,22 @@ cdef class SparsePartitioner: &self.end_negative, &self.start_positive) +cdef inline bint goes_left( + SplitValue split_value, bint missing_go_to_left, bint is_categorical, float32_t value +) noexcept nogil: + """Return whether ``value`` should go to the left child + + This helper centralizes the split semantics for numerical, categorical, + and missing values. + """ + if isnan(value): + return missing_go_to_left + elif is_categorical: + return split_value.cat_split & ((1) << ( value)) + else: + return value <= split_value.threshold + + cdef int compare_SIZE_t(const void* a, const void* b) noexcept nogil: """Comparison function for sort. @@ -670,7 +858,7 @@ def _py_sort(float32_t[::1] feature_values, intp_t[::1] samples, intp_t n): # Sort n-element arrays pointed to by feature_values and samples, simultaneously, # by the values in feature_values. Algorithm: Introsort (Musser, SP&E, 1997). -cdef void sort(floating* feature_values, intp_t* samples, intp_t n) noexcept nogil: +cdef inline void sort(floating* feature_values, intp_t* samples, intp_t n) noexcept nogil: if n == 0: return cdef intp_t maxd = 2 * log2(n) diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 511ecdc655663..9525a04eb19f0 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -8,6 +8,7 @@ from sklearn.utils._typedefs cimport ( ) from sklearn.tree._criterion cimport Criterion from sklearn.tree._tree cimport ParentInfo +from sklearn.tree._utils cimport SplitValue cdef struct SplitRecord: @@ -16,7 +17,7 @@ cdef struct SplitRecord: intp_t pos # Split samples array at the given position, # # i.e. count of samples below threshold for feature. # # pos is >= end if the node is a leaf. - float64_t threshold # Threshold to split at. + SplitValue value # Threshold/Bitset to split at. float64_t improvement # Impurity improvement given parent node. float64_t impurity_left # Impurity of the left split. float64_t impurity_right # Impurity of the right split. @@ -84,6 +85,7 @@ cdef class Splitter: const float64_t[:, ::1] y, const float64_t[:] sample_weight, const uint8_t[::1] missing_values_in_feature_mask, + const intp_t[::1] n_categories_in_feature, ) except -1 cdef int node_reset( diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index 11f1b204a1210..c08ab11992654 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -50,7 +50,6 @@ cdef inline void _init_split(SplitRecord* self, intp_t start_pos) noexcept nogil self.impurity_right = INFINITY self.pos = start_pos self.feature = 0 - self.threshold = 0. self.improvement = -INFINITY self.missing_go_to_left = False @@ -129,6 +128,7 @@ cdef class Splitter: const float64_t[:, ::1] y, const float64_t[:] sample_weight, const uint8_t[::1] missing_values_in_feature_mask, + const intp_t[::1] n_categories_in_feature, ) except -1: """Initialize the splitter. @@ -152,8 +152,9 @@ cdef class Splitter: are assumed to have uniform weight. This is represented as a Cython memoryview. - has_missing : bool - At least one missing values is in X. + n_categories_in_feature : ndarray, dtype=intp_t + Per-feature number of categories for categorical features, and + ``-1`` for numerical features. """ self.rand_r_state = self.random_state.randint(0, RAND_R_MAX) @@ -281,7 +282,6 @@ cdef inline int node_split_best( # Find the best split cdef intp_t start = splitter.start cdef intp_t end = splitter.end - cdef intp_t end_non_missing cdef intp_t n_missing = 0 cdef bint has_missing = 0 cdef intp_t n_searches @@ -292,7 +292,6 @@ cdef inline int node_split_best( cdef intp_t[::1] constant_features = splitter.constant_features cdef intp_t n_features = splitter.n_features - cdef float32_t[::1] feature_values = splitter.feature_values cdef intp_t max_features = splitter.max_features cdef intp_t min_samples_leaf = splitter.min_samples_leaf cdef float64_t min_weight_leaf = splitter.min_weight_leaf @@ -308,6 +307,7 @@ cdef inline int node_split_best( cdef intp_t f_i = n_features cdef intp_t f_j + cdef bint is_constant cdef intp_t p cdef intp_t p_prev @@ -367,19 +367,10 @@ cdef inline int node_split_best( f_j += n_found_constants # f_j in the interval [n_total_constants, f_i[ current_split.feature = features[f_j] - partitioner.sort_samples_and_feature_values(current_split.feature) + is_constant = partitioner.sort_samples_and_feature_values(current_split.feature) n_missing = partitioner.n_missing - end_non_missing = end - n_missing - if ( - # All values for this feature are missing, or - end_non_missing == start or - # This feature is considered constant (max - min <= FEATURE_THRESHOLD) - (( - feature_values[end_non_missing - 1] - <= feature_values[start] + FEATURE_THRESHOLD - ) and n_missing == 0) - ): + if is_constant: # We consider this feature constant in this case. # Since finding a split among constant feature is not valuable, # we do not consider this feature for splitting. @@ -448,20 +439,7 @@ cdef inline int node_split_best( if current_proxy_improvement > best_proxy_improvement: best_proxy_improvement = current_proxy_improvement - if p == end_non_missing and not missing_go_to_left: - # split with the right node being only the missing values - current_split.threshold = INFINITY - else: - # split between two non-missing values - # sum of halves is used to avoid infinite value - current_split.threshold = ( - feature_values[p_prev] / 2.0 + feature_values[p] / 2.0 - ) - if ( - current_split.threshold == INFINITY or - current_split.threshold == -INFINITY - ): - current_split.threshold = feature_values[p_prev] + current_split.value = partitioner.pos_to_threshold(p_prev, p) # if there are no missing values in the training data, during # test time, we send missing values to the branch that contains @@ -476,7 +454,7 @@ cdef inline int node_split_best( # Reorganize into samples[start:best_split.pos] + samples[best_split.pos:end] if best_split.pos < end: partitioner.partition_samples_final( - best_split.threshold, + best_split.value, best_split.feature, best_split.missing_go_to_left ) @@ -635,7 +613,7 @@ cdef inline int node_split_random( has_missing = n_missing != 0 # Draw a random threshold - current_split.threshold = rand_uniform( + current_split.value.threshold = rand_uniform( min_feature_value, max_feature_value, random_state, @@ -655,12 +633,12 @@ cdef inline int node_split_random( else: missing_go_to_left = 0 - if current_split.threshold == max_feature_value: - current_split.threshold = min_feature_value + if current_split.value.threshold == max_feature_value: + current_split.value.threshold = min_feature_value # Partition current_split.pos = partitioner.partition_samples( - current_split.threshold, missing_go_to_left + current_split.value.threshold, missing_go_to_left ) n_left = current_split.pos - start @@ -711,7 +689,7 @@ cdef inline int node_split_random( if best_split.pos < end: if current_split.feature != best_split.feature: partitioner.partition_samples_final( - best_split.threshold, + best_split.value, best_split.feature, best_split.missing_go_to_left ) @@ -752,10 +730,12 @@ cdef class BestSplitter(Splitter): const float64_t[:, ::1] y, const float64_t[:] sample_weight, const uint8_t[::1] missing_values_in_feature_mask, + const intp_t[::1] n_categories_in_feature, ) except -1: - Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask) + Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask, n_categories_in_feature) self.partitioner = DensePartitioner( - X, self.samples, self.feature_values, missing_values_in_feature_mask + X, y, sample_weight, self.samples, self.feature_values, + missing_values_in_feature_mask, n_categories_in_feature ) cdef int node_split( @@ -780,8 +760,9 @@ cdef class BestSparseSplitter(Splitter): const float64_t[:, ::1] y, const float64_t[:] sample_weight, const uint8_t[::1] missing_values_in_feature_mask, + const intp_t[::1] n_categories_in_feature, ) except -1: - Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask) + Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask, n_categories_in_feature) self.partitioner = SparsePartitioner( X, self.samples, self.n_samples, self.feature_values, missing_values_in_feature_mask ) @@ -808,10 +789,12 @@ cdef class RandomSplitter(Splitter): const float64_t[:, ::1] y, const float64_t[:] sample_weight, const uint8_t[::1] missing_values_in_feature_mask, + const intp_t[::1] n_categories_in_feature, ) except -1: - Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask) + Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask, n_categories_in_feature) self.partitioner = DensePartitioner( - X, self.samples, self.feature_values, missing_values_in_feature_mask + X, y, sample_weight, self.samples, self.feature_values, + missing_values_in_feature_mask, n_categories_in_feature ) cdef int node_split( @@ -836,8 +819,9 @@ cdef class RandomSparseSplitter(Splitter): const float64_t[:, ::1] y, const float64_t[:] sample_weight, const uint8_t[::1] missing_values_in_feature_mask, + const intp_t[::1] n_categories_in_feature, ) except -1: - Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask) + Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask, n_categories_in_feature) self.partitioner = SparsePartitioner( X, self.samples, self.n_samples, self.feature_values, missing_values_in_feature_mask ) diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 593f8d0c5f542..fcd5dceb4fae7 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -6,22 +6,13 @@ import numpy as np cimport numpy as cnp -from sklearn.utils._typedefs cimport float32_t, float64_t, intp_t, int32_t, uint8_t, uint32_t +from sklearn.utils._typedefs cimport ( + float32_t, float64_t, intp_t, int32_t, uint8_t, uint32_t, uint64_t +) from sklearn.tree._splitter cimport Splitter from sklearn.tree._splitter cimport SplitRecord - -cdef struct Node: - # Base storage structure for the nodes in a Tree object - - intp_t left_child # id of the left child of the node - intp_t right_child # id of the right child of the node - intp_t feature # Feature used for splitting the node - float64_t threshold # Threshold value at the node - float64_t impurity # Impurity of the node (i.e., the value of the criterion) - intp_t n_node_samples # Number of samples at the node - float64_t weighted_n_node_samples # Weighted number of samples at the node - uint8_t missing_go_to_left # Whether features have missing values +from sklearn.tree._utils cimport Node, SplitValue cdef struct ParentInfo: @@ -44,6 +35,9 @@ cdef class Tree: cdef public intp_t n_outputs # Number of outputs in y cdef public intp_t max_n_classes # max(n_classes) + # FIXME: change to uint8_t: but the error it triggers might be a Cython bug + cdef intp_t* is_categorical # Shape (n_features,) + # Inner structures: values are stored separately from node structure, # since size is determined at runtime. cdef public intp_t max_depth # Max depth of the tree @@ -55,8 +49,8 @@ cdef class Tree: # Methods cdef intp_t _add_node(self, intp_t parent, bint is_left, bint is_leaf, - intp_t feature, float64_t threshold, float64_t impurity, - intp_t n_node_samples, + intp_t feature, float64_t threshold, uint64_t cat_split, + float64_t impurity, intp_t n_node_samples, float64_t weighted_n_node_samples, uint8_t missing_go_to_left) except -1 nogil cdef int _resize(self, intp_t capacity) except -1 nogil diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 7044673189fb6..4f220bfd9da2c 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -145,6 +145,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): const float64_t[:, ::1] y, const float64_t[:] sample_weight=None, const uint8_t[::1] missing_values_in_feature_mask=None, + const intp_t [::1] n_categories_in_feature=None, ): """Build a decision tree from the training set (X, y).""" @@ -170,7 +171,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef float64_t min_impurity_decrease = self.min_impurity_decrease # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight, missing_values_in_feature_mask) + splitter.init(X, y, sample_weight, missing_values_in_feature_mask, n_categories_in_feature) cdef intp_t start cdef intp_t end @@ -254,7 +255,8 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): min_impurity_decrease)) node_id = tree._add_node(parent, is_left, is_leaf, split.feature, - split.threshold, parent_record.impurity, + split.value.threshold, split.value.cat_split, + parent_record.impurity, n_node_samples, weighted_n_node_samples, split.missing_go_to_left) @@ -400,6 +402,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): const float64_t[:, ::1] y, const float64_t[:] sample_weight=None, const uint8_t[::1] missing_values_in_feature_mask=None, + const intp_t[::1] n_categories_in_feature=None, ): """Build a decision tree from the training set (X, y).""" @@ -411,7 +414,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef intp_t max_leaf_nodes = self.max_leaf_nodes # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight, missing_values_in_feature_mask) + splitter.init(X, y, sample_weight, missing_values_in_feature_mask, n_categories_in_feature) cdef vector[FrontierRecord] frontier cdef FrontierRecord record @@ -467,6 +470,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): node.right_child = _TREE_LEAF node.feature = _TREE_UNDEFINED node.threshold = _TREE_UNDEFINED + # node.categorical_bitset = _TREE_UNDEFINED else: # Node is expandable @@ -611,8 +615,9 @@ cdef class BestFirstTreeBuilder(TreeBuilder): node_id = tree._add_node(parent - tree.nodes if parent != NULL else _TREE_UNDEFINED, - is_left, is_leaf, - split.feature, split.threshold, parent_record.impurity, + is_left, is_leaf, split.feature, + split.value.threshold, split.value.cat_split, + parent_record.impurity, n_node_samples, weighted_n_node_samples, split.missing_go_to_left) if node_id == INTPTR_MAX: @@ -746,6 +751,10 @@ cdef class Tree: def threshold(self): return self._get_node_ndarray()['threshold'][:self.node_count] + @property + def categorical_bitset(self): + return self._get_node_ndarray()['categorical_bitset'][:self.node_count] + @property def impurity(self): return self._get_node_ndarray()['impurity'][:self.node_count] @@ -768,7 +777,7 @@ cdef class Tree: # TODO: Convert n_classes to cython.integral memory view once # https://github.com/cython/cython/issues/5243 is fixed - def __cinit__(self, intp_t n_features, cnp.ndarray n_classes, intp_t n_outputs): + def __cinit__(self, intp_t n_features, cnp.ndarray n_classes, intp_t n_outputs, cnp.ndarray is_categorical): """Constructor.""" cdef intp_t dummy = 0 size_t_dtype = np.array(dummy).dtype @@ -788,6 +797,16 @@ cdef class Tree: for k in range(n_outputs): self.n_classes[k] = n_classes[k] + self.is_categorical = NULL + safe_realloc(&self.is_categorical, n_features) + if is_categorical is None: + for f in range(n_features): + self.is_categorical[f] = False + else: + is_categorical = is_categorical.astype(np.intp) + for f in range(n_features): + self.is_categorical[f] = is_categorical[f] + # Inner structures self.max_depth = 0 self.node_count = 0 @@ -798,15 +817,19 @@ cdef class Tree: def __dealloc__(self): """Destructor.""" # Free all inner structures + free(self.is_categorical) free(self.n_classes) free(self.value) free(self.nodes) def __reduce__(self): """Reduce re-implementation, for pickling.""" - return (Tree, (self.n_features, - sizet_ptr_to_ndarray(self.n_classes, self.n_outputs), - self.n_outputs), self.__getstate__()) + return (Tree, ( + self.n_features, + sizet_ptr_to_ndarray(self.n_classes, self.n_outputs), + self.n_outputs, + sizet_ptr_to_ndarray(self.is_categorical, self.n_features) + ), self.__getstate__()) def __getstate__(self): """Getstate re-implementation, for pickling.""" @@ -895,8 +918,8 @@ cdef class Tree: return 0 cdef intp_t _add_node(self, intp_t parent, bint is_left, bint is_leaf, - intp_t feature, float64_t threshold, float64_t impurity, - intp_t n_node_samples, + intp_t feature, float64_t threshold, uint64_t cat_split, + float64_t impurity, intp_t n_node_samples, float64_t weighted_n_node_samples, uint8_t missing_go_to_left) except -1 nogil: """Add a node to the tree. @@ -927,11 +950,17 @@ cdef class Tree: node.right_child = _TREE_LEAF node.feature = _TREE_UNDEFINED node.threshold = _TREE_UNDEFINED + # node.categorical_bitset = _TREE_UNDEFINED else: # left_child and right_child will be set later node.feature = feature - node.threshold = threshold + if self.is_categorical[feature]: + node.threshold = -INFINITY + node.categorical_bitset = cat_split + else: + node.threshold = threshold + node.categorical_bitset = 0 node.missing_go_to_left = missing_go_to_left self.node_count += 1 @@ -981,13 +1010,18 @@ cdef class Tree: node = self.nodes # While node not a leaf while node.left_child != _TREE_LEAF: - X_i_node_feature = X_ndarray[i, node.feature] # ... and node.right_child != _TREE_LEAF: + X_i_node_feature = X_ndarray[i, node.feature] if isnan(X_i_node_feature): if node.missing_go_to_left: node = &self.nodes[node.left_child] else: node = &self.nodes[node.right_child] + elif self.is_categorical[node.feature]: + if node.categorical_bitset & ((1) << ( X_i_node_feature)): + node = &self.nodes[node.left_child] + else: + node = &self.nodes[node.right_child] elif X_i_node_feature <= node.threshold: node = &self.nodes[node.left_child] else: @@ -1109,13 +1143,17 @@ cdef class Tree: # ... and node.right_child != _TREE_LEAF: indices[indptr[i + 1]] = (node - self.nodes) indptr[i + 1] += 1 - X_i_node_feature = X_ndarray[i, node.feature] if isnan(X_i_node_feature): if node.missing_go_to_left: node = &self.nodes[node.left_child] else: node = &self.nodes[node.right_child] + elif self.is_categorical[node.feature]: + if node.categorical_bitset & ((1) << ( X_ndarray[i, node.feature])): + node = &self.nodes[node.left_child] + else: + node = &self.nodes[node.right_child] elif X_i_node_feature <= node.threshold: node = &self.nodes[node.left_child] else: @@ -1396,6 +1434,7 @@ cdef class Tree: if is_target_feature: # In this case, we push left or right child on stack + # TODO: handle categorical (and missing?) if X[sample_idx, feature_idx] <= current_node.threshold: node_idx_stack[stack_size] = current_node.left_child else: @@ -1936,7 +1975,8 @@ cdef void _build_pruned_tree( break new_node_id = tree._add_node( - parent, is_left, is_leaf, node.feature, node.threshold, + parent, is_left, is_leaf, node.feature, + node.threshold, node.categorical_bitset, node.impurity, node.n_node_samples, node.weighted_n_node_samples, node.missing_go_to_left) diff --git a/sklearn/tree/_utils.pxd b/sklearn/tree/_utils.pxd index 97f8d60645b04..4857edc070202 100644 --- a/sklearn/tree/_utils.pxd +++ b/sklearn/tree/_utils.pxd @@ -4,9 +4,34 @@ # See _utils.pyx for details. cimport numpy as cnp -from sklearn.tree._tree cimport Node from sklearn.neighbors._quad_tree cimport Cell -from sklearn.utils._typedefs cimport float32_t, float64_t, intp_t, uint8_t, int32_t, uint32_t +from sklearn.utils._typedefs cimport ( + float32_t, float64_t, intp_t, uint8_t, int32_t, uint32_t, uint64_t +) + + +ctypedef union SplitValue: + # Union type to generalize the concept of a threshold to categorical + # features. The floating point view, i.e. ``split_value.threshold`` is used + # for numerical features, where feature values less than or equal to the + # threshold go left, and values greater than the threshold go right. + + float64_t threshold + uint64_t cat_split # bitset + + +cdef struct Node: + # Base storage structure for the nodes in a Tree object + + intp_t left_child # id of the left child of the node + intp_t right_child # id of the right child of the node + intp_t feature # Feature used for splitting the node + float64_t threshold # Threshold value at the node, for continuous split (-INF otherwise) + uint64_t categorical_bitset # Bitset for categorical split (0 otherwise) + float64_t impurity # Impurity of the node (i.e., the value of the criterion) + intp_t n_node_samples # Number of samples at the node + float64_t weighted_n_node_samples # Weighted number of samples at the node + uint8_t missing_go_to_left # Whether features have missing values cdef enum: diff --git a/sklearn/tree/tests/test_split.py b/sklearn/tree/tests/test_split.py index cd5a56eaf7601..fde9b23e0e341 100644 --- a/sklearn/tree/tests/test_split.py +++ b/sklearn/tree/tests/test_split.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from itertools import product +from itertools import chain, combinations, product from operator import itemgetter import numpy as np @@ -21,6 +21,7 @@ REG_CRITERIONS = ("squared_error", "absolute_error", "poisson") + CLF_TREES = { "DecisionTreeClassifier": DecisionTreeClassifier, "ExtraTreeClassifier": ExtraTreeClassifier, @@ -32,10 +33,45 @@ } +def powerset(iterable): + """returns all the subsets of `iterable`.""" + s = list(iterable) + return chain.from_iterable( + combinations(s, r) for r in range(1, (len(s) + 1) // 2 + 1) + ) + + +def bitset_to_tuple(v: np.uint64): + return tuple(c for c in range(64) if v & (1 << c)) + + +@dataclass +class Split: + feature: int + threshold: float | tuple + missing_left: bool = False + + @property + def is_categorical(self): + return isinstance(self.threshold, tuple) + + @classmethod + def from_tree(cls, tree): + ftr = int(tree.tree_.feature[0]) + if tree.is_categorical_[ftr]: + cat_bitset = np.uint64(tree.tree_.categorical_bitset[0]) + threshold = bitset_to_tuple(cat_bitset) + else: + threshold = tree.tree_.threshold[0] + missing_left = bool(tree.tree_.missing_go_to_left[0]) + return cls(ftr, threshold, missing_left) + + @dataclass class NaiveSplitter: criterion: str - n_classes: int = 0 + n_classes: int + is_categorical: np.ndarray def compute_node_value_and_impurity(self, y, w): sum_weights = np.sum(w) @@ -63,20 +99,24 @@ def compute_node_value_and_impurity(self, y, w): raise ValueError(f"Unknown criterion: {self.criterion}") return pred, loss * sum_weights - def compute_split_nodes(self, X, y, w, feature, threshold=None, missing_left=False): - x = X[:, feature] - go_left = x <= threshold - if missing_left: + def compute_split_nodes(self, X, y, w, split): + x = X[:, split.feature] + if split.is_categorical: + x = x.astype(int) + cat_go_left = np.zeros(max(max(x), max(split.threshold)) + 1, dtype=bool) + cat_go_left[list(split.threshold)] = True + go_left = cat_go_left[x] + else: + go_left = x <= split.threshold + if split.missing_left: go_left |= np.isnan(x) return ( self.compute_node_value_and_impurity(y[go_left], w[go_left]), self.compute_node_value_and_impurity(y[~go_left], w[~go_left]), ) - def compute_split_impurity( - self, X, y, w, feature, threshold=None, missing_left=False - ): - nodes = self.compute_split_nodes(X, y, w, feature, threshold, missing_left) + def compute_split_impurity(self, X, y, w, split): + nodes = self.compute_split_nodes(X, y, w, split) (_, left_impurity), (_, right_impurity) = nodes return left_impurity + right_impurity @@ -85,21 +125,15 @@ def _generate_all_splits(self, X): x = X[:, f] nan_mask = np.isnan(x) thresholds = np.unique(x[~nan_mask]) + if self.is_categorical[f]: + thresholds = list(powerset(int(th) for th in thresholds)) for th in thresholds: - yield { - "feature": f, - "threshold": th, - "missing_left": False, - } + yield Split(f, th) if not nan_mask.any(): continue for th in [*thresholds, -np.inf]: # include -inf to test the split with only NaNs on the left node - yield { - "feature": f, - "threshold": th, - "missing_left": True, - } + yield Split(f, th, missing_left=True) def best_split_naive(self, X, y, w): splits = list(self._generate_all_splits(X)) @@ -107,17 +141,25 @@ def best_split_naive(self, X, y, w): return (np.inf, None) split_impurities = [ - self.compute_split_impurity(X, y, w, **split) for split in splits + self.compute_split_impurity(X, y, w, split) for split in splits ] return min(zip(split_impurities, splits), key=itemgetter(0)) +def to_categorical(x, nc, rng): + q = np.linspace(0, 1, num=nc + 1)[1:-1] + quantiles = np.quantile(x, q) + cats = np.searchsorted(quantiles, x) + return rng.permutation(nc)[cats] + + def make_simple_dataset( n, d, with_nans, is_sparse, + is_categorical, is_clf, n_classes, rng, @@ -126,6 +168,9 @@ def make_simple_dataset( y = rng.random(n) + X_dense.sum(axis=1) w = rng.integers(0, 5, size=n) if rng.uniform() < 0.5 else rng.random(n) + for idx in np.where(is_categorical)[0]: + nc = rng.integers(2, 6) # cant go to high or test will be too slow + X_dense[:, idx] = to_categorical(X_dense[:, idx], nc, rng) with_duplicates = rng.integers(2) == 0 if with_duplicates: X_dense = X_dense.round(1 if n < 50 else 2) @@ -160,13 +205,18 @@ def make_simple_dataset( ], ) @pytest.mark.parametrize( - "sparse, missing_values", - [(False, False), (True, False), (False, True)], - ids=["dense-without_missing", "sparse-without_missing", "dense-with_missing"], + "sparse, missing_values, categorical", + [(0, 0, 0), (1, 0, 0), (0, 1, 0), (0, 0, 1)], + ids=["dense", "sparse", "dense-with_missing", "dense-categorical"], ) -def test_split_impurity(Tree, criterion, sparse, missing_values, global_random_seed): +def test_split_impurity( + Tree, criterion, sparse, missing_values, categorical, global_random_seed +): is_clf = criterion in CLF_CRITERIONS + if categorical and "Extra" in Tree.__name__: + pytest.skip("Categorical features not implemented for the random splitter") + rng = np.random.default_rng(global_random_seed) ns = [5] * 5 + [10] * 5 + [20, 30, 50, 100] @@ -174,17 +224,24 @@ def test_split_impurity(Tree, criterion, sparse, missing_values, global_random_s for it, n in enumerate(ns): d = rng.integers(1, 4) n_classes = rng.integers(2, 5) # only used for classification - X_dense, X, y, w = make_simple_dataset( - n, d, missing_values, sparse, is_clf, n_classes, rng + + tree_kwargs = dict( + criterion=criterion, max_depth=1, random_state=global_random_seed ) + if categorical: + is_categorical = rng.random(d) < 0.5 + n_classes = 2 + tree_kwargs["categorical_features"] = is_categorical + else: + is_categorical = np.zeros(d, dtype=bool) - naive_splitter = NaiveSplitter(criterion, n_classes) + tree = Tree(**tree_kwargs) + naive_splitter = NaiveSplitter(criterion, n_classes, is_categorical) - tree = Tree( - criterion=criterion, - max_depth=1, - random_state=global_random_seed, + X_dense, X, y, w = make_simple_dataset( + n, d, missing_values, sparse, is_categorical, is_clf, n_classes, rng ) + tree.fit(X, y, sample_weight=w) actual_impurity = tree.tree_.impurity * tree.tree_.weighted_n_node_samples actual_value = tree.tree_.value[:, 0] @@ -206,12 +263,8 @@ def test_split_impurity(Tree, criterion, sparse, missing_values, global_random_s continue # Check children impurity: - actual_split = { - "feature": int(tree.tree_.feature[0]), - "threshold": tree.tree_.threshold[0], - "missing_left": bool(tree.tree_.missing_go_to_left[0]), - } - nodes = naive_splitter.compute_split_nodes(X_dense, y, w, **actual_split) + actual_split = Split.from_tree(tree) + nodes = naive_splitter.compute_split_nodes(X_dense, y, w, actual_split) (left_val, left_impurity), (right_val, right_impurity) = nodes assert_allclose(left_impurity, actual_impurity[1], atol=1e-12) assert_allclose(right_impurity, actual_impurity[2], atol=1e-12) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index fa6632cfddf73..c5d57a066012d 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2133,11 +2133,21 @@ def test_criterion_entropy_same_as_log_loss(Tree, n_classes): assert_allclose(tree_log_loss.predict(X), tree_entropy.predict(X)) +def to_categorical(x, nc): + q = np.linspace(0, 1, num=nc + 1)[1:-1] + quantiles = np.quantile(x, q) + cats = np.searchsorted(quantiles, x) + return np.random.permutation(nc)[cats] + + def test_different_endianness_pickle(): - X, y = datasets.make_classification(random_state=0) + X, y = datasets.make_classification(random_state=0, n_redundant=0, shuffle=False) + X[:, 0] = to_categorical(X[:, 0], 50) - clf = DecisionTreeClassifier(random_state=0, max_depth=3) + clf = DecisionTreeClassifier(random_state=0, max_depth=3, categorical_features=[0]) clf.fit(X, y) + assert 0 < clf.feature_importances_[0] < 1 + # ^ ensures some splits are categorical, some are continuous score = clf.score(X, y) def reduce_ndarray(arr): @@ -2160,9 +2170,12 @@ def get_pickle_non_native_endianness(): def test_different_endianness_joblib_pickle(): X, y = datasets.make_classification(random_state=0) + X[:, 0] = to_categorical(X[:, 0], 50) - clf = DecisionTreeClassifier(random_state=0, max_depth=3) + clf = DecisionTreeClassifier(random_state=0, max_depth=3, categorical_features=[0]) clf.fit(X, y) + assert 0 < clf.feature_importances_[0] < 1 + # ^ ensures some splits are categorical, some are continuous score = clf.score(X, y) class NonNativeEndiannessNumpyPickler(NumpyPickler): @@ -2221,13 +2234,15 @@ def get_different_alignment_node_ndarray(node_ndarray): def reduce_tree_with_different_bitness(tree): new_dtype = np.int64 if _IS_32BIT else np.int32 - tree_cls, (n_features, n_classes, n_outputs), state = tree.__reduce__() + tree_cls, (n_features, n_classes, n_outputs, is_categorical), state = ( + tree.__reduce__() + ) new_n_classes = n_classes.astype(new_dtype, casting="same_kind") new_state = state.copy() new_state["nodes"] = get_different_bitness_node_ndarray(new_state["nodes"]) - return (tree_cls, (n_features, new_n_classes, n_outputs), new_state) + return (tree_cls, (n_features, new_n_classes, n_outputs, is_categorical), new_state) def test_different_bitness_pickle(): @@ -2866,7 +2881,9 @@ def test_build_pruned_tree_py(): tree.fit(iris.data, iris.target) n_classes = np.atleast_1d(tree.n_classes_) - pruned_tree = CythonTree(tree.n_features_in_, n_classes, tree.n_outputs_) + pruned_tree = CythonTree( + tree.n_features_in_, n_classes, tree.n_outputs_, tree.is_categorical_ + ) # only keep the root note leave_in_subtree = np.zeros(tree.tree_.node_count, dtype=np.uint8) @@ -2880,7 +2897,9 @@ def test_build_pruned_tree_py(): assert_array_equal(tree.tree_.value[0], pruned_tree.value[0]) # now keep all the leaves - pruned_tree = CythonTree(tree.n_features_in_, n_classes, tree.n_outputs_) + pruned_tree = CythonTree( + tree.n_features_in_, n_classes, tree.n_outputs_, tree.is_categorical_ + ) leave_in_subtree = np.zeros(tree.tree_.node_count, dtype=np.uint8) leave_in_subtree[1:] = 1 @@ -2898,7 +2917,9 @@ def test_build_pruned_tree_infinite_loop(): tree = DecisionTreeClassifier(random_state=0, max_depth=1) tree.fit(iris.data, iris.target) n_classes = np.atleast_1d(tree.n_classes_) - pruned_tree = CythonTree(tree.n_features_in_, n_classes, tree.n_outputs_) + pruned_tree = CythonTree( + tree.n_features_in_, n_classes, tree.n_outputs_, tree.is_categorical_ + ) # only keeping one child as a leaf results in an improper tree leave_in_subtree = np.zeros(tree.tree_.node_count, dtype=np.uint8) @@ -3047,6 +3068,26 @@ def test_friedman_mse_deprecation(): _ = DecisionTreeRegressor(criterion="friedman_mse") +@pytest.mark.parametrize("Tree", [DecisionTreeClassifier, DecisionTreeRegressor]) +def test_categorical(Tree): + rng = np.random.default_rng(3) + n = 40 + c = rng.integers(0, 20, size=n) + y = c % 2 + + X = rng.random((n, 3)) + X[:, 0] = c + + tree = Tree(categorical_features=[0], max_depth=1, random_state=8) + # assert perfect tree was reached in one split + assert tree.fit(X, y).score(X, y) == 1 + assert tree.feature_importances_[0] == 1 + + # assert it's not the case without using categorical_features + tree = Tree(max_depth=1) + assert tree.fit(X, y).score(X, y) < 1 + + @pytest.mark.parametrize( "X,y", [ @@ -3066,3 +3107,32 @@ def test_random_splitter_missing_values_uses_non_missing_min_max(X, y): assert np.isfinite(threshold) assert non_missing.min() <= threshold <= non_missing.max() + + +def test_categorical_random_splitter_raises(): + X = np.array([[0], [1], [2], [3]], dtype=np.float32) + y = np.array([0, 1, 0, 1]) + + tree = DecisionTreeClassifier( + splitter="random", categorical_features=[0], random_state=0 + ) + with pytest.raises(ValueError, match="splitter='random'"): + tree.fit(X, y) + + +def test_categorical_multiclass_classification_raises(): + X = np.array([[0], [1], [2], [3]], dtype=np.float32) + y = np.array([0, 1, 2, 0]) + + tree = DecisionTreeClassifier(categorical_features=[0], random_state=0) + with pytest.raises(ValueError, match="binary classification"): + tree.fit(X, y) + + +def test_categorical_multioutput_raises(): + X = np.array([[0], [1], [2], [3]], dtype=np.float32) + y = np.array([[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [1.0, 0.0]]) + + tree = DecisionTreeRegressor(categorical_features=[0], random_state=0) + with pytest.raises(ValueError, match="multi-output"): + tree.fit(X, y)