博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
使用GridSearchCV寻找最佳参数组合——机器学习工具箱代码
阅读量:6340 次
发布时间:2019-06-22

本文共 2359 字,大约阅读时间需要 7 分钟。

 

 

# -*- coding: utf-8 -*-import numpy as npfrom sklearn.feature_extraction import FeatureHasherfrom sklearn import datasetsfrom sklearn.ensemble import GradientBoostingClassifierfrom sklearn.neighbors import KNeighborsClassifierimport xgboost as xgbfrom sklearn.model_selection import GridSearchCVfrom sklearn.model_selection import train_test_splitfrom sklearn import metricsfrom matplotlib import pyplot as pltfrom sklearn.ensemble import GradientBoostingClassifierfrom sklearn.model_selection import GridSearchCVdef report(test_Y, pred_Y):    print("accuracy_score:")    print(metrics.accuracy_score(test_Y, pred_Y))    print("f1_score:")    print(metrics.f1_score(test_Y, pred_Y))    print("recall_score:")    print(metrics.recall_score(test_Y, pred_Y))    print("precision_score:")    print(metrics.precision_score(test_Y, pred_Y))    print("confusion_matrix:")    print(metrics.confusion_matrix(test_Y, pred_Y))    print("AUC:")    print(metrics.roc_auc_score(test_Y, pred_Y))    f_pos, t_pos, thresh = metrics.roc_curve(test_Y, pred_Y)    auc_area = metrics.auc(f_pos, t_pos)    plt.plot(f_pos, t_pos, 'darkorange', lw=2, label='AUC = %.2f' % auc_area)    plt.legend(loc='lower right')    plt.plot([0, 1], [0, 1], color='navy', linestyle='--')    plt.title('ROC')    plt.ylabel('True Pos Rate')    plt.xlabel('False Pos Rate')    plt.show()if __name__== '__main__':    x, y = datasets.make_classification(n_samples=1000, n_features=100,n_redundant=0, random_state = 1)    train_X, test_X, train_Y, test_Y = train_test_split(x,                                                        y,                                                        test_size=0.2,                                                        random_state=66)    #clf = GradientBoostingClassifier(n_estimators=100)    #clf.fit(train_X, train_Y)    #pred_Y = clf.predict(test_X)    #report(test_Y, pred_Y)    scoring= "f1"    parameters ={'n_estimators': range( 50, 200, 25), 'max_depth': range( 2, 10, 2)}    gsearch = GridSearchCV(estimator= GradientBoostingClassifier(), param_grid= parameters, scoring='accuracy', iid= False, cv= 5)     gsearch.fit(x, y)    print("gsearch.best_params_")     print(gsearch.best_params_)     print("gsearch.best_score_")     print(gsearch.best_score_)

 效果:

gsearch.best_params_

{'max_depth': 4, 'n_estimators': 100}
gsearch.best_score_
0.868142228555714

转载地址:http://qhhoa.baihongyu.com/

你可能感兴趣的文章
小程序爆红 专家:对简单APP是巨大打击
查看>>
FarBox--另类有趣的网站服务【转】
查看>>
在非纯色背景上,叠加背景透明的BUTTON和STATIC_TEXT控件
查看>>
Distributed2:Linked Server Login 添加和删除
查看>>
Java中取两位小数
查看>>
使用 ftrace 调试 Linux 内核【转】
查看>>
唯一聚集索引上的唯一和非唯一非聚集索引
查看>>
Spark新愿景:让深度学习变得更加易于使用——见https://github.com/yahoo/TensorFlowOnSpark...
查看>>
linux磁盘配额
查看>>
NFS文件共享服务器的搭建
查看>>
IP_VFR-4-FRAG_TABLE_OVERFLOW【cisco设备报错】碎片***
查看>>
Codeforces Round #256 (Div. 2) D. Multiplication Table 【二分】
查看>>
ARM汇编指令格式
查看>>
HDU-2044-一只小蜜蜂
查看>>
HDU-1394-Minimum Inversion Number
查看>>
[转] createObjectURL方法 实现本地图片预览
查看>>
JavaScript—DOM编程核心.
查看>>
JavaScript碎片
查看>>
Bootstrap-下拉菜单
查看>>
soapUi 接口测试
查看>>