如何使用keras保存最终模型?
keras
machine-learning
python
9
0

我使用KerasClassifier训练分类器。

代码如下:

import numpy
from pandas import read_csv
from keras.models import Sequential
from keras.layers import Dense
from keras.wrappers.scikit_learn import KerasClassifier
from keras.utils import np_utils
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import KFold
from sklearn.preprocessing import LabelEncoder
from sklearn.pipeline import Pipeline
# fix random seed for reproducibility
seed = 7
numpy.random.seed(seed)
# load dataset
dataframe = read_csv("iris.csv", header=None)
dataset = dataframe.values
X = dataset[:,0:4].astype(float)
Y = dataset[:,4]
# encode class values as integers
encoder = LabelEncoder()
encoder.fit(Y)
encoded_Y = encoder.transform(Y)
#print("encoded_Y")
#print(encoded_Y)
# convert integers to dummy variables (i.e. one hot encoded)
dummy_y = np_utils.to_categorical(encoded_Y)
#print("dummy_y")
#print(dummy_y)
# define baseline model
def baseline_model():
    # create model
    model = Sequential()
    model.add(Dense(4, input_dim=4, init='normal', activation='relu'))
    #model.add(Dense(4, init='normal', activation='relu'))
    model.add(Dense(3, init='normal', activation='softmax'))
    # Compile model
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

    return model

estimator = KerasClassifier(build_fn=baseline_model, nb_epoch=200, batch_size=5, verbose=0)
#global_model = baseline_model()
kfold = KFold(n_splits=10, shuffle=True, random_state=seed)
results = cross_val_score(estimator, X, dummy_y, cv=kfold)
print("Accuracy: %.2f%% (%.2f%%)" % (results.mean()*100, results.std()*100))

但是如何保存最终模型以供将来预测?

我通常使用以下代码保存模型:

# serialize model to JSON
model_json = model.to_json()
with open("model.json", "w") as json_file:
    json_file.write(model_json)
# serialize weights to HDF5
model.save_weights("model.h5")
print("Saved model to disk")

但是我不知道如何将保存模型的代码插入KerasClassifier的代码中。

谢谢。

参考资料:
Stack Overflow
收藏
评论
共 4 个回答
高赞 时间 活跃

您可以通过这种方式保存模型并加载。

from keras.models import Sequential
from keras_contrib.losses import import crf_loss
from keras_contrib.metrics import crf_viterbi_accuracy

# To save model
model.save('my_model_01.hdf5')

# To load the model
custom_objects={'CRF': CRF,'crf_loss':crf_loss,'crf_viterbi_accuracy':crf_viterbi_accuracy}

# To load a persisted model that uses the CRF layer 
model1 = load_model("/home/abc/my_model_01.hdf5", custom_objects = custom_objects)
收藏
评论

您可以将模型保存为json,并将权重保存为hdf5文件格式。

# keras library import  for Saving and loading model and weights

from keras.models import model_from_json
from keras.models import load_model

# serialize model to JSON
#  the keras model which is trained is defined as 'model' in this example
model_json = model.to_json()


with open("model_num.json", "w") as json_file:
    json_file.write(model_json)

# serialize weights to HDF5
model.save_weights("model_num.h5")

创建包含模型和权重的文件“ model_num.h5”和“ model_num.json”

要使用相同的训练模型进行进一步测试,您只需加载hdf5文件并将其用于预测不同的数据。这是从保存的文件中加载模型的方法。

# load json and create model
json_file = open('model_num.json', 'r')

loaded_model_json = json_file.read()
json_file.close()
loaded_model = model_from_json(loaded_model_json)

# load weights into new model
loaded_model.load_weights("model_num.h5")
print("Loaded model from disk")

loaded_model.save('model_num.hdf5')
loaded_model=load_model('model_num.hdf5')

要预测不同的数据,您可以使用此

loaded_model.predict_classes("your_test_data here")
收藏
评论

该模型具有save方法,该方法保存重构模型所需的所有详细信息。来自keras文档的示例:

from keras.models import load_model

model.save('my_model.h5')  # creates a HDF5 file 'my_model.h5'
del model  # deletes the existing model

# returns a compiled model
# identical to the previous one
model = load_model('my_model.h5')
收藏
评论

您可以使用model.save(filepath)model.save(filepath)模型保存到单个HDF5文件中,该文件包含:

  • 模型的架构,从而可以重新创建模型。
  • 模型的权重。
  • 训练配置(损失,优化器)
  • 优化程序的状态,从而可以从您上次中断的地方继续正确地进行训练。

在您的Python代码中,最后一行可能是:

model.save("m.hdf5")

这使您可以将模型的整个状态保存在单个文件中。可以通过keras.models.load_model()重新实例化保存的模型。

load_model()返回的模型是已准备好可以使用的已编译模型(除非保存的模型从未被首先编译过)。

model.save()参数:

  • filepath:字符串,要将权重保存到的文件的路径。
  • 覆盖:是在目标位置静默覆盖任何现有文件,还是向用户提供手动提示。
  • include_optimizer:如果为True,则将优化器的状态保存在一起。
收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号