将sklearn的GridSearchCV与管道一起使用,只需预处理一次
machine-learning
numpy
python
scikit-learn
4
0

我正在使用scickit-learn来调整模型的超参数。我正在使用管道将预处理器与估算器链接在一起。我的问题的一个简单版本如下所示:

import numpy as np
from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression


grid = GridSearchCV(make_pipeline(StandardScaler(), LogisticRegression()),
                    param_grid={'logisticregression__C': [0.1, 10.]},
                    cv=2,
                    refit=False)

_ = grid.fit(X=np.random.rand(10, 3),
             y=np.random.randint(2, size=(10,)))

在我的情况下,预处理(在玩具示例中为StandardScale())很耗时,并且我没有调整其任何参数。

因此,当我执行示例时,StandardScaler被执行12次。 2个适合/预测* 2个简历* 3个参数。但是,每次对参数C的不同值执行StandardScaler时,它都会返回相同的输出,因此效率更高,只需计算一次,然后运行管道的估计器部分即可。

我可以在预处理(不调整超参数)和估计器之间手动拆分管道。但是要将预处理应用于数据,我应该只提供训练集。因此,我将不得不手动实现拆分,而根本不使用GridSearchCV。

有没有简单/标准的方法来避免在使用GridSearchCV时重复进行预处理?

参考资料:
Stack Overflow
收藏
评论
共 2 个回答
高赞 时间 活跃

在当前版本的scikit-learn(0.18.1)中无法执行此操作。已针对github项目提出了一个修复程序:

https://github.com/scikit-learn/scikit-learn/issues/8830

https://github.com/scikit-learn/scikit-learn/pull/8322

收藏
评论

本质上,GridSearchCV还是一个估计器,实现管道使用的fit()和predict()方法。

所以代替:

grid = GridSearchCV(make_pipeline(StandardScaler(), LogisticRegression()),
                    param_grid={'logisticregression__C': [0.1, 10.]},
                    cv=2,
                    refit=False)

做这个:

clf = make_pipeline(StandardScaler(), 
                    GridSearchCV(LogisticRegression(),
                                 param_grid={'logisticregression__C': [0.1, 10.]},
                                 cv=2,
                                 refit=True))

clf.fit()
clf.predict()

它的作用是,仅一次调用StandardScalar(),一次调用clf.fit()而不是您所描述的多次调用。

编辑:

当在管道内部使用GridSearchCV时,将refit更改为True 。如文档中所述

refit:boolean,default = True用整个数据集重新拟合最佳估计量。如果为“ False”,则拟合后将无法使用此GridSearchCV实例进行预测。

如果refit = False,则clf.fit()将无效,因为管道内部的GridSearchCV对象将在fit()之后重新初始化。当refit=True ,将对在fit()传递的整个数据使用最佳评分参数组合重新fit() GridSearchCV。

因此,如果您要制作管道,仅查看网格搜索的分数,则只有refit=False合适。如果要调用clf.predict()方法,则必须使用refit=True ,否则将引发Not Fitted错误。

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号