'''
Created on Jul 18, 2011
@author: sean
'''
from meta.asttools import Visitor, visit_children
import _ast
from meta.asttools.visitors.symbol_visitor import get_symbols
try:
from networkx import DiGraph
except ImportError:
DiGraph = None
def collect_(self, node):
names = set()
for child in self.children(node):
names.update(self.visit(child))
if hasattr(node, 'ctx'):
if isinstance(node.ctx, _ast.Store):
self.modified.update(names)
elif isinstance(node.ctx, _ast.Load):
self.used.update(names)
return names
class CollectNodes(Visitor):
def __init__(self, call_deps=False):
self.graph = DiGraph()
self.modified = set()
self.used = set()
self.undefined = set()
self.sources = set()
self.targets = set()
self.context_names = set()
self.call_deps = call_deps
visitDefault = collect_
def visitName(self, node):
if isinstance(node.ctx, _ast.Store):
self.modified.add(node.id)
elif isinstance(node.ctx, _ast.Load):
self.used.update(node.id)
if not self.graph.has_node(node.id):
self.graph.add_node(node.id)
if isinstance(node.ctx, _ast.Load):
self.undefined.add(node.id)
for ctx_var in self.context_names:
if not self.graph.has_edge(node.id, ctx_var):
self.graph.add_edge(node.id, ctx_var)
return {node.id}
def visitalias(self, node):
name = node.asname if node.asname else node.name
if '.' in name:
name = name.split('.', 1)[0]
if not self.graph.has_node(name):
self.graph.add_node(name)
return {name}
def visitCall(self, node):
left = self.visit(node.func)
right = set()
for attr in ('args', 'keywords'):
for child in getattr(node, attr):
if child:
right.update(self.visit(child))
for attr in ('starargs', 'kwargs'):
child = getattr(node, attr)
if child:
right.update(self.visit(child))
for src in left | right:
if not self.graph.has_node(src):
self.undefined.add(src)
if self.call_deps:
add_edges(self.graph, left, right)
add_edges(self.graph, right, left)
right.update(left)
return right
def visitSubscript(self, node):
if isinstance(node.ctx, _ast.Load):
return collect_(self, node)
else:
sources = self.visit(node.slice)
targets = self.visit(node.value)
self.modified.update(targets)
add_edges(self.graph, targets, sources)
return targets
def handle_generators(self, generators):
defined = set()
required = set()
for generator in generators:
get_symbols(generator, _ast.Load)
required.update(get_symbols(generator, _ast.Load) - defined)
defined.update(get_symbols(generator, _ast.Store))
return defined, required
def visitListComp(self, node):
defined, required = self.handle_generators(node.generators)
required.update(get_symbols(node.elt, _ast.Load) - defined)
for symbol in required:
if not self.graph.has_node(symbol):
self.graph.add_node(symbol)
self.undefined.add(symbol)
return required
def visitSetComp(self, node):
defined, required = self.handle_generators(node.generators)
required.update(get_symbols(node.elt, _ast.Load) - defined)
for symbol in required:
if not self.graph.has_node(symbol):
self.graph.add_node(symbol)
self.undefined.add(symbol)
return required
def visitDictComp(self, node):
defined, required = self.handle_generators(node.generators)
required.update(get_symbols(node.key, _ast.Load) - defined)
required.update(get_symbols(node.value, _ast.Load) - defined)
for symbol in required:
if not self.graph.has_node(symbol):
self.graph.add_node(symbol)
self.undefined.add(symbol)
return required
def add_edges(graph, targets, sources):
for target in targets:
for src in sources:
edge = target, src
if not graph.has_edge(*edge):
graph.add_edge(*edge)
class GlobalDeps(object):
def __init__(self, gen, nodes):
self.nodes = nodes
self.gen = gen
def __enter__(self):
self._old_context_names = set(self.gen.context_names)
self.gen.context_names.update(self.nodes)
def __exit__(self, *args):
self.gen.context_names = self._old_context_names
class GraphGen(CollectNodes):
'''
Create a graph from the execution flow of the ast
'''
visitModule = visit_children
def depends_on(self, nodes):
return GlobalDeps(self, set(nodes))
def visit_lambda(self, node):
sources = self.visit(node.args)
self.sources.update(sources)
self.visit(node.body)
def visitLambda(self, node):
gen = GraphGen()
gen.visit_lambda(node)
for undef in gen.undefined:
if not self.graph.has_node(undef):
self.graph.add_node(undef)
return gen.undefined
def visit_function_def(self, node):
sources = self.visit(node.args)
self.sources.update(sources)
for stmnt in node.body:
self.visit(stmnt)
def visitFunctionDef(self, node):
gen = GraphGen()
gen.visit_function_def(node)
if not self.graph.has_node(node.name):
self.graph.add_node(node.name)
for undef in gen.undefined:
if not self.graph.has_node(undef):
self.graph.add_node(undef)
add_edges(self.graph, [node.name], gen.undefined)
return gen.undefined
def visitAssign(self, node):
nodes = self.visit(node.value)
tsymols = get_symbols(node, _ast.Store)
re_defined = tsymols.intersection(set(self.graph.nodes()))
if re_defined:
add_edges(self.graph, re_defined, re_defined)
targets = set()
for target in node.targets:
targets.update(self.visit(target))
add_edges(self.graph, targets, nodes)
return targets | nodes
def visitAugAssign(self, node):
targets = self.visit(node.target)
values = self.visit(node.value)
self.modified.update(targets)
for target in targets:
for value in values:
edge = target, value
if not self.graph.has_edge(*edge):
self.graph.add_edge(*edge)
for tgt2 in targets:
edge = target, tgt2
if not self.graph.has_edge(*edge):
self.graph.add_edge(*edge)
return targets | values
def visitFor(self, node):
nodes = set()
targets = self.visit(node.target)
for_iter = self.visit(node.iter)
nodes.update(targets)
nodes.update(for_iter)
add_edges(self.graph, targets, for_iter)
with self.depends_on(for_iter):
for stmnt in node.body:
nodes.update(self.visit(stmnt))
return nodes
def visitIf(self, node):
nodes = set()
names = self.visit(node.test)
nodes.update(names)
with self.depends_on(names):
for stmnt in node.body:
nodes.update(self.visit(stmnt))
for stmnt in node.orelse:
nodes.update(self.visit(stmnt))
return nodes
def visitReturn(self, node):
targets = self.visit(node.value)
self.targets.update(targets)
return targets
def visitWith(self, node):
nodes = set()
targets = self.visit(node.context_expr)
nodes.update(targets)
if node.optional_vars is None:
vars = ()
else:
vars = self.visit(node.optional_vars)
nodes.update(vars)
add_edges(self.graph, vars, targets)
with self.depends_on(targets):
for stmnt in node.body:
nodes.update(self.visit(stmnt))
return nodes
def visitWhile(self, node):
nodes = set()
targets = self.visit(node.test)
nodes.update(targets)
with self.depends_on(targets):
for stmnt in node.body:
nodes.update(self.visit(stmnt))
for stmnt in node.orelse:
nodes.update(self.visit(stmnt))
return nodes
def visitTryFinally(self, node):
assert len(node.body) == 1
nodes = self.visit(node.body[0])
with self.depends_on(nodes):
for stmnt in node.finalbody:
nodes.update(self.visit(stmnt))
def visitTryExcept(self, node):
body_nodes = set()
for stmnt in node.body:
body_nodes.update(self.visit(stmnt))
all_nodes = set(body_nodes)
for hndlr in node.handlers:
nodes = set(body_nodes)
if hndlr.name:
nodes.update(self.visit(hndlr.name))
if hndlr.type:
nodes.update(self.visit(hndlr.type))
with self.depends_on(nodes):
for stmnt in hndlr.body:
nodes.update(self.visit(stmnt))
all_nodes.update(nodes)
nodes = set(body_nodes)
with self.depends_on(nodes):
for stmnt in node.orelse:
nodes.update(self.visit(stmnt))
all_nodes.update(nodes)
return all_nodes
[docs]def make_graph(node, call_deps=False):
'''
Create a dependency graph from an ast node.
:param node: ast node.
:param call_deps: if true, then the graph will create a cyclic dependance for all
function calls. (i.e for `a.b(c)` a depends on b and b depends on a)
:returns: a tuple of (graph, undefined)
'''
gen = GraphGen(call_deps=call_deps)
gen.visit(node)
return gen.graph, gen.undefined