Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix year-weighted sampling #1778

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
### Bug fixes

* titers: Improve error messages when titer models do not have enough data. [#1769][] (@huddlej)
* filter: Fixed an error with weighted sampling by `year`. [#1776][] (@victorlin)

[#1769]: https://github.com/nextstrain/augur/pull/1769
[#1776]: https://github.com/nextstrain/augur/issues/1776

## 29.0.0 (26 February 2025)

Expand Down
8 changes: 4 additions & 4 deletions augur/filter/subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_groups_for_subsampling(strains, metadata, group_by=None):
>>> group_by = ["year", "month"]
>>> group_by_strain = get_groups_for_subsampling(strains, metadata, group_by)
>>> group_by_strain
{'strain1': (2020, '2020-01'), 'strain2': (2020, '2020-02')}
{'strain1': ('2020', '2020-01'), 'strain2': ('2020', '2020-02')}

If we omit the grouping columns, the result will group by a dummy column.

Expand All @@ -73,7 +73,7 @@ def get_groups_for_subsampling(strains, metadata, group_by=None):
>>> group_by = ["year", "month", "missing_column"]
>>> group_by_strain = get_groups_for_subsampling(strains, metadata, group_by)
>>> group_by_strain
{'strain1': (2020, '2020-01', 'unknown'), 'strain2': (2020, '2020-02', 'unknown')}
{'strain1': ('2020', '2020-01', 'unknown'), 'strain2': ('2020', '2020-02', 'unknown')}

We can group metadata without any non-ID columns.

Expand Down Expand Up @@ -142,7 +142,7 @@ def get_groups_for_subsampling(strains, metadata, group_by=None):

# Generate columns.
if constants.DATE_YEAR_COLUMN in generated_columns_requested:
metadata[constants.DATE_YEAR_COLUMN] = metadata[f'{temp_prefix}year']
metadata[constants.DATE_YEAR_COLUMN] = metadata[f'{temp_prefix}year'].astype(str)
if constants.DATE_MONTH_COLUMN in generated_columns_requested:
metadata[constants.DATE_MONTH_COLUMN] = metadata.apply(lambda row: get_year_month(
row[f'{temp_prefix}year'],
Expand Down Expand Up @@ -386,7 +386,7 @@ def _add_unweighted_columns(
values_for_unweighted_columns[column].add(column_to_value_map[column])

# Create a DataFrame for all permutations of values in unweighted columns.
lists = [list(values_for_unweighted_columns[column]) for column in unweighted_columns]
lists = [sorted(values_for_unweighted_columns[column]) for column in unweighted_columns]
unweighted_permutations = pd.DataFrame(list(itertools.product(*lists)), columns=unweighted_columns)

return pd.merge(unweighted_permutations, weights, how='cross')
Expand Down
22 changes: 16 additions & 6 deletions augur/filter/weights_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@ def __init__(self, file, error_message):


def read_weights_file(weights_file):
weights = pd.read_csv(weights_file, delimiter='\t', comment='#')
weights = pd.read_csv(weights_file, delimiter='\t', comment='#', dtype="string")

if not pd.api.types.is_numeric_dtype(weights[WEIGHTS_COLUMN]):
non_numeric_weight_lines = [index + 2 for index in weights[~weights[WEIGHTS_COLUMN].str.isnumeric()].index.tolist()]
if non_numeric_weight_lines := [index + 2 for index in weights[~weights[WEIGHTS_COLUMN].str.lstrip("-").str.isnumeric()].index.tolist()]:
raise InvalidWeightsFile(weights_file, dedent(f"""\
Found non-numeric weights on the following lines: {non_numeric_weight_lines}
{WEIGHTS_COLUMN!r} column must be numeric."""))

if any(weights[WEIGHTS_COLUMN] < 0):
negative_weight_lines = [index + 2 for index in weights[weights[WEIGHTS_COLUMN] < 0].index.tolist()]
# Cast weights to numeric for calculations
weights[WEIGHTS_COLUMN] = pd.to_numeric(weights[WEIGHTS_COLUMN])

if negative_weight_lines := [index + 2 for index in weights[weights[WEIGHTS_COLUMN] < 0].index.tolist()]:
raise InvalidWeightsFile(weights_file, dedent(f"""\
Found negative weights on the following lines: {negative_weight_lines}
{WEIGHTS_COLUMN!r} column must be non-negative."""))
Expand All @@ -47,7 +48,16 @@ def get_weighted_columns(weights_file):


def get_default_weight(weights: pd.DataFrame, weighted_columns: List[str]):
default_weight_values = weights[(weights[weighted_columns] == COLUMN_VALUE_FOR_DEFAULT_WEIGHT).all(axis=1)][WEIGHTS_COLUMN].unique()
# Match weighted columns with 'default' value. Multiple values can be matched for 2 reasons:
# 1. Repeated rows following additional permutation with unweighted columns.
# This is handled by unique() since the value should be the same.
# 2. Multiple default rows specified in the weights file.
# This is a user error.
mask = (
weights[weighted_columns].eq(COLUMN_VALUE_FOR_DEFAULT_WEIGHT).all(axis=1) &
weights[weighted_columns].notna().all(axis=1)
)
default_weight_values = weights.loc[mask, WEIGHTS_COLUMN].unique()

if len(default_weight_values) > 1:
# TODO: raise InvalidWeightsFile, not AugurError. This function takes
Expand Down
30 changes: 15 additions & 15 deletions tests/filter/test_subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ def test_filter_groupby_invalid_warn(self, valid_metadata: pd.DataFrame, capsys)
strains = metadata.index.tolist()
group_by_strain = augur.filter.subsample.get_groups_for_subsampling(strains, metadata, group_by=groups)
assert group_by_strain == {
'SEQ_1': ('A', 2020, '2020-01', 'unknown'),
'SEQ_2': ('A', 2020, '2020-02', 'unknown'),
'SEQ_3': ('B', 2020, '2020-03', 'unknown'),
'SEQ_4': ('B', 2020, '2020-04', 'unknown'),
'SEQ_5': ('B', 2020, '2020-05', 'unknown')
'SEQ_1': ('A', '2020', '2020-01', 'unknown'),
'SEQ_2': ('A', '2020', '2020-02', 'unknown'),
'SEQ_3': ('B', '2020', '2020-03', 'unknown'),
'SEQ_4': ('B', '2020', '2020-04', 'unknown'),
'SEQ_5': ('B', '2020', '2020-05', 'unknown')
}
captured = capsys.readouterr()
assert captured.err == "WARNING: Some of the specified group-by categories couldn't be found: invalid\nFiltering by group may behave differently than expected!\n"
Expand Down Expand Up @@ -136,11 +136,11 @@ def test_filter_groupby_only_year_provided(self, valid_metadata: pd.DataFrame):
strains = metadata.index.tolist()
group_by_strain = augur.filter.subsample.get_groups_for_subsampling(strains, metadata, group_by=groups)
assert group_by_strain == {
'SEQ_1': ('A', 2020),
'SEQ_2': ('A', 2020),
'SEQ_3': ('B', 2020),
'SEQ_4': ('B', 2020),
'SEQ_5': ('B', 2020)
'SEQ_1': ('A', '2020'),
'SEQ_2': ('A', '2020'),
'SEQ_3': ('B', '2020'),
'SEQ_4': ('B', '2020'),
'SEQ_5': ('B', '2020')
}

def test_filter_groupby_only_year_month_provided(self, valid_metadata: pd.DataFrame):
Expand All @@ -150,9 +150,9 @@ def test_filter_groupby_only_year_month_provided(self, valid_metadata: pd.DataFr
strains = metadata.index.tolist()
group_by_strain = augur.filter.subsample.get_groups_for_subsampling(strains, metadata, group_by=groups)
assert group_by_strain == {
'SEQ_1': ('A', 2020, '2020-01'),
'SEQ_2': ('A', 2020, '2020-01'),
'SEQ_3': ('B', 2020, '2020-01'),
'SEQ_4': ('B', 2020, '2020-01'),
'SEQ_5': ('B', 2020, '2020-01')
'SEQ_1': ('A', '2020', '2020-01'),
'SEQ_2': ('A', '2020', '2020-01'),
'SEQ_3': ('B', '2020', '2020-01'),
'SEQ_4': ('B', '2020', '2020-01'),
'SEQ_5': ('B', '2020', '2020-01')
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ requested 17, so the total number of sequences outputted is lower than requested
> --output-strains strains.txt
Sampling with weights defined by weights-A1B1.tsv.
NOTE: Weights were not provided for the column 'year'. Using equal weights across values in that column.
WARNING: Targeted 17 sequences for group ['year=2002', "location='A'"] but only 1 is available.
WARNING: Targeted 17 sequences for group ["year='2002'", "location='A'"] but only 1 is available.
168 strains were dropped during filtering
168 were dropped because of subsampling criteria
83 strains passed all filters
Expand Down
40 changes: 40 additions & 0 deletions tests/functional/filter/cram/subsample-weighted-year.t
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
Setup

$ source "$TESTDIR"/_setup.sh

Set up files.

$ cat >metadata.tsv <<~~
> strain date
> SEQ1 2001-01-01
> SEQ2 2001-01-01
> SEQ3 2002-01-01
> SEQ4 2002-01-01
> SEQ5 2003-01-01
> SEQ6 2003-01-01
> ~~

$ cat >weights.tsv <<~~
> year weight
> 2001 2
> 2002 2
> default 1
> ~~

Subsample with year weights.

$ ${AUGUR} filter \
> --metadata metadata.tsv \
> --group-by year \
> --group-by-weights weights.tsv \
> --subsample-max-sequences 5 \
> --subsample-seed 0 \
> --output-strains strains.txt
Sampling with weights defined by weights.tsv.
WARNING: The input metadata contains these values under the following columns that are not directly covered by 'weights.tsv':
- 'year': ['2003']
The default weight of 1 will be used for all groups defined by those values.
NOTE: Skipping 1 group due to lack of entries in metadata.
1 strain was dropped during filtering
1 was dropped because of subsampling criteria
5 strains passed all filters
Loading