Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
korikuzma committed Aug 20, 2024
1 parent 465e81d commit 2a9c24f
Showing 1 changed file with 89 additions and 44 deletions.
133 changes: 89 additions & 44 deletions src/cool_seq_tool/mappers/exon_genomic_coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,57 @@
_logger = logging.getLogger(__name__)


def _check_errors(
values: dict,
required_fields: list[str],
either_or_fields: list[tuple[str, str]] | None = None,
) -> dict:
"""Ensure that required fields are set if `errors` field is empty
:param values: Values in model
:param required_fields: List of field names that are required if there are no errors
:param either_or_fields: List of field names where at least one field is required if
there are no errors
:raises ValueError: If required or either/or fields are not provided when there are
no errors
:return: Values in model
"""
if not values.get("errors"):
if not all(values.get(required_field) for required_field in required_fields):
err_msg = f"{required_fields} must all be provided"
raise ValueError(err_msg)

if either_or_fields:
for field1, field2 in either_or_fields:
if not (values.get(field1) or values.get(field2)):
err_msg = f"At least one of {field1} or {field2} must be provided"
raise ValueError(err_msg)

return values


class _Grch38Data(BaseModelForbidExtra):
"""Model representing GRCh38 accession and position, with errors"""

accession: StrictStr | None = Field(
None, description="GRCh38 genomic RefSeq accession."
)
position: StrictInt | None = Field(
None, description="GRCh38 genomic position on `genomic_ac`."
)
errors: list[StrictStr] = Field([], description="Error messages.")

@model_validator(mode="before")
def check_errors(cls, values: dict) -> dict: # noqa: N805
"""Ensure that fields are (un)set depending on errors
:param values: Values in model
:raises ValueError: If `accession` and `position` are not provided when there are no errors
:return: Values in model
"""
return _check_errors(values, required_fields=["accession", "position"])


class ExonCoord(BaseModelForbidExtra):
"""Model for representing exon coordinate data"""

Expand Down Expand Up @@ -101,17 +152,9 @@ def check_errors(cls, values: dict) -> dict: # noqa: N805
provided when there are no errors
:return: Values in model
"""
if not values.get("errors") and not all(
(
values.get("seg"),
values.get("gene"),
values.get("genomic_ac"),
values.get("tx_ac"),
)
):
err_msg = "`seg`, `gene`, `genomic_ac` and `tx_ac` must be provided"
raise ValueError(err_msg)
return values
return _check_errors(
values, required_fields=["seg", "gene", "genomic_ac", "tx_ac"]
)

model_config = ConfigDict(
json_schema_extra={
Expand Down Expand Up @@ -159,18 +202,11 @@ def add_meta_check_errors(cls, values: dict) -> dict: # noqa: N805
:return: Values in model, including service metadata
"""
values["service_meta"] = service_meta()
if not values.get("errors") and not all(
(
values.get("gene"),
values.get("genomic_ac"),
values.get("tx_ac"),
values.get("seg_start") or values.get("seg_end"),
)
):
err_msg = "`gene`, `genomic_ac`, `tx_ac` and `seg_start` or `seg_end` must be provided"
raise ValueError(err_msg)

return values
return _check_errors(
values,
required_fields=["gene", "genomic_ac", "tx_ac"],
either_or_fields=[("seg_start", "seg_end")],
)

model_config = ConfigDict(
json_schema_extra={
Expand Down Expand Up @@ -720,12 +756,11 @@ async def _genomic_to_tx_segment(
)
genomic_ac = genomic_acs[0]

# We should always try to liftover
genomic_ac, genomic_pos, err_msg = await self._get_grch38_ac_pos(
genomic_ac, genomic_pos
)
if err_msg:
return _GenomicTxSeg(errors=[err_msg])
# Always liftover to GRCh38
grch38_data = await self._get_grch38_ac_pos(genomic_ac, genomic_pos)
if grch38_data.errors:
return _GenomicTxSeg(errors=grch38_data.errors)
genomic_ac, genomic_pos = grch38_data.accession, grch38_data.position

if not transcript:
# Select a transcript if not provided
Expand Down Expand Up @@ -874,20 +909,24 @@ async def _genomic_to_tx_segment(

async def _get_grch38_ac_pos(
self, genomic_ac: str, genomic_pos: int, grch38_ac: str | None = None
) -> tuple[str | None, int | None, str | None]:
) -> _Grch38Data:
"""Get GRCh38 genomic representation for accession and position
:param genomic_ac: RefSeq genomic accession
:param genomic_pos: Genomic position on ``genomic_ac``
:param grch38_ac: GRCh38 genomic accession for ``genomic_ac``. If not provided,
will get associated GRCh38 accession.
:return: Tuple containing GRCh38 accession, GRCh38 position, and errors if
unable to get GRCh38 representation
:return: GRCh38 accession, GRCh38 position, and errors if unable to get GRCh38
representation
"""
if not grch38_ac:
grch38_ac = await self.uta_db.get_newest_assembly_ac(genomic_ac)
if not grch38_ac:
return None, None, f"Invalid genomic accession: {genomic_ac}"
return _Grch38Data(
accession=None,
position=None,
errors=[f"Invalid genomic accession: {genomic_ac}"],
)

grch38_ac = grch38_ac[0]

Expand All @@ -897,23 +936,29 @@ async def _get_grch38_ac_pos(
genomic_ac, Assembly.GRCH37.value
)
if not chromosome:
return None, None, "`genomic_ac` must use GRCh37 or GRCh38"
return _Grch38Data(
accession=None,
position=None,
errors=["`genomic_ac` must use GRCh37 or GRCh38"],
)

chromosome = chromosome[-1].split(":")[-1]
liftover_data = self.liftover.get_liftover(
chromosome, genomic_pos, Assembly.GRCH38
)
if liftover_data is None:
return (
None,
None,
f"Position {genomic_pos} does not exist on chromosome {chromosome}",
return _Grch38Data(
accession=None,
position=None,
errors=[
f"Position {genomic_pos} does not exist on chromosome {chromosome}"
],
)

genomic_pos = liftover_data[1]
genomic_ac = grch38_ac

return genomic_ac, genomic_pos, None
return _Grch38Data(accession=genomic_ac, position=genomic_pos)

def _get_tx_segment(
self,
Expand Down Expand Up @@ -1033,11 +1078,11 @@ async def _get_tx_seg_genomic_metadata(
tx_ac = mane_data["RefSeq_nuc"]
grch38_ac = mane_data["GRCh38_chr"]

genomic_ac, genomic_pos, err_msg = await self._get_grch38_ac_pos(
genomic_ac, genomic_pos, grch38_ac=grch38_ac
)
if err_msg:
return _GenomicTxSeg(errors=[err_msg])
# Always liftover to GRCh38
grch38_data = await self._get_grch38_ac_pos(genomic_ac, genomic_pos)
if grch38_data.errors:
return _GenomicTxSeg(errors=grch38_data.errors)
genomic_ac, genomic_pos = grch38_data.accession, grch38_data.position

tx_exons = await self._get_all_exon_coords(tx_ac, genomic_ac=grch38_ac)
if not tx_exons:
Expand Down

0 comments on commit 2a9c24f

Please sign in to comment.