Skip to content

Commit 1397d69

Browse files
authored
Make Gibbs work with step_warmup (#2502)
* Make Gibbs work with step_warmup * Bump patch version to 0.36.3 * Fix a Gibbs bug
1 parent 3cab967 commit 1397d69

File tree

4 files changed

+226
-9
lines changed

4 files changed

+226
-9
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.36.2"
3+
version = "0.36.3"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

src/mcmc/gibbs.jl

+102-8
Original file line numberDiff line numberDiff line change
@@ -405,20 +405,75 @@ end
405405

406406
varinfo(state::GibbsState) = state.vi
407407

408-
function DynamicPPL.initialstep(
408+
"""
409+
Initialise a VarInfo for the Gibbs sampler.
410+
411+
This is straight up copypasta from DynamicPPL's src/sampler.jl. It is repeated here to
412+
support calling both step and step_warmup as the initial step. DynamicPPL initialstep is
413+
incompatible with step_warmup.
414+
"""
415+
function initial_varinfo(rng, model, spl, initial_params)
416+
vi = DynamicPPL.default_varinfo(rng, model, spl)
417+
418+
# Update the parameters if provided.
419+
if initial_params !== nothing
420+
vi = DynamicPPL.initialize_parameters!!(vi, initial_params, spl, model)
421+
422+
# Update joint log probability.
423+
# This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
424+
# and https://github.com/TuringLang/Turing.jl/issues/1563
425+
# to avoid that existing variables are resampled
426+
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.DefaultContext()))
427+
end
428+
return vi
429+
end
430+
431+
function AbstractMCMC.step(
409432
rng::Random.AbstractRNG,
410433
model::DynamicPPL.Model,
411-
spl::DynamicPPL.Sampler{<:Gibbs},
412-
vi::DynamicPPL.AbstractVarInfo;
434+
spl::DynamicPPL.Sampler{<:Gibbs};
413435
initial_params=nothing,
414436
kwargs...,
415437
)
416438
alg = spl.alg
417439
varnames = alg.varnames
418440
samplers = alg.samplers
441+
vi = initial_varinfo(rng, model, spl, initial_params)
419442

420443
vi, states = gibbs_initialstep_recursive(
421-
rng, model, varnames, samplers, vi; initial_params=initial_params, kwargs...
444+
rng,
445+
model,
446+
AbstractMCMC.step,
447+
varnames,
448+
samplers,
449+
vi;
450+
initial_params=initial_params,
451+
kwargs...,
452+
)
453+
return Transition(model, vi), GibbsState(vi, states)
454+
end
455+
456+
function AbstractMCMC.step_warmup(
457+
rng::Random.AbstractRNG,
458+
model::DynamicPPL.Model,
459+
spl::DynamicPPL.Sampler{<:Gibbs};
460+
initial_params=nothing,
461+
kwargs...,
462+
)
463+
alg = spl.alg
464+
varnames = alg.varnames
465+
samplers = alg.samplers
466+
vi = initial_varinfo(rng, model, spl, initial_params)
467+
468+
vi, states = gibbs_initialstep_recursive(
469+
rng,
470+
model,
471+
AbstractMCMC.step_warmup,
472+
varnames,
473+
samplers,
474+
vi;
475+
initial_params=initial_params,
476+
kwargs...,
422477
)
423478
return Transition(model, vi), GibbsState(vi, states)
424479
end
@@ -427,9 +482,20 @@ end
427482
Take the first step of MCMC for the first component sampler, and call the same function
428483
recursively on the remaining samplers, until no samplers remain. Return the global VarInfo
429484
and a tuple of initial states for all component samplers.
485+
486+
The `step_function` argument should always be either AbstractMCMC.step or
487+
AbstractMCMC.step_warmup.
430488
"""
431489
function gibbs_initialstep_recursive(
432-
rng, model, varname_vecs, samplers, vi, states=(); initial_params=nothing, kwargs...
490+
rng,
491+
model,
492+
step_function::Function,
493+
varname_vecs,
494+
samplers,
495+
vi,
496+
states=();
497+
initial_params=nothing,
498+
kwargs...,
433499
)
434500
# End recursion
435501
if isempty(varname_vecs) && isempty(samplers)
@@ -450,7 +516,7 @@ function gibbs_initialstep_recursive(
450516
conditioned_model, context = make_conditional(model, varnames, vi)
451517

452518
# Take initial step with the current sampler.
453-
_, new_state = AbstractMCMC.step(
519+
_, new_state = step_function(
454520
rng,
455521
conditioned_model,
456522
sampler;
@@ -470,6 +536,7 @@ function gibbs_initialstep_recursive(
470536
return gibbs_initialstep_recursive(
471537
rng,
472538
model,
539+
step_function,
473540
varname_vecs_tail,
474541
samplers_tail,
475542
vi,
@@ -493,7 +560,29 @@ function AbstractMCMC.step(
493560
states = state.states
494561
@assert length(samplers) == length(state.states)
495562

496-
vi, states = gibbs_step_recursive(rng, model, varnames, samplers, states, vi; kwargs...)
563+
vi, states = gibbs_step_recursive(
564+
rng, model, AbstractMCMC.step, varnames, samplers, states, vi; kwargs...
565+
)
566+
return Transition(model, vi), GibbsState(vi, states)
567+
end
568+
569+
function AbstractMCMC.step_warmup(
570+
rng::Random.AbstractRNG,
571+
model::DynamicPPL.Model,
572+
spl::DynamicPPL.Sampler{<:Gibbs},
573+
state::GibbsState;
574+
kwargs...,
575+
)
576+
vi = varinfo(state)
577+
alg = spl.alg
578+
varnames = alg.varnames
579+
samplers = alg.samplers
580+
states = state.states
581+
@assert length(samplers) == length(state.states)
582+
583+
vi, states = gibbs_step_recursive(
584+
rng, model, AbstractMCMC.step_warmup, varnames, samplers, states, vi; kwargs...
585+
)
497586
return Transition(model, vi), GibbsState(vi, states)
498587
end
499588

@@ -620,10 +709,14 @@ end
620709
"""
621710
Run a Gibbs step for the first varname/sampler/state tuple, and recursively call the same
622711
function on the tail, until there are no more samplers left.
712+
713+
The `step_function` argument should always be either AbstractMCMC.step or
714+
AbstractMCMC.step_warmup.
623715
"""
624716
function gibbs_step_recursive(
625717
rng::Random.AbstractRNG,
626718
model::DynamicPPL.Model,
719+
step_function::Function,
627720
varname_vecs,
628721
samplers,
629722
states,
@@ -657,7 +750,7 @@ function gibbs_step_recursive(
657750
state = setparams_varinfo!!(conditioned_model, sampler, state, vi)
658751

659752
# Take a step with the local sampler.
660-
new_state = last(AbstractMCMC.step(rng, conditioned_model, sampler, state; kwargs...))
753+
new_state = last(step_function(rng, conditioned_model, sampler, state; kwargs...))
661754

662755
new_vi_local = varinfo(new_state)
663756
# Merge the latest values for all the variables in the current sampler.
@@ -668,6 +761,7 @@ function gibbs_step_recursive(
668761
return gibbs_step_recursive(
669762
rng,
670763
model,
764+
step_function,
671765
varname_vecs_tail,
672766
samplers_tail,
673767
states_tail,

src/mcmc/repeat_sampler.jl

+27
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,30 @@ function AbstractMCMC.step(
6060
end
6161
return transition, state
6262
end
63+
64+
function AbstractMCMC.step_warmup(
65+
rng::Random.AbstractRNG,
66+
model::AbstractMCMC.AbstractModel,
67+
sampler::RepeatSampler;
68+
kwargs...,
69+
)
70+
return AbstractMCMC.step_warmup(rng, model, sampler.sampler; kwargs...)
71+
end
72+
73+
function AbstractMCMC.step_warmup(
74+
rng::Random.AbstractRNG,
75+
model::AbstractMCMC.AbstractModel,
76+
sampler::RepeatSampler,
77+
state;
78+
kwargs...,
79+
)
80+
transition, state = AbstractMCMC.step_warmup(
81+
rng, model, sampler.sampler, state; kwargs...
82+
)
83+
for _ in 2:(sampler.num_repeat)
84+
transition, state = AbstractMCMC.step_warmup(
85+
rng, model, sampler.sampler, state; kwargs...
86+
)
87+
end
88+
return transition, state
89+
end

test/mcmc/gibbs.jl

+96
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,102 @@ end
268268
@test chain1.value == chain2.value
269269
end
270270

271+
@testset "Gibbs warmup" begin
272+
# An inference algorithm, for testing purposes, that records how many warm-up steps
273+
# and how many non-warm-up steps haven been taken.
274+
mutable struct WarmupCounter <: Inference.InferenceAlgorithm
275+
warmup_init_count::Int
276+
non_warmup_init_count::Int
277+
warmup_count::Int
278+
non_warmup_count::Int
279+
280+
WarmupCounter() = new(0, 0, 0, 0)
281+
end
282+
283+
Turing.Inference.drop_space(wuc::WarmupCounter) = wuc
284+
Turing.Inference.getspace(::WarmupCounter) = ()
285+
Turing.Inference.isgibbscomponent(::WarmupCounter) = true
286+
287+
# A trivial state that holds nothing but a VarInfo, to be used with WarmupCounter.
288+
struct VarInfoState{T}
289+
vi::T
290+
end
291+
292+
Turing.Inference.varinfo(state::VarInfoState) = state.vi
293+
function Turing.Inference.setparams_varinfo!!(
294+
::DynamicPPL.Model,
295+
::DynamicPPL.Sampler,
296+
::VarInfoState,
297+
params::DynamicPPL.AbstractVarInfo,
298+
)
299+
return VarInfoState(params)
300+
end
301+
302+
function AbstractMCMC.step(
303+
::Random.AbstractRNG,
304+
model::DynamicPPL.Model,
305+
spl::DynamicPPL.Sampler{<:WarmupCounter};
306+
kwargs...,
307+
)
308+
spl.alg.non_warmup_init_count += 1
309+
return Turing.Inference.Transition(nothing, 0.0),
310+
VarInfoState(DynamicPPL.VarInfo(model))
311+
end
312+
313+
function AbstractMCMC.step_warmup(
314+
::Random.AbstractRNG,
315+
model::DynamicPPL.Model,
316+
spl::DynamicPPL.Sampler{<:WarmupCounter};
317+
kwargs...,
318+
)
319+
spl.alg.warmup_init_count += 1
320+
return Turing.Inference.Transition(nothing, 0.0),
321+
VarInfoState(DynamicPPL.VarInfo(model))
322+
end
323+
324+
function AbstractMCMC.step(
325+
::Random.AbstractRNG,
326+
::DynamicPPL.Model,
327+
spl::DynamicPPL.Sampler{<:WarmupCounter},
328+
s::VarInfoState;
329+
kwargs...,
330+
)
331+
spl.alg.non_warmup_count += 1
332+
return Turing.Inference.Transition(nothing, 0.0), s
333+
end
334+
335+
function AbstractMCMC.step_warmup(
336+
::Random.AbstractRNG,
337+
::DynamicPPL.Model,
338+
spl::DynamicPPL.Sampler{<:WarmupCounter},
339+
s::VarInfoState;
340+
kwargs...,
341+
)
342+
spl.alg.warmup_count += 1
343+
return Turing.Inference.Transition(nothing, 0.0), s
344+
end
345+
346+
@model f() = x ~ Normal()
347+
m = f()
348+
349+
num_samples = 10
350+
num_warmup = 3
351+
wuc = WarmupCounter()
352+
sample(m, Gibbs(:x => wuc), num_samples; num_warmup=num_warmup)
353+
@test wuc.warmup_init_count == 1
354+
@test wuc.non_warmup_init_count == 0
355+
@test wuc.warmup_count == num_warmup
356+
@test wuc.non_warmup_count == num_samples - 1
357+
358+
num_reps = 2
359+
wuc = WarmupCounter()
360+
sample(m, Gibbs(:x => RepeatSampler(wuc, num_reps)), num_samples; num_warmup=num_warmup)
361+
@test wuc.warmup_init_count == 1
362+
@test wuc.non_warmup_init_count == 0
363+
@test wuc.warmup_count == num_warmup * num_reps
364+
@test wuc.non_warmup_count == (num_samples - 1) * num_reps
365+
end
366+
271367
@testset "Testing gibbs.jl with $adbackend" for adbackend in ADUtils.adbackends
272368
@info "Starting Gibbs tests with $adbackend"
273369
@testset "Deprecated Gibbs constructors" begin

0 commit comments

Comments
 (0)