Artificial Intelligence 7 min read

Visualizing Decision Trees for Product Purchase Prediction with scikit-learn and dtreeviz

This tutorial explains how to prepare advertising click data, train a decision‑tree classifier, and generate clear visualizations using scikit‑learn and dtreeviz, while also showing how to inspect individual prediction paths and feature importance.

Model Perspective
Model Perspective
Model Perspective
Visualizing Decision Trees for Product Purchase Prediction with scikit-learn and dtreeviz

Decision Tree Model

Decision trees are interpretable models for classification and regression. Using scikit-learn and dtreeviz you can visualize them, which helps present modeling results.

Product Purchase Prediction

Data preparation

The example uses an advertising click dataset with features Gender, Age, EstimatedSalary and target Purchased. After removing the user ID column, the data are cleaned, categorical text is encoded to numeric, and the dataset is split into training and test sets.

Model training

<code># import libraries
import numpy as np
import sklearn, sklearn.tree
import matplotlib.pyplot as plt
import pandas as pd
import sklearn.metrics as metrics
import seaborn as sn

# load dataset
dataset = pd.read_csv('data/Social_Network_Ads.csv')
dataset = dataset.drop(columns=['User ID'])
# encode categorical gender
enc = sklearn.preprocessing.OneHotEncoder()
enc.fit(dataset.iloc[:,[0]])
onehotlabels = enc.transform(dataset.iloc[:,[0]]).toarray()
genders = pd.DataFrame({'Female': onehotlabels[:,0], 'Male': onehotlabels[:,1]})
result = pd.concat([genders, dataset.iloc[:,1:]], axis=1, sort=False)

# split data
y = result['Purchased']
X = result.drop(columns=['Purchased'])
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=0)

# train decision tree
classifier = sklearn.tree.DecisionTreeClassifier(criterion='entropy', random_state=100, max_depth=2)
classifier.fit(X_train, y_train)
</code>

Visualization with scikit-learn

Using plot_tree you can draw the tree with feature and class names.

<code>feature_names=['Female','Male','Age','EstimatedSalary']
class_names=['No Purchase','Purchase']
from sklearn.tree import plot_tree, export_text
plt.figure(figsize=(10,5))
plot_tree(classifier, class_names=class_names, feature_names=X.columns, filled=True)
plt.show()
</code>

Advanced visualization with dtreeviz

dtreeviz produces richer visualizations, including histograms and pie charts, and allows orientation changes or depth range selection.

<code>import dtreeviz
random_state = 1234
viz_model = dtreeviz.model(classifier,
                           X_train=X_train, y_train=y_train,
                           feature_names=feature_names,
                           target_name=[0,1],
                           class_names=['No Purchased','Purchased'])
viz_model.view(scale=1)
</code>

You can also inspect a single prediction path and feature importance for a specific sample.

<code>x0 = dataset.iloc[0]
viz_model.view(x=x0)
print(viz_model.explain_prediction_path(x0))
</code>

The analysis shows that Age has the greatest impact, followed by EstimatedSalary, while Gender contributes little.

Pythonclassificationdecision treevisualizationscikit-learndtreeviz
Model Perspective
Written by

Model Perspective

Insights, knowledge, and enjoyment from a mathematical modeling researcher and educator. Hosted by Haihua Wang, a modeling instructor and author of "Clever Use of Chat for Mathematical Modeling", "Modeling: The Mathematics of Thinking", "Mathematical Modeling Practice: A Hands‑On Guide to Competitions", and co‑author of "Mathematical Modeling: Teaching Design and Cases".

0 followers
Reader feedback

How this landed with the community

login Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.