回归分析
回归分析是一种统计方法,用于研究变量之间的关系,特别是自变量对因变量的影响。
简单线性回归
概述
简单线性回归研究两个变量之间的线性关系。
实现
python
import numpy as np
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
# 创建数据
X = np.array([1, 2, 3, 4, 5]).reshape(-1, 1)
y = np.array([2, 4, 5, 4, 5])
# 拟合模型
model = LinearRegression()
model.fit(X, y)
# 输出结果
print(f"截距: {model.intercept_}")
print(f"系数: {model.coef_[0]}")
print(f"R²: {model.score(X, y)}")
# 预测
y_pred = model.predict(X)
# 可视化
plt.scatter(X, y)
plt.plot(X, y_pred, color='red')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('简单线性回归')
plt.show()多元线性回归
概述
多元线性回归研究多个自变量对因变量的影响。
实现
python
# 创建数据
X = np.array([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]])
y = np.array([3, 5, 7, 9, 11])
# 拟合模型
model = LinearRegression()
model.fit(X, y)
# 输出结果
print(f"截距: {model.intercept_}")
print(f"系数: {model.coef_}")
print(f"R²: {model.score(X, y)}")多项式回归
概述
多项式回归用于拟合非线性关系。
实现
python
from sklearn.preprocessing import PolynomialFeatures
# 创建数据
X = np.array([1, 2, 3, 4, 5]).reshape(-1, 1)
y = np.array([1, 4, 9, 16, 25])
# 创建多项式特征
poly = PolynomialFeatures(degree=2)
X_poly = poly.fit_transform(X)
# 拟合模型
model = LinearRegression()
model.fit(X_poly, y)
# 预测
y_pred = model.predict(X_poly)
# 可视化
plt.scatter(X, y)
plt.plot(X, y_pred, color='red')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('多项式回归')
plt.show()正则化回归
Ridge 回归
python
from sklearn.linear_model import Ridge
# 创建数据
X = np.array([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]])
y = np.array([3, 5, 7, 9, 11])
# 拟合 Ridge 回归模型
model = Ridge(alpha=1.0)
model.fit(X, y)
# 输出结果
print(f"截距: {model.intercept_}")
print(f"系数: {model.coef_}")Lasso 回归
python
from sklearn.linear_model import Lasso
# 拟合 Lasso 回归模型
model = Lasso(alpha=0.1)
model.fit(X, y)
# 输出结果
print(f"截距: {model.intercept_}")
print(f"系数: {model.coef_}")模型评估
常用指标
python
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
# 预测
y_pred = model.predict(X)
# 计算指标
mse = mean_squared_error(y, y_pred)
mae = mean_absolute_error(y, y_pred)
r2 = r2_score(y, y_pred)
print(f"MSE: {mse}")
print(f"MAE: {mae}")
print(f"R²: {r2}")残差分析
python
# 计算残差
residuals = y - y_pred
# 绘制残差图
plt.scatter(y_pred, residuals)
plt.axhline(y=0, color='red', linestyle='--')
plt.xlabel('预测值')
plt.ylabel('残差')
plt.title('残差图')
plt.show()注意事项
- 线性假设: 确保变量之间存在线性关系
- 多重共线性: 检查自变量之间的相关性
- 异方差性: 检查残差是否恒定
- 正态性: 检查残差是否服从正态分布