(tensorflow)tf.keras.callbacks.ModelCheckpoint在训练期间保存模型

时间:4年前   阅读:11536

在训练期间保存模型 

可以使用训练好的模型而无需从头开始重新训练,或在您打断的地方开始训练,以防止训练过程没有保存。

 tf.keras.callbacks.ModelCheckpoint 允许在训练的过程中和结束时回调保存的模型。tf.keras.callbacks.ModelCheckpoint对象是回调对象,该回调对象可以在每一个周期保存模型 

实例化方法:

tf.keras.callbacks.ModelCheckpoint(
    filepath, monitor='val_loss', verbose=0, save_best_only=False,
    save_weights_only=False, mode='auto', save_freq='epoch', **kwargs
)

Arguments:

  • filepath:字符串类型,储存模型的路径

  • monitor: 监视指标val_acc或者val_loss

  • verbose: 日志模型, 0 或者 1.

  • save_best_only: 布尔型,如果为true,根据监视指标,最后一个最好的模型将不会被覆盖,如果filepath 没有包含格式化选项{epoch} ,保存的模型文件将会被新周期里更好的模型覆盖。

  • mode: 可选填{'auto', 'min', 'max'}中的一个,如果save_best_only被设置为true,那么覆盖操作的执行,将根据监视指标和这个参数来决定,比如max val_acc 和min val_loss ,如果被天填auto则会跟俊监视指标指定选择max 还是min

  • save_weights_only: 如果参数为真,则只保存模型的权值,而非整个模型,这个参数只要影响该回调对象是调用(model.save_weights(filepath)方法还是调用model.save(filepath).

  • save_freq: 填写epoch或者一个数字,如果是epoch时,回调会在每个epoch之后保存模型。当使用integer时,回调将在处理n个样本后保存模型,默认是'epoch'

例如:

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
 
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)
history = model.fit(x_train, y_train,
                    batch_size=64,
                    epochs=10, callbacks=[cp_callback])
Train on 60000 samples
Epoch 1/10
58304/60000 [============================>.] - ETA: 0s - loss: 0.3180
Epoch 00001: saving model to training_1/cp.ckpt
60000/60000 [==============================] - 9s 150us/sample - loss: 0.3145
Epoch 2/10
58432/60000 [============================>.] - ETA: 0s - loss: 0.1484
Epoch 00002: saving model to training_1/cp.ckpt
60000/60000 [==============================] - 2s 32us/sample - loss: 0.1476


版权声明:本文为期权记的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。

原文链接:https://www.qiquanji.com/post/9668.html

微信扫码关注

更新实时通知

上一篇:tensorflow(tf)keras save h5、hdf5模型,loadmodel现AttributeError: ‘str‘ object has no attribute ‘decode‘

下一篇:TensorFlow创建自定义类继承tf.layers.Layer创建新的layer层,自定义类继承keras.Model创建自定义model

网友评论

请先 登录 再评论,若不是会员请先 注册