longling.ML.MxnetHelper.toolkit.viz 源代码

# coding: utf-8
# created by tongshiwei on 18-1-27

import traceback

import mxnet as mx

from longling.lib.stream import build_dir

__all__ = ["VizError", "plot_network"]


[文档]class VizError(Exception): # pragma: no cover pass
def plot_network(nn_symbol, save_path="plot/network", shape=None, node_attrs=None, view=False): # pragma: no cover node_attrs = {} if node_attrs is None else node_attrs graph = mx.viz.plot_network(nn_symbol, shape=shape, node_attrs=node_attrs) assert save_path build_dir(save_path) try: graph.render(save_path, view=view) except Exception: raise VizError(traceback.format_exc())