Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More NumPy operation implementations #1498

Merged
merged 26 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
fdd73cf
Implement numpy concatenation and stacking
tbennun Jan 7, 2024
531576f
Implement numpy.linspace, add argument checks to numpy.arange
tbennun Jan 7, 2024
9687cb6
Fix complex to scalar type inference
tbennun Jan 7, 2024
73413da
Fix parsing of nested attributes
tbennun Jan 7, 2024
06e5cac
Implement np.clip ufunc
tbennun Jan 7, 2024
abd0fe6
Fix attribute evaluation for non-arrays
tbennun Jan 7, 2024
a2e47c6
Implement numpy.split and its variants
tbennun Jan 7, 2024
66ef621
Fix fast transposition for mismatched input/output types
tbennun Jan 8, 2024
80ce041
Fix arange result type, fix numpy.full variants for scalar shapes
tbennun Jan 8, 2024
a052d42
Fix another case of attribute misparsing
tbennun Jan 8, 2024
e7ebad5
Safer creation of complex values in codegen
tbennun Jan 8, 2024
df33fdd
Support complex gemv in CPU BLAS replacement
tbennun Jan 8, 2024
0eb1398
Fix parsing of builtin values and symbolic expressions
tbennun Jan 8, 2024
b4bd72e
Implement len builtin for constants
tbennun Jan 8, 2024
b835cd2
Further fix for callbacks
tbennun Jan 8, 2024
7574fe9
Fix tests
tbennun Jan 8, 2024
21b191d
Implement `numpy.fft.{fft,ifft}` and library node
tbennun Jan 8, 2024
2119c42
Cast cblas_transpose correctly
tbennun Jan 8, 2024
293f354
Merge branch 'master' into numpy-extension
alexnick83 Feb 20, 2024
949572f
Merge branch 'master' into numpy-extension
tbennun Feb 25, 2024
b89b33a
Merge branch 'master' into numpy-extension
tbennun Mar 23, 2024
2c9fcc6
Merge branch 'main' into numpy-extension
tbennun Oct 29, 2024
7fc04c5
Fix field name
tbennun Oct 29, 2024
3ae6d0a
Fix special-case handling of intracomms
tbennun Oct 29, 2024
d0f2994
Merge remote-tracking branch 'origin/main' into numpy-extension
tbennun Oct 30, 2024
95b5cd5
Fix process grids in Python visitor's defined variables
tbennun Oct 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions dace/codegen/cppunparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,8 @@ def _Num(self, t):
# For complex values, use ``dtype_to_typeclass``
if isinstance(t_n, complex):
dtype = dtypes.dtype_to_typeclass(complex)
repr_n = f'{dtype}({t_n.real}, {t_n.imag})'


# Handle large integer values
if isinstance(t_n, int):
Expand All @@ -765,10 +767,8 @@ def _Num(self, t):
elif bits >= 64:
warnings.warn(f'Value wider than 64 bits encountered in expression ({t_n}), emitting as-is')

if repr_n.endswith("j"):
self.write("%s(0, %s)" % (dtype, repr_n.replace("inf", INFSTR)[:-1]))
else:
self.write(repr_n.replace("inf", INFSTR))
repr_n = repr_n.replace("inf", INFSTR)
self.write(repr_n)

def _List(self, t):
raise NotImplementedError('Invalid C++')
Expand Down
4 changes: 4 additions & 0 deletions dace/distr_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ def _validate(self):
raise ValueError('Color must have only logical true (1) or false (0) values.')
return True

@property
def dtype(self):
return type(self)

def to_json(self):
attrs = serialize.all_properties_to_json(self)
retdict = {"type": type(self).__name__, "attributes": attrs}
Expand Down
32 changes: 16 additions & 16 deletions dace/frontend/common/distr.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ def _cart_create(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, dims: Shape


@oprepo.replaces_method('Intracomm', 'Create_cart')
def _intracomm_create(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: 'Intracomm', dims: ShapeType):
def _intracomm_create(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: str, dims: ShapeType):
""" Equivalent to `dace.comm.Cart_create(dims).
:param dims: Shape of the process-grid (see `dims` parameter of `MPI_Cart_create`), e.g., [2, 3, 3].
:return: Name of the new process-grid descriptor.
"""

from mpi4py import MPI
icomm_name, icomm_obj = icomm
icomm_name, icomm_obj = icomm, pv.globals[icomm]
if icomm_obj != MPI.COMM_WORLD:
raise ValueError('Only the mpi4py.MPI.COMM_WORLD Intracomm is supported in DaCe Python programs.')
return _cart_create(pv, sdfg, state, dims)
Expand Down Expand Up @@ -186,13 +186,13 @@ def _bcast(pv: ProgramVisitor,
def _intracomm_bcast(pv: 'ProgramVisitor',
sdfg: SDFG,
state: SDFGState,
comm: Tuple[str, 'Comm'],
comm: str,
buffer: str,
root: Union[str, sp.Expr, Number] = 0):
""" Equivalent to `dace.comm.Bcast(buffer, root)`. """

from mpi4py import MPI
comm_name, comm_obj = comm
comm_name, comm_obj = comm, pv.globals[comm]
if comm_obj == MPI.COMM_WORLD:
return _bcast(pv, sdfg, state, buffer, root)
# NOTE: Highly experimental
Expand Down Expand Up @@ -267,12 +267,12 @@ def _alltoall(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, inbuffer: str,


@oprepo.replaces_method('Intracomm', 'Alltoall')
def _intracomm_alltoall(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: 'Intracomm', inp_buffer: str,
def _intracomm_alltoall(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: str, inp_buffer: str,
out_buffer: str):
""" Equivalent to `dace.comm.Alltoall(inp_buffer, out_buffer)`. """

from mpi4py import MPI
icomm_name, icomm_obj = icomm
icomm_name, icomm_obj = icomm, pv.globals[icomm]
if icomm_obj != MPI.COMM_WORLD:
raise ValueError('Only the mpi4py.MPI.COMM_WORLD Intracomm is supported in DaCe Python programs.')
return _alltoall(pv, sdfg, state, inp_buffer, out_buffer)
Expand Down Expand Up @@ -303,12 +303,12 @@ def _allreduce(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, buffer: str, op


@oprepo.replaces_method('Intracomm', 'Allreduce')
def _intracomm_allreduce(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: 'Intracomm', inp_buffer: 'InPlace',
def _intracomm_allreduce(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: str, inp_buffer: 'InPlace',
out_buffer: str, op: str):
""" Equivalent to `dace.comm.Allreduce(out_buffer, op)`. """

from mpi4py import MPI
icomm_name, icomm_obj = icomm
icomm_name, icomm_obj = icomm, pv.globals[icomm]
if icomm_obj != MPI.COMM_WORLD:
raise ValueError('Only the mpi4py.MPI.COMM_WORLD Intracomm is supported in DaCe Python programs.')
if inp_buffer != MPI.IN_PLACE:
Expand Down Expand Up @@ -470,12 +470,12 @@ def _send(pv: ProgramVisitor,


@oprepo.replaces_method('Intracomm', 'Send')
def _intracomm_send(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: 'Intracomm', buffer: str,
def _intracomm_send(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: str, buffer: str,
dst: Union[str, sp.Expr, Number], tag: Union[str, sp.Expr, Number]):
""" Equivalent to `dace.comm.end(buffer, dst, tag)`. """

from mpi4py import MPI
icomm_name, icomm_obj = icomm
icomm_name, icomm_obj = icomm, pv.globals[icomm]
if icomm_obj != MPI.COMM_WORLD:
raise ValueError('Only the mpi4py.MPI.COMM_WORLD Intracomm is supported in DaCe Python programs.')
return _send(pv, sdfg, state, buffer, dst, tag)
Expand Down Expand Up @@ -592,12 +592,12 @@ def _isend(pv: ProgramVisitor,


@oprepo.replaces_method('Intracomm', 'Isend')
def _intracomm_isend(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: 'Intracomm', buffer: str,
def _intracomm_isend(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: str, buffer: str,
dst: Union[str, sp.Expr, Number], tag: Union[str, sp.Expr, Number]):
""" Equivalent to `dace.comm.Isend(buffer, dst, tag, req)`. """

from mpi4py import MPI
icomm_name, icomm_obj = icomm
icomm_name, icomm_obj = icomm, pv.globals[icomm]
if icomm_obj != MPI.COMM_WORLD:
raise ValueError('Only the mpi4py.MPI.COMM_WORLD Intracomm is supported in DaCe Python programs.')
req, _ = sdfg.add_array("isend_req", [1], dace.dtypes.opaque("MPI_Request"), transient=True, find_new_name=True)
Expand Down Expand Up @@ -690,12 +690,12 @@ def _recv(pv: ProgramVisitor,


@oprepo.replaces_method('Intracomm', 'Recv')
def _intracomm_Recv(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: 'Intracomm', buffer: str,
def _intracomm_Recv(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: str, buffer: str,
src: Union[str, sp.Expr, Number], tag: Union[str, sp.Expr, Number]):
""" Equivalent to `dace.comm.Recv(buffer, src, tagq)`. """

from mpi4py import MPI
icomm_name, icomm_obj = icomm
icomm_name, icomm_obj = icomm, pv.globals[icomm]
if icomm_obj != MPI.COMM_WORLD:
raise ValueError('Only the mpi4py.MPI.COMM_WORLD Intracomm is supported in DaCe Python programs.')
return _recv(pv, sdfg, state, buffer, src, tag)
Expand Down Expand Up @@ -810,12 +810,12 @@ def _irecv(pv: ProgramVisitor,


@oprepo.replaces_method('Intracomm', 'Irecv')
def _intracomm_irecv(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: 'Intracomm', buffer: str,
def _intracomm_irecv(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, icomm: str, buffer: str,
src: Union[str, sp.Expr, Number], tag: Union[str, sp.Expr, Number]):
""" Equivalent to `dace.comm.Irecv(buffer, src, tag, req)`. """

from mpi4py import MPI
icomm_name, icomm_obj = icomm
icomm_name, icomm_obj = icomm, pv.globals[icomm]
if icomm_obj != MPI.COMM_WORLD:
raise ValueError('Only the mpi4py.MPI.COMM_WORLD Intracomm is supported in DaCe Python programs.')
req, _ = sdfg.add_array("irecv_req", [1], dace.dtypes.opaque("MPI_Request"), transient=True, find_new_name=True)
Expand Down
57 changes: 40 additions & 17 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,7 +1342,7 @@ def defined(self):

# MPI-related stuff
result.update({
k: self.sdfg.process_grids[v]
v: self.sdfg.process_grids[v]
for k, v in self.variables.items() if v in self.sdfg.process_grids
})
try:
Expand Down Expand Up @@ -4461,7 +4461,14 @@ def visit_Call(self, node: ast.Call, create_callbacks=False):
func = node.func.value

if func is None:
funcname = rname(node)
func_result = self.visit(node.func)
if isinstance(func_result, str):
if isinstance(node.func, ast.Attribute):
funcname = f'{func_result}.{node.func.attr}'
else:
funcname = func_result
else:
funcname = rname(node)
# Check if the function exists as an SDFG in a different module
modname = until(funcname, '.')
if ('.' in funcname and len(modname) > 0 and modname in self.globals
Expand Down Expand Up @@ -4576,7 +4583,7 @@ def visit_Call(self, node: ast.Call, create_callbacks=False):
arg = self.scope_vars[modname]
else:
# Fallback to (name, object)
arg = (modname, self.defined[modname])
arg = modname
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change handled already in the code somehow, i.e., is self.defined queried on demand?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It’s a code path that we didn’t cover before in coverage. The rest of the code assumes a string return type rather than a tuple, and crashes somewhere else.

args.append(arg)
# Otherwise, try to find a default implementation for the SDFG
elif not found_ufunc:
Expand Down Expand Up @@ -4795,12 +4802,18 @@ def _visitname(self, name: str, node: ast.AST):
self.sdfg.add_symbol(result.name, result.dtype)
return result

if name in self.closure.callbacks:
return name

if name in self.sdfg.arrays:
return name

if name in self.sdfg.symbols:
return name

if name in __builtins__:
return name

if name not in self.scope_vars:
raise DaceSyntaxError(self, node, 'Use of undefined variable "%s"' % name)
rname = self.scope_vars[name]
Expand Down Expand Up @@ -4845,33 +4858,43 @@ def visit_NameConstant(self, node: NameConstant):
return self.visit_Constant(node)

def visit_Attribute(self, node: ast.Attribute):
# If visiting an attribute, return attribute value if it's of an array or global
name = until(astutils.unparse(node), '.')
result = self._visitname(name, node)
result = self.visit(node.value)
if isinstance(result, (tuple, list, dict)):
if len(result) > 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the fix for attributes on expressions?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly

raise DaceSyntaxError(
self, node.value, f'{type(result)} object cannot use attributes. Try storing the '
'object to a different variable first (e.g., ``a = result; a.attribute``')
else:
result = result[0]

tmpname = f"{result}.{astutils.unparse(node.attr)}"
if tmpname in self.sdfg.arrays:
return tmpname

if isinstance(result, str) and result in self.sdfg.arrays:
arr = self.sdfg.arrays[result]
elif isinstance(result, str) and result in self.scope_arrays:
arr = self.scope_arrays[result]
else:
return result
arr = None

# Try to find sub-SDFG attribute
func = oprepo.Replacements.get_attribute(type(arr), node.attr)
if func is not None:
# A new state is likely needed here, e.g., for transposition (ndarray.T)
self._add_state('%s_%d' % (type(node).__name__, node.lineno))
self.last_block.set_default_lineinfo(self.current_lineinfo)
result = func(self, self.sdfg, self.last_block, result)
self.last_block.set_default_lineinfo(None)
return result
if arr is not None:
func = oprepo.Replacements.get_attribute(type(arr), node.attr)
if func is not None:
# A new state is likely needed here, e.g., for transposition (ndarray.T)
self._add_state('%s_%d' % (type(node).__name__, node.lineno))
self.last_block.set_default_lineinfo(self.current_lineinfo)
result = func(self, self.sdfg, self.last_block, result)
self.last_block.set_default_lineinfo(None)
return result

# Otherwise, try to find compile-time attribute (such as shape)
try:
return getattr(arr, node.attr)
except KeyError:
if arr is not None:
return getattr(arr, node.attr)
return getattr(result, node.attr)
except (AttributeError, KeyError):
return result

def visit_List(self, node: ast.List):
Expand Down
2 changes: 2 additions & 0 deletions dace/frontend/python/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,8 @@ def global_value_to_node(self,
elif isinstance(value, symbolic.symbol):
# Symbols resolve to the symbol name
newnode = ast.Name(id=value.name, ctx=ast.Load())
elif isinstance(value, sympy.Basic): # Symbolic or constant expression
newnode = ast.parse(symbolic.symstr(value)).body[0].value
elif isinstance(value, ast.Name):
newnode = ast.Name(id=value.id, ctx=ast.Load())
elif (dtypes.isconstant(value) or isinstance(value, (StringLiteral, SDFG)) or hasattr(value, '__sdfg__')):
Expand Down
Loading
Loading