Source code for oarphpy.util.misc

# Copyright 2023 Maintainers of OarphPy
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import os
import pickle
import shutil
import subprocess
import sys
import tempfile
import threading
import time

from contextlib import contextmanager


def np_truthy(v):
  import numpy as np
  if isinstance(v, np.ndarray):
    return bool(v.size)
  else:
    return bool(v)


try:
  import six
  FIXED_SIZE_TYPES = tuple(itertools.chain.from_iterable(
        (six.integer_types, (float,))))
  INTEGRAL_TYPES = tuple(itertools.chain.from_iterable(
        (six.string_types, six.class_types,
          (bytes, bytearray), FIXED_SIZE_TYPES)))
except Exception as e:
  FIXED_SIZE_TYPES = tuple()
  INTEGRAL_TYPES = tuple()

[docs]def get_size_of_deep(v): """(Hacky) Get size of the value `v` in bytes. Does not rely on a more precise library like guppy or pympler. Intended for values `v` that contain large binary blobs.""" # NB: requires `six` module! if isinstance(v, INTEGRAL_TYPES): # These types can trigger expensive recursion unless we base case them return sys.getsizeof(v) elif hasattr(v, 'nbytes'): return v.nbytes elif hasattr(v, 'items'): # Typically a dict return sum( get_size_of_deep(key) + get_size_of_deep(value) for key, value in v.items()) elif hasattr(v, '__len__') and hasattr(v, '__getitem__'): # Typically a list or tuple if len(v) == 0: return 0 if isinstance(v[0], FIXED_SIZE_TYPES): return len(v) * sys.getsizeof(v[0]) else: return sum(get_size_of_deep(v[i]) for i in range(len(v))) elif hasattr(v, '__next__') or ( sys.version_info[0] == 2 and hasattr(v, 'next')): # Don't consume generators return sys.getsizeof(v) elif hasattr(v, '__iter__'): # Some other sequence type, but NOT a generator (see above) return sum(get_size_of_deep(el) for el in iter(v)) elif hasattr(v, '__dict__'): return sum( get_size_of_deep(dk) + get_size_of_deep(dv) for dk, dv in v.__dict__.items()) elif hasattr(v, '__slots__'): return sum(get_size_of_deep(getattr(v, k)) for k in v.__slots__) else: return sys.getsizeof(v)
[docs]def stable_hash(x): """A hash of `x` that is stable across program runs. Background: As of Python 3, `hash()` is given a fresh seed every time the interpret starts; hash codes are not stable without setting the env var `PYTHONHASHSEED`. Can we just simply adjust for the seed programmatically? Note that while it is possible to get the hash seed at runtime: * https://stackoverflow.com/questions/41088635/extract-hash-seed-in-unit-testing Python doesn't use the seed in an easily-inverted way: * https://github.com/python/cpython/blob/630c8df5cf126594f8c1c4579c1888ca80a29d59/Python/pyhash.c#L237 Thus for stability (and even light portability), we leverage Python serialization to provide a key for `x`. """ try: import cloudpickle as pkl except ImportError: try: import pyspark.cloudpickle as pkl # Cloudpickle embedded starting in Spark 3.x except ImportError: import pickle as pkl PICKLE_PROTOCOL = 3 if sys.version_info[0] == 3 else 2 # NB: Python v3.6.8 defaults to 3 but v3.8.2 defaults to 4 key = pkl.dumps(x, protocol=PICKLE_PROTOCOL) import hashlib return int(hashlib.md5(key).hexdigest(), 16)
[docs]def ichunked(seq, n): """Generate chunks of `seq` of size (at most) `n`. More efficient and less junk than itertools recipes version using izip_longest... """ n = max(1, n) seq = iter(seq) while True: chunk = tuple(itertools.islice(seq, n)) if chunk: yield chunk else: break
[docs]def roundrobin(*seqs): """Generate a sequence pulling round-robin from each of `seqs`; similar to `itertools.roundrobin()` recipe but (1) won't hide nested `StopIteration`s (2) uses a queue to reduce dynamic allocations """ import six from collections import deque its = deque((iter(s) for s in seqs), maxlen=len(seqs)) while its: it = its.popleft() try: v = six.next(it) except StopIteration: continue yield v its.append(it)
[docs]def as_row_of_constants(inst): """Row-ify an object instance `inst` keeping only the "class-constant" attributes of `inst`, i.e. the members with UPPERCASE names. >>> class Foo(object): ... CONST = 5 ... def __init(self, x): ... self.x = x >>> as_row_of_constants(Foo()) OrderedDict([('CONST', 5)]) """ from collections import OrderedDict row = OrderedDict() def is_constant_field(name): return not name.startswith('_') and name.isupper() for attr in sorted(dir(inst)): if is_constant_field(attr): v = getattr(inst, attr) if isinstance(v, (str, float, int, list, dict)): row[attr] = v else: subrow = as_row_of_constants(v) if subrow: if hasattr(v, '__name__'): row[attr] = v.__name__ else: row[attr] = v.__class__.__name__ for col, colval in subrow.items(): row[attr + '_' + col] = colval return row
def fname_timestamp(random_suffix=True): timestr = time.strftime("%Y-%m-%d-%H_%M_%S") if random_suffix: # Ideally we use a UUID but idk # https://stackoverflow.com/a/2257449 import random import string NUM_CHARS = 5 chars = ( random.choice(string.ascii_uppercase + string.digits) for _ in range(NUM_CHARS) ) timestr = timestr + "." + ''.join(chars) return timestr
[docs]class Proxy(object): """A thin wrapper around an `instance` that supports custom semantics.""" __slots__ = ('instance',) def __init__(self, instance): self.instance = instance def __getattr__(self, name): return getattr(self.instance, name) def _on_delete(self): pass def __del__(self): self._on_delete() del self.instance
[docs]@contextmanager def quiet(): """Silence stdout and stderr for any commands in this context""" old_stdout = sys.stdout old_stderr = sys.stderr f = open(os.devnull, 'w') # python3, pytest, and docker can combine to create an # odd 'ValueError: underlying buffer has been detached' # error unless we detach() now. import codecs f = codecs.getwriter("utf-8")(f.detach()) new_stdout = sys.stdout = f new_stderr = sys.stderr = f try: yield new_stdout, new_stderr finally: try: new_stdout.seek(0) new_stderr.seek(0) except Exception: pass sys.stdout = old_stdout sys.stderr = old_stderr
[docs]@contextmanager def with_cwd(path): """Use a current working directory of `path` for this context""" old_cwd = os.getcwd() os.chdir(path) try: yield finally: os.chdir(old_cwd)
@contextmanager def imageio_ignore_warnings(): # Imageio needs some fix: https://github.com/imageio/imageio/issues/376 import imageio.core.util def silence_imageio_warning(*args, **kwargs): pass old = imageio.core.util._precision_warn imageio.core.util._precision_warn = silence_imageio_warning try: yield finally: imageio.core.util._precision_warn = old
[docs]def to_png_bytes(arr): """Typically used for testing; when comparing images, we need to compare actual and expected via image bytes b/c imageio does some sort of subtle color normalization and we want our fixtures to simply be user-readable PNGs.""" import io import imageio buf = io.BytesIO() imageio.imwrite(buf, arr, 'png') return buf.getvalue()
[docs]def to_jpeg_bytes(arr, quality=100): """Given a numpy array image `arr`, return the image encoded as a jpeg buffer.""" import io import imageio buf = io.BytesIO() imageio.imwrite(buf, arr, 'jpg', quality=quality) return buf.getvalue()
[docs]def get_jpeg_size(jpeg_bytes): """Get the size of a JPEG image without reading and decompressing the entire file. Based upon: * https://github.com/shibukawa/imagesize_py/blob/master/imagesize.py#L87 """ import struct from io import BytesIO buf = BytesIO(jpeg_bytes) head = buf.read(24) if not head.startswith(b'\377\330'): raise ValueError("Invalid JPEG header") buf.seek(0) size = 2 ftype = 0 while not 0xc0 <= ftype <= 0xcf or ftype in [0xc4, 0xc8, 0xcc]: buf.seek(size, 1) byte = buf.read(1) while ord(byte) == 0xff: byte = buf.read(1) ftype = ord(byte) size = struct.unpack('>H', buf.read(2))[0] - 2 # Now we're at a SOFn block buf.seek(1, 1) # Skip `precision' byte. height, width = struct.unpack('>HH', buf.read(4)) return width, height
def run_cmd(cmd, collect=False, nolog=False): dolog = not nolog cmd = cmd.replace('\n', '').strip() if dolog: from oarphpy.util.log import create_log log = create_log() log.info("Running %s ..." % cmd) if collect: out = subprocess.check_output(cmd, shell=True, stderr=subprocess.STDOUT) else: subprocess.check_call(cmd, shell=True) out = None if dolog: log.info("... done with %s " % cmd) return out def get_non_loopback_iface(): # Based upon https://stackoverflow.com/a/1267524 import socket non_loopbacks = [ ip for ip in socket.gethostbyname_ex(socket.gethostname())[2] if not ip.startswith("127.") ] if non_loopbacks: return non_loopbacks[0] # Get an iface that can connect to Google DNS ... s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.connect(("8.8.8.8", 80)) iface = s.getsockname()[0] s.close() return iface _SYS_INFO_LOCK = threading.Lock() def get_sys_info(): global _SYS_INFO_LOCK from oarphpy.util.log import create_log log = create_log() log.info("Listing system info ...") info = {} info['filepath'] = os.path.abspath(__file__) info['PYTHONPATH'] = ':'.join(sys.path) @contextmanager def atomic_ignore_exceptions(): with _SYS_INFO_LOCK: try: yield except Exception: pass def safe_cmd(cmd): with atomic_ignore_exceptions(): return run_cmd(cmd, collect=True) or '' # NB: some commands, especially nvidia-smi, crash under concurrent access info['nvidia_smi'] = safe_cmd('nvidia-smi') info['cpuinfo'] = safe_cmd('cat /proc/cpuinfo') info['disk_free'] = safe_cmd('df -h') info['ifconfig'] = safe_cmd('ifconfig') info['memory'] = safe_cmd('free -h') TEST_URI = 'https://raw.githubusercontent.com/pwais/au2018/master/README.md' info['have_internet'] = bool(safe_cmd('curl ' + TEST_URI)) import socket info['hostname'] = socket.gethostname() info['host'] = get_non_loopback_iface() import multiprocessing info['n_cpus'] = multiprocessing.cpu_count() log.info("... got all system info.") return info def copy_n_from_zip(src, dest, n): log.info("Copying %s of %s -> %s ..." % (n, src, dest)) mkdir(os.path.split(dest)[0]) import zipfile with zipfile.ZipFile(src) as zin: with zipfile.ZipFile(dest, mode='w') as zout: for name in itertools.islice(sorted(zin.namelist()), n): zout.writestr(name, zin.read(name)) log.info("... done") def mkdir(path): import errno try: os.makedirs(path) except OSError as exc: if exc.errno == errno.EEXIST and os.path.isdir(path): pass else: raise def rm_rf(path): shutil.rmtree(path) def all_files_recursive(root_dir, pattern='*'): import fnmatch paths = [] for root, dirs, files in os.walk(root_dir): for basename in files: if fnmatch.fnmatch(basename, pattern): paths.append(os.path.join(root, basename)) return paths def cleandir(path): mkdir(path) rm_rf(path) mkdir(path) def missing_or_empty(path): if not os.path.exists(path): return True else: for p in all_files_recursive(path): return False return True def is_stupid_mac_file(path): fname = os.path.basename(path) return fname.startswith('._') or fname in ('.DS_Store',)
[docs]def download(uri, dest, try_expand=True): """Fetch `uri`, which is a file or archive, and put in `dest`, which is either a destination file path or destination directory.""" import math from oarphpy.util import log from oarphpy.util.thruput_observer import ThruputObserver # Import urllib try: import urllib.error as urlliberror import urllib.request HTTPError = urlliberror.HTTPError URLError = urlliberror.URLError except ImportError: import urllib2 as urllib HTTPError = urllib.HTTPError URLError = urllib.URLError import urllib.request import patoolib if os.path.exists(dest): return fname = os.path.split(uri)[-1] tempdest = tempfile.NamedTemporaryFile(suffix='_' + fname) try: log.info("Fetching %s ..." % uri) response = urllib.request.urlopen(uri) size = int(response.info().get('Content-Length').strip()) log.info("... downloading %s MB ..." % (float(size) * 1e-6)) chunk_size = min(size, 8192) t = ThruputObserver( name=uri, log_freq=10000, n_total=math.ceil(size / chunk_size)) while True: with t.observe(n=1, num_bytes=chunk_size): data = response.read(chunk_size) if not data: break tempdest.write(data) t.maybe_log_progress() sys.stdout.write("") sys.stdout.flush() log.info("... fetched!") except HTTPError as e: raise Exception("[HTTP Error] {code}: {reason}." .format(code=e.code, reason=e.reason)) except URLError as e: raise Exception("[URL Error] {reason}.".format(reason=e.reason)) tempdest.flush() if try_expand: try: # Is it an archive? expand! mkdir(dest) patoolib.extract_archive(tempdest.name, outdir=dest) log.info("Extracted archive.") except Exception: # Just move the file shutil.move(tempdest.name, dest) tempdest.delete = False else: shutil.move(tempdest.name, dest) tempdest.delete = False log.info("Downloaded to %s" % dest)
### GPU Utils GPUS_UNRESTRICTED = None class GPUInfo(object): __slots__ = ( 'index', 'name', 'mem_util_frac', 'mem_free', 'mem_used', 'mem_total' ) def __str__(self): data = ', '.join((k + '=' + str(getattr(self, k))) for k in self.__slots__) return 'GPUInfo(' + data + ')' def __eq__(self, other): return all(getattr(self, k) == getattr(other, k) for k in self.__slots__) @staticmethod def from_nvidia_smi(row): info = GPUInfo() info.index = int(row['index']) info.name = row['name'] info.mem_util_frac = float(row['utilization.memory [%]']) / 100. def to_num_bytes(s): return int(s) * int(1e6) info.mem_free = to_num_bytes(row['memory.free [MiB]']) info.mem_used = to_num_bytes(row['memory.used [MiB]']) info.mem_total = to_num_bytes(row['memory.total [MiB]']) return info @staticmethod def get_infos(only_visible=True): from oarphpy.util import log # Much safer than pycuda and Tensorflow, which can both segfault if the # nvidia driver is absent :P try: cmd = "nvidia-smi --query-gpu=index,name,utilization.memory,name,memory.total,memory.free,memory.used --format=csv,nounits" out = run_cmd(cmd, collect=True) except Exception as e: log.info("No GPUs found") return [] # NB: nvidia doesn't actually return *valid* csv. # Why would they? They make hardware, not software! out = out.decode('utf-8') out = out.replace(', ', ',') import csv rows = list(csv.DictReader(out.split('\n'))) infos = [GPUInfo.from_nvidia_smi(row) for row in rows] log.info("Found GPUs: %s" % ([str(info) for info in infos],)) if only_visible: if 'CUDA_VISIBLE_DEVICES' in os.environ: allowed_gpus = set( int(g) for g in os.environ['CUDA_VISIBLE_DEVICES'].split(',') if g) log.info("... restricting to GPUs %s ..." % (allowed_gpus,)) infos = [ info for info in infos if info.index in allowed_gpus ] return infos @staticmethod def num_total_gpus(): return len(GPUInfo.get_infos())