Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemm, Gemv: Support symbolic expressions #1316

Closed
wants to merge 1 commit into from

Conversation

lukastruemper
Copy link
Contributor

No description provided.

Copy link
Contributor

@alexnick83 alexnick83 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good, but is it really solving the underlying issue? It may be better to change validation to throw warnings instead of errors if the sizes don't match (but potentially could, depending on what values symbols take). I will bring it up in the next meeting because I would like everyone's opinion.

I will not approve this yet, but ping me if it is blocking you.

@@ -1010,15 +1010,30 @@ def validate(self, sdfg, state):
if dst_conn == '_a':
subset = dc(memlet.subset)
subset.squeeze()
size0 = subset.size()
size0 = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about we add to dace.symbolic.evaluate a throw_error flag with default value True. When set to False, it returns the symbolic expression instead of raising TypeError. Then you could rewrite this (and everything below) as follows:

size0 = dace.symbolic.evaluate(subset, sdfg.constants, throw_error=False)

I will bring it up in the next meeting.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dace.symbolic.symplify(expr, symbol_map) ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using dace.symbolic.simplify with a symbols map sounds like an elegant solution for this. The throw_error flag may be a bit confusing imo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is then that one cannot re-specialize symbols because some memlet etc. may already be simplified. But this is the a usage error imo

@lukastruemper
Copy link
Contributor Author

It looks good, but is it really solving the underlying issue? It may be better to change validation to throw warnings instead of errors if the sizes don't match (but potentially could, depending on what values symbols take). I will bring it up in the next meeting because I would like everyone's opinion.

I will not approve this yet, but ping me if it is blocking you.

Yes, I didn't want to make a bigger change, but I totally agree. It would be cool to have a function that "simplifies" expressions by substituting symbols with constants, but not necessarily evaluating it to a constant

phschaad
phschaad previously approved these changes Jul 13, 2023
Copy link
Collaborator

@phschaad phschaad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also generally LGTM. Other point as already discussed between the two of you - I agree.

@@ -1010,15 +1010,30 @@ def validate(self, sdfg, state):
if dst_conn == '_a':
subset = dc(memlet.subset)
subset.squeeze()
size0 = subset.size()
size0 = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using dace.symbolic.simplify with a symbols map sounds like an elegant solution for this. The throw_error flag may be a bit confusing imo.

@phschaad phschaad dismissed their stale review July 13, 2023 12:19

Comment, not approve.

@tbennun
Copy link
Collaborator

tbennun commented Jul 13, 2023

I don't think this is the right solution. Instead, we should warn on those indeterminate cases rather than error out.
In SymPy:

  • M == M - 3 returns False
  • M == M returns True
  • M == N returns M == N aka indeterminate

if we check (size1 == size2) == False (note the == False rather than is False because SymPy is SymPy), that would be preferable.

@lukastruemper
Copy link
Contributor Author

I don't think this is the right solution. Instead, we should warn on those indeterminate cases rather than error out. In SymPy:

  • M == M - 3 returns False
  • M == M returns True
  • M == N returns M == N aka indeterminate

if we check (size1 == size2) == False (note the == False rather than is False because SymPy is SymPy), that would be preferable.

I would even say that the class should not be responsible for incorrect usage, i.e. not matching dimensions. This is the job of the front end/transformation/developer

@alexnick83
Copy link
Contributor

alexnick83 commented Jul 15, 2023

@lukastruemper @tbennun @phschaad

I made #1321 as an alternative. If you agree, we can merge it, and Lukas can continue adding tests and amending the validation methods.

@phschaad
Copy link
Collaborator

phschaad commented Jul 17, 2023

@lukastruemper @tbennun @phschaad

I made #1321 as an alternative. If you agree, we can merge it, and Lukas can continue adding tests and amending the validation methods.

Very nice, I'll be glad to review too once ready. Thanks!

@lukastruemper lukastruemper deleted the users/lukas/blas-fixes branch November 24, 2023 13:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants