Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for structures nested in (nested) struct-arrays #1534

Merged
merged 7 commits into from
Feb 27, 2024
8 changes: 5 additions & 3 deletions dace/codegen/targets/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dace.sdfg import (ScopeSubgraphView, SDFG, scope_contains_scope, is_array_stream_view, NodeNotExpandedError,
dynamic_map_inputs, local_transients)
from dace.sdfg.scope import is_devicelevel_gpu, is_devicelevel_fpga, is_in_scope
from dace.sdfg.validation import validate_memlet_data
from typing import Union
from dace.codegen.targets import fpga

Expand All @@ -40,7 +41,7 @@ def _visit_structure(struct: data.Structure, args: dict, prefix: str = ''):
_visit_structure(v, args, f'{prefix}->{k}')
elif isinstance(v, data.ContainerArray):
_visit_structure(v.stype, args, f'{prefix}->{k}')
elif isinstance(v, data.Data):
if isinstance(v, data.Data):
args[f'{prefix}->{k}'] = v

# Keeps track of generated connectors, so we know how to access them in nested scopes
Expand Down Expand Up @@ -620,6 +621,7 @@ def copy_memory(
callsite_stream,
)


def _emit_copy(
self,
sdfg,
Expand All @@ -637,9 +639,9 @@ def _emit_copy(
orig_vconn = vconn

# Determine memlet directionality
if isinstance(src_node, nodes.AccessNode) and memlet.data == src_node.data:
if isinstance(src_node, nodes.AccessNode) and validate_memlet_data(memlet.data, src_node.data):
write = True
elif isinstance(dst_node, nodes.AccessNode) and memlet.data == dst_node.data:
elif isinstance(dst_node, nodes.AccessNode) and validate_memlet_data(memlet.data, dst_node.data):
write = False
elif isinstance(src_node, nodes.CodeNode) and isinstance(dst_node, nodes.CodeNode):
# Code->Code copy (not read nor write)
Expand Down
7 changes: 6 additions & 1 deletion dace/codegen/targets/framecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def generate_fileheader(self, sdfg: SDFG, global_stream: CodeIOStream, backend:
if arr is not None:
datatypes.add(arr.dtype)

emitted = set()

def _emit_definitions(dtype: dtypes.typeclass, wrote_something: bool) -> bool:
if isinstance(dtype, dtypes.pointer):
wrote_something = _emit_definitions(dtype._typeclass, wrote_something)
Expand All @@ -164,7 +166,10 @@ def _emit_definitions(dtype: dtypes.typeclass, wrote_something: bool) -> bool:
if hasattr(dtype, 'emit_definition'):
if not wrote_something:
global_stream.write("", sdfg)
global_stream.write(dtype.emit_definition(), sdfg)
if dtype not in emitted:
global_stream.write(dtype.emit_definition(), sdfg)
wrote_something = True
emitted.add(dtype)
return wrote_something

# Emit unique definitions
Expand Down
6 changes: 4 additions & 2 deletions dace/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,8 +1449,10 @@ def validate_name(name):
return False
if name in {'True', 'False', 'None'}:
return False
if namere.match(name) is None:
return False
tokens = name.split('.')
for token in tokens:
if namere.match(token) is None:
return False
return True


Expand Down
4 changes: 3 additions & 1 deletion dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,9 @@ def replace_dict(self,

# Replace in arrays and symbols (if a variable name)
if replace_keys:
for name, new_name in repldict.items():
# Filter out nested data names, as we cannot and do not want to replace names in nested data descriptors
repldict_filtered = {k: v for k, v in repldict.items() if '.' not in k}
for name, new_name in repldict_filtered.items():
if validate_name(new_name):
_replace_dict_keys(self._arrays, name, new_name)
_replace_dict_keys(self.symbols, name, new_name)
Expand Down
17 changes: 17 additions & 0 deletions dace/sdfg/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,3 +981,20 @@ def __str__(self):
locinfo += f'\nInvalid SDFG saved for inspection in {os.path.abspath(self.path)}'

return f'{self.message} (at state {state.label}{edgestr}){locinfo}'


def validate_memlet_data(memlet_data: str, access_data: str) -> bool:
""" Validates that the src/dst access node data matches the memlet data.

:param memlet_data: The data of the memlet.
:param access_data: The data of the access node.
:return: True if the memlet data matches the access node data.
"""
if memlet_data == access_data:
return True
if memlet_data is None or access_data is None:
return False
access_tokens = access_data.split('.')
memlet_tokens = memlet_data.split('.')
mem_root = '.'.join(memlet_tokens[:len(access_tokens)])
return mem_root == access_data
59 changes: 59 additions & 0 deletions tests/sdfg/data/container_array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,67 @@ def test_two_levels():
assert np.allclose(ref, B[0])


def test_multi_nested_containers():

M, N = dace.symbol('M'), dace.symbol('N')
sdfg = dace.SDFG('tester')
float_desc = dace.data.Scalar(dace.float32)
E_desc = dace.data.Structure({'F': dace.float32[N], 'G':float_desc}, 'InnerStruct')
B_desc = dace.data.ContainerArray(E_desc, [M])
A_desc = dace.data.Structure({'B': B_desc, 'C': dace.float32[M], 'D': float_desc}, 'OuterStruct')
sdfg.add_datadesc('A', A_desc)
sdfg.add_datadesc_view('vB', B_desc)
sdfg.add_datadesc_view('vE', E_desc)
sdfg.add_array('out', [M, N], dace.float32)

state = sdfg.add_state()
rA = state.add_read('A')
vB = state.add_access('vB')
vE = state.add_access('vE')
wout = state.add_write('out')

me, mx = state.add_map('outer_product', dict(i='0:M', j='0:N'))
tasklet = state.add_tasklet('outer_product', {'__in_A_B_E_F', '__in_A_B_E_G', '__in_A_C', '__in_A_D'}, {'__out'},
'__out = (__in_A_B_E_F + __in_A_B_E_G) * (__in_A_C + __in_A_D)')

state.add_edge(rA, None, vB, 'views', dace.Memlet('A.B'))
state.add_memlet_path(vB, me, vE, dst_conn='views', memlet=dace.Memlet('vB[i]'))
state.add_edge(vE, None, tasklet, '__in_A_B_E_F', dace.Memlet('vE.F[j]'))
state.add_edge(vE, None, tasklet, '__in_A_B_E_G', dace.Memlet(data='vE.G', subset='0'))
state.add_memlet_path(rA, me, tasklet, dst_conn='__in_A_C', memlet=dace.Memlet('A.C[i]'))
state.add_memlet_path(rA, me, tasklet, dst_conn='__in_A_D', memlet=dace.Memlet(data='A.D', subset='0'))
state.add_memlet_path(tasklet, mx, wout, src_conn='__out', memlet=dace.Memlet('out[i, j]'))

c_data = np.arange(5, dtype=np.float32)
f_data = np.arange(5 * 3, dtype=np.float32).reshape(5, 3)

e_class = E_desc.dtype._typeclass.as_ctypes()
b_obj = []
b_data = np.ndarray((5, ), dtype=ctypes.c_void_p)
for i in range(5):
f_obj = f_data[i].__array_interface__['data'][0]
e_obj = e_class(F=f_obj, G=ctypes.c_float(0.1))
b_obj.append(e_obj) # NOTE: This is needed to keep the object alive ...
b_data[i] = ctypes.addressof(e_obj)
a_dace = A_desc.dtype._typeclass.as_ctypes()(B=b_data.__array_interface__['data'][0],
C=c_data.__array_interface__['data'][0],
D=ctypes.c_float(0.2))




out_dace = np.empty((5, 3), dtype=np.float32)
ref = np.empty((5, 3), dtype=np.float32)
for i in range(5):
ref[i] = (f_data[i] + 0.1) * (c_data[i] + 0.2)

sdfg(A=a_dace, out=out_dace, M=5, N=3)
assert np.allclose(out_dace, ref)


if __name__ == '__main__':
test_read_struct_array()
test_write_struct_array()
test_jagged_container_array()
test_two_levels()
test_multi_nested_containers()
Loading