-
Notifications
You must be signed in to change notification settings - Fork 133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
More NumPy operation implementations #1498
Merged
Merged
Changes from all commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
fdd73cf
Implement numpy concatenation and stacking
tbennun 531576f
Implement numpy.linspace, add argument checks to numpy.arange
tbennun 9687cb6
Fix complex to scalar type inference
tbennun 73413da
Fix parsing of nested attributes
tbennun 06e5cac
Implement np.clip ufunc
tbennun abd0fe6
Fix attribute evaluation for non-arrays
tbennun a2e47c6
Implement numpy.split and its variants
tbennun 66ef621
Fix fast transposition for mismatched input/output types
tbennun 80ce041
Fix arange result type, fix numpy.full variants for scalar shapes
tbennun a052d42
Fix another case of attribute misparsing
tbennun e7ebad5
Safer creation of complex values in codegen
tbennun df33fdd
Support complex gemv in CPU BLAS replacement
tbennun 0eb1398
Fix parsing of builtin values and symbolic expressions
tbennun b4bd72e
Implement len builtin for constants
tbennun b835cd2
Further fix for callbacks
tbennun 7574fe9
Fix tests
tbennun 21b191d
Implement `numpy.fft.{fft,ifft}` and library node
tbennun 2119c42
Cast cblas_transpose correctly
tbennun 293f354
Merge branch 'master' into numpy-extension
alexnick83 949572f
Merge branch 'master' into numpy-extension
tbennun b89b33a
Merge branch 'master' into numpy-extension
tbennun 2c9fcc6
Merge branch 'main' into numpy-extension
tbennun 7fc04c5
Fix field name
tbennun 3ae6d0a
Fix special-case handling of intracomms
tbennun d0f2994
Merge remote-tracking branch 'origin/main' into numpy-extension
tbennun 95b5cd5
Fix process grids in Python visitor's defined variables
tbennun File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1342,7 +1342,7 @@ def defined(self): | |
|
||
# MPI-related stuff | ||
result.update({ | ||
k: self.sdfg.process_grids[v] | ||
v: self.sdfg.process_grids[v] | ||
for k, v in self.variables.items() if v in self.sdfg.process_grids | ||
}) | ||
try: | ||
|
@@ -4461,7 +4461,14 @@ def visit_Call(self, node: ast.Call, create_callbacks=False): | |
func = node.func.value | ||
|
||
if func is None: | ||
funcname = rname(node) | ||
func_result = self.visit(node.func) | ||
if isinstance(func_result, str): | ||
if isinstance(node.func, ast.Attribute): | ||
funcname = f'{func_result}.{node.func.attr}' | ||
else: | ||
funcname = func_result | ||
else: | ||
funcname = rname(node) | ||
# Check if the function exists as an SDFG in a different module | ||
modname = until(funcname, '.') | ||
if ('.' in funcname and len(modname) > 0 and modname in self.globals | ||
|
@@ -4576,7 +4583,7 @@ def visit_Call(self, node: ast.Call, create_callbacks=False): | |
arg = self.scope_vars[modname] | ||
else: | ||
# Fallback to (name, object) | ||
arg = (modname, self.defined[modname]) | ||
arg = modname | ||
args.append(arg) | ||
# Otherwise, try to find a default implementation for the SDFG | ||
elif not found_ufunc: | ||
|
@@ -4795,12 +4802,18 @@ def _visitname(self, name: str, node: ast.AST): | |
self.sdfg.add_symbol(result.name, result.dtype) | ||
return result | ||
|
||
if name in self.closure.callbacks: | ||
return name | ||
|
||
if name in self.sdfg.arrays: | ||
return name | ||
|
||
if name in self.sdfg.symbols: | ||
return name | ||
|
||
if name in __builtins__: | ||
return name | ||
|
||
if name not in self.scope_vars: | ||
raise DaceSyntaxError(self, node, 'Use of undefined variable "%s"' % name) | ||
rname = self.scope_vars[name] | ||
|
@@ -4845,33 +4858,43 @@ def visit_NameConstant(self, node: NameConstant): | |
return self.visit_Constant(node) | ||
|
||
def visit_Attribute(self, node: ast.Attribute): | ||
# If visiting an attribute, return attribute value if it's of an array or global | ||
name = until(astutils.unparse(node), '.') | ||
result = self._visitname(name, node) | ||
result = self.visit(node.value) | ||
if isinstance(result, (tuple, list, dict)): | ||
if len(result) > 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this the fix for attributes on expressions? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Exactly |
||
raise DaceSyntaxError( | ||
self, node.value, f'{type(result)} object cannot use attributes. Try storing the ' | ||
'object to a different variable first (e.g., ``a = result; a.attribute``') | ||
else: | ||
result = result[0] | ||
|
||
tmpname = f"{result}.{astutils.unparse(node.attr)}" | ||
if tmpname in self.sdfg.arrays: | ||
return tmpname | ||
|
||
if isinstance(result, str) and result in self.sdfg.arrays: | ||
arr = self.sdfg.arrays[result] | ||
elif isinstance(result, str) and result in self.scope_arrays: | ||
arr = self.scope_arrays[result] | ||
else: | ||
return result | ||
arr = None | ||
|
||
# Try to find sub-SDFG attribute | ||
func = oprepo.Replacements.get_attribute(type(arr), node.attr) | ||
if func is not None: | ||
# A new state is likely needed here, e.g., for transposition (ndarray.T) | ||
self._add_state('%s_%d' % (type(node).__name__, node.lineno)) | ||
self.last_block.set_default_lineinfo(self.current_lineinfo) | ||
result = func(self, self.sdfg, self.last_block, result) | ||
self.last_block.set_default_lineinfo(None) | ||
return result | ||
if arr is not None: | ||
func = oprepo.Replacements.get_attribute(type(arr), node.attr) | ||
if func is not None: | ||
# A new state is likely needed here, e.g., for transposition (ndarray.T) | ||
self._add_state('%s_%d' % (type(node).__name__, node.lineno)) | ||
self.last_block.set_default_lineinfo(self.current_lineinfo) | ||
result = func(self, self.sdfg, self.last_block, result) | ||
self.last_block.set_default_lineinfo(None) | ||
return result | ||
|
||
# Otherwise, try to find compile-time attribute (such as shape) | ||
try: | ||
return getattr(arr, node.attr) | ||
except KeyError: | ||
if arr is not None: | ||
return getattr(arr, node.attr) | ||
return getattr(result, node.attr) | ||
except (AttributeError, KeyError): | ||
return result | ||
|
||
def visit_List(self, node: ast.List): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change handled already in the code somehow, i.e., is
self.defined
queried on demand?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It’s a code path that we didn’t cover before in coverage. The rest of the code assumes a string return type rather than a tuple, and crashes somewhere else.