Source code for horovod.spark.common.util

# Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import horovod.spark.common._namedtuple_fix

import contextlib
import os
import time

from multiprocessing.pool import ThreadPool
import pyarrow as pa
import numpy as np
import pyspark.sql.functions as f
from pyspark.ml.linalg import DenseVector, SparseVector, Vector, VectorUDT
from pyspark.sql.types import ArrayType, BinaryType, BooleanType, FloatType, DoubleType, \
    IntegerType, LongType, NullType, StringType
try:
    # Spark 3.0 moved to a pandas submodule
    from pyspark.sql.pandas.types import from_arrow_type
except ImportError:
    from pyspark.sql.types import from_arrow_type

from pyspark.sql import SparkSession

from horovod.runner.common.util import codec, host_hash as hh
from horovod.spark.common import cache, constants

_training_cache = cache.TrainingDataCache()


[docs]def host_hash(salt=None): """ Computes this host's host hash by invoking horovod.runner.common.util.host_hash.host_hash. Consider environment variable CONTAINER_ID which is present when running Spark via YARN. A YARN container does not share memory with other containers on the same host, so it must be considered a `host` in the sense of the `host_hash`. :param salt: extra information to include in the hash, ignores Falsy values :return: host hash """ # turn salt into an array of a single string if given salt = [str(salt)] if salt else [] # We would violate resource allocation if we run all tasks of a host in one container. # See [issues 1497](https://github.com/horovod/horovod/issues/1497) for details. container = os.environ.get("CONTAINER_ID") if container is not None: salt.append(container) return hh.host_hash(salt='-'.join(salt))
def data_type_to_str(dtype): if dtype == VectorUDT or dtype == SparseVector or dtype == DenseVector: return 'Vector' elif dtype == IntegerType: return 'Int' elif dtype == StringType: return 'String' elif dtype == FloatType: return 'Float' elif dtype == BinaryType: return 'Binary' elif dtype == DoubleType: return 'Double' elif dtype == LongType: return 'Long' elif dtype == BooleanType: return 'Boolean' else: raise ValueError('Unrecognized data type: {}'.format(dtype)) def numpy_type_to_str(dtype): if dtype == np.int32: return 'Int' elif dtype == np.float32: return 'Float' elif dtype == np.uint8: return 'Binary' elif dtype == np.float64: return 'Double' elif dtype == np.int64: return 'Long' elif dtype == np.bool: return 'Boolean' else: raise ValueError('Cannot convert numpy data type to Spark string: {}'.format(dtype)) def spark_scalar_to_python_type(dtype): if dtype == IntegerType: return int elif dtype == StringType: return str elif dtype == FloatType: return float elif dtype == DoubleType: return float elif dtype == LongType: return int elif dtype == BooleanType: return bool elif dtype == BinaryType: return bytes else: raise ValueError('cannot convert Spark data Type {} to native python type'.format(dtype)) def pyarrow_to_spark_data_type(dtype): # PySpark will interpret list types as Arrays, but for ML applications we want to default to # treating these as DenseVectors. if pa.types.is_list(dtype): return DenseVector return type(from_arrow_type(dtype)) def data_type_to_numpy(dtype): if dtype == VectorUDT or dtype == SparseVector or dtype == DenseVector: return np.float64 elif dtype == ArrayType: return np.float64 elif dtype == IntegerType: return np.int32 elif dtype == StringType: return np.str elif dtype == FloatType: return np.float32 elif dtype == BinaryType: return np.uint8 elif dtype == DoubleType: return np.float64 elif dtype == LongType: return np.int64 elif dtype == BooleanType: return np.bool else: raise ValueError('Unrecognized data type: {}'.format(dtype)) def check_shape_compatibility(metadata, feature_columns, label_columns, input_shapes=None, output_shapes=None, label_shapes=None): # Check for model and input type incompatibility. Columns must have the same size # (total number of elements) of the corresponding inputs. feature_count = len(feature_columns) if input_shapes is not None: if feature_count != len(input_shapes): raise ValueError('Feature column count {features} must equal ' 'model inputs count {inputs}' .format(features=feature_count, inputs=len(input_shapes))) for idx, col, input_shape in zip(range(feature_count), feature_columns, input_shapes): col_size = metadata[col]['shape'] if col_size is None: # When training directly on Parquet, we do not compute shape metadata continue input_size = abs(np.prod(input_shape)) if col_size != input_size: raise ValueError( 'Feature column \'{col}\' with size {feature} must equal that of the ' 'model input at index {idx} with size {input}' .format(col=col, feature=col_size, idx=idx, input=input_size)) label_count = len(label_columns) if label_shapes is not None and label_count != len(label_shapes): raise ValueError('Label column count {labels} must equal ' 'provided label shapes count {outputs}' .format(labels=label_count, outputs=len(label_shapes))) if output_shapes is not None and label_count != len(output_shapes): raise ValueError('Label column count {labels} must equal ' 'model outputs count {outputs}' .format(labels=label_count, outputs=len(output_shapes))) def _check_label_cols_size(target_shapes, target_name): for idx, col, target_shape in zip(range(label_count), label_columns, target_shapes): col_size = metadata[col]['shape'] if col_size is None: # When training directly on Parquet, we do not compute shape metadata continue target_size = abs(np.prod(target_shape)) if col_size != target_size: raise ValueError('Label column \'{col}\' with size {label} must equal that of the ' '{target_name} shape at index {idx} with size {output}' .format(col=col, label=col_size, idx=idx, output=target_size, target_name=target_name)) if label_shapes is not None: _check_label_cols_size(label_shapes, 'label') elif output_shapes is not None: # Check the label size against the model output shapes only if label_shapes is not provided. _check_label_cols_size(output_shapes, 'model output') def _get_col_info(df): """ Infer the type and shape of all the columns. NOTE: This function processes the entire DataFrame, and can therefore be very expensive to run. TODO(travis): Only run this if user sets compress_sparse param, otherwise convert all to Array. """ def get_meta(row): row_dict = row.asDict() row_schema = [] for col_name, data_col in row_dict.items(): dtype = type(data_col) if isinstance(data_col, DenseVector): # shape and size of dense vector are the same shape = size = data_col.array.shape[0] elif isinstance(data_col, SparseVector): # shape is the total size of vector shape = data_col.size # size is the number of nonzero elements in the sparse vector size = data_col.indices.shape[0] elif isinstance(data_col, list): shape = size = len(data_col) else: shape = size = 1 row_schema.append((col_name, ({dtype}, {shape}, {size}))) return row_schema def merge(x, y): x_dtypes, x_shapes, x_sizes = x y_dtypes, y_shapes, y_sizes = y dtypes = x_dtypes | y_dtypes shapes = x_shapes | y_shapes sizes = x_sizes | y_sizes return dtypes, {min(shapes), max(shapes)}, {min(sizes), max(sizes)} raw_col_info_list = df.rdd.flatMap(get_meta).reduceByKey(merge).collect() all_col_types = {} col_shapes = {} col_max_sizes = {} for col_info in raw_col_info_list: col_name, col_meta = col_info dtypes, shapes, sizes = col_meta all_col_types[col_name] = dtypes col_shapes[col_name] = shapes col_max_sizes[col_name] = sizes for col in df.schema.names: # All rows in every column must have the same shape shape_set = col_shapes[col] if len(shape_set) != 1: raise ValueError( 'Column {col} does not have uniform shape. ' 'shape set: {shapes_set}'.format(col=col, shapes_set=shape_set)) col_shapes[col] = shape_set.pop() # All rows in every column must have the same size unless they have SparseVectors sizes = col_max_sizes[col] if len(sizes) > 1 and not (SparseVector in all_col_types[col]): raise ValueError( 'Rows of column {col} have varying sizes. This is only allowed if datatype is ' 'SparseVector or a mix of Sparse and DenseVector.'.format(col=col)) col_max_sizes[col] = max(sizes) return all_col_types, col_shapes, col_max_sizes def _get_metadata(df): """ Infer the type and shape of all the columns and determines if what intermediate format they need to be converted to in case they are a vector. Example return value: { 'col1': { 'dtype': <type 'float'>, 'intermediate_format': 'nochange', 'max_size': 1, 'shape': 1 }, 'col2': { 'dtype': <type 'float'>, 'intermediate_format': 'nochange', 'max_size': 1, 'shape': 1 }, 'col3': { 'dtype': <class 'pyspark.ml.linalg.SparseVector'>, 'intermediate_format': 'custom_sparse_format', 'max_size': 37, 'shape': 56 } } """ all_col_types, col_shapes, col_max_sizes = _get_col_info(df) metadata = dict() for field in df.schema.fields: col = field.name col_types = all_col_types[col].copy() if DenseVector in col_types: # If a col has DenseVector type (whether it is mixed sparse and dense vector or just # DenseVector), convert all of the values to dense vector is_sparse_vector_only = False spark_data_type = DenseVector convert_to_target = constants.ARRAY elif SparseVector in col_types: # If a col has only sparse vectors, convert all the data into custom dense vectors is_sparse_vector_only = True spark_data_type = SparseVector convert_to_target = constants.CUSTOM_SPARSE else: is_sparse_vector_only = False spark_data_type = type(field.dataType) convert_to_target = constants.NOCHANGE # Explanation of the fields in metadata # dtype: # # spark_data_type: # The spark data type from dataframe schema: type(field.dataType). If column has # mixed SparseVector and DenseVector we categorize it as DenseVector. # # is_sparse_vector_only: # If all the rows in the column were sparse vectors. # # shape: # Determines the shape of the data in the spark dataframe. It is useful for sparse # vectors. # # intermediate_format: # Specifies if the column need to be converted to a different format so that # petastorm can read it. It can be one of ARRAY, CUSTOM_SPARSE, or NOCHANGE. It is # required because petastorm cannot read DenseVector and SparseVectors. We need to # identify these types and convert them to petastorm compatible type of array. metadata[col] = {'spark_data_type': spark_data_type, 'is_sparse_vector_only': is_sparse_vector_only, 'shape': col_shapes[col], 'intermediate_format': convert_to_target, 'max_size': col_max_sizes[col]} return metadata def to_petastorm_fn(schema_cols, metadata): ARRAY = constants.ARRAY CUSTOM_SPARSE = constants.CUSTOM_SPARSE # Convert Spark Vectors into arrays so Petastorm can read them def to_petastorm(row): import numpy as np from pyspark import Row converted = {} for col in schema_cols: col_data = row[col] if isinstance(col_data, Vector): intermediate_format = metadata[col]['intermediate_format'] if metadata else ARRAY if intermediate_format == ARRAY: converted[col] = col_data.toArray().tolist() elif intermediate_format == CUSTOM_SPARSE: # Currently petastorm does not support reading pyspark sparse vector. We put # the indices and values into one array. when consuming the data, we re-create # the vector from this format. size = len(col_data.indices) padding_zeros = 2 * (metadata[col]['max_size'] - len(col_data.indices)) converted[col] = np.concatenate( (np.array([size]), col_data.indices, col_data.values, np.zeros(padding_zeros))).tolist() if converted: row = row.asDict().copy() row.update(converted) return Row(**row) return to_petastorm def _has_vector_column(df): for field in df.schema.fields: if isinstance(field.dataType, VectorUDT): return True return False def _get_dataset_info(dataset, dataset_id, path): total_rows = 0 total_byte_size = 0 for piece in dataset.pieces: metadata = piece.get_metadata() total_rows += metadata.num_rows for row_group_index in range(metadata.num_row_groups): row_group = metadata.row_group(row_group_index) total_byte_size += row_group.total_byte_size if total_rows == 0: raise ValueError('No rows found in {} dataset: {}'.format(dataset_id, path)) if total_byte_size == 0: raise ValueError('No data found in {} dataset: {}'.format(dataset_id, path)) if total_rows > total_byte_size: raise ValueError('Found {} bytes in {} rows; {} dataset may be corrupted.' .format(total_byte_size, total_rows, dataset_id)) return total_rows, total_byte_size def _save_meta_to_fs(fs, path, schema, rows, total_byte_size): with fs.open(path, 'wb') as train_meta_file: serialized_content = codec.dumps_base64(dict(schema=schema, rows=rows, total_byte_size=total_byte_size)) train_meta_file.write(serialized_content.encode('utf-8')) def _load_metadata_from_fs(fs, path): with fs.open(path, 'rb') as train_meta_file: meta = train_meta_file.read() meta = codec.loads_base64(meta.decode()) data_schema = meta['schema'] rows = meta['rows'] total_byte_size = meta['total_byte_size'] return data_schema, rows, total_byte_size def get_simple_meta_from_parquet(store, label_columns, feature_columns, sample_weight_col, dataset_idx=None): train_data_path = store.get_train_data_path(dataset_idx) validation_data_path = store.get_val_data_path(dataset_idx) if not store.exists(train_data_path): raise ValueError("{} path does not exist in the store".format(train_data_path)) train_data_meta_path = store.get_data_metadata_path(train_data_path) val_data_meta_path = store.get_data_metadata_path(validation_data_path) fs = store.fs schema_cols = feature_columns + label_columns if sample_weight_col: schema_cols.append(sample_weight_col) def make_metadata_dictionary(_train_data_schema): _metadata = {} for col in schema_cols: col_schema = _train_data_schema.field_by_name(col) col_info = { 'spark_data_type': pyarrow_to_spark_data_type(col_schema.type), 'is_sparse_vector_only': False, 'shape': None, # Only used by SparseVector columns 'intermediate_format': constants.NOCHANGE, 'max_size': None # Only used by SparseVector columns } _metadata[col] = col_info _avg_row_size = train_data_total_byte_size / train_rows return _metadata, _avg_row_size # In the try block we try to read the data metadata from the cached metadata in the store. If # anything goes wrong, we will ignore the cache and create the metadata from data. try: if store.exists(train_data_meta_path) and \ (store.exists(val_data_meta_path) or not store.exists(validation_data_path)): train_data_schema, train_rows, train_data_total_byte_size = \ _load_metadata_from_fs(fs, train_data_meta_path) metadata, avg_row_size = make_metadata_dictionary(train_data_schema) val_rows = 0 if store.exists(val_data_meta_path): val_data_schema, val_rows, val_data_total_byte_size = _load_metadata_from_fs(fs, val_data_meta_path) return train_rows, val_rows, metadata, avg_row_size except Exception as ex: print(ex) train_data = store.get_parquet_dataset(train_data_path) train_data_schema = train_data.schema.to_arrow_schema() train_rows, train_data_total_byte_size = _get_dataset_info(train_data, 'training', train_data_path) # Write train metadata to filesystem _save_meta_to_fs(fs, train_data_meta_path, train_data_schema, train_rows, train_data_total_byte_size) val_rows = 0 if store.exists(validation_data_path): val_data = store.get_parquet_dataset(validation_data_path) val_data_schema = val_data.schema.to_arrow_schema() val_rows, val_data_total_byte_size = _get_dataset_info(val_data, 'validation', validation_data_path) # Write validation metadata to filesystem _save_meta_to_fs(fs, val_data_meta_path, val_data_schema, val_rows, val_data_total_byte_size) metadata, avg_row_size = make_metadata_dictionary(train_data_schema) return train_rows, val_rows, metadata, avg_row_size def _train_val_split(df, validation): train_df = df val_df = None validation_ratio = 0.0 if isinstance(validation, float) and validation > 0: train_df, val_df = train_df.randomSplit([1.0 - validation, validation]) validation_ratio = validation elif isinstance(validation, str): dtype = [field.dataType for field in df.schema.fields if field.name == validation][0] bool_dtype = isinstance(dtype, BooleanType) val_df = train_df.filter( f.col(validation) if bool_dtype else f.col(validation) > 0).drop(validation) train_df = train_df.filter( ~f.col(validation) if bool_dtype else f.col(validation) == 0).drop(validation) # Approximate ratio of validation data to training data for proportionate scale # of partitions timeout_ms = 1000 confidence = 0.90 train_rows = train_df.rdd.countApprox(timeout=timeout_ms, confidence=confidence) val_rows = val_df.rdd.countApprox(timeout=timeout_ms, confidence=confidence) validation_ratio = val_rows / (val_rows + train_rows) elif validation: raise ValueError('Unrecognized validation type: {}'.format(type(validation))) return train_df, val_df, validation_ratio _DATABRICKS_FILE_AVAILABILITY_WAIT_TIMEOUT_SECS = \ int(os.environ.get('DATABRICKS_FILE_AVAILABILITY_WAIT_TIMEOUT_SECS', '30')) _DATABRICKS_FILE_AVAILABILITY_CHECK_INTERVAL_SECS = \ float(os.environ.get('DATABRICKS_FILE_AVAILABILITY_CHECK_INTERVAL_SECS', '0.1')) def _wait_file_available_on_dbfs(store, url_list): """ On databricks runtime, Waiting about DATABRICKS_FILE_AVAILABILITY_WAIT_TIMEOUT_SECS seconds (default 30 seconds) to make sure all files are available for reading. This is because Databricks filesystem backend storage such as S3 which only providing eventually consistency. """ # Import LocalStore here to avoid circular import from horovod.spark.common.store import LocalStore if isinstance(store, LocalStore): return if not is_databricks(): return def wait_for_file(path): end_time = time.time() + _DATABRICKS_FILE_AVAILABILITY_WAIT_TIMEOUT_SECS while time.time() < end_time: if store.exists(path): return True time.sleep(_DATABRICKS_FILE_AVAILABILITY_CHECK_INTERVAL_SECS) return False if len(url_list) == 0: raise ValueError('Input url_list argument is empty.') pool = ThreadPool(min(len(url_list), 64)) try: results = pool.map(wait_for_file, url_list) failed_list = [url for url, result in zip(url_list, results) if not result] if failed_list: raise TimeoutError('Timeout while waiting for all files to appear at urls {failed_list}.' .format(failed_list=','.join(failed_list))) finally: pool.close() pool.join() def _get_spark_df_saved_file_list(saved_path): spark_session = SparkSession.builder.getOrCreate() return list(spark_session.read.parquet(saved_path)._jdf.inputFiles()) def _get_or_create_dataset(key, store, df, feature_columns, label_columns, validation, sample_weight_col, compress_sparse, num_partitions, num_processes, verbose): with _training_cache.lock: if _training_cache.is_cached(key, store): dataset_idx = _training_cache.get_dataset(key) train_rows, val_rows, metadata, avg_row_size = _training_cache.get_dataset_properties(dataset_idx) train_data_path = store.get_train_data_path(dataset_idx) val_data_path = store.get_val_data_path(dataset_idx) if verbose: print('using cached dataframes for key: {}'.format(key)) print('train_data_path={}'.format(train_data_path)) print('train_rows={}'.format(train_rows)) print('val_data_path={}'.format(val_data_path)) print('val_rows={}'.format(val_rows)) else: dataset_idx = _training_cache.next_dataset_index(key) train_data_path = store.get_train_data_path(dataset_idx) val_data_path = store.get_val_data_path(dataset_idx) if verbose: print('writing dataframes') print('train_data_path={}'.format(train_data_path)) print('val_data_path={}'.format(val_data_path)) schema_cols = feature_columns + label_columns if sample_weight_col: schema_cols.append(sample_weight_col) if isinstance(validation, str): schema_cols.append(validation) df = df[schema_cols] metadata = None if _has_vector_column(df): if compress_sparse: metadata = _get_metadata(df) to_petastorm = to_petastorm_fn(schema_cols, metadata) df = df.rdd.map(to_petastorm).toDF() train_df, val_df, validation_ratio = _train_val_split(df, validation) train_partitions = max(int(num_partitions * (1.0 - validation_ratio)), num_processes) if verbose: print('train_partitions={}'.format(train_partitions)) train_df \ .coalesce(train_partitions) \ .write \ .mode('overwrite') \ .parquet(train_data_path) saved_file_list = _get_spark_df_saved_file_list(train_data_path) if val_df: val_partitions = max(int(num_partitions * validation_ratio), num_processes) if verbose: print('val_partitions={}'.format(val_partitions)) val_df \ .coalesce(val_partitions) \ .write \ .mode('overwrite') \ .parquet(val_data_path) saved_file_list += _get_spark_df_saved_file_list(val_data_path) try: _wait_file_available_on_dbfs(store, saved_file_list) except TimeoutError as e: err_msg = 'Timeout while waiting for all parquet-store files to appear, Please ' \ 'check whether these files were saved successfully when materializing ' \ 'dataframe. Internal Error: {e}'.format(e=str(e)) raise RuntimeError(err_msg) train_rows, val_rows, pq_metadata, avg_row_size = get_simple_meta_from_parquet( store, label_columns, feature_columns, sample_weight_col, dataset_idx) if verbose: print('train_rows={}'.format(train_rows)) if val_df: if val_rows == 0: raise ValueError( 'Validation DataFrame is empty with validation param: {}' .format(validation)) if verbose: print('val_rows={}'.format(val_rows)) metadata = metadata or pq_metadata _training_cache.set_dataset_properties( dataset_idx, (train_rows, val_rows, metadata, avg_row_size)) return dataset_idx def check_validation(validation, df=None): if validation: if isinstance(validation, float): if validation < 0 or validation >= 1: raise ValueError('Validation split {} must be in the range: [0, 1)' .format(validation)) elif isinstance(validation, str): if df is not None and validation not in df.columns: raise ValueError('Validation column {} does not exist in the DataFrame' .format(validation)) else: raise ValueError('Param validation must be of type "float" or "str", found: {}' .format(type(validation))) @contextlib.contextmanager def prepare_data(num_processes, store, df, label_columns, feature_columns, validation=None, sample_weight_col=None, compress_sparse=False, partitions_per_process=10, verbose=0): check_validation(validation, df=df) if num_processes <= 0 or partitions_per_process <= 0: raise ValueError('num_proc={} and partitions_per_process={} must both be > 0' .format(num_processes, partitions_per_process)) if not label_columns: raise ValueError('Parameter label_columns cannot be None or empty') num_partitions = num_processes * partitions_per_process if verbose: print('num_partitions={}'.format(num_partitions)) for col in label_columns: if col not in df.columns: raise ValueError('Label column {} does not exist in the DataFrame'.format(col)) if feature_columns is None: feature_columns = [col for col in df.columns if col not in set(label_columns)] else: for col in feature_columns: if col not in df.columns: raise ValueError('Feature column {} does not exist in the DataFrame'.format(col)) key = _training_cache.create_key(df, store, validation) with _training_cache.use_key(key): dataset_idx = _get_or_create_dataset(key, store, df, feature_columns, label_columns, validation, sample_weight_col, compress_sparse, num_partitions, num_processes, verbose) yield dataset_idx def get_dataset_properties(dataset_idx): return _training_cache.get_dataset_properties(dataset_idx) def clear_training_cache(): _training_cache.clear() def to_list(var, length): if var is None: return None if not isinstance(var, list): var = [var] # If var has only one element, pad it to match the given length. if len(var) == 1: var = [var[0] for _ in range(length)] else: count = len(var) if count != length: raise ValueError(f'List {var} must be length {length} (found: {count})') return var def _get_assigned_gpu_or_local_rank(local_rank): from horovod.spark.task import get_available_devices available_devices = get_available_devices() if available_devices and os.getenv('HOROVOD_SPARK_USE_LOCAL_RANK_GPU_INDEX', '0') == '0': # if GPU-aware scheduling is available, pin to the assigned GPU index # this is always the first GPU index in available_devices as only one GPU is expected to be assigned return int(available_devices[0]) else: # pin to local rank GPU index return local_rank def is_databricks(): return "DATABRICKS_RUNTIME_VERSION" in os.environ def get_output_cols(input_df_schema, output_cols): final_output_cols = [f.name for f in input_df_schema.fields] return final_output_cols + output_cols def get_spark_df_output_schema(input_df_schema, label_cols, output_cols, metadata): if len(label_cols) != len(output_cols): raise Exception('Number of output columns must equal to number of label columns.') import copy from pyspark.sql.types import StructType, StructField from pyspark.ml.linalg import VectorUDT ## output col name should not be the same as any data in existing df input_df_col_set = {f.name for f in input_df_schema.fields} for col in output_cols: if col in input_df_col_set: raise ValueError("Output col '{}' exists in original df schema: {}".format(col, input_df_col_set)) # assuming the label_cols and output_cols are 1:1 mapping. output_fields = copy.deepcopy(input_df_schema.fields) for label_col, output_col in zip(label_cols, output_cols): col_type = metadata[label_col]['spark_data_type'] if col_type == DenseVector or col_type == SparseVector: col_type = VectorUDT field = StructField(output_col, col_type(), True) output_fields.append(field) # Final output schema contains both input and output fields output_schema = StructType(output_fields) return output_schema def _set_mp_start_method(method, verbose): import multiprocessing as mp supported_values = mp.get_all_start_methods() if method not in supported_values: raise ValueError(f"Multiprocessing start method {method} is not supported on this system. " f"Supported values are : {supported_values}") mp.set_start_method(method) if verbose: print(f'Multiprocessing start method has been set to {mp.get_start_method()}')