如何使用张量流执行k倍交叉验证?
cross-validation
python
tensorflow
6
0

我正在遵循IRIS的tensorflow示例

我现在的情况是,我将所有数据保存在一个CSV文件中,没有分开,我想对该数据应用k倍交叉验证。

我有

data_set = tf.contrib.learn.datasets.base.load_csv(filename="mydata.csv",
                                                   target_dtype=np.int)

与IRIS示例一样,如何使用多层神经网络对此数据集执行k折交叉验证?

参考资料:
Stack Overflow
收藏
评论
共 2 个回答
高赞 时间 活跃

NN通常用于不使用CV的大型数据集-而且非常昂贵。对于IRIS(每种物种有50个样本),您可能需要它。为什么不使用带有不同随机种子的scikit-learn来分开训练和测试?

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

对于k倍的k:

  1. 拆分数据,将不同的值传递给“ random_state”
  2. 使用_train学习网络
  3. 使用_test进行测试

如果您不喜欢随机种子,并且想要更结构化的k倍拆分,则可以使用此处的方法

from sklearn.model_selection import KFold, cross_val_score
X = ["a", "a", "b", "c", "c", "c"]
k_fold = KFold(n_splits=3)
for train_indices, test_indices in k_fold.split(X):
    print('Train: %s | test: %s' % (train_indices, test_indices))
Train: [2 3 4 5] | test: [0 1]
Train: [0 1 4 5] | test: [2 3]
Train: [0 1 2 3] | test: [4 5]
收藏
评论

我知道这个问题很旧,但是如果有人想做类似的事情,请扩大ahmedhosny的回答:

新的tensorflow数据集API能够使用python生成器创建数据集对象,因此与scikit-learn的KFold一起使用的一个选项是从KFold.split()生成器创建数据集:

import numpy as np

from sklearn.model_selection import LeaveOneOut,KFold

import tensorflow as tf
import tensorflow.contrib.eager as tfe
tf.enable_eager_execution()

from sklearn.datasets import load_iris
data = load_iris()
X=data['data']
y=data['target']

def make_dataset(X_data,y_data,n_splits):

    def gen():
        for train_index, test_index in KFold(n_splits).split(X_data):
            X_train, X_test = X_data[train_index], X_data[test_index]
            y_train, y_test = y_data[train_index], y_data[test_index]
            yield X_train,y_train,X_test,y_test

    return tf.data.Dataset.from_generator(gen, (tf.float64,tf.float64,tf.float64,tf.float64))

dataset=make_dataset(X,y,10)

然后,可以在基于图的张量流中或使用热切的执行来遍历数据集。使用热切的执行:

for X_train,y_train,X_test,y_test in tfe.Iterator(dataset):
    ....
收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号