Source code for horovod.torch

# Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
# Modifications copyright Microsoft
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from contextlib import contextmanager

import io
import warnings

import cloudpickle

from horovod.common.util import check_extension

try:
    check_extension('horovod.torch', 'HOROVOD_WITH_PYTORCH',
                    __file__, 'mpi_lib_v2')
except:
    check_extension('horovod.torch', 'HOROVOD_WITH_PYTORCH',
                    __file__, 'mpi_lib', '_mpi_lib')

try:
    from collections.abc import Iterable
except ImportError:
    from collections import Iterable


from horovod.torch.compression import Compression
from horovod.torch.mpi_ops import allreduce, allreduce_async, allreduce_, allreduce_async_
from horovod.torch.mpi_ops import allgather, allgather_async
from horovod.torch.mpi_ops import broadcast, broadcast_async, broadcast_, broadcast_async_
from horovod.torch.mpi_ops import join
from horovod.torch.mpi_ops import poll, synchronize
from horovod.torch.mpi_ops import init, shutdown
from horovod.torch.mpi_ops import size, local_size, rank, local_rank
from horovod.torch.mpi_ops import mpi_threads_supported, mpi_enabled, mpi_built
from horovod.torch.mpi_ops import gloo_enabled, gloo_built
from horovod.torch.mpi_ops import nccl_built, ddl_built, ccl_built
from horovod.torch.mpi_ops import Average, Sum, Adasum
from horovod.torch.sync_batch_norm import SyncBatchNorm

import torch
import collections


# Please run this function in a subprocess
def _check_has_gpu():
    import torch
    return torch.cuda.is_available()


class _DistributedOptimizer(torch.optim.Optimizer):
    def __init__(self, params, named_parameters, compression,
                 backward_passes_per_step=1, op=Average):
        super(self.__class__, self).__init__(params)
        self._compression = compression

        if named_parameters is not None:
            named_parameters = list(named_parameters)
        else:
            named_parameters = [('allreduce.noname.%s' % i, v)
                                for param_group in self.param_groups
                                for i, v in enumerate(param_group['params'])]
        # make sure that named_parameters are tuples
        if any([not isinstance(p, tuple) for p in named_parameters]):
            raise ValueError('named_parameters should be a sequence of '
                             'tuples (name, parameter), usually produced by '
                             'model.named_parameters().')

        dups = _DistributedOptimizer.find_duplicates([k for k, _ in named_parameters])
        if len(dups) > 0:
            raise ValueError('Parameter names in named_parameters must be unique. '
                             'Found duplicates: %s' % ', '.join(dups))

        all_param_ids = {id(v)
                         for param_group in self.param_groups
                         for v in param_group['params']}
        named_param_ids = {id(v) for k, v in named_parameters}
        unnamed_param_ids = all_param_ids - named_param_ids
        if len(unnamed_param_ids):
            raise ValueError('named_parameters was specified, but one or more model '
                             'parameters were not named. Python object ids: '
                             '%s' % ', '.join(str(id) for id in unnamed_param_ids))

        self._parameter_names = {v: k for k, v in sorted(named_parameters)}
        self.backward_passes_per_step = backward_passes_per_step
        self._allreduce_delay = {v: self.backward_passes_per_step
                                 for _, v in sorted(named_parameters)}
        self.op = op
        self._handles = {}
        self._grad_accs = []
        self._requires_update = set()
        self._synchronized = False
        self._should_synchronize = True
        if size() > 1:
            self._register_hooks()

    @staticmethod
    def find_duplicates(lst):
        seen = set()
        dups = set()
        for el in lst:
            if el in seen:
                dups.add(el)
            seen.add(el)
        return dups

    def set_backward_passes_per_step(self, passes):
        self.backward_passes_per_step = passes
        for p in self._allreduce_delay:
            self._allreduce_delay[p] = self.backward_passes_per_step

    def _register_hooks(self):
        for param_group in self.param_groups:
            for p in param_group['params']:
                if p.requires_grad:
                    p.grad = p.data.new(p.size()).zero_()
                    self._requires_update.add(p)
                    p_tmp = p.expand_as(p)
                    grad_acc = p_tmp.grad_fn.next_functions[0][0]
                    grad_acc.register_hook(self._make_hook(p))
                    self._grad_accs.append(grad_acc)

    def _allreduce_grad_async(self, p):
        name = self._parameter_names.get(p)
        tensor = p.grad
        tensor_compressed, ctx = self._compression.compress(tensor)

        handle = allreduce_async_(tensor_compressed, name=name, op=self.op)
        return handle, ctx

    def _make_hook(self, p):
        def hook(*ignore):
            if p in self._handles and self._handles[p][0] is not None:
                if self._allreduce_delay[p] <= 0:
                    raise AssertionError(
                        "Gradients were computed more than "
                        "backward_passes_per_step times before call "
                        "to step(). Increase backward_passes_per_step to "
                        "accumulate gradients locally.")
            assert not p.grad.requires_grad
            assert self._allreduce_delay[p] > 0
            handle, ctx = None, None
            self._allreduce_delay[p] -= 1
            if self._allreduce_delay[p] == 0:
                handle, ctx = self._allreduce_grad_async(p)
            self._handles[p] = (handle, ctx)
        return hook

    def synchronize(self):
        missing_p = self._requires_update - set(self._handles.keys())
        for p in missing_p:
            handle, ctx = self._allreduce_grad_async(p)
            self._handles[p] = (handle, ctx)

        for p, value in self._handles.items():
            handle, ctx = value
            if handle is None:
                handle, ctx = self._allreduce_grad_async(p)
                self._handles[p] = (handle, ctx)
        for p, (handle, _) in self._handles.items():
            output = synchronize(handle)
            self._allreduce_delay[p] = self.backward_passes_per_step
            p.grad.set_(self._compression.decompress(output, ctx))
        self._handles.clear()

        self._synchronized = True

    @contextmanager
    def skip_synchronize(self):
        """
        A context manager used to specify that optimizer.step() should
        not perform synchronization.

        It's typically used in a following pattern:

        .. code-block:: python

            optimizer.synchronize()
            with optimizer.skip_synchronize():
                optimizer.step()
        """
        self._should_synchronize = False
        try:
            yield
        finally:
            self._should_synchronize = True

    def step(self, closure=None):
        if self._should_synchronize:
            if self._synchronized:
                warnings.warn("optimizer.step() called without "
                              "optimizer.skip_synchronize() context after "
                              "optimizer.synchronize(). This can cause training "
                              "slowdown. You may want to consider using "
                              "optimizer.skip_synchronize() context if you use "
                              "optimizer.synchronize() in your code.")
            self.synchronize()
        self._synchronized = False
        return super(self.__class__, self).step(closure)

    def zero_grad(self):
        if self._handles:
            raise AssertionError("optimizer.zero_grad() was called after loss.backward() "
                                 "but before optimizer.step() or optimizer.synchronize(). "
                                 "This is prohibited as it can cause a race condition.")
        return super(self.__class__, self).zero_grad()


class _DistributedAdasumOptimizer(torch.optim.Optimizer):
    def __init__(self, params, named_parameters, compression,
                 backward_passes_per_step=1):
        super(self.__class__, self).__init__(params)

        self._compression = compression

        if named_parameters is not None:
            named_parameters = list(named_parameters)
        else:
            named_parameters = [('allreduce.noname.%s' % i, v)
                                for param_group in self.param_groups
                                for i, v in enumerate(param_group['params'])]

        # make sure that named_parameters are tuples
        if any([not isinstance(p, tuple) for p in named_parameters]):
            raise ValueError('named_parameters should be a sequence of '
                             'tuples (name, parameter), usually produced by '
                             'model.named_parameters().')

        dups = _DistributedOptimizer.find_duplicates([k for k, _ in named_parameters])
        if len(dups) > 0:
            raise ValueError('Parameter names in named_parameters must be unique. '
                             'Found duplicates: %s' % ', '.join(dups))

        all_param_ids = {id(v)
                         for param_group in self.param_groups
                         for v in param_group['params']}
        named_param_ids = {id(v) for k, v in named_parameters}
        unnamed_param_ids = all_param_ids - named_param_ids
        if len(unnamed_param_ids):
            raise ValueError('named_parameters was specified, but one or more model '
                             'parameters were not named. Python object ids: '
                             '%s' % ', '.join(str(id) for id in unnamed_param_ids))

        self._parameter_names = {v: k for k, v in sorted(named_parameters)}
        self.backward_passes_per_step = backward_passes_per_step
        self._allreduce_delay = {v: self.backward_passes_per_step
                                 for _, v in sorted(named_parameters)}
        self._handles = {}
        self._grad_accs = []
        self._requires_update = set()
        self._synchronized = False
        self._should_synchronize = True

        self._starting_models = {
            p : torch.zeros_like(p, requires_grad=False)
            for _, p in named_parameters
        }

        self._register_hooks()

    def set_backward_passes_per_step(self, passes):
        self.backward_passes_per_step = passes
        for p in self._allreduce_delay:
            self._allreduce_delay[p] = self.backward_passes_per_step

    def _register_hooks(self):
        for param_group in self.param_groups:
            for p in param_group['params']:
                if p.requires_grad:
                    p.grad = p.data.new(p.size()).zero_()
                    self._requires_update.add(p)
                    p_tmp = p.expand_as(p)
                    grad_acc = p_tmp.grad_fn.next_functions[0][0]
                    grad_acc.register_hook(self._make_hook(p))
                    self._grad_accs.append(grad_acc)

    def _allreduce_grad_async(self, p):
        # Delta optimizer implements this logic:
        #  start = current.copy()
        #  step() -> computes 'current - \alpha.f(g)' where f is
        #            optimizer logic and g is the gradient
        #  delta = current-start
        #  allreduce_(delta)
        #  start += delta
        #  current = start
        # In order to suppport this logic using function hook to improve performance,
        # we do:
        # delta = (start - \alpha.f(g)) - start
        #       = -\alpha.f(g)
        # set start to zero and step computes -\alpha.f(g)
        # where f is the underlying optimizer logic

        name = self._parameter_names.get(p)
        start = self._starting_models[p]
    
        stashed_params = []
        for group in self.param_groups:
            stashed_params.append(group['params'])
            # only want to step on p
            if any([p is v for v in group['params']]):
                group['params'] = [p]
            else:
                group['params'] = []

        start.data.copy_(p)

        super(self.__class__, self).step()

        # compute delta = curr - start
        p.data.sub_(start)
        
        # allreduce as before
        tensor_compressed, ctx = self._compression.compress(p)
        handle = allreduce_async_(tensor_compressed.data, name=name, op=Adasum)

        # reset stashed parameters
        for stashed, group in zip(stashed_params, self.param_groups):
            group['params'] = stashed        

        return handle, ctx

    def _make_hook(self, p):
        def hook(*ignore):
            if p in self._handles and self._handles[p][0] is not None:
                if self._allreduce_delay[p] <= 0:
                    raise AssertionError(
                        "Gradients were computed more than "
                        "backward_passes_per_step times before call "
                        "to step(). Increase backward_passes_per_step to "
                        "accumulate gradients locally.")
            assert not p.grad.requires_grad
            assert self._allreduce_delay[p] > 0
            handle, ctx = None, None
            self._allreduce_delay[p] -= 1
            if self._allreduce_delay[p] == 0:
                handle, ctx = self._allreduce_grad_async(p)
            self._handles[p] = (handle, ctx)
        return hook

    def synchronize(self):
        pass

    @contextmanager
    def skip_synchronize(self):
        raise AssertionError("Skipping synchronization is not supported when using Adasum optimizer.")

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        missing_p = self._requires_update - set(self._handles.keys())
        for p in missing_p:
            handle, ctx = self._allreduce_grad_async(p)
            self._handles[p] = (handle, ctx)

        for p, (handle, ctx) in self._handles.items():
            # This means step() is called before backward_passes_per_steps finished.
            # We do a synchoronous allreduce here.
            if not handle:
                handle, ctx = self._allreduce_grad_async(p)
                self._handles[p] = (handle, ctx)
            delta = synchronize(handle)
            delta = self._compression.decompress(delta, ctx)            
            start = self._starting_models[p]
            start.data.add_(delta.data)
            p.data.copy_(start)
            self._allreduce_delay[p] = self.backward_passes_per_step
        self._handles.clear()
        return loss

    def zero_grad(self):
        if self._handles:
            raise AssertionError("optimizer.zero_grad() was called after loss.backward() "
                                 "but before optimizer.step() or optimizer.synchronize(). "
                                 "This is prohibited as it can cause a race condition.")
        return super(self.__class__, self).zero_grad()


[docs]def DistributedOptimizer(optimizer, named_parameters=None, compression=Compression.none, backward_passes_per_step=1, op=Average): """ An optimizer that wraps another torch.optim.Optimizer, using an allreduce to combine gradient values before applying gradients to model weights. Allreduce operations are executed after each gradient is computed by ``loss.backward()`` in parallel with each other. The ``step()`` method ensures that all allreduce operations are finished before applying gradients to the model. DistributedOptimizer exposes the ``synchronize()`` method, which forces allreduce operations to finish before continuing the execution. It's useful in conjunction with gradient clipping, or other operations that modify gradients in place before ``step()`` is executed. Make sure to use ``optimizer.skip_synchronize()`` if you're calling ``synchronize()`` in your code. Example of gradient clipping: .. code-block:: python output = model(data) loss = F.nll_loss(output, target) loss.backward() optimizer.synchronize() torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) with optimizer.skip_synchronize(): optimizer.step() Arguments: optimizer: Optimizer to use for computing gradients and applying updates. named_parameters: A mapping between parameter names and values. Used for naming of allreduce operations. Typically just ``model.named_parameters()``. compression: Compression algorithm used during allreduce to reduce the amount of data sent during the each parameter update step. Defaults to not using compression. backward_passes_per_step: Number of expected backward passes to perform before calling step()/synchronize(). This allows accumulating gradients over multiple mini-batches before reducing and applying them. op: The reduction operation to use when combining gradients across different ranks. """ # We dynamically create a new class that inherits from the optimizer that was passed in. # The goal is to override the `step()` method with an allreduce implementation. if op != Adasum or size() == 1: cls = type(optimizer.__class__.__name__, (optimizer.__class__,), dict(_DistributedOptimizer.__dict__)) return cls(optimizer.param_groups, named_parameters, compression, backward_passes_per_step, op) else: cls = type(optimizer.__class__.__name__, (optimizer.__class__,), dict(_DistributedAdasumOptimizer.__dict__)) return cls(optimizer.param_groups, named_parameters, compression, backward_passes_per_step)
[docs]def broadcast_parameters(params, root_rank): """ Broadcasts the parameters from root rank to all other processes. Typical usage is to broadcast the ``model.state_dict()``, ``model.named_parameters()``, or ``model.parameters()``. Arguments: params: One of the following: - list of parameters to broadcast - dict of parameters to broadcast root_rank: The rank of the process from which parameters will be broadcasted to all other processes. """ if isinstance(params, dict): params = sorted(params.items()) elif isinstance(params, list): # support both named_parameters() and regular parameters() params = [p if isinstance(p, tuple) else (None, p) for p in params] else: raise ValueError('invalid params of type: %s' % type(params)) # Run asynchronous broadcasts. handles = [] for name, p in params: handle = broadcast_async_(p, root_rank, name) handles.append(handle) # Wait for completion. for handle in handles: synchronize(handle)
[docs]def broadcast_optimizer_state(optimizer, root_rank): """ Broadcasts an optimizer state from root rank to all other processes. Arguments: optimizer: An optimizer. root_rank: The rank of the process from which the optimizer will be broadcasted to all other processes. """ if isinstance(optimizer, torch.optim.LBFGS): # TODO(travis): L-BFGS cannot be easily supported without serializing # the entire state_dict, as its structure is deeply nested and contains # None type parameter values raise ValueError('cannot broadcast torch.optim.LBFGS state') state_dict = optimizer.state_dict() # Newly created optimizers will not have their state initialized, so # do that initialization here if len(state_dict['state']) == 0: for group in optimizer.param_groups: for p in group['params']: if p.requires_grad and id(p) not in state_dict['state']: p.grad = p.data.new(p.size()).zero_() # This function accepts a torch.optim.Optimizer or a DistributedOptimizer # wrapped around a torch optimizer. Calling step() with a DistributedOptimizer # forces allreduce on all model parameters, which will result in deadlock # unless every rank calls step(). Therefore, to finish state initialization # only call optimizer.step() with a torch.optim.Optimizer. if optimizer.__module__ == DistributedOptimizer.__module__: super(optimizer.__class__, optimizer).step() else: optimizer.step() state_dict = optimizer.state_dict() # If the state_dict is still empty after initialization, then # the optimizer is stateless, and there is nothing to broadcast. # Furthermore, attempting to access the state dict would result in # an error. if len(state_dict['state']) == 0: return params = [] callbacks = {} occurrences = collections.defaultdict(int) # Returns the full type structure of the possibly nested objects for recursive casting back def _get_types(x): if isinstance(x, Iterable): return type(x), [_get_types(xi) for xi in x] else: return type(x) # Casts an object encoded in a tensor back into its original type and subtypes def _recursive_cast(x, dtype): if isinstance(dtype, tuple): t, dtypes = dtype x = t(x) return t([_recursive_cast(x[i], dtypes[i]) for i in range(len(x))]) else: return dtype(x) # Some optimizer parameters may be represented as scalars instead of # tensors. In such cases, we need to wrap the scalar in a tensor, then # broadcast, then update the appropriate value in the state_dict with the # new unwrapped scalar value via a callback. def _create_callback(pid, name, t, p): def _from_tensor(): state_dict['state'][pid][name] = t(p.cpu().numpy()[0]) return _from_tensor def _create_option_callback(index, option_key, option_tensor, dtypes): def _from_tensor(): optimizer.param_groups[index][option_key] = _recursive_cast(option_tensor.cpu().numpy()[0], dtypes) return _from_tensor # Param groups are an ordered list, normally there is only one per model, # but users can add additional param groups for example to train # previously frozen layers for index, group in enumerate(state_dict['param_groups']): # Broadcast options like learning rate for option_key, option_value in group.items(): if option_key == 'params': continue # Options like the learning rate are scalar, and need to be wrapped in tensors key = '%s.%d' % (option_key, index) dtypes = _get_types(option_value) option_tensor = torch.Tensor([option_value]) callbacks[key] = _create_option_callback(index, option_key, option_tensor, dtypes) params.append((key, option_tensor)) # The params list here is ordered by the layers in the model for pid in group['params']: if pid not in state_dict['state']: # The param has not set requires_grad, so skip broadcast continue param_state = state_dict['state'][pid] for name, p in param_state.items(): # Some parameter names may appear more than once, in which # case we ensure they have a unique identifier defined by # their order occurrences[name] += 1 key = '%s.%d' % (str(name), occurrences[name]) if not torch.is_tensor(p): # Wrap the scalar in a FloatTensor, and remember its type # so we can cast it back after unwrapping t = type(p) p = torch.Tensor([p]) callbacks[key] = _create_callback(pid, name, t, p) params.append((key, p)) # Synchronized broadcast of all parameters broadcast_parameters(params, root_rank) # Post-broadcast cleanup for non-tensor parameters for key, p in params: if key in callbacks: callbacks[key]()
[docs]def broadcast_object(obj, root_rank, name=None): """ Serializes and broadcasts an object from root rank to all other processes. Typical usage is to broadcast the `optimizer.state_dict()`, for example: .. code-block:: python state_dict = broadcast_object(optimizer.state_dict(), 0) if hvd.rank() > 0: optimizer.load_state_dict(state_dict) Arguments: obj: An object capable of being serialized without losing any context. root_rank: The rank of the process from which parameters will be broadcasted to all other processes. name: Optional name to use during broadcast, will default to the class type. Returns: The object that was broadcast from the `root_rank`. """ if name is None: name = str(type(obj)) if rank() == root_rank: b = io.BytesIO() cloudpickle.dump(obj, b) t = torch.ByteTensor(bytearray(b.getvalue())) sz = torch.IntTensor([t.shape[0]]) broadcast_(sz, root_rank, name + '.sz') else: sz = torch.IntTensor([0]) broadcast_(sz, root_rank, name + '.sz') t = torch.ByteTensor(sz.tolist()[0]) broadcast_(t, root_rank, name + '.t') if rank() != root_rank: buf = io.BytesIO(t.numpy().tobytes()) obj = cloudpickle.load(buf) return obj