-
Notifications
You must be signed in to change notification settings - Fork 7
Speed up calculation of truncated normal mean and cdf #652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Instead of using bulit in mean, cdf and truncation functions provided by DIstributions.jl use the explicit formula of the truncated normal distributions. Uses the approximated error function provided by SpecialFunctions.jl. Removed unused variable removed unused distribution variable
General comments:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some minor concerns, mostly to do with formatting.
Instead of using bulit in mean, cdf and truncation functions provided by DIstributions.jl use the explicit formula of the truncated normal distributions. Uses the approximated error function provided by SpecialFunctions.jl. Removed unused variable removed unused distribution variable
8092086
to
183eaba
Compare
…/ADRIA.jl into truncated-norm-speedup
Added spaces between operators in new code.
Some performance notes: Trial runs with the Moore domain (256 scenarios). Prior to the changes in this PR Initial runtime was: ~1 min 40 secs (estimate - forgot to actually write it down) With changes:
Trajectories look as expected: The image above is misleading as samples are biased towards guided scenarios, so here's one where I make sure there are equivalent number of samples for each scenario type. |
I'll change interventions/seeding.jl to use the new calculation. Should I export the function truncated_normal_mean and truncated_normal_cdf so that in can be used in interventions/seeding.jl and tests? |
Shouldn't need to export those, no - I think the functions need to be moved elsewhere (e.g., outside of |
Right now I'm thinking |
I've come across a potential solution to the numerical issues we were having with the new implementation. I came across the implementation here that is currently the implementation Julia DIstribution.jl uses. It turns out mean(Truncated(Normal(...)...)) uses this implementation but it appears that the calls to truncated() and Normal() are big slow downs from profiling. I implemented a nearly identical version of it excluding some error checks which we don't need. And benchmarked them and the result is nearly twice as fast as our previous attempt and agrees with mean(truncated(Normal())) which is unsurprising as its the same. The Benchmark results are attached. 1st benchmark is the original implementation and the second is the new implementation from Distributions.jl minus error checks. ![]() |
Hmm interesting - the potential maximum is about double your earlier implementation, but median is indeed half. Let me know when this is ready for a final review :) |
updated new truncated normal mean calculations with tests. New calculations is the same as Distributions calculation minus checks.
Added truncated normal cdf calculations and tests added functions
Swapped use of mean(Truncated(Normal(mu, stdev), lower, upper)) to use truncated_normal_mean
The pull request is ready for final review. The Truncated Normal Mean has nearly equivalent accuracy to the in built implementation we were using but is much faster. The truncated normal cdf function has equivalent performance and I haven't found any speed ups that are stable for large deviations from the normal mean. However the new truncated_normal_cdf doesn't return NaN values for the large values unlike the inbuild cdf we were using. |
@@ -336,6 +336,7 @@ end | |||
include("clustering.jl") | |||
include("data_loading.jl") | |||
include("domain.jl") | |||
include("Ecosystem.jl") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure how I feel about this, but okay to leave as is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @DanTanAtAims
Some mostly style-related issues to fix then we're good to go.
Addressed style issues. Used random sampling for testing of truncated normal functions
Fixed grammar error
Thanks @ConnectedSystems for the comments. I've addressed the issues raised and swapped the testing to draw random numbers using the same testing bounds as before. Not sure if I mentioned this earlier but testing the tuncated normal cdf becomes difficult when the bounds start exceeding 10 standard deviations from the given mean as the built-in function sometimes returns NaN unexpectedly. These bounds aren't exceeded as far as I know in ADRIA however the implemented function we're are currently using won't throw an error. I can added a warning to the cdf function indicating possible loss of accuracy is we ever test these bounds? |
Could you add it as a Same usage as |
Debug log possible loss of accuracy when truncated bounds exceed 10 standard deviations from normal mean.
Fixed doc string function signature and corrected comment grammar
fix spelling error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good, thanks!
Just had a thought - it could be that that there is a small speed up, but the reason it's spending so much time in those functions is they're called a lot - it's just the nature of the use context. We can discuss more when Pedro gets back. |
Instead of using built in mean, cdf, and truncation functions provided by DIstributions.jl, use the explicit formula of the truncated normal distributions.
Uses the approximated error function provided by SpecialFunctions.jl (added as dependency).
This does not resolve memory issues mentioned in issue #572.