longling.ML.toolkit.monitor.ValueMonitor 源代码

# coding: utf-8
# 2019/11/28 @ tongshiwei

import warnings
import functools
import math
from typing import Iterable

try:
    NAN = math.nan
except (AttributeError, ImportError):  # pragma: no cover
    NAN = float('nan')

__all__ = ["ValueMonitor", "EMAValue", "as_tmt_value"]


class BaseTMTValue(object):
    """Basic Template is Module & module is Template (TMT) value"""

    def __init__(self):
        self.pipe = []

    def pipe_put(self, loss_value, *args, **kwargs):
        for _pipe in self.pipe:
            _pipe(loss_value, *args, **kwargs)


class TMTValue(BaseTMTValue):
    def __init__(self, value_obj, transform=lambda x: x):
        super(TMTValue, self).__init__()
        self._value_obj = value_obj
        self._transform = transform

    def __call__(self, *args, **kwargs):
        value = self._value_obj(*args, **kwargs)
        super(TMTValue, self).pipe_put(self._transform(value))
        return value


[文档]def as_tmt_value(value_obj, transform=lambda x: x): """ Parameters ---------- value_obj transform Returns ------- Examples -------- >>> def loss_f(a): ... return a >>> loss_f(10) 10 >>> tmt_loss_f = as_tmt_value(loss_f) >>> tmt_loss_f(10) 10 >>> @as_tmt_value ... def loss_f2(a): ... return a >>> loss_f2(10) 10 """ return TMTValue(value_obj, transform)
# def tmt_value(transform=lambda x: x): # def _tmt_value(class_type): # class _TMTValue(class_type, BaseTMTValue): # @functools.wraps(class_type.__call__) # def __call__(self, *args, **kwargs): # value = class_type.__call__(self, *args, **kwargs) # BaseTMTValue.pipe_put(self, self.transform(value)) # return value # # @staticmethod # def transform(loss): # return transform(loss) # # return _TMTValue # # return _tmt_value class ValueMonitor(object): """ 记录损失函数并显示 几个重要的函数: * update: 定义损失函数的如何累加更新 Examples -------- >>> vm = ValueMonitor(["m1", "m2"]) >>> vm.value {'m1': nan, 'm2': nan} >>> vm.keys() dict_keys(['m1', 'm2']) >>> vm.values() dict_values([nan, nan]) >>> vm.items() dict_items([('m1', nan), ('m2', nan)]) >>> "m1" in vm True >>> vm.monitor_off("m1") >>> "m1" in vm False >>> vm.monitor_on("m1") >>> "m1" in vm True >>> def m1_func(): ... return 1 >>> def m2_func(): ... return 2 >>> vm = ValueMonitor({"m1": m1_func, "m2": m2_func}, auto="ignore") >>> vm.value {'m1': nan, 'm2': nan} >>> vm["m1"] nan """ def __init__(self, value_function_names: (list, dict), digits=6, auto=None, *args, **kwargs): """ Parameters ---------- value_function_names digits auto: str, None or bool args kwargs """ self._values = {name: NAN for name in value_function_names} self._funcs = value_function_names if isinstance(value_function_names, dict) else {} self._auto = auto if self._auto is not False: for name, func in self._funcs.items(): self.link(name, func) self._digits = digits def __str__(self): return str(self._values) def __call__(self, *args, **kwargs): self.update(*args, **kwargs) def __getitem__(self, item): return self._values[item] def __contains__(self, name): return name in self._values def monitor_on(self, name, func=None): self._values[name] = NAN if func is not None: self._funcs[name] = func def monitor_off(self, name): del self._values[name] if name in self._funcs: del self._funcs[name] def get_update_value(self, name, value): raise NotImplementedError def update(self, name, value): self._values[name] = self.round(self.get_update_value(name, value)) def reset(self): for name in self._values: self._values[name] = NAN def values(self): return self._values.values() def items(self): return self._values.items() def keys(self): return self._values.keys() @property def func(self): return self._funcs @property def value(self): return self._values def round(self, value): return round(value, self._digits) if self._digits else value def pipe_get(self, name, value): if name in self._values: self.update(name, value) def link(self, name, obj): if hasattr(obj, "pipe"): obj.pipe.append(functools.partial(self.pipe_get, name)) else: hint = "invalid link operation, for %s has no pipe attribute, suggest to use TMTValue" % obj if self._auto is True: raise AttributeError(hint) elif self._auto is None or self._auto == "warn": warnings.warn(hint)
[文档]class EMAValue(ValueMonitor): """ Exponential moving average: smoothing to give progressively lower weights to older values. .. math :: losses[name] = (1 - c) \\times previous\\_value + c \\times loss\\_value >>> ema = EMAValue(["l2"]) >>> ema["l2"] nan >>> ema("l2", 100) >>> ema("l2", 1) >>> ema["l2"] 90.1 >>> list(ema.values()) [90.1] >>> list(ema.keys()) ['l2'] >>> list(ema.items()) [('l2', 90.1)] >>> ema.reset() >>> ema["l2"] nan >>> ema = EMAValue(["l1", "l2"]) >>> ema["l2"], ema["l1"] (nan, nan) >>> ema.updates({"l1": 1, "l2": 10}) >>> ema.updates({"l1": 10, "l2": 100}) >>> ema["l1"] 1.9 >>> ema["l2"] 19.0 >>> ema = EMAValue(["l1"], smoothing_constant=0.0) >>> ema["l1"] nan >>> ema.updates({"l1": 1}) >>> ema.updates({"l1": 10}) >>> ema["l1"] 1.0 >>> ema = EMAValue(["l1"], smoothing_constant=1.0) >>> ema.updates({"l1": 1}) >>> ema.updates({"l1": 10}) >>> ema["l1"] 10.0 >>> @as_tmt_value ... def mse_loss(a): ... return a ** 2 >>> ema = EMAValue({"mse": mse_loss}) >>> ema["mse"] nan >>> mse_loss(1) 1 >>> ema["mse"] 1 >>> mse_loss(10) 100 >>> ema["mse"] 10.9 >>> ema = EMAValue({"mse": mse_loss}) >>> mse_loss(1) 1 >>> ema["mse"] 1 >>> ema.monitor_off("mse") >>> ema.func {} >>> mse_loss(10) 100 >>> "mse" not in ema True >>> ema.monitor_on("mse", mse_loss) >>> mse_loss(10) 100 >>> ema["mse"] 100 """ def __init__(self, value_function_names: (list, dict), smoothing_constant=0.1, *args, **kwargs): """ Parameters ---------- value_function_names smoothing_constant: float, optional Smoothing factor in range [0, 1], [default: 0.01]. Increase to give more weight to recent values. """ super(EMAValue, self).__init__(value_function_names, *args, **kwargs) self.smoothing_constant = smoothing_constant
[文档] def get_update_value(self, name: str, value: (float, int)): """ Parameters ---------- name: str The name of the value to be updated value : int or float New value to include in EMA. """ return ( value if math.isnan(self._values[name]) else (1 - self.smoothing_constant) * self._values[ name] + self.smoothing_constant * value )
def updates(self, pairs: (dict, Iterable)): pairs = pairs.items() if isinstance(pairs, dict) else pairs for name, value in pairs: self.update(name, value)