diff --git a/dace/transformation/interstate/state_fusion.py b/dace/transformation/interstate/state_fusion.py index 0ce5b9f437..89367ab51d 100644 --- a/dace/transformation/interstate/state_fusion.py +++ b/dace/transformation/interstate/state_fusion.py @@ -153,12 +153,15 @@ def _check_paths(self, first_state: SDFGState, second_state: SDFGState, match_no path_to = nx.has_path(first_state._nx, node, match) if not path_to: continue - path_found = True + path_found |= True node2 = next(n for n in second_input if n.data == match.data) if not all(nx.has_path(second_state._nx, node2, n) for n in nodes_second): fail = True break - if fail or path_found: + # We keep looking for a potential match with a path that fail to find + # a path to the second state to make sure we test memlet_intersections + # independant of the order of the access nodes in the lists + if fail: break # Check for intersection (if None, fusion is ok) diff --git a/tests/transformations/state_fusion_test.py b/tests/transformations/state_fusion_test.py index 6fa0cc8d05..837b9abacc 100644 --- a/tests/transformations/state_fusion_test.py +++ b/tests/transformations/state_fusion_test.py @@ -397,7 +397,98 @@ def func(A: dace.float64[128, 128], B: dace.float64[128, 128]): assert sdfg.number_of_nodes() == 2 -if __name__ == '__main__': +def test_check_paths(): + # Test extracted from NASA GFDL_1M microphysics + + # Case of: + # qm -> q -> qm, m1 in Block_0 + # qm -> q and m1 -> m1 in Block_5 + # m1 has a write in both cases, leading to state not being fusable + # but original code would exist early if qm was tested _before_ m1 + + sdfg = dace.SDFG("state_fusion_check_path_test") + sdfg.add_array("m1", [1], dace.int32) + sdfg.add_array("precip_fall", [1], dace.int32) + sdfg.add_array("q", [1], dace.int32) + sdfg.add_array("qm", [1], dace.int32) + sdfg.add_array("dp1", [1], dace.int32) + + block_0 = sdfg.add_state() + q_b0_w = block_0.add_write("q") + qm_b0 = block_0.add_read("qm") + qm_b0_w = block_0.add_write("qm") + tasklet_b0_on_q = block_0.add_tasklet( + "tasklet_b0_on_q", + {"p_qm"}, + {"p_q_w"}, + "p_q_w = p_qm", + ) + block_0.add_edge(qm_b0, None, tasklet_b0_on_q, "p_qm", dace.Memlet("qm[0]")) + block_0.add_edge(tasklet_b0_on_q, "p_q_w", q_b0_w, None, dace.Memlet("q[0]")) + + m1_b0_w = block_0.add_write("m1") + tasklet_b0_on_m1 = block_0.add_tasklet( + "tasklet_b0_on_m1_qm", + {"p_q"}, + {"p_m1_w", "p_qm_w"}, + "p_m1_w = p_q", + ) + block_0.add_edge(q_b0_w, None, tasklet_b0_on_m1, "p_q", dace.Memlet("q[0]")) + block_0.add_edge(tasklet_b0_on_m1, "p_m1_w", m1_b0_w, None, dace.Memlet("m1[0]")) + block_0.add_edge(tasklet_b0_on_m1, "p_qm_w", qm_b0_w, None, dace.Memlet("qm[0]")) + + block_5 = sdfg.add_state_after(block_0) + precip_fall_b5 = block_5.add_read("precip_fall") + qm_b5 = block_5.add_read("qm") + q_b5_w = block_5.add_write("q") + tasklet_b5_on_q = block_5.add_tasklet( + "tasklet_b5_on_q", + {"p_precip_fall", "p_qm"}, + {"p_q_w"}, + "p_q_w = p_dp1 + 1", + ) + block_5.add_edge( + precip_fall_b5, + None, + tasklet_b5_on_q, + "p_precip_fall", + dace.Memlet("precip_fall[0]"), + ) + block_5.add_edge(qm_b5, None, tasklet_b5_on_q, "p_qm", dace.Memlet("qm[0]")) + block_5.add_edge(tasklet_b5_on_q, "p_q_w", q_b5_w, None, dace.Memlet("q[0]")) + + m1_b5 = block_5.add_read("m1") + m1_b5_w = block_5.add_write("m1") + tasklet_b5_on_m1 = block_5.add_tasklet( + "tasklet_b5_on_m1", + {"p_m1", "p_precip_fall"}, + {"p_m1_w"}, + "m1_w = p_m1 + 1", + ) + block_5.add_edge(m1_b5, None, tasklet_b5_on_m1, "p_m1", dace.Memlet("m1[0]")) + block_5.add_edge( + precip_fall_b5, + None, + tasklet_b5_on_m1, + "p_precip_fall", + dace.Memlet("precip_fall[0]"), + ) + block_5.add_edge(tasklet_b5_on_m1, "p_m1_w", m1_b5_w, None, dace.Memlet("m1[0]")) + + do_fuse = StateFusion()._check_paths( + first_state=block_0, + second_state=block_5, + match_nodes={qm_b0_w: qm_b5, m1_b0_w: m1_b5}, + nodes_first=[q_b0_w], + nodes_second=[q_b5_w], + second_input={precip_fall_b5, m1_b5, qm_b5}, + first_read=False, + second_read=False, + ) + assert not do_fuse + + +if __name__ == "__main__": test_fuse_assignments() test_fuse_assignments_2() test_fuse_assignment_in_use() @@ -414,3 +505,4 @@ def func(A: dace.float64[128, 128], B: dace.float64[128, 128]): test_inout_read_after_write() test_inout_second_state() test_inout_second_state_2() + test_check_paths()