import os
import pickle
import sys
from collections import Counter
from contextlib import contextmanager

from oarphpy import util

### Import Spark
### We're going to expose any import and setup errors HERE, at the time of
### importing `oarphpy.spark`, because:
###  * `pyspark` requires java even to use as a client for a remote cluster;
###       if you simply import pyspark but don't have java, your job will
###       crash with a unhelpful error message.
###  * importing `pyspark`  module may require first monkeying with
###      `sys.environ` (e.g. through `findspark`) or else you get a
###      broken imported pyspark
### So below we try to find Spark / Java, and produce a helpful error
### message otherwise


  # In python3, we filter:
  #  " 
  #      DeprecationWarning: invalid escape sequence \*"
  if sys.version_info.major >= 3:
    import warnings
      message=r'invalid escape sequence')

  # Is pyspark on the PYTHONPATH?
    import pyspark
  except ModuleNotFoundError as e:
    # OK, can findspark find a local install of Spark / Hadoop, e.g.
    # at $SPARK_HOME or /opt/spark ?

      import findspark
    except ImportError:
      # Don't make findspark a hard requirement

    import pyspark

except Exception as e:
  msg = """
      This portion of OarphPy requires Spark, which in turn requires
      Java 8 or higher.  Mebbe try installing using:
        $ pip install pyspark
      That will fix import errors.  To get Java, try:
        $ apt-get install -y openjdk-11-jdk && echo JAVA_HOME=/usr/lib/jvm/java-11-openjdk-amd64 >> /etc/environment
      If you have spark installed locally (e.g. from source), set $SPARK_HOME
      *** Original error: %s
  """ % (e,)
  raise ImportError(msg)

### General Spark Utils

def num_executors(spark):
  # NB: Not a public API! But likely stable.

[docs]def for_each_executor(spark, thunk): """While Spark does not officially provide an API since the number of executors can change through the course of a job, this approach leverages the best practices to evaluate `thunk` at most once per executor.""" class LazyFunc(object): def __init__(self, th): self._thunk = th def __call__(self, _): if not hasattr(self, '_result'): # call body will be executed at most once per executor python process import uuid self._result = self._thunk() self._id = str(uuid.uuid4()) return (self._id, self._result) sc = spark.sparkContext N = max(1, num_executors(spark)) rdd = sc.parallelize(list(range(N)), numSlices=N) id_to_res = dict( results = list(id_to_res.values()) # Return results from only distinct executors assert len(results) == N return results
[docs]def cluster_cpu_count(spark): """Like `os.cpu_count()` but for an entire Spark cluster. Useful for scalling memory-intensive jobs that use the RDD api.""" def get_cpu_count(): import multiprocessing return multiprocessing.cpu_count() res = for_each_executor(spark, get_cpu_count) return sum(res)
def run_callables(spark, callables, parallel=-1): import cloudpickle # Spark uses regular pickle for RDD data; here we need cloudpickle in # order to serialize and run code. callable_bytess = [cloudpickle.dumps(c) for c in callables] if parallel <= 0: parallel = len(callable_bytess) rdd = spark.sparkContext.parallelize(callable_bytess, numSlices=parallel) def invoke(callable_bytes): import cloudpickle c = cloudpickle.loads(callable_bytes) res = c() return callable_bytes, cloudpickle.dumps(res) rdd = all_results = [ (cloudpickle.loads(callable_bytes), cloudpickle.loads(res)) for callable_bytes, res in rdd.collect() ] return all_results
[docs]def union_dfs(*dfs): """Return the union of a sequence DataFrames and attempt to merge the schemas of each (i.e. union of all columns). Based upon """ if not dfs: return dfs df = dfs[0] for df_other in dfs[1:]: left_types = { f.dataType for f in df.schema} right_types = { f.dataType for f in df_other.schema} left_fields = set( (, f.dataType, f.nullable) for f in df.schema) right_fields = set( (, f.dataType, f.nullable) for f in df_other.schema) from pyspark.sql.functions import lit # First go over `df`-unique fields for l_name, l_type, l_nullable in left_fields.difference(right_fields): if l_name in right_types: r_type = right_types[l_name] if l_type != r_type: raise TypeError( "Union failed. Type conflict on field %s. left type %s, right type %s" % (l_name, l_type, r_type)) else: raise TypeError( "Union failed. Nullability conflict on field %s. left nullable %s, right nullable %s" % (l_name, l_nullable, not(l_nullable))) df_other = df_other.withColumn(l_name, lit(None).cast(l_type)) # Now go over `df_other`-unique fields for r_name, r_type, r_nullable in right_fields.difference(left_fields): if r_name in left_types: l_type = right_types[r_name] if r_type != l_type: raise TypeError( "Union failed. Type conflict on field %s. right type %s, left type %s" % (r_name, r_type, l_type)) else: raise TypeError( "Union failed. Nullability conflict on field %s. right nullable %s, left nullable %s" % (r_name, r_nullable, not(r_nullable))) df = df.withColumn(r_name, lit(None).cast(r_type)) df = df.unionByName(df_other) return df
[docs]def get_balanced_sample(spark_df, col, n_per_category=None, seed=1337): """Given a column `col` in `spark_df`, return a *balanced* sample (countering class imbalances in `spark_df[col]`). Optionally limit the sample to having up to `n_per_category` examples for every distinct categorical value of `spark_df[col]`.""" from pyspark.sql import functions as F category_to_count_df = spark_df.groupBy(col).agg(F.count('*')) category_to_count = category_to_count_df.rdd.collectAsMap() assert category_to_count # We will only sample as many as the rarest category numerator = min(category_to_count.values()) if n_per_category is not None: numerator = min(numerator, n_per_category) fractions = dict( (category, float(numerator) / count) for category, count in category_to_count.items() ) return spark_df.sampleBy(col, fractions=fractions, seed=seed)
### Test Utilities (for unit tests and other debugging)
[docs]def cluster_get_info(spark): """Return a text report showing various details about each worker in the cluster.""" infos = for_each_executor(spark, lambda: util.get_sys_info()) def format_info(info): s = """ Host: {hostname} {host} Egg: {filepath} Internet connectivity: {have_internet} Num CPUs: {n_cpus} Memory: {memory} PYTHONPATH: {PYTHONPATH} nvidia-smi: {nvidia_smi} Disk: {disk_free} """.format(**info) return '\n'.join(l.lstrip() for l in s.split('\n')) info_str = '\n\n'.join(format_info(info) for info in infos) return info_str
[docs]def test_pi(spark): """Run the "textbook" Monte Carlo Pi Sampling demo and assert correctness. When run on a cluster, this test can help suss out networking issues between the master and worker machines.""""Running PI ...") sc = spark.sparkContext num_samples = 1000000 def inside(p): import random x, y = random.random(), random.random() return x*x + y*y < 1 count = sc.parallelize(list(range(0, num_samples))).filter(inside).count() pi = 4 * float(count) / num_samples"Pi estimate: %s" % pi) assert abs(pi - 3.14) < 0.1, "Spark program had an error?"
def _op_test(): from oarphpy import util res = list(util.ichunked([1, 2, 3], 1)) assert res == [(1,), (2,), (3,)], res
[docs]def test_egg(spark, modname='oarphpy', test_egg_contents=_op_test): """Test the egg that `SessFactory` (below) dynamically builds and includes in spark jobs. We'll run the function `test_egg_contents` on each worker to actually execute code included in the egg. """ # SessFactory below always creates eggs with this name EXPECTED_EGG_NAME = modname + '-0.0.0' + _egg_py_suffix() def run_test_on_worker(): # Normally, pytest puts the local source tree on the PYTHONPATH. That # setting gets inherited when Spark forks a python subprocess to run # this function. Remove the source tree from the PYTHONPATH here # in order to force pyspark to read from the egg file / SparkFiles. # We may safely edit the PYTHONPATH here because this code is run in a # child python process that will soon exit. import sys if '/opt/oparhpy' in sys.path: sys.path.remove('/opt/oparhpy') if '' in sys.path: sys.path.remove('') ## Check for the egg, which Spark puts on the PYTHONPATH egg_path = '' for p in sys.path: if EXPECTED_EGG_NAME in p: egg_path = p assert egg_path, "Egg not found in {}".format(sys.path) ## Is the egg any good? import zipfile f = zipfile.ZipFile(egg_path) egg_contents = f.namelist() assert any(modname in fname for fname in egg_contents), egg_contents ## Use the egg! test_egg_contents() return util.get_sys_info()"Testing egg ...") infos = for_each_executor(spark, run_test_on_worker) host_to_syspath = dict((info['host'], info['PYTHONPATH']) for info in infos) assert all(EXPECTED_EGG_NAME in p for p in host_to_syspath.values()), \ "One or more workers missing egg %s: %s" % ( EXPECTED_EGG_NAME, host_to_syspath)
[docs]def test_tensorflow(spark): """Remotely test Tensorflow support in the given `spark` cluster""" def test_and_get_info(): import random import tensorflow as tf tf.compat.v1.disable_v2_behavior() # NB: this only impacts the executor Python process x = int(10 * random.random()) a = tf.constant(x) b = tf.constant(3) from oarphpy import util sess = util.tf_create_session() res = * b) assert res == 3 * x import socket info = { 'hostname': socket.gethostname(), 'gpu': tf.test.gpu_device_name(), } return [info]"Testing Tensorflow ...") res = for_each_executor(spark, test_and_get_info)"... Tensorflow success! Info:") import pprint'\n\n' + pprint.pformat(res) + '\n\n')
### Counters try: from pyspark.accumulators import AccumulatorParam AccumulatorParam_BASE = AccumulatorParam except: class AccumulatorParam_BASE(object): # A non-functional dummy when pyspark is not available pass
[docs]class CounterAccumulator(AccumulatorParam_BASE):
[docs] def zero(self, value): return Counter({})
[docs] def addInPlace(self, value1, value2): return value1 + value2
def create_counter_accumulator(spark): sc = spark.sparkContext acc = sc.accumulator(Counter(), CounterAccumulator()) return acc class CounterCollection(object): def __init__(self, spark, name=''): self._acc = create_counter_accumulator(spark) self._name = name def __getitem__(self, key): return self._acc.value[key] def __setitem__(self, key, value): self.tally(key, value) def tally(self, key, value): c = Counter() c[key] = value self._acc += c def kv_tally(self, tag, key='', value=0): self.tally('__psegs_kv.' + tag + '.key=' + key, value) def get_kv_tally(self, tag): key_prefix = '__psegs_kv.' + tag + '.key=' return dict( (k[len(key_prefix):], v) for k, v in self._acc.value.items() if k.startswith(key_prefix) ) def __str__(self): import pprint HEADER = ( "CounterCollection({name})\n" "Spark Accumulator: {acc}\n" "Counters:\n" ).format( name=self._name, acc=self._acc.aid, ) kv_tags = set() kvs = [] kv_prefix = '__psegs_kv.' for k, v in self._acc.value.items(): if k.startswith(kv_prefix): k = k[len(kv_prefix):] tag = k.split('.key=')[0] kv_tags.add(tag) else: kvs.append((k, v)) for tag in kv_tags: kvs.append((tag, self.get_kv_tally(tag))) BODY = '\n'.join("%s: %s" % (k, pprint.pformat(v)) for k, v in sorted(kvs)) return "%s\n%s\n" % (HEADER, BODY) def __repr__(self): # For things like pprint that use __repr__ instead of __str__ return str(self) @contextmanager def log_progress(self, log_func=None, log_freq_sec=10): log_func = log_func or import threading exit_event = threading.Event() def spin_log(): REPORT_EVERY_SEC = 10 import time start_wait = time.time() while not exit_event.is_set(): if time.time() - start_wait >= log_freq_sec: log_func(self) start_wait = time.time() time.sleep(0.5) bkg_th = threading.Thread(target=spin_log, args=()) bkg_th.daemon = True bkg_th.start() yield self exit_event.set() bkg_th.join() ### OarphPy-specific Extras def archive_rdd(spark, path): fws = util.ArchiveFileFlyweight.fws_from(path) return spark.sparkContext.parallelize(fws) ################################################################################ ### Spark Session Factories ### def _egg_py_suffix(): py_vers = (sys.version_info.major, sys.version_info.minor) py_suffix = '-py{}.{}.egg'.format(*py_vers) return py_suffix
[docs]class SessionFactory(object): """This class is a factory for pre-configured `SparkSession` instances. A primary feaure of this factory is that it automagically includes the caller's python module as a Spark PyFile, thus effectively shipping the user's python project with the Spark job. This class also helps centralize project-wide Spark configuration. This class is designed as programmatic replacement for the `spark-submit` shell script. To create and use a session: >>> from oarphpy import spark as S >>> spark = S.SessionFactory.getOrCreate() >>> S.num_executors(spark) 1 Or using as a context manager: >>> with S.SessionFactory.sess() as spark: ... print(S.num_executors(spark)) 1 See `LocalK8SSpark` below for an example of how to subclass this factory for your own project. See also `NBSpark` below for an example subclass that enables interop with the `sparkmonitor` package for jupyter. """ # Default to local Spark master MASTER = None # Default `SparkConf` instance to use for any new session CONF = None # Default set of `SparkConf` key-value settings to use for any new sesion, # e.g. { # 'spark.port.maxRetries': '96', # # For local instances with many CPUs, let Spark use tons of ports # # 'spark.driver.memory': '8g', # 'spark.sql.files.maxPartitionBytes': int(8 * 1e6), # # Helps aid in reading larger Parqet datasets # # 'spark.jars.packages': '', # # To enable Delta # # 'spark.pyspark.python': 'python3', # # To force python3 # # 'spark.python.worker.reuse': False, # # Useful when using Tensorflow on Spark, because TF leaks :P # } CONF_KV = {} # Force all sessions to package and use code in this source root directory. # Specify a path to the library dir, i.e. the given path should be a # directory containing an file (but a is not needed). # If False-y, we'll try to auto-deduce this path from the calling code. SRC_ROOT = None # If you have more than one python module in `SRC_ROOT`, provide a list # of their names, or the list ['*'] to match all eligible modules (i.e. # every subdirectory with an SRC_ROOT_MODULES = [] ### Core Features #### create_egg() and Support @classmethod def _resolve_src_root(cls): src_root = cls.SRC_ROOT if src_root is None: try: import inspect frames = inspect.stack()#[2][0] for frame in frames: # Ignore frames associated with this class if ('oarphpy/' in frame.filename and hasattr(cls, frame.function)): continue # Ignore user using ::sess() context manager if ('' in frame.filename and frame.function == '__enter__'): continue # Ok we might have found the calling user program! candidate = os.path.abspath(frame.filename) i_am_in_a_module = os.path.exists( os.path.join(os.path.dirname(candidate), '')) if i_am_in_a_module: src_root = os.path.abspath(os.path.join(candidate, os.pardir)) break if not src_root: raise ValueError("Ran out of candidate stack frames") except Exception as e: "Failed to auto-resolve src root (error: %s) " "falling back to %s" % (e, cls.SRC_ROOT)) src_root = cls.SRC_ROOT if src_root and src_root.endswith(os.sep): src_root = src_root[:-1] return src_root @classmethod def _create_tmp_workdir(cls): # Create a working directory for the build and the egg. The spark context # dies upon process exit, so we'll keep the temp directory alive for # the same duration. import atexit import shutil import tempfile path = tempfile.mkdtemp(suffix='_oarphpy_eggbuild') atexit.register(lambda: shutil.rmtree(path)) return path @classmethod def _create_new_egg(cls, src_root, out_dir): assert os.path.exists(src_root) assert os.path.exists(out_dir) MODNAME = os.path.basename(src_root) if sys.version_info.major >= 3: # For whatever reason, # In py 2.7.x, setuptools wants the path of the python module # In py 3.x, setuptools wants the directory containing the python module src_root = os.path.dirname(src_root)"Using source root %s " % src_root) # Below is a programmatic way to run something like: # $ cd /opt/au && python clean bdist_egg # But we don't actually need a (!) # Based upon # See also: # * # * # * from setuptools.dist import Distribution from setuptools import PackageFinder MODNAME = MODNAME.replace('-', '_') # setuptools will do it anyways # By default we only want MODNAME in the egg, but we'll support # multiple modules (e.g. both oarphpy and oaprhpy_test). include = [MODNAME + '*'] if cls.SRC_ROOT_MODULES == ['*']: include = cls.SRC_ROOT_MODULES elif cls.SRC_ROOT_MODULES: include = [m + '*' for m in cls.SRC_ROOT_MODULES] # We want to confine setuptools to a clean directory because it'll create # stateful files and directories like `build/` setuptools_workdir = os.path.join(out_dir, 'workdir') util.cleandir(setuptools_workdir) dist = Distribution(attrs=dict( script_name='', script_args=[ 'clean', 'bdist_egg', '--dist-dir', out_dir, '--bdist-dir', setuptools_workdir, ], name=MODNAME, src_root=src_root, packages=PackageFinder.find(where=src_root, include=include), ))"Generating egg to %s ..." % out_dir) with util.with_cwd(setuptools_workdir): with util.quiet(): dist.parse_command_line() dist.run_commands() # NB: This approach didn't work so well: # Typically we want to give spark the egg from: # $ python bdist_egg # from setuptools.command import bdist_egg # cmd = bdist_egg.bdist_egg( # bdist_dir=os.path.dirname(setup_py_path), editable=True) # egg_path = os.path.join(out_dir, MODNAME + '-0.0.0' + _egg_py_suffix()) assert os.path.exists(egg_path), "Can't find {}".format(egg_path)"... done. Egg at %s" % egg_path) return egg_path
[docs] @classmethod def create_egg(cls, force_new=False): """Build a Python Egg from the current project and return a path to the artifact. The path may be to a cached, pre-computed egg only if not `force_new`. The 'current project' is either class-defaulted or auto-deduced. Why an Egg? `pyspark` supports zipfiles and egg files as Python artifacts. One might wish to use a wheel instead of an egg. See this excellent article and repo: * * The drawbacks to using a wheel include: * wheels often require native libraries to be installed (e.g. via `apt-get`), and those deps are typically best baked into the Spark Worker environment (versus installed every job run). * The `BdistSpark` example above is actually rather slow, especially when Tensorflow is a dependency, and `BdistSpark` must run before every job is submitted. * Spark treats wheels as zip files and unzips them on every run; this unzip operation can be very expensive if the zipfile contains large binaries (e.g. tensorflow) * Wheels are not yet officially supported: In comparison, an Egg provides the main benefits we want (to ship project code, often pre-committed code, to workers). """ if force_new or not hasattr(cls, '_cached_egg_path'): cls._cached_egg_path = '' if not cls._cached_egg_path: # Lazyily egg-ify `src_root`. NB: We don't delete any previous eggs # if `force_new` because some Spark session might still try to read # the old file. out_dir = cls._create_tmp_workdir() # Now decide the source root that we'll egg-ify. src_root = cls._resolve_src_root() if src_root:"Using source root %s " % src_root) cls._cached_egg_path = cls._create_new_egg(src_root, out_dir) return cls._cached_egg_path
## Primary Public Interface
[docs] @classmethod def getOrCreate(cls): """Spark sessions are typically instantiated using the `Builder.getOrCreate()` or `SparkSession.getOrCreate()` methods. This method is a drop-in replacement that uses class-specified defaults, includes the local python module as a PyFile, etc. Returns: A `pyspark.sql.session.SparkSession` instance. """ # TODO: take a SparkConf like SparkSession does. # Warm the cache; this call surfaces build errors *before* trying # to start spark. _ = cls.create_egg() from pyspark import sql builder = sql.SparkSession.builder if cls.MASTER is not None: builder = builder.master(cls.MASTER) elif 'SPARK_MASTER' in os.environ: # spark-submit honors this env var builder = builder.master(os.environ['SPARK_MASTER']) if cls.CONF is not None: builder = builder.config(conf=cls.CONF) if cls.CONF_KV is not None: for k, v in cls.CONF_KV.items(): builder = builder.config(k, v) # if cls.HIVE: # # FIXME see mebbe # # builder = builder.config("hive.metastore.warehouse.dir", '/tmp') # # builder = builder.config("spark.sql.warehouse.dir", '/tmp') # builder = builder.enableHiveSupport() try: spark = builder.getOrCreate() except Exception as e: # If the user just did `pip install pyspark`, then they might not # have java. Try to provide a helpful error message. if hasattr(e, 'args') and e.args: if 'Java gateway process exited' in e.args[0]: msg = """ This portion of OarphPy requires Spark, which in turn requires Java 8 or higher. Looks like you might be missing Java. To get it, try: $ apt-get update && apt-get install -y openjdk-8-jdk && echo JAVA_HOME=/usr/lib/jvm/java-8-openjdk-amd64 >> /etc/environment Original error: %s """ % (e,) raise Exception(msg) raise # To show info logs # spark.sparkContext.setLogLevel('INFO') egg_path = cls.create_egg() if egg_path: spark.sparkContext.addPyFile(egg_path) else: "Could not resolve a source root, skipping auto-egg inclusion.") return spark
@classmethod @contextmanager def sess(cls, *args): if args and args[0]: spark = args[0] yield spark else: spark = cls.getOrCreate() yield spark @classmethod def selftest(cls, modname=None): with cls.sess() as spark: test_pi(spark) if modname: test_egg(spark, modname=modname) return True
[docs]class LocalK8SSpark(SessionFactory): """Example of how to subclass the Spark factory above for use with K8S""" MASTER = 'k8s://' CONF_KV = { 'spark.kubernetes.container.image': 'my-docker-image', }
[docs] @classmethod def getOrCreate(cls): if '' not in cls.CONF_KV: # In practice, we need to set `host` explicitly in order to get the # proper driver IP address to the workers. This choice may break # cluster mode where the driver process will run in the cluster # instead of locally. This choice may also break in certain networking # setups. Spark networking is a pain :( cls.CONF_KV[''] = ( os.environ.get('SPARK_LOCAL_IP', util.get_non_loopback_iface())) return super(NBSpark, cls).getOrCreate()
# NBSpark is a session builder for local Jupyter notebooks. NBSpark also serves # as an example of how to subclass the Spark factory above for use with the # `sparkmonitor` jupyter package. # (FMI see demo # Try to find sparkmonitor's Spark interop jar try: HAVE_SPARK_2 = pyspark.__version__.startswith('2') import sparkmonitor SPARKMONITOR_HOME = os.path.dirname(sparkmonitor.__file__) SPARKMONITOR_JAR_PATH = os.path.join( SPARKMONITOR_HOME, 'listener_2.11.jar' if HAVE_SPARK_2 else 'listener_2.12.jar') # E.g. /usr/local/lib/python3.6/dist-packages/sparkmonitor/listener_x.xx.jar except Exception: SPARKMONITOR_JAR_PATH = '' def _dict_concat(*dicts): out = {} for d in dicts: out.update(**d) return out
[docs]class NBSpark(SessionFactory): """NBSpark is a session builder for local Jupyter notebooks. Also includes support for the `sparkmonitor` jupyter package, see """ # Enable support for shipping local modifications to library code without # needing to restart the notebook kernel. This feature will attempt to # re-build the user code Egg if there are new local changes and update the # Egg in the current Spark session. Enabling this feature leads to small # potential performance degredation; see below. # NB: If you have an RDD or DataFrame that uses Egg data structures or code, # you'll need to re-compute that RDD or DataFrame for updated code to take # effect; updating the egg does not invalidate any `cache()ed` or # `persist()ed` Spark data. MAYBE_REBUILD_EGG_EVERY_CELL_RUN = True # Options to support dynamic Egg updating. Firstly, `spark.files.overwrite` # is needed to accomodate updates to SparkFiles at all; `pyspark` will error # on update otherwise. Secondly, while updating SparkFiles works as # expected, updating *Python* modules can lead to 'zipimport malformed zip # header' crashes because: # * zipimport caches loaded modules: # # Clearing this private global can fix the error, but it must be # cleared immediately before import, which is an undesirable # constraint to impress upon the user. # * pyspark clears importlib caches, but only in the driver: # # A fix we've found reliable is to disable Python process re-use via # `spark.python.worker.reuse` so that the Spark Python worker always has # an empty zipimport cache prior to execution. REBUILD_EGG_OPTS = { 'spark.files.overwrite': 'true', 'spark.python.worker.reuse': 'false', } # These settings are required as part of standard `sparkmonitor` setup # SPARKMONITOR_OPTS = { 'spark.extraListeners': 'sparkmonitor.listener.JupyterSparkMonitorListener', 'spark.driver.extraClassPath': SPARKMONITOR_JAR_PATH, } CONF_KV = _dict_concat( REBUILD_EGG_OPTS if MAYBE_REBUILD_EGG_EVERY_CELL_RUN else {}, SPARKMONITOR_OPTS if SPARKMONITOR_JAR_PATH else {} )
[docs] @classmethod def getOrCreate(cls): spark = super(NBSpark, cls).getOrCreate() if cls.MAYBE_REBUILD_EGG_EVERY_CELL_RUN: def maybe_rebuild_egg(*args): egg_path = cls.create_egg() egg_time = os.path.getmtime(egg_path) src_root = cls._resolve_src_root() latest_src_time = os.path.getmtime(src_root) for path in util.all_files_recursive(src_root): latest_src_time = max(latest_src_time, os.path.getmtime(path)) if latest_src_time > egg_time:"Source has changed! Rebuilding Egg ...") egg_path = cls.create_egg(force_new=True) spark.sparkContext.addPyFile(egg_path) # Updated! # Patch into IPython / Jupyter / Google Colab to maybe rebuild the egg # every cell run. NB: # * `get_ipython()` is a global provided in any IPython-esque session # * Using pre_run_code_hook triggers a UserWarning / deprecation # warning, but the suggested events don't exist in older IPython # * In IPython >= 5, using `pre_run_code_hook` throws an error :( import IPython if IPython.version_info[0] >= 5: get_ipython().events.register('pre_execute', maybe_rebuild_egg) else: import warnings warnings.filterwarnings( action='ignore', message=r'Hook pre_run_code_hook is deprecated') get_ipython().set_hook('pre_run_code_hook', maybe_rebuild_egg) return spark
################################################################################ ### Spark SQL Type Adaption Utils ### TENSOR_AUTO_PACK_MIN_KBYTES = 2
[docs]class Tensor(object): """An ndarray-like object designed to store numpy arrays in Parquet / Spark SQL format. Spark's DenseVector and Matrix unfortunately don't support arbitrary tensor shape. Furthermore, `Tensor` stores data in an explicit order accessible to external readers such as Eigen in C++ or nd4j / BLAS wrappers in Java. """ __slots__ = ('shape', 'dtype', 'order', 'values', 'values_packed') @staticmethod def from_numpy(arr): t = Tensor() t.shape = list(arr.shape) t.dtype = t.order = 'C' # C-style row-major if arr.nbytes >= TENSOR_AUTO_PACK_MIN_KBYTES * (2**10): t.values = [arr.dtype.type(0)] # Need a non-empty array for type deduction t.values_packed = bytearray(arr.tobytes(order='C')) else: t.values = arr.flatten(order='C').tolist() t.values_packed = bytearray() return t @staticmethod def to_numpy(t): import numpy as np if t.values_packed: return np.reshape( np.frombuffer(t.values_packed, dtype=np.dtype(t.dtype)), t.shape) else: return np.array( np.reshape(t.values, t.shape, order=t.order), dtype=np.dtype(t.dtype))
class CloudpickeledCallableData(object): __slots__ = ('func_bytes', 'func_pyclass') def __init__(self, **kwargs): for k in self.__slots__: setattr(self, k, kwargs.get(k))
[docs]class CloudpickeledCallable(object): """Wraps callable objects (e.g. functions, including lambdas) and uses `cloudpickle` for serialization. Spark uses `cloudpickle` for serializing _tasks_ (e.g. map functions) but uses `pickle` for serializing _data_. In particular, data in a Spark RDD or DataFrame must be pickleable. `CloudpickeledCallable` provides a wrapper so that you can embed Python functions as *data* in RDDs, DataFrames, and other forms of data at rest (e.g. pickle files or Parquet data). Note that `cloudpickle` can be selective about how much of the object tree that it serializes; some imports and globals may get ignored. When you deserialize and invoke a `CloudpickeledCallable`, your interpreter should have the same (or similar) code as that used when serializing the callable, otherwise behavior may be difficult to predict. Note that `cloudpickle` can't handle non-serializable data like thread local variables, mutices, etc. `CloudpickeledCallable` won't work for all code. `CloudpickeledCallable` is useful for embedding flyweights in your dataset. (FMI see <> ) For example: >>> def load_matrix(path): >>> import numpy as np >>> return np.loadtxt(path) >>> my_db_row = { >>> 'path': 'path/to/data.txt', >>> 'factory': >>> CloudpickeledCallable(lambda: load_matrix('path/to/data.txt')) >>> } >>> import pickle >>> pickle.dump(my_db_row, open('dumped.pkl', 'wb')) Now if you deserialize `my_db_row` from disk and run `my_db_row['factory']()`, your `load_matrix()` helper will get invoked with the embedded path. Thus your `my_db_row` is a flyweight for the data in 'path/to/data.txt'. """ __slots__ = ('_func', '_func_pyclass') # Most Python 3.x distributions don't have newer than pickle 4, so stick # to version 4 for compatibility. FMI: # * # * PICKLE_PROTOCOL = min(4, pickle.HIGHEST_PROTOCOL) @staticmethod def _get_func_name(func): module = '<unknown_module>' if hasattr(func, '__module__'): module = func.__module__ if hasattr(func, '__name__'): return '%s.%s' % (module, func.__name__) else: import inspect src = inspect.getsource(func) lambda_varname = inspect.getsource(func).split('=')[0].strip() return '%s.%s' % (module, lambda_varname) def __init__(self, func=None): self._func = func if func is None: self._func_pyclass = '(empty CloudpickeledCallable)' else: self._func_pyclass = CloudpickeledCallable._get_func_name(func) @classmethod def empty(cls): return cls() def __call__(self, *args, **kwargs): assert self._func is not None, \ "This CloudpickeledCallable is the null CloudpickeledCallable" return self._func(*args, **kwargs) @classmethod def to_cc_data(cls, cc): if cc == cls.empty(): func_bytes = bytearray() else: import cloudpickle func_bytes = bytearray( cloudpickle.dumps(cc._func, protocol=cls.PICKLE_PROTOCOL)) return CloudpickeledCallableData( func_bytes=func_bytes, func_pyclass=cc._func_pyclass) @classmethod def from_cc_data(cls, ccd): if len(ccd.func_bytes) > 0: import cloudpickle func = cloudpickle.loads(ccd.func_bytes) else: func = None cc = cls(func=func) cc._func_pyclass = ccd.func_pyclass return cc def __getstate__(self): return (self.to_cc_data(self),) def __setstate__(self, d): cc = self.from_cc_data(d[0]) self._func = cc._func self._func_pyclass = cc._func_pyclass def __eq__(self, other): return self._func == other._func def __repr__(self): return "CloudpickeledCallable(_func_pyclass=%s)" % self._func_pyclass
[docs]class RowAdapter(object): """Transforms between custom objects and `pyspark.sql.Row`s used in Spark SQL or Parquet files. Use to encode numpy arrays and standard Python objects with a transparent Parquet schema that is accessible to other readers. Usage: * Use `RowAdapter.to_row()` in place of the `pyspark.sql.Row` constructor. * Call `RowAdapter.from_row()` on any `pyspark.sql.Row` instance, e.g. within an `` call or after a `DataFrame.collect()` call. * Decoding requires Python objects to have an available zero-arg __init__() Unfortunately, we can't use Spark's UDT API to embed this adapter (and obviate user calls) because UDTs require schema definitions. Furthermore, Spark <=2.x could not handle UDTs nested in maps or lists; i.e. [UDT()] (i.e. a list of UDTs) and {'foo': UDT()} (i.e. a map with UDT values) would cause Spark to crash. Moreover, for ndarray data, Spark's ml.linalg package coerces all data to floats. Benefits of RowAdapter: * Transparently handles numpy arrays and numpy boxed scalar types (e.g. np.float32). * Deep type adaptation; supports nested types. * At the decode stage, supports evolution of object types independent of the schema of data at rest: - Added object fields don't get set unless there's a recorded value - Removed object fields will get ignored - NB: Fields that change type will get set with the data at rest; if you need to change type, consider adding a new field. * Handles slotted Python objects (which Spark currently does not support), as well as un-slotted objects (where Spark supports automatic encoding but not decoding). * Enables saving objects and numpy arrays to Parquet in a format accessible to external systems (no additional SERDES library required) * Uses cloudpickle to serialize `CloudpickeledCallable`-wrapped functions. """ IGNORE_PROTECTED = False IGNORE_PRIVATE = True @staticmethod def _get_classname_from_obj(o): # Based upon module = o.__class__.__module__ # NB: __module__ might be null if module is None or module == str.__class__.__module__: return o.__class__.__name__ # skip "__builtin__" else: return module + '.' + o.__class__.__name__ @staticmethod def _get_class_from_path(path): # Pydoc is a bit safer and more robust than anything we can write import pydoc obj_cls, obj_name = pydoc.resolve(path) assert obj_cls return obj_cls @classmethod def to_schema(cls, obj): row = cls.to_row(obj) import pyspark.sql.types as pst return pst._infer_schema(row) @classmethod def to_row(cls, obj): from pyspark.sql import Row import numpy as np if isinstance(obj, Row): # Row is immutable, so we have to recreate row_dict = obj.asDict() return Row(**cls.to_row(row_dict)) elif isinstance(obj, np.ndarray): return cls.to_row(Tensor.from_numpy(obj)) elif isinstance(obj, CloudpickeledCallable): return cls.to_row(CloudpickeledCallable.to_cc_data(obj)) elif isinstance(obj, np.generic): # Those pesky boxed scalars like np.float32 return obj.item() elif hasattr(obj, '__slots__') or hasattr(obj, '__dict__'): def is_hidden(fname): # Check private first to disambiguate `_` vs `__` prefixes if cls.IGNORE_PRIVATE and fname.startswith('__'): return True if cls.IGNORE_PROTECTED and fname.startswith('_'): return True return False tag = ('__pyclass__', RowAdapter._get_classname_from_obj(obj)) if hasattr(obj, '__slots__'): obj_attrs = [ (k, cls.to_row(getattr(obj, k))) for k in obj.__slots__ if not is_hidden(k) ] else: obj_attrs = [ (k, cls.to_row(v)) for k, v in obj.__dict__.items() if not is_hidden(k) ] return Row(**dict([tag] + obj_attrs)) elif isinstance(obj, list): return [cls.to_row(x) for x in obj] elif isinstance(obj, tuple): return tuple(cls.to_row(x) for x in obj) # Spark will typically transform to list elif isinstance(obj, dict): return dict((k, cls.to_row(v)) for k, v in obj.items()) else: return obj @classmethod def from_row(cls, row): if hasattr(row, '__fields__'): # Probably a pyspark Row instance; try to convert it to an object if '__pyclass__' in row.__fields__: obj_cls_name = row['__pyclass__'] obj_cls = RowAdapter._get_class_from_path(obj_cls_name) obj = obj_cls.__new__(obj_cls) attr_to_default = {} if hasattr(obj, '__attrs_attrs__'): # In the case that an attrs-based class has now added an attribute # since being row-ified, the row will be missing a value for that # new attribute. Use the attrs-specified default if available. attr_to_default = dict((, a.default) for a in obj.__attrs_attrs__) if hasattr(obj, '__slots__'): for k in obj.__slots__: if k in row: setattr(obj, k, cls.from_row(row[k])) elif k in attr_to_default: setattr(obj, k, attr_to_default[k]) elif hasattr(obj, '__dict__'): for k, v in row.asDict().items(): if k == '__pyclass__': continue obj.__dict__[k] = cls.from_row(v) else: raise ValueError( "Object %s no longer has __slots__ nor __dict__" % obj_cls_name) if hasattr(obj, '__attrs_init__'): obj.__attrs_post_init__() if isinstance(obj, Tensor): obj = Tensor.to_numpy(obj) elif isinstance(obj, CloudpickeledCallableData): obj = CloudpickeledCallable.from_cc_data(obj) return obj else: # No known __pyclass__, so fall back to generic from pyspark.sql import Row attrs = dict((k, cls.from_row(v)) for k, v in row.asDict().items()) return Row(**attrs) elif isinstance(row, list): return [cls.from_row(x) for x in row] elif isinstance(row, dict): return dict((k, cls.from_row(v)) for k, v in row.items()) return row
################################################################################ ### Spark-Tensorflow Utils ###
[docs]def spark_df_to_tf_dataset( spark_df, shard_col, spark_row_to_tf_element, # E.g. lambda r: (np.array[0],), tf_element_types, # E.g. [tf.int64] tf_output_shapes=None, non_deterministic_element_order=True, num_reader_threads=-1, logging_name='spark_tf_dataset'): """Create a that reads from the Spark Dataframe `spark_df`. Executes parallel reads using the Tensorflow's internal (native code) threadpool. Each thread reads a single Spark partition at a time. This utility is similar to Petastorm's `make_reader()` but is far simpler and leverages Tensorflow's build-in threadpool (so we let Tensorflow do the read scheduling). Status: alpha-quality; some perf quirks. Args: spark_df (pyspark.sql.DataFrame): Read from this Dataframe. shard_col (str): Implicly shard the dataset using this column; read one shard per reader thread at a time to conserve memory. spark_row_to_tf_element (func): Use this function to map each pyspark.sql.Row in `spark_df` to a tuple that represents a single element of the induced TF Dataset. tf_element_types (tuple): The types of the elements that `spark_row_to_tf_element` returns; e.g. (tf.float32, tf.string). tf_output_shapes (tuple): Optionally specify the shape of the output of `spark_row_to_tf_element`; e.g. (tf.TensorShape([]), tf.TensorShape([None])) (where the former return element is a single scalar and the latter is a list) non_deterministic_element_order (bool): Allow the resulting to have elements in non-deterministic order for speed gains. num_reader_threads (int): Tell Tensorflow to use this many reader threads, or use -1 to provision one reader thread per CPU core. logging_name (str): Log progress under this name. Returns: The induced TF Datset with one element per row in `spark_df`. """ if num_reader_threads < 1: import multiprocessing num_reader_threads = max(1, int(0.5 * multiprocessing.cpu_count())) # NB: Tensorflow prefetch appears to launch twice as many threads # as you ask for :P df = spark_df # Each Tensorflow reader thread will read a single Spark partition # Breadcrumbs: We tried using the built-in spark partition id, but # Spark appears to want to re-compute this for every partition read, # leading to an O(n^2) overhead cost (for `n` partitions). # df = spark_df.withColumn('shard_col', spark_partition_id())"Getting shards ...") pids = x: x).collect()"... found %s shards." % (len(pids),)) import threading import tensorflow as tf pid_ds = class PartitionToRows(object): def __init__(self): self.overall_thruput = util.ThruputObserver( name=logging_name, log_on_del=True, n_total=df.count()) # TODO: count() can be slow self.overall_thruput.start_block() self.lock = threading.Lock() def __call__(self, pid): # Convert pesky numpy boxed numeric types if needed import numpy as np if isinstance(pid, np.generic): pid = pid.item() part_df = df.filter(df[shard_col] == pid) part_rdd = part_df.rdd.repartition(100) rows ="Reading partition %s " % pid) t = util.ThruputObserver(name='Partition %s' % pid, log_on_del=True) t.start_block() for row in rows: yield row t.update_tallies(n=1, num_bytes=util.get_size_of_deep(row)) t.stop_block()"Done reading partition %s, stats:\n %s" % (pid, t)) with self.lock: # Since partitions are read in parallel, we need to maintain # independent timing stats for the main thread self.overall_thruput.stop_block(n=t.n, num_bytes=t.num_bytes) self.overall_thruput.maybe_log_progress(every_n=1) self.overall_thruput.start_block() ds = pid_ds.interleave( lambda pid_t: \ PartitionToRows(), args=(pid_t,), output_types=tf_element_types, output_shapes=tf_output_shapes), cycle_length=num_reader_threads, num_parallel_calls=num_reader_threads) # Breadcrumbs: alternative TF API # ds = pid_ds.apply( # # lambda pid_t: # # get_rows, # args=(pid_t,), # output_types=tf_element_types), # cycle_length=num_reader_threads, # sloppy=non_deterministic_element_order)) return ds