Decision Tree – Only Six Rows


This entry is part 5 of 6 in the series Decision Trees

This is a very simple example of building a decision tree model on a very small dataset that has only six rows of data. The dataset only has 6 rows of data. With so few rows of data we can easily see the tree and understand how the tree was built. Let’s dive right inn to the Python code we would use. To make things easier, I have provided the data right in the code itself, opposed to importing the data from an external file.

import pandas as pd
import matplotlib.pyplot as plt

# from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier

# This function displays the splits of the tree
from sklearn.tree import plot_tree

# to manually create a DataFrame, start with a dictionary of equal-length lists.
# The data has been slightly modified from the Udemy course
data = {"J": [1, 1, 0, 1, 0, 1],
        "K": [1, 1, 0, 0, 1, 1],
        "L": [1, 0, 1, 0, 1, 0],
        "Class": ['A', 'A', 'B', 'B', 'A', 'B']}
df = pd.DataFrame(data)
df

  • We are trying to predict the class. Notice that the K feature is the best feature for predicting.
  • When K = 1 it is class A three out of four times, which seems to the best predictor.
  • When L = 1 it is A two out of three times, which is not a good predictor.
  • When J = 1 it is class A half of the time, making J a poor predictor
  • So, in order they are K, L and J.
# Define the y (target) variable
y = df['Class']

# Define the X (predictor) variables
X = df.copy()
X = X.drop('Class', axis=1)
X

y

Instantiate the Model

# Instantiate the model
decision_tree = DecisionTreeClassifier(random_state=42)

# We will use ALL of the date to train the model.
# Fit the model to training data
decision_tree.fit(X, y)

print(X.columns)
print('\n')
X.info()

cn = list(set(df['Class']))
cn

# Plot the tree
plt.figure(figsize=(6,4))
plot_tree(decision_tree, max_depth=4, fontsize=9, feature_names=['J', 'K', 'L'], class_names=cn, filled=True);
plt.show()

Test Dataset One – Perfect K

# to manually create a DataFrame, start with a dictionary of equal-length lists.
# The data has been slightly modified from the Udemy course
data_test = {"J": [1, 0, 0],
        "K": [1, 1, 0],
        "L": [1, 0, 1],
        "Class": ['A', 'A', 'B']}
df_test = pd.DataFrame(data_test)
df_test
# Define the y (target) variable
y_test = df_test['Class']

# Define the X (predictor) variables
X_test = df_test.copy()
X_test = X_test.drop('Class', axis=1)
X_test

predictions = decision_tree.predict(X_test)
# 
from sklearn.metrics import classification_report, confusion_matrix
print(confusion_matrix(y_test, predictions))
print('\n')
print(classification_report(y_test, predictions))

Test Dataset Two – Imperfect K

# to manually create a DataFrame, start with a dictionary of equal-length lists.
# The data has been slightly modified from the Udemy course
data_test2 = {"J": [1, 0, 0],
        "K": [1, 0, 0],
        "L": [1, 0, 1],
        "Class": ['A', 'A', 'B']}
df_test2 = pd.DataFrame(data_test2)
df_test2

# Define the y (target) variable
y_test2 = df_test2['Class']

# Define the X (predictor) variables
X_test2 = df_test2.copy()
X_test2 = X_test2.drop('Class', axis=1)
X_test2

predictions2 = decision_tree.predict(X_test2)
print(confusion_matrix(y_test2, predictions2))
print('\n')
print(classification_report(y_test2, predictions2))

[[1 1]
 [0 1]]


              precision    recall  f1-score   support

           A       1.00      0.50      0.67         2
           B       0.50      1.00      0.67         1

    accuracy                           0.67         3
   macro avg       0.75      0.75      0.67         3
weighted avg       0.83      0.67      0.67         3
Series Navigation<< Decision Tree 4 RowsDecision Trees and Random Forests >>

Leave a comment

Your email address will not be published. Required fields are marked *