'''
Created on Aug 4, 2011
@author: sean
'''
from meta.asttools.visitors import Visitor, visit_children
from meta.asttools.visitors.symbol_visitor import get_symbols
import ast
from meta.utils import py2op
class ConditionalSymbolVisitor(Visitor):
def __init__(self):
self._cond_lhs = set()
self._stable_lhs = set()
self._cond_rhs = set()
self._stable_rhs = set()
self.undefined = set()
self.seen_break = False
visitModule = visit_children
visitPass = visit_children
def update_stable_rhs(self, symbols):
new_symbols = symbols - self._stable_rhs
self._update_undefined(new_symbols)
if self.seen_break:
self._cond_rhs.update(new_symbols)
else:
self._cond_rhs -= new_symbols
self._stable_rhs.update(new_symbols)
def update_stable_lhs(self, symbols):
new_symbols = symbols - self._stable_lhs
if self.seen_break:
self._cond_lhs.update(new_symbols)
else:
self._cond_lhs -= new_symbols
self._stable_lhs.update(new_symbols)
def update_cond_rhs(self, symbols):
new_symbols = symbols - self._stable_rhs
self._update_undefined(new_symbols)
self._cond_rhs.update(new_symbols)
def update_cond_lhs(self, symbols):
self._cond_lhs.update(symbols - self._stable_lhs)
def _update_undefined(self, symbols):
self.undefined.update(symbols - self._stable_lhs)
update_undefined = _update_undefined
@property
def stable_lhs(self):
assert not (self._stable_lhs & self._cond_lhs)
return self._stable_lhs
@property
def stable_rhs(self):
assert not (self._stable_rhs & self._cond_rhs)
return self._stable_rhs
@property
def cond_rhs(self):
assert not (self._stable_rhs & self._cond_rhs)
return self._cond_rhs
@property
def cond_lhs(self):
assert not (self._stable_lhs & self._cond_lhs)
return self._cond_lhs
@property
def lhs(self):
assert not (self._stable_lhs & self._cond_lhs)
return self._cond_lhs | self._stable_lhs
@property
def rhs(self):
assert not (self._stable_rhs & self._cond_rhs)
return self._cond_rhs | self._stable_rhs
def visitAugAssign(self, node):
values = get_symbols(node.value)
self.update_stable_rhs(values)
targets = get_symbols(node.target)
self.update_stable_rhs(targets)
self.update_stable_lhs(targets)
def visitAssign(self, node):
ids = set()
for target in node.targets:
ids.update(get_symbols(target, ast.Store))
rhs_ids = get_symbols(node.value, ast.Load)
for target in node.targets:
rhs_ids.update(get_symbols(target, ast.Load))
self.update_stable_rhs(rhs_ids)
self.update_stable_lhs(ids)
def visitBreak(self, node):
self.seen_break = True
def visitContinue(self, node):
self.seen_break = True
def visit_loop(self, node):
gen = ConditionalSymbolVisitor()
for stmnt in node.body:
gen.visit(stmnt)
self.update_cond_lhs(gen.cond_lhs)
self.update_cond_rhs(gen.cond_rhs)
outputs = gen.stable_lhs
inputs = gen.stable_rhs
gen = ConditionalSymbolVisitor()
for stmnt in node.orelse:
gen.visit(stmnt)
self.update_cond_rhs(gen.cond_rhs)
self.update_cond_lhs(gen.cond_lhs)
orelse_outputs = gen.stable_lhs
orelse_inputs = gen.stable_rhs
self.update_stable_lhs(outputs.intersection(orelse_outputs))
self.update_stable_rhs(inputs.intersection(orelse_inputs))
self.update_cond_lhs(outputs.symmetric_difference(orelse_outputs))
self.update_cond_rhs(inputs.symmetric_difference(orelse_inputs))
def visitFor(self, node):
lhs_symbols = get_symbols(node.target, ast.Store)
self.update_cond_lhs(lhs_symbols)
rhs_symbols = get_symbols(node.iter, ast.Load)
self.update_stable_rhs(rhs_symbols)
remove_from_undef = lhs_symbols - self.undefined
self.visit_loop(node)
self.undefined -= remove_from_undef
def visitExpr(self, node):
rhs_ids = get_symbols(node, ast.Load)
self.update_stable_rhs(rhs_ids)
def visitPrint(self, node):
rhs_ids = get_symbols(node, ast.Load)
self.update_stable_rhs(rhs_ids)
def visitWhile(self, node):
rhs_symbols = get_symbols(node.test, ast.Load)
self.update_stable_rhs(rhs_symbols)
self.visit_loop(node)
def visitIf(self, node):
rhs_symbols = get_symbols(node.test, ast.Load)
self.update_stable_rhs(rhs_symbols)
gen = ConditionalSymbolVisitor()
for stmnt in node.body:
gen.visit(stmnt)
if gen.seen_break:
self.seen_break = True
self.update_cond_lhs(gen._cond_lhs)
self.update_cond_rhs(gen._cond_rhs)
outputs = gen.stable_lhs
inputs = gen.stable_rhs
gen = ConditionalSymbolVisitor()
for stmnt in node.orelse:
gen.visit(stmnt)
self.update_cond_lhs(gen._cond_lhs)
self.update_cond_rhs(gen._cond_rhs)
orelse_outputs = gen.stable_lhs
orelse_inputs = gen.stable_rhs
self.update_stable_lhs(outputs.intersection(orelse_outputs))
self.update_stable_rhs(inputs.intersection(orelse_inputs))
self.update_cond_lhs(outputs.symmetric_difference(orelse_outputs))
self.update_cond_rhs(inputs.symmetric_difference(orelse_inputs))
@py2op
def visitExec(self, node):
self.update_stable_rhs(get_symbols(node.body, ast.Load))
if node.globals:
self.update_stable_rhs(get_symbols(node.globals, ast.Load))
if node.locals:
self.update_stable_rhs(get_symbols(node.locals, ast.Load))
def visitAssert(self, node):
self.update_stable_rhs(get_symbols(node.test, ast.Load))
if node.msg:
self.update_stable_rhs(get_symbols(node.msg, ast.Load))
@py2op
def visitRaise(self, node):
if node.type:
self.update_stable_rhs(get_symbols(node.type, ast.Load))
if node.inst:
self.update_stable_rhs(get_symbols(node.inst, ast.Load))
if node.tback:
self.update_stable_rhs(get_symbols(node.tback, ast.Load))
@visitRaise.py3op
def visitRaise(self, node):
if node.exc:
self.update_stable_rhs(get_symbols(node.exc, ast.Load))
if node.cause:
self.update_stable_rhs(get_symbols(node.cause, ast.Load))
def visitTryExcept(self, node):
gen = ConditionalSymbolVisitor()
gen.visit_list(node.body)
self.update_undefined(gen.undefined)
handlers = [csv(hndlr) for hndlr in node.handlers]
for g in handlers:
self.update_undefined(g.undefined)
stable_rhs = gen.stable_rhs.intersection(*[g.stable_rhs for g in handlers])
self.update_stable_rhs(stable_rhs)
all_rhs = gen.rhs.union(*[g.rhs for g in handlers])
self.update_cond_rhs(all_rhs - stable_rhs)
stable_lhs = gen.stable_lhs.intersection(*[g.stable_lhs for g in handlers])
self.update_stable_lhs(stable_lhs)
all_lhs = gen.lhs.union(*[g.lhs for g in handlers])
self.update_cond_lhs(all_lhs - stable_lhs)
gen = ConditionalSymbolVisitor()
gen.visit_list(node.orelse)
self.update_undefined(gen.undefined)
self.update_cond_lhs(gen.lhs)
self.update_cond_rhs(gen.rhs)
@py2op
def visitExceptHandler(self, node):
if node.type:
self.update_stable_rhs(get_symbols(node.type, ast.Load))
if node.name:
self.update_stable_lhs(get_symbols(node.name, ast.Store))
self.visit_list(node.body)
@visitExceptHandler.py3op
def visitExceptHandler(self, node):
if node.type:
self.update_stable_rhs(get_symbols(node.type, ast.Load))
if node.name:
self.update_stable_lhs({node.name})
self.visit_list(node.body)
def visitTryFinally(self, node):
self.visit_list(node.body)
self.visit_list(node.finalbody)
def visitImportFrom(self, node):
symbols = get_symbols(node)
self.update_stable_lhs(symbols)
def visitImport(self, node):
symbols = get_symbols(node)
self.update_stable_lhs(symbols)
def visitLambda(self, node):
gen = ConditionalSymbolVisitor()
gen.update_stable_lhs(symbols={arg for arg in node.args.args})
gen.visit_list(node.body)
self.update_stable_rhs(gen.undefined)
def visitFunctionDef(self, node):
for decorator in node.decorator_list:
self.update_stable_rhs(get_symbols(decorator, ast.Load))
self.update_stable_lhs({node.name})
gen = ConditionalSymbolVisitor()
gen.update_stable_lhs(symbols={arg for arg in node.args.args})
gen.visit_list(node.body)
self.update_stable_rhs(gen.undefined)
def visitGlobal(self, node):
pass
def visitWith(self, node):
self.update_stable_rhs(get_symbols(node.context_expr, ast.Load))
if node.optional_vars:
self.update_stable_lhs(get_symbols(node.optional_vars, ast.Load))
self.visit_list(node.body)
def visitReturn(self, node):
self.update_stable_rhs(get_symbols(node.value, ast.Load))
def csv(node):
gen = ConditionalSymbolVisitor()
gen.visit(node)
return gen
[docs]def lhs(node):
'''
Return a set of symbols in `node` that are assigned.
:param node: ast node
:returns: set of strings.
'''
gen = ConditionalSymbolVisitor()
if isinstance(node, (list, tuple)):
gen.visit_list(node)
else:
gen.visit(node)
return gen.lhs
[docs]def rhs(node):
'''
Return a set of symbols in `node` that are used.
:param node: ast node
:returns: set of strings.
'''
gen = ConditionalSymbolVisitor()
if isinstance(node, (list, tuple)):
gen.visit_list(node)
else:
gen.visit(node)
return gen.rhs
[docs]def conditional_lhs(node):
'''
Group outputs into contitional and stable
:param node: ast node
:returns: tuple of (contitional, stable)
'''
gen = ConditionalSymbolVisitor()
gen.visit(node)
return gen.cond_lhs, gen.stable_lhs
[docs]def conditional_symbols(node):
'''
Group lhs and rhs into contitional, stable and undefined
:param node: ast node
:returns: tuple of (contitional_lhs, stable_lhs),(contitional_rhs, stable_rhs), undefined
'''
gen = ConditionalSymbolVisitor()
gen.visit(node)
lhs = gen.cond_lhs, gen.stable_lhs
rhs = gen.cond_rhs, gen.stable_rhs
undefined = gen.undefined
return lhs, rhs, undefined
if __name__ == '__main__':
source = '''
while k:
a = 1
b = 1
break
d = 1
else:
a =2
c= 3
d = 1
'''
print(conditional_lhs(ast.parse(source)))