Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complete coverage for reference-to-view pass #1488

Merged
merged 1 commit into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions dace/transformation/passes/reference_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def find_candidates(
# Check if any of the symbols is a scope symbol
entry = state.entry_node(node)
while entry is not None:
if fsyms & entry.new_symbols(sdfg, state, {}):
if fsyms & entry.new_symbols(sdfg, state, {}).keys():
result.remove(cand)
break
entry = state.entry_node(entry)
Expand Down Expand Up @@ -183,11 +183,12 @@ def remove_refsets(

# Modify the state graph as necessary
for e in edges_to_remove:
state.remove_edge_and_connectors(e)
state.remove_memlet_path(e)
for n in nodes_to_remove:
state.remove_node(n)
for e in edges_to_add:
state.add_edge(*e)
if len(state.edges_between(e[0], e[2])) == 0:
state.add_edge(*e)
for n in affected_nodes: # Orphaned nodes
if n in nodes_to_remove:
continue
Expand Down
59 changes: 59 additions & 0 deletions tests/sdfg/reference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,61 @@ def test_reference_loop_nonfree_internal_use():
assert np.allclose(ref, A)


@pytest.mark.parametrize(('array_outside_scope', 'depends_on_iterate'), ((False, True), (False, True)))
def test_ref2view_refset_in_scope(array_outside_scope, depends_on_iterate):
sdfg = dace.SDFG('reftest')
sdfg.add_array('A', [20], dace.float64)
sdfg.add_array('B', [20], dace.float64)
sdfg.add_reference('ref', [1], dace.float64)

memlet_string = 'A[i]' if depends_on_iterate else 'A[3]'

state = sdfg.add_state()
me, mx = state.add_map('somemap', dict(i='0:20'))
arr = state.add_access('A')
ref = state.add_access('ref')
write = state.add_write('B')

if array_outside_scope:
state.add_edge_pair(me, ref, arr, dace.Memlet(memlet_string), internal_connector='set')
else:
state.add_nedge(me, arr, dace.Memlet())
state.add_edge(arr, None, ref, 'set', dace.Memlet(memlet_string))

t = state.add_tasklet('addone', {'inp'}, {'out'}, 'out = inp + 1')
state.add_edge(ref, None, t, 'inp', dace.Memlet('ref'))
state.add_edge_pair(mx, t, write, dace.Memlet('B[i]'), internal_connector='out')

# Test sources
sources = FindReferenceSources().apply_pass(sdfg, {})
assert len(sources) == 1 # There is only one SDFG
sources = sources[0]
assert len(sources) == 1
assert sources['ref'] == {dace.Memlet(memlet_string)}

# Test correctness before pass
A = np.random.rand(20)
B = np.random.rand(20)
ref = (A + 1) if depends_on_iterate else (A[3] + 1)
sdfg(A=A, B=B)
assert np.allclose(B, ref)

# Test reference-to-view - should fail to apply
result = Pipeline([ReferenceToView()]).apply_pass(sdfg, {})
if depends_on_iterate:
assert 'ReferenceToView' not in result or not result['ReferenceToView']
else:
assert result['ReferenceToView'] == {'ref'}

# Test correctness after pass
if not depends_on_iterate:
A = np.random.rand(20)
B = np.random.rand(20)
ref = (A + 1) if depends_on_iterate else (A[3] + 1)
sdfg(A=A, B=B)
assert np.allclose(B, ref)


if __name__ == '__main__':
test_unset_reference()
test_reference_branch()
Expand All @@ -603,3 +658,7 @@ def test_reference_loop_nonfree_internal_use():
test_reference_loop_internal_use(False)
test_reference_loop_internal_use(True)
test_reference_loop_nonfree_internal_use()
test_ref2view_refset_in_scope(False, False)
test_ref2view_refset_in_scope(False, True)
test_ref2view_refset_in_scope(True, False)
test_ref2view_refset_in_scope(True, True)