Source code for srlearn.plotting

# Copyright © 2017-2021 Alexander L. Hayes

"""
Methods for plotting and visualization.
"""

from .base import BaseBoostedRelationalModel


class _GVPlotter:
    def __init__(self, dot_string):
        self.dot_string = dot_string
    def _repr_html_(self):
        import graphviz
        return graphviz.Source(self.dot_string)._repr_svg_()

[docs]def plot_digraph(dot_string, format="png"): """Plot a digraph as an image. Parameters ---------- dot_string : str String representing a dot format : str Format passed to Source (default: ``png``) Returns ------- source : graphviz.files.Source Graphviz ``Source`` object """ try: import graphviz except ImportError as excep: raise ImportError("graphviz needs to be available to plot_digraph") from excep from graphviz import Source if format == "html": return _GVPlotter(dot_string) return Source(dot_string, format=format)
[docs]def export_digraph(booster, tree_index=0, out_file=None): """Create a digraph representation of a tree. Parameters ---------- booster : BaseBoostedRelationalModel Model to create a tree from tree_index : int Index of the tree to visualize out_file : str, pathlike, or None Handle or name of the output file. If ``None``, returns a string Examples -------- This can be used in two ways: returning a string, or directly writing the result to a file. .. code-block:: python from srlearn.rdn import BoostedRDNClassifier from srlearn import Background from srlearn.datasets import load_toy_cancer from srlearn.plotting import export_digraph train, _ = load_toy_cancer() bkg = Background( modes=train.modes, ) clf = BoostedRDNClassifier( background=bkg, target="cancer", ) clf.fit(train) print(export_digraph(clf, tree_index=0)) """ if not isinstance(booster, BaseBoostedRelationalModel): raise TypeError("booster must inherit from BaseBoostedRelationalModel.") dotfiles = booster._dotfiles if not 0 <= tree_index < len(dotfiles): raise IndexError("tree_index is out of range.") if out_file: with open(out_file, "w") as _fh: _fh.write(dotfiles[tree_index]) else: return dotfiles[tree_index]