# pylint:disable=wrong-import-position,broad-exception-caught,ungrouped-imports
import pathlib
import copy
from typing import Optional, Tuple, Any, Union, List, Iterable
import logging
import networkx
import ailment
import angr
from .call_counter import AILBlockCallCounter
from .seq_to_blocks import SequenceToBlocks
_l = logging.getLogger(__name__)
[docs]def remove_last_statement(node):
stmt = None
if type(node) is CodeNode:
stmt = remove_last_statement(node.node)
elif type(node) is ailment.Block:
stmt = node.statements[-1]
node.statements = node.statements[:-1]
elif type(node) is MultiNode:
if node.nodes:
stmt = remove_last_statement(node.nodes[-1])
if BaseNode.test_empty_node(node.nodes[-1]):
node.nodes = node.nodes[:-1]
elif type(node) is SequenceNode:
if node.nodes:
stmt = remove_last_statement(node.nodes[-1])
if BaseNode.test_empty_node(node.nodes[-1]):
node.nodes = node.nodes[:-1]
elif type(node) is ConditionNode:
if node.true_node is None and node.false_node is not None:
stmt = remove_last_statement(node.false_node)
elif node.true_node is not None and node.false_node is None:
stmt = remove_last_statement(node.true_node)
else:
raise NotImplementedError("More than one last statement exist")
elif type(node) is LoopNode:
stmt = remove_last_statement(node.sequence_node)
else:
raise NotImplementedError()
return stmt
[docs]def append_statement(node, stmt):
if type(node) is CodeNode:
append_statement(node.node, stmt)
return
if type(node) is ailment.Block:
node.statements.append(stmt)
return
if type(node) is MultiNode:
if node.nodes:
append_statement(node.nodes[-1], stmt)
else:
raise NotImplementedError()
return
if type(node) is SequenceNode:
if node.nodes:
append_statement(node.nodes[-1], stmt)
else:
raise NotImplementedError()
return
raise NotImplementedError()
[docs]def replace_last_statement(node, old_stmt, new_stmt):
if type(node) is CodeNode:
replace_last_statement(node.node, old_stmt, new_stmt)
return
if type(node) is ailment.Block:
if node.statements[-1] is old_stmt:
node.statements[-1] = new_stmt
return
if type(node) is MultiNode:
if node.nodes:
replace_last_statement(node.nodes[-1], old_stmt, new_stmt)
return
if type(node) is SequenceNode:
if node.nodes:
replace_last_statement(node.nodes[-1], old_stmt, new_stmt)
return
if type(node) is ConditionNode:
if node.true_node is not None:
replace_last_statement(node.true_node, old_stmt, new_stmt)
if node.false_node is not None:
replace_last_statement(node.false_node, old_stmt, new_stmt)
return
raise NotImplementedError()
[docs]def get_ast_subexprs(claripy_ast):
queue = [claripy_ast]
while queue:
ast = queue.pop(0)
if ast.op == "And":
queue += ast.args[1:]
yield ast.args[0]
else:
yield ast
[docs]def insert_node(parent, insert_location: str, node, node_idx: Optional[Union[int, Tuple[int]]], label=None):
if insert_location not in {"before", "after"}:
raise ValueError('"insert_location" must be either "before" or "after"')
if isinstance(parent, SequenceNode):
if insert_location == "before":
parent.nodes.insert(node_idx, node)
else: # if insert_location == "after":
parent.nodes.insert(node_idx + 1, node)
elif isinstance(parent, CodeNode):
# Make a new sequence node
if insert_location == "before":
seq = SequenceNode(parent.addr, nodes=[node, parent.node])
else: # if insert_location == "after":
seq = SequenceNode(parent.addr, nodes=[parent.node, node])
parent.node = seq
elif isinstance(parent, MultiNode):
if insert_location == "before":
parent.nodes.insert(node_idx, node)
else:
parent.nodes.insert(node_idx + 1, node)
elif isinstance(parent, ConditionNode):
if node_idx == 0:
# true node
if not isinstance(parent.true_node, SequenceNode):
if parent.true_node is None:
parent.true_node = SequenceNode(parent.addr, nodes=[])
else:
parent.true_node = SequenceNode(parent.true_node.addr, nodes=[parent.true_node])
insert_node(parent.true_node, insert_location, node, 0)
else:
# false node
if not isinstance(parent.false_node, SequenceNode):
if parent.false_node is None:
parent.false_node = SequenceNode(parent.addr, nodes=[])
else:
parent.false_node = SequenceNode(parent.false_node.addr, nodes=[parent.false_node])
insert_node(parent.false_node, insert_location, node, 0)
elif isinstance(parent, CascadingConditionNode):
cond, child_node = parent.condition_and_nodes[node_idx]
if not isinstance(child_node, SequenceNode):
child_node = SequenceNode(child_node.addr, nodes=[child_node])
parent.condition_and_nodes[node_idx] = (cond, child_node)
insert_node(child_node, insert_location, node, 0)
elif isinstance(parent, SwitchCaseNode):
# note that this case will be hit only when the parent node is not a container, such as SequenceNode or
# MultiNode. we always need to create a new SequenceNode and replace the original node in place.
if label == "switch_expr":
raise TypeError("You cannot insert a node after an expression.")
if label == "case":
# node_idx is the case number.
if insert_location == "after":
new_nodes = [parent.cases[node_idx], node]
elif insert_location == "before":
new_nodes = [node, parent.cases[node_idx]]
else:
raise TypeError(f'Unsupported insert_location value "{insert_location}".')
seq = SequenceNode(new_nodes[0].addr, nodes=new_nodes)
parent.cases[node_idx] = seq
elif label == "default":
if insert_location == "after":
new_nodes = [parent.default_node, node]
elif insert_location == "before":
new_nodes = [node, parent.default_node]
else:
raise TypeError("Unsupported 'insert_location' value %r." % insert_location)
seq = SequenceNode(new_nodes[0].addr, nodes=new_nodes)
parent.default_node = seq
else:
raise TypeError(
f'Unsupported label value "{label}". Must be one of the following: switch_expr, case, ' f"default."
)
elif isinstance(parent, LoopNode):
if label == "condition":
raise ValueError("Cannot insert nodes into a condition expression.")
if label == "body":
if not isinstance(parent.sequence_node, SequenceNode):
parent.sequence_node = SequenceNode(parent.sequence_node.addr, nodes=[parent.sequence_node])
insert_node(parent.sequence_node, insert_location, node, node_idx)
else:
raise NotImplementedError()
else:
raise NotImplementedError()
def _merge_ail_nodes(graph, node_a: ailment.Block, node_b: ailment.Block) -> ailment.Block:
in_edges = list(graph.in_edges(node_a, data=True))
out_edges = list(graph.out_edges(node_b, data=True))
a_ogs = graph.nodes[node_a].get("original_nodes", set())
b_ogs = graph.nodes[node_b].get("original_nodes", set())
new_node = node_a.copy() if node_a.addr <= node_b.addr else node_b.copy()
old_node = node_b if new_node == node_a else node_a
# remove jumps in the middle of nodes when merging
if new_node.statements and isinstance(new_node.statements[-1], ailment.Stmt.Jump):
new_node.statements = new_node.statements[:-1]
new_node.statements += old_node.statements
new_node.original_size += old_node.original_size
graph.remove_node(node_a)
graph.remove_node(node_b)
if new_node is not None:
graph.add_node(new_node, original_nodes=a_ogs.union(b_ogs))
for src, _, data in in_edges:
if src is node_b:
src = new_node
graph.add_edge(src, new_node, **data)
for _, dst, data in out_edges:
if dst is node_a:
dst = new_node
graph.add_edge(new_node, dst, **data)
return new_node
[docs]def to_ail_supergraph(transition_graph: networkx.DiGraph) -> networkx.DiGraph:
"""
Takes an AIL graph and converts it into a AIL graph that treats calls and redundant jumps
as parts of a bigger block instead of transitions. Calls to returning functions do not terminate basic blocks.
Based on region_identifier super_graph
:return: A converted super transition graph
"""
# make a copy of the graph
transition_graph = networkx.DiGraph(transition_graph)
networkx.set_node_attributes(transition_graph, {node: {node} for node in transition_graph.nodes}, "original_nodes")
while True:
for src, dst, data in transition_graph.edges(data=True):
type_ = data.get("type", None)
if len(list(transition_graph.successors(src))) == 1 and len(list(transition_graph.predecessors(dst))) == 1:
# calls in the middle of blocks OR boring jumps
if (type_ == "fake_return") or (src.addr + src.original_size == dst.addr):
_merge_ail_nodes(transition_graph, src, dst)
break
# calls to functions with no return
elif type_ == "call":
transition_graph.remove_node(dst)
break
else:
break
return transition_graph
[docs]def is_empty_node(node) -> bool:
if isinstance(node, ailment.Block):
return not node.statements
if isinstance(node, MultiNode):
return all(is_empty_node(n) for n in node.nodes)
if isinstance(node, SequenceNode):
return all(is_empty_node(n) for n in node.nodes)
return False
[docs]def is_empty_or_label_only_node(node) -> bool:
if isinstance(node, ailment.Block):
return not has_nonlabel_statements(node)
if isinstance(node, MultiNode):
return all(is_empty_or_label_only_node(n) for n in node.nodes)
if isinstance(node, SequenceNode):
return all(is_empty_or_label_only_node(n) for n in node.nodes)
return False
[docs]def has_nonlabel_statements(block: ailment.Block) -> bool:
return block.statements and any(not isinstance(stmt, ailment.Stmt.Label) for stmt in block.statements)
[docs]def first_nonlabel_statement(block: Union[ailment.Block, "MultiNode"]) -> Optional[ailment.Stmt.Statement]:
if isinstance(block, MultiNode):
for n in block.nodes:
stmt = first_nonlabel_statement(n)
if stmt is not None:
return stmt
return None
for stmt in block.statements:
if not isinstance(stmt, ailment.Stmt.Label):
return stmt
return None
[docs]def last_nonlabel_statement(block: ailment.Block) -> Optional[ailment.Stmt.Statement]:
for stmt in reversed(block.statements):
if not isinstance(stmt, ailment.Stmt.Label):
return stmt
return None
[docs]def first_nonlabel_node(seq: "SequenceNode") -> Optional[Union["BaseNode", ailment.Block]]:
for node in seq.nodes:
if isinstance(node, CodeNode):
inner_node = node.node
else:
inner_node = node
if isinstance(inner_node, ailment.Block) and not has_nonlabel_statements(inner_node):
continue
return node
return None
[docs]def remove_labels(graph: networkx.DiGraph):
new_graph = networkx.DiGraph()
nodes_map = {}
for node in graph:
node_copy = node.copy()
node_copy.statements = [stmt for stmt in node_copy.statements if not isinstance(stmt, ailment.Stmt.Label)]
nodes_map[node] = node_copy
new_graph.add_nodes_from(nodes_map.values())
for src, dst, data in graph.edges(data=True):
new_graph.add_edge(nodes_map[src], nodes_map[dst], **data)
return new_graph
[docs]def add_labels(graph: networkx.DiGraph):
new_graph = networkx.DiGraph()
nodes_map = {}
for node in graph:
lbl = ailment.Stmt.Label(None, f"LABEL_{node.addr:x}", node.addr, block_idx=node.idx)
node_copy = node.copy()
node_copy.statements = [lbl] + node_copy.statements
nodes_map[node] = node_copy
new_graph.add_nodes_from(nodes_map.values())
for src, dst in graph.edges:
new_graph.add_edge(nodes_map[src], nodes_map[dst])
return new_graph
[docs]def update_labels(graph: networkx.DiGraph):
"""
A utility function to recreate the labels for every node in an AIL graph. This useful when you are working with
a graph where only _some_ of the nodes have labels.
"""
return add_labels(remove_labels(graph))
[docs]def structured_node_is_simple_return(node: Union["SequenceNode", "MultiNode"], graph: networkx.DiGraph) -> bool:
"""
Will check if a "simple return" is contained within the node a simple returns looks like this:
if (cond) {
// simple return
...
return 0;
}
...
Returns true on any block ending in linear statements and a return.
"""
def _flatten_structured_node(packed_node: Union["SequenceNode", "MultiNode"]) -> List[ailment.Block]:
if not packed_node or not packed_node.nodes:
return []
blocks = []
if packed_node.nodes is not None:
for _node in packed_node.nodes:
if isinstance(_node, (SequenceNode, MultiNode)):
blocks += _flatten_structured_node(_node)
else:
blocks.append(_node)
return blocks
# sanity check: we need a graph to understand returning blocks
if graph is None:
return False
last_block = None
if isinstance(node, (SequenceNode, MultiNode)) and node.nodes:
flat_blocks = _flatten_structured_node(node)
if all(isinstance(block, ailment.Block) for block in flat_blocks):
last_block = flat_blocks[-1]
elif isinstance(node, ailment.Block):
last_block = node
valid_last_stmt = last_block is not None
if valid_last_stmt and last_block.statements:
valid_last_stmt = not isinstance(last_block.statements[-1], (ailment.Stmt.ConditionalJump, ailment.Stmt.Jump))
return valid_last_stmt and last_block in graph and not list(graph.successors(last_block))
[docs]def is_statement_terminating(stmt: ailment.statement.Statement, functions) -> bool:
if isinstance(stmt, ailment.Stmt.Return):
return True
if isinstance(stmt, ailment.Stmt.Call) and isinstance(stmt.target, ailment.Expr.Const):
# is it calling a non-returning function?
target_func_addr = stmt.target.value
try:
func = functions.get_by_addr(target_func_addr)
return func.returning is False
except KeyError:
pass
return False
[docs]def peephole_optimize_exprs(block, expr_opts):
class _any_update:
"""
Local temporary class used as a container for variable `v`.
"""
v = False
def _handle_expr(
expr_idx: int, expr: ailment.Expr.Expression, stmt_idx: int, stmt: Optional[ailment.Stmt.Statement], block
) -> Optional[ailment.Expr.Expression]:
old_expr = expr
redo = True
while redo:
redo = False
for expr_opt in expr_opts:
if isinstance(expr, expr_opt.expr_classes):
r = expr_opt.optimize(expr, stmt_idx=stmt_idx, block=block)
if r is not None and r is not expr:
expr = r
redo = True
break
if expr is not old_expr:
_any_update.v = True
# continue to process the expr
r = ailment.AILBlockWalker._handle_expr(walker, expr_idx, expr, stmt_idx, stmt, block)
return expr if r is None else r
return ailment.AILBlockWalker._handle_expr(walker, expr_idx, expr, stmt_idx, stmt, block)
# run expression optimizers
walker = ailment.AILBlockWalker()
walker._handle_expr = _handle_expr
walker.walk(block)
return _any_update.v
[docs]def peephole_optimize_expr(expr, expr_opts):
def _handle_expr(
expr_idx: int, expr: ailment.Expr.Expression, stmt_idx: int, stmt: Optional[ailment.Stmt.Statement], block
) -> Optional[ailment.Expr.Expression]:
old_expr = expr
redo = True
while redo:
redo = False
for expr_opt in expr_opts:
if isinstance(expr, expr_opt.expr_classes):
r = expr_opt.optimize(expr)
if r is not None and r is not expr:
expr = r
redo = True
break
if expr is not old_expr:
# continue to process the expr
r = ailment.AILBlockWalker._handle_expr(walker, expr_idx, expr, stmt_idx, stmt, block)
return expr if r is None else r
return ailment.AILBlockWalker._handle_expr(walker, expr_idx, expr, stmt_idx, stmt, block)
# run expression optimizers
walker = ailment.AILBlockWalker()
walker._handle_expr = _handle_expr
new_expr = walker._handle_expr(0, expr, 0, None, None)
return new_expr
[docs]def copy_graph(graph: networkx.DiGraph):
"""
Copy AIL Graph.
:return: A copy of the AIl graph.
"""
graph_copy = networkx.DiGraph()
block_mapping = {}
# copy all blocks
for block in graph.nodes():
new_block = copy.copy(block)
new_stmts = copy.copy(block.statements)
new_block.statements = new_stmts
block_mapping[block] = new_block
graph_copy.add_node(new_block)
# copy all edges
for src, dst, data in graph.edges(data=True):
new_src = block_mapping[src]
new_dst = block_mapping[dst]
graph_copy.add_edge(new_src, new_dst, **data)
return graph_copy
[docs]def peephole_optimize_stmts(block, stmt_opts):
any_update = False
statements = []
# run statement optimizers
# note that an optimizer may optionally edit or remove statements whose statement IDs are greater than stmt_idx
stmt_idx = 0
while stmt_idx < len(block.statements):
stmt = block.statements[stmt_idx]
old_stmt = stmt
redo = True
while redo:
redo = False
for opt in stmt_opts:
if isinstance(stmt, opt.stmt_classes):
r = opt.optimize(stmt, stmt_idx=stmt_idx, block=block)
if r is not None and r is not stmt:
stmt = r
redo = True
break
if stmt is not None and stmt is not old_stmt:
statements.append(stmt)
any_update = True
else:
statements.append(old_stmt)
stmt_idx += 1
return statements, any_update
[docs]def match_stmt_classes(all_stmts: List, idx: int, stmt_class_seq: Iterable[type]) -> bool:
for i, cls in enumerate(stmt_class_seq):
if idx + i >= len(all_stmts):
return False
if not isinstance(all_stmts[idx + i], cls):
return False
return True
[docs]def peephole_optimize_multistmts(block, stmt_opts):
any_update = False
statements = block.statements[::]
# run multi-statement optimizers
stmt_idx = 0
while stmt_idx < len(statements):
redo = True
while redo and stmt_idx < len(statements):
redo = False
for opt in stmt_opts:
matched = False
stmt_seq_len = None
for stmt_class_seq in opt.stmt_classes:
if match_stmt_classes(statements, stmt_idx, stmt_class_seq):
stmt_seq_len = len(stmt_class_seq)
matched = True
break
if matched:
matched_stmts = statements[stmt_idx : stmt_idx + stmt_seq_len]
r = opt.optimize(matched_stmts, stmt_idx=stmt_idx, block=block)
if r is not None:
# update statements
statements = statements[:stmt_idx] + r + statements[stmt_idx + stmt_seq_len :]
any_update = True
redo = True
break
# move on to the next statement
stmt_idx += 1
return statements, any_update
[docs]def decompile_functions(path, functions=None, structurer=None, catch_errors=False) -> Optional[str]:
"""
Decompile a binary into a set of functions.
:param path: The path to the binary to decompile.
:param functions: The functions to decompile. If None, all functions will be decompiled.
:param structurer: The structuring algorithms to use.
:param catch_errors: The structuring algorithms to use.
:return: The decompilation of all functions appended in order.
"""
# delayed imports to avoid circular imports
from angr.analyses.decompiler.decompilation_options import PARAM_TO_OPTION
structurer = structurer or "phoenix"
path = pathlib.Path(path).resolve().absolute()
proj = angr.Project(path, auto_load_libs=False)
cfg = proj.analyses.CFG(normalize=True, data_references=True)
proj.analyses.CompleteCallingConventions(recover_variables=True, analyze_callsites=True)
# collect all functions when None are provided
if functions is None:
functions = list(sorted(cfg.kb.functions))
# normalize the functions that could be ints as names
normalized_functions: List[Union[int, str]] = []
for func in functions:
try:
if isinstance(func, str):
normalized_name = int(func, 0)
else:
normalized_name = func
except ValueError:
normalized_name = func
normalized_functions.append(normalized_name)
functions = normalized_functions
# verify that all functions exist
for func in list(functions):
if func not in cfg.functions:
if catch_errors:
_l.warning("Function %s does not exist in the CFG.", str(func))
functions.remove(func)
else:
raise ValueError(f"Function {func} does not exist in the CFG.")
# decompile all functions
decompilation = ""
dec_options = [
(PARAM_TO_OPTION["structurer_cls"], structurer),
]
for func in functions:
f = cfg.functions[func]
if f is None or f.is_plt or f.is_syscall or f.is_alignment or f.is_simprocedure:
continue
exception_string = ""
if not catch_errors:
dec = proj.analyses.Decompiler(f, cfg=cfg, options=dec_options)
else:
try:
# TODO: add a timeout
dec = proj.analyses.Decompiler(f, cfg=cfg, options=dec_options)
except Exception as e:
exception_string = str(e).replace("\n", " ")
dec = None
# do sanity checks on decompilation, skip checks if we already errored
if not exception_string:
if dec is None or not dec.codegen or not dec.codegen.text:
exception_string = "Decompilation had no code output (failed in decompilation)"
elif "{\n}" in dec.codegen.text:
exception_string = "Decompilation outputted an empty function (failed in structuring)"
elif structurer in ["dream", "combing"] and "goto" in dec.codegen.text:
exception_string = "Decompilation outputted a goto for a Gotoless algorithm (failed in structuring)"
if exception_string:
_l.critical("Failed to decompile %s because %s", repr(f), exception_string)
decompilation += f"// [error: {func} | {exception_string}]\n"
else:
decompilation += dec.codegen.text + "\n"
return decompilation
[docs]def calls_in_graph(graph: networkx.DiGraph) -> int:
"""
Counts the number of calls in an graph full of AIL Blocks
"""
counter = AILBlockCallCounter()
for node in graph.nodes:
counter.walk(node)
return counter.calls
[docs]def find_block_by_addr(graph: networkx.DiGraph, addr: int):
for block in graph.nodes():
if block.addr == addr:
return block
raise KeyError("The block is not in the graph!")
[docs]def sequence_to_blocks(seq: "BaseNode") -> List[ailment.Block]:
"""
Converts a sequence node (BaseNode) to a list of ailment blocks contained in it and all its children.
"""
walker = SequenceToBlocks()
walker.walk(seq)
return walker.blocks
[docs]def sequence_to_statements(
seq: "BaseNode", exclude=(ailment.statement.Jump, ailment.statement.Jump)
) -> List[ailment.statement.Statement]:
"""
Converts a sequence node (BaseNode) to a list of ailment Statements contained in it and all its children.
May exclude certain types of statements.
"""
statements = []
blocks = sequence_to_blocks(seq)
block: ailment.Block
for block in blocks:
if not block.statements:
continue
for stmt in block.statements:
if isinstance(stmt, exclude):
continue
statements.append(stmt)
return statements
# delayed import
from .structuring.structurer_nodes import (
MultiNode,
BaseNode,
CodeNode,
SequenceNode,
ConditionNode,
SwitchCaseNode,
CascadingConditionNode,
LoopNode,
)