Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Record counts #417

Merged
merged 4 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 38 additions & 6 deletions sc2ts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ def initial_ts(problematic_sites=list()):
"date": {},
"node": {},
},
"num_samples_processed": {},
"samples_processed": {},
"samples_rejected": {},
"retro_groups": [],
}
}
Expand Down Expand Up @@ -338,6 +339,7 @@ class Sample:
strain: str
date: str = "1999-01-01"
pango: str = "Unknown"
scorpio: str = "Unknown"
metadata: Dict = dataclasses.field(default_factory=dict)
alignment_composition: Dict = None
haplotype: List = None
Expand Down Expand Up @@ -620,6 +622,7 @@ def extend(
)
# FIXME parametrise
pango_lineage_key = "Viridian_pangolin"
scorpio_key = "Viridian_scorpio"

include_strains = set(include_samples)
unconditional_include_samples = []
Expand All @@ -631,6 +634,7 @@ def extend(
md = metadata_matches[s.strain]
s.metadata = md
s.pango = md.get(pango_lineage_key, "Unknown")
s.scorpio = md.get(scorpio_key, "Unknown")
s.date = date
num_missing_sites = s.num_missing_sites
num_deletion_sites = s.num_deletion_sites
Expand Down Expand Up @@ -722,20 +726,48 @@ def extend(
f"Add retro group {group.summary()}:"
f"{group.tree_quality_metrics.summary()}"
)
return update_top_level_metadata(ts, date, groups, len(samples))
return update_top_level_metadata(ts, date, groups, samples)


def update_top_level_metadata(ts, date, retro_groups, num_samples):
def update_top_level_metadata(ts, date, retro_groups, samples):
tables = ts.dump_tables()
md = tables.metadata
md["sc2ts"]["date"] = date
num_samples = len(samples)
samples_strain = md["sc2ts"]["samples_strain"]
new_samples = ts.samples()[len(samples_strain) :]
inserted_samples = set()
for u in new_samples:
node = ts.node(u)
samples_strain.append(node.metadata["strain"])
s = node.metadata["strain"]
samples_strain.append(s)
inserted_samples.add(s)

overall_processed = collections.Counter()
overall_hmm_cost = collections.Counter()
rejected = collections.Counter()
rejected_hmm_cost = collections.Counter()
for sample in samples:
overall_processed[sample.scorpio] += 1
overall_hmm_cost[sample.scorpio] += float(sample.hmm_match.cost)
if sample.strain not in inserted_samples and sample.hmm_match.cost > 0:
rejected[sample.scorpio] += 1
rejected[sample.scorpio] += float(sample.hmm_match.cost)

for scorpio in overall_processed.keys():
overall_hmm_cost[scorpio] /= overall_processed[scorpio]
for scorpio in rejected.keys():
rejected_hmm_cost[scorpio] /= rejected[scorpio]

md["sc2ts"]["samples_strain"] = samples_strain
md["sc2ts"]["num_samples_processed"][date] = num_samples
md["sc2ts"]["samples_processed"][date] = {
"count": dict(overall_processed),
"mean_hmm_cost": dict(overall_hmm_cost),
}
md["sc2ts"]["samples_rejected"][date] = {
"count": dict(rejected),
"mean_hmm_cost": dict(rejected_hmm_cost),
}
existing_retro_groups = md["sc2ts"].get("retro_groups", [])
if isinstance(existing_retro_groups, dict):
# Hack to implement metadata format change
Expand Down Expand Up @@ -1605,7 +1637,7 @@ def get_closest_mutation(node, site_id):
mutation.is_immediate_reversion = (
closest_mutation.node == seg.parent
)
logger.debug(f"Characterised {num_mutations}")
logger.debug(f"Characterised {num_mutations} mutations")


def attach_tree(
Expand Down
139 changes: 113 additions & 26 deletions sc2ts/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,41 @@
logger = logging.getLogger(__name__)


def scorpio_to_major_voc(scorpio):
keys = {"Alpha", "Delta", "BA.1", "BA.2", "BA.4", "BA.5", "BQ.1", "XBB"}
for k in keys:
if k in scorpio:
return k
if "Omicron" in scorpio:
return "Other_Omicron"
return "Other"


def merge_scorpio_columns(df):
out_cols = {}
for col, values in df.items():
if col.endswith("_count"):
col = col[: -len("_count")]
voc = scorpio_to_major_voc(col)
if voc not in out_cols:
out_cols[voc] = values.copy()
else:
out_cols[voc] += values
elif col == "date":
out_cols[col] = values
return pd.DataFrame(out_cols)


def compute_fractions(df):
sum_col = sum(col for _, col in df.items())
# print(df.copy)
copy = df.copy()
# print(copy)
for k, v in df.items():
copy[k] = v / sum_col
return copy


@dataclasses.dataclass
class LineageDetails:
"""
Expand Down Expand Up @@ -862,17 +897,32 @@ def site_summary(self, position):
def samples_summary(self):
data = []
md = self.ts.metadata["sc2ts"]
samples_processed = md["samples_processed"]
for days_ago in np.arange(self.num_samples_per_day.shape[0]):
date = str(self.time_zero_as_date - days_ago)
data.append(
{
"date": self.time_zero_as_date - days_ago,
"samples_in_arg": self.num_samples_per_day[days_ago],
"samples_processed": md["num_samples_processed"].get(date, 0),
"exact_matches": md["exact_matches"]["date"].get(date, 0),
}
)
return pd.DataFrame(data)
processed_by_date = samples_processed.get(date, {})
count = processed_by_date.get("count", {})
mean_hmm_cost = processed_by_date.get("mean_hmm_cost", {})
datum = {}
total_count = 0
total_hmm_cost = 0
for scorpio in count:
datum[f"{scorpio}_count"] = count[scorpio]
datum[f"{scorpio}_hmm_cost"] = mean_hmm_cost[scorpio]
total_count += count[scorpio]
total_hmm_cost += count[scorpio] * mean_hmm_cost[scorpio]

datum = {
"date": self.time_zero_as_date - days_ago,
"samples_in_arg": self.num_samples_per_day[days_ago],
"samples_processed": total_count,
"mean_hmm_cost": total_hmm_cost / max(1, total_count),
"exact_matches": md["exact_matches"]["date"].get(date, 0),
**datum,
}
data.append(datum)

return pd.DataFrame(data).fillna(0)

def recombinants_summary(self):
data = []
Expand Down Expand Up @@ -1505,29 +1555,60 @@ def plot_deletion_overlaps(self, annotate_threshold=0.9):
def plot_samples_per_day(self, start_date="2020-04-01"):
df = self.samples_summary()
df = df[df.date >= start_date]
fig, (ax1, ax2) = self._wide_plot(2, height=6, sharex=True)

ax1.plot(df.date, df.samples_in_arg, label="In ARG")
fig, (ax1, ax2, ax3, ax4) = self._wide_plot(4, height=12, sharex=True)
exact_col = "tab:red"
in_col = "tab:purple"
ax1.plot(df.date, df.samples_in_arg, label="In ARG", color=in_col)
ax1.plot(df.date, df.samples_processed, label="Processed")
ax1.plot(df.date, df.exact_matches, label="Exact matches")
ax1.plot(df.date, df.exact_matches, label="Exact matches", color=exact_col)

ax2.plot(
df.date,
df.samples_in_arg / df.samples_processed,
label="Fraction processed in ARG",
color=in_col,
)
ax2.plot(
df.date,
df.exact_matches / df.samples_processed,
label="Fraction processed exact matches",
color=exact_col,
)
excluded = df.samples_processed - df.exact_matches - df.samples_in_arg
ax2.plot(df.date, excluded / df.samples_processed, label="Fraction excluded")
ax2.set_xlabel("Date")
ax3.plot(df.date, excluded / df.samples_processed, label="Fraction excluded")
ax3_2 = ax3.twinx()
ax3_2.plot(df.date, df.mean_hmm_cost, color="tab:orange")
ax3.set_xlabel("Date")
ax3_2.set_ylabel("Mean HMM cost")
ax1.set_ylabel("Number of samples")
ax1.legend()
ax2.legend()
return fig, [ax1, ax2]
ax3.legend()

df_major_voc = merge_scorpio_columns(df).set_index("date")
df_voc = compute_fractions(df_major_voc)
ax4.stackplot(
df_voc.index,
*[df_voc[voc] for voc in df_voc],
labels=[" ".join(k.split("_")) for k in df_voc],
alpha=0.7,
)
ax4.legend()

for lin in major_lineages:
if lin.date < self.date and lin.who_label not in ["Beta", "Gamma"]:
x_first = None
if len(self.pango_lineage_samples[lin.pango_lineage]) > 0:
first_sample_date = self.nodes_metadata[
self.pango_lineage_samples[lin.pango_lineage][0]
]["date"]
x_first = np.array([first_sample_date], dtype="datetime64[D]")[0]
ax4.annotate(
f"first\n{lin.pango_lineage}", xy=(x_first, 0), xycoords="data"
)
ax4.axvline(x_first, color="grey", alpha=0.5)

return fig, [ax1, ax2, ax3, ax4]

def plot_resources(self, start_date="2020-04-01"):
ts = self.ts
Expand All @@ -1538,6 +1619,7 @@ def plot_resources(self, start_date="2020-04-01"):
# Should be able to do this with join, but I failed
df["samples_in_arg"] = dfs.loc[df.index]["samples_in_arg"]
df["samples_processed"] = dfs.loc[df.index]["samples_processed"]
df["mean_hmm_cost"] = dfs.loc[df.index]["mean_hmm_cost"]

df = df[df.index >= start_date]
df["cpu_time"] = df.user_time + df.sys_time
Expand Down Expand Up @@ -1565,14 +1647,20 @@ def plot_resources(self, start_date="2020-04-01"):
ax_twin.legend()
ax[1].plot(x, df.elapsed_time / df.samples_processed)
ax[1].set_ylabel("Elapsed time per sample (s)")
ax_twin = ax[1].twinx()
ax_twin.plot(
x, df.mean_hmm_cost, color="tab:purple", alpha=0.5, label="HMM cost"
)
ax_twin.set_ylabel("HMM cost")
ax[2].plot(x, df.max_memory / 1024**3)
ax[2].set_ylabel("Max memory (GiB)")
return fig, ax

def resources_summary(self):
ts = self.ts
data = []
dates = sorted(list(ts.metadata["sc2ts"]["num_samples_processed"].keys()))
samples_processed = ts.metadata["sc2ts"]["samples_processed"]
dates = sorted(list(samples_processed.keys()))
assert len(dates) == ts.num_provenances - 1
for j in range(1, ts.num_provenances):
p = ts.provenance(j)
Expand Down Expand Up @@ -1707,14 +1795,13 @@ def draw_subtree(
time = tables.mutations.time
time[:] = tskit.UNKNOWN_TIME
tables.mutations.time = time
#rescale to negative times
# rescale to negative times
if date_format == "from_zero":
for node in reversed(self.ts.nodes(order="timeasc")):
if node.is_sample():
break
tables.nodes.time = tables.nodes.time - node.time
ts = tables.tree_sequence()


tracked_nodes = []
if tracked_pango is not None:
Expand Down Expand Up @@ -1802,9 +1889,7 @@ def draw_subtree(
inherited_state = parent.derived_state
parent_inherited_state = site.ancestral_state
if parent.parent >= 0:
parent_inherited_state = ts.mutation(
parent.parent
).derived_state
parent_inherited_state = ts.mutation(parent.parent).derived_state
if parent_inherited_state == mut.derived_state:
reverted_mutations.append(mut.id)
pos = int(site.position)
Expand All @@ -1814,8 +1899,9 @@ def draw_subtree(
# If more than one mutation has the same label, add a prefix with the counts
if append_mutation_recurrence:
num_recurrent = {
m_id: (i+1, len(ids))
for ids in recurrent_mutations.values() for i, m_id in enumerate(ids)
m_id: (i + 1, len(ids))
for ids in recurrent_mutations.values()
for i, m_id in enumerate(ids)
if len(ids) > 1
}
for m_id, (i, n) in num_recurrent.items():
Expand Down Expand Up @@ -1862,8 +1948,9 @@ def draw_subtree(
}
else:
# only place ticks at the sample nodes
y_ticks = {t: ts.node(u).metadata.get("date", "")
for u, t in zip(shown_nodes, shown_times)
y_ticks = {
t: ts.node(u).metadata.get("date", "")
for u, t in zip(shown_nodes, shown_times)
}
return tree.draw_svg(
time_scale=time_scale,
Expand Down
13 changes: 13 additions & 0 deletions tests/test_info.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import inspect

import pytest
Expand Down Expand Up @@ -218,6 +219,18 @@ def test_resources_summary(self, fx_ti_2020_02_13):
assert df.shape[0] == 20
assert np.all(df.date.str.startswith("2020"))

def test_samples_summary(self, fx_ti_2020_02_13):
df = fx_ti_2020_02_13.samples_summary()
# NOTE: just doing this subsetting here to get rid of the annoying reference
# as-sample issue, which mucks up counts. Should be able to get rid of this
# after closing https://github.com/jeromekelleher/sc2ts/issues/413
df = df[df["date"] >= datetime.datetime.fromisoformat("2020-01-01")]
assert np.all(
df["samples_processed"] >= (df["samples_in_arg"] + df["exact_matches"])
)
assert df.shape[0] > 0
assert np.all(df["samples_processed"] == df["._count"])


class TestSampleGroupInfo:
def test_draw_svg(self, fx_ti_2020_02_13):
Expand Down
Loading