CommonLounge Archive

Decision Tree

May 27, 2017

Decision tree is a popular supervised learning algorithm which can be used for classification as well as regression problems. A decision tree resembles a flow-chart, and is easy to interpret. The decision tree algorithm works by recursively splitting the data based on the value of a feature. After each split, the portion of the data becomes more and more homogeneous, and eventually becomes mostly the same.

Illustrative example

Consider a problem where we want to decide what to do on a particular day. Our decision making is based on three inputs (features), Work to do?, Outlook? and Friends busy?. Our decision tree might look something like the figure below.

Source: Python Machine Learning by Sebastian Raschka

Suppose our input is Work to do? = No, Outlook? = Sunny and Friends busy? = No. Then, our decision would be Go to beach. We start at the top node, and choose which child node to proceed to based on the value of the feature at that node. Each terminal node (gray nodes) has a possible output value.

Note that it is possible for multiple terminal nodes to output the same value. For example, in the above picture, there are two possible ways to get “Stay in”.

Constructing the decision tree

Lets assume that we have a metric that defines how impure a dataset is. The following is a pseudocode for constructing the decision tree

  1. Start with all the data in one node
  2. Split the dataset into two parts A and B based on a feature that results in largest purity gain (or impurity reduction).
  3. Repeat this process of splitting on child node until we get nodes which are pure, i.e. they contain samples of a single class, or some other stopping criteria is met.

Impurity metrics

To concretize the above procedure, we need to define the impurity metric. Commonly used metrics are gini impurity (IG), entropy (IE) and classification error (IC). [note 1]

Graph of impurity measures for a dataset with two classes, as a function of the fraction of data points which belong to the first class. Notice that in all cases, impurity is 0 when only one class is present, and is maximum when the classes are split 50-50. Source [1].

Mathematical definitions of the impurity metrics (p1, p2, … pk) are the probabilities of each target class in the dataset D.

$$ \begin{aligned} I_G(D) &= \sum{p_i(1-p_i)} \\ I_E(D) &= \sum{p_i\log_2(p_i)} \\ I_C(D) &= 1 - max(p_1, p_2, ..., p_k) \end{aligned} $$

Gini impurity measures what the error would be if we predicted based on randomly sampling the target class. Entropy is a standard metric from information theory. And classification error is what the error would be if we predicted the most common class.

Finally, purity gain is

$$ PurityGain(D_{parent}, f) = I(D_{parent}) - \sum_{child}^{m} \frac{N_{child}}{N_{parent}} I(D_{child}) $$

where, f is the feature to perform the split, Dx are datasets, I is the impurity measure (i.e gini impurity, etc), Nx is the total number of samples in the Dx. That is, purity gain is the difference between the impurity of parent node and the weighed sum of the child node impurities.

For computational efficiency reasons, most machine learning libraries implement binary decision tree, that is, each parent node has two child nodes.

Advantages and disadvantages


  • Easy to understand because of their interpretability,
  • Less data cleaning is required
  • Can handle both numerical and categorical data


  • Overfitting, which can be partially solved by pruning

Illustration in Scikit-learn

Now let us implement Decision Tree in sklearn.

# import libs and dataset
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
dataset = load_iris()
X =
y =
X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.3, random_state=42)
model = DecisionTreeClassifier(criterion="entropy", max_depth=3, random_state=42)
# decision trees are prone to overfitting thats why we remove some 
# sub-nodes of the tree, that's called "pruning"
# here, we control depth of the tree using max_depth attribute
# other option for criterion is "gini"
# random_state: just to make sure we get same results each time we 
# run this code, y_train)
# DecisionTreeClassifier(class_weight=None, criterion='entropy', 
#            max_depth=3, max_features=None, max_leaf_nodes=None,
#            min_impurity_split=1e-07, min_samples_leaf=1,
#            min_samples_split=2, min_weight_fraction_leaf=0.0,
#            presort=False, random_state=42, splitter='best')
# test the model
model.score(X_test, y_test)
# Output: 0.97777777777777775

We can also visualize the decision tree. You need to have Graph Viz and pydotplus installed in your system.

from sklearn.externals.six import StringIO  
from sklearn.tree import export_graphviz
import pydotplus
dot_data = StringIO() 
export_graphviz(model, out_file=dot_data) 
graph = pydotplus.graph_from_dot_data(dot_data.getvalue()) 
from IPython.display import Image


  1. TPZ : Photometric redshift PDFs and ancillary information by using prediction trees and random forests

© 2016-2022. All rights reserved.