Source code for horovod.tensorflow.functions

# Copyright 2020 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import io

import cloudpickle
import numpy as np
import tensorflow as tf

from tensorflow.python.framework import ops

from horovod.tensorflow.mpi_ops import allgather, broadcast
from horovod.tensorflow.mpi_ops import rank, size
from horovod.tensorflow.util import _cache, _executing_eagerly, _make_subgraph


@_cache
def _make_broadcast_group_fn():
    if _executing_eagerly():
        # Eager mode will parallelize independent control flow
        def broadcast_group(variables, root_rank):
            for var in variables:
                var.assign(broadcast(var, root_rank))

        return _make_subgraph(broadcast_group)
    else:
        # Graph mode requires an Op
        def broadcast_group(variables, root_rank):
            return tf.group(*[var.assign(broadcast(var, root_rank))
                              for var in variables])

        return broadcast_group


[docs]def broadcast_variables(variables, root_rank): """Broadcasts variables from root rank to all other processes. Arguments: variables: variables for broadcast root_rank: rank of the process from which global variables will be broadcasted to all other processes. """ broadcast_group = _make_broadcast_group_fn() return broadcast_group(variables, root_rank)
[docs]def broadcast_object(obj, root_rank=0, session=None, name=None): """ Serializes and broadcasts an object from root rank to all other processes. Arguments: obj: An object capable of being serialized without losing any context. root_rank: The rank of the process from which parameters will be broadcasted to all other processes. session: Session for TensorFlow v1 compatibility. name: Optional name to use during broadcast, will default to the class type. Returns: The object that was broadcast from the `root_rank`. """ if name is None: name = type(obj).__name__ def to_numpy(v): if not _executing_eagerly(): sess = session or ops.get_default_session() return sess.run(v) else: return v.numpy() if rank() == root_rank: b = io.BytesIO() cloudpickle.dump(obj, b) t = tf.convert_to_tensor(bytearray(b.getvalue()), dtype=tf.uint8) sz = tf.convert_to_tensor([t.shape[0]], dtype=tf.int32) to_numpy(broadcast(sz, root_rank, name + '.sz')) else: sz = tf.convert_to_tensor([0], dtype=tf.int32) sz = to_numpy(broadcast(sz, root_rank, name + '.sz')) t = tf.zeros(sz.tolist()[0], dtype=tf.uint8) t = to_numpy(broadcast(t, root_rank, name + '.t')) if rank() != root_rank: buf = io.BytesIO(t.tobytes()) obj = cloudpickle.load(buf) return obj
def broadcast_object_fn(root_rank=0, session=None, name=None): name = name or 'broadcast_object_fn' sz = tf.placeholder(tf.int32, [1], name='bcast_object_size') bcast_size = broadcast(sz, root_rank, name + '.sz') t = tf.placeholder(tf.uint8, [None], name='bcast_object_data') bcast_data = broadcast(t, root_rank, name + '.t') session = session or ops.get_default_session() def _bcast(obj): if rank() == root_rank: b = io.BytesIO() cloudpickle.dump(obj, b) t_ = bytearray(b.getvalue()) sz_ = [len(t_)] session.run(bcast_size, feed_dict={sz: sz_}) else: sz_ = [0] sz_ = session.run(bcast_size, feed_dict={sz: sz_}) t_ = np.zeros(sz_, dtype=np.uint8) t_ = session.run(bcast_data, feed_dict={t: t_}) if rank() != root_rank: buf = io.BytesIO(t_.tobytes()) obj = cloudpickle.load(buf) return obj return _bcast
[docs]def allgather_object(obj, session=None, name=None): """ Serializes and allgathers an object from all other processes. Arguments: obj: An object capable of being serialized without losing any context. session: Session for TensorFlow v1 compatibility. name: Optional name to use during allgather, will default to the class type. Returns: The list of objects that were allgathered across all ranks. """ if name is None: name = type(obj).__name__ def load(byte_array): buf = io.BytesIO(byte_array.tobytes()) return cloudpickle.load(buf) def to_numpy(v): if not _executing_eagerly(): sess = session or ops.get_default_session() return sess.run(v) else: return v.numpy() b = io.BytesIO() cloudpickle.dump(obj, b) t = tf.convert_to_tensor(bytearray(b.getvalue()), dtype=tf.uint8) sz = tf.convert_to_tensor([t.shape[0]], dtype=tf.int32) sizes = to_numpy(allgather(sz, name=name + '.sz')) gathered = to_numpy(allgather(t, name=name + '.t')) def select(i): start = sum(sizes[:i]) end = start + sizes[i] return gathered[start:end] return [load(select(i)) for i in range(size())]