54 lines
1.8 KiB
Python
54 lines
1.8 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
from __future__ import absolute_import
|
|
|
|
from packtml.decision_tree import CARTRegressor
|
|
from sklearn.metrics import mean_squared_error
|
|
from sklearn.model_selection import train_test_split
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import sys
|
|
|
|
# #############################################################################
|
|
# Create a classification dataset
|
|
rs = np.random.RandomState(42)
|
|
X = np.sort(5 * rs.rand(80, 1), axis=0)
|
|
y = np.sin(X).ravel()
|
|
|
|
# split the data
|
|
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
|
|
|
|
# #############################################################################
|
|
# Fit a simple decision tree regressor and get predictions
|
|
clf = CARTRegressor(X_train, y_train, max_depth=3, random_state=42)
|
|
pred = clf.predict(X_test)
|
|
clf_mse = mean_squared_error(y_test, pred)
|
|
print("Test MSE (depth=3): %.3f" % clf_mse)
|
|
|
|
# Fit a deeper tree and show accuracy increases
|
|
clf2 = CARTRegressor(X_train, y_train, max_depth=10, random_state=42)
|
|
pred2 = clf2.predict(X_test)
|
|
clf2_mse = mean_squared_error(y_test, pred2)
|
|
print("Test MSE (depth=10): %.3f" % clf2_mse)
|
|
|
|
# #############################################################################
|
|
# Visualize difference in learning ability
|
|
|
|
x = X_train.ravel()
|
|
xte = X_test.ravel()
|
|
|
|
fig, axes = plt.subplots(1, 2, figsize=(12, 8))
|
|
axes[0].scatter(x, y_train, alpha=0.25, c='r')
|
|
axes[0].scatter(xte, pred, alpha=1.)
|
|
axes[0].set_title("Shallow tree (depth=3) test MSE: %.3f" % clf_mse)
|
|
|
|
axes[1].scatter(x, y_train, alpha=0.4, c='r')
|
|
axes[1].scatter(xte, pred2, alpha=1.)
|
|
axes[1].set_title("Deeper tree (depth=10) test MSE: %.3f" % clf2_mse)
|
|
|
|
# if we're supposed to save it, do so INSTEAD OF showing it
|
|
if len(sys.argv) > 1:
|
|
plt.savefig(sys.argv[1])
|
|
else:
|
|
plt.show()
|