Skip to content

Commit

Permalink
Track processed samples in md
Browse files Browse the repository at this point in the history
Closes #403
  • Loading branch information
jeromekelleher committed Dec 1, 2024
1 parent d909cd1 commit 250fcc4
Showing 1 changed file with 34 additions and 3 deletions.
37 changes: 34 additions & 3 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ def initial_ts(problematic_sites=list()):
"date": {},
"node": {},
},
"num_samples_processed": {},
"samples_processed": {},
"samples_rejected": {},
"retro_groups": [],
}
}
Expand Down Expand Up @@ -338,6 +339,7 @@ class Sample:
strain: str
date: str = "1999-01-01"
pango: str = "Unknown"
scorpio: str = "Unknown"
metadata: Dict = dataclasses.field(default_factory=dict)
alignment_composition: Dict = None
haplotype: List = None
Expand Down Expand Up @@ -620,6 +622,7 @@ def extend(
)
# FIXME parametrise
pango_lineage_key = "Viridian_pangolin"
scorpio_key = "Viridian_scorpio"

include_strains = set(include_samples)
unconditional_include_samples = []
Expand All @@ -631,6 +634,7 @@ def extend(
md = metadata_matches[s.strain]
s.metadata = md
s.pango = md.get(pango_lineage_key, "Unknown")
s.scorpio = md.get(scorpio_key, "Unknown")
s.date = date
num_missing_sites = s.num_missing_sites
num_deletion_sites = s.num_deletion_sites
Expand Down Expand Up @@ -732,11 +736,38 @@ def update_top_level_metadata(ts, date, retro_groups, samples):
num_samples = len(samples)
samples_strain = md["sc2ts"]["samples_strain"]
new_samples = ts.samples()[len(samples_strain) :]
inserted_samples = set()
for u in new_samples:
node = ts.node(u)
samples_strain.append(node.metadata["strain"])
s = node.metadata["strain"]
samples_strain.append(s)
inserted_samples.add(s)

overall_processed = collections.Counter()
overall_hmm_cost = collections.Counter()
rejected = collections.Counter()
rejected_hmm_cost = collections.Counter()
for sample in samples:
overall_processed[sample.scorpio] += 1
overall_hmm_cost[sample.scorpio] += float(sample.hmm_match.cost)
if sample.strain not in inserted_samples and sample.hmm_match.cost > 0:
rejected[sample.scorpio] += 1
rejected[sample.scorpio] += float(sample.hmm_match.cost)

for scorpio in overall_processed.keys():
overall_hmm_cost[scorpio] /= overall_processed[scorpio]
for scorpio in rejected.keys():
rejected_hmm_cost[scorpio] /= rejected[scorpio]

md["sc2ts"]["samples_strain"] = samples_strain
md["sc2ts"]["num_samples_processed"][date] = num_samples
md["sc2ts"]["samples_processed"][date] = {
"count": dict(overall_processed),
"mean_hmm_cost": dict(overall_hmm_cost),
}
md["sc2ts"]["samples_rejected"][date] = {
"count": dict(rejected),
"mean_hmm_cost": dict(rejected_hmm_cost),
}
existing_retro_groups = md["sc2ts"].get("retro_groups", [])
if isinstance(existing_retro_groups, dict):
# Hack to implement metadata format change
Expand Down

0 comments on commit 250fcc4

Please sign in to comment.