Tudor Lapusan's Blog

Post info:

Visual interpretation of Decision Tree structure

In Machine Learning it’s important to understand why based on specific inputs (model hyperparameters, features or training set) your models generate some specific outputs (model performance measured by loss functions).

My opinion is if we just measure the model performance we will don’t have the full picture of what’s happening behind, so we may end up luckily selecting the set of hyperparameters which we think generate the best model. Maybe the worst thing is that for the next ML project we are starting almost from scratch because we didn’t developed a skill to help us understand where to look when our model is performing bad.

Developing this skill I think is mandatory to become a better machine learning practitioner.
With this goal in mind, I started to develop a python library, woodpecker,  to help me understand the model structure and the reason behind model performance. The first version of the library contains logic for DecisionTreeClassifier from sklearn library.
If my library proves to be helpful, I will continue to expend it to other algorithms, like DecisionTreeRegression, RandomForest and other tree based ensembles.

This article assume that you have a basic understanding of how decision tree works and to be familiar with terms like split/leaf node, entropy and information gain. If not, just search on internet, it is full of good tutorials :). I personally like this list of video tutorials (sounds quality is not so good)

I chose a well known dataset, titanic, to present the capabilities of the library. Each row from this dataset represents passenger features. After applying feature engineer, I selected this set of features :
Pclass : A proxy for socio-economic status

  • 1st = Upper
  • 2nd = Middle
  • 3rd = Lower

Age : Age is fractional if less than 1. If the age is estimated, is it in the form of xx.5
Fare : ticket price
Sex_label : encoded variable for raw feature Sex. It represents passenger gender

  • 0 = female
  • 1 = Male

Cabin_label : encoded variable for raw feature Cabin. It represents the cabin number.
Embarked_label : encoded variable for raw feature Embarked.  It represents the port of embarkation

The goal of this dataset is to predict the chance of a passenger to survive or not. The target variable is the column Survived :

  • 1 = Survived
  • 0 = Didn’t survived

Before training our Decision Tree Classifier model we need to split the dataset intro training and validation datasets.


Now let’s imagine that is the first time when we see the dataset and we have pretty low domain knowledge about it.
Our first try is a decision tree with depth = 10.


Model performance is as follow :

  • train accuracy 0.917
  • valid accuracy 0.821

As we can see our model is overfitting. It could be caused by a too big depth, an inappropriate train/validation split or others. We don’t have a clear hint, so let’s visualise the tree and try to understand it’s structure.



We have a big boy tree!
It is very hard, if not impossible, to understand from this visualisation how the model perform overall or to have a better clue of the tree structure. This visualisation is the most common one used in the sklearn community to visualise trees and when we have a big/complex dataset we are forced to increase the depth of the tree to get better predictions.

Because of these types of difficulties to better understand what’s happening behind a decision tree, I decided to develop a library from scratch. This library was started with two main goals into my mind :

  • help me better understand model structure and its performance
  • help me explain model predictions for technical/non technical people.

You can install it by simply running “pip install git+https://github.com/tlapusan/woodpecker.git” in your terminal. Once you have it installed, you can access library capabilities for DecisionTreeClassifier using DecisionTreeStructure class.

woodpecker init

As you can see, I initialised it with the trained model, decision_tree, with training set, the list of features and target variable.

The tree model performance comes from its leaves. Leaves are the nodes from the tree which will make the final predictions.
Any information related to leaves showed in an easy, interpretable way can help us the understand the model performance and what set of new values to choose for hyperparameters for the next fit.
The most important information from leaves are entropy and number of samples :

  • entropy measures how impure is a node. Range value for entropy is [0,1] where 1 means the node is totally impure (bad) and 0 means the node is totally pure (good).
  • samples : number of samples from training set reaching the node.

Using woodpecker, we can very simply get a general overview of leaves impurities.






The tree contains 73 leaves in total. We can see from above visualisation that more than 60 leaves have impurity very close to 0. Theoretically this is a good sign, but it can also be a sign that our model is overfitting. This will happen in case when impurities very close to 0 came from nodes with very few samples. Let’s see if this is the case.


As we can see, the majority of leaves have very few samples. Now we know for sure why our model is overfitting and we have a clue which hyperparameters to change to reduce overfitting. In this scenario, is recommended to decrease the value for max_depth and/or increase values for min_samples_split and min_samples_leaf.



Above we looked at leaves information in a general way, but what if we want to get insights from leaves individually.


These three visualisations contain informations only for leaf nodes (axe oX):

  • first visualisation contains leaves impurities
  • second contains leaves samples
  • third one contains leaves samples by class


I do feel that these visualisations can help a lot to understand model performance. For now, I can say that the model either predict very well (impurity 0) or very bad (impurity > 0.5).
Third visualisation also reflect very interesting thing.
Look for example at leaf 19, more than 60 samples and all with class label 1 (survived).
leaf_stats_1Interesting, all are females (Sex_label=0), are from upper socio-economic status (Pclass ± 1), bought an expensive ticket (mean(Fare)=78.5), majority between age 22 and 38.

Now let’s look in the opposite direction, leaf 98, with approx40 samples and all with class label 0 (not survived).


All are men (Sex_label=1), are from lower socio-economic status (Pclass ± 2.7), bought a cheaper ticket (mean(fare) =12.9), majority between age 34 and 38.
By investigating only these two leaves, we can already see patterns in our dataset. Young to mid-age rich women had a very high chance to survive and mid-age poor men not.

Let’s look at leaf 118, which is somehow 50/50 chance to survive.


All samples seem to be man(Sex_label=1), are from upper socio-economic status (Pclass=1), bought an expensive ticket(mean(fare)=49.1), majority between age 28 and 41. Hard to get some ideas….let’s look at individual samples :


Both class labels have same kind of age, paid kind of the same price for the ticket. I cannot see a clear pattern how to split the samples, maybe neither the model.
This can be the case when we need to add more features in our training set !



– notes

  • majority of people are used to manually tune hyperparameters and in fact seems to have better results
    • first we should understand if our model is overfitting or underfiting and somehow start to change our hyperparameters in the right direction
  • constrain the model
    • if we are going to increase the values for some hyperparameters, the model will change its state from overfitting to underfiting
  • decision tree is a white box model
    • it is easily interpretables, compare with black box model (like CNN)


  • Node #0 : indicate the node id number from the tree
  • Sex_label <= 0.5 : this feature was chosen for the best split based on information gain.
  • value : an array of length two. Value from index 0 represents the number of samples for class 0 and value from index 1 represents the number of samples for class 1.
  • class : the predicted class for this node
    • survived = class 0
    • not survived = class 1


Leave a Reply

Your email address will not be published. Required fields are marked *