1
1
# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved.
2
2
3
3
from dace .frontend .fortran import ast_components , ast_internal_classes
4
- from typing import List , Tuple , Set
4
+ from typing import Dict , List , Optional , Tuple , Set
5
5
import copy
6
6
7
7
@@ -310,6 +310,65 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No
310
310
311
311
return ast_internal_classes .Execution_Part_Node (execution = newbody )
312
312
313
+ class ParentScopeAssigner (NodeVisitor ):
314
+ """
315
+ For each node, it assigns its parent scope - program, subroutine, function.
316
+
317
+ If the parent node is one of the "parent" types, we assign it as the parent.
318
+ Otherwise, we look for the parent of my parent to cover nested AST nodes within
319
+ a single scope.
320
+ """
321
+ def __init__ (self ):
322
+ pass
323
+
324
+ def visit (self , node : ast_internal_classes .FNode , parent_node : Optional [ast_internal_classes .FNode ] = None ):
325
+
326
+ parent_node_types = [
327
+ ast_internal_classes .Subroutine_Subprogram_Node ,
328
+ ast_internal_classes .Function_Subprogram_Node ,
329
+ ast_internal_classes .Main_Program_Node ,
330
+ ast_internal_classes .Module_Node
331
+ ]
332
+
333
+ if parent_node is not None and type (parent_node ) in parent_node_types :
334
+ node .parent = parent_node
335
+ elif parent_node is not None :
336
+ node .parent = parent_node .parent
337
+
338
+ # Copied from `generic_visit` to recursively parse all leafs
339
+ for field , value in iter_fields (node ):
340
+ if isinstance (value , list ):
341
+ for item in value :
342
+ if isinstance (item , ast_internal_classes .FNode ):
343
+ self .visit (item , node )
344
+ elif isinstance (value , ast_internal_classes .FNode ):
345
+ self .visit (value , node )
346
+
347
+ class ScopeVarsDeclarations (NodeVisitor ):
348
+ """
349
+ Creates a mapping (scope name, variable name) -> variable declaration.
350
+
351
+ The visitor is used to access information on variable dimension, sizes, and offsets.
352
+ """
353
+
354
+ def __init__ (self ):
355
+
356
+ self .scope_vars : Dict [Tuple [str , str ], ast_internal_classes .FNode ] = {}
357
+
358
+ def get_var (self , scope : ast_internal_classes .FNode , variable_name : str ) -> ast_internal_classes .FNode :
359
+ return self .scope_vars [(self ._scope_name (scope ), variable_name )]
360
+
361
+ def visit_Var_Decl_Node (self , node : ast_internal_classes .Var_Decl_Node ):
362
+
363
+ parent_name = self ._scope_name (node .parent )
364
+ var_name = node .name
365
+ self .scope_vars [(parent_name , var_name )] = node
366
+
367
+ def _scope_name (self , scope : ast_internal_classes .FNode ) -> str :
368
+ if isinstance (scope , ast_internal_classes .Main_Program_Node ):
369
+ return scope .name .name .name
370
+ else :
371
+ return scope .name .name
313
372
314
373
class IndexExtractorNodeLister (NodeVisitor ):
315
374
"""
@@ -336,9 +395,20 @@ class IndexExtractor(NodeTransformer):
336
395
Uses the IndexExtractorNodeLister to find all array subscript expressions
337
396
in the AST node and its children that have to be extracted into independent expressions
338
397
It then creates a new temporary variable for each of them and replaces the index expression with the variable.
398
+
399
+ Before parsing the AST, the transformation first runs:
400
+ - ParentScopeAssigner to ensure that each node knows its scope assigner.
401
+ - ScopeVarsDeclarations to aggregate all variable declarations for each function.
339
402
"""
340
- def __init__ (self , count = 0 ):
403
+ def __init__ (self , ast : ast_internal_classes .FNode , normalize_offsets : bool = False , count = 0 ):
404
+
341
405
self .count = count
406
+ self .normalize_offsets = normalize_offsets
407
+
408
+ if normalize_offsets :
409
+ ParentScopeAssigner ().visit (ast )
410
+ self .scope_vars = ScopeVarsDeclarations ()
411
+ self .scope_vars .visit (ast )
342
412
343
413
def visit_Call_Expr_Node (self , node : ast_internal_classes .Call_Expr_Node ):
344
414
if node .name .name in ["sqrt" , "exp" , "pow" , "max" , "min" , "abs" , "tanh" ]:
@@ -367,9 +437,11 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No
367
437
lister .visit (child )
368
438
res = lister .nodes
369
439
temp = self .count
440
+
441
+
370
442
if res is not None :
371
443
for j in res :
372
- for i in j .indices :
444
+ for idx , i in enumerate ( j .indices ) :
373
445
if isinstance (i , ast_internal_classes .ParDecl_Node ):
374
446
continue
375
447
else :
@@ -383,16 +455,34 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No
383
455
line_number = child .line_number )
384
456
],
385
457
line_number = child .line_number ))
386
- newbody .append (
387
- ast_internal_classes .BinOp_Node (
388
- op = "=" ,
389
- lval = ast_internal_classes .Name_Node (name = tmp_name ),
390
- rval = ast_internal_classes .BinOp_Node (
391
- op = "-" ,
392
- lval = i ,
393
- rval = ast_internal_classes .Int_Literal_Node (value = "1" ),
394
- line_number = child .line_number ),
395
- line_number = child .line_number ))
458
+ if self .normalize_offsets :
459
+
460
+ # Find the offset of a variable to which we are assigning
461
+ var_name = child .lval .name .name
462
+ variable = self .scope_vars .get_var (child .parent , var_name )
463
+ offset = variable .offsets [idx ]
464
+
465
+ newbody .append (
466
+ ast_internal_classes .BinOp_Node (
467
+ op = "=" ,
468
+ lval = ast_internal_classes .Name_Node (name = tmp_name ),
469
+ rval = ast_internal_classes .BinOp_Node (
470
+ op = "-" ,
471
+ lval = i ,
472
+ rval = ast_internal_classes .Int_Literal_Node (value = str (offset )),
473
+ line_number = child .line_number ),
474
+ line_number = child .line_number ))
475
+ else :
476
+ newbody .append (
477
+ ast_internal_classes .BinOp_Node (
478
+ op = "=" ,
479
+ lval = ast_internal_classes .Name_Node (name = tmp_name ),
480
+ rval = ast_internal_classes .BinOp_Node (
481
+ op = "-" ,
482
+ lval = i ,
483
+ rval = ast_internal_classes .Int_Literal_Node (value = "1" ),
484
+ line_number = child .line_number ),
485
+ line_number = child .line_number ))
396
486
newbody .append (self .visit (child ))
397
487
return ast_internal_classes .Execution_Part_Node (execution = newbody )
398
488
@@ -646,6 +736,7 @@ def par_Decl_Range_Finder(node: ast_internal_classes.Array_Subscript_Node,
646
736
rangepos : list ,
647
737
count : int ,
648
738
newbody : list ,
739
+ scope_vars : ScopeVarsDeclarations ,
649
740
declaration = True ,
650
741
is_sum_to_loop = False ):
651
742
"""
@@ -662,16 +753,40 @@ def par_Decl_Range_Finder(node: ast_internal_classes.Array_Subscript_Node,
662
753
663
754
currentindex = 0
664
755
indices = []
665
- for i in node .indices :
756
+ offsets = scope_vars .get_var (node .parent , node .name .name ).offsets
757
+
758
+ for idx , i in enumerate (node .indices ):
666
759
if isinstance (i , ast_internal_classes .ParDecl_Node ):
760
+
667
761
if i .type == "ALL" :
668
- ranges .append ([
669
- ast_internal_classes .Int_Literal_Node (value = "1" ),
670
- ast_internal_classes .Name_Range_Node (name = "f2dace_MAX" ,
671
- type = "INTEGER" ,
672
- arrname = node .name ,
673
- pos = currentindex )
674
- ])
762
+
763
+ lower_boundary = None
764
+ if offsets [idx ] != 1 :
765
+ lower_boundary = ast_internal_classes .Int_Literal_Node (value = str (offsets [idx ]))
766
+ else :
767
+ lower_boundary = ast_internal_classes .Int_Literal_Node (value = "1" )
768
+
769
+ upper_boundary = ast_internal_classes .Name_Range_Node (name = "f2dace_MAX" ,
770
+ type = "INTEGER" ,
771
+ arrname = node .name ,
772
+ pos = currentindex )
773
+ """
774
+ When there's an offset, we add MAX_RANGE + offset.
775
+ But since the generated loop has `<=` condition, we need to subtract 1.
776
+ """
777
+ if offsets [idx ] != 1 :
778
+ upper_boundary = ast_internal_classes .BinOp_Node (
779
+ lval = upper_boundary ,
780
+ op = "+" ,
781
+ rval = ast_internal_classes .Int_Literal_Node (value = str (offsets [idx ]))
782
+ )
783
+ upper_boundary = ast_internal_classes .BinOp_Node (
784
+ lval = upper_boundary ,
785
+ op = "-" ,
786
+ rval = ast_internal_classes .Int_Literal_Node (value = "1" )
787
+ )
788
+ ranges .append ([lower_boundary , upper_boundary ])
789
+
675
790
else :
676
791
ranges .append ([i .range [0 ], i .range [1 ]])
677
792
rangepos .append (currentindex )
@@ -693,9 +808,13 @@ class ArrayToLoop(NodeTransformer):
693
808
"""
694
809
Transforms the AST by removing array expressions and replacing them with loops
695
810
"""
696
- def __init__ (self ):
811
+ def __init__ (self , ast ):
697
812
self .count = 0
698
813
814
+ ParentScopeAssigner ().visit (ast )
815
+ self .scope_vars = ScopeVarsDeclarations ()
816
+ self .scope_vars .visit (ast )
817
+
699
818
def visit_Execution_Part_Node (self , node : ast_internal_classes .Execution_Part_Node ):
700
819
newbody = []
701
820
for child in node .execution :
@@ -709,15 +828,15 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No
709
828
val = child .rval
710
829
ranges = []
711
830
rangepos = []
712
- par_Decl_Range_Finder (current , ranges , rangepos , self .count , newbody , True )
831
+ par_Decl_Range_Finder (current , ranges , rangepos , self .count , newbody , self . scope_vars , True )
713
832
714
833
if res_range is not None and len (res_range ) > 0 :
715
834
rvals = [i for i in mywalk (val ) if isinstance (i , ast_internal_classes .Array_Subscript_Node )]
716
835
for i in rvals :
717
836
rangeposrval = []
718
837
rangesrval = []
719
838
720
- par_Decl_Range_Finder (i , rangesrval , rangeposrval , self .count , newbody , False )
839
+ par_Decl_Range_Finder (i , rangesrval , rangeposrval , self .count , newbody , self . scope_vars , False )
721
840
722
841
for i , j in zip (ranges , rangesrval ):
723
842
if i != j :
@@ -791,8 +910,11 @@ class SumToLoop(NodeTransformer):
791
910
"""
792
911
Transforms the AST by removing array sums and replacing them with loops
793
912
"""
794
- def __init__ (self ):
913
+ def __init__ (self , ast ):
795
914
self .count = 0
915
+ ParentScopeAssigner ().visit (ast )
916
+ self .scope_vars = ScopeVarsDeclarations ()
917
+ self .scope_vars .visit (ast )
796
918
797
919
def visit_Execution_Part_Node (self , node : ast_internal_classes .Execution_Part_Node ):
798
920
newbody = []
@@ -811,7 +933,7 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No
811
933
rangeposrval = []
812
934
rangesrval = []
813
935
814
- par_Decl_Range_Finder (val , rangesrval , rangeposrval , self .count , newbody , False , True )
936
+ par_Decl_Range_Finder (val , rangesrval , rangeposrval , self .count , newbody , self . scope_vars , False , True )
815
937
816
938
range_index = 0
817
939
body = ast_internal_classes .BinOp_Node (lval = current ,
0 commit comments