Decision tree is a popular supervised learning algorithm which can be used for classification as well as regression problems. A decision tree resembles a flow-chart, and is easy to interpret. The decision tree algorithm works by recursively splitting the data based the value of a feature. After each split, the portion of the data becomes more and more homogeneous, and eventually becomes mostly the same.
Consider a problem where we want to decide what to do on a particular day. Our decision making is based on three inputs (features), Work to do?, Outlook? and Friends busy?. Our decision tree might look something like the figure below.
Suppose our input is Work to do? = No, Outlook? = Sunny and Friends busy? = No. Then, our decision would be Go to beach. We start at the top node, and choose which child node to proceed to based on the value of the feature at that node. Each terminal node (gray nodes) have a possible output value.
Note that it is possible for multiple terminal nodes to output the same value. For example, in the above picture, there are two possible ways to get "Stay in".
Lets assume that we have a metric that defines how impure a dataset is. The following is a pseudocode for constructing the decision tree
- Start with all the data in one node
- Split the dataset into two parts A and B based on a feature that results in largest purity gain (or impurity reduction).
- Repeat this process of splitting on child node until we get nodes which are pure, i.e. they contain samples of a single class, or some other stopping criteria is met.
To concretize the above procedure, we need to define the impurity metric. Commonly used metrics are gini impurity (IG), entropy (IE) and classification error (IC). [note 1]
Mathematical definitions of the impurity metrics (p1, p2, ... pk) are the probabilities of each target class in the dataset D.
Gini impurity measures what the error would be if we predicted based on randomly sampling the target class. Entropy is a standard metric from information theory. And classification error is what the error would be if we predicted the most common class.
Finally, purity gain is
where, f is the feature to perform the split, Dx are datasets, I is the impurity measure (i.e gini impurity, etc), Nx is the total number of samples in the Dx. That is, purity gain is the difference between the impurity of parent node and the weighed sum of the child node impurities.
For computational efficiency reasons, most machine learning libraries implement binary decision tree, that is, each parent node has two child nodes.
- Easy to understand because of their interpretability,
- Less data cleaning is required
- Can handle both numerical and categorical data
- Overfitting, which can be partially solved by pruning
Now let us implement Decision Tree in sklearn.
# import libs and datasetfrom sklearn.datasets import load_irisfrom sklearn.tree import DecisionTreeClassifierfrom sklearn.model_selection import train_test_splitdataset = load_iris()X = dataset.datay = dataset.targetX_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.3, random_state=42)model = DecisionTreeClassifier(criterion="entropy", max_depth=3, random_state=42)# decision trees are prone to overfitting thats why we remove some# sub-nodes of the tree, that's called "pruning"# here, we control depth of the tree using max_depth attribute# other option for criterion is "gini"# random_state: just to make sure we get same results each time we# run this codemodel.fit(X_train, y_train)# DecisionTreeClassifier(class_weight=None, criterion='entropy',# max_depth=3, max_features=None, max_leaf_nodes=None,# min_impurity_split=1e-07, min_samples_leaf=1,# min_samples_split=2, min_weight_fraction_leaf=0.0,# presort=False, random_state=42, splitter='best')# test the modelmodel.score(X_test, y_test)# Output: 0.97777777777777775
from sklearn.externals.six import StringIOfrom sklearn.tree import export_graphvizimport pydotplusdot_data = StringIO()export_graphviz(model, out_file=dot_data)graph = pydotplus.graph_from_dot_data(dot_data.getvalue())graph.write_pdf("iris.pdf")from IPython.display import ImageImage(filename="./images/tree.png")