将张量分为训练集和测试集
cross-validation
tensorflow
7
0

假设我已经使用TextLineReader读取了文本文件。有没有办法在Tensorflow其拆分为训练集和测试集?就像是:

def read_my_file_format(filename_queue):
  reader = tf.TextLineReader()
  key, record_string = reader.read(filename_queue)
  raw_features, label = tf.decode_csv(record_string)
  features = some_processing(raw_features)
  features_train, labels_train, features_test, labels_test = tf.train_split(features,
                                                                            labels,
                                                                            frac=.1)
  return features_train, labels_train, features_test, labels_test
参考资料:
Stack Overflow
收藏
评论
共 4 个回答
高赞 时间 活跃

正如elham所提到的,您可以使用scikit-learn轻松完成此操作。 scikit-learn是一个用于机器学习的开源库。有大量的数据准备工具,包括model_selection模块,该模块处理比较,验证和选择参数。

model_selection.train_test_split()方法经过专门设计,可将您的数据随机且按百分比分为训练集和测试集。

X_train, X_test, y_train, y_test = train_test_split(features,
                                                    labels,
                                                    test_size=0.33,
                                                    random_state=42)

test_size是保留用于测试的百分比, random_state是播种随机采样的种子。

我通常使用它来提供训练和验证数据集,并分别保存真实的测试数据。您也可以只运行两次train_test_split来执行此操作。即将数据分为(训练+验证)和测试,然后将训练+验证分为两个单独的张量。

收藏
评论

使用tf.data.Dataset api的map和filter函数,我设法获得了不错的结果。只需使用地图功能即可在训练和测试之间随机选择示例。为此,对于每个示例,您都可以从均匀分布中获取样本,并检查样本值是否低于速率划分。

def split_train_test(parsed_features, train_rate):
    parsed_features['is_train'] = tf.gather(tf.random_uniform([1], maxval=100, dtype=tf.int32) < tf.cast(train_rate * 100, tf.int32), 0)
    return parsed_features

def grab_train_examples(parsed_features):
    return parsed_features['is_train']

def grab_test_examples(parsed_features):
    return ~parsed_features['is_train']
收藏
评论
import sklearn.model_selection as sk

X_train, X_test, y_train, y_test = 
sk.train_test_split(features,labels,test_size=0.33, random_state = 42)
收藏
评论

像下面这样的东西应该起作用: tf.split_v(tf.random_shuffle(...

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号