Skip to content

Commit 680a956

Browse files
authoredSep 11, 2023
Merge pull request #1367 from spcl/fortran_ast_parents
Add offset normalization to Fortran frontend
2 parents c9304d6 + 70c33dd commit 680a956

7 files changed

+621
-34
lines changed
 

‎dace/frontend/fortran/ast_internal_classes.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
2-
from typing import Any, List, Tuple, Type, TypeVar, Union, overload
2+
from typing import Any, List, Optional, Tuple, Type, TypeVar, Union, overload
33

44
# The node class is the base class for all nodes in the AST. It provides attributes including the line number and fields.
55
# Attributes are not used when walking the tree, but are useful for debugging and for code generation.
@@ -11,6 +11,14 @@ def __init__(self, *args, **kwargs): # real signature unknown
1111
self.integrity_exceptions = []
1212
self.read_vars = []
1313
self.written_vars = []
14+
self.parent: Optional[
15+
Union[
16+
Subroutine_Subprogram_Node,
17+
Function_Subprogram_Node,
18+
Main_Program_Node,
19+
Module_Node
20+
]
21+
] = None
1422
for k, v in kwargs.items():
1523
setattr(self, k, v)
1624

‎dace/frontend/fortran/ast_transforms.py

+148-26
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved.
22

33
from dace.frontend.fortran import ast_components, ast_internal_classes
4-
from typing import List, Tuple, Set
4+
from typing import Dict, List, Optional, Tuple, Set
55
import copy
66

77

@@ -310,6 +310,65 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No
310310

311311
return ast_internal_classes.Execution_Part_Node(execution=newbody)
312312

313+
class ParentScopeAssigner(NodeVisitor):
314+
"""
315+
For each node, it assigns its parent scope - program, subroutine, function.
316+
317+
If the parent node is one of the "parent" types, we assign it as the parent.
318+
Otherwise, we look for the parent of my parent to cover nested AST nodes within
319+
a single scope.
320+
"""
321+
def __init__(self):
322+
pass
323+
324+
def visit(self, node: ast_internal_classes.FNode, parent_node: Optional[ast_internal_classes.FNode] = None):
325+
326+
parent_node_types = [
327+
ast_internal_classes.Subroutine_Subprogram_Node,
328+
ast_internal_classes.Function_Subprogram_Node,
329+
ast_internal_classes.Main_Program_Node,
330+
ast_internal_classes.Module_Node
331+
]
332+
333+
if parent_node is not None and type(parent_node) in parent_node_types:
334+
node.parent = parent_node
335+
elif parent_node is not None:
336+
node.parent = parent_node.parent
337+
338+
# Copied from `generic_visit` to recursively parse all leafs
339+
for field, value in iter_fields(node):
340+
if isinstance(value, list):
341+
for item in value:
342+
if isinstance(item, ast_internal_classes.FNode):
343+
self.visit(item, node)
344+
elif isinstance(value, ast_internal_classes.FNode):
345+
self.visit(value, node)
346+
347+
class ScopeVarsDeclarations(NodeVisitor):
348+
"""
349+
Creates a mapping (scope name, variable name) -> variable declaration.
350+
351+
The visitor is used to access information on variable dimension, sizes, and offsets.
352+
"""
353+
354+
def __init__(self):
355+
356+
self.scope_vars: Dict[Tuple[str, str], ast_internal_classes.FNode] = {}
357+
358+
def get_var(self, scope: ast_internal_classes.FNode, variable_name: str) -> ast_internal_classes.FNode:
359+
return self.scope_vars[(self._scope_name(scope), variable_name)]
360+
361+
def visit_Var_Decl_Node(self, node: ast_internal_classes.Var_Decl_Node):
362+
363+
parent_name = self._scope_name(node.parent)
364+
var_name = node.name
365+
self.scope_vars[(parent_name, var_name)] = node
366+
367+
def _scope_name(self, scope: ast_internal_classes.FNode) -> str:
368+
if isinstance(scope, ast_internal_classes.Main_Program_Node):
369+
return scope.name.name.name
370+
else:
371+
return scope.name.name
313372

314373
class IndexExtractorNodeLister(NodeVisitor):
315374
"""
@@ -336,9 +395,20 @@ class IndexExtractor(NodeTransformer):
336395
Uses the IndexExtractorNodeLister to find all array subscript expressions
337396
in the AST node and its children that have to be extracted into independent expressions
338397
It then creates a new temporary variable for each of them and replaces the index expression with the variable.
398+
399+
Before parsing the AST, the transformation first runs:
400+
- ParentScopeAssigner to ensure that each node knows its scope assigner.
401+
- ScopeVarsDeclarations to aggregate all variable declarations for each function.
339402
"""
340-
def __init__(self, count=0):
403+
def __init__(self, ast: ast_internal_classes.FNode, normalize_offsets: bool = False, count=0):
404+
341405
self.count = count
406+
self.normalize_offsets = normalize_offsets
407+
408+
if normalize_offsets:
409+
ParentScopeAssigner().visit(ast)
410+
self.scope_vars = ScopeVarsDeclarations()
411+
self.scope_vars.visit(ast)
342412

343413
def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node):
344414
if node.name.name in ["sqrt", "exp", "pow", "max", "min", "abs", "tanh"]:
@@ -367,9 +437,11 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No
367437
lister.visit(child)
368438
res = lister.nodes
369439
temp = self.count
440+
441+
370442
if res is not None:
371443
for j in res:
372-
for i in j.indices:
444+
for idx, i in enumerate(j.indices):
373445
if isinstance(i, ast_internal_classes.ParDecl_Node):
374446
continue
375447
else:
@@ -383,16 +455,34 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No
383455
line_number=child.line_number)
384456
],
385457
line_number=child.line_number))
386-
newbody.append(
387-
ast_internal_classes.BinOp_Node(
388-
op="=",
389-
lval=ast_internal_classes.Name_Node(name=tmp_name),
390-
rval=ast_internal_classes.BinOp_Node(
391-
op="-",
392-
lval=i,
393-
rval=ast_internal_classes.Int_Literal_Node(value="1"),
394-
line_number=child.line_number),
395-
line_number=child.line_number))
458+
if self.normalize_offsets:
459+
460+
# Find the offset of a variable to which we are assigning
461+
var_name = child.lval.name.name
462+
variable = self.scope_vars.get_var(child.parent, var_name)
463+
offset = variable.offsets[idx]
464+
465+
newbody.append(
466+
ast_internal_classes.BinOp_Node(
467+
op="=",
468+
lval=ast_internal_classes.Name_Node(name=tmp_name),
469+
rval=ast_internal_classes.BinOp_Node(
470+
op="-",
471+
lval=i,
472+
rval=ast_internal_classes.Int_Literal_Node(value=str(offset)),
473+
line_number=child.line_number),
474+
line_number=child.line_number))
475+
else:
476+
newbody.append(
477+
ast_internal_classes.BinOp_Node(
478+
op="=",
479+
lval=ast_internal_classes.Name_Node(name=tmp_name),
480+
rval=ast_internal_classes.BinOp_Node(
481+
op="-",
482+
lval=i,
483+
rval=ast_internal_classes.Int_Literal_Node(value="1"),
484+
line_number=child.line_number),
485+
line_number=child.line_number))
396486
newbody.append(self.visit(child))
397487
return ast_internal_classes.Execution_Part_Node(execution=newbody)
398488

@@ -646,6 +736,7 @@ def par_Decl_Range_Finder(node: ast_internal_classes.Array_Subscript_Node,
646736
rangepos: list,
647737
count: int,
648738
newbody: list,
739+
scope_vars: ScopeVarsDeclarations,
649740
declaration=True,
650741
is_sum_to_loop=False):
651742
"""
@@ -662,16 +753,40 @@ def par_Decl_Range_Finder(node: ast_internal_classes.Array_Subscript_Node,
662753

663754
currentindex = 0
664755
indices = []
665-
for i in node.indices:
756+
offsets = scope_vars.get_var(node.parent, node.name.name).offsets
757+
758+
for idx, i in enumerate(node.indices):
666759
if isinstance(i, ast_internal_classes.ParDecl_Node):
760+
667761
if i.type == "ALL":
668-
ranges.append([
669-
ast_internal_classes.Int_Literal_Node(value="1"),
670-
ast_internal_classes.Name_Range_Node(name="f2dace_MAX",
671-
type="INTEGER",
672-
arrname=node.name,
673-
pos=currentindex)
674-
])
762+
763+
lower_boundary = None
764+
if offsets[idx] != 1:
765+
lower_boundary = ast_internal_classes.Int_Literal_Node(value=str(offsets[idx]))
766+
else:
767+
lower_boundary = ast_internal_classes.Int_Literal_Node(value="1")
768+
769+
upper_boundary = ast_internal_classes.Name_Range_Node(name="f2dace_MAX",
770+
type="INTEGER",
771+
arrname=node.name,
772+
pos=currentindex)
773+
"""
774+
When there's an offset, we add MAX_RANGE + offset.
775+
But since the generated loop has `<=` condition, we need to subtract 1.
776+
"""
777+
if offsets[idx] != 1:
778+
upper_boundary = ast_internal_classes.BinOp_Node(
779+
lval=upper_boundary,
780+
op="+",
781+
rval=ast_internal_classes.Int_Literal_Node(value=str(offsets[idx]))
782+
)
783+
upper_boundary = ast_internal_classes.BinOp_Node(
784+
lval=upper_boundary,
785+
op="-",
786+
rval=ast_internal_classes.Int_Literal_Node(value="1")
787+
)
788+
ranges.append([lower_boundary, upper_boundary])
789+
675790
else:
676791
ranges.append([i.range[0], i.range[1]])
677792
rangepos.append(currentindex)
@@ -693,9 +808,13 @@ class ArrayToLoop(NodeTransformer):
693808
"""
694809
Transforms the AST by removing array expressions and replacing them with loops
695810
"""
696-
def __init__(self):
811+
def __init__(self, ast):
697812
self.count = 0
698813

814+
ParentScopeAssigner().visit(ast)
815+
self.scope_vars = ScopeVarsDeclarations()
816+
self.scope_vars.visit(ast)
817+
699818
def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node):
700819
newbody = []
701820
for child in node.execution:
@@ -709,15 +828,15 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No
709828
val = child.rval
710829
ranges = []
711830
rangepos = []
712-
par_Decl_Range_Finder(current, ranges, rangepos, self.count, newbody, True)
831+
par_Decl_Range_Finder(current, ranges, rangepos, self.count, newbody, self.scope_vars, True)
713832

714833
if res_range is not None and len(res_range) > 0:
715834
rvals = [i for i in mywalk(val) if isinstance(i, ast_internal_classes.Array_Subscript_Node)]
716835
for i in rvals:
717836
rangeposrval = []
718837
rangesrval = []
719838

720-
par_Decl_Range_Finder(i, rangesrval, rangeposrval, self.count, newbody, False)
839+
par_Decl_Range_Finder(i, rangesrval, rangeposrval, self.count, newbody, self.scope_vars, False)
721840

722841
for i, j in zip(ranges, rangesrval):
723842
if i != j:
@@ -791,8 +910,11 @@ class SumToLoop(NodeTransformer):
791910
"""
792911
Transforms the AST by removing array sums and replacing them with loops
793912
"""
794-
def __init__(self):
913+
def __init__(self, ast):
795914
self.count = 0
915+
ParentScopeAssigner().visit(ast)
916+
self.scope_vars = ScopeVarsDeclarations()
917+
self.scope_vars.visit(ast)
796918

797919
def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node):
798920
newbody = []
@@ -811,7 +933,7 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No
811933
rangeposrval = []
812934
rangesrval = []
813935

814-
par_Decl_Range_Finder(val, rangesrval, rangeposrval, self.count, newbody, False, True)
936+
par_Decl_Range_Finder(val, rangesrval, rangeposrval, self.count, newbody, self.scope_vars, False, True)
815937

816938
range_index = 0
817939
body = ast_internal_classes.BinOp_Node(lval=current,

‎dace/frontend/fortran/fortran_parser.py

+43-7
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def translate(self, node: ast_internal_classes.FNode, sdfg: SDFG):
133133
for i in node:
134134
self.translate(i, sdfg)
135135
else:
136-
warnings.warn("WARNING:", node.__class__.__name__)
136+
warnings.warn(f"WARNING: {node.__class__.__name__}")
137137

138138
def ast2sdfg(self, node: ast_internal_classes.Program_Node, sdfg: SDFG):
139139
"""
@@ -1015,10 +1015,46 @@ def vardecl2sdfg(self, node: ast_internal_classes.Var_Decl_Node, sdfg: SDFG):
10151015
if node.name not in self.contexts[sdfg.name].containers:
10161016
self.contexts[sdfg.name].containers.append(node.name)
10171017

1018+
def create_ast_from_string(
1019+
source_string: str,
1020+
sdfg_name: str,
1021+
transform: bool = False,
1022+
normalize_offsets: bool = False
1023+
):
1024+
"""
1025+
Creates an AST from a Fortran file in a string
1026+
:param source_string: The fortran file as a string
1027+
:param sdfg_name: The name to be given to the resulting SDFG
1028+
:return: The resulting AST
1029+
1030+
"""
1031+
parser = pf().create(std="f2008")
1032+
reader = fsr(source_string)
1033+
ast = parser(reader)
1034+
tables = SymbolTable
1035+
own_ast = ast_components.InternalFortranAst(ast, tables)
1036+
program = own_ast.create_ast(ast)
1037+
1038+
functions_and_subroutines_builder = ast_transforms.FindFunctionAndSubroutines()
1039+
functions_and_subroutines_builder.visit(program)
1040+
functions_and_subroutines = functions_and_subroutines_builder.nodes
1041+
1042+
if transform:
1043+
program = ast_transforms.functionStatementEliminator(program)
1044+
program = ast_transforms.CallToArray(functions_and_subroutines_builder.nodes).visit(program)
1045+
program = ast_transforms.CallExtractor().visit(program)
1046+
program = ast_transforms.SignToIf().visit(program)
1047+
program = ast_transforms.ArrayToLoop(program).visit(program)
1048+
program = ast_transforms.SumToLoop(program).visit(program)
1049+
program = ast_transforms.ForDeclarer().visit(program)
1050+
program = ast_transforms.IndexExtractor(program, normalize_offsets).visit(program)
1051+
1052+
return (program, own_ast)
10181053

10191054
def create_sdfg_from_string(
10201055
source_string: str,
10211056
sdfg_name: str,
1057+
normalize_offsets: bool = False
10221058
):
10231059
"""
10241060
Creates an SDFG from a fortran file in a string
@@ -1040,10 +1076,10 @@ def create_sdfg_from_string(
10401076
program = ast_transforms.CallToArray(functions_and_subroutines_builder.nodes).visit(program)
10411077
program = ast_transforms.CallExtractor().visit(program)
10421078
program = ast_transforms.SignToIf().visit(program)
1043-
program = ast_transforms.ArrayToLoop().visit(program)
1044-
program = ast_transforms.SumToLoop().visit(program)
1079+
program = ast_transforms.ArrayToLoop(program).visit(program)
1080+
program = ast_transforms.SumToLoop(program).visit(program)
10451081
program = ast_transforms.ForDeclarer().visit(program)
1046-
program = ast_transforms.IndexExtractor().visit(program)
1082+
program = ast_transforms.IndexExtractor(program, normalize_offsets).visit(program)
10471083
ast2sdfg = AST_translator(own_ast, __file__)
10481084
sdfg = SDFG(sdfg_name)
10491085
ast2sdfg.top_level = program
@@ -1082,10 +1118,10 @@ def create_sdfg_from_fortran_file(source_string: str):
10821118
program = ast_transforms.CallToArray(functions_and_subroutines_builder.nodes).visit(program)
10831119
program = ast_transforms.CallExtractor().visit(program)
10841120
program = ast_transforms.SignToIf().visit(program)
1085-
program = ast_transforms.ArrayToLoop().visit(program)
1086-
program = ast_transforms.SumToLoop().visit(program)
1121+
program = ast_transforms.ArrayToLoop(program).visit(program)
1122+
program = ast_transforms.SumToLoop(program).visit(program)
10871123
program = ast_transforms.ForDeclarer().visit(program)
1088-
program = ast_transforms.IndexExtractor().visit(program)
1124+
program = ast_transforms.IndexExtractor(program).visit(program)
10891125
ast2sdfg = AST_translator(own_ast, __file__)
10901126
sdfg = SDFG(source_string)
10911127
ast2sdfg.top_level = program

‎tests/fortran/array_to_loop_offset.py

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
2+
3+
import numpy as np
4+
5+
from dace.frontend.fortran import ast_transforms, fortran_parser
6+
7+
def test_fortran_frontend_arr2loop_without_offset():
8+
"""
9+
Tests that the generated array map correctly handles offsets.
10+
"""
11+
test_string = """
12+
PROGRAM index_offset_test
13+
implicit none
14+
double precision, dimension(5,3) :: d
15+
CALL index_test_function(d)
16+
end
17+
18+
SUBROUTINE index_test_function(d)
19+
double precision, dimension(5,3) :: d
20+
21+
do i=1,5
22+
d(i, :) = i * 2.0
23+
end do
24+
25+
END SUBROUTINE index_test_function
26+
"""
27+
28+
# Now test to verify it executes correctly with no offset normalization
29+
30+
sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", False)
31+
sdfg.simplify(verbose=True)
32+
sdfg.compile()
33+
34+
assert len(sdfg.data('d').shape) == 2
35+
assert sdfg.data('d').shape[0] == 5
36+
assert sdfg.data('d').shape[1] == 3
37+
38+
a = np.full([5,9], 42, order="F", dtype=np.float64)
39+
sdfg(d=a)
40+
for i in range(1,6):
41+
for j in range(1,4):
42+
assert a[i-1, j-1] == i * 2
43+
44+
def test_fortran_frontend_arr2loop_1d_offset():
45+
"""
46+
Tests that the generated array map correctly handles offsets.
47+
"""
48+
test_string = """
49+
PROGRAM index_offset_test
50+
implicit none
51+
double precision, dimension(2:6) :: d
52+
CALL index_test_function(d)
53+
end
54+
55+
SUBROUTINE index_test_function(d)
56+
double precision, dimension(2:6) :: d
57+
58+
d(:) = 5
59+
60+
END SUBROUTINE index_test_function
61+
"""
62+
63+
# Now test to verify it executes correctly with no offset normalization
64+
65+
sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", False)
66+
sdfg.simplify(verbose=True)
67+
sdfg.compile()
68+
69+
assert len(sdfg.data('d').shape) == 1
70+
assert sdfg.data('d').shape[0] == 5
71+
72+
a = np.full([6], 42, order="F", dtype=np.float64)
73+
sdfg(d=a)
74+
assert a[0] == 42
75+
for i in range(2,7):
76+
assert a[i-1] == 5
77+
78+
def test_fortran_frontend_arr2loop_2d_offset():
79+
"""
80+
Tests that the generated array map correctly handles offsets.
81+
"""
82+
test_string = """
83+
PROGRAM index_offset_test
84+
implicit none
85+
double precision, dimension(5,7:9) :: d
86+
CALL index_test_function(d)
87+
end
88+
89+
SUBROUTINE index_test_function(d)
90+
double precision, dimension(5,7:9) :: d
91+
92+
do i=1,5
93+
d(i, :) = i * 2.0
94+
end do
95+
96+
END SUBROUTINE index_test_function
97+
"""
98+
99+
# Now test to verify it executes correctly with no offset normalization
100+
101+
sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", False)
102+
sdfg.simplify(verbose=True)
103+
sdfg.compile()
104+
105+
assert len(sdfg.data('d').shape) == 2
106+
assert sdfg.data('d').shape[0] == 5
107+
assert sdfg.data('d').shape[1] == 3
108+
109+
a = np.full([5,9], 42, order="F", dtype=np.float64)
110+
sdfg(d=a)
111+
for i in range(1,6):
112+
for j in range(7,10):
113+
assert a[i-1, j-1] == i * 2
114+
115+
if __name__ == "__main__":
116+
117+
test_fortran_frontend_arr2loop_1d_offset()
118+
test_fortran_frontend_arr2loop_2d_offset()
119+
test_fortran_frontend_arr2loop_without_offset()

‎tests/fortran/offset_normalizer.py

+164
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
2+
3+
import numpy as np
4+
5+
from dace.frontend.fortran import ast_transforms, fortran_parser
6+
7+
def test_fortran_frontend_offset_normalizer_1d():
8+
"""
9+
Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct.
10+
"""
11+
test_string = """
12+
PROGRAM index_offset_test
13+
implicit none
14+
double precision, dimension(50:54) :: d
15+
CALL index_test_function(d)
16+
end
17+
18+
SUBROUTINE index_test_function(d)
19+
double precision, dimension(50:54) :: d
20+
21+
do i=50,54
22+
d(i) = i * 2.0
23+
end do
24+
25+
END SUBROUTINE index_test_function
26+
"""
27+
28+
# Test to verify that offset is normalized correctly
29+
ast, own_ast = fortran_parser.create_ast_from_string(test_string, "index_offset_test", True, True)
30+
31+
for subroutine in ast.subroutine_definitions:
32+
33+
loop = subroutine.execution_part.execution[1]
34+
idx_assignment = loop.body.execution[1]
35+
assert idx_assignment.rval.rval.value == "50"
36+
37+
# Now test to verify it executes correctly
38+
39+
sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True)
40+
sdfg.simplify(verbose=True)
41+
sdfg.compile()
42+
43+
assert len(sdfg.data('d').shape) == 1
44+
assert sdfg.data('d').shape[0] == 5
45+
46+
a = np.full([5], 42, order="F", dtype=np.float64)
47+
sdfg(d=a)
48+
for i in range(0,5):
49+
assert a[i] == (50+i)* 2
50+
51+
def test_fortran_frontend_offset_normalizer_2d():
52+
"""
53+
Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct.
54+
"""
55+
test_string = """
56+
PROGRAM index_offset_test
57+
implicit none
58+
double precision, dimension(50:54,7:9) :: d
59+
CALL index_test_function(d)
60+
end
61+
62+
SUBROUTINE index_test_function(d)
63+
double precision, dimension(50:54,7:9) :: d
64+
65+
do i=50,54
66+
do j=7,9
67+
d(i, j) = i * 2.0 + 3 * j
68+
end do
69+
end do
70+
71+
END SUBROUTINE index_test_function
72+
"""
73+
74+
# Test to verify that offset is normalized correctly
75+
ast, own_ast = fortran_parser.create_ast_from_string(test_string, "index_offset_test", True, True)
76+
77+
for subroutine in ast.subroutine_definitions:
78+
79+
loop = subroutine.execution_part.execution[1]
80+
nested_loop = loop.body.execution[1]
81+
82+
idx = nested_loop.body.execution[1]
83+
assert idx.lval.name == 'tmp_index_0'
84+
assert idx.rval.rval.value == "50"
85+
86+
idx2 = nested_loop.body.execution[3]
87+
assert idx2.lval.name == 'tmp_index_1'
88+
assert idx2.rval.rval.value == "7"
89+
90+
# Now test to verify it executes correctly
91+
92+
sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True)
93+
sdfg.simplify(verbose=True)
94+
sdfg.compile()
95+
96+
assert len(sdfg.data('d').shape) == 2
97+
assert sdfg.data('d').shape[0] == 5
98+
assert sdfg.data('d').shape[1] == 3
99+
100+
a = np.full([5,3], 42, order="F", dtype=np.float64)
101+
sdfg(d=a)
102+
for i in range(0,5):
103+
for j in range(0,3):
104+
assert a[i, j] == (50+i) * 2 + 3 * (7 + j)
105+
106+
def test_fortran_frontend_offset_normalizer_2d_arr2loop():
107+
"""
108+
Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct.
109+
"""
110+
test_string = """
111+
PROGRAM index_offset_test
112+
implicit none
113+
double precision, dimension(50:54,7:9) :: d
114+
CALL index_test_function(d)
115+
end
116+
117+
SUBROUTINE index_test_function(d)
118+
double precision, dimension(50:54,7:9) :: d
119+
120+
do i=50,54
121+
d(i, :) = i * 2.0
122+
end do
123+
124+
END SUBROUTINE index_test_function
125+
"""
126+
127+
# Test to verify that offset is normalized correctly
128+
ast, own_ast = fortran_parser.create_ast_from_string(test_string, "index_offset_test", True, True)
129+
130+
for subroutine in ast.subroutine_definitions:
131+
132+
loop = subroutine.execution_part.execution[1]
133+
nested_loop = loop.body.execution[1]
134+
135+
idx = nested_loop.body.execution[1]
136+
assert idx.lval.name == 'tmp_index_0'
137+
assert idx.rval.rval.value == "50"
138+
139+
idx2 = nested_loop.body.execution[3]
140+
assert idx2.lval.name == 'tmp_index_1'
141+
assert idx2.rval.rval.value == "7"
142+
143+
# Now test to verify it executes correctly with no normalization
144+
145+
sdfg = fortran_parser.create_sdfg_from_string(test_string, "index_offset_test", True)
146+
sdfg.save('test.sdfg')
147+
sdfg.simplify(verbose=True)
148+
sdfg.compile()
149+
150+
assert len(sdfg.data('d').shape) == 2
151+
assert sdfg.data('d').shape[0] == 5
152+
assert sdfg.data('d').shape[1] == 3
153+
154+
a = np.full([5,3], 42, order="F", dtype=np.float64)
155+
sdfg(d=a)
156+
for i in range(0,5):
157+
for j in range(0,3):
158+
assert a[i, j] == (50 + i) * 2
159+
160+
if __name__ == "__main__":
161+
162+
test_fortran_frontend_offset_normalizer_1d()
163+
test_fortran_frontend_offset_normalizer_2d()
164+
test_fortran_frontend_offset_normalizer_2d_arr2loop()

‎tests/fortran/parent_test.py

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved.
2+
3+
from dace.frontend.fortran import fortran_parser
4+
5+
import dace.frontend.fortran.ast_transforms as ast_transforms
6+
import dace.frontend.fortran.ast_internal_classes as ast_internal_classes
7+
8+
9+
def test_fortran_frontend_parent():
10+
"""
11+
Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct.
12+
"""
13+
test_string = """
14+
PROGRAM access_test
15+
implicit none
16+
double precision d(4)
17+
d(1)=0
18+
CALL array_access_test_function(d)
19+
end
20+
21+
SUBROUTINE array_access_test_function(d)
22+
double precision d(4)
23+
24+
d(2)=5.5
25+
26+
END SUBROUTINE array_access_test_function
27+
"""
28+
ast, functions = fortran_parser.create_ast_from_string(test_string, "array_access_test")
29+
ast_transforms.ParentScopeAssigner().visit(ast)
30+
31+
assert ast.parent is None
32+
assert ast.main_program.parent == None
33+
34+
main_program = ast.main_program
35+
# Both executed lines
36+
for execution in main_program.execution_part.execution:
37+
assert execution.parent == main_program
38+
# call to the function
39+
call_node = main_program.execution_part.execution[1]
40+
assert isinstance(call_node, ast_internal_classes.Call_Expr_Node)
41+
for arg in call_node.args:
42+
assert arg.parent == main_program
43+
44+
for subroutine in ast.subroutine_definitions:
45+
46+
assert subroutine.parent == None
47+
assert subroutine.execution_part.parent == subroutine
48+
for execution in subroutine.execution_part.execution:
49+
assert execution.parent == subroutine
50+
51+
def test_fortran_frontend_module():
52+
"""
53+
Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct.
54+
"""
55+
test_string = """
56+
module test_module
57+
implicit none
58+
! good enough approximation
59+
integer, parameter :: pi = 4
60+
end module test_module
61+
62+
PROGRAM access_test
63+
implicit none
64+
double precision d(4)
65+
d(1)=0
66+
CALL array_access_test_function(d)
67+
end
68+
69+
SUBROUTINE array_access_test_function(d)
70+
double precision d(4)
71+
72+
d(2)=5.5
73+
74+
END SUBROUTINE array_access_test_function
75+
"""
76+
ast, functions = fortran_parser.create_ast_from_string(test_string, "array_access_test")
77+
ast_transforms.ParentScopeAssigner().visit(ast)
78+
79+
assert ast.parent is None
80+
assert ast.main_program.parent == None
81+
82+
module = ast.modules[0]
83+
assert module.parent == None
84+
specification = module.specification_part.specifications[0]
85+
assert specification.parent == module
86+
87+
88+
if __name__ == "__main__":
89+
90+
test_fortran_frontend_parent()
91+
test_fortran_frontend_module()

‎tests/fortran/scope_arrays.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved.
2+
3+
from dace.frontend.fortran import fortran_parser
4+
5+
import dace.frontend.fortran.ast_transforms as ast_transforms
6+
import dace.frontend.fortran.ast_internal_classes as ast_internal_classes
7+
8+
9+
def test_fortran_frontend_parent():
10+
"""
11+
Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct.
12+
"""
13+
test_string = """
14+
PROGRAM scope_test
15+
implicit none
16+
double precision d(4)
17+
double precision, dimension(5) :: arr
18+
double precision, dimension(50:54) :: arr3
19+
CALL scope_test_function(d)
20+
end
21+
22+
SUBROUTINE scope_test_function(d)
23+
double precision d(4)
24+
double precision, dimension(50:54) :: arr4
25+
26+
d(2)=5.5
27+
28+
END SUBROUTINE scope_test_function
29+
"""
30+
31+
ast, functions = fortran_parser.create_ast_from_string(test_string, "array_access_test")
32+
ast_transforms.ParentScopeAssigner().visit(ast)
33+
visitor = ast_transforms.ScopeVarsDeclarations()
34+
visitor.visit(ast)
35+
36+
for var in ['d', 'arr', 'arr3']:
37+
assert ('scope_test', var) in visitor.scope_vars
38+
assert isinstance(visitor.scope_vars[('scope_test', var)], ast_internal_classes.Var_Decl_Node)
39+
assert visitor.scope_vars[('scope_test', var)].name == var
40+
41+
for var in ['d', 'arr4']:
42+
assert ('scope_test_function', var) in visitor.scope_vars
43+
assert visitor.scope_vars[('scope_test_function', var)].name == var
44+
45+
if __name__ == "__main__":
46+
47+
test_fortran_frontend_parent()

0 commit comments

Comments
 (0)
Please sign in to comment.