Source code for horovod.spark.runner

# Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import os
import platform
import time

import pyspark
from six.moves import queue

from horovod.run.util.threads import in_thread
from horovod.spark.task import task_service
from horovod.spark.gloo_run import gloo_run
from horovod.spark.mpi_run import mpi_run
from horovod.run.runner import is_gloo_used, run_controller
from horovod.run.common.util import timeout, host_hash, secret
from horovod.run.common.util import settings as hvd_settings
from horovod.spark.driver import driver_service, job_id


MINIMUM_COMMAND_LIFETIME_S = 3

# Spark will fail to initialize correctly locally on Mac OS without this
if platform.system() == 'Darwin':
    os.environ['OBJC_DISABLE_INITIALIZE_FORK_SAFETY'] = 'YES'


def _task_fn(index, driver_addresses, key, settings, use_gloo):
    # deserialized on Spark workers, settings do not contain the key, so it is given here explicitly
    # Spark RPC communicates the key and supports encryption
    # for convenience, we put it back into settings
    settings.key = key

    task = task_service.SparkTaskService(index, settings.key, settings.nics, settings.verbose)
    try:
        driver_client = driver_service.SparkDriverClient(driver_addresses, settings.key, settings.verbose)
        driver_client.register_task(index, task.addresses(), host_hash.host_hash())
        task.wait_for_initial_registration(settings.timeout)
        task_indices_on_this_host = driver_client.task_host_hash_indices(host_hash.host_hash())

        # With Gloo all tasks wait for the command
        # With MPI task with first index executes orted which will run mpirun_exec_fn for all tasks.
        minimum_lifetime_after_start = None
        if use_gloo or task_indices_on_this_host[0] == index:
            task.wait_for_command_start(settings.timeout)
            minimum_lifetime_after_start = timeout.Timeout(MINIMUM_COMMAND_LIFETIME_S,
                                                           message='Just measuring runtime')
            task.wait_for_command_termination()
        else:
            # The rest of tasks need to wait for the first task to finish.
            first_task_addresses = driver_client.all_task_addresses(task_indices_on_this_host[0])
            first_task_client = \
                task_service.SparkTaskClient(task_indices_on_this_host[0],
                                             first_task_addresses, settings.key,
                                             settings.verbose)
            first_task_client.wait_for_command_termination()

        # command terminated, make sure this task service does not shutdown too quickly after
        # the client started the command as it needs some time to connect again
        # to wait for the result after starting the command (see horovod.spark.driver.rsh).
        if minimum_lifetime_after_start is not None:
            time.sleep(minimum_lifetime_after_start.remaining())

        return task.fn_result()
    finally:
        # this has to block on running requests (wait_for_command_exit_code)
        # so they can finish serving the exit code
        # shutdown does block with network.BasicService._server._block_on_close = True
        task.shutdown()


def _make_mapper(driver_addresses, settings, use_gloo):
    # serialised settings do not have a key so we have to copy it and provide it explicitly here
    key = settings.key

    def _mapper(index, _):
        yield _task_fn(index, driver_addresses, key, settings, use_gloo)

    return _mapper


def _make_spark_thread(spark_context, spark_job_group, driver, result_queue,
                       settings, use_gloo):
    """Creates `settings.num_proc` Spark tasks in a parallel thread."""
    def run_spark():
        """Creates `settings.num_proc` Spark tasks, each executing `_task_fn` and waits for them to terminate."""
        try:
            spark_context.setJobGroup(spark_job_group,
                                      "Horovod Spark Run",
                                      interruptOnCancel=True)
            procs = spark_context.range(0, numSlices=settings.num_proc)
            # We assume that folks caring about security will enable Spark RPC encryption,
            # thus ensuring that key that is passed here remains secret.
            result = procs.mapPartitionsWithIndex(_make_mapper(driver.addresses(), settings, use_gloo)).collect()
            result_queue.put(result)
        except:
            driver.notify_spark_job_failed()
            raise

    spark_thread = in_thread(target=run_spark, daemon=False)
    return spark_thread


def _launch_job(use_mpi, use_gloo, settings, driver, env, stdout=None, stderr=None):
    # Determine a set of common interfaces for task-to-task communication.
    nics = set(driver.task_addresses_for_tasks(0).keys())
    for index in range(1, settings.num_proc):
        nics.intersection_update(driver.task_addresses_for_tasks(index).keys())
    if not nics:
        raise Exception('Unable to find a set of common task-to-task communication interfaces: %s'
                        % [(index, driver.task_addresses_for_tasks(index)) for index in range(settings.num_proc)])

    run_controller(use_gloo, lambda: gloo_run(settings, nics, driver, env),
                   use_mpi, lambda: mpi_run(settings, nics, driver, env, stdout, stderr),
                   False, lambda: None,
                   settings.verbose)


[docs]def run(fn, args=(), kwargs={}, num_proc=None, start_timeout=None, use_mpi=None, use_gloo=None, extra_mpi_args=None, env=None, stdout=None, stderr=None, verbose=1, nics=None): """ Runs Horovod in Spark. Runs `num_proc` processes executing `fn` using the same amount of Spark tasks. Args: fn: Function to run. args: Arguments to pass to `fn`. kwargs: Keyword arguments to pass to `fn`. num_proc: Number of Horovod processes. Defaults to `spark.default.parallelism`. start_timeout: Timeout for Spark tasks to spawn, register and start running the code, in seconds. If not set, falls back to `HOROVOD_SPARK_START_TIMEOUT` environment variable value. If it is not set as well, defaults to 600 seconds. extra_mpi_args: Extra arguments for mpi_run. Defaults to no extra args. env: Environment dictionary to use in Horovod run. stdout: Horovod stdout is redirected to this stream. Defaults to sys.stdout. stderr: Horovod stderr is redirected to this stream. Defaults to sys.stderr. verbose: Debug output verbosity (0-2). Defaults to 1. nics: List of NICs for tcp network communication. Returns: List of results returned by running `fn` on each rank. """ if start_timeout is None: # Lookup default timeout from the environment variable. start_timeout = int(os.getenv('HOROVOD_SPARK_START_TIMEOUT', '600')) # nics needs to be a set if nics and not isinstance(nics, set): nics = set(nics) tmout = timeout.Timeout(start_timeout, message='Timed out waiting for {activity}. Please check that you have ' 'enough resources to run all Horovod processes. Each Horovod ' 'process runs in a Spark task. You may need to increase the ' 'start_timeout parameter to a larger value if your Spark resources ' 'are allocated on-demand.') settings = hvd_settings.Settings(verbose=verbose, extra_mpi_args=extra_mpi_args, key=secret.make_secret_key(), timeout=tmout, nics=nics, run_func_mode=True) spark_context = pyspark.SparkContext._active_spark_context if spark_context is None: raise Exception('Could not find an active SparkContext, are you ' 'running in a PySpark session?') if num_proc is None: num_proc = spark_context.defaultParallelism if settings.verbose >= 1: print('Running %d processes (inferred from spark.default.parallelism)...' % num_proc) else: if settings.verbose >= 1: print('Running %d processes...' % num_proc) settings.num_proc = num_proc result_queue = queue.Queue(1) # start Spark driver service and launch settings.num_proc Spark tasks spark_job_group = 'horovod.spark.run.%d' % job_id.next_job_id() driver = driver_service.SparkDriverService(settings.num_proc, fn, args, kwargs, settings.key, settings.nics) gloo_is_used = is_gloo_used(use_gloo=use_gloo, use_mpi=use_mpi, use_jsrun=False) spark_thread = _make_spark_thread(spark_context, spark_job_group, driver, result_queue, settings, gloo_is_used) try: # wait for all tasks to register, notify them and initiate task-to-task address registration _notify_and_register_task_addresses(driver, settings) # Determine the index grouping based on host hashes. # Barrel shift until index 0 is in the first host. host_hashes = list(driver.task_host_hash_indices().keys()) host_hashes.sort() while 0 not in driver.task_host_hash_indices()[host_hashes[0]]: host_hashes = host_hashes[1:] + host_hashes[:1] settings.hosts = ','.join('%s:%d' % (host_hash, len(driver.task_host_hash_indices()[host_hash])) for host_hash in host_hashes) # Determine the ranks to indicies ranks_to_indices = [] for host_hash in host_hashes: ranks_to_indices += driver.task_host_hash_indices()[host_hash] driver.set_ranks_to_indices(ranks_to_indices) # Run the job _launch_job(use_mpi, use_gloo, settings, driver, env, stdout, stderr) except: # Terminate Spark job. spark_context.cancelJobGroup(spark_job_group) # Re-raise exception. raise finally: spark_thread.join() driver.shutdown() # Make sure Spark Job did not fail. driver.check_for_spark_job_failure() # If there's no exception, execution results are in this queue. results = result_queue.get_nowait() return [results[index] for index in ranks_to_indices]
def _notify_and_register_task_addresses(driver, settings): # wait for num_proc tasks to register driver.wait_for_initial_registration(settings.timeout) if settings.verbose >= 2: print('Initial Spark task registration is complete.') def notify_and_register(index): task_client = task_service.SparkTaskClient(index, driver.task_addresses_for_driver(index), settings.key, settings.verbose) task_client.notify_initial_registration_complete() next_task_index = (index + 1) % settings.num_proc next_task_addresses = driver.all_task_addresses(next_task_index) task_to_task_addresses = task_client.get_task_addresses_for_task(next_task_index, next_task_addresses) driver.register_task_to_task_addresses(next_task_index, task_to_task_addresses) for index in range(settings.num_proc): in_thread(notify_and_register, (index,)) driver.wait_for_task_to_task_address_updates(settings.timeout) if settings.verbose >= 2: print('Spark task-to-task address registration is complete.')