# Copyright 2020 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import collections
import io
from collections.abc import Iterable
import cloudpickle
import torch
from horovod.torch.mpi_ops import allgather, broadcast_, broadcast_async_
from horovod.torch.mpi_ops import synchronize
from horovod.torch.mpi_ops import rank, size
[docs]def broadcast_parameters(params, root_rank):
"""
Broadcasts the parameters from root rank to all other processes.
Typical usage is to broadcast the ``model.state_dict()``,
``model.named_parameters()``, or ``model.parameters()``.
Arguments:
params: One of the following:
- list of parameters to broadcast
- dict of parameters to broadcast
root_rank: The rank of the process from which parameters will be
broadcasted to all other processes.
"""
if isinstance(params, dict):
params = sorted(params.items())
elif isinstance(params, list):
# support both named_parameters() and regular parameters()
params = [p if isinstance(p, tuple) else (None, p) for p in params]
else:
raise ValueError('invalid params of type: %s' % type(params))
# Run asynchronous broadcasts.
handles = []
for name, p in params:
handle = broadcast_async_(p, root_rank, name)
handles.append(handle)
# Wait for completion.
for handle in handles:
synchronize(handle)
[docs]def broadcast_optimizer_state(optimizer, root_rank):
"""
Broadcasts an optimizer state from root rank to all other processes.
Arguments:
optimizer: An optimizer.
root_rank: The rank of the process from which the optimizer will be
broadcasted to all other processes.
"""
from horovod.torch.optimizer import DistributedOptimizer
if isinstance(optimizer, torch.optim.LBFGS):
# TODO(travis): L-BFGS cannot be easily supported without serializing
# the entire state_dict, as its structure is deeply nested and contains
# None type parameter values
raise ValueError('cannot broadcast torch.optim.LBFGS state')
state_dict = optimizer.state_dict()
# Newly created optimizers will not have their state initialized, so
# do that initialization here
if len(state_dict['state']) == 0:
for group in optimizer.param_groups:
for p in group['params']:
if p.requires_grad and id(p) not in state_dict['state']:
p.grad = p.data.new(p.size()).zero_()
if isinstance(optimizer, torch.optim.SparseAdam):
p.grad = p.grad.to_sparse()
# This function accepts a torch.optim.Optimizer or a DistributedOptimizer
# wrapped around a torch optimizer. Calling step() with a DistributedOptimizer
# forces allreduce on all model parameters, which will result in deadlock
# unless every rank calls step(). Therefore, to finish state initialization
# only call optimizer.step() with a torch.optim.Optimizer.
if optimizer.__module__ == DistributedOptimizer.__module__:
super(optimizer.__class__, optimizer).step()
else:
optimizer.step()
state_dict = optimizer.state_dict()
# If the state_dict is still empty after initialization, then
# the optimizer is stateless, and there is nothing to broadcast.
# Furthermore, attempting to access the state dict would result in
# an error.
if len(state_dict['state']) == 0:
return
params = []
scalars = {}
callbacks = {}
occurrences = collections.defaultdict(int)
# Returns the full type structure of the possibly nested objects for recursive casting back
def _get_types(x):
if isinstance(x, Iterable):
return type(x), [_get_types(xi) for xi in x]
else:
return type(x)
# Casts an object encoded in a tensor back into its original type and subtypes
def _recursive_cast(x, dtype):
if isinstance(dtype, tuple):
t, dtypes = dtype
x = t(x)
return t([_recursive_cast(x[i], dtypes[i]) for i in range(len(x))])
else:
return dtype(x)
# Some optimizer parameters may be represented as scalars instead of
# tensors. In such cases, we place the scalars into a single dict,
# then pickle and broadcast with broadcast_object (under the assumption
# that there are not many scalars, and so the overhead of pickling will
# be relatively low). Because broadcast_obect is performed out-of-place,
# we then use a callback to assign the new value to the correct element
# of the optimizer state.
def _create_state_callback(pid, name):
def _assign_state(v):
state_dict['state'][pid][name] = v
return _assign_state
def _create_option_callback(index, option_key):
def _assign_option(v):
optimizer.param_groups[index][option_key] = v
return _assign_option
# Param groups are an ordered list, normally there is only one per model,
# but users can add additional param groups for example to train
# previously frozen layers
for index, group in enumerate(state_dict['param_groups']):
# Broadcast options like learning rate
for option_key, option_value in group.items():
if option_key == 'params':
continue
# Options like the learning rate are scalar, and need to be broadcast separately
key = '%s.%d' % (option_key, index)
scalars[key] = option_value
callbacks[key] = _create_option_callback(index, option_key)
# The params list here is ordered by the layers in the model
for pid in group['params']:
if pid not in state_dict['state']:
# The param has not set requires_grad, so skip broadcast
continue
param_state = state_dict['state'][pid]
for name, p in param_state.items():
# Some parameter names may appear more than once, in which
# case we ensure they have a unique identifier defined by
# their order
occurrences[name] += 1
key = '%s.%d' % (str(name), occurrences[name])
if torch.is_tensor(p):
# Tensor -> use broadcast_parameters
params.append((key, p))
else:
# Scalar -> use broadcast_object
scalars[key] = p
callbacks[key] = _create_state_callback(pid, name)
# Synchronized broadcast of all tensor parameters
broadcast_parameters(params, root_rank)
# Broadcast and cleanup for non-tensor parameters
scalars = broadcast_object(scalars, root_rank)
for key, p in scalars.items():
callbacks[key](p)
[docs]def broadcast_object(obj, root_rank=0, name=None):
"""
Serializes and broadcasts an object from root rank to all other processes.
Typical usage is to broadcast the `optimizer.state_dict()`, for example:
.. code-block:: python
state_dict = broadcast_object(optimizer.state_dict(), 0)
if hvd.rank() > 0:
optimizer.load_state_dict(state_dict)
Arguments:
obj: An object capable of being serialized without losing any context.
root_rank: The rank of the process from which parameters will be
broadcasted to all other processes.
name: Optional name to use during broadcast, will default to the class
type.
Returns:
The object that was broadcast from the `root_rank`.
"""
if name is None:
name = type(obj).__name__
if rank() == root_rank:
b = io.BytesIO()
cloudpickle.dump(obj, b)
t = torch.ByteTensor(bytearray(b.getvalue()))
sz = torch.IntTensor([t.shape[0]])
broadcast_(sz, root_rank, name + '.sz')
else:
sz = torch.IntTensor([0])
broadcast_(sz, root_rank, name + '.sz')
t = torch.ByteTensor(sz.tolist()[0])
broadcast_(t, root_rank, name + '.t')
if rank() != root_rank:
buf = io.BytesIO(t.numpy().tobytes())
obj = cloudpickle.load(buf)
return obj
[docs]def allgather_object(obj, name=None):
"""
Serializes and allgathers an object from all other processes.
Arguments:
obj: An object capable of being serialized without losing any context.
name: Optional name to use during allgather, will default to the class
type.
Returns:
The list of objects that were allgathered across all ranks.
"""
if name is None:
name = type(obj).__name__
def load(byte_array):
buf = io.BytesIO(byte_array.tobytes())
return cloudpickle.load(buf)
b = io.BytesIO()
cloudpickle.dump(obj, b)
t = torch.ByteTensor(bytearray(b.getvalue()))
sz = torch.IntTensor([t.shape[0]])
sizes = allgather(sz, name=name + '.sz').numpy()
gathered = allgather(t, name=name + '.t').numpy()
def select(i):
start = sum(sizes[:i])
end = start + sizes[i]
return gathered[start:end]
return [load(select(i)) for i in range(size())]