83 lines
2.8 KiB
Python
83 lines
2.8 KiB
Python
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 600
|
|
from sklearn.tree import DecisionTreeClassifier
|
|
from ipywidgets import interact
|
|
|
|
|
|
def visualize_tree(estimator, X, y, boundaries=True,
|
|
xlim=None, ylim=None, ax=None):
|
|
ax = ax or plt.gca()
|
|
|
|
# Plot the training points
|
|
ax.scatter(X[:, 0], X[:, 1], c=y, s=30, cmap='viridis',
|
|
clim=(y.min(), y.max()), zorder=3)
|
|
ax.axis('tight')
|
|
ax.axis('off')
|
|
if xlim is None:
|
|
xlim = ax.get_xlim()
|
|
if ylim is None:
|
|
ylim = ax.get_ylim()
|
|
|
|
# fit the estimator
|
|
estimator.fit(X, y)
|
|
xx, yy = np.meshgrid(np.linspace(*xlim, num=200),
|
|
np.linspace(*ylim, num=200))
|
|
Z = estimator.predict(np.c_[xx.ravel(), yy.ravel()])
|
|
|
|
# Put the result into a color plot
|
|
n_classes = len(np.unique(y))
|
|
Z = Z.reshape(xx.shape)
|
|
contours = ax.contourf(xx, yy, Z, alpha=0.3,
|
|
levels=np.arange(n_classes + 1) - 0.5,
|
|
cmap='viridis', zorder=1)
|
|
|
|
ax.set(xlim=xlim, ylim=ylim)
|
|
|
|
# Plot the decision boundaries
|
|
def plot_boundaries(i, xlim, ylim):
|
|
if i >= 0:
|
|
tree = estimator.tree_
|
|
|
|
if tree.feature[i] == 0:
|
|
ax.plot([tree.threshold[i], tree.threshold[i]], ylim, '-k', zorder=2)
|
|
plot_boundaries(tree.children_left[i],
|
|
[xlim[0], tree.threshold[i]], ylim)
|
|
plot_boundaries(tree.children_right[i],
|
|
[tree.threshold[i], xlim[1]], ylim)
|
|
|
|
elif tree.feature[i] == 1:
|
|
ax.plot(xlim, [tree.threshold[i], tree.threshold[i]], '-k', zorder=2)
|
|
plot_boundaries(tree.children_left[i], xlim,
|
|
[ylim[0], tree.threshold[i]])
|
|
plot_boundaries(tree.children_right[i], xlim,
|
|
[tree.threshold[i], ylim[1]])
|
|
|
|
if boundaries:
|
|
plot_boundaries(0, xlim, ylim)
|
|
|
|
|
|
def plot_tree_interactive(X, y):
|
|
def interactive_tree(depth=5):
|
|
clf = DecisionTreeClassifier(max_depth=depth, random_state=0)
|
|
visualize_tree(clf, X, y)
|
|
|
|
return interact(interactive_tree, depth=(1, 5))
|
|
|
|
|
|
def randomized_tree_interactive(X, y):
|
|
N = int(0.75 * X.shape[0])
|
|
|
|
xlim = (X[:, 0].min(), X[:, 0].max())
|
|
ylim = (X[:, 1].min(), X[:, 1].max())
|
|
|
|
def fit_randomized_tree(random_state=0):
|
|
clf = DecisionTreeClassifier(max_depth=15)
|
|
i = np.arange(len(y))
|
|
rng = np.random.RandomState(random_state)
|
|
rng.shuffle(i)
|
|
visualize_tree(clf, X[i[:N]], y[i[:N]], boundaries=False,
|
|
xlim=xlim, ylim=ylim)
|
|
|
|
interact(fit_randomized_tree, random_state=(0, 100));
|