Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More robust loop detection #1646

Merged
merged 7 commits into from
Sep 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
465 changes: 445 additions & 20 deletions dace/transformation/interstate/loop_detection.py

Large diffs are not rendered by default.

15 changes: 7 additions & 8 deletions dace/transformation/interstate/loop_peeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,16 @@ def _modify_cond(self, condition, var, step):
def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG):
####################################################################
# Obtain loop information
guard: sd.SDFGState = self.loop_guard
begin: sd.SDFGState = self.loop_begin
after_state: sd.SDFGState = self.exit_state

# Obtain iteration variable, range, and stride
condition_edge = graph.edges_between(guard, begin)[0]
not_condition_edge = graph.edges_between(guard, after_state)[0]
itervar, rng, loop_struct = find_for_loop(graph, guard, begin)
condition_edge = self.loop_condition_edge()
not_condition_edge = self.loop_exit_edge()
itervar, rng, loop_struct = self.loop_information()

# Get loop states
loop_states = list(sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != guard))
loop_states = self.loop_body()
first_id = loop_states.index(begin)
last_state = loop_struct[1]
last_id = loop_states.index(last_state)
Expand All @@ -104,7 +103,7 @@ def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG):
init_edges = []
before_states = loop_struct[0]
for before_state in before_states:
init_edge = graph.edges_between(before_state, guard)[0]
init_edge = self.loop_init_edge()
init_edge.data.assignments[itervar] = str(rng[0] + self.count * rng[2])
init_edges.append(init_edge)
append_states = before_states
Expand Down Expand Up @@ -133,7 +132,7 @@ def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG):
if append_state not in before_states:
for init_edge in init_edges:
graph.remove_edge(init_edge)
graph.add_edge(append_state, guard, init_edges[0].data)
graph.add_edge(append_state, init_edge.dst, init_edges[0].data)
else:
# If begin, change initialization assignment and prepend states before
# guard
Expand Down Expand Up @@ -164,4 +163,4 @@ def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG):
# Reconnect edge to guard state from last peeled iteration
if prepend_state != after_state:
graph.remove_edge(not_condition_edge)
graph.add_edge(guard, prepend_state, not_condition_edge.data)
graph.add_edge(not_condition_edge.src, prepend_state, not_condition_edge.data)
114 changes: 69 additions & 45 deletions dace/transformation/interstate/loop_to_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,16 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi
if not super().can_be_applied(graph, expr_index, sdfg, permissive):
return False

guard = self.loop_guard
begin = self.loop_begin

# Guard state should not contain any dataflow
if len(guard.nodes()) != 0:
return False
if expr_index <= 1:
guard = self.loop_guard
if len(guard.nodes()) != 0:
return False

# If loop cannot be detected, fail
found = find_for_loop(graph, guard, begin, itervar=self.itervar)
found = self.loop_information(itervar=self.itervar)
if not found:
return False

Expand All @@ -123,7 +124,7 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi
return False

# Find all loop-body states
states: List[SDFGState] = list(sdutil.dfs_conditional(sdfg, [begin], lambda _, c: c is not guard))
states: List[SDFGState] = self.loop_body()

assert (body_end in states)

Expand Down Expand Up @@ -349,22 +350,15 @@ def apply(self, _, sdfg: sd.SDFG):
from dace.sdfg.propagation import align_memlet

# Obtain loop information
guard: sd.SDFGState = self.loop_guard
itervar, (start, end, step), (_, body_end) = self.loop_information(itervar=self.itervar)
states = self.loop_body()
body: sd.SDFGState = self.loop_begin
after: sd.SDFGState = self.exit_state

# Obtain iteration variable, range, and stride
itervar, (start, end, step), (_, body_end) = find_for_loop(sdfg, guard, body, itervar=self.itervar)

# Find all loop-body states
states = set()
to_visit = [body]
while to_visit:
state = to_visit.pop(0)
for _, dst, _ in sdfg.out_edges(state):
if dst not in states and dst is not guard:
to_visit.append(dst)
states.add(state)
exit_state = self.exit_state
entry_edge = self.loop_condition_edge()
init_edge = self.loop_init_edge()
after_edge = self.loop_exit_edge()
condition_edge = self.loop_condition_edge()
increment_edge = self.loop_increment_edge()

nsdfg = None

Expand Down Expand Up @@ -425,7 +419,7 @@ def apply(self, _, sdfg: sd.SDFG):
nsdfg = SDFG("loop_body", constants=sdfg.constants_prop, parent=new_body)
nsdfg.add_node(body, is_start_state=True)
body.parent = nsdfg
exit_state = nsdfg.add_state('exit')
nexit_state = nsdfg.add_state('exit')
nsymbols = dict()
for state in states:
if state is body:
Expand All @@ -438,20 +432,48 @@ def apply(self, _, sdfg: sd.SDFG):
for src, dst, data in sdfg.in_edges(state):
nsymbols.update({s: sdfg.symbols[s] for s in data.assignments.keys() if s in sdfg.symbols})
nsdfg.add_edge(src, dst, data)
nsdfg.add_edge(body_end, exit_state, InterstateEdge())
nsdfg.add_edge(body_end, nexit_state, InterstateEdge())

# Move guard -> body edge to guard -> new_body
for src, dst, data, in sdfg.edges_between(guard, body):
sdfg.add_edge(src, new_body, data)
# Move body_end -> guard edge to new_body -> guard
for src, dst, data in sdfg.edges_between(body_end, guard):
sdfg.add_edge(new_body, dst, data)
increment_edge = None

# Delete loop-body states and edges from parent SDFG
for state in states:
for e in sdfg.all_edges(state):
# Specific instructions for loop type
if self.expr_index <= 1: # Natural loop with guard
guard = self.loop_guard

# Move guard -> body edge to guard -> new_body
for e in sdfg.edges_between(guard, body):
sdfg.remove_edge(e)
condition_edge = sdfg.add_edge(e.src, new_body, e.data)
# Move body_end -> guard edge to new_body -> guard
for e in sdfg.edges_between(body_end, guard):
sdfg.remove_edge(e)
sdfg.remove_node(state)
increment_edge = sdfg.add_edge(new_body, e.dst, e.data)


elif 1 < self.expr_index <= 3: # Rotated loop
entrystate = self.entry_state
latch = self.loop_latch

# Move entry edge to entry -> new_body
for src, dst, data, in sdfg.edges_between(entrystate, body):
init_edge = sdfg.add_edge(src, new_body, data)

# Move body_end -> latch to new_body -> latch
for src, dst, data in sdfg.edges_between(latch, exit_state):
after_edge = sdfg.add_edge(new_body, dst, data)

elif self.expr_index == 4: # Self-loop
entrystate = self.entry_state

# Move entry edge to entry -> new_body
for src, dst, data in sdfg.edges_between(entrystate, body):
init_edge = sdfg.add_edge(src, new_body, data)
for src, dst, data in sdfg.edges_between(body, exit_state):
after_edge = sdfg.add_edge(new_body, dst, data)


# Delete loop-body states and edges from parent SDFG
sdfg.remove_nodes_from(states)

# Add NestedSDFG arrays
for name in read_set | write_set:
Expand Down Expand Up @@ -490,12 +512,13 @@ def apply(self, _, sdfg: sd.SDFG):
# correct map with a positive increment
start, end, step = end, start, -step

reentry_assignments = {k: v for k, v in condition_edge.data.assignments.items() if k != itervar}

# If necessary, make a nested SDFG with assignments
isedge = sdfg.edges_between(guard, body)[0]
symbols_to_remove = set()
if len(isedge.data.assignments) > 0:
if len(reentry_assignments) > 0:
nsdfg = helpers.nest_state_subgraph(sdfg, body, gr.SubgraphView(body, body.nodes()))
for sym in isedge.data.free_symbols:
for sym in entry_edge.data.free_symbols:
if sym in nsdfg.symbol_mapping or sym in nsdfg.in_connectors:
continue
if sym in sdfg.symbols:
Expand All @@ -522,12 +545,12 @@ def apply(self, _, sdfg: sd.SDFG):
nstate = nsdfg.sdfg.node(0)
init_state = nsdfg.sdfg.add_state_before(nstate)
nisedge = nsdfg.sdfg.edges_between(init_state, nstate)[0]
nisedge.data.assignments = isedge.data.assignments
nisedge.data.assignments = reentry_assignments
symbols_to_remove = set(nisedge.data.assignments.keys())
for k in nisedge.data.assignments.keys():
if k in nsdfg.symbol_mapping:
del nsdfg.symbol_mapping[k]
isedge.data.assignments = {}
condition_edge.data.assignments = {}

source_nodes = body.source_nodes()
sink_nodes = body.sink_nodes()
Expand All @@ -541,8 +564,8 @@ def apply(self, _, sdfg: sd.SDFG):
continue
# Arrays written with subsets that do not depend on the loop variable must be thread-local
map_dependency = False
for e in state.in_edges(node):
subset = e.data.get_dst_subset(e, state)
for e in body.in_edges(node):
subset = e.data.get_dst_subset(e, body)
if any(str(s) == itervar for s in subset.free_symbols):
map_dependency = True
break
Expand Down Expand Up @@ -644,25 +667,26 @@ def apply(self, _, sdfg: sd.SDFG):
if not source_nodes and not sink_nodes:
body.add_nedge(entry, exit, memlet.Memlet())

# Get rid of the loop exit condition edge
after_edge = sdfg.edges_between(guard, after)[0]
# Get rid of the loop exit condition edge (it will be readded below)
sdfg.remove_edge(after_edge)

# Remove the assignment on the edge to the guard
for e in sdfg.in_edges(guard):
for e in [init_edge, increment_edge]:
if e is None:
continue
if itervar in e.data.assignments:
del e.data.assignments[itervar]

# Remove the condition on the entry edge
condition_edge = sdfg.edges_between(guard, body)[0]
condition_edge.data.condition = CodeBlock("1")

# Get rid of backedge to guard
sdfg.remove_edge(sdfg.edges_between(body, guard)[0])
if increment_edge is not None:
sdfg.remove_edge(increment_edge)

# Route body directly to after state, maintaining any other assignments
# it might have had
sdfg.add_edge(body, after, sd.InterstateEdge(assignments=after_edge.data.assignments))
sdfg.add_edge(body, exit_state, sd.InterstateEdge(assignments=after_edge.data.assignments))

# If this had made the iteration variable a free symbol, we can remove
# it from the SDFG symbols
Expand Down
14 changes: 6 additions & 8 deletions dace/transformation/interstate/loop_unroll.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
if not super().can_be_applied(graph, expr_index, sdfg, permissive):
return False

guard = self.loop_guard
begin = self.loop_begin
found = find_for_loop(graph, guard, begin)
found = self.loop_information()

# If loop cannot be detected, fail
if not found:
Expand All @@ -49,20 +47,19 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):

def apply(self, graph: ControlFlowRegion, sdfg):
# Obtain loop information
guard: sd.SDFGState = self.loop_guard
begin: sd.SDFGState = self.loop_begin
after_state: sd.SDFGState = self.exit_state

# Obtain iteration variable, range, and stride, together with the last
# state(s) before the loop and the last loop state.
itervar, rng, loop_struct = find_for_loop(graph, guard, begin)
itervar, rng, loop_struct = self.loop_information()

# Loop must be fully unrollable for now.
if self.count != 0:
raise NotImplementedError # TODO(later)

# Get loop states
loop_states = list(sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != guard))
loop_states = self.loop_body()
first_id = loop_states.index(begin)
last_state = loop_struct[1]
last_id = loop_states.index(last_state)
Expand Down Expand Up @@ -91,7 +88,7 @@ def apply(self, graph: ControlFlowRegion, sdfg):
unrolled_states.append((new_states[first_id], new_states[last_id]))

# Get any assignments that might be on the edge to the after state
after_assignments = (graph.edges_between(guard, after_state)[0].data.assignments)
after_assignments = self.loop_exit_edge().data.assignments

# Connect new states to before and after states without conditions
if unrolled_states:
Expand All @@ -101,7 +98,8 @@ def apply(self, graph: ControlFlowRegion, sdfg):
graph.add_edge(unrolled_states[-1][1], after_state, sd.InterstateEdge(assignments=after_assignments))

# Remove old states from SDFG
graph.remove_nodes_from([guard] + loop_states)
guard_or_latch = self.loop_meta_states()
graph.remove_nodes_from(guard_or_latch + loop_states)

def instantiate_loop(
self,
Expand Down
52 changes: 27 additions & 25 deletions dace/transformation/interstate/move_loop_into_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,10 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
return False

# Obtain loop information
guard: sd.SDFGState = self.loop_guard
body: sd.SDFGState = self.loop_begin
after: sd.SDFGState = self.exit_state

# Obtain iteration variable, range, and stride
loop_info = find_for_loop(sdfg, guard, body)
loop_info = self.loop_information()
if not loop_info:
return False
itervar, (start, end, step), (_, body_end) = loop_info
Expand Down Expand Up @@ -157,11 +155,10 @@ def test_subset_dependency(subset: sbs.Subset, mparams: Set[int]) -> Tuple[bool,

def apply(self, _, sdfg: sd.SDFG):
# Obtain loop information
guard: sd.SDFGState = self.loop_guard
body: sd.SDFGState = self.loop_begin

# Obtain iteration variable, range, and stride
itervar, (start, end, step), _ = find_for_loop(sdfg, guard, body)
itervar, (start, end, step), _ = self.loop_information()

forward_loop = step > 0

Expand Down Expand Up @@ -194,26 +191,31 @@ def apply(self, _, sdfg: sd.SDFG):
else:
guard_body_edge = e

for body_inedge in sdfg.in_edges(body):
if body_inedge.src is guard:
guard_body_edge.data.assignments.update(body_inedge.data.assignments)
sdfg.remove_edge(body_inedge)
for body_outedge in sdfg.out_edges(body):
sdfg.remove_edge(body_outedge)
for guard_inedge in sdfg.in_edges(guard):
before_guard_edge.data.assignments.update(guard_inedge.data.assignments)
guard_inedge.data.assignments = {}
sdfg.add_edge(guard_inedge.src, body, guard_inedge.data)
sdfg.remove_edge(guard_inedge)
for guard_outedge in sdfg.out_edges(guard):
if guard_outedge.dst is body:
guard_body_edge.data.assignments.update(guard_outedge.data.assignments)
else:
guard_after_edge.data.assignments.update(guard_outedge.data.assignments)
guard_outedge.data.condition = CodeBlock("1")
sdfg.add_edge(body, guard_outedge.dst, guard_outedge.data)
sdfg.remove_edge(guard_outedge)
sdfg.remove_node(guard)
if self.expr_index <= 1:
guard = self.loop_guard
for body_inedge in sdfg.in_edges(body):
if body_inedge.src is guard:
guard_body_edge.data.assignments.update(body_inedge.data.assignments)
sdfg.remove_edge(body_inedge)
for body_outedge in sdfg.out_edges(body):
sdfg.remove_edge(body_outedge)
for guard_inedge in sdfg.in_edges(guard):
before_guard_edge.data.assignments.update(guard_inedge.data.assignments)
guard_inedge.data.assignments = {}
sdfg.add_edge(guard_inedge.src, body, guard_inedge.data)
sdfg.remove_edge(guard_inedge)
for guard_outedge in sdfg.out_edges(guard):
if guard_outedge.dst is body:
guard_body_edge.data.assignments.update(guard_outedge.data.assignments)
else:
guard_after_edge.data.assignments.update(guard_outedge.data.assignments)
guard_outedge.data.condition = CodeBlock("1")
sdfg.add_edge(body, guard_outedge.dst, guard_outedge.data)
sdfg.remove_edge(guard_outedge)
sdfg.remove_node(guard)
else: # Rotated or self loops
raise NotImplementedError('MoveLoopIntoMap not implemented for rotated and self-loops')

if itervar in nsdfg.symbol_mapping:
del nsdfg.symbol_mapping[itervar]
if itervar in sdfg.symbols:
Expand Down
Loading
Loading