Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Parallel Map Fusion #1965

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
Realized that I by accident deleted the validation arguments.
philip-paul-mueller committed Mar 6, 2025
commit f4b12c52462f5a757ffb3912f5a786d108df886a
23 changes: 21 additions & 2 deletions dace/transformation/passes/full_map_fusion.py
Original file line number Diff line number Diff line change
@@ -18,7 +18,8 @@ class FullMapFusion(ppl.Pass):
will speedup the fusion, as the SDFG has to be scanned only once.
The pass accepts the same options as `MapFusion`, for a detailed description
see there.
The main difference is that parallel map fusion on by default.
The main difference is that parallel map fusion is on by default and that
the single use data can not be passed as an argument.
:param only_inner_maps: Only match Maps that are internal, i.e. inside another Map.
:param only_toplevel_maps: Only consider Maps that are at the top.
@@ -28,6 +29,8 @@ class FullMapFusion(ppl.Pass):
:param allow_parallel_map_fusion: Allow to merge parallel maps, by default `True`.
:param only_if_common_ancestor: In parallel map fusion mode, only fuse if both map
have a common direct ancestor.
:param validate: Validate the SDFG after the pass as finished.
:param validate_all: Validate the SDFG after every transformation.
:todo: Implement a faster matcher as the pattern is constant.
"""
@@ -55,7 +58,6 @@ class FullMapFusion(ppl.Pass):
default=False,
desc="If `True` then all intermediates will be classified as shared.",
)

allow_serial_map_fusion = properties.Property(
dtype=bool,
default=True,
@@ -72,6 +74,16 @@ class FullMapFusion(ppl.Pass):
default=False,
desc="If `True` restrict parallel map fusion to maps that have a direct common ancestor.",
)
validate = properties.Property(
dtype=bool,
default=True,
desc='If True, validates the SDFG after all transformations have been applied.',
)
validate_all = properties.Property(
dtype=bool,
default=False,
desc='If True, validates the SDFG after each transformation applies.'
)


def __init__(
@@ -83,8 +95,11 @@ def __init__(
allow_serial_map_fusion: Optional[bool] = None,
allow_parallel_map_fusion: Optional[bool] = None,
only_if_common_ancestor: Optional[bool] = None,
validate: Optional[bool] = None,
validate_all: Optional[bool] = None,
**kwargs: Any,
) -> None:
breakpoint()
super().__init__(**kwargs)
if only_toplevel_maps is not None:
self.only_toplevel_maps = only_toplevel_maps
@@ -100,6 +115,10 @@ def __init__(
self.allow_parallel_map_fusion = allow_parallel_map_fusion
if only_if_common_ancestor is not None:
self.only_if_common_ancestor = only_if_common_ancestor
if validate is not None:
self.validate = validate
if validate_all is not None:
self.validate_all = validate_all

def modifies(self) -> ppl.Modifies:
return ppl.Modifies.Scopes | ppl.Modifies.AccessNodes | ppl.Modifies.Memlets