-
Notifications
You must be signed in to change notification settings - Fork 133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gemm, Gemv: Support symbolic expressions #1316
Conversation
8215004
to
3e7d472
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks good, but is it really solving the underlying issue? It may be better to change validation to throw warnings instead of errors if the sizes don't match (but potentially could, depending on what values symbols take). I will bring it up in the next meeting because I would like everyone's opinion.
I will not approve this yet, but ping me if it is blocking you.
@@ -1010,15 +1010,30 @@ def validate(self, sdfg, state): | |||
if dst_conn == '_a': | |||
subset = dc(memlet.subset) | |||
subset.squeeze() | |||
size0 = subset.size() | |||
size0 = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about we add to dace.symbolic.evaluate
a throw_error
flag with default value True.
When set to False,
it returns the symbolic expression instead of raising TypeError.
Then you could rewrite this (and everything below) as follows:
size0 = dace.symbolic.evaluate(subset, sdfg.constants, throw_error=False)
I will bring it up in the next meeting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dace.symbolic.symplify(expr, symbol_map) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using dace.symbolic.simplify
with a symbols map sounds like an elegant solution for this. The throw_error
flag may be a bit confusing imo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is then that one cannot re-specialize symbols because some memlet etc. may already be simplified. But this is the a usage error imo
Yes, I didn't want to make a bigger change, but I totally agree. It would be cool to have a function that "simplifies" expressions by substituting symbols with constants, but not necessarily evaluating it to a constant |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also generally LGTM. Other point as already discussed between the two of you - I agree.
@@ -1010,15 +1010,30 @@ def validate(self, sdfg, state): | |||
if dst_conn == '_a': | |||
subset = dc(memlet.subset) | |||
subset.squeeze() | |||
size0 = subset.size() | |||
size0 = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using dace.symbolic.simplify
with a symbols map sounds like an elegant solution for this. The throw_error
flag may be a bit confusing imo.
I don't think this is the right solution. Instead, we should warn on those indeterminate cases rather than error out.
if we check |
I would even say that the class should not be responsible for incorrect usage, i.e. not matching dimensions. This is the job of the front end/transformation/developer |
@lukastruemper @tbennun @phschaad I made #1321 as an alternative. If you agree, we can merge it, and Lukas can continue adding tests and amending the validation methods. |
Very nice, I'll be glad to review too once ready. Thanks! |
No description provided.