Skip to content

Commit

Permalink
Implement basic copy function
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromekelleher committed Nov 26, 2024
1 parent d5c4f52 commit 9d73f6b
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 20 deletions.
4 changes: 1 addition & 3 deletions sc2ts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,13 +238,11 @@ def import_metadata(dataset, metadata, viridian, verbose):
df_in = pd.read_csv(metadata, sep="\t")
# TODO do we need to do this?
# , dtype={"Artic_primer_version": str})
date_field = "date"
index_field = "Run"
if viridian:
df_in = sc2ts.massage_viridian_metadata(df_in)
date_field = "Collection_date"
df = df_in.set_index(index_field)
sc2ts.Dataset.add_metadata(dataset, df, date_field=date_field)
sc2ts.Dataset.add_metadata(dataset, df)


@click.command()
Expand Down
61 changes: 48 additions & 13 deletions sc2ts/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,18 +92,18 @@ def __len__(self):


class CachedMetadataMapping(collections.abc.Mapping):
def __init__(self, root, sample_id_map):
def __init__(self, root, sample_id_map, date_field):
# NOTE: this is definitely wasteful. We shouldn't load the sample_id field
# from zarr more than once, and reading in all the metadata arrays in
# one go unconditionally is also too slow.
self.sample_id_map = sample_id_map
self.sample_date = root["sample_date"][:].astype(str)
self.sample_date = root[f"sample_{date_field}"][:].astype(str)
self.sample_id = root["sample_id"][:].astype(str)
self.arrays = {}
prefix = "sample_"
# We might need to do this on a chunk-aware basis
for k, v in root.items():
if k.startswith(prefix) and k not in ("sample_id", "sample_date"):
if k.startswith(prefix) and k != "sample_id":
name = k[len(prefix) :]
logger.debug(f"Decompressing metadata {name}")
self.arrays[name] = v[:]
Expand Down Expand Up @@ -140,6 +140,11 @@ def __len__(self):
def samples_for_date(self, date):
return self.sample_id[self.sample_date == date]

def as_dataframe(self):
return pd.DataFrame({"sample_id": self.sample_id, **self.arrays}).set_index(
"sample_id"
)


@dataclasses.dataclass
class Variant:
Expand All @@ -150,7 +155,7 @@ class Variant:

class Dataset:

def __init__(self, path, chunk_cache_size=1):
def __init__(self, path, chunk_cache_size=1, date_field="date"):
self.path = pathlib.Path(path)
if self.path.suffix == ".zip":
self.store = zarr.ZipStore(path)
Expand All @@ -166,7 +171,7 @@ def __init__(self, path, chunk_cache_size=1):
self.alignments = CachedAlignmentMapping(
self.root, self.sample_id_map, chunk_cache_size
)
self.metadata = CachedMetadataMapping(self.root, self.sample_id_map)
self.metadata = CachedMetadataMapping(self.root, self.sample_id_map, date_field)

@property
def num_samples(self):
Expand Down Expand Up @@ -206,8 +211,40 @@ def variants(self, sample_id, position):
)
j += 1

def copy(
self, path, samples_chunk_size=None, variants_chunk_size=None, sample_id=None
):
"""
Copy this dataset to the specified path.
If sample_id is specified, only include these samples in the specified order.
"""
if sample_id is None:
sample_id = self.root["sample_id"][:]
Dataset.new(
path,
samples_chunk_size=samples_chunk_size,
variants_chunk_size=variants_chunk_size,
)
alignments = {}
for s in sample_id:
alignments[s] = self.alignments[s]
if len(alignments) == samples_chunk_size:
Dataset.append_alignments(path, alignments)
alignments = {}
Dataset.append_alignments(path, alignments)

df = self.metadata.as_dataframe()
Dataset.add_metadata(path, df)

@staticmethod
def new(path, samples_chunk_size=10_000, variants_chunk_size=100):
def new(path, samples_chunk_size=None, variants_chunk_size=None):

if samples_chunk_size is None:
samples_chunk_size = 10_000
if variants_chunk_size is None:
variants_chunk_size = 100

logger.info(f"Creating new dataset at {path}")
L = core.REFERENCE_SEQUENCE_LENGTH - 1
N = 0 # Samples must be added
Expand Down Expand Up @@ -293,6 +330,8 @@ def append_alignments(path, alignments):
Append alignments to the store. If this method fails then the store
should be considered corrupt.
"""
if len(alignments) == 0:
return
store = zarr.DirectoryStore(path)
root = zarr.open(store, mode="a")

Expand All @@ -316,13 +355,10 @@ def append_alignments(path, alignments):
zarr.consolidate_metadata(store)

@staticmethod
def add_metadata(path, df, date_field):
def add_metadata(path, df):
"""
Add metadata from the specified dataframe, indexed by sample ID.
Each column will be added as a new array with prefix "sample_"
A "sample_date" field will be added as a copy of the given
date_field.
"""
store = zarr.DirectoryStore(path)
root = zarr.open(store, mode="a")
Expand All @@ -332,11 +368,10 @@ def add_metadata(path, df, date_field):
if samples.shape[0] == 0:
raise ValueError("Cannot add metadata to empty dataset")
df = df.loc[samples].copy()
df["date"] = df[date_field]
for colname in df:
data = df[colname].to_numpy()
dtype = data.dtype
if dtype == int:
if dtype.kind == "i":
max_v = data.max()
if max_v < 127:
dtype = "i1"
Expand Down Expand Up @@ -380,5 +415,5 @@ def tmp_dataset(path, alignments, date="2020-01-01"):
Dataset.new(path)
Dataset.append_alignments(path, alignments)
df = pd.DataFrame({"strain": alignments.keys(), "date": [date] * len(alignments)})
Dataset.add_metadata(path, df.set_index("strain"), "date")
Dataset.add_metadata(path, df.set_index("strain"))
return Dataset(path)
2 changes: 1 addition & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def test_defaults(self, fx_dataset):
catch_exceptions=False,
)
assert result.exit_code == 0
assert "with 55 samples and 25 metadata fields" in result.stdout
assert "with 55 samples and 26 metadata fields" in result.stdout

def test_zarr(self, fx_dataset):
runner = ct.CliRunner(mix_stderr=False)
Expand Down
21 changes: 18 additions & 3 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_add_metadata(self, tmp_path, fx_encoded_alignments, fx_metadata_df):
path = tmp_path / "dataset.vcz"
ds = sc2ts.Dataset.new(path)
sc2ts.Dataset.append_alignments(path, fx_encoded_alignments)
sc2ts.Dataset.add_metadata(path, fx_metadata_df, "date")
sc2ts.Dataset.add_metadata(path, fx_metadata_df)

sg_ds = sgkit.load_dataset(path)
assert dict(sg_ds.sizes) == {
Expand All @@ -147,7 +147,7 @@ def test_create_zip(self, tmp_path, fx_encoded_alignments, fx_metadata_df):
path = tmp_path / "dataset.vcz"
sc2ts.Dataset.new(path)
sc2ts.Dataset.append_alignments(path, fx_encoded_alignments)
sc2ts.Dataset.add_metadata(path, fx_metadata_df, "date")
sc2ts.Dataset.add_metadata(path, fx_metadata_df)
zip_path = tmp_path / "dataset.vcz.zip"
sc2ts.Dataset.create_zip(path, zip_path)

Expand All @@ -159,6 +159,13 @@ def test_create_zip(self, tmp_path, fx_encoded_alignments, fx_metadata_df):
for k in alignments1.keys():
nt.assert_array_equal(alignments1[k], alignments2[k])

def test_copy(self, tmp_path, fx_dataset):
path = tmp_path / "dataset.vcz"
fx_dataset.copy(path)
ds = sc2ts.Dataset(path)
# FIXME assert_dataset_equal
print(ds)


class TestDatasetAlignments:

Expand Down Expand Up @@ -205,7 +212,7 @@ def test_chunk_size_cache_size(
path = tmp_path / "dataset.vcz"
sc2ts.Dataset.new(path, samples_chunk_size=chunk_size)
sc2ts.Dataset.append_alignments(path, fx_encoded_alignments)
sc2ts.Dataset.add_metadata(path, fx_metadata_df, "date")
sc2ts.Dataset.add_metadata(path, fx_metadata_df)
ds = sc2ts.Dataset(path, chunk_cache_size=cache_size)
for k, v in fx_encoded_alignments.items():
nt.assert_array_equal(v, ds.alignments[k])
Expand Down Expand Up @@ -238,6 +245,14 @@ def test_samples_for_date(self, fx_dataset):
samples = fx_dataset.metadata.samples_for_date("2020-01-19")
assert samples == ["SRR11772659"]

def test_as_dataframe(self, fx_dataset, fx_metadata_df):
df1 = fx_dataset.metadata.as_dataframe()
df2 = fx_metadata_df.loc[df1.index]
assert df1.shape[0] == df2.shape[0]
for col, data1 in df2.items():
data2 = df2[col]
nt.assert_array_equal(data1.to_numpy(), data2.to_numpy())


class TestEncodeAlignment:
@pytest.mark.parametrize(
Expand Down

0 comments on commit 9d73f6b

Please sign in to comment.