Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: DDE removing read from access_set in read/write nodes #1955

Merged

Conversation

romanc
Copy link
Contributor

@romanc romanc commented Feb 27, 2025

Description

For this bug to show, we need two separate states with a transient produced in one and subsequently read and written (but not read again). It is important that StateFusion isn't able to merge these two state. I've put a dummy if/else in the middle. Before DDE this might look like

image

where tmp_computed is transient and tmp is a given variable. DDE will now go and see that tmp_computed can be removed as an output of the read_write tasklet. The currently faulty update of access_set will remove tmp_computed from the list of reads in block state. This will then propagate (badly) up to the start state where tmp_computed is marked as never read again, removing the whole tasklet, leaving the block state to read an uninitialized tmp_computed (if we were to codegen).

image

Repro

import dace
import os

# Create an empty SDFG
sdfg = dace.SDFG(os.path.basename(__file__).removesuffix(".py").replace("-", "_"))

sdfg.add_scalar("tmp", dace.float32)
sdfg.add_scalar("tmp_computed", dace.float32, transient=True)

start_state = sdfg.add_state("start", is_start_block=True)

read_tmp = start_state.add_read("tmp")
write_computed = start_state.add_write("tmp_computed")

# upstream tasklet that writes a transient (to be read in a separate state)
write = start_state.add_tasklet("write", {"IN_tmp"}, {"OUT_computed"}, "OUT_computed = IN_tmp * 2 + 1")
start_state.add_memlet_path(read_tmp, write, dst_conn="IN_tmp", memlet=dace.Memlet(data="tmp"))
start_state.add_memlet_path(write, write_computed, src_conn="OUT_computed", memlet=dace.Memlet(data="tmp_computed"))

# Add a condition to avoid fusing next_state and start_state separate
guard = sdfg.add_state_after(start_state, "guard_state")
true_state = sdfg.add_state("true_state")
false_state = sdfg.add_state("false_state")
fs_read = false_state.add_read("tmp")
fs_write = false_state.add_write("tmp")
fs_tasklet = false_state.add_tasklet("abs", {"IN_tmp"}, {"OUT_tmp"}, "OUT_tmp = -IN_tmp")
false_state.add_memlet_path(fs_read, fs_tasklet, dst_conn="IN_tmp", memlet=dace.Memlet("tmp"))
false_state.add_memlet_path(fs_tasklet, fs_write, src_conn="OUT_tmp", memlet=dace.Memlet("tmp"))
merge = sdfg.add_state("merge_state")

sdfg.add_edge(guard, true_state, dace.InterstateEdge("tmp >= 0"))
sdfg.add_edge(guard, false_state, dace.InterstateEdge("tmp < 0"))
sdfg.add_edge(true_state, merge, dace.InterstateEdge())
sdfg.add_edge(false_state, merge, dace.InterstateEdge())

next_state = sdfg.add_state_after(merge)

write_computed = next_state.add_write("tmp_computed")
read_computed = next_state.add_read("tmp_computed")
write_tmp = next_state.add_write("tmp")

# downstream tasklet that reads _and_ writes a transient
consume = next_state.add_tasklet("read_write", {"IN_computed"}, {"OUT_tmp", "OUT_computed"}, "OUT_computed = 2 * IN_computed\nOUT_tmp = OUT_computed + IN_computed")
next_state.add_memlet_path(read_computed, consume, dst_conn="IN_computed", memlet=dace.Memlet(data="tmp_computed"))
next_state.add_memlet_path(consume, write_tmp, src_conn="OUT_tmp", memlet=dace.Memlet(data="tmp"))
next_state.add_memlet_path(consume, write_computed, src_conn="OUT_computed", memlet=dace.Memlet(data="tmp_computed"))

sdfg.validate()

sdfg.simplify(verbose=True, validate=True)

assert len(list(filter(lambda node: isinstance(node, dace.sdfg.nodes.Tasklet), sdfg.start_block.nodes()))) == 1, "write tasklet in start_block is gone"

Finishing this PR

I'll need help to evaluate whether or not the proposed solution is a good one. Questions that I have:

  • Is this the right place to fix it?
  • Should we manually change access_sets or would it be simpler/more reliable to redo the analysis step?
  • The repro case translates to a unit test. Is it a good one or should I e.g. call DDE directly and write assertions for the output of the pass?

@tbennun
Copy link
Collaborator

tbennun commented Feb 27, 2025

Good find! Looks like a bug in DDE indeed.

I need to review these remaining access nodes a bit more, but so far looks like a good solution.

If you can create a pipeline that calls DeadDataflowElimination that would be better. You can check out the rest of the DDE tests in the relevant file (tests/passes/dead_dataflow_elimination_test.py) to see how the rest of the tests invoke DDE.

@romanc romanc force-pushed the romanc/dde-transient-propagation branch 2 times, most recently from e47839b to 2aa315f Compare February 28, 2025 08:52
@romanc
Copy link
Contributor Author

romanc commented Feb 28, 2025

Just running DDE instead of the full pipeline was a good idea. It allowed me to simplify the test case and narrowly test for the expected cleanup.

simplified test case
image

@romanc romanc marked this pull request as ready for review February 28, 2025 08:57
@romanc romanc force-pushed the romanc/dde-transient-propagation branch from 2aa315f to cfc8e9d Compare February 28, 2025 09:22
@romanc
Copy link
Contributor Author

romanc commented Feb 28, 2025

PS: the same issue also seem present in main. Branch https://github.com/romanc/dace/tree/romanc/dde-fix-mainline ports this fix to mainline. Let's have all the discussions here first and then once we all agree, the mainline version is just gonna be a port.

@tbennun tbennun added this pull request to the merge queue Mar 5, 2025
Merged via the queue into spcl:v1/maintenance with commit 6e2585b Mar 5, 2025
10 checks passed
FlorianDeconinck pushed a commit to FlorianDeconinck/dace that referenced this pull request Mar 5, 2025
## Description

For this bug to show, we need two separate states with a transient
produced in one and subsequently read and written (but not read again).
It is important that `StateFusion` isn't able to merge these two state.
I've put a dummy if/else in the middle. Before DDE this might look like


![image](https://github.com/user-attachments/assets/4ac52b3d-8cd8-4035-bc20-fba2258a7fd7)

where `tmp_computed` is transient and `tmp` is a given variable. DDE
will now go and see that `tmp_computed` can be removed as an output of
the `read_write` tasklet. The currently faulty update of `access_set`
will remove `tmp_computed` from the list of reads in `block` state. This
will then propagate (badly) up to the `start` state where `tmp_computed`
is marked as never read again, removing the whole tasklet, leaving the
`block` state to read an uninitialized `tmp_computed` (if we were to
codegen).


![image](https://github.com/user-attachments/assets/64892133-1f2f-4bf5-bee6-8a80c192a0cc)



### Repro

```python
import dace
import os

# Create an empty SDFG
sdfg = dace.SDFG(os.path.basename(__file__).removesuffix(".py").replace("-", "_"))

sdfg.add_scalar("tmp", dace.float32)
sdfg.add_scalar("tmp_computed", dace.float32, transient=True)

start_state = sdfg.add_state("start", is_start_block=True)

read_tmp = start_state.add_read("tmp")
write_computed = start_state.add_write("tmp_computed")

# upstream tasklet that writes a transient (to be read in a separate state)
write = start_state.add_tasklet("write", {"IN_tmp"}, {"OUT_computed"}, "OUT_computed = IN_tmp * 2 + 1")
start_state.add_memlet_path(read_tmp, write, dst_conn="IN_tmp", memlet=dace.Memlet(data="tmp"))
start_state.add_memlet_path(write, write_computed, src_conn="OUT_computed", memlet=dace.Memlet(data="tmp_computed"))

# Add a condition to avoid fusing next_state and start_state separate
guard = sdfg.add_state_after(start_state, "guard_state")
true_state = sdfg.add_state("true_state")
false_state = sdfg.add_state("false_state")
fs_read = false_state.add_read("tmp")
fs_write = false_state.add_write("tmp")
fs_tasklet = false_state.add_tasklet("abs", {"IN_tmp"}, {"OUT_tmp"}, "OUT_tmp = -IN_tmp")
false_state.add_memlet_path(fs_read, fs_tasklet, dst_conn="IN_tmp", memlet=dace.Memlet("tmp"))
false_state.add_memlet_path(fs_tasklet, fs_write, src_conn="OUT_tmp", memlet=dace.Memlet("tmp"))
merge = sdfg.add_state("merge_state")

sdfg.add_edge(guard, true_state, dace.InterstateEdge("tmp >= 0"))
sdfg.add_edge(guard, false_state, dace.InterstateEdge("tmp < 0"))
sdfg.add_edge(true_state, merge, dace.InterstateEdge())
sdfg.add_edge(false_state, merge, dace.InterstateEdge())

next_state = sdfg.add_state_after(merge)

write_computed = next_state.add_write("tmp_computed")
read_computed = next_state.add_read("tmp_computed")
write_tmp = next_state.add_write("tmp")

# downstream tasklet that reads _and_ writes a transient
consume = next_state.add_tasklet("read_write", {"IN_computed"}, {"OUT_tmp", "OUT_computed"}, "OUT_computed = 2 * IN_computed\nOUT_tmp = OUT_computed + IN_computed")
next_state.add_memlet_path(read_computed, consume, dst_conn="IN_computed", memlet=dace.Memlet(data="tmp_computed"))
next_state.add_memlet_path(consume, write_tmp, src_conn="OUT_tmp", memlet=dace.Memlet(data="tmp"))
next_state.add_memlet_path(consume, write_computed, src_conn="OUT_computed", memlet=dace.Memlet(data="tmp_computed"))

sdfg.validate()

sdfg.simplify(verbose=True, validate=True)

assert len(list(filter(lambda node: isinstance(node, dace.sdfg.nodes.Tasklet), sdfg.start_block.nodes()))) == 1, "write tasklet in start_block is gone"
```

### Finishing this PR

I'll need help to evaluate whether or not the proposed solution is a
good one. Questions that I have:

- Is this the right place to fix it?
- Should we manually change `access_sets` or would it be simpler/more
reliable to redo the analysis step?
- The repro case translates to a unit test. Is it a good one or should I
e.g. call DDE directly and write assertions for the output of the pass?

---------

Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com>
@romanc romanc deleted the romanc/dde-transient-propagation branch March 6, 2025 08:41
github-merge-queue bot pushed a commit that referenced this pull request Mar 13, 2025
Cherry picking fix to the data flow elimination path from
V1/maintenance. Original PR: #1955

---------

Co-authored-by: Roman Cattaneo <romanc@users.noreply.github.com>
Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com>
romanc added a commit to GridTools/gt4py that referenced this pull request Mar 18, 2025
## Description

This PR refactors the GT4Py/DaCe bridge to expose control flow elements
(`if` statements and `while` loops) to DaCe. Previously, the whole
contents of a vertical loop was put in one big Tasklet. With this PR,
that Tasklet is broken apart in case control flow is found such that
control flow is visible in the SDFG. This allows DaCe to better analyze
code and will be crucial in future (within the current milestone)
performance optimization work.

The main ideas in this PR are the following

1. Introduce `oir.CodeBlock` to recursively break down
`oir.HorizontalExecution`s into smaller pieces that are either code flow
or evaluated in (smaller) Tasklets.
2. Introduce `dcir.Condition`and `dcir.WhileLoop` to represent if
statements and while loops that are translated into SDFG states. We keep
the current `dcir.MaskStmt` / `dcir.While` for if statements / while
loops inside horizontal regions, which aren't yet exposed to DaCe (see
#1900).
3. Add support for `if` statements and `while` loops in the state
machine of `sdfg_builder.py`
4. We are breaking up vertical loops inside stencils in multiple
Tasklets. It might thus happen that we write a "local" scalar in one
Tasklet and read it in another Tasklet (downstream). We thus create
output connectors for all scalar writes in a Tasklet and input
connectors for all reads (unless previously written in the same
Tasklet).
5. Memlets can't be generated per horizontal execution anymore and need
to be more fine grained. `TaskletAccessInfoCollector` does this work for
us, duplicating some logic in `AccessInfoCollector`. A refactor task has
been logged to fix/re-evaluate this later.

This PR depends on the following (downstream) DaCe fixes

- spcl/dace#1954
- spcl/dace#1955

which have been merged by now.

Follow-up issues

- unrelated changes have been moved to
#1895
- #1896
- #1898
- #1900

Related issue: GEOS-ESM/NDSL#53

## Requirements

- [x] All fixes and/or new features come with corresponding tests.
Added new tests and increased coverage of horizontal regions with PRs
#1807 and #1851.
- [ ] Important design decisions have been documented in the appropriate
ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md)
folder.
Docs are [in our knowledge
base](https://geos-esm.github.io/SMT-Nebulae/technical/backend/dace-bridge/)
for now. Will be ported.

---------

Co-authored-by: Roman Cattaneo <>
Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants