-
Notifications
You must be signed in to change notification settings - Fork 133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix: DDE removing read from access_set in read/write nodes #1955
Fix: DDE removing read from access_set in read/write nodes #1955
Conversation
Good find! Looks like a bug in DDE indeed. I need to review these remaining access nodes a bit more, but so far looks like a good solution. If you can create a pipeline that calls DeadDataflowElimination that would be better. You can check out the rest of the DDE tests in the relevant file ( |
e47839b
to
2aa315f
Compare
2aa315f
to
cfc8e9d
Compare
PS: the same issue also seem present in |
## Description For this bug to show, we need two separate states with a transient produced in one and subsequently read and written (but not read again). It is important that `StateFusion` isn't able to merge these two state. I've put a dummy if/else in the middle. Before DDE this might look like  where `tmp_computed` is transient and `tmp` is a given variable. DDE will now go and see that `tmp_computed` can be removed as an output of the `read_write` tasklet. The currently faulty update of `access_set` will remove `tmp_computed` from the list of reads in `block` state. This will then propagate (badly) up to the `start` state where `tmp_computed` is marked as never read again, removing the whole tasklet, leaving the `block` state to read an uninitialized `tmp_computed` (if we were to codegen).  ### Repro ```python import dace import os # Create an empty SDFG sdfg = dace.SDFG(os.path.basename(__file__).removesuffix(".py").replace("-", "_")) sdfg.add_scalar("tmp", dace.float32) sdfg.add_scalar("tmp_computed", dace.float32, transient=True) start_state = sdfg.add_state("start", is_start_block=True) read_tmp = start_state.add_read("tmp") write_computed = start_state.add_write("tmp_computed") # upstream tasklet that writes a transient (to be read in a separate state) write = start_state.add_tasklet("write", {"IN_tmp"}, {"OUT_computed"}, "OUT_computed = IN_tmp * 2 + 1") start_state.add_memlet_path(read_tmp, write, dst_conn="IN_tmp", memlet=dace.Memlet(data="tmp")) start_state.add_memlet_path(write, write_computed, src_conn="OUT_computed", memlet=dace.Memlet(data="tmp_computed")) # Add a condition to avoid fusing next_state and start_state separate guard = sdfg.add_state_after(start_state, "guard_state") true_state = sdfg.add_state("true_state") false_state = sdfg.add_state("false_state") fs_read = false_state.add_read("tmp") fs_write = false_state.add_write("tmp") fs_tasklet = false_state.add_tasklet("abs", {"IN_tmp"}, {"OUT_tmp"}, "OUT_tmp = -IN_tmp") false_state.add_memlet_path(fs_read, fs_tasklet, dst_conn="IN_tmp", memlet=dace.Memlet("tmp")) false_state.add_memlet_path(fs_tasklet, fs_write, src_conn="OUT_tmp", memlet=dace.Memlet("tmp")) merge = sdfg.add_state("merge_state") sdfg.add_edge(guard, true_state, dace.InterstateEdge("tmp >= 0")) sdfg.add_edge(guard, false_state, dace.InterstateEdge("tmp < 0")) sdfg.add_edge(true_state, merge, dace.InterstateEdge()) sdfg.add_edge(false_state, merge, dace.InterstateEdge()) next_state = sdfg.add_state_after(merge) write_computed = next_state.add_write("tmp_computed") read_computed = next_state.add_read("tmp_computed") write_tmp = next_state.add_write("tmp") # downstream tasklet that reads _and_ writes a transient consume = next_state.add_tasklet("read_write", {"IN_computed"}, {"OUT_tmp", "OUT_computed"}, "OUT_computed = 2 * IN_computed\nOUT_tmp = OUT_computed + IN_computed") next_state.add_memlet_path(read_computed, consume, dst_conn="IN_computed", memlet=dace.Memlet(data="tmp_computed")) next_state.add_memlet_path(consume, write_tmp, src_conn="OUT_tmp", memlet=dace.Memlet(data="tmp")) next_state.add_memlet_path(consume, write_computed, src_conn="OUT_computed", memlet=dace.Memlet(data="tmp_computed")) sdfg.validate() sdfg.simplify(verbose=True, validate=True) assert len(list(filter(lambda node: isinstance(node, dace.sdfg.nodes.Tasklet), sdfg.start_block.nodes()))) == 1, "write tasklet in start_block is gone" ``` ### Finishing this PR I'll need help to evaluate whether or not the proposed solution is a good one. Questions that I have: - Is this the right place to fix it? - Should we manually change `access_sets` or would it be simpler/more reliable to redo the analysis step? - The repro case translates to a unit test. Is it a good one or should I e.g. call DDE directly and write assertions for the output of the pass? --------- Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com>
Cherry picking fix to the data flow elimination path from V1/maintenance. Original PR: #1955 --------- Co-authored-by: Roman Cattaneo <romanc@users.noreply.github.com> Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com>
## Description This PR refactors the GT4Py/DaCe bridge to expose control flow elements (`if` statements and `while` loops) to DaCe. Previously, the whole contents of a vertical loop was put in one big Tasklet. With this PR, that Tasklet is broken apart in case control flow is found such that control flow is visible in the SDFG. This allows DaCe to better analyze code and will be crucial in future (within the current milestone) performance optimization work. The main ideas in this PR are the following 1. Introduce `oir.CodeBlock` to recursively break down `oir.HorizontalExecution`s into smaller pieces that are either code flow or evaluated in (smaller) Tasklets. 2. Introduce `dcir.Condition`and `dcir.WhileLoop` to represent if statements and while loops that are translated into SDFG states. We keep the current `dcir.MaskStmt` / `dcir.While` for if statements / while loops inside horizontal regions, which aren't yet exposed to DaCe (see #1900). 3. Add support for `if` statements and `while` loops in the state machine of `sdfg_builder.py` 4. We are breaking up vertical loops inside stencils in multiple Tasklets. It might thus happen that we write a "local" scalar in one Tasklet and read it in another Tasklet (downstream). We thus create output connectors for all scalar writes in a Tasklet and input connectors for all reads (unless previously written in the same Tasklet). 5. Memlets can't be generated per horizontal execution anymore and need to be more fine grained. `TaskletAccessInfoCollector` does this work for us, duplicating some logic in `AccessInfoCollector`. A refactor task has been logged to fix/re-evaluate this later. This PR depends on the following (downstream) DaCe fixes - spcl/dace#1954 - spcl/dace#1955 which have been merged by now. Follow-up issues - unrelated changes have been moved to #1895 - #1896 - #1898 - #1900 Related issue: GEOS-ESM/NDSL#53 ## Requirements - [x] All fixes and/or new features come with corresponding tests. Added new tests and increased coverage of horizontal regions with PRs #1807 and #1851. - [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. Docs are [in our knowledge base](https://geos-esm.github.io/SMT-Nebulae/technical/backend/dace-bridge/) for now. Will be ported. --------- Co-authored-by: Roman Cattaneo <> Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com>
Description
For this bug to show, we need two separate states with a transient produced in one and subsequently read and written (but not read again). It is important that
StateFusion
isn't able to merge these two state. I've put a dummy if/else in the middle. Before DDE this might look likewhere
tmp_computed
is transient andtmp
is a given variable. DDE will now go and see thattmp_computed
can be removed as an output of theread_write
tasklet. The currently faulty update ofaccess_set
will removetmp_computed
from the list of reads inblock
state. This will then propagate (badly) up to thestart
state wheretmp_computed
is marked as never read again, removing the whole tasklet, leaving theblock
state to read an uninitializedtmp_computed
(if we were to codegen).Repro
Finishing this PR
I'll need help to evaluate whether or not the proposed solution is a good one. Questions that I have:
access_sets
or would it be simpler/more reliable to redo the analysis step?