keras:如何保存历史对象的训练历史属性
deep-learning
keras
machine-learning
neural-network
5
0

在Keras中,我们可以将model.fit的输出返回到历史记录,如下所示:

 history = model.fit(X_train, y_train, 
                     batch_size=batch_size, 
                     nb_epoch=nb_epoch,
                     validation_data=(X_test, y_test))

现在,如何将历史记录对象的历史记录属性保存到文件中以供进一步使用(例如,绘制acc或针对历时的损失图)?

参考资料:
Stack Overflow
收藏
评论
共 5 个回答
高赞 时间 活跃

另一种方法是:

由于history.historydict ,因此您也可以将其转换为pandas DataFrame对象,然后可以将其保存以满足您的需求。

一步步:

import pandas as pd

# assuming you stored your model.fit results in a 'history' variable:
history = model.fit(x_train, y_train, epochs=10)

# convert the history.history dict to a pandas DataFrame:     
hist_df = pd.DataFrame(history.history) 

# save to json:  
hist_json_file = 'history.json' 
with open(hist_json_file, mode='w') as f:
    hist_df.to_json(f)

# or save to csv: 
hist_csv_file = 'history.csv'
with open(hist_csv_file, mode='w') as f:
    hist_df.to_csv(f)
收藏
评论

可以将model历史记录保存到文件中,如下所示

import json
hist = model.fit(X_train, y_train, epochs=5, batch_size=batch_size,validation_split=0.1)
with open('file.json', 'w') as f:
    json.dump(hist.history, f)
收藏
评论

history对象具有一个history字段,该history字段是一个字典,其中包含跨每个训练时期的不同训练指标。因此,例如, history.history['loss'][99]将在训练的第100个时期返回模型损失。为了节省您可以pickle该词典或简单地将该词典中的其他列表保存到适当的文件中。

收藏
评论

我使用的是以下内容:

    with open('/trainHistoryDict', 'wb') as file_pi:
        pickle.dump(history.history, file_pi)

这样,我便将历史记录另存为字典,以备日后绘制损失或准确性时使用。

收藏
评论

我遇到了一个问题,即在keras中列表内的值不是json可序列化的。因此,出于使用原因,我编写了这两个方便的函数。

import json,codecs
import numpy as np
def saveHist(path,history):

    new_hist = {}
    for key in list(history.history.keys()):
        if type(history.history[key]) == np.ndarray:
            new_hist[key] = history.history[key].tolist()
        elif type(history.history[key]) == list:
           if  type(history.history[key][0]) == np.float64:
               new_hist[key] = list(map(float, history.history[key]))

    print(new_hist)
    with codecs.open(path, 'w', encoding='utf-8') as file:
        json.dump(new_hist, file, separators=(',', ':'), sort_keys=True, indent=4) 

def loadHist(path):
    with codecs.open(path, 'r', encoding='utf-8') as file:
        n = json.loads(file.read())
    return n

其中saveHist仅需要获取应保存json文件的路径,以及从keras fitfit_generator方法返回的历史对象。

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号