Source code for horovod.tensorflow.mpi_ops

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Modifications copyright (C) 2019 Uber Technologies, Inc.
# Modifications copyright Microsoft
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Inter-process communication using MPI."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import re
import tensorflow as tf
from tensorflow.python.framework import load_library
from tensorflow.python.framework import ops
from tensorflow.python.platform import resource_loader

from horovod.common.util import get_ext_suffix, get_average_backwards_compatibility_fun, gpu_available, \
    num_rank_is_power_2
from horovod.common.basics import HorovodBasics as _HorovodBasics
from horovod.tensorflow.util import _executing_eagerly


def _load_library(name):
    """Loads a .so file containing the specified operators.

    Args:
      name: The name of the .so file to load.

    Raises:
      NotFoundError if were not able to load .so file.
    """
    filename = resource_loader.get_path_to_datafile(name)
    library = load_library.load_op_library(filename)
    return library


MPI_LIB = _load_library('mpi_lib' + get_ext_suffix())

_basics = _HorovodBasics(__file__, 'mpi_lib')

# import basic methods
init = _basics.init
shutdown = _basics.shutdown
size = _basics.size
local_size = _basics.local_size
rank = _basics.rank
local_rank = _basics.local_rank
mpi_threads_supported = _basics.mpi_threads_supported
mpi_enabled = _basics.mpi_enabled
mpi_built = _basics.mpi_built
gloo_enabled = _basics.gloo_enabled
gloo_built = _basics.gloo_built
nccl_built = _basics.nccl_built
ddl_built = _basics.ddl_built
ccl_built = _basics.ccl_built

# import reduction op values
Average = _basics.Average
Sum = _basics.Sum
Adasum = _basics.Adasum

is_homogeneous = _basics.is_homogeneous

handle_average_backwards_compatibility = get_average_backwards_compatibility_fun(_basics)

check_num_rank_power_of_2 = num_rank_is_power_2


# This function will create a default device map which includes all visible devices.
# Please run this function in a subprocess
def _check_has_gpu():
    import tensorflow as tf
    return tf.test.is_gpu_available()


def _normalize_name(name):
    """Normalizes operation name to TensorFlow rules."""
    return re.sub('[^a-zA-Z0-9_]', '_', name)


def _allreduce(tensor, name=None, op=Sum):
    """An op which reduces an input tensor over all the Horovod processes. The
    default reduction is a sum.

    The reduction operation is keyed by the name of the op. The tensor type and
    shape must be the same on all Horovod processes for a given name. The reduction
    will not start until all processes are ready to send and receive the tensor.

    Returns:
      A tensor of the same shape and type as `tensor`, summed across all
      processes.
    """
    if name is None and not _executing_eagerly():
        name = 'HorovodAllreduce_%s' % _normalize_name(tensor.name)
    return MPI_LIB.horovod_allreduce(tensor, name=name, reduce_op=op)


@ops.RegisterGradient('HorovodAllreduce')
def _allreduce_grad(op, grad):
    """Gradient for allreduce op.

    Args:
      op: An operation.
      grad: `Tensor` gradient with respect to the output of the op.

    Returns:
      The gradient with respect to the input of the op.
    """
    return _allreduce(grad)


[docs]def allgather(tensor, name=None): """An op which concatenates the input tensor with the same input tensor on all other Horovod processes. The concatenation is done on the first dimension, so the input tensors on the different processes must have the same rank and shape, except for the first dimension, which is allowed to be different. Returns: A tensor of the same type as `tensor`, concatenated on dimension zero across all processes. The shape is identical to the input shape, except for the first dimension, which may be greater and is the sum of all first dimensions of the tensors in different Horovod processes. """ if name is None and not _executing_eagerly(): name = 'HorovodAllgather_%s' % _normalize_name(tensor.name) return MPI_LIB.horovod_allgather(tensor, name=name)
@ops.RegisterGradient('HorovodAllgather') def _allgather_grad(op, grad): """Gradient for allgather op. Args: op: An operation. grad: `Tensor` gradient with respect to the output of the op. Returns: The gradient with respect to the input of the op. """ grad = _allreduce(grad) with tf.device('/cpu:0'): # Keep the tensor of split sizes on CPU. x = op.inputs[0] d0 = x.get_shape().as_list()[0] d = tf.convert_to_tensor([d0], dtype=tf.int32) s = size() d = tf.reshape(allgather(d), [s]) splits = tf.split(grad, num_or_size_splits=d, axis=0) return splits[rank()]
[docs]def broadcast(tensor, root_rank, name=None): """An op which broadcasts the input tensor on root rank to the same input tensor on all other Horovod processes. The broadcast operation is keyed by the name of the op. The tensor type and shape must be the same on all Horovod processes for a given name. The broadcast will not start until all processes are ready to send and receive the tensor. Returns: A tensor of the same shape and type as `tensor`, with the value broadcasted from root rank. """ if name is None and not _executing_eagerly(): name = 'HorovodBroadcast_%s' % _normalize_name(tensor.name) return MPI_LIB.horovod_broadcast(tensor, name=name, root_rank=root_rank)
@ops.RegisterGradient('HorovodBroadcast') def _broadcast_grad(op, grad): """Gradient for broadcast op. Args: op: An operation. grad: `Tensor` gradient with respect to the output of the op. Returns: The gradient with respect to the input of the op. """ root_rank = op.get_attr('root_rank') grad_reduced = _allreduce(grad) if rank() != root_rank: return grad_reduced * 0 return grad_reduced