Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix writes to nested definitions #1287

Merged
merged 2 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3067,11 +3067,11 @@ def _add_write_access(self,
arr_type: data.Data = None):

if name in self.sdfg.arrays:
return (name, None)
return (name, rng)
if (name, rng, 'w') in self.accesses:
return self.accesses[(name, rng, 'w')]
elif name in self.variables:
return (self.variables[name], None)
return (self.variables[name], rng)
elif (name, rng, 'r') in self.accesses or name in self.scope_vars:
return self._add_access(name, rng, 'w', target, new_name, arr_type)
else:
Expand Down
85 changes: 76 additions & 9 deletions tests/python_frontend/nested_name_accesses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_nested_name_accesses():
diff_norm = np.linalg.norm(dc_out - np_out)
ref_norm = np.linalg.norm(np_out)
rel_err = diff_norm / ref_norm
assert (rel_err < 1e-7)
assert rel_err < 1e-7


def test_nested_offset_access():
Expand All @@ -42,7 +42,7 @@ def nested_offset_access(inp: dc.float64[6, 5, 5]):
inp = np.reshape(np.arange(6 * 5 * 5, dtype=np.float64), (6, 5, 5)).copy()
out = nested_offset_access(inp)
ref = nested_offset_access.f(inp)
assert (np.allclose(out, ref))
assert np.allclose(out, ref)


def test_nested_offset_access_dappy():
Expand All @@ -62,7 +62,7 @@ def nested_offset_access(inp: dc.float64[6, 5, 5]):
inp = np.reshape(np.arange(6 * 5 * 5, dtype=np.float64), (6, 5, 5)).copy()
out = nested_offset_access(inp)
ref = nested_offset_access.f(inp)
assert (np.allclose(out, ref))
assert np.allclose(out, ref)


def test_nested_multi_offset_access():
Expand All @@ -79,7 +79,7 @@ def nested_offset_access(inp: dc.float64[6, 5, 10]):
inp = np.reshape(np.arange(6 * 5 * 10, dtype=np.float64), (6, 5, 10)).copy()
out = nested_offset_access(inp)
ref = nested_offset_access.f(inp)
assert (np.allclose(out, ref))
assert np.allclose(out, ref)


def test_nested_multi_offset_access_dappy():
Expand All @@ -100,7 +100,7 @@ def nested_offset_access(inp: dc.float64[6, 5, 10]):
inp = np.reshape(np.arange(6 * 5 * 10, dtype=np.float64), (6, 5, 10)).copy()
out = nested_offset_access(inp)
ref = nested_offset_access.f(inp)
assert (np.allclose(out, ref))
assert np.allclose(out, ref)


def test_nested_dec_offset_access():
Expand All @@ -116,7 +116,7 @@ def nested_offset_access(inp: dc.float64[6, 5, 5]):
inp = np.reshape(np.arange(6 * 5 * 5, dtype=np.float64), (6, 5, 5)).copy()
out = nested_offset_access(inp)
ref = nested_offset_access.f(inp)
assert (np.allclose(out, ref))
assert np.allclose(out, ref)


def test_nested_dec_offset_access_dappy():
Expand All @@ -136,7 +136,7 @@ def nested_offset_access(inp: dc.float64[6, 5, 5]):
inp = np.reshape(np.arange(6 * 5 * 5, dtype=np.float64), (6, 5, 5)).copy()
out = nested_offset_access(inp)
ref = nested_offset_access.f(inp)
assert (np.allclose(out, ref))
assert np.allclose(out, ref)


def test_nested_offset_access_nested_dependency():
Expand All @@ -157,7 +157,7 @@ def nested_offset_access_nested_dep(inp: dc.float64[6, 5, 5]):
out = nested_offset_access_nested_dep(inp)
os.environ['DACE_testing_serialization'] = last_value
ref = nested_offset_access_nested_dep.f(inp)
assert (np.allclose(out, ref))
assert np.allclose(out, ref)


def test_nested_offset_access_nested_dependency_dappy():
Expand All @@ -178,9 +178,74 @@ def nested_offset_access_nested_dep(inp: dc.float64[6, 5, 10]):
inp = np.reshape(np.arange(6 * 5 * 10, dtype=np.float64), (6, 5, 10)).copy()
out = nested_offset_access_nested_dep(inp)
ref = nested_offset_access_nested_dep.f(inp)
assert (np.allclose(out, ref))
assert np.allclose(out, ref)


def test_access_to_nested_transient():

KLEV = 3
KLON = 4
NBLOCKS = 5

@dc.program
def small_wip(inp: dc.float64[KLEV+1, KLON, NBLOCKS], out: dc.float64[KLEV, KLON, NBLOCKS]):
for jn in dc.map[0:NBLOCKS]:
tmp = np.zeros([KLEV+1, KLON])
for jl in range(KLON):
for jk in range(KLEV):
tmp[jk, jl] = inp[jk, jl, jn] + inp[jk+1, jl, jn]

for jl in range(KLON):
for jk in range(KLEV):
out[jk, jl, jn] = tmp[jk, jl] + tmp[jk+1, jl]

rng = np.random.default_rng(42)
inp = rng.random((KLEV+1, KLON, NBLOCKS))
ref = np.zeros((KLEV, KLON, NBLOCKS))
val = np.zeros((KLEV, KLON, NBLOCKS))

small_wip(inp, val)
small_wip.f(inp, ref)

assert np.allclose(val, ref)


def test_access_to_nested_transient_dappy():

KLEV = 3
KLON = 4
NBLOCKS = 5

@dc.program
def small_wip_dappy(inp: dc.float64[KLEV+1, KLON, NBLOCKS], out: dc.float64[KLEV, KLON, NBLOCKS]):
for jn in dc.map[0:NBLOCKS]:
tmp = np.zeros([KLEV+1, KLON])
for jl in range(KLON):
for jk in range(KLEV):
with dc.tasklet():
in1 << inp[jk, jl, jn]
in2 << inp[jk+1, jl, jn]
out1 >> tmp[jk, jl]
out1 = in1 + in2

for jl in range(KLON):
for jk in range(KLEV):
with dc.tasklet():
in1 << tmp[jk, jl]
in2 << tmp[jk+1, jl]
out1 >> out[jk, jl, jn]
out1 = in1 + in2

rng = np.random.default_rng(42)
inp = rng.random((KLEV+1, KLON, NBLOCKS))
ref = np.zeros((KLEV, KLON, NBLOCKS))
val = np.zeros((KLEV, KLON, NBLOCKS))

small_wip_dappy(inp, val)
small_wip_dappy.f(inp, ref)

assert np.allclose(val, ref)


if __name__ == "__main__":
test_nested_name_accesses()
Expand All @@ -192,3 +257,5 @@ def nested_offset_access_nested_dep(inp: dc.float64[6, 5, 10]):
test_nested_dec_offset_access_dappy()
test_nested_offset_access_nested_dependency()
test_nested_offset_access_nested_dependency_dappy()
test_access_to_nested_transient()
test_access_to_nested_transient_dappy()