Skip to content

Speed up calculation of truncated normal mean and cdf #652

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jan 24, 2024

Conversation

DanTanAtAims
Copy link
Contributor

Instead of using built in mean, cdf, and truncation functions provided by DIstributions.jl, use the explicit formula of the truncated normal distributions.

Uses the approximated error function provided by SpecialFunctions.jl (added as dependency).

This does not resolve memory issues mentioned in issue #572.

Instead of using bulit in mean, cdf and truncation functions provided by DIstributions.jl use the explicit formula of the truncated normal distributions.

Uses the approximated error function provided by SpecialFunctions.jl.

Removed unused variable

removed unused distribution variable
@ConnectedSystems
Copy link
Collaborator

General comments:

  • There's a few other spots (at least two from memory) where the truncated normal mean is used. Please check the spec document I sent you earlier
  • Could you add some high level tests? I think something that ensures the method produces values < some small error threshold compared to the original approach. This is just in case someone accidentally changes something in the future.
  • There's some minor formatting issues we'll check when we chat again next.

Copy link
Collaborator

@ConnectedSystems ConnectedSystems left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor concerns, mostly to do with formatting.

Instead of using bulit in mean, cdf and truncation functions provided by DIstributions.jl use the explicit formula of the truncated normal distributions.

Uses the approximated error function provided by SpecialFunctions.jl.

Removed unused variable

removed unused distribution variable
@ConnectedSystems
Copy link
Collaborator

ConnectedSystems commented Jan 21, 2024

Some performance notes:

Trial runs with the Moore domain (256 scenarios).

Prior to the changes in this PR bleaching_mortality() takes ~38% of runtime with adjust_DHW_distribution() taking another ~13% (total of 51%).

Initial runtime was: ~1 min 40 secs (estimate - forgot to actually write it down)
Second run took: ~1 min 20 secs

Pasted image 20240119194729

With changes:

  • bleaching_mortality() takes ~13% + 9% (22%),
  • adjusted_DHW_distribution() taking another 12% (total of 34% of runtime).

Trajectories look as expected:

Pasted image 20240119210342

The image above is misleading as samples are biased towards guided scenarios, so here's one where I make sure there are equivalent number of samples for each scenario type.

Pasted image 20240119210951

@DanTanAtAims
Copy link
Contributor Author

I'll change interventions/seeding.jl to use the new calculation.

Should I export the function truncated_normal_mean and truncated_normal_cdf so that in can be used in interventions/seeding.jl and tests?

@ConnectedSystems
Copy link
Collaborator

Shouldn't need to export those, no - I think the functions need to be moved elsewhere (e.g., outside of corals/growth.jl) because it's not specific to growth.

@ConnectedSystems
Copy link
Collaborator

Right now I'm thinking Ecosystem.jl, does that sound good to you?

@DanTanAtAims
Copy link
Contributor Author

I've come across a potential solution to the numerical issues we were having with the new implementation. I came across the implementation here that is currently the implementation Julia DIstribution.jl uses.

It turns out mean(Truncated(Normal(...)...)) uses this implementation but it appears that the calls to truncated() and Normal() are big slow downs from profiling.

I implemented a nearly identical version of it excluding some error checks which we don't need. And benchmarked them and the result is nearly twice as fast as our previous attempt and agrees with mean(truncated(Normal())) which is unsurprising as its the same. The Benchmark results are attached. 1st benchmark is the original implementation and the second is the new implementation from Distributions.jl minus error checks.

tnm_bench

@ConnectedSystems
Copy link
Collaborator

Hmm interesting - the potential maximum is about double your earlier implementation, but median is indeed half.
Either way, very happy with this as it's ~17% of runtime compared to the usual mean(truncated()) approach.

Let me know when this is ready for a final review :)

updated new truncated normal mean calculations with tests. New calculations is the same as Distributions calculation minus checks.
Added truncated normal cdf calculations and tests added functions
Swapped use of mean(Truncated(Normal(mu, stdev), lower, upper)) to use truncated_normal_mean
@DanTanAtAims
Copy link
Contributor Author

The pull request is ready for final review.

The Truncated Normal Mean has nearly equivalent accuracy to the in built implementation we were using but is much faster.

The truncated normal cdf function has equivalent performance and I haven't found any speed ups that are stable for large deviations from the normal mean. However the new truncated_normal_cdf doesn't return NaN values for the large values unlike the inbuild cdf we were using.

@@ -336,6 +336,7 @@ end
include("clustering.jl")
include("data_loading.jl")
include("domain.jl")
include("Ecosystem.jl")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how I feel about this, but okay to leave as is.

Copy link
Collaborator

@ConnectedSystems ConnectedSystems left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @DanTanAtAims

Some mostly style-related issues to fix then we're good to go.

Addressed style issues.

Used random sampling for testing of truncated normal functions
Fixed grammar error
@DanTanAtAims
Copy link
Contributor Author

Thanks @ConnectedSystems for the comments.

I've addressed the issues raised and swapped the testing to draw random numbers using the same testing bounds as before.

Not sure if I mentioned this earlier but testing the tuncated normal cdf becomes difficult when the bounds start exceeding 10 standard deviations from the given mean as the built-in function sometimes returns NaN unexpectedly. These bounds aren't exceeded as far as I know in ADRIA however the implemented function we're are currently using won't throw an error.

I can added a warning to the cdf function indicating possible loss of accuracy is we ever test these bounds?

@ConnectedSystems
Copy link
Collaborator

ConnectedSystems commented Jan 24, 2024

I can added a warning to the cdf function indicating possible loss of accuracy is we ever test these bounds?

Could you add it as a @debug level log please?

Same usage as @info and @warn

https://docs.julialang.org/en/v1/stdlib/Logging/

Debug log possible loss of accuracy when truncated bounds exceed 10 standard deviations from normal mean.
Fixed doc string function signature and corrected comment grammar
fix spelling error
Copy link
Collaborator

@ConnectedSystems ConnectedSystems left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good, thanks!

@ConnectedSystems ConnectedSystems merged commit 61e7c5f into main Jan 24, 2024
@ConnectedSystems ConnectedSystems deleted the truncated-norm-speedup branch January 24, 2024 04:53
@ConnectedSystems
Copy link
Collaborator

ConnectedSystems commented Jan 24, 2024

Hmm, unfortunately we seem to be almost back at square 1.

bleaching_mortality!() + adjust_DHW_distribution() takes up ~50% of runtime.

image

Most of the time is spent in erf() and logerf(). Let's leave this for now and see if we can come up with anything else.

@ConnectedSystems
Copy link
Collaborator

Just had a thought - it could be that that there is a small speed up, but the reason it's spending so much time in those functions is they're called a lot - it's just the nature of the use context. We can discuss more when Pedro gets back.

@DanTanAtAims DanTanAtAims mentioned this pull request Aug 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants