This is a very simple example of building a decision tree model on a very small dataset that has only six rows of data. The dataset only has 6 rows of data. With so few rows of data we can easily see the tree and understand how the tree was built. Let’s dive right inn to the Python code we would use. To make things easier, I have provided the data right in the code itself, opposed to importing the data from an external file.
2
import
matplotlib.pyplot as plt
5
from
sklearn.tree
import
DecisionTreeClassifier
8
from
sklearn.tree
import
plot_tree
3
data
=
{
"J"
: [
1
,
1
,
0
,
1
,
0
,
1
],
4
"K"
: [
1
,
1
,
0
,
0
,
1
,
1
],
5
"L"
: [
1
,
0
,
1
,
0
,
1
,
0
],
6
"Class"
: [
'A'
,
'A'
,
'B'
,
'B'
,
'A'
,
'B'
]}
We are trying to predict the class. Notice that the K feature is the best feature for predicting.
When K = 1 it is class A three out of four times, which seems to the best predictor.
When L = 1 it is A two out of three times, which is not a good predictor.
When J = 1 it is class A half of the time, making J a poor predictor
So, in order they are K, L and J.
6
X
=
X.drop(
'Class'
, axis
=
1
)
Instantiate the Model
2
decision_tree
=
DecisionTreeClassifier(random_state
=
42
)
1
cn
=
list
(
set
(df[
'Class'
]))
2
plt.figure(figsize
=
(
6
,
4
))
3
plot_tree(decision_tree, max_depth
=
4
, fontsize
=
9
, feature_names
=
[
'J'
,
'K'
,
'L'
], class_names
=
cn, filled
=
True
);
Test Dataset One – Perfect K
3
data_test
=
{
"J"
: [
1
,
0
,
0
],
6
"Class"
: [
'A'
,
'A'
,
'B'
]}
7
df_test
=
pd.DataFrame(data_test)
2
y_test
=
df_test[
'Class'
]
6
X_test
=
X_test.drop(
'Class'
, axis
=
1
)
1
predictions
=
decision_tree.predict(X_test)
3
from
sklearn.metrics
import
classification_report, confusion_matrix
4
print
(confusion_matrix(y_test, predictions))
6
print
(classification_report(y_test, predictions))
Test Dataset Two – Imperfect K
3
data_test2
=
{
"J"
: [
1
,
0
,
0
],
6
"Class"
: [
'A'
,
'A'
,
'B'
]}
7
df_test2
=
pd.DataFrame(data_test2)
2
y_test2
=
df_test2[
'Class'
]
5
X_test2
=
df_test2.copy()
6
X_test2
=
X_test2.drop(
'Class'
, axis
=
1
)
1
predictions2
=
decision_tree.predict(X_test2)
2
print
(confusion_matrix(y_test2, predictions2))
4
print
(classification_report(y_test2, predictions2))
5
precision recall f1-score support
11
macro avg 0.75 0.75 0.67 3
12
weighted avg 0.83 0.67 0.67 3
Series Navigation << Decision Tree 4 Rows Decision Trees and Random Forests >>