I'm trying to define a custom splitter using sklearn Classification Trees classes, but I'm getting no results so far. I got no errors but the tree is not developed. How to achieve this?
My strategy is largely inspired by this approach: __cinit__() takes exactly 2 positional arguments when extending a Cython class
Here is the code:
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree._tree import Tree
from sklearn.tree._splitter import BestSplitter
from sklearn.tree._criterion import Gini
from sklearn.tree._classes import DepthFirstTreeBuilder
class CustomBestSplitter(BestSplitter):
""" Custom splitter that only allows splits on even-indexed features """
def __init__(self, *args, **kwargs):
pass
""" Custom splitter that only allows splitting on even-indexed features """
def best_split(self, *args, **kwargs):
best_split = super().best_split(*args, **kwargs)
if best_split is not None:
print(best_split)
feature_index = best_split[0] # Extract feature index from split
if feature_index % 2 != 0: # Enforce even-index features only
return None # Reject the split if it does not satisfy the constraint
return best_split # Otherwise, allow the split
class CustomDepthFirstTreeBuilder(DepthFirstTreeBuilder):
def __init__(self, *args, **kwargs):
pass
def best_tree(self, *args, **kwargs):
best_tree = super().best_tree(*args, **kwargs)
# Build the tree manually
builder.build(tree = self.tree_, X = X, y = y)
return self
class CustomDecisionTreeClassifier(DecisionTreeClassifier):
def __init__(self, criterion='gini', splitter='best', max_depth=None, min_samples_split=2, min_samples_leaf=1,
min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None,
min_impurity_decrease=0.0, class_weight=None, ccp_alpha=0.0, monotonic_cst=None):
# Appeler le constructeur de la classe parente
super().__init__(criterion=criterion, splitter=splitter, max_depth=max_depth,
min_samples_split=min_samples_split, min_samples_leaf=min_samples_leaf,
min_weight_fraction_leaf=min_weight_fraction_leaf, max_features=max_features,
random_state=random_state, max_leaf_nodes=max_leaf_nodes,
min_impurity_decrease=min_impurity_decrease, class_weight=class_weight,
ccp_alpha=ccp_alpha, monotonic_cst = monotonic_cst)
pass
def fit(self, X, y, sample_weight=None, check_input=True, X_idx_sorted=None):
""" Override fit to inject custom splitter """
y = y.reshape(-1, 1) if y.ndim == 1 else y
# Compute number of outputs and classes
n_outputs = 1 if len(y.shape) == 1 else y.shape[1]
n_classes = np.array([np.unique(y).shape[0]], dtype=np.intp)
# Create tree structure
self.tree_ = Tree(X.shape[1], n_classes, n_outputs)
# Initialize Gini criterion
criterion = Gini(n_outputs, n_classes)
# Initialize the custom splitter correctly
splitter = CustomBestSplitter(criterion=criterion,
max_features=X.shape[1],
min_samples_leaf=1,
min_weight_leaf=0.0,
random_state=None,
monotonic_cst=None)
# Manually create a tree builder with the custom splitter
builder = CustomDepthFirstTreeBuilder(
splitter=splitter,
min_samples_split=2,
min_samples_leaf=1,
min_weight_leaf=0.0,
max_depth=3,
min_impurity_decrease=0.0
)
return builder
# Generate synthetic data
X = np.random.rand(100, 5) # 100 samples, 5 features
y = np.random.randint(0, 2, 100) # Binary target
# Train the custom decision tree
model = CustomDecisionTreeClassifier(max_depth=3)
model.fit(X, y)