longling.ML.toolkit.hyper_search.nni 源代码

# coding: utf-8
# 2020/1/10 @ tongshiwei

import pathlib
import warnings
from heapq import nlargest
from longling.ML.toolkit.analyser import get_max, get_min, get_by_key, key_parser
from longling import as_list, path_append
from longling import list2dict, nested_update

import json
import sqlite3
import os


def _key(x):
    """
    Examples
    --------
    >>> _key(123)
    123.0
    >>> _key('{"default": 123}')
    123.0
    """
    try:
        return float(x)
    except ValueError:
        return float(json.loads(x)["default"])


def show(key, max_key=True, exp_id=None, res_dir="./",
         nni_dir=path_append(os.environ.get("HOME", "./"), "nni-experiments"),
         only_final=False,
         with_keys=None, with_all=False):  # pragma: no cover
    """
    Updated in v1.3.17

    cli alias: ``nni show``

    Parameters
    ----------
    key
    exp_id
    res_dir
    nni_dir
    only_final
    with_keys
    with_all

    Returns
    -------

    """
    if exp_id is None:
        exp_id = pathlib.PurePath(os.path.abspath(res_dir)).name
    nni_dir = path_append(nni_dir, exp_id)
    sqlite_db = path_append(nni_dir, "db", "nni.sqlite", to_str=True)
    print(sqlite_db)
    conn = sqlite3.connect(sqlite_db)
    c = conn.cursor()
    if only_final:
        cursor = c.execute("SELECT trialJobId FROM MetricData WHERE type='FINAL';")
    else:
        cursor = c.execute("SELECT DISTINCT trialJobId FROM MetricData;")
    trial_dir = path_append(nni_dir, "trials")
    result = []
    for trial in [row[0] for row in cursor]:
        with open(path_append(trial_dir, trial, "parameter.cfg")) as f:
            trial_params = json.load(f)["parameters"]
        trial_res = path_append(res_dir, trial, "result.json", to_str=True)
        if max_key:
            value, appendix = get_max(
                trial_res, key, with_keys=with_keys, with_all=with_all, merge=False
            )
        else:
            value, appendix = get_min(
                trial_res, key, with_keys=with_keys, with_all=with_all, merge=False
            )
        if with_keys is not None or with_all is True:
            result.append([trial, trial_params, value, dict(appendix[key])])
        else:
            result.append([trial, trial_params, value])
    conn.close()
    result.sort(key=lambda x: float(x[2][key]), reverse=True)
    return result


def show_top_k(k, exp_id=None,
               exp_dir=path_append(os.environ.get("HOME", "./"), "nni-experiments")):  # pragma: no cover
    """
    Updated in v1.3.17

    cli alias: ``nni k-best``

    Parameters
    ----------
    k
    exp_id
    exp_dir

    Returns
    -------

    """
    import warnings
    warnings.warn("deprecated method")

    if exp_id:
        exp_dir = path_append(exp_dir, exp_id)
    sqlite_db = path_append(exp_dir, "db", "nni.sqlite", to_str=True)
    print(sqlite_db)
    conn = sqlite3.connect(sqlite_db)
    c = conn.cursor()
    cursor = c.execute("SELECT trialJobId, data FROM MetricData WHERE type='FINAL';")
    _ret = []
    top_k = nlargest(k, [row for row in cursor], key=lambda x: _key(x[1]))
    trial_dir = path_append(exp_dir, "trials")
    for trial, result in sorted(top_k, key=lambda x: _key(x[1]), reverse=True):
        with open(path_append(trial_dir, trial, "parameter.cfg")) as f:
            trial_params = json.load(f)["parameters"]
            _ret.append([trial, result, trial_params])
    conn.close()

    return _ret


class BaseReporter(object):
    def intermediate(self, data):
        raise NotImplementedError

    def final(self):
        raise NotImplementedError


def get_params(received_params: dict):
    """
    Updated in v1.3.17

    Parameters
    ----------
    received_params: dict
        nni get_next_parameters() 得到的参数字典

    Returns
    -------
    cfg_params: dict
        更新后的参数字典

    Examples
    --------
    >>> get_params({
    ...     "hyper_params_update:hidden_num": 50,
    ...     "hyper_params_update:alpha": 5,
    ...     "learning_rate": 0.1
    ... })
    {'hyper_params_update': {'hidden_num': 50, 'alpha': 5}, 'learning_rate': 0.1}
    """
    cfg_params = {}

    for k, v in received_params.items():
        if ":" in k:
            k = k.split(":")
            obj = list2dict(k, v)
        else:
            obj = {k: v}

        nested_update(cfg_params, obj)

    return cfg_params