# Source code for horovod.keras.callbacks

# Copyright 2018 Uber Technologies, Inc. All Rights Reserved.
#
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# ==============================================================================

import keras
import keras.backend as K

from horovod._keras import callbacks as _impl

"""
Keras Callback that will broadcast all global variables from root rank
to all other processes during initialization.

This is necessary to ensure consistent initialization of all workers when
training is started with random weights or restored from a checkpoint.
"""

[docs]    def __init__(self, root_rank, device=''):
"""
global variables from root rank to all other processes during initialization.

Args:
root_rank: Rank that will send data, other ranks will receive data.
device: Device to be used for broadcasting. Uses GPU by default
if Horovod was build with HOROVOD_GPU_OPERATIONS.
"""

[docs]class MetricAverageCallback(_impl.MetricAverageCallbackImpl, keras.callbacks.Callback):
"""
Keras Callback that will average metrics across all processes at the
end of the epoch. Useful in conjuction with ReduceLROnPlateau,
TensorBoard and other metrics-based callbacks.

Note: This callback must be added to the callback list before the
ReduceLROnPlateau, TensorBoard or other metrics-based callbacks.
"""

[docs]    def __init__(self, device=''):
"""
Construct a new MetricAverageCallback that will average metrics
across all processes at the end of the epoch.

Args:
device: Device to be used for allreduce. Uses GPU by default
if Horovod was build with HOROVOD_GPU_OPERATIONS.
"""
super(MetricAverageCallback, self).__init__(K, device)

[docs]class LearningRateScheduleCallback(_impl.LearningRateScheduleCallbackImpl, keras.callbacks.Callback):
"""
LearningRateScheduleCallback sets learning rate between epochs start_epoch and
end_epoch to be initial_lr * multiplier.  multiplier can be a constant or
a function f(epoch) = lr'.

If multiplier is a function and staircase=True, learning rate adjustment will
happen at the beginning of each epoch and the epoch passed to the multiplier
function will be an integer.

If multiplier is a function and staircase=False, learning rate adjustment will
happen at the beginning of each batch and the epoch passed to the multiplier
function will be a floating number: epoch' = epoch + batch / steps_per_epoch.
This functionality is useful for smooth learning rate adjustment schedulers, such
as LearningRateWarmupCallback.

initial_lr is the learning rate of the model optimizer at the start of the training.
"""

[docs]    def __init__(self, initial_lr, multiplier, start_epoch=0, end_epoch=None, staircase=True,
momentum_correction=True, steps_per_epoch=None):
"""
Construct a new LearningRateScheduleCallback.

Args:
initial_lr: Initial learning rate at the start of training.
multiplier: A constant multiplier or a function f(epoch) = lr'
start_epoch: The first epoch this adjustment will be applied to. Defaults to 0.
end_epoch: The epoch this adjustment will stop applying (exclusive end).
Defaults to None.
staircase: Whether to adjust learning rate at the start of epoch (staircase=True)
or at the start of every batch (staircase=False).
momentum_correction: Apply momentum correction to optimizers that have momentum.
Defaults to True.
steps_per_epoch: The callback will attempt to autodetect number of batches per
epoch with Keras >= 2.0.0. Provide this value if you have an older
version of Keras.
"""
super(LearningRateScheduleCallback, self).__init__(K, initial_lr, multiplier, start_epoch, end_epoch,
staircase, momentum_correction, steps_per_epoch)

[docs]class LearningRateWarmupCallback(_impl.LearningRateWarmupCallbackImpl, keras.callbacks.Callback):
"""

lr = initial_lr / hvd.size() ---> lr = initial_lr

initial_lr is the learning rate of the model optimizer at the start of the training.

This technique was described in the paper "Accurate, Large Minibatch SGD: Training
ImageNet in 1 Hour". See https://arxiv.org/pdf/1706.02677.pdf for details.

Math recap:

.. math::

epoch &= full\\_epochs + \\frac{batch}{steps\\_per\\_epoch}

lr'(epoch) &= \\frac{lr}{size} * (\\frac{size - 1}{warmup} * epoch + 1)

lr'(epoch = 0) &= \\frac{lr}{size}

lr'(epoch = warmup) &= lr
"""

[docs]    def __init__(self, initial_lr, warmup_epochs=5, momentum_correction=True, steps_per_epoch=None,
verbose=0):
"""
Construct a new LearningRateWarmupCallback that will gradually warm up the learning rate.

Args:
initial_lr: Initial learning rate at the start of training.
warmup_epochs: The number of epochs of the warmup phase. Defaults to 5.
momentum_correction: Apply momentum correction to optimizers that have momentum.
Defaults to True.
steps_per_epoch: The callback will attempt to autodetect number of batches per
epoch with Keras >= 2.0.0. Provide this value if you have an older
version of Keras.
verbose: verbosity mode, 0 or 1.
"""
super(LearningRateWarmupCallback, self).__init__(K, initial_lr, warmup_epochs, momentum_correction,
steps_per_epoch, verbose)

class BestModelCheckpoint(keras.callbacks.ModelCheckpoint):
def __init__(self,
monitor='val_loss',
verbose=0,
mode='auto',
save_freq='epoch'):
super(BestModelCheckpoint, self).__init__(filepath=None,
monitor=monitor,
verbose=verbose,
save_best_only=True,
save_weights_only=False,
mode=mode,
save_freq=save_freq)