54 lines
1.8 KiB
Python
54 lines
1.8 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
from __future__ import absolute_import
|
|
|
|
from packtml.regression import SimpleLinearRegression
|
|
from sklearn.linear_model import LinearRegression
|
|
from sklearn.model_selection import train_test_split
|
|
from matplotlib import pyplot as plt
|
|
import numpy as np
|
|
import sys
|
|
|
|
# #############################################################################
|
|
# Create a data-set that perfectly models the linear relationship:
|
|
# y = 2a + 1.5b + 0
|
|
random_state = np.random.RandomState(42)
|
|
X = random_state.rand(500, 2)
|
|
y = 2. * X[:, 0] + 1.5 * X[:, 1]
|
|
|
|
# split the data
|
|
X_train, X_test, y_train, y_test = train_test_split(X, y,
|
|
random_state=random_state)
|
|
|
|
# #############################################################################
|
|
# Fit a simple linear regression, produce predictions
|
|
lm = SimpleLinearRegression(X_train, y_train)
|
|
predictions = lm.predict(X_test)
|
|
print("Test sum of residuals: %.3f" % (y_test - predictions).sum())
|
|
assert np.allclose(lm.theta, [2., 1.5])
|
|
|
|
# #############################################################################
|
|
# Show that our solution is similar to scikit-learn's
|
|
|
|
lr = LinearRegression(fit_intercept=True)
|
|
lr.fit(X_train, y_train)
|
|
assert np.allclose(lm.theta, lr.coef_)
|
|
assert np.allclose(predictions, lr.predict(X_test))
|
|
|
|
# #############################################################################
|
|
# Fit another on ONE feature so we can show the plot
|
|
X_train = X_train[:, np.newaxis, 0]
|
|
X_test = X_test[:, np.newaxis, 0]
|
|
lm = SimpleLinearRegression(X_train, y_train)
|
|
|
|
# create the predictions & plot them as the line
|
|
preds = lm.predict(X_test)
|
|
plt.scatter(X_test[:, 0], y_test, color='black')
|
|
plt.plot(X_test[:, 0], preds, linewidth=3)
|
|
|
|
# if we're supposed to save it, do so INSTEAD OF showing it
|
|
if len(sys.argv) > 1:
|
|
plt.savefig(sys.argv[1])
|
|
else:
|
|
plt.show()
|