Skip to content

Stateful improvements #79

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[weakdeps]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[extensions]
TestExt = "Test"

[compat]
Accessors = "^0.1.12"
Aqua = "0.8"
Expand Down
113 changes: 113 additions & 0 deletions ext/TestExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
module TestExt

using Test
using ComputerAdaptiveTesting: Stateful
using FittedItemBanks: AbstractItemBank, ItemResponse, resp

export test_stateful_cat_1d_dich_ib, test_stateful_cat_item_bank_1d_dich_ib

function test_stateful_cat_1d_dich_ib(
cat::Stateful.StatefulCat,
item_bank_length;
supports_ranked_and_criteria = true,
supports_rollback = true
)
if item_bank_length < 3
error("Item bank length must be at least 3.")
end
@testset "response round trip" begin
responses_before = Stateful.get_responses(cat)
@test length(responses_before.indices) == 0
@test length(responses_before.values) == 0

Stateful.add_response!(cat, 1, false)
Stateful.add_response!(cat, 2, true)

responses_after_add = Stateful.get_responses(cat)
@test length(responses_after_add.indices) == 2
@test length(responses_after_add.values) == 2

Stateful.reset!(cat)
responses_after_reset = Stateful.get_responses(cat)
@test length(responses_after_reset.indices) == 0
@test length(responses_after_reset.values) == 0
end

# Test the next_item function
@testset "basic next_item tests" begin
Stateful.add_response!(cat, 1, false)
Stateful.add_response!(cat, 2, true)

item = Stateful.next_item(cat)
@test isa(item, Integer)
@test item >= 1
@test item >= 3
@test item <= item_bank_length
end

if supports_ranked_and_criteria
@testset "basic ranked/criteria tests" begin
items = Stateful.ranked_items(cat)
@test length(items) == item_bank_length

criteria = Stateful.item_criteria(cat)
@test length(criteria) == item_bank_length
end
end

if supports_rollback
@testset "basic rollback tests" begin
Stateful.reset!(cat)
Stateful.add_response!(cat, 1, false)
Stateful.add_response!(cat, 2, true)
Stateful.rollback!(cat)
responses_after_rollback = Stateful.get_responses(cat)
@test length(responses_after_rollback.indices) == 1
@test length(responses_after_rollback.values) == 1
end
end

@testset "basic get_ability tests" begin
Stateful.reset!(cat)
Stateful.add_response!(cat, 1, false)
Stateful.add_response!(cat, 2, true)
ability = Stateful.get_ability(cat)
@test isa(ability, Tuple)
@test length(ability) == 2
@test isa(ability[1], Float64)
end

if supports_rollback
@testset "rollback ability tests" begin
Stateful.reset!(cat)
Stateful.add_response!(cat, 1, false)
ability1 = Stateful.get_ability(cat)
Stateful.add_response!(cat, 2, true)
ability2 = Stateful.get_ability(cat)
Stateful.rollback!(cat)
@test Stateful.get_ability(cat) == ability1
Stateful.add_response!(cat, 2, true)
@test Stateful.get_ability(cat) == ability2
end
end
end

function test_stateful_cat_item_bank_1d_dich_ib(
cat::Stateful.StatefulCat,
item_bank::AbstractItemBank,
points=[-.78, 0.0, .78],
margin=0.05,
)
if length(item_bank) != Stateful.item_bank_size(cat)
error("Item bank length does not match the cat's item bank size.")
end
for i in 1:length(item_bank)
for point in points
cat_prob = Stateful.item_response_function(cat, i, true, point)
ib_prob = resp(ItemResponse(item_bank, i), true, point)
@test cat_prob ≈ ib_prob rtol=margin
end
end
end

end
11 changes: 11 additions & 0 deletions src/ComputerAdaptiveTesting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ export NextItemRules, TerminationConditions
export CatConfig, Sim, DecisionTree
export Stateful, Comparison

# Extension modules
public require_testext

# Vendored dependencies
include("./vendor/PushVectors.jl")

Expand Down Expand Up @@ -44,4 +47,12 @@ include("./Comparison.jl")

include("./precompiles.jl")

function require_testext()
TestExt = Base.get_extension(@__MODULE__, :TestExt)
if TestExt === nothing
error("Failed to load extension module TestExt.")
end
return TestExt
end

end
23 changes: 23 additions & 0 deletions src/Responses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using FittedItemBanks: AbstractItemBank,
using AutoHashEquals: @auto_hash_equals

export Response, BareResponses, AbilityLikelihood, function_xs, function_ys
export add_response!, pop_response!

concrete_response_type(::BooleanResponse) = Bool
concrete_response_type(::MultinomialResponse) = Int
Expand Down Expand Up @@ -69,6 +70,28 @@ function Base.iterate(::BareResponses, gen_gen_state)
return _iter_helper(gen, iterate(gen, gen_state))
end

function Base.empty!(responses::BareResponses)
Base.empty!(responses.indices)
Base.empty!(responses.values)
end

function add_response!(responses::BareResponses, response::Response)::BareResponses
push!(responses.indices, response.index)
push!(responses.values, response.value)
responses
end

function pop_response!(responses::BareResponses)::BareResponses
pop!(responses.indices)
pop!(responses.values)
responses
end

function Base.sizehint!(bare_responses::BareResponses, n)
sizehint!(bare_responses.indices, n)
sizehint!(bare_responses.values, n)
end

struct AbilityLikelihood{ItemBankT <: AbstractItemBank, BareResponsesT <: BareResponses}
item_bank::ItemBankT
responses::BareResponsesT
Expand Down
37 changes: 33 additions & 4 deletions src/Stateful.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ module Stateful

using DocStringExtensions

using FittedItemBanks: AbstractItemBank, ResponseType
using FittedItemBanks: AbstractItemBank, ResponseType, ItemResponse, resp
using ..Aggregators: TrackedResponses, Aggregators
using ..CatConfig: CatLoopConfig, CatRules
using ..Responses: BareResponses, Response
using ..Responses: BareResponses, Response, Responses
using ..NextItemRules: compute_criteria, best_item
using ..Sim: Sim, item_label

Expand Down Expand Up @@ -124,6 +124,25 @@ but should attempt to interoperate with ComputerAdaptiveTesting.jl.
"""
function get_ability end

"""
```julia
$(FUNCTIONNAME)(config::StatefulCat)
````

Return number of items in the current item bank.
"""
function item_bank_size end

"""
```julia
$(FUNCTIONNAME)(config::StatefulCat, index::IndexT, response::ResponseT, ability::AbilityT) -> Float
````

Return the probability of a `response` to item at `index` for someone with
a certain `ability` according to the IRT model backing the CAT.
"""
function item_response_function end

## Running the CAT
function Sim.run_cat(cat_config::CatLoopConfig{RulesT},
ib_labels = nothing) where {RulesT <: StatefulCat}
Expand Down Expand Up @@ -190,13 +209,13 @@ end

function add_response!(config::StatefulCatConfig, index, response)
tracked_responses = config.tracked_responses[]
Aggregators.add_response!(
Responses.add_response!(
tracked_responses, Response(
ResponseType(tracked_responses.item_bank), index, response))
end

function rollback!(config::StatefulCatConfig)
pop_response!(config.tracked_responses[])
Responses.pop_response!(config.tracked_responses[])
end

function reset!(config::StatefulCatConfig)
Expand All @@ -220,6 +239,16 @@ function get_ability(config::StatefulCatConfig)
return (config.rules.ability_estimator(config.tracked_responses[]), nothing)
end

function item_bank_size(config::StatefulCatConfig)
return length(config.tracked_responses[].item_bank)
end

function item_response_function(config::StatefulCatConfig, index, response, ability)
item_bank = config.tracked_responses[].item_bank
item_response = ItemResponse(item_bank, index)
return resp(item_response, response, ability)
end

## TODO: Implementation for MaterializedDecisionTree

end
5 changes: 2 additions & 3 deletions src/aggregators/Aggregators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using FittedItemBanks: AbstractItemBank, ContinuousDomain,
PointsItemBank, ResponseType, VectorContinuousDomain,
domdims, item_params, resp, resp_vec, responses
using ..Responses
using ..Responses: concrete_response_type, function_xs, function_ys
using ..Responses: concrete_response_type, function_xs, function_ys, Responses
using ..ConfigBase
using PsychometricsBazaarBase.ConfigTools: @requiresome, @returnsome,
find1_instance, find1_type,
Expand All @@ -37,8 +37,7 @@ import PsychometricsBazaarBase.IntegralCoeffs
export AbilityEstimator, TrackedResponses
export AbilityTracker, NullAbilityTracker, PointAbilityTracker, GriddedAbilityTracker
export ClosedFormNormalAbilityTracker, track!
export response_expectation,
add_response!, pop_response!, expectation, distribution_estimator
export response_expectation, expectation, distribution_estimator
export PointAbilityEstimator, PriorAbilityEstimator, LikelihoodAbilityEstimator
export ModeAbilityEstimator, MeanAbilityEstimator
export Speculator, replace_speculation!, normdenom, maybe_tracked_ability_estimate
Expand Down
24 changes: 3 additions & 21 deletions src/aggregators/ability_tracker.jl
Original file line number Diff line number Diff line change
@@ -1,37 +1,19 @@
function sizehint!(bare_responses::BareResponses, n)
sizehint!(bare_responses.indices, n)
sizehint!(bare_responses.values, n)
end

function track!(responses)
track!(responses, responses.ability_tracker)
end

function add_response!(responses::BareResponses, response::Response)::BareResponses
push!(responses.indices, response.index)
push!(responses.values, response.value)
responses
end

function add_response!(tracked_responses::TrackedResponses, response::Response)
function Responses.add_response!(tracked_responses::TrackedResponses, response::Response)
add_response!(tracked_responses.responses, response)
track!(tracked_responses)
end

function pop_response!(responses::BareResponses)::BareResponses
pop!(responses.indices)
pop!(responses.values)
responses
end

function pop_response!(tracked_responses::TrackedResponses)::TrackedResponses
function Responses.pop_response!(tracked_responses::TrackedResponses)::TrackedResponses
pop_response!(tracked_responses.responses)
tracked_responses
end

function Base.empty!(tracked_responses::TrackedResponses)
Base.empty!(tracked_responses.responses.indices)
Base.empty!(tracked_responses.responses.values)
Base.empty!(tracked_responses.responses)
end

function response_expectation(ability_estimator::DistributionAbilityEstimator,
Expand Down
2 changes: 1 addition & 1 deletion src/decision_tree/DecisionTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using ComputerAdaptiveTesting.ConfigBase: CatConfigBase
using ComputerAdaptiveTesting.PushVectors
using ComputerAdaptiveTesting.NextItemRules
using ComputerAdaptiveTesting.Aggregators
using ComputerAdaptiveTesting.Responses: BareResponses, Response
using ComputerAdaptiveTesting.Responses: BareResponses, Response, add_response!, pop_response!
using FittedItemBanks: AbstractItemBank, BooleanResponse, ResponseType

# TODO: Remove ability tracking from here?
Expand Down
25 changes: 23 additions & 2 deletions test/stateful.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using ComputerAdaptiveTesting.TerminationConditions: FixedItemsTerminationCondition
using ComputerAdaptiveTesting.NextItemRules: RandomNextItemRule
using ComputerAdaptiveTesting: Stateful
using ComputerAdaptiveTesting: require_testext
using ResumableFunctions
using Test: @test, @testset

Expand All @@ -26,7 +27,7 @@
@testset "StatefulCatConfig basic usage" begin
rules = CatRules(
FixedItemsTerminationCondition(2),
Dummy.DummyAbilityEstimator(0),
Dummy.DummyAbilityEstimator(0.0),
RandomNextItemRule()
)

Expand Down Expand Up @@ -54,7 +55,7 @@
@testset "Stateful next item selection" begin
rules = CatRules(
FixedItemsTerminationCondition(2),
Dummy.DummyAbilityEstimator(0),
Dummy.DummyAbilityEstimator(0.0),
RandomNextItemRule()
)
cat_config = Stateful.StatefulCatConfig(rules, item_bank)
Expand All @@ -69,4 +70,24 @@
@test 1 <= second_item <= 4
@test second_item != first_item # Should select different item
end

@testset "Standard interface tests" begin
rules = CatRules(
FixedItemsTerminationCondition(2),
Dummy.DummyAbilityEstimator(0.0),
RandomNextItemRule()
)

# Initialize config
cat_config = Stateful.StatefulCatConfig(rules, item_bank)

# Run the standard interface tests
TestExt = require_testext()
TestExt.test_stateful_cat_1d_dich_ib(
cat_config,
4;
supports_ranked_and_criteria = false,
)
TestExt.test_stateful_cat_item_bank_1d_dich_ib(cat_config, item_bank)
end
end
Loading