Source code for oarphpy.util.tfutil

# Copyright 2023 Maintainers of OarphPy
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import contextmanager

from oarphpy.util.misc import GPUS_UNRESTRICTED


def tf_create_session_config(restrict_gpus=GPUS_UNRESTRICTED, extra_opts=None):
  extra_opts = extra_opts or {}
  
  import tensorflow as tf
  config = tf.compat.v1.ConfigProto()

  tf_session_config_restrict_gpus(config, restrict_gpus=restrict_gpus)
  config.log_device_placement = False
  
  # # Enable CPU XLA!
  # config.graph_options.optimizer_options.global_jit_level = \
  #   tf.OptimizerOptions.ON_1

  for k, v in extra_opts.items():
    setattr(config, k, v)
  return config


def tf_session_config_restrict_gpus(config, restrict_gpus=GPUS_UNRESTRICTED):
  if restrict_gpus is GPUS_UNRESTRICTED:
    config.allow_soft_placement = True
  else:
    config.device_count['GPU'] = len(restrict_gpus)
    config.gpu_options.visible_device_list = (
      ','.join(str(g) for g in restrict_gpus))
  config.gpu_options.allow_growth = True


def tf_create_session(config=None):
  config = config or tf_create_session_config()

  import tensorflow as tf
  sess = tf.compat.v1.Session(config=config)
  return sess


def tf_cpu_session_config():
  return tf_create_session_config(restrict_gpus=[])


def tf_cpu_session(config=None):
  if not config:
    config = tf_cpu_session_config()
  else:
    tf_session_config_restrict_gpus(config, restrict_gpus=[])
  return tf_create_session(config=config)


@contextmanager
def tf_data_session(dataset, sess=None, config=None):
  import tensorflow as tf

  # Must declare these before the graph gets finalized below
  iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
  next_element = iterator.get_next()
  
  # Silly way to iterate over a tf.Dataset
  # https://stackoverflow.com/a/47917849
  sess = sess or tf_cpu_session()
  with sess as sess:
    def iter_dataset():
      # see MonitoredTrainingSession.StepContext
      while True:
        try:
          # with loop_until_data_exausted():
          yield sess.run(next_element)
        except (tf.errors.OutOfRangeError, StopIteration):
          break
    yield sess, iter_dataset


# TPUS don't support strings :P  but they do support lists of integers
def to_padded_codepoint_list(s, max_len=5000):
  s_enc = s.encode('utf-8')
  return list(s_enc[:max_len]) + ([0] * (max_len - len(s_enc)))


def from_padded_codepoint_list(l, null_value=0):
  return ''.join(
              chr(el) for el in l # May only work in python3?
              if (null_value is None or el != null_value))


[docs]def give_me_frozen_graph( checkpoint, nodes=None, blacklist=None, base_graph=None, sess=None, saver=None): """ Tensorflow has several ways to load checkpoints / graph artifacts. It's impossible to know if some API is stable or if tomorrow somebody will invent something new and break everything becaus PyTorch is shiny (e.g. TF Eager). Sam Abrahams wrote a book on Tensorflow ( https://www.amazon.com/TensorFlow-Machine-Intelligence-hands--introduction-ebook/dp/B01IZ43JV4/ ) and one time couldn't tell me definitively which API to use. What's more is that freeze_graph.py is an optional script instead of a library module in Tensorflow. Chaos!! So, based upon spark-dl's `strip_and_freeze_until()` ( https://github.com/databricks/spark-deep-learning/blob/4daa1179f498df4627310afea291133539ce7001/python/sparkdl/graph/utils.py#L199 ), here's a utility for getting a frozen, serializable, pyspark-friendly graph from a checkpoint artifact metagraph thingy I have no idea. """ def op_name(v): name = v if hasattr(v, 'name'): name = v.name if ':' not in name: return name toks = name.split(':') assert len(toks) <= 2, (toks, v, name) return toks[0] import tensorflow as tf graph = base_graph or tf.Graph() if nodes: ops = [graph.get_operation_by_name(op_name(n)) for n in nodes] else: ops = graph.get_operations() # if blacklist: # for n in blacklist: # ops.remove(graph.get_operation_by_name(op_name(n))) with graph.as_default(): with (sess or tf_cpu_session()) as sess: saver = saver or tf.train.Saver() log.info("Reading from checkpoint %s ..." % checkpoint) saver.restore(sess, checkpoint) log.info("... done.") gdef_frozen = tf.graph_util.convert_variables_to_constants( sess, graph.as_graph_def(add_shapes=True), [op.name for op in ops]) # variable_names_blacklist=blacklist) return gdef_frozen
[docs]def tf_variable_summaries(var, prefix=''): """Create Tensorboard summaries showing basic stats of the variable `var`.""" import tensorflow as tf if prefix: prefix = prefix + '/' else: def var_name(v): """Magic: get the name of the variable that the caller passed to `tf_variable_summaries()`""" import inspect lcls = inspect.stack()[2][0].f_locals for name in lcls: if id(v) == id(lcls[name]): return name return None prefix = var_name(var) if not prefix: prefix = str(var.name) idx = prefix.find('/') if idx >= 0: prefix = prefix[:prefix.find('/')] # Exclude slashes in var name idx = prefix.find(':') if idx >= 0: prefix = prefix[:prefix.find(':')] # Exclude : too with tf.variable_scope(prefix): mean = tf.reduce_mean(var) tf.summary.scalar('mean', mean) with tf.name_scope('stddev'): stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) tf.summary.scalar('stddev', stddev) tf.summary.scalar('max', tf.reduce_max(var)) tf.summary.scalar('min', tf.reduce_min(var)) tf.summary.histogram('histogram', var)
class TFSummaryRow(object): __slots__ = ( 'path', 'split', 'step', 'wall_time', 'tag', 'simple_value', 'image', 'tensor', ) def __init__(self): self.path = '' self.split = '' self.step = -1 self.wall_time = 0 self.tag = '' self.simple_value = float('nan') self.image = None self.tensor = None @staticmethod def fill_simple_value(row, summary): if summary.HasField('simple_value'): row.simple_value = summary.simple_value @staticmethod def fill_image(row, summary): if summary.HasField('image'): import imageio row.image = imageio.imread(summary.image.encoded_image_string) @staticmethod def fill_tensor(row, summary): if summary.HasField('tensor'): import tensorflow as tf row.tensor = tf.make_ndarray(summary.tensor) def as_dict(self): return dict((k, getattr(self, k)) for k in self.__slots__) def as_row(self, extra=None): from pyspark.sql import Row from au.spark import NumpyArray d = self.as_dict() d['image'] = NumpyArray(d['image']) d['tensor'] = NumpyArray(d['tensor']) d.update(**(extra or {})) return Row(**d) class TFSummaryReader(object): # Subclass and use this attribute to elide / ignore some summary messages FILLERS = ( TFSummaryRow.fill_simple_value, TFSummaryRow.fill_image, TFSummaryRow.fill_tensor, ) def __init__(self, paths=None, glob_events_from_dir=None): self._paths = paths or [] if glob_events_from_dir and os.path.exists(glob_events_from_dir): self._paths.extend( pathlib.Path(glob_events_from_dir).rglob('**/events.out*')) def __iter__(self): import tensorflow as tf for path in self._paths: path = str(path) log.info("Reading summaries from path %s ..." % path) split = '' # TF estimators puts eval summaries in the 'eval' subdir eval_str = os.pathsep + 'eval' + os.pathsep if eval_str in path: split = 'eval' def iter_events_verbose(path): # When there's an error in the file, e.g. truncated record, Tensorflow # doesn't print the path :( try: for tf_event in tf.train.summary_iterator(path): yield tf_event except Exception as e: raise Exception(("Error reading file %s" % path, e)) for tf_event in iter_events_verbose(path): for tf_summary in tf_event.summary.value: row = TFSummaryRow() row.path = path row.split = split row.wall_time = tf_event.wall_time row.step = tf_event.step row.tag = tf_summary.tag for filler in self.FILLERS: filler(row, tf_summary) yield row
[docs]class TFRecordsFileAsListOfStrings(object): """ Friends Don't Let Friends Use TFRecords. This utility provides a Tensorflow-free, minimal-dependency solution for reading TFRecords from a *file stream* (e.g. a buffered reader) and exposes random access over the stream's records. Based upon: * https://github.com/apache/beam/blob/master/sdks/python/apache_beam/io/tfrecordio.py#L67 * https://www.tensorflow.org/versions/r1.11/api_guides/python/python_io#TFRecords_Format_Details * https://github.com/gdlg/simple-waymo-open-dataset-reader/blob/master/simple_waymo_open_dataset_reader/__init__.py#L19 * https://gist.github.com/ed-alertedh/9f49bfc6216585f520c7c7723d20d951 """ ## Public API def __init__(self, fileobj): self.fileobj = fileobj self._offset_length = None def __len__(self): self._maybe_build_index() return len(self._offset_length) def __getitem__(self, idx): if idx >= len(self): raise IndexError else: start, length = self._offset_length[idx] return self._get_data(start, length) def __iter__(self): self.fileobj.seek(0) for offset, length in self._iter_offset_length(): yield self._get_data(offset, length) ## Utils @classmethod def _masked_crc32c(cls, value): """Compute a masked crc32c checksum for a value. FMI see https://www.tensorflow.org/versions/r1.11/api_guides/python/python_io#TFRecords_Format_Details """ if not hasattr(cls, '_crc32c_fn'): import crcmod cls._crc32c_fn = crcmod.predefined.mkPredefinedCrcFun('crc-32c') crc = cls._crc32c_fn(value) return (((crc >> 15) | (crc << 17)) + 0xa282ead8) & 0xffffffff @classmethod def _check_crc(cls, value, expected_crc): crc_actual = cls._masked_crc32c(value) if crc_actual != expected_crc: import codecs raise ValueError( 'Invalid TFRecord, mistmatch: %s %s %s' % ( codecs.encode(value[:10], 'hex'), crc_actual, expected_crc)) def _iter_offset_length(self): self.fileobj.seek(0) while True: header = self.fileobj.read(12) if header == b'': break assert len(header) == 12 import struct length, lengthcrc = struct.unpack("<QI", header) self._check_crc(header[:8], lengthcrc) base = self.fileobj.tell() yield (base, length) # Skip over payload and payload CRC self.fileobj.seek(base + length + 4) def _maybe_build_index(self): if self._offset_length is None: self._offset_length = list(self._iter_offset_length()) self.fileobj.seek(0) def _get_data(self, start, length): self.fileobj.seek(start) payload = self.fileobj.read(length + 4) assert len(payload) == (length + 4) import struct data, datacrc = struct.unpack("<%dsI" % length, payload) self._check_crc(data, datacrc) return data