使用MXnet时如何保存模型
deep-learning
mxnet
r
6
0

我正在使用MXnet训练CN(在R中),并且可以使用以下代码来训练模型而不会出现任何错误:

model <- mx.model.FeedForward.create(symbol=network,
                                     X=train.iter,
                                     ctx=mx.gpu(0),
                                     num.round=20,
                                     array.batch.size=batch.size,
                                     learning.rate=0.1,
                                     momentum=0.1,  
                                     eval.metric=mx.metric.accuracy,
                                     wd=0.001,
                                     batch.end.callback=mx.callback.log.speedometer(batch.size, frequency = 100)
    )

但是由于此过程很耗时,因此我在夜间在服务器上运行它,并且我想保存模型以便在完成培训后使用它。

我用了:

save(list = ls(), file="mymodel.RData")

mx.model.save("mymodel", 10)

但是它们都无法保存模型!例如,当我加载"mymodel.RData" ,我无法预测测试集的标签!

另一个示例是当我加载"mymodel.RData"并尝试使用以下代码对其进行绘制时:

graph.viz(model$symbol$as.json())

我收到以下错误:

Error in model$symbol$as.json() : external pointer is not valid

有人可以给我一个保存然后加载此模型以供将来使用的解决方案吗?

谢谢

参考资料:
Stack Overflow
收藏
评论
共 2 个回答
高赞 时间 活跃

一个mxnet模型是一个R列表,但是它的第一个组件不是R对象而是C ++指针,并且不能保存并重新加载为R对象。因此,需要对模型进行序列化以使其表现为实际的R对象。序列化的对象也是一个列表,但其第一个对象是包含模型信息的文本字符串。

要保存模型:

modelR <- mx.serialize(model)
save(modelR, file="~/model1.RData")

要检索它并再次使用它:

load("~/model1.RData", verbose=TRUE)
model <- mx.unserialize(modelR)
收藏
评论

您可以通过以下方式保存模型

model <- mx.model.FeedForward.create(symbol=network,
                                 X=train.iter,
                                 ctx=mx.gpu(0),
                                 num.round=20,
                                 array.batch.size=batch.size,
                                 learning.rate=0.1,
                                 momentum=0.1,  
                                 eval.metric=mx.metric.accuracy,
                                 wd=0.001,
                                 epoch.end.callback=mx.callback.save.checkpoint("model_prefix")
                                 batch.end.callback=mx.callback.log.speedometer(batch.size, frequency = 100)
)
收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号