Skip to content

Commit

Permalink
build: replace pandas with polars (#178)
Browse files Browse the repository at this point in the history
  • Loading branch information
korikuzma committed Oct 10, 2023
1 parent cbd712f commit b616b92
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 47 deletions.
2 changes: 1 addition & 1 deletion Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ aiofiles = "*"
asyncpg = "*"
boto3 = "*"
pyliftover = "*"
pandas = "*"
polars = "*"
hgvs = "*"
"biocommons.seqrepo" = "*"
pydantic = "*"
Expand Down
60 changes: 39 additions & 21 deletions cool_seq_tool/mappers/mane_transcript.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import math
from typing import Dict, List, Optional, Set, Tuple, Union

import pandas as pd
import polars as pl

from cool_seq_tool.handlers.seqrepo_access import SeqRepoAccess
from cool_seq_tool.schemas import (
Expand Down Expand Up @@ -457,35 +457,51 @@ def _validate_index(
else:
return False

def _get_prioritized_transcripts_from_gene(
self, df: pd.core.frame.DataFrame
) -> List:
def _get_prioritized_transcripts_from_gene(self, df: pl.DataFrame) -> List:
"""Sort and filter transcripts from gene to get priority list
:param pd.core.frame.DataFrame df: Data frame containing transcripts from gene
:param df: Data frame containing transcripts from gene
data
:return: List of prioritized transcripts for a given gene. Sort by latest
assembly, longest length of transcript, with first-published transcripts
breaking ties. If there are multiple transcripts for a given accession, the
most recent version of a transcript associated with an assembly will be kept
"""
copy_df = df.copy(deep=True)
copy_df = copy_df.drop(columns="alt_ac").drop_duplicates()
copy_df["ac_no_version_as_int"] = copy_df["tx_ac"].apply(
lambda x: int(x.split(".")[0].split("NM_")[1])
copy_df = df.clone()
copy_df = copy_df.drop(columns="alt_ac").unique()
copy_df = copy_df.with_columns(
[
pl.col("tx_ac")
.str.split(".")
.list.get(0)
.str.split("NM_")
.list.get(1)
.cast(pl.Int64)
.alias("ac_no_version_as_int"),
pl.col("tx_ac")
.str.split(".")
.list.get(1)
.cast(pl.Int16)
.alias("ac_version"),
]
)
copy_df["ac_version"] = copy_df["tx_ac"].apply(lambda x: x.split(".")[1])
copy_df = copy_df.sort_values(
["ac_no_version_as_int", "ac_version"], ascending=[False, False]
copy_df = copy_df.sort(
by=["ac_no_version_as_int", "ac_version"], descending=[True, True]
)
copy_df = copy_df.drop_duplicates(["ac_no_version_as_int"], keep="first")
copy_df.loc[:, "len_of_tx"] = copy_df.loc[:, "tx_ac"].apply(
lambda ac: len(self.seqrepo_access.get_reference_sequence(ac)[0])
copy_df = copy_df.unique(["ac_no_version_as_int"], keep="first")

copy_df = copy_df.with_columns(
copy_df.map_rows(
lambda x: len(self.seqrepo_access.get_reference_sequence(x[1])[0])
)
.to_series()
.alias("len_of_tx")
)
copy_df = copy_df.sort_values(
["len_of_tx", "ac_no_version_as_int"], ascending=[False, True]

copy_df = copy_df.sort(
by=["len_of_tx", "ac_no_version_as_int"], descending=[True, False]
)
return list(copy_df["tx_ac"])
return copy_df.select("tx_ac").to_series().to_list()

async def get_longest_compatible_transcript(
self,
Expand Down Expand Up @@ -537,7 +553,7 @@ async def get_longest_compatible_transcript(
df = await self.uta_db.get_transcripts_from_gene(
gene, start_pos, end_pos, use_tx_pos=False, alt_ac=alt_ac
)
if df.empty:
if df.is_empty():
logger.warning(f"Unable to get transcripts from gene {gene}")
return None

Expand All @@ -551,8 +567,10 @@ async def get_longest_compatible_transcript(

for tx_ac in prioritized_tx_acs:
# Only need to check the one row since we do liftover in _c_to_g
tmp_df = df.loc[df["tx_ac"] == tx_ac].sort_values("alt_ac", ascending=False)
row = tmp_df.iloc[0]
tmp_df = df.filter(pl.col("tx_ac") == tx_ac).sort(
by="alt_ac", descending=True
)
row = tmp_df[0].to_dicts()[0]

if alt_ac is None:
alt_ac = row["alt_ac"]
Expand Down
34 changes: 17 additions & 17 deletions cool_seq_tool/sources/mane_transcript_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path
from typing import Dict, List, Optional

import pandas as pd
import polars as pl

from cool_seq_tool.paths import MANE_SUMMARY_PATH

Expand All @@ -20,19 +20,19 @@ def __init__(self, mane_data_path: Path = MANE_SUMMARY_PATH) -> None:
self.mane_data_path = mane_data_path
self.df = self._load_mane_transcript_data()

def _load_mane_transcript_data(self) -> pd.core.frame.DataFrame:
def _load_mane_transcript_data(self) -> pl.DataFrame:
"""Load RefSeq MANE data file into DataFrame.
:return: DataFrame containing RefSeq MANE Transcript data
"""
return pd.read_csv(self.mane_data_path, delimiter="\t")
return pl.read_csv(self.mane_data_path, separator="\t")

def get_gene_mane_data(self, gene_symbol: str) -> Optional[List[Dict]]:
"""Return MANE Transcript data for a gene.
:param str gene_symbol: HGNC Gene Symbol
:return: MANE Transcript data (Transcript accessions,
gene, and location information)
"""
data = self.df.loc[self.df["symbol"] == gene_symbol.upper()]
data = self.df.filter(pl.col("symbol") == gene_symbol.upper())

if len(data) == 0:
logger.warning(
Expand All @@ -41,20 +41,19 @@ def get_gene_mane_data(self, gene_symbol: str) -> Optional[List[Dict]]:
return None

# Ordering: MANE Plus Clinical (If it exists), MANE Select
data = data.sort_values("MANE_status")
return data.to_dict("records")
data = data.sort(by="MANE_status", descending=False)
return data.to_dicts()

def get_mane_from_transcripts(self, transcripts: List[str]) -> List[Dict]:
"""Get mane transcripts from a list of transcripts
:param List[str] transcripts: RefSeq transcripts on c. coordinate
:return: MANE data
"""
mane_rows = self.df["RefSeq_nuc"].isin(transcripts)
result = self.df[mane_rows]
if len(result) == 0:
mane_rows = self.df.filter(pl.col("RefSeq_nuc").is_in(transcripts))
if len(mane_rows) == 0:
return []
return result.to_dict("records")
return mane_rows.to_dicts()

def get_mane_data_from_chr_pos(
self, alt_ac: str, start: int, end: int
Expand All @@ -66,12 +65,13 @@ def get_mane_data_from_chr_pos(
:return: List of MANE data. Will return sorted list:
MANE Select then MANE Plus Clinical.
"""
mane_rows = self.df[
(start >= self.df["chr_start"].astype(int))
& (end <= self.df["chr_end"].astype(int))
& (self.df["GRCh38_chr"] == alt_ac)
]
mane_rows = self.df.filter(
(start >= pl.col("chr_start"))
& (end <= pl.col("chr_end"))
& (pl.col("GRCh38_chr") == alt_ac)
)
if len(mane_rows) == 0:
return []
mane_rows = mane_rows.sort_values("MANE_status", ascending=False)
return mane_rows.to_dict("records")

mane_rows = mane_rows.sort(by="MANE_status", descending=True)
return mane_rows.to_dicts()
13 changes: 8 additions & 5 deletions cool_seq_tool/sources/uta_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import asyncpg
import boto3
import pandas as pd
import polars as pl
from asyncpg.exceptions import InterfaceError, InvalidAuthorizationSpecificationError
from botocore.exceptions import ClientError
from pyliftover import LiftOver
Expand Down Expand Up @@ -870,7 +870,7 @@ async def get_transcripts_from_gene(
end_pos: int,
use_tx_pos: bool = True,
alt_ac: Optional[str] = None,
) -> pd.core.frame.DataFrame:
) -> pl.DataFrame:
"""Get transcripts associated to a gene.
:param str gene: Gene symbol
Expand Down Expand Up @@ -923,9 +923,12 @@ async def get_transcripts_from_gene(
{order_by_cond}
"""
results = await self.execute_query(query)
return pd.DataFrame(
results, columns=["pro_ac", "tx_ac", "alt_ac", "cds_start_i"]
).drop_duplicates()
results = [
(r["pro_ac"], r["tx_ac"], r["alt_ac"], r["cds_start_i"]) for r in results
]
return pl.DataFrame(
results, schema=["pro_ac", "tx_ac", "alt_ac", "cds_start_i"]
).unique()

async def get_chr_assembly(self, ac: str) -> Optional[Tuple[str, str]]:
"""Get chromosome and assembly for NC accession if not in GRCh38.
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ install_requires =
aiofiles
boto3
pyliftover
pandas
polars
hgvs
biocommons.seqrepo
pydantic
Expand Down
4 changes: 2 additions & 2 deletions tests/mappers/test_mane_transcript.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Module for testing MANE Transcript class."""
import copy

import pandas as pd
import polars as pl
import pytest
from mock import patch

Expand Down Expand Up @@ -469,7 +469,7 @@ def get_reference_sequence(ac):
["NM_001378472.1", 1, "NC_000007.14"],
["NM_001374258.2", 1, "NC_000007.14"],
]
test_df = pd.DataFrame(data, columns=["tx_ac", "len_of_tx", "alt_ac"])
test_df = pl.DataFrame(data, schema=["tx_ac", "len_of_tx", "alt_ac"])

resp = test_mane_transcript._get_prioritized_transcripts_from_gene(test_df)
assert resp == ["NM_004333.6", "NM_001374258.2", "NM_001378472.1"]
Expand Down

0 comments on commit b616b92

Please sign in to comment.