From 962fefb3c3facda9117e406b6717fceb8ebd4296 Mon Sep 17 00:00:00 2001 From: sileod Date: Wed, 3 Jan 2024 15:47:17 +0100 Subject: [PATCH] Update xpflow.py --- xpflow.py | 110 ++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 98 insertions(+), 12 deletions(-) diff --git a/xpflow.py b/xpflow.py index a9c3e44..215087a 100644 --- a/xpflow.py +++ b/xpflow.py @@ -4,7 +4,21 @@ import hashlib import json from sorcery import dict_of -import os, sys, traceback +import os, sys, traceback, psutil +import functools +import tqdm +import logging + +def without(d, key): + if key not in d: + return d + new_d = d.copy() + new_d.pop(key) + return new_d + +def is_interactive(): + import __main__ as main + return not hasattr(main, '__file__') def override(xp): import argparse, sys @@ -12,19 +26,21 @@ def override(xp): _, unknown = parser.parse_known_args(sys.argv[1:]) cmd_args_dict = dict(zip(unknown[:-1:2],unknown[1::2])) cmd_args_dict = {k.lstrip('-'): v for (k,v) in cmd_args_dict.items()} - print(f"cmd_args: {cmd_args_dict}") for k,v in cmd_args_dict.items(): if k in xp: xp[k]=type(xp[k])(v) return xp class edict(EasyDict): - def __hash__(self): - json_dump = json.dumps(self, sort_keys=True, ensure_ascii=True) - digest = hashlib.md5(json_dump.encode('utf-8')).hexdigest() - identifier = int(digest, 16) - return identifier + def __hash__(self): + try: + json_dump = json.dumps(self, sort_keys=True, ensure_ascii=True) + digest = hashlib.md5(json_dump.encode('utf-8')).hexdigest() + identifier = int(digest, 16) + return identifier + except: + return 0 class Xpl(): def __init__(self,a,b): self.a=a @@ -74,7 +90,8 @@ def edict(self): def __iter__(self): keys = self.keys() values_list = self._values() - for values in values_list: + history=[] + for i, values in enumerate(values_list): args = edict({}) for a, v in zip(keys, values): @@ -86,6 +103,9 @@ def __iter__(self): xp = selfi.edict() xp = override(xp) xp._hash = hash(xp) + history+=[xp] + if i==len(values_list)-1: + xp._history=history yield xp def first(self): @@ -99,6 +119,8 @@ def __len__(self): return len([x for x in self]) +# Context managers: + class NoPrint: def __enter__(self): self._original_stdout = sys.stdout @@ -108,11 +130,26 @@ def __exit__(self, exc_type, exc_val, exc_tb): sys.stdout.close() sys.stdout = self._original_stdout +class NoTqdm: + def __enter__(self): + tqdm.__init__ = functools.partialmethod(tqdm.__init__, disable=True) + def __exit__(self, exc_type, exc_value, exc_traceback): + tqdm.__init__ = functools.partialmethod(tqdm.__init__, disable=False) + + +class NoLogging: + def __enter__(self): + self._old_logging_level = logging.root.manager.disable + logging.disable(logging.CRITICAL) + + def __exit__(self, exc_type, exc_value, traceback): + logging.disable(self._old_logging_level) + class Catch: - def __init__(self, exceptions=[], exit_fn=lambda:None): + def __init__(self, exceptions=[], exit_fn=lambda:None,info=''): self.allowed_exceptions = exceptions - self.encountered_expcetions=[] self.exit_fn=exit_fn + self.info=info def __enter__(self): return self @@ -127,6 +164,55 @@ def __exit__(self, exception_type, exception_value, tb): _EXCEPTIONS=[] if exception_type and (exception_type in self.allowed_exceptions or not self.allowed_exceptions): - print(f"{exception_type.__name__} swallowed!",exception_value,traceback.print_tb(tb)) - _EXCEPTIONS+=[dict_of(exception_type,exception_value,tb)] + print(f"{exception_type.__name__} swallowed!",str(self.info),exception_value,traceback.print_tb(tb)) + _EXCEPTIONS+=[dict_of(exception_type,exception_value,tb,info=self.info)] return True + + +class Notifier: + def __init__(self, exit_fn=lambda x:None): + self.exit_fn=exit_fn + def __enter__(self): + return self + def __exit__(self, *args): + self.exit_fn(str(args)) + +class MeasureRAM: + def __init__(self, id=None, logger=print): + self.id = id + self.logger = logger + if self.logger==print: + self.logger=type('logger', (object,), {'log':print})() + + def __enter__(self): + self.mem_before = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024) + + def __exit__(self, exc_type, exc_val, exc_tb): + mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024) + variation_mb = (mem_after - self.mem_before) + if self.logger: + self.logger.log(dict_of(self.id,variation_mb)) + +class DisableOutput: + def __init__(self): + self.devnull = None + + def __enter__(self): + # Disable all output streams + self.devnull = open(os.devnull, "w") + self.old_stdout = sys.stdout + self.old_stderr = sys.stderr + sys.stdout = self.devnull + sys.stderr = self.devnull + logging.disable(logging.CRITICAL) # Disable all logging output + return self + + def __exit__(self, *args): + # Re-enable the output streams + sys.stdout = self.old_stdout + sys.stderr = self.old_stderr + self.devnull.close() + logging.disable(logging.NOTSET) # Re-enable logging output + + def write(self, *args, **kwargs): + pass