Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Control Flow Raising #1657

Merged
merged 25 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 21 additions & 16 deletions dace/codegen/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,13 @@ def as_cpp(self, codegen, symbols) -> str:
expr += elem.as_cpp(codegen, symbols)
# In a general block, emit transitions and assignments after each individual block or region.
if isinstance(elem, BasicCFBlock) or (isinstance(elem, RegionBlock) and elem.region):
cfg = elem.state.parent_graph if isinstance(elem, BasicCFBlock) else elem.region.parent_graph
if isinstance(elem, BasicCFBlock):
g_elem = elem.state
else:
g_elem = elem.region
cfg = g_elem.parent_graph
sdfg = cfg if isinstance(cfg, SDFG) else cfg.sdfg
out_edges = cfg.out_edges(elem.state) if isinstance(elem, BasicCFBlock) else cfg.out_edges(elem.region)
out_edges = cfg.out_edges(g_elem)
for j, e in enumerate(out_edges):
if e not in self.gotos_to_ignore:
# Skip gotos to immediate successors
Expand Down Expand Up @@ -532,26 +536,27 @@ def as_cpp(self, codegen, symbols) -> str:
expr = ''

if self.loop.update_statement and self.loop.init_statement and self.loop.loop_variable:
# Initialize to either "int i = 0" or "i = 0" depending on whether the type has been defined.
defined_vars = codegen.dispatcher.defined_vars
if not defined_vars.has(self.loop.loop_variable):
try:
init = f'{symbols[self.loop.loop_variable]} '
except KeyError:
init = 'auto '
symbols[self.loop.loop_variable] = None
init += unparse_interstate_edge(self.loop.init_statement.code[0], sdfg, codegen=codegen, symbols=symbols)
init = unparse_interstate_edge(self.loop.init_statement.code[0], sdfg, codegen=codegen, symbols=symbols)
init = init.strip(';')

update = unparse_interstate_edge(self.loop.update_statement.code[0], sdfg, codegen=codegen, symbols=symbols)
update = update.strip(';')

if self.loop.inverted:
expr += f'{init};\n'
expr += 'do {\n'
expr += _clean_loop_body(self.body.as_cpp(codegen, symbols))
expr += f'{update};\n'
expr += f'\n}} while({cond});\n'
if self.loop.update_before_condition:
expr += f'{init};\n'
expr += 'do {\n'
expr += _clean_loop_body(self.body.as_cpp(codegen, symbols))
expr += f'{update};\n'
expr += f'}} while({cond});\n'
else:
expr += f'{init};\n'
expr += 'while (1) {\n'
expr += _clean_loop_body(self.body.as_cpp(codegen, symbols))
expr += f'if (!({cond}))\n'
expr += 'break;\n'
expr += f'{update};\n'
expr += '}\n'
else:
expr += f'for ({init}; {cond}; {update}) {{\n'
expr += _clean_loop_body(self.body.as_cpp(codegen, symbols))
Expand Down
24 changes: 22 additions & 2 deletions dace/codegen/targets/framecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
from dace.codegen.prettycode import CodeIOStream
from dace.codegen.common import codeblock_to_cpp, sym2cpp
from dace.codegen.targets.target import TargetCodeGenerator
from dace.codegen.tools.type_inference import infer_expr_type
from dace.frontend.python import astutils
from dace.sdfg import SDFG, SDFGState, nodes
from dace.sdfg import scope as sdscope
from dace.sdfg import utils
from dace.sdfg.analysis import cfg as cfg_analysis
from dace.sdfg.state import ControlFlowRegion
from dace.transformation.passes.analysis import StateReachability
from dace.sdfg.state import ControlFlowRegion, LoopRegion
from dace.transformation.passes.analysis import StateReachability, loop_analysis


def _get_or_eval_sdfg_first_arg(func, sdfg):
Expand Down Expand Up @@ -916,6 +918,24 @@ def generate_code(self,
interstate_symbols.update(symbols)
global_symbols.update(symbols)

if isinstance(cfr, LoopRegion) and cfr.loop_variable is not None and cfr.init_statement is not None:
init_assignment = cfr.init_statement.code[0]
update_assignment = cfr.update_statement.code[0]
if isinstance(init_assignment, astutils.ast.Assign):
init_assignment = init_assignment.value
if isinstance(update_assignment, astutils.ast.Assign):
update_assignment = update_assignment.value
if not cfr.loop_variable in interstate_symbols:
l_end = loop_analysis.get_loop_end(cfr)
l_start = loop_analysis.get_init_assignment(cfr)
l_step = loop_analysis.get_loop_stride(cfr)
sym_type = dtypes.result_type_of(infer_expr_type(l_start, global_symbols),
infer_expr_type(l_step, global_symbols),
infer_expr_type(l_end, global_symbols))
interstate_symbols[cfr.loop_variable] = sym_type
if not cfr.loop_variable in global_symbols:
global_symbols[cfr.loop_variable] = interstate_symbols[cfr.loop_variable]

for isvarName, isvarType in interstate_symbols.items():
if isvarType is None:
raise TypeError(f'Type inference failed for symbol {isvarName}')
Expand Down
7 changes: 2 additions & 5 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2565,8 +2565,7 @@ def visit_If(self, node: ast.If):
self._on_block_added(cond_block)

if_body = ControlFlowRegion(cond_block.label + '_body', sdfg=self.sdfg)
cond_block.branches.append((CodeBlock(cond), if_body))
if_body.parent_graph = self.cfg_target
cond_block.add_branch(CodeBlock(cond), if_body)

# Visit recursively
self._recursive_visit(node.body, 'if', node.lineno, if_body, False)
Expand All @@ -2575,9 +2574,7 @@ def visit_If(self, node: ast.If):
if len(node.orelse) > 0:
else_body = ControlFlowRegion(f'{cond_block.label}_else_{node.orelse[0].lineno}',
sdfg=self.sdfg)
#cond_block.branches.append((CodeBlock(cond_else), else_body))
cond_block.branches.append((None, else_body))
else_body.parent_graph = self.cfg_target
cond_block.add_branch(None, else_body)
# Visit recursively
self._recursive_visit(node.orelse, 'else', node.lineno, else_body, False)

Expand Down
2 changes: 2 additions & 0 deletions dace/frontend/python/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,8 @@ def _parse(self, args, kwargs, simplify=None, save=False, validate=False) -> SDF
sdutils.inline_control_flow_regions(nsdfg)
sdfg.using_experimental_blocks = self.use_experimental_cfg_blocks

sdfg.reset_cfg_list()

# Apply simplification pass automatically
if not cached and (simplify == True or
(simplify is None and Config.get_bool('optimizer', 'automatic_simplification'))):
Expand Down
15 changes: 11 additions & 4 deletions dace/sdfg/analysis/schedule_tree/treenodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,17 @@ def as_string(self, indent: int = 0):
loop = self.header.loop
if loop.update_statement and loop.init_statement and loop.loop_variable:
if loop.inverted:
pre_header = indent * INDENTATION + f'{loop.init_statement.as_string}\n'
header = indent * INDENTATION + 'do:\n'
pre_footer = (indent + 1) * INDENTATION + f'{loop.update_statement.as_string}\n'
footer = indent * INDENTATION + f'while {loop.loop_condition.as_string}'
if loop.update_before_condition:
pre_header = indent * INDENTATION + f'{loop.init_statement.as_string}\n'
header = indent * INDENTATION + 'do:\n'
pre_footer = (indent + 1) * INDENTATION + f'{loop.update_statement.as_string}\n'
footer = indent * INDENTATION + f'while {loop.loop_condition.as_string}'
else:
pre_header = indent * INDENTATION + f'{loop.init_statement.as_string}\n'
header = indent * INDENTATION + 'while True:\n'
pre_footer = (indent + 1) * INDENTATION + f'if (not {loop.loop_condition.as_string}):\n'
pre_footer += (indent + 2) * INDENTATION + 'break\n'
footer = (indent + 1) * INDENTATION + f'{loop.update_statement.as_string}\n'
return pre_header + header + super().as_string(indent) + '\n' + pre_footer + footer
else:
result = (indent * INDENTATION +
Expand Down
Loading
Loading