(tensorflow)tf.keras.callbacks.ModelCheckpoint在训练期间保存模型
时间:4年前 阅读:11536
在训练期间保存模型
可以使用训练好的模型而无需从头开始重新训练,或在您打断的地方开始训练,以防止训练过程没有保存。
tf.keras.callbacks.ModelCheckpoint 允许在训练的过程中和结束时回调保存的模型。tf.keras.callbacks.ModelCheckpoint对象是回调对象,该回调对象可以在每一个周期保存模型
实例化方法:
tf.keras.callbacks.ModelCheckpoint( filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', save_freq='epoch', **kwargs )
Arguments:
filepath
:字符串类型,储存模型的路径monitor
: 监视指标val_acc
或者val_loss
verbose
: 日志模型, 0 或者 1.save_best_only
: 布尔型,如果为true,根据监视指标,最后一个最好的模型将不会被覆盖,如果filepath
没有包含格式化选项{epoch}
,保存的模型文件将会被新周期里更好的模型覆盖。mode
: 可选填{'auto', 'min', 'max'}
中的一个,如果save_best_only
被设置为true,那么覆盖操作的执行,将根据监视指标和这个参数来决定,比如max
val_acc
和min
val_loss
,如果被天填auto
则会跟俊监视指标指定选择max
还是min
save_weights_only
: 如果参数为真,则只保存模型的权值,而非整个模型,这个参数只要影响该回调对象是调用(model.save_weights(filepath)
方法还是调用model.save(filepath)
.save_freq
: 填写epoch
或者一个数字,如果是epoch
时,回调会在每个epoch之后保存模型。当使用integer时,回调将在处理n个样本后保存模型,默认是'epoch'
例如:
checkpoint_path = "training_1/cp.ckpt" checkpoint_dir = os.path.dirname(checkpoint_path) cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, verbose=1) history = model.fit(x_train, y_train, batch_size=64, epochs=10, callbacks=[cp_callback])
Train on 60000 samples Epoch 1/10 58304/60000 [============================>.] - ETA: 0s - loss: 0.3180 Epoch 00001: saving model to training_1/cp.ckpt 60000/60000 [==============================] - 9s 150us/sample - loss: 0.3145 Epoch 2/10 58432/60000 [============================>.] - ETA: 0s - loss: 0.1484 Epoch 00002: saving model to training_1/cp.ckpt 60000/60000 [==============================] - 2s 32us/sample - loss: 0.1476
版权声明:本文为期权记的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://www.qiquanji.com/post/9668.html
微信扫码关注
更新实时通知
网友评论