Source code for oarphpy.util.thruput_observer

# Copyright 2023 Maintainers of OarphPy
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time
from contextlib import contextmanager


[docs]class ThruputObserver(object): """A utility for measuring the runtime and throughput of a subroutine. Similar in spirit to `tqdm`, except `ThruputObserver`: * Tracks not just time but a size metric (e.g. memory) in bytes * Reports percentiles * Simply logs strings and is not terminal-interactive While `tqdm` is useful for notebooks, `ThruputObserver` seeks to be more useful for longer-running batch jobs. """ def __init__( self, name='', log_on_del=False, only_stats=None, log_freq=100, n_total=None, n_total_chunks=None): self.n = 0 self.num_bytes = 0 self.ts = [] self.name = name self.log_on_del = log_on_del self.only_stats = only_stats or [] self.n_total = max(n_total, 1) if n_total is not None else None self.n_total_chunks = ( max(n_total_chunks, 1) if n_total_chunks is not None else None) self._start = None self.__log_freq = log_freq self.__last_log = 0
[docs] @contextmanager def observe(self, n=0, num_bytes=0): """ NB: contextmanagers appear to be expensive due to object creation. Use ThurputObserver#{start,stop}_block() for <10ms ops. FMI https://stackoverflow.com/questions/34872535/why-contextmanager-is-slow """ self.start_block() yield self self.stop_block(n=n, num_bytes=num_bytes)
def start_block(self): self._start = time.time() def update_tallies(self, n=0, num_bytes=0, new_block=False): self.n += n self.num_bytes += num_bytes if new_block: self.stop_block() self.start_block() def stop_block(self, n=0, num_bytes=0): end = time.time() self.n += n self.num_bytes += num_bytes if self._start is not None: self.ts.append(end - self._start) self._start = None def maybe_log_progress(self, every_n=-1): if every_n >= 0: self.__log_freq = every_n if self.n >= self.__last_log + self.__log_freq: from oarphpy.util import log log.info("Progress for \n" + str(self)) self.__last_log = self.n # Track last log because `n` may increase inconsistently if every_n == -1 and (self.n >= (1.7 * self.__log_freq)): self.__log_freq = int(1.7 * self.__log_freq) # Exponentially decay logging frequency. Don't decay quite as # fast as Vowpal Wabbit did, though.
[docs] @classmethod def union(cls, thruputs): """Support reduction for use in e.g. MapReduce jobs as a counter""" u = cls() for t in thruputs: u += t return u
@property def total_time(self): return sum(self.ts) def get_stats(self): import numpy as np from humanfriendly import format_size from humanfriendly import format_timespan total_time = self.total_time stats = [ ('Thruput', ''), ('N thru', (self.n if self.n_total is None else '%s (of %s)' % (self.n, self.n_total))), ('N chunks', (len(self.ts) if self.n_total_chunks is None else '%s (of %s)' % (len(self.ts), self.n_total_chunks))), ('Total time', format_timespan(total_time) if total_time else '-'), ('Total thru', format_size(self.num_bytes)), ('Rate', format_size(self.num_bytes / total_time) + ' / sec' if total_time else '-'), ('Hz', "%2.f" % (float(self.n) / total_time) if total_time else '-'), ] percent_complete = None if self.n_total is not None: percent_complete = 100. * float(self.n) / self.n_total elif self.n_total_chunks is not None: percent_complete = 100. * float(len(self.ts)) / self.n_total_chunks if percent_complete is not None: eta_sec = ( (100. - percent_complete) * (total_time / (percent_complete + 1e-10))) stats.extend([ ('Progress', ''), ('Percent Complete', "%2f" % percent_complete), ('Est. Time To Completion', format_timespan(eta_sec)), ]) if len(self.ts) >= 2: format_t = lambda t: format_timespan(t, detailed=True) stats.extend([ ('Latency (per chunk)', ''), ('Avg', format_t(np.mean(self.ts))), ('p50', format_t(np.percentile(self.ts, 50))), ('p95', format_t(np.percentile(self.ts, 95))), ('p99', format_t(np.percentile(self.ts, 99))), ]) if self.only_stats: stats = tuple( (name, value) for name, value in stats if name in self.only_stats ) return stats def __iadd__(self, other): self.n += other.n self.num_bytes += other.num_bytes self.ts.extend(other.ts) if not self.name: self.name = other.name if self.n_total is None and other.n_total is not None: self.n_total = other.n_total if self.n_total_chunks is None and other.n_total_chunks is not None: self.n_total_chunks = other.n_total_chunks return self def __str__(self): import tabulate stats = self.get_stats() summary = tabulate.tabulate(stats) prefix = '[Pid:%s Id:%s]' % (os.getpid(), id(self)) if self.name: prefix = '%s %s' % (self.name, prefix) summary = '%s\n%s' % (prefix, summary) return summary def __repr__(self): # pprint and some other utils use __repr__ instead of __str__ return str(self) def __del__(self): if self.log_on_del: self.stop_block() from oarphpy.util import create_log log = create_log() log.info('\n' + str(self) + '\n') def __gt__(self, v): # Support use in containers.Counter if isinstance(v, self.__class__): return self.name > v.name else: return self.n > v def __lt__(self, v): # Support use in containers.Counter if isinstance(v, self.__class__): return self.name < v.name else: return self.n < v def __add__(self, other): # Support use in containers.Counter if isinstance(other, self.__class__): return self.union((self, other)) else: return self
[docs] @staticmethod def monitoring_tensor(name, tensor, **observer_init_kwargs): """Monitor the size of the given tensorflow `Tensor` and record a text TF Summary with the contents of this ThruputObserver.""" class Observer(object): def __init__(self, dtype_size_bytes): self.observer = ThruputObserver(name=name, **observer_init_kwargs) self.dtype_size_bytes = dtype_size_bytes def __call__(self, t_shape): import numpy as np n = t_shape[0] num_bytes = np.prod(t_shape) * self.dtype_size_bytes self.observer.stop_block(n=n, num_bytes=num_bytes) self.observer.maybe_log_progress() # Tensorboard is very picky about wanting Markdown :P import tabulatehelper as th stats = self.observer.get_stats() out = th.md_table(stats, headers=[name]) self.observer.start_block() return out import tensorflow as tf obs_str_tensor = tf.compat.v1.py_func( Observer(tensor.dtype.size), [tf.shape(tensor)], tf.string) tf.summary.text(name + '/ThruputObserver', obs_str_tensor) return obs_str_tensor
[docs] @staticmethod def wrap_func(func, **observer_init_kwargs): """Decorate `func` and observe a block on each call""" class MonitoredFunc(object): __slots__ = ('func', 'observer') def __init__(self, func, observer_init_kwargs): self.func = func self.observer = ThruputObserver(**observer_init_kwargs) def __call__(self, *args, **kwargs): from oarphpy.util.misc import get_size_of_deep self.observer.start_block() ret = self.func(*args, **kwargs) self.observer.stop_block(n=1, num_bytes=get_size_of_deep(ret)) self.observer.maybe_log_progress() return ret return MonitoredFunc(func, observer_init_kwargs)
@staticmethod def to_monitored_generator(gen, **observer_init_kwargs): from oarphpy.util.misc import get_size_of_deep class MonitoredGen(object): __slots__ = ('gen', 'observer') def __init__(self, gen): self.gen = gen self.observer = ThruputObserver(**observer_init_kwargs) def __iter__(self): return self def __next__(self): return self.next() def next(self): self.observer.start_block() if hasattr(self.gen, '__next__'): x = self.gen.__next__() else: x = self.gen.next() self.observer.stop_block(n=1, num_bytes=get_size_of_deep(x)) self.observer.maybe_log_progress() return x def __str__(self): return str(self.observer) return MonitoredGen(gen)