Skip to content

Commit

Permalink
Merge pull request #417 from jeromekelleher/record-counts
Browse files Browse the repository at this point in the history
Record counts
  • Loading branch information
jeromekelleher authored Dec 3, 2024
2 parents 4461b21 + fe49a51 commit 2e0ccb3
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 32 deletions.
44 changes: 38 additions & 6 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ def initial_ts(problematic_sites=list()):
"date": {},
"node": {},
},
"num_samples_processed": {},
"samples_processed": {},
"samples_rejected": {},
"retro_groups": [],
}
}
Expand Down Expand Up @@ -338,6 +339,7 @@ class Sample:
strain: str
date: str = "1999-01-01"
pango: str = "Unknown"
scorpio: str = "Unknown"
metadata: Dict = dataclasses.field(default_factory=dict)
alignment_composition: Dict = None
haplotype: List = None
Expand Down Expand Up @@ -620,6 +622,7 @@ def extend(
)
# FIXME parametrise
pango_lineage_key = "Viridian_pangolin"
scorpio_key = "Viridian_scorpio"

include_strains = set(include_samples)
unconditional_include_samples = []
Expand All @@ -631,6 +634,7 @@ def extend(
md = metadata_matches[s.strain]
s.metadata = md
s.pango = md.get(pango_lineage_key, "Unknown")
s.scorpio = md.get(scorpio_key, "Unknown")
s.date = date
num_missing_sites = s.num_missing_sites
num_deletion_sites = s.num_deletion_sites
Expand Down Expand Up @@ -722,20 +726,48 @@ def extend(
f"Add retro group {group.summary()}:"
f"{group.tree_quality_metrics.summary()}"
)
return update_top_level_metadata(ts, date, groups, len(samples))
return update_top_level_metadata(ts, date, groups, samples)


def update_top_level_metadata(ts, date, retro_groups, num_samples):
def update_top_level_metadata(ts, date, retro_groups, samples):
tables = ts.dump_tables()
md = tables.metadata
md["sc2ts"]["date"] = date
num_samples = len(samples)
samples_strain = md["sc2ts"]["samples_strain"]
new_samples = ts.samples()[len(samples_strain) :]
inserted_samples = set()
for u in new_samples:
node = ts.node(u)
samples_strain.append(node.metadata["strain"])
s = node.metadata["strain"]
samples_strain.append(s)
inserted_samples.add(s)

overall_processed = collections.Counter()
overall_hmm_cost = collections.Counter()
rejected = collections.Counter()
rejected_hmm_cost = collections.Counter()
for sample in samples:
overall_processed[sample.scorpio] += 1
overall_hmm_cost[sample.scorpio] += float(sample.hmm_match.cost)
if sample.strain not in inserted_samples and sample.hmm_match.cost > 0:
rejected[sample.scorpio] += 1
rejected[sample.scorpio] += float(sample.hmm_match.cost)

for scorpio in overall_processed.keys():
overall_hmm_cost[scorpio] /= overall_processed[scorpio]
for scorpio in rejected.keys():
rejected_hmm_cost[scorpio] /= rejected[scorpio]

md["sc2ts"]["samples_strain"] = samples_strain
md["sc2ts"]["num_samples_processed"][date] = num_samples
md["sc2ts"]["samples_processed"][date] = {
"count": dict(overall_processed),
"mean_hmm_cost": dict(overall_hmm_cost),
}
md["sc2ts"]["samples_rejected"][date] = {
"count": dict(rejected),
"mean_hmm_cost": dict(rejected_hmm_cost),
}
existing_retro_groups = md["sc2ts"].get("retro_groups", [])
if isinstance(existing_retro_groups, dict):
# Hack to implement metadata format change
Expand Down Expand Up @@ -1605,7 +1637,7 @@ def get_closest_mutation(node, site_id):
mutation.is_immediate_reversion = (
closest_mutation.node == seg.parent
)
logger.debug(f"Characterised {num_mutations}")
logger.debug(f"Characterised {num_mutations} mutations")


def attach_tree(
Expand Down
139 changes: 113 additions & 26 deletions sc2ts/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,41 @@
logger = logging.getLogger(__name__)


def scorpio_to_major_voc(scorpio):
keys = {"Alpha", "Delta", "BA.1", "BA.2", "BA.4", "BA.5", "BQ.1", "XBB"}
for k in keys:
if k in scorpio:
return k
if "Omicron" in scorpio:
return "Other_Omicron"
return "Other"


def merge_scorpio_columns(df):
out_cols = {}
for col, values in df.items():
if col.endswith("_count"):
col = col[: -len("_count")]
voc = scorpio_to_major_voc(col)
if voc not in out_cols:
out_cols[voc] = values.copy()
else:
out_cols[voc] += values
elif col == "date":
out_cols[col] = values
return pd.DataFrame(out_cols)


def compute_fractions(df):
sum_col = sum(col for _, col in df.items())
# print(df.copy)
copy = df.copy()
# print(copy)
for k, v in df.items():
copy[k] = v / sum_col
return copy


@dataclasses.dataclass
class LineageDetails:
"""
Expand Down Expand Up @@ -862,17 +897,32 @@ def site_summary(self, position):
def samples_summary(self):
data = []
md = self.ts.metadata["sc2ts"]
samples_processed = md["samples_processed"]
for days_ago in np.arange(self.num_samples_per_day.shape[0]):
date = str(self.time_zero_as_date - days_ago)
data.append(
{
"date": self.time_zero_as_date - days_ago,
"samples_in_arg": self.num_samples_per_day[days_ago],
"samples_processed": md["num_samples_processed"].get(date, 0),
"exact_matches": md["exact_matches"]["date"].get(date, 0),
}
)
return pd.DataFrame(data)
processed_by_date = samples_processed.get(date, {})
count = processed_by_date.get("count", {})
mean_hmm_cost = processed_by_date.get("mean_hmm_cost", {})
datum = {}
total_count = 0
total_hmm_cost = 0
for scorpio in count:
datum[f"{scorpio}_count"] = count[scorpio]
datum[f"{scorpio}_hmm_cost"] = mean_hmm_cost[scorpio]
total_count += count[scorpio]
total_hmm_cost += count[scorpio] * mean_hmm_cost[scorpio]

datum = {
"date": self.time_zero_as_date - days_ago,
"samples_in_arg": self.num_samples_per_day[days_ago],
"samples_processed": total_count,
"mean_hmm_cost": total_hmm_cost / max(1, total_count),
"exact_matches": md["exact_matches"]["date"].get(date, 0),
**datum,
}
data.append(datum)

return pd.DataFrame(data).fillna(0)

def recombinants_summary(self):
data = []
Expand Down Expand Up @@ -1505,29 +1555,60 @@ def plot_deletion_overlaps(self, annotate_threshold=0.9):
def plot_samples_per_day(self, start_date="2020-04-01"):
df = self.samples_summary()
df = df[df.date >= start_date]
fig, (ax1, ax2) = self._wide_plot(2, height=6, sharex=True)

ax1.plot(df.date, df.samples_in_arg, label="In ARG")
fig, (ax1, ax2, ax3, ax4) = self._wide_plot(4, height=12, sharex=True)
exact_col = "tab:red"
in_col = "tab:purple"
ax1.plot(df.date, df.samples_in_arg, label="In ARG", color=in_col)
ax1.plot(df.date, df.samples_processed, label="Processed")
ax1.plot(df.date, df.exact_matches, label="Exact matches")
ax1.plot(df.date, df.exact_matches, label="Exact matches", color=exact_col)

ax2.plot(
df.date,
df.samples_in_arg / df.samples_processed,
label="Fraction processed in ARG",
color=in_col,
)
ax2.plot(
df.date,
df.exact_matches / df.samples_processed,
label="Fraction processed exact matches",
color=exact_col,
)
excluded = df.samples_processed - df.exact_matches - df.samples_in_arg
ax2.plot(df.date, excluded / df.samples_processed, label="Fraction excluded")
ax2.set_xlabel("Date")
ax3.plot(df.date, excluded / df.samples_processed, label="Fraction excluded")
ax3_2 = ax3.twinx()
ax3_2.plot(df.date, df.mean_hmm_cost, color="tab:orange")
ax3.set_xlabel("Date")
ax3_2.set_ylabel("Mean HMM cost")
ax1.set_ylabel("Number of samples")
ax1.legend()
ax2.legend()
return fig, [ax1, ax2]
ax3.legend()

df_major_voc = merge_scorpio_columns(df).set_index("date")
df_voc = compute_fractions(df_major_voc)
ax4.stackplot(
df_voc.index,
*[df_voc[voc] for voc in df_voc],
labels=[" ".join(k.split("_")) for k in df_voc],
alpha=0.7,
)
ax4.legend()

for lin in major_lineages:
if lin.date < self.date and lin.who_label not in ["Beta", "Gamma"]:
x_first = None
if len(self.pango_lineage_samples[lin.pango_lineage]) > 0:
first_sample_date = self.nodes_metadata[
self.pango_lineage_samples[lin.pango_lineage][0]
]["date"]
x_first = np.array([first_sample_date], dtype="datetime64[D]")[0]
ax4.annotate(
f"first\n{lin.pango_lineage}", xy=(x_first, 0), xycoords="data"
)
ax4.axvline(x_first, color="grey", alpha=0.5)

return fig, [ax1, ax2, ax3, ax4]

def plot_resources(self, start_date="2020-04-01"):
ts = self.ts
Expand All @@ -1538,6 +1619,7 @@ def plot_resources(self, start_date="2020-04-01"):
# Should be able to do this with join, but I failed
df["samples_in_arg"] = dfs.loc[df.index]["samples_in_arg"]
df["samples_processed"] = dfs.loc[df.index]["samples_processed"]
df["mean_hmm_cost"] = dfs.loc[df.index]["mean_hmm_cost"]

df = df[df.index >= start_date]
df["cpu_time"] = df.user_time + df.sys_time
Expand Down Expand Up @@ -1565,14 +1647,20 @@ def plot_resources(self, start_date="2020-04-01"):
ax_twin.legend()
ax[1].plot(x, df.elapsed_time / df.samples_processed)
ax[1].set_ylabel("Elapsed time per sample (s)")
ax_twin = ax[1].twinx()
ax_twin.plot(
x, df.mean_hmm_cost, color="tab:purple", alpha=0.5, label="HMM cost"
)
ax_twin.set_ylabel("HMM cost")
ax[2].plot(x, df.max_memory / 1024**3)
ax[2].set_ylabel("Max memory (GiB)")
return fig, ax

def resources_summary(self):
ts = self.ts
data = []
dates = sorted(list(ts.metadata["sc2ts"]["num_samples_processed"].keys()))
samples_processed = ts.metadata["sc2ts"]["samples_processed"]
dates = sorted(list(samples_processed.keys()))
assert len(dates) == ts.num_provenances - 1
for j in range(1, ts.num_provenances):
p = ts.provenance(j)
Expand Down Expand Up @@ -1707,14 +1795,13 @@ def draw_subtree(
time = tables.mutations.time
time[:] = tskit.UNKNOWN_TIME
tables.mutations.time = time
#rescale to negative times
# rescale to negative times
if date_format == "from_zero":
for node in reversed(self.ts.nodes(order="timeasc")):
if node.is_sample():
break
tables.nodes.time = tables.nodes.time - node.time
ts = tables.tree_sequence()


tracked_nodes = []
if tracked_pango is not None:
Expand Down Expand Up @@ -1802,9 +1889,7 @@ def draw_subtree(
inherited_state = parent.derived_state
parent_inherited_state = site.ancestral_state
if parent.parent >= 0:
parent_inherited_state = ts.mutation(
parent.parent
).derived_state
parent_inherited_state = ts.mutation(parent.parent).derived_state
if parent_inherited_state == mut.derived_state:
reverted_mutations.append(mut.id)
pos = int(site.position)
Expand All @@ -1814,8 +1899,9 @@ def draw_subtree(
# If more than one mutation has the same label, add a prefix with the counts
if append_mutation_recurrence:
num_recurrent = {
m_id: (i+1, len(ids))
for ids in recurrent_mutations.values() for i, m_id in enumerate(ids)
m_id: (i + 1, len(ids))
for ids in recurrent_mutations.values()
for i, m_id in enumerate(ids)
if len(ids) > 1
}
for m_id, (i, n) in num_recurrent.items():
Expand Down Expand Up @@ -1862,8 +1948,9 @@ def draw_subtree(
}
else:
# only place ticks at the sample nodes
y_ticks = {t: ts.node(u).metadata.get("date", "")
for u, t in zip(shown_nodes, shown_times)
y_ticks = {
t: ts.node(u).metadata.get("date", "")
for u, t in zip(shown_nodes, shown_times)
}
return tree.draw_svg(
time_scale=time_scale,
Expand Down
13 changes: 13 additions & 0 deletions tests/test_info.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import inspect

import pytest
Expand Down Expand Up @@ -218,6 +219,18 @@ def test_resources_summary(self, fx_ti_2020_02_13):
assert df.shape[0] == 20
assert np.all(df.date.str.startswith("2020"))

def test_samples_summary(self, fx_ti_2020_02_13):
df = fx_ti_2020_02_13.samples_summary()
# NOTE: just doing this subsetting here to get rid of the annoying reference
# as-sample issue, which mucks up counts. Should be able to get rid of this
# after closing https://github.com/jeromekelleher/sc2ts/issues/413
df = df[df["date"] >= datetime.datetime.fromisoformat("2020-01-01")]
assert np.all(
df["samples_processed"] >= (df["samples_in_arg"] + df["exact_matches"])
)
assert df.shape[0] > 0
assert np.all(df["samples_processed"] == df["._count"])


class TestSampleGroupInfo:
def test_draw_svg(self, fx_ti_2020_02_13):
Expand Down

0 comments on commit 2e0ccb3

Please sign in to comment.