I am implementing the Random Ferns Algorithm for Classification. For simplicity, let's imagine a single decision tree with only a single node. As input we have a feature and the label of each dataset.
The function should work properly for any number of classes (length of set(labels)). The output is a feature threshold which leads to the best split. I plan to further implement other impurity measures such as misclassification rate or entropy.
For those interested in the topic, here is a link to a short introduction presentation in pdf format for the topic: classification trees and node split.
My current implementation works fine, yet I am sure there is plenty of place for improvement. If you have question regarding functionality, please ask. I gave some comments which explain what is to be done.
Example input:
test_features = [1,2,3,3,4]
test_labels = [0,0,1,1,1]
Example output:
3
Code as follows:
def best_split(feature_values, labels):
# training for each node/feature determining the threshold
impurity = []
possible_thresholds = sorted(list(set(feature_values)))
# the only relevant possibilities for a threshold are the feature values themselves
for threshold in possible_thresholds:
# split node content based on threshold
# to do here: what happens if len(right) or len(left) is zero
right = [label for value, label in zip(feature_values, labels) if value >= threshold]
left = [label for value, label in zip(feature_values, labels) if value < threshold]
# compute distribution of labels for each split
right_distribution = [len(list(group)) for key, group in groupby(right)]
left_distribution = [len(list(group)) for key, group in groupby(left)]
# compute impurity of split based on the distribution
gini_right = 1 - np.sum((right_distribution / np.sum(right_distribution)) ** 2)
gini_left = 1 - np.sum((left_distribution / np.sum(left_distribution)) ** 2)
# compute weighted total impurity of the split
gini_split = (len(right) * gini_right + len(left) * gini_left) / len(labels)
impurity.append(gini_split)
# returns the threshold with the highest associated impurity value --> best split threshold
return possible_thresholds[impurity.index(min(impurity))]
This function is used for training of the Random Ferns class such as:
def train(self, patches, labels):
self.classes = list(set(labels))
# here uniform distribution for each leaf is assumed
# for each fern - for each feature combination - for each class - there is a posterior probability
# these are all stored in a list of lists of lists named 'posterior'
initial_distribution = [1 / len(self.classes)] * len(self.classes)
self.posterior = [[initial_distribution] * (2 ** self.fernsize)] * self.number_of_ferns
#determining the best threshold for each feature using best_split function
all_thresholds = []
for fern in self.ferns:
fern_thresholds = []
for feature_params in fern:
# the function feature() extracts the feature values of a
# specific feature (determined by feature_params) from each patch in patches
feature_values = feature(patches, feature_params)
fern_thresholds.append(best_split(feature_values, labels))
all_thresholds.append(fern_thresholds)
self.threshold = all_thresholds
if ... elif... else...or exception handling usingtry... except... else... finally.... \$\endgroup\$