@@ -405,20 +405,75 @@ end
405
405
406
406
varinfo (state:: GibbsState ) = state. vi
407
407
408
- function DynamicPPL. initialstep (
408
+ """
409
+ Initialise a VarInfo for the Gibbs sampler.
410
+
411
+ This is straight up copypasta from DynamicPPL's src/sampler.jl. It is repeated here to
412
+ support calling both step and step_warmup as the initial step. DynamicPPL initialstep is
413
+ incompatible with step_warmup.
414
+ """
415
+ function initial_varinfo (rng, model, spl, initial_params)
416
+ vi = DynamicPPL. default_varinfo (rng, model, spl)
417
+
418
+ # Update the parameters if provided.
419
+ if initial_params != = nothing
420
+ vi = DynamicPPL. initialize_parameters!! (vi, initial_params, spl, model)
421
+
422
+ # Update joint log probability.
423
+ # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
424
+ # and https://github.com/TuringLang/Turing.jl/issues/1563
425
+ # to avoid that existing variables are resampled
426
+ vi = last (DynamicPPL. evaluate!! (model, vi, DynamicPPL. DefaultContext ()))
427
+ end
428
+ return vi
429
+ end
430
+
431
+ function AbstractMCMC. step (
409
432
rng:: Random.AbstractRNG ,
410
433
model:: DynamicPPL.Model ,
411
- spl:: DynamicPPL.Sampler{<:Gibbs} ,
412
- vi:: DynamicPPL.AbstractVarInfo ;
434
+ spl:: DynamicPPL.Sampler{<:Gibbs} ;
413
435
initial_params= nothing ,
414
436
kwargs... ,
415
437
)
416
438
alg = spl. alg
417
439
varnames = alg. varnames
418
440
samplers = alg. samplers
441
+ vi = initial_varinfo (rng, model, spl, initial_params)
419
442
420
443
vi, states = gibbs_initialstep_recursive (
421
- rng, model, varnames, samplers, vi; initial_params= initial_params, kwargs...
444
+ rng,
445
+ model,
446
+ AbstractMCMC. step,
447
+ varnames,
448
+ samplers,
449
+ vi;
450
+ initial_params= initial_params,
451
+ kwargs... ,
452
+ )
453
+ return Transition (model, vi), GibbsState (vi, states)
454
+ end
455
+
456
+ function AbstractMCMC. step_warmup (
457
+ rng:: Random.AbstractRNG ,
458
+ model:: DynamicPPL.Model ,
459
+ spl:: DynamicPPL.Sampler{<:Gibbs} ;
460
+ initial_params= nothing ,
461
+ kwargs... ,
462
+ )
463
+ alg = spl. alg
464
+ varnames = alg. varnames
465
+ samplers = alg. samplers
466
+ vi = initial_varinfo (rng, model, spl, initial_params)
467
+
468
+ vi, states = gibbs_initialstep_recursive (
469
+ rng,
470
+ model,
471
+ AbstractMCMC. step_warmup,
472
+ varnames,
473
+ samplers,
474
+ vi;
475
+ initial_params= initial_params,
476
+ kwargs... ,
422
477
)
423
478
return Transition (model, vi), GibbsState (vi, states)
424
479
end
427
482
Take the first step of MCMC for the first component sampler, and call the same function
428
483
recursively on the remaining samplers, until no samplers remain. Return the global VarInfo
429
484
and a tuple of initial states for all component samplers.
485
+
486
+ The `step_function` argument should always be either AbstractMCMC.step or
487
+ AbstractMCMC.step_warmup.
430
488
"""
431
489
function gibbs_initialstep_recursive (
432
- rng, model, varname_vecs, samplers, vi, states= (); initial_params= nothing , kwargs...
490
+ rng,
491
+ model,
492
+ step_function:: Function ,
493
+ varname_vecs,
494
+ samplers,
495
+ vi,
496
+ states= ();
497
+ initial_params= nothing ,
498
+ kwargs... ,
433
499
)
434
500
# End recursion
435
501
if isempty (varname_vecs) && isempty (samplers)
@@ -450,7 +516,7 @@ function gibbs_initialstep_recursive(
450
516
conditioned_model, context = make_conditional (model, varnames, vi)
451
517
452
518
# Take initial step with the current sampler.
453
- _, new_state = AbstractMCMC . step (
519
+ _, new_state = step_function (
454
520
rng,
455
521
conditioned_model,
456
522
sampler;
@@ -470,6 +536,7 @@ function gibbs_initialstep_recursive(
470
536
return gibbs_initialstep_recursive (
471
537
rng,
472
538
model,
539
+ step_function,
473
540
varname_vecs_tail,
474
541
samplers_tail,
475
542
vi,
@@ -493,7 +560,29 @@ function AbstractMCMC.step(
493
560
states = state. states
494
561
@assert length (samplers) == length (state. states)
495
562
496
- vi, states = gibbs_step_recursive (rng, model, varnames, samplers, states, vi; kwargs... )
563
+ vi, states = gibbs_step_recursive (
564
+ rng, model, AbstractMCMC. step, varnames, samplers, states, vi; kwargs...
565
+ )
566
+ return Transition (model, vi), GibbsState (vi, states)
567
+ end
568
+
569
+ function AbstractMCMC. step_warmup (
570
+ rng:: Random.AbstractRNG ,
571
+ model:: DynamicPPL.Model ,
572
+ spl:: DynamicPPL.Sampler{<:Gibbs} ,
573
+ state:: GibbsState ;
574
+ kwargs... ,
575
+ )
576
+ vi = varinfo (state)
577
+ alg = spl. alg
578
+ varnames = alg. varnames
579
+ samplers = alg. samplers
580
+ states = state. states
581
+ @assert length (samplers) == length (state. states)
582
+
583
+ vi, states = gibbs_step_recursive (
584
+ rng, model, AbstractMCMC. step_warmup, varnames, samplers, states, vi; kwargs...
585
+ )
497
586
return Transition (model, vi), GibbsState (vi, states)
498
587
end
499
588
@@ -620,10 +709,14 @@ end
620
709
"""
621
710
Run a Gibbs step for the first varname/sampler/state tuple, and recursively call the same
622
711
function on the tail, until there are no more samplers left.
712
+
713
+ The `step_function` argument should always be either AbstractMCMC.step or
714
+ AbstractMCMC.step_warmup.
623
715
"""
624
716
function gibbs_step_recursive (
625
717
rng:: Random.AbstractRNG ,
626
718
model:: DynamicPPL.Model ,
719
+ step_function:: Function ,
627
720
varname_vecs,
628
721
samplers,
629
722
states,
@@ -657,7 +750,7 @@ function gibbs_step_recursive(
657
750
state = setparams_varinfo!! (conditioned_model, sampler, state, vi)
658
751
659
752
# Take a step with the local sampler.
660
- new_state = last (AbstractMCMC . step (rng, conditioned_model, sampler, state; kwargs... ))
753
+ new_state = last (step_function (rng, conditioned_model, sampler, state; kwargs... ))
661
754
662
755
new_vi_local = varinfo (new_state)
663
756
# Merge the latest values for all the variables in the current sampler.
@@ -668,6 +761,7 @@ function gibbs_step_recursive(
668
761
return gibbs_step_recursive (
669
762
rng,
670
763
model,
764
+ step_function,
671
765
varname_vecs_tail,
672
766
samplers_tail,
673
767
states_tail,
0 commit comments