使用机器学习预测心脏病

根据一些病理学属性预测心脏病

特别说明:

  1. 开新坑啦!本系列共2个项目,难度不大,特别适合新手入坑

  2. 由于本项目只是系列课程的第一个项目,所以很多细节不深挖,仅做示范,在第二个项目中再完善。

以下为整体思路概述


1. 问题定义

给定一个病人的临床诊断,能否预测他们是否患有心脏病?

2. 数据来源

https://archive.ics.uci.edu/ml/datasets/Heart+Disease

3. 评估

期望准确率达到95%

4. 特征和标签

数据字典

  1. age: age in years
  2. sex: sex (1 = male; 0 = female)
  3. cp: chest pain type
  • – Value 0: typical angina
  • – Value 1: atypical angina
  • – Value 2: non-anginal pain
  • – Value 3: asymptomatic
  1. trestbps: resting blood pressure (in mm Hg on admission to the hospital)
  2. chol: serum cholestoral in mg/dl
  3. fbs: (fasting blood sugar > 120 mg/dl) (1 = true; 0 = false)
  4. restecg: resting electrocardiographic results
  • – Value 0: normal
  • – Value 1: having ST-T wave abnormality (T wave inversions and/or ST elevation or depression of > 0.05 mV)
  • – Value 2: showing probable or definite left ventricular hypertrophy by Estes’ criteria
  1. thalach: maximum heart rate achieved
  2. exang: exercise induced angina (1 = yes; 0 = no)
  3. oldpeak = ST depression induced by exercise relative to rest
  4. slope: the slope of the peak exercise ST segment
  • – Value 0: upsloping
  • – Value 1: flat
  • – Value 2: downsloping
  1. ca: number of major vessels (0-3) colored by flourosopy
  2. thal: 0 = normal; 1 = fixed defect; 2 = reversable defect
  3. target: 0 = no disease, 1 = disease

0. 导包

# EDA
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
sns.set()
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
%config InlineBackend.figure_config = 'svg'

# sklearn模型
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier

# 模型评估
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.metrics import precision_score, recall_score, f1_score
from sklearn.metrics import plot_roc_curve

载入数据

hd_df = pd.read_csv('heart-disease.csv')
hd_df.shape
(303, 14)

1. EDA

了解更多有关这个数据集的信息,成为该数据集的懂王

  1. 要解决什么问题?
  2. 都有些什么数据,要怎么处理?
  3. 有无缺失值,如何处理?
  4. 有无异常值,如何处理?
  5. 如何通过创建衍生特征、处理和筛选现有特征得到更多信息?
hd_df.head()
agesexcptrestbpscholfbsrestecgthalachexangoldpeakslopecathaltarget
063131452331015002.30011
137121302500118703.50021
241011302040017201.42021
356111202360117800.82021
457001203540116310.62021
hd_df.tail()
agesexcptrestbpscholfbsrestecgthalachexangoldpeakslopecathaltarget
29857001402410112310.21030
29945131102640113201.21030
30068101441931114103.41230
30157101301310111511.21130
30257011302360017400.01120
# 查看样本分布
targets = hd_df['target'].value_counts()
targets
1    165
0    138
Name: target, dtype: int64
targets.plot(
    kind='bar', 
    color=['salmon', 'lightblue'],
    figsize=(10,6)
)
plt.xticks(rotation=0)
plt.show()

在这里插入图片描述

hd_df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 303 entries, 0 to 302
Data columns (total 14 columns):
 #   Column    Non-Null Count  Dtype  
---  ------    --------------  -----  
 0   age       303 non-null    int64  
 1   sex       303 non-null    int64  
 2   cp        303 non-null    int64  
 3   trestbps  303 non-null    int64  
 4   chol      303 non-null    int64  
 5   fbs       303 non-null    int64  
 6   restecg   303 non-null    int64  
 7   thalach   303 non-null    int64  
 8   exang     303 non-null    int64  
 9   oldpeak   303 non-null    float64
 10  slope     303 non-null    int64  
 11  ca        303 non-null    int64  
 12  thal      303 non-null    int64  
 13  target    303 non-null    int64  
dtypes: float64(1), int64(13)
memory usage: 33.3 KB
# 查看缺失值
hd_df.isna().sum()
age         0
sex         0
cp          0
trestbps    0
chol        0
fbs         0
restecg     0
thalach     0
exang       0
oldpeak     0
slope       0
ca          0
thal        0
target      0
dtype: int64
# 查看描述性统计信息
hd_df.describe([0.01, 0.25, 0.5, 0.75, 0.99]).T
countmeanstdmin1%25%50%75%99%max
age303.054.3663379.08210129.035.0047.555.061.071.0077.0
sex303.00.6831680.4660110.00.000.01.01.01.001.0
cp303.00.9669971.0320520.00.000.01.02.03.003.0
trestbps303.0131.62376217.53814394.0100.00120.0130.0140.0180.00200.0
chol303.0246.26402651.830751126.0149.00211.0240.0274.5406.74564.0
fbs303.00.1485150.3561980.00.000.00.00.01.001.0
restecg303.00.5280530.5258600.00.000.01.01.01.982.0
thalach303.0149.64686522.90516171.095.02133.5153.0166.0191.96202.0
exang303.00.3267330.4697940.00.000.00.01.01.001.0
oldpeak303.01.0396041.1610750.00.000.00.81.64.206.2
slope303.01.3993400.6162260.00.001.01.02.02.002.0
ca303.00.7293731.0226060.00.000.00.01.04.004.0
thal303.02.3135310.6122770.01.002.02.03.03.003.0
target303.00.5445540.4988350.00.000.01.01.01.001.0

查看性别和标签之间的关系

hd_df['sex'].value_counts()
1    207
0     96
Name: sex, dtype: int64
# cross_tab改进版函数
def to_cross_tab(origin_df, index_name, col_name):
    df = pd.crosstab(origin_df[index_name], origin_df[col_name])
    df['rate'] = df.iloc[:,1] / (df.iloc[:,0] + df.iloc[:,1])
    return df
sex_target_df = to_cross_tab(hd_df, 'target', 'sex')
sex_target_df
sex01rate
target
0241140.750000
172930.449275
# 方便绘图的函数
def to_plot(df, title, xlabel, ylabel, legend):
    df.plot(
    kind='bar', 
    color=['lightblue', 'salmon'],
    figsize=(10,6)
)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.xticks(rotation=0)
    plt.legend(legend)
    plt.show()
to_plot(sex_target_df[[0,1]], '按性别统计的心脏病概率', '0 = 女生, 1 = 男生', '总人数', ['未得病', '得病'])


在这里插入图片描述

明显女性发病率高得多


查看得病/未得病两类人中年龄和最大心率的关系

plt.figure(figsize=(10,6))

# 查看得病人群
plt.scatter(hd_df['age'][hd_df['target']==1],
            hd_df['thalach'][hd_df['target']==1],
            c='salmon'
)

# 查看未得病人群
plt.scatter(hd_df['age'][hd_df['target']==0],
            hd_df['thalach'][hd_df['target']==0],
            c='lightblue'
)

# 说明
plt.title('根据是否得心脏病分成两类人群来查看年龄和最大心率')
plt.xlabel('年龄')
plt.ylabel('最大心率')
plt.legend(['得病', '未得病'])

plt.show()


在这里插入图片描述

# 查看年龄分布
hd_df['age'].hist()
<AxesSubplot:>


在这里插入图片描述

# 做正态性检验
stats.normaltest(hd_df['age'])
NormaltestResult(statistic=8.74798581312778, pvalue=0.012600826063683705)

年龄符合正态分布


查看心绞痛类型和标签之间的关系

  1. cp: chest pain type
  • – Value 0: typical angina
  • – Value 1: atypical angina
  • – Value 2: non-anginal pain
  • – Value 3: asymptomatic
cp_target_df = to_cross_tab(hd_df, 'cp', 'target')
cp_target_df
target01rate
cp
0104390.272727
19410.820000
218690.793103
37160.695652
to_plot(cp_target_df[[0,1]], '按心绞痛类型统计的心脏病人数', '心绞痛类型', '总人数', ['未得病', '得病'])


在这里插入图片描述

# 相关系数矩阵
corr_matrix = hd_df.corr()
corr_matrix
agesexcptrestbpscholfbsrestecgthalachexangoldpeakslopecathaltarget
age1.000000-0.098447-0.0686530.2793510.2136780.121308-0.116211-0.3985220.0968010.210013-0.1688140.2763260.068001-0.225439
sex-0.0984471.000000-0.049353-0.056769-0.1979120.045032-0.058196-0.0440200.1416640.096093-0.0307110.1182610.210041-0.280937
cp-0.068653-0.0493531.0000000.047608-0.0769040.0944440.0444210.295762-0.394280-0.1492300.119717-0.181053-0.1617360.433798
trestbps0.279351-0.0567690.0476081.0000000.1231740.177531-0.114103-0.0466980.0676160.193216-0.1214750.1013890.062210-0.144931
chol0.213678-0.197912-0.0769040.1231741.0000000.013294-0.151040-0.0099400.0670230.053952-0.0040380.0705110.098803-0.085239
fbs0.1213080.0450320.0944440.1775310.0132941.000000-0.084189-0.0085670.0256650.005747-0.0598940.137979-0.032019-0.028046
restecg-0.116211-0.0581960.044421-0.114103-0.151040-0.0841891.0000000.044123-0.070733-0.0587700.093045-0.072042-0.0119810.137230
thalach-0.398522-0.0440200.295762-0.046698-0.009940-0.0085670.0441231.000000-0.378812-0.3441870.386784-0.213177-0.0964390.421741
exang0.0968010.141664-0.3942800.0676160.0670230.025665-0.070733-0.3788121.0000000.288223-0.2577480.1157390.206754-0.436757
oldpeak0.2100130.096093-0.1492300.1932160.0539520.005747-0.058770-0.3441870.2882231.000000-0.5775370.2226820.210244-0.430696
slope-0.168814-0.0307110.119717-0.121475-0.004038-0.0598940.0930450.386784-0.257748-0.5775371.000000-0.080155-0.1047640.345877
ca0.2763260.118261-0.1810530.1013890.0705110.137979-0.072042-0.2131770.1157390.222682-0.0801551.0000000.151832-0.391724
thal0.0680010.210041-0.1617360.0622100.098803-0.032019-0.011981-0.0964390.2067540.210244-0.1047640.1518321.000000-0.344029
target-0.225439-0.2809370.433798-0.144931-0.085239-0.0280460.1372300.421741-0.436757-0.4306960.345877-0.391724-0.3440291.000000
plt.figure(figsize=(14, 10))
sns.heatmap(
    corr_matrix, 
    vmin=-1, 
    annot=True, 
    linewidth=5, 
    fmt='.2f', 
    cmap='YlGnBu'
)
plt.show()


在这里插入图片描述

这个相关性看起来还是比较好的,大部分特征和标签之间都有一定的相关性,且特征之间也没有相关性>0.8的需要排除。当然,真的想看相关性还得分类别变量和连续值变量,连续值变量又得做正态检验。


3. 建模

hd_df.head()
agesexcptrestbpscholfbsrestecgthalachexangoldpeakslopecathaltarget
063131452331015002.30011
137121302500118703.50021
241011302040017201.42021
356111202360117800.82021
457001203540116310.62021
X = hd_df.drop(columns=['target'])
y = hd_df['target']
# 设置随机种子,便于其他人重复实验
np.random.seed(13)

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

依次使用逻辑斯蒂回归、KNN、随机森林

# 创建字典
models = {
    'lr': LogisticRegression(),
    'knn': KNeighborsClassifier(),
    'rf': RandomForestClassifier()
}

# 一个简单的试探性fit和score的函数
def fit_and_score(models, X_train, X_test, y_train, y_test):
    np.random.seed(13)
    model_score = {}
    for name, model in models.items():
        model.fit(X_train, y_train)
        model_score[name] = model.score(X_test, y_test)
    return model_score
model_scores = fit_and_score(models, X_train, X_test, y_train, y_test)
model_scores
 {'lr': 0.8360655737704918, 'knn': 0.639344262295082, 'rf': 0.819672131147541}

模型比较

model_compare = pd.DataFrame(model_scores, index=['正确率'])
model_compare.T.plot(kind='bar')
plt.xticks(rotation=0)
plt.show()

在这里插入图片描述

接下来做什么?

  • 超参数优化
  • 特征重要性
  • 混淆矩阵
  • 交叉验证
  • 精确率
  • 召回率
  • F1 score
  • 分类报告
  • ROC
  • AUC
# knn调参(假装不会GSCV和RSCV)
train_scores = []
test_scores = []

neighbors = range(1, 21)

knn = KNeighborsClassifier()
for i in n_neighbors:
    knn.set_params(n_neighbors=i)
    knn.fit(X_train, y_train)
    train_scores.append(knn.score(X_train, y_train))
    test_scores.append(knn.score(X_test, y_test))    
train_scores
[1.0,
 0.8016528925619835,
 0.8057851239669421,
 0.7603305785123967,
 0.768595041322314,
 0.7355371900826446,
 0.7396694214876033,
 0.71900826446281,
 0.7024793388429752,
 0.6900826446280992,
 0.7107438016528925,
 0.6859504132231405,
 0.7024793388429752,
 0.6776859504132231,
 0.6942148760330579,
 0.6859504132231405,
 0.6694214876033058,
 0.6859504132231405,
 0.7024793388429752,
 0.7066115702479339]
test_scores
[0.6065573770491803,
 0.4426229508196721,
 0.5737704918032787,
 0.5409836065573771,
 0.639344262295082,
 0.6557377049180327,
 0.6065573770491803,
 0.6721311475409836,
 0.6557377049180327,
 0.6557377049180327,
 0.6885245901639344,
 0.6885245901639344,
 0.6885245901639344,
 0.7377049180327869,
 0.7213114754098361,
 0.7213114754098361,
 0.7213114754098361,
 0.7049180327868853,
 0.7377049180327869,
 0.7377049180327869]
plt.plot(neighbors, train_scores, label='Train score')
plt.plot(neighbors, test_scores, label='Test score')
plt.xticks(range(1,21,1))
plt.xlabel('n_neighbors参数值')
plt.ylabel('正确率')
plt.legend()
plt.show()


在这里插入图片描述

knn最高分也没达到80%正确率,放弃


使用RandomizedSearchCV调参

# 逻辑斯蒂回归
# 由于主要是想找最优C值,其他参数就不设置了,并且这里使用np.logspace故意把C值分布得开一些,因为完全不知道在哪里取得最优值
log_reg_grid = {
    'C':np.logspace(-4, 4, 20),
    'solver': ['liblinear']
}

# 随机森林
rf_grid = {
    'n_estimators': np.arange(10, 1000, 50),
    'max_depth': [None, 3, 5, 10],
    'min_samples_split': np.arange(2, 20, 2),
    'min_samples_leaf': np.arange(1, 20, 2)
}
np.random.seed(13)

# 实例化RSCV对象
rs_log_reg = RandomizedSearchCV(
    LogisticRegression(),
    param_distributions=log_reg_grid,
    cv=5,
    n_iter=20,
    verbose=True
)
# fit
rs_log_reg.fit(X_train, y_train)
Fitting 5 folds for each of 20 candidates, totalling 100 fits
RandomizedSearchCV(cv=5, estimator=LogisticRegression(), n_iter=20,
               param_distributions={&#x27;C&#x27;: array([1.00000000e-04, 2.63665090e-04, 6.95192796e-04, 1.83298071e-03,
   4.83293024e-03, 1.27427499e-02, 3.35981829e-02, 8.85866790e-02,
   2.33572147e-01, 6.15848211e-01, 1.62377674e+00, 4.28133240e+00,
   1.12883789e+01, 2.97635144e+01, 7.84759970e+01, 2.06913808e+02,
   5.45559478e+02, 1.43844989e+03, 3.79269019e+03, 1.00000000e+04]),
                                    &#x27;solver&#x27;: [&#x27;liblinear&#x27;]},
               verbose=True)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class="sk-container" hidden><div class="sk-item sk-dashed-wrapped"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-4" type="checkbox" ><label for="sk-estimator-id-4" class="sk-toggleable__label sk-toggleable__label-arrow">RandomizedSearchCV</label><div class="sk-toggleable__content"><pre>RandomizedSearchCV(cv=5, estimator=LogisticRegression(), n_iter=20,
               param_distributions={&#x27;C&#x27;: array([1.00000000e-04, 2.63665090e-04, 6.95192796e-04, 1.83298071e-03,
   4.83293024e-03, 1.27427499e-02, 3.35981829e-02, 8.85866790e-02,
   2.33572147e-01, 6.15848211e-01, 1.62377674e+00, 4.28133240e+00,
   1.12883789e+01, 2.97635144e+01, 7.84759970e+01, 2.06913808e+02,
   5.45559478e+02, 1.43844989e+03, 3.79269019e+03, 1.00000000e+04]),
                                    &#x27;solver&#x27;: [&#x27;liblinear&#x27;]},
               verbose=True)</pre></div></div></div><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-5" type="checkbox" ><label for="sk-estimator-id-5" class="sk-toggleable__label sk-toggleable__label-arrow">estimator: LogisticRegression</label><div class="sk-toggleable__content"><pre>LogisticRegression()</pre></div></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-6" type="checkbox" ><label for="sk-estimator-id-6" class="sk-toggleable__label sk-toggleable__label-arrow">LogisticRegression</label><div class="sk-toggleable__content"><pre>LogisticRegression()</pre></div></div></div></div></div></div></div></div></div></div>
rs_log_reg.best_params_
{'solver': 'liblinear', 'C': 1.623776739188721}
rs_log_reg.score(X_test, y_test)
0.819672131147541

负提升,难绷,由于只是第一个项目,对调参仅做展示,就不管了

np.random.seed(13)

# 实例化RSCV对象
rs_rf = RandomizedSearchCV(
    RandomForestClassifier(),
    param_distributions=rf_grid,
    cv=5,
    n_iter=20,
    verbose=True
)
# fit
rs_rf.fit(X_train, y_train)
Fitting 5 folds for each of 20 candidates, totalling 100 fits
RandomizedSearchCV(cv=5, estimator=RandomForestClassifier(), n_iter=20,
               param_distributions={&#x27;max_depth&#x27;: [None, 3, 5, 10],
                                    &#x27;min_samples_leaf&#x27;: array([ 1,  3,  5,  7,  9, 11, 13, 15, 17, 19]),
                                    &#x27;min_samples_split&#x27;: array([ 2,  4,  6,  8, 10, 12, 14, 16, 18]),
                                    &#x27;n_estimators&#x27;: array([ 10,  60, 110, 160, 210, 260, 310, 360, 410, 460, 510, 560, 610,
   660, 710, 760, 810, 860, 910, 960])},
               verbose=True)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class="sk-container" hidden><div class="sk-item sk-dashed-wrapped"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-7" type="checkbox" ><label for="sk-estimator-id-7" class="sk-toggleable__label sk-toggleable__label-arrow">RandomizedSearchCV</label><div class="sk-toggleable__content"><pre>RandomizedSearchCV(cv=5, estimator=RandomForestClassifier(), n_iter=20,
               param_distributions={&#x27;max_depth&#x27;: [None, 3, 5, 10],
                                    &#x27;min_samples_leaf&#x27;: array([ 1,  3,  5,  7,  9, 11, 13, 15, 17, 19]),
                                    &#x27;min_samples_split&#x27;: array([ 2,  4,  6,  8, 10, 12, 14, 16, 18]),
                                    &#x27;n_estimators&#x27;: array([ 10,  60, 110, 160, 210, 260, 310, 360, 410, 460, 510, 560, 610,
   660, 710, 760, 810, 860, 910, 960])},
               verbose=True)</pre></div></div></div><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-8" type="checkbox" ><label for="sk-estimator-id-8" class="sk-toggleable__label sk-toggleable__label-arrow">estimator: RandomForestClassifier</label><div class="sk-toggleable__content"><pre>RandomForestClassifier()</pre></div></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-9" type="checkbox" ><label for="sk-estimator-id-9" class="sk-toggleable__label sk-toggleable__label-arrow">RandomForestClassifier</label><div class="sk-toggleable__content"><pre>RandomForestClassifier()</pre></div></div></div></div></div></div></div></div></div></div>
rs_rf.best_params_
{'n_estimators': 310,
 'min_samples_split': 16,
 'min_samples_leaf': 9,
 'max_depth': None}
rs_rf.score(X_test, y_test)
0.8360655737704918

有轻微提升


使用GSCV调参

这次稍微多用点参数

log_reg_grid = {
    'C':np.logspace(-4, 4, 30),
    'solver': ['liblinear', 'sag', 'saga', 'newton-cg', 'lbfgs'],
    'penalty': ['l1', 'l2']
}

# 实例化RSCV对象
gs_log_reg = GridSearchCV(
    LogisticRegression(),
    param_grid=log_reg_grid,
    cv=5,
    verbose=True
)
# fit
gs_log_reg.fit(X_train, y_train)
GridSearchCV(cv=5, estimator=LogisticRegression(),
         param_grid={&#x27;C&#x27;: array([1.00000000e-04, 1.88739182e-04, 3.56224789e-04, 6.72335754e-04,
   1.26896100e-03, 2.39502662e-03, 4.52035366e-03, 8.53167852e-03,
   1.61026203e-02, 3.03919538e-02, 5.73615251e-02, 1.08263673e-01,
   2.04335972e-01, 3.85662042e-01, 7.27895384e-01, 1.37382380e+00,
   2.59294380e+00, 4.89390092e+00, 9.23670857e+00, 1.74332882e+01,
   3.29034456e+01, 6.21016942e+01, 1.17210230e+02, 2.21221629e+02,
   4.17531894e+02, 7.88046282e+02, 1.48735211e+03, 2.80721620e+03,
   5.29831691e+03, 1.00000000e+04]),
                     &#x27;penalty&#x27;: [&#x27;l1&#x27;, &#x27;l2&#x27;],
                     &#x27;solver&#x27;: [&#x27;liblinear&#x27;, &#x27;sag&#x27;, &#x27;saga&#x27;, &#x27;newton-cg&#x27;,
                                &#x27;lbfgs&#x27;]},
         verbose=True)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class="sk-container" hidden><div class="sk-item sk-dashed-wrapped"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-14" type="checkbox" ><label for="sk-estimator-id-14" class="sk-toggleable__label sk-toggleable__label-arrow">GridSearchCV</label><div class="sk-toggleable__content"><pre>GridSearchCV(cv=5, estimator=LogisticRegression(),
         param_grid={&#x27;C&#x27;: array([1.00000000e-04, 1.88739182e-04, 3.56224789e-04, 6.72335754e-04,
   1.26896100e-03, 2.39502662e-03, 4.52035366e-03, 8.53167852e-03,
   1.61026203e-02, 3.03919538e-02, 5.73615251e-02, 1.08263673e-01,
   2.04335972e-01, 3.85662042e-01, 7.27895384e-01, 1.37382380e+00,
   2.59294380e+00, 4.89390092e+00, 9.23670857e+00, 1.74332882e+01,
   3.29034456e+01, 6.21016942e+01, 1.17210230e+02, 2.21221629e+02,
   4.17531894e+02, 7.88046282e+02, 1.48735211e+03, 2.80721620e+03,
   5.29831691e+03, 1.00000000e+04]),
                     &#x27;penalty&#x27;: [&#x27;l1&#x27;, &#x27;l2&#x27;],
                     &#x27;solver&#x27;: [&#x27;liblinear&#x27;, &#x27;sag&#x27;, &#x27;saga&#x27;, &#x27;newton-cg&#x27;,
                                &#x27;lbfgs&#x27;]},
         verbose=True)</pre></div></div></div><div class="sk-parallel"><div class="sk-parallel-item"><div class="sk-item"><div class="sk-label-container"><div class="sk-label sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-15" type="checkbox" ><label for="sk-estimator-id-15" class="sk-toggleable__label sk-toggleable__label-arrow">estimator: LogisticRegression</label><div class="sk-toggleable__content"><pre>LogisticRegression()</pre></div></div></div><div class="sk-serial"><div class="sk-item"><div class="sk-estimator sk-toggleable"><input class="sk-toggleable__control sk-hidden--visually" id="sk-estimator-id-16" type="checkbox" ><label for="sk-estimator-id-16" class="sk-toggleable__label sk-toggleable__label-arrow">LogisticRegression</label><div class="sk-toggleable__content"><pre>LogisticRegression()</pre></div></div></div></div></div></div></div></div></div></div>
gs_log_reg.best_params_
{'C': 221.22162910704503, 'penalty': 'l2', 'solver': 'lbfgs'}

和之前的分数一样…

gs_log_reg.score(X_test, y_test)
0.819672131147541

4. 评估

y_pred = gs_log_reg.predict(X_test)
y_pred
array([0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0,
       1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1,
       0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1], dtype=int64)
y_test
203    0
30     1
58     1
90     1
119    1
      ..
249    0
135    1
41     1
67     1
148    1
Name: target, Length: 61, dtype: int64
plot_roc_curve(gs_log_reg, X_test, y_test)
<sklearn.metrics._plot.roc_curve.RocCurveDisplay at 0x22b49b146a0>


在这里插入图片描述

y_pred==1
array([False,  True,  True,  True,  True, False, False, False,  True,
       False,  True,  True, False, False,  True,  True,  True, False,
       False, False, False, False,  True, False,  True,  True,  True,
        True, False, False,  True, False, False,  True, False, False,
        True, False,  True,  True,  True, False, False,  True, False,
        True,  True, False,  True,  True,  True, False,  True,  True,
        True, False, False,  True,  True,  True,  True])
# 混淆矩阵
def to_confusion_matrix(y_test, y_pred):
    return pd.DataFrame(
        data=confusion_matrix(y_test, y_pred), 
        index=pd.MultiIndex.from_product([['y_test'], [0, 1]]),
        columns=pd.MultiIndex.from_product([['y_pred'], [0, 1]])
    )
cf_matrix = to_confusion_matrix(y_test, y_pred)
cf_matrix
y_pred
01
y_test0215
1629
# 分类报告
print(classification_report(y_test, y_pred))
              precision    recall  f1-score   support

           0       0.78      0.81      0.79        26
           1       0.85      0.83      0.84        35

    accuracy                           0.82        61
   macro avg       0.82      0.82      0.82        61
weighted avg       0.82      0.82      0.82        61

利用交叉验证评估模型

利用交叉验证计算精确率、召回率、F1值

gs_log_reg.best_params_
{'C': 221.22162910704503, 'penalty': 'l2', 'solver': 'lbfgs'}
# 重新实例化逻辑斯蒂回归模型
clf = LogisticRegression(
    C=221.22162910704503, 
    penalty='l2', 
    solver='lbfgs'
)
# 交叉验证正确率
cv_acc = cross_val_score(
    clf, 
    X, 
    y, 
    cv=5,
    scoring='accuracy'
)
cv_acc
array([0.81967213, 0.83606557, 0.85245902, 0.83333333, 0.75      ])
cv_acc = np.mean(cv_acc)
cv_acc
0.8183060109289617
# 交叉验证精确率
cv_precision = cross_val_score(
    clf, 
    X, 
    y, 
    cv=5,
    scoring='precision'
)
cv_precision = np.mean(cv_precision)
cv_precision
0.8088942275474784
# 交叉验证召回率
cv_recall = cross_val_score(
    clf, 
    X, 
    y, 
    cv=5,
    scoring='recall'
)
cv_recall = np.mean(cv_recall)
cv_recall
0.8787878787878787
# 交叉验证F1值
cv_f1 = cross_val_score(
    clf, 
    X, 
    y, 
    cv=5,
    scoring='f1'
)
cv_f1 = np.mean(cv_f1)
cv_f1
0.8413377274453797
# 可视化
cv_metrics = pd.DataFrame(
    {'正确率': cv_acc,
     '精确率': cv_precision,
     '召回率': cv_recall,
     'f1值': cv_f1
    },
    index=[0]
)
cv_metrics.T.plot(
    kind='bar',
    legend=False
)
plt.xticks(rotation=0)
plt.show()


在这里插入图片描述

5. 评估特征重要性

clf.fit(X_train, y_train)
LogisticRegression(C=221.22162910704503)
clf.coef_
array([[ 0.00513208, -1.43253864,  0.78004753, -0.01083726, -0.0019836 ,
         0.0976912 ,  0.71562367,  0.03049414, -0.80027663, -0.44530236,
         0.53599288, -0.66841624, -1.15804589]])
feature_dict = dict(zip(hd_df.columns, clf.coef_[0]))
feature_dict
{'age': 0.005132076982516595,
 'sex': -1.4325386407347098,
 'cp': 0.7800475335340353,
 'trestbps': -0.010837256399792251,
 'chol': -0.001983600334944071,
 'fbs': 0.09769119644464817,
 'restecg': 0.7156236671955836,
 'thalach': 0.030494138473504826,
 'exang': -0.8002766264626233,
 'oldpeak': -0.44530236148020047,
 'slope': 0.5359928831085665,
 'ca': -0.6684162375711792,
 'thal': -1.158045891987526}
feature_df = pd.DataFrame(feature_dict, index=['feature_importance'])
feature_df.T.plot(
    kind='bar',
    title='Feature Importance',
    legend=False,
)
plt.xticks(rotation=30)
plt.show()

在这里插入图片描述


6. 继续实验

如果没有达到预期目标(比如这次定的95%正确率),则继续研究:

  • 还能收集更多数据吗?因为机器学习需要数据
  • 能不能换一个更好的模型?比如XGB、CatBoost
  • 可以继续调参优化吗?

如果已经达到了预期目标,想想:
怎么给其他人汇报工作结果?

Logo

更多推荐