打开APP
userphoto
未登录

开通VIP,畅享免费电子书等14项超值服

开通VIP
Keras如何保存训练模型

一、保存模型

方法一:通过Checkpoint保存

在Keras中有ModelCheckpoint函数,调用该函数可以将每个epoch后的模型进行保存。详见官方文档。具体的使用方法如下:

from keras.callbacks import ModelCheckpointcheckpoint = ModelCheckpoint(filepath,                             monitor='val_loss',                             verbose=0,                             save_best_only=False,                             save_weights_only=False,                             mode='auto',                             period=1)

参数:

  • filepath: 保存模型的路径

  • monitor: 被监测的对象,比如acc, loss, val_acc,...

  • verbose: 冗余的,如果想加上进度条,就是1,如果不想,就是0

  • set_best_only = True: 只保存最好的模型

  • save_weights_only = False: 如果 True,那么只有模型的权重会被保存 model.save_weights(filepath), 否则的话,整个模型会被保存 model.save(filepath)

  • mode = 'auto': auto maxmin,如果使监测acc,就是max,如果是监测loss,就是minauto就i会从监测值中自己判断

  • period: 每个检查点之间的间隔(训练轮数)

方法二:通过save_model()保存

保存模型有两种形式,一种是将模型整个都保存下来,包括权重和结构;另一种是只保留权重。

1. 1保存整个模型

可以使用model.save(filepath)将Keras的模型保存到HDF5文件中,该文件将包含:模型结构、模型权重、配置项(优化函数、优化器)和优化状态,允许准确地从上次结束地地方继续训练,详见官方文档

from keras.model import load_modelmodel.save('my_model.h5')model = load_model('my_model.h5')

1.2 保存部分模型

1.2.1 分别保存模型的权重和结构

我们可以使用to_json()to_yaml()方法将模型结构保存到josn文件或者yaml文件中。

'''方法1:保存为json'''json_string = model.to_json()open('model_architecture_1.json', 'w').write(json_string)   #重命名#从json中读出数据from keras.models import model_from_jsonmodel = model_from_json(json_string)
'''方法2:保存为yaml'''yaml_string = model.to_yaml()open('model_arthitecture_2.yaml', 'w').write(yaml_string)  #重命名#从json中读出数据from keras.models import model_from_yamlmodel = model_from_yaml(yaml_string)

1.2.2 只保存模型权重

通过model.save_weights('my_model_weights.h5')将权重保存在HDF5文件中。如果有可以实例化模型的代码。则可以将保存的权重加载到相同结构的模型中:

model.load_weights('my_model_weights.h5')

二、导入模型

如果我们想导入训练好的最好的模型来进行预测,最好使用方法一,将最好的模型保存下来然后导入进行预测,如果想接着上一次的模型继续训练,可以两种方法都可以。

from keras.models import load_modelmodel = load_model(filepath)

保存模型可能会出错的地方:

  1. filepath不会自己建立文件夹,例如checkpointer = ModelCheckpoint(filepath='tmp\model.h5'),如果同目录下没有tmp文件夹,程序将会出错,模型无法保存。

  2. 在进行训练的时候,需要把checkpointer加进去回调函数callbacks中,并用中括号扩起来,即model.fit(x, y, callbacks=[checkpointer])

附:保存模型图

我们可以通过model.summary()将模型的结构打印出来,另外,我们可以通过plot_model(model, 'model_plot.png')将模型的基本结构框图保存下来。另外你也可以使用keras.utils.vis_utils模块将模型的详细结构框图保存下来。

from keras.utils import plot_modelplot_model(model, to_file='model.png')

参考文档:
官方文档
Keras保存最好的模型
Keras框架下的保存模型和加载模型

本站仅提供存储服务,所有内容均由用户发布,如发现有害或侵权内容,请点击举报
打开APP,阅读全文并永久保存 查看更多类似文章
猜你喜欢
类似文章
【热】打开小程序,算一算2024你的财运
FAQ
基于matplotlib和keras的神经网络结果可视化
机器学习之keras模型保存为pb文件
用LSTM检测垃圾邮件LSTM Spam Detection
自己动手做一个识别手写数字的web应用01
DL框架之Keras:深度学习框架Keras框架的简介、安装(Python库)、相关概念、Keras模型使用、使用方法之详细攻略
更多类似文章 >>
生活服务
热点新闻
分享 收藏 导长图 关注 下载文章
绑定账号成功
后续可登录账号畅享VIP特权!
如果VIP功能使用有故障,
可点击这里联系客服!

联系客服