Source code for nbrefactor.visualization.plot_module_tree

""" Visualization methods for refactored notebooks
"""

# import networkx as nx
from graphviz import Digraph
# import matplotlib.pyplot as plt
# from networkx.drawing.nx_agraph import graphviz_layout

[docs]def plot_module_tree(root_node, format='pdf', ): """ Visualize the :class:`~nbrefactor.datastructs.ModuleNode` tree structure \ and save it as a PDF. Primarily used for debugging and ensuring the generated structure is as intended. Args: root_node (:class:`~nbrefactor.datastructs.ModuleNode`): the root \ node of the tree to visualize. output_path (str): the file path where the PDF should be saved. Returns: :class:`graphviz.Digraph`: the module tree graph to be plotted or saved """ def add_nodes_edges(graph, node, parent_name=None): node_name = f'{id(node)}_{node.name}' # using id() to ensure # node-name uniqueness graph.node(node_name, node.name) if parent_name: # create an edge from the parent to this node graph.edge(parent_name, node_name, minlen='2') # recursive calls to all child `~ModuleNode` nodes for child in node.children.values(): add_nodes_edges(graph, child, node_name) dag = Digraph(comment='ModuleNodeTree', graph_attr={'splines': 'splines', 'rankdir': 'TB', 'nodesep': '0.5'}) # recursively add the tree nodes to the DAG add_nodes_edges(dag, root_node) dag.format = format return dag
# TODO: Needs improvements. It has to look much better to justify the # dependency clutter. Could potentially also explore an igraph + plotly # approach to get that *hierarchical directory tree* look; refer to # (https://plotly.com/python/tree-plots/) # def plot_module_tree_nx(root_node, output_path): # """ # Alternative visualization method to draw the :class:`~ModuleNode` tree \ # structure using NetworkX, Graphviz layout, and plt. # Args: # root_node (:class:`~nbrefactor.datastructs.ModuleNode`): the root \ # node of the tree to visualize. # Returns: # None: the plot is displayed with matplotlib. # """ # def truncate_name(name, max_length=50): # if len(name) <= max_length: # return name # half_length = (max_length - 3) // 2 # return f'{name[:half_length]}...{name[-half_length:]}' # def add_edges_to_graph(graph, node): # node_name = f'{node.name}' # node_name = truncate_name(node_name) # for child in node.children.values(): # child_name = truncate_name(f'{child.name}') # graph.add_edge(node_name, child_name) # add_edges_to_graph(graph, child) # dag = nx.DiGraph() # add_edges_to_graph(dag, root_node) # pos = graphviz_layout(dag, prog='dot') # plt.figure(figsize=(18, 8)) # nx.draw(dag, pos, with_labels=True, node_size=3000, node_shape='$\u25AC$', # node_color='none', font_size=10, edge_color='gray') # plt.savefig(output_path)