Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: DDE removing read from access_set in read/write nodes #1964

Merged
merged 2 commits into from
Mar 13, 2025

Conversation

FlorianDeconinck
Copy link
Contributor

Cherry picking fix to the data flow elimination path from V1/maintenance. Original PR: #1955

## Description

For this bug to show, we need two separate states with a transient
produced in one and subsequently read and written (but not read again).
It is important that `StateFusion` isn't able to merge these two state.
I've put a dummy if/else in the middle. Before DDE this might look like


![image](https://github.com/user-attachments/assets/4ac52b3d-8cd8-4035-bc20-fba2258a7fd7)

where `tmp_computed` is transient and `tmp` is a given variable. DDE
will now go and see that `tmp_computed` can be removed as an output of
the `read_write` tasklet. The currently faulty update of `access_set`
will remove `tmp_computed` from the list of reads in `block` state. This
will then propagate (badly) up to the `start` state where `tmp_computed`
is marked as never read again, removing the whole tasklet, leaving the
`block` state to read an uninitialized `tmp_computed` (if we were to
codegen).


![image](https://github.com/user-attachments/assets/64892133-1f2f-4bf5-bee6-8a80c192a0cc)



### Repro

```python
import dace
import os

# Create an empty SDFG
sdfg = dace.SDFG(os.path.basename(__file__).removesuffix(".py").replace("-", "_"))

sdfg.add_scalar("tmp", dace.float32)
sdfg.add_scalar("tmp_computed", dace.float32, transient=True)

start_state = sdfg.add_state("start", is_start_block=True)

read_tmp = start_state.add_read("tmp")
write_computed = start_state.add_write("tmp_computed")

# upstream tasklet that writes a transient (to be read in a separate state)
write = start_state.add_tasklet("write", {"IN_tmp"}, {"OUT_computed"}, "OUT_computed = IN_tmp * 2 + 1")
start_state.add_memlet_path(read_tmp, write, dst_conn="IN_tmp", memlet=dace.Memlet(data="tmp"))
start_state.add_memlet_path(write, write_computed, src_conn="OUT_computed", memlet=dace.Memlet(data="tmp_computed"))

# Add a condition to avoid fusing next_state and start_state separate
guard = sdfg.add_state_after(start_state, "guard_state")
true_state = sdfg.add_state("true_state")
false_state = sdfg.add_state("false_state")
fs_read = false_state.add_read("tmp")
fs_write = false_state.add_write("tmp")
fs_tasklet = false_state.add_tasklet("abs", {"IN_tmp"}, {"OUT_tmp"}, "OUT_tmp = -IN_tmp")
false_state.add_memlet_path(fs_read, fs_tasklet, dst_conn="IN_tmp", memlet=dace.Memlet("tmp"))
false_state.add_memlet_path(fs_tasklet, fs_write, src_conn="OUT_tmp", memlet=dace.Memlet("tmp"))
merge = sdfg.add_state("merge_state")

sdfg.add_edge(guard, true_state, dace.InterstateEdge("tmp >= 0"))
sdfg.add_edge(guard, false_state, dace.InterstateEdge("tmp < 0"))
sdfg.add_edge(true_state, merge, dace.InterstateEdge())
sdfg.add_edge(false_state, merge, dace.InterstateEdge())

next_state = sdfg.add_state_after(merge)

write_computed = next_state.add_write("tmp_computed")
read_computed = next_state.add_read("tmp_computed")
write_tmp = next_state.add_write("tmp")

# downstream tasklet that reads _and_ writes a transient
consume = next_state.add_tasklet("read_write", {"IN_computed"}, {"OUT_tmp", "OUT_computed"}, "OUT_computed = 2 * IN_computed\nOUT_tmp = OUT_computed + IN_computed")
next_state.add_memlet_path(read_computed, consume, dst_conn="IN_computed", memlet=dace.Memlet(data="tmp_computed"))
next_state.add_memlet_path(consume, write_tmp, src_conn="OUT_tmp", memlet=dace.Memlet(data="tmp"))
next_state.add_memlet_path(consume, write_computed, src_conn="OUT_computed", memlet=dace.Memlet(data="tmp_computed"))

sdfg.validate()

sdfg.simplify(verbose=True, validate=True)

assert len(list(filter(lambda node: isinstance(node, dace.sdfg.nodes.Tasklet), sdfg.start_block.nodes()))) == 1, "write tasklet in start_block is gone"
```

### Finishing this PR

I'll need help to evaluate whether or not the proposed solution is a
good one. Questions that I have:

- Is this the right place to fix it?
- Should we manually change `access_sets` or would it be simpler/more
reliable to redo the analysis step?
- The repro case translates to a unit test. Is it a good one or should I
e.g. call DDE directly and write assertions for the output of the pass?

---------

Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com>
Copy link
Contributor

@acalotoiu acalotoiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thank you!

Co-authored-by: Roman Cattaneo <romanc@users.noreply.github.com>
@phschaad phschaad added this pull request to the merge queue Mar 13, 2025
Merged via the queue into spcl:main with commit 3ae5027 Mar 13, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants