diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 8aff261..89ba2be 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -80,7 +80,7 @@ Individual classes will accept arguments upon initialization to set parameters r * - ``SEQREPO_ROOT_DIR`` - Path to SeqRepo directory (i.e. contains ``aliases.sqlite3`` database file, and ``sequences`` directory). Used by :py:class:`SeqRepoAccess `. If not defined, defaults to ``/usr/local/share/seqrepo/latest``. * - ``UTA_DB_URL`` - - A `libpq connection string `_, i.e. of the form ``postgresql://:@://``, used by the :py:class:`UtaDatabase ` class. By default, it is set to ``postgresql://anonymous@localhost:5432/uta/uta_20241220``. + - A `libpq connection URI `_, i.e. of the form ``postgresql://:@:/?options=search_path%3D,public``, used by the :py:class:`UtaDatabase ` class. By default, it is set to ``postgresql://anonymous@localhost:5432/uta?options=-csearch_path%3Duta_20241220,public``. * - ``LIFTOVER_CHAIN_37_TO_38`` - A path to a `chainfile `_ for lifting from GRCh37 to GRCh38. Used by the :py:class:`LiftOver ` class as input to `agct `_. If not provided, agct will fetch it automatically from UCSC. * - ``LIFTOVER_CHAIN_38_TO_37`` diff --git a/pyproject.toml b/pyproject.toml index c407919..d8093c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ description = "Common Operation on Lots of Sequences Tool" license = "MIT" license-files = ["LICENSE"] dependencies = [ - "asyncpg", + "psycopg[pool]", "boto3", "agct >= 0.2.0rc1", "polars ~= 1.0", @@ -37,6 +37,9 @@ dependencies = [ dynamic = ["version"] [project.optional-dependencies] +pg_binary = [ + "psycopg[binary, pool]", +] dev = [ "prek>=0.2.23", "ipython", diff --git a/src/cool_seq_tool/app.py b/src/cool_seq_tool/app.py index 6b1b456..34a33bb 100644 --- a/src/cool_seq_tool/app.py +++ b/src/cool_seq_tool/app.py @@ -2,10 +2,10 @@ data handler and mapping resources for straightforward access. """ -import logging from pathlib import Path from biocommons.seqrepo import SeqRepo +from psycopg_pool import AsyncConnectionPool from cool_seq_tool.handlers.seqrepo_access import SEQREPO_ROOT_DIR, SeqRepoAccess from cool_seq_tool.mappers import ( @@ -16,9 +16,7 @@ ) from cool_seq_tool.sources.mane_transcript_mappings import ManeTranscriptMappings from cool_seq_tool.sources.transcript_mappings import TranscriptMappings -from cool_seq_tool.sources.uta_database import UTA_DB_URL, UtaDatabase - -_logger = logging.getLogger(__name__) +from cool_seq_tool.sources.uta_database import LazyUtaDatabase, UtaDatabase class CoolSeqTool: @@ -40,7 +38,7 @@ def __init__( transcript_file_path: Path | None = None, lrg_refseqgene_path: Path | None = None, mane_data_path: Path | None = None, - db_url: str = UTA_DB_URL, + uta_connection_pool: AsyncConnectionPool | None = None, sr: SeqRepo | None = None, force_local_files: bool = False, ) -> None: @@ -75,8 +73,10 @@ def __init__( :param transcript_file_path: The path to ``transcript_mapping.tsv`` :param lrg_refseqgene_path: The path to the LRG_RefSeqGene file :param mane_data_path: Path to RefSeq MANE summary data - :param db_url: PostgreSQL connection URL - Format: ``driver://user:password@host/database/schema`` + :param uta_connection_pool: pyscopg connection pool to UTA instance. If not + provided, a lazy UTA connection will be used, meaning the connection won't + be initiated until the first attempted UTA query, and will use environment + configs/library defaults :param sr: SeqRepo instance. If this is not provided, will create a new instance :param force_local_files: if ``True``, don't check for or try to acquire latest versions of static data files -- just use most recently available, if any @@ -92,7 +92,10 @@ def __init__( self.mane_transcript_mappings = ManeTranscriptMappings( mane_data_path=mane_data_path, from_local=force_local_files ) - self.uta_db = UtaDatabase(db_url=db_url) + if uta_connection_pool: + self.uta_db = UtaDatabase(uta_connection_pool) + else: + self.uta_db = LazyUtaDatabase() self.alignment_mapper = AlignmentMapper( self.seqrepo_access, self.transcript_mappings, self.uta_db ) diff --git a/src/cool_seq_tool/mappers/alignment.py b/src/cool_seq_tool/mappers/alignment.py index 70c7432..81ce4ff 100644 --- a/src/cool_seq_tool/mappers/alignment.py +++ b/src/cool_seq_tool/mappers/alignment.py @@ -48,7 +48,8 @@ async def p_to_c( * Warning, if unable to translate to cDNA representation. Else ``None`` """ # Get cDNA accession - temp_c_ac = await self.uta_db.p_to_c_ac(p_ac) + async with self.uta_db.repository() as uta: + temp_c_ac = await uta.p_to_c_ac(p_ac) if temp_c_ac: c_ac = temp_c_ac[-1] else: @@ -89,7 +90,8 @@ async def _get_cds_start(self, c_ac: str) -> tuple[int | None, str | None]: - CDS start site if found. Else ``None`` - Warning, if unable to get CDS start. Else ``None`` """ - cds_start_end = await self.uta_db.get_cds_start_end(c_ac) + async with self.uta_db.repository() as uta: + cds_start_end = await uta.get_cds_start_end(c_ac) if not cds_start_end: cds_start = None warning = f"Accession {c_ac} not found in UTA db" @@ -149,12 +151,13 @@ async def c_to_g( c_start_pos -= 1 # Get aligned genomic and transcript data - genomic_tx_data = await self.uta_db.get_genomic_tx_data( - c_ac, - (c_start_pos + cds_start, c_end_pos + cds_start), - AnnotationLayer.CDNA, - target_genome_assembly=target_genome_assembly, - ) + async with self.uta_db.repository() as uta: + genomic_tx_data = await uta.get_genomic_tx_data( + c_ac, + (c_start_pos + cds_start, c_end_pos + cds_start), + AnnotationLayer.CDNA, + target_genome_assembly=target_genome_assembly, + ) if not genomic_tx_data: warning = ( diff --git a/src/cool_seq_tool/mappers/exon_genomic_coords.py b/src/cool_seq_tool/mappers/exon_genomic_coords.py index e00a8bf..efc0d2d 100644 --- a/src/cool_seq_tool/mappers/exon_genomic_coords.py +++ b/src/cool_seq_tool/mappers/exon_genomic_coords.py @@ -19,7 +19,11 @@ TranscriptPriority, ) from cool_seq_tool.sources.mane_transcript_mappings import ManeTranscriptMappings -from cool_seq_tool.sources.uta_database import GenomicAlnData, UtaDatabase +from cool_seq_tool.sources.uta_database import ( + GenomicAlnData, + NoMatchingAlignmentError, + UtaDatabase, +) from cool_seq_tool.utils import service_meta _logger = logging.getLogger(__name__) @@ -637,29 +641,43 @@ async def _get_all_exon_coords( The list will be ordered by ascending exon number. """ if genomic_ac: - query = f""" + query = """ SELECT DISTINCT ord, tx_start_i, tx_end_i, alt_start_i, alt_end_i, alt_strand - FROM {self.uta_db.schema}.tx_exon_aln_mv - WHERE tx_ac = '{tx_ac}' + FROM tx_exon_aln_mv + WHERE tx_ac = %(tx_ac)s AND alt_aln_method = 'splign' - AND alt_ac = '{genomic_ac}' - ORDER BY ord ASC - """ # noqa: S608 + AND alt_ac = %(genomic_ac)s + ORDER BY ord ASC; + """ else: - query = f""" + query = """ SELECT DISTINCT ord, tx_start_i, tx_end_i, alt_start_i, alt_end_i, alt_strand - FROM {self.uta_db.schema}.tx_exon_aln_mv as t - INNER JOIN {self.uta_db.schema}._seq_anno_most_recent as s + FROM tx_exon_aln_mv as t + INNER JOIN _seq_anno_most_recent as s ON t.alt_ac = s.ac WHERE s.descr = '' - AND t.tx_ac = '{tx_ac}' + AND t.tx_ac = %(tx_ac)s AND t.alt_aln_method = 'splign' - AND t.alt_ac like 'NC_000%' - ORDER BY ord ASC - """ # noqa: S608 + AND t.alt_ac like 'NC_000%%' + ORDER BY ord ASC; + """ - results = await self.uta_db.execute_query(query) - return [_ExonCoord(**r) for r in results] + async with self.uta_db.repository() as uta: + cursor = await uta.execute_query( + query, {"tx_ac": tx_ac, "genomic_ac": genomic_ac} + ) + results = await cursor.fetchall() + return [ + _ExonCoord( + ord=r[0], + tx_start_i=r[1], + tx_end_i=r[2], + alt_start_i=r[3], + alt_end_i=r[4], + alt_strand=r[5], + ) + for r in results + ] async def _get_genomic_aln_coords( self, @@ -690,13 +708,14 @@ async def _get_genomic_aln_coords( aligned_coords = {"start": None, "end": None} for exon, key in [(tx_exon_start, "start"), (tx_exon_end, "end")]: if exon: - aligned_coord, warning = await self.uta_db.get_alt_ac_start_or_end( - tx_ac, exon.tx_start_i, exon.tx_end_i, gene=gene - ) - if aligned_coord: + async with self.uta_db.repository() as uta: + try: + aligned_coord = await uta.get_alt_ac_start_or_end( + tx_ac, exon.tx_start_i, exon.tx_end_i, gene=gene + ) + except NoMatchingAlignmentError as e: + return None, None, str(e) aligned_coords[key] = aligned_coord - else: - return None, None, warning return *aligned_coords.values(), None @@ -827,19 +846,22 @@ async def _genomic_to_tx_segment( # Validate inputs exist in UTA if gene: - gene_validation = await self.uta_db.gene_exists(gene) + async with self.uta_db.repository() as uta: + gene_validation = await uta.gene_exists(gene) if not gene_validation: return GenomicTxSeg(errors=[f"Gene does not exist in UTA: {gene}"]) if transcript: - transcript_validation = await self.uta_db.transcript_exists(transcript) + async with self.uta_db.repository() as uta: + transcript_validation = await uta.transcript_exists(transcript) if not transcript_validation: return GenomicTxSeg( errors=[f"Transcript does not exist in UTA: {transcript}"] ) if genomic_ac: - grch38_ac = await self.uta_db.get_newest_assembly_ac(genomic_ac) + async with self.uta_db.repository() as uta: + grch38_ac = await uta.get_newest_assembly_ac(genomic_ac) if grch38_ac: genomic_ac = grch38_ac[0] else: @@ -888,16 +910,20 @@ async def _genomic_to_tx_segment( transcript = results.refseq else: # Run if gene is for a noncoding transcript - query = f""" + query = """ SELECT DISTINCT tx_ac - FROM {self.uta_db.schema}.tx_exon_aln_mv - WHERE hgnc = '{gene}' - AND alt_ac = '{genomic_ac}' - """ # noqa: S608 - result = await self.uta_db.execute_query(query) + FROM tx_exon_aln_mv + WHERE hgnc = %(gene)s + AND alt_ac = %(genomic_ac)s; + """ + async with self.uta_db.repository() as uta: + cursor = await uta.execute_query( + query, {"gene": gene, "genomic_ac": genomic_ac} + ) + result = await cursor.fetchone() if result: - transcript = result[0]["tx_ac"] + transcript = result[0] else: return GenomicTxSeg( errors=[ @@ -955,13 +981,14 @@ async def _genomic_to_tx_segment( ) else: is_exonic = True - exon_data = await self.uta_db.get_tx_exon_aln_data( - transcript, - genomic_pos, - genomic_pos, - alt_ac=genomic_ac, - use_tx_pos=False, - ) + async with self.uta_db.repository() as uta: + exon_data = await uta.get_tx_exon_aln_data( + transcript, + genomic_pos, + genomic_pos, + alt_ac=genomic_ac, + use_tx_pos=False, + ) exon_num = exon_data[0].ord offset = self._get_exon_offset( @@ -1030,20 +1057,24 @@ async def _validate_genomic_breakpoint( for the transcript, ``False`` if not. Breakpoints past this threshold are likely erroneous. """ - query = f""" + query = """ WITH tx_boundaries AS ( SELECT MIN(alt_start_i) AS min_start, MAX(alt_end_i) AS max_end - FROM {self.uta_db.schema}.tx_exon_aln_mv - WHERE tx_ac = '{tx_ac}' - AND alt_ac = '{genomic_ac}' + FROM tx_exon_aln_mv + WHERE tx_ac = %(tx_ac)s + AND alt_ac = %(genomic_ac)s ) SELECT * FROM tx_boundaries - WHERE {pos} between (tx_boundaries.min_start - 150) and (tx_boundaries.max_end + 150) - """ # noqa: S608 - results = await self.uta_db.execute_query(query) - return bool(results) + WHERE %(pos)s between (tx_boundaries.min_start - 150) and (tx_boundaries.max_end + 150); + """ + async with self.uta_db.repository() as uta: + cursor = await uta.execute_query( + query, {"tx_ac": tx_ac, "genomic_ac": genomic_ac, "pos": pos} + ) + result = await cursor.fetchone() + return bool(result) async def _get_tx_ac_gene( self, @@ -1058,18 +1089,20 @@ async def _get_tx_ac_gene( :return: HGNC gene symbol associated to transcript and warning """ - query = f""" + query = """ SELECT DISTINCT hgnc - FROM {self.uta_db.schema}.tx_exon_aln_mv - WHERE tx_ac = '{tx_ac}' + FROM tx_exon_aln_mv + WHERE tx_ac = %(tx_ac)s ORDER BY hgnc LIMIT 1; - """ # noqa: S608 - results = await self.uta_db.execute_query(query) - if not results: + """ + async with self.uta_db.repository() as uta: + cursor = await uta.execute_query(query, {"tx_ac": tx_ac}) + result = await cursor.fetchone() + if not result: return None, f"No gene(s) found given {tx_ac}" - return results[0]["hgnc"], None + return result[0], None @staticmethod def _is_exonic_breakpoint(pos: int, tx_genomic_coords: list[_ExonCoord]) -> bool: diff --git a/src/cool_seq_tool/mappers/mane_transcript.py b/src/cool_seq_tool/mappers/mane_transcript.py index dc1a10c..5f0c999 100644 --- a/src/cool_seq_tool/mappers/mane_transcript.py +++ b/src/cool_seq_tool/mappers/mane_transcript.py @@ -164,7 +164,8 @@ async def _p_to_c( :return: [cDNA transcript accession, [cDNA pos start, cDNA pos end]] """ # TODO: Check version mappings 1 to 1 relationship - temp_ac = await self.uta_db.p_to_c_ac(ac) + async with self.uta_db.repository() as uta: + temp_ac = await uta.p_to_c_ac(ac) if temp_ac: ac = temp_ac[-1] else: @@ -209,7 +210,8 @@ async def _c_to_g(self, ac: str, pos: tuple[int, int]) -> GenomicTxMetadata | No temp_ac = ac # c. coordinate does not contain cds start, so we need to add it - cds_start_end = await self.uta_db.get_cds_start_end(temp_ac) + async with self.uta_db.repository() as uta: + cds_start_end = await uta.get_cds_start_end(temp_ac) if not cds_start_end: _logger.warning("Accession %s not found in UTA", temp_ac) return None @@ -226,18 +228,22 @@ async def _liftover_to_38(self, genomic_tx_data: GenomicTxMetadata) -> None: :param genomic_tx_data: Metadata for genomic and transcript accessions. This will be mutated in-place if not GRCh38 assembly. """ - descr = await self.uta_db.get_chr_assembly(genomic_tx_data.alt_ac) + async with self.uta_db.repository() as uta: + descr = await uta.get_chr_assembly(genomic_tx_data.alt_ac) if descr is None: # already grch38 return chromosome, _ = descr - query = f""" - SELECT DISTINCT alt_ac - FROM {self.uta_db.schema}.tx_exon_aln_mv - WHERE tx_ac = '{genomic_tx_data.tx_ac}'; - """ # noqa: S608 - nc_acs = await self.uta_db.execute_query(query) + async with self.uta_db.repository() as uta: + cursor = await uta.execute_query( + """SELECT DISTINCT alt_ac + FROM tx_exon_aln_mv + WHERE tx_ac = %(tx_ac)s; + """, + {"tx_ac": genomic_tx_data.tx_ac}, + ) + nc_acs = await cursor.fetchall() nc_acs = [nc_ac[0] for nc_ac in nc_acs] if nc_acs == [genomic_tx_data.alt_ac]: _logger.warning( @@ -267,11 +273,15 @@ async def _liftover_to_38(self, genomic_tx_data: GenomicTxMetadata) -> None: """ query = f""" SELECT alt_ac - FROM {self.uta_db.schema}.genomic - WHERE alt_ac LIKE '{genomic_tx_data.alt_ac.split(".")[0]}%' + FROM genomic + WHERE alt_ac LIKE %(ac_pattern)s {order_by_cond} """ # noqa: S608 - nc_acs = await self.uta_db.execute_query(query) + async with self.uta_db.repository() as uta: + cursor = await uta.execute_query( + query, {"ac_pattern": f"{genomic_tx_data.alt_ac.split('.')[0]}%"} + ) + nc_acs = await cursor.fetchall() genomic_tx_data.alt_ac = nc_acs[0][0] def _set_liftover( @@ -331,9 +341,10 @@ async def _get_and_validate_genomic_tx_data( :return: Metadata for genomic and transcript accessions if found and validated, else None """ - genomic_tx_data = await self.uta_db.get_genomic_tx_data( - tx_ac, pos, annotation_layer, alt_ac=alt_ac - ) + async with self.uta_db.repository() as uta: + genomic_tx_data = await uta.get_genomic_tx_data( + tx_ac, pos, annotation_layer, alt_ac=alt_ac + ) if not genomic_tx_data: _logger.warning( "Unable to find genomic_tx_data for %s at position %s on annotation layer %s", @@ -470,13 +481,14 @@ async def _g_to_c( tx_g_pos = g.alt_pos_range tx_pos_range = g.tx_pos_range else: - result = await self.uta_db.get_tx_exon_aln_data( - refseq_c_ac, - g.alt_pos_change_range[0], - g.alt_pos_change_range[1], - alt_ac=alt_ac if alt_ac else g.alt_ac, - use_tx_pos=False, - ) + async with self.uta_db.repository() as uta: + result = await uta.get_tx_exon_aln_data( + refseq_c_ac, + g.alt_pos_change_range[0], + g.alt_pos_change_range[1], + alt_ac=alt_ac or g.alt_ac, + use_tx_pos=False, + ) if not result: _logger.warning( @@ -487,7 +499,8 @@ async def _g_to_c( tx_g_pos = result.alt_start_i, result.alt_end_i tx_pos_range = result.tx_start_i, result.tx_end_i - cds_start_end = await self.uta_db.get_cds_start_end(refseq_c_ac) + async with self.uta_db.repository() as uta: + cds_start_end = await uta.get_cds_start_end(refseq_c_ac) if not cds_start_end: return None coding_start_site = cds_start_end[0] @@ -790,13 +803,15 @@ def _get_protein_rep( # Data Frame that contains transcripts associated to a gene if is_p_or_c_start_anno: - df = await self.uta_db.get_transcripts( - c_start_pos, c_end_pos, gene=gene, use_tx_pos=True, alt_ac=alt_ac - ) + async with self.uta_db.repository() as uta: + df = await uta.get_transcripts( + c_start_pos, c_end_pos, gene=gene, use_tx_pos=True, alt_ac=alt_ac + ) else: - df = await self.uta_db.get_transcripts( - start_pos, end_pos, gene=gene, use_tx_pos=False, alt_ac=alt_ac - ) + async with self.uta_db.repository() as uta: + df = await uta.get_transcripts( + start_pos, end_pos, gene=gene, use_tx_pos=False, alt_ac=alt_ac + ) if df.is_empty(): _logger.warning("Unable to get transcripts from gene %s", gene) @@ -1151,7 +1166,8 @@ async def g_to_grch38( start_pos, end_pos = get_inter_residue_pos(start_pos, end_pos, coordinate_type) # Checking to see what chromosome and assembly we're on - descr = await self.uta_db.get_chr_assembly(ac) + async with self.uta_db.repository() as uta: + descr = await uta.get_chr_assembly(ac) if not descr: # Already GRCh38 assembly if self.validate_index(ac, (start_pos, end_pos), 0): @@ -1190,7 +1206,8 @@ async def g_to_grch38( else: end_pos = start_pos - newest_ac = await self.uta_db.get_newest_assembly_ac(ac) + async with self.uta_db.repository() as uta: + newest_ac = await uta.get_newest_assembly_ac(ac) if newest_ac: ac = newest_ac[0] if self.validate_index(ac, (start_pos, end_pos), 0): @@ -1262,7 +1279,9 @@ async def g_to_mane_c( start_pos, end_pos = get_inter_residue_pos(start_pos, end_pos, coordinate_type) coordinate_type = CoordinateType.INTER_RESIDUE - if not await self.uta_db.validate_genomic_ac(ac): + async with self.uta_db.repository() as uta: + validation_result = uta.validate_genomic_ac(ac) + if not validation_result: _logger.warning("Genomic accession does not exist: %s", ac) return None @@ -1284,15 +1303,17 @@ async def g_to_mane_c( mane_tx_genomic_data = None if grch38: # GRCh38 -> MANE C - mane_tx_genomic_data = await self.uta_db.get_mane_c_genomic_data( - mane_c_ac, grch38.ac, grch38.pos[0], grch38.pos[1] - ) + async with self.uta_db.repository() as uta: + mane_tx_genomic_data = await uta.get_mane_c_genomic_data( + mane_c_ac, grch38.ac, grch38.pos[0], grch38.pos[1] + ) if not grch38 or not mane_tx_genomic_data: # GRCh38 did not work, so let's try original assembly (37) - mane_tx_genomic_data = await self.uta_db.get_mane_c_genomic_data( - mane_c_ac, ac, start_pos, end_pos - ) + async with self.uta_db.repository() as uta: + mane_tx_genomic_data = await uta.get_mane_c_genomic_data( + mane_c_ac, ac, start_pos, end_pos + ) if not mane_tx_genomic_data: continue _logger.info("Not using most recent assembly") @@ -1378,9 +1399,13 @@ async def grch38_to_mane_c_p( mane_transcripts |= {mane_c_ac, current_mane_data["Ensembl_nuc"]} # GRCh38 -> MANE C - mane_tx_genomic_data = await self.uta_db.get_mane_c_genomic_data( - ac=mane_c_ac, alt_ac=mane_alt_ac, start_pos=start_pos, end_pos=end_pos - ) + async with self.uta_db.repository() as uta: + mane_tx_genomic_data = await uta.get_mane_c_genomic_data( + ac=mane_c_ac, + alt_ac=mane_alt_ac, + start_pos=start_pos, + end_pos=end_pos, + ) if not mane_tx_genomic_data: continue diff --git a/src/cool_seq_tool/sources/__init__.py b/src/cool_seq_tool/sources/__init__.py index e48a429..288be12 100644 --- a/src/cool_seq_tool/sources/__init__.py +++ b/src/cool_seq_tool/sources/__init__.py @@ -1,7 +1,17 @@ """Module for providing basic acquisition/setup for the various resources""" -from .mane_transcript_mappings import ManeTranscriptMappings -from .transcript_mappings import TranscriptMappings -from .uta_database import UtaDatabase +from cool_seq_tool.sources.mane_transcript_mappings import ManeTranscriptMappings +from cool_seq_tool.sources.transcript_mappings import TranscriptMappings +from cool_seq_tool.sources.uta_database import ( + UtaDatabase, + UtaRepository, + create_uta_connection_pool, +) -__all__ = ["ManeTranscriptMappings", "TranscriptMappings", "UtaDatabase"] +__all__ = [ + "ManeTranscriptMappings", + "TranscriptMappings", + "UtaDatabase", + "UtaRepository", + "create_uta_connection_pool", +] diff --git a/src/cool_seq_tool/sources/uta_database.py b/src/cool_seq_tool/sources/uta_database.py index 5f7ede7..ed5ee6a 100644 --- a/src/cool_seq_tool/sources/uta_database.py +++ b/src/cool_seq_tool/sources/uta_database.py @@ -1,17 +1,36 @@ -"""Provide transcript lookup and metadata tools via the UTA database.""" +"""Provide transcript lookup and metadata tools via the UTA database. + +In an asyncio runtime: + + >>> from cool_seq_tool.sources.uta_database import ( + ... create_uta_connection_pool, + ... UtaDatabase, + ... ) + >>> pool = await create_uta_connection_pool() + >>> uta_db = UtaDatabase(pool) + >>> async with uta_db.repository() as uta: + ... braf_exists = await uta.gene_exists("BRAF") + >>> braf_exists + True + +""" import ast import logging -from os import environ -from typing import Any, Literal, TypeVar +import os +import warnings +from collections.abc import AsyncIterator, Mapping, Sequence +from contextlib import asynccontextmanager +from typing import Literal from urllib.parse import ParseResult as UrlLibParseResult -from urllib.parse import unquote, urlparse, urlunparse +from urllib.parse import urlparse, urlunparse -import asyncpg import boto3 import polars as pl -from asyncpg.exceptions import InterfaceError, InvalidAuthorizationSpecificationError from botocore.exceptions import ClientError +from psycopg import AsyncConnection, AsyncCursor +from psycopg.errors import UndefinedTable +from psycopg_pool import AsyncConnectionPool from pydantic import Field, StrictInt, StrictStr from cool_seq_tool.schemas import ( @@ -23,26 +42,9 @@ Strand, ) -# use `bound` to upper-bound UtaDatabase or child classes -UTADatabaseType = TypeVar("UTADatabaseType", bound="UtaDatabase") - -UTA_DB_URL = environ.get( - "UTA_DB_URL", "postgresql://anonymous@localhost:5432/uta/uta_20241220" -) - _logger = logging.getLogger(__name__) -class DbConnectionArgs(BaseModelForbidExtra): - """Represent database connection arguments""" - - host: str - port: int - user: str - password: str - database: str - - class GenomicAlnData(BaseModelForbidExtra): """Represent genomic alignment data from UTA tx_exon_aln_mv view""" @@ -79,251 +81,185 @@ class TxExonAlnData(GenomicAlnData): alt_exon_id: StrictInt = Field(..., description="`alt_ac` exon identifier.") -class UtaDatabase: - """Provide transcript lookup and metadata tools via the Universal Transcript Archive - (UTA) database. +class NoMatchingAlignmentError(Exception): + """Raise for failure to find alignment matching user parameters""" - Users should use the ``create()`` method to construct a new instance. Note that - almost all public methods are defined as ``async`` -- see the :ref:`Usage section ` - for more information. - >>> import asyncio - >>> from cool_seq_tool.sources.uta_database import UtaDatabase - >>> uta_db = asyncio.run(UtaDatabase.create()) - """ +class UtaRepository: + """Connection-scoped repository for issuing queries against UTA. - def __init__(self, db_url: str = UTA_DB_URL) -> None: - """Initialize DB class. Should only be used by ``create()`` method, and not - be called directly by a user. + This class encapsulates predefined UTA queries and related result parsing. + It operates on an active psycopg async connection provided at initialization + time and does not manage connection lifecycle or pooling. - :param db_url: PostgreSQL connection URL - Format: ``driver://user:password@host/database/schema`` - """ - self.schema = None - self._connection_pool = None - self.db_url = db_url - self.args = self._get_conn_args() + Instances are intended to be short-lived and used within the scope of a + checked-out connection (e.g., from a connection pool). + """ - def _get_conn_args(self) -> DbConnectionArgs: - """Return connection arguments. + def __init__(self, conn: AsyncConnection) -> None: + """Initialize the repository with an active database connection. - :param db_url: raw connection URL - :return: Database connection arguments + :param conn: Active psycopg async connection to a UTA database. + The caller is responsible for connection lifecycle management. """ - if "UTA_DB_PROD" in environ: - secret = ast.literal_eval(self.get_secret()) - - password = secret["password"] - username = secret["username"] - port = secret["port"] - host = secret["host"] - database = secret["dbname"] - schema = secret["schema"] - self.schema = schema - - environ["PGPASSWORD"] = password - environ["UTA_DB_URL"] = ( - f"postgresql://{username}@{host}:{port}/{database}/{schema}" - ) - return DbConnectionArgs( - host=host, - port=int(port), - database=database, - user=username, - password=password, - ) - - url = ParseResult(urlparse(self.db_url)) - self.schema = url.schema - password = unquote(url.password) if url.password else "" - return DbConnectionArgs( - host=url.hostname, - port=url.port, - database=url.database, - user=url.username, - password=password, - ) - - async def create_pool(self) -> None: - """Create connection pool if not already created.""" - if not self._connection_pool: - self.args = self._get_conn_args() - try: - self._connection_pool = await asyncpg.create_pool( - min_size=1, - max_size=10, - max_inactive_connection_lifetime=3, - command_timeout=60, - host=self.args.host, - port=self.args.port, - user=self.args.user, - password=self.args.password, - database=self.args.database, - ) - except InterfaceError as e: - _logger.exception( - "While creating connection pool, encountered exception" - ) - msg = "Could not create connection pool" - raise Exception(msg) from e # noqa: TRY002 - - @classmethod - async def create( - cls: type[UTADatabaseType], db_url: str = UTA_DB_URL - ) -> UTADatabaseType: - """Manufacture a fully-initialized class instance (a la factory pattern). This - method should be used instead of calling the class directly to create a new - instance. - - >>> import asyncio - >>> from cool_seq_tool.sources.uta_database import UtaDatabase - >>> uta_db = asyncio.run(UtaDatabase.create()) - - :param cls: supplied implicitly - :param db_url: PostgreSQL connection URL - Format: ``driver://user:password@host/database/schema`` - :return: UTA DB access class instance + self._conn = conn + + async def execute_query( + self, q: str, params: Sequence | Mapping | None = None + ) -> AsyncCursor: + """Execute an arbitrary query against the UTA DB + + This method is marked as public so that downstream applications can run custom + queries using the same DB connection. However, that means they are responsible + for managing the cursor themselves. + + :param q: raw query. May need to specify schema depending on connection context. + :param params: query variables, if needed. These should not be hard-coded into the query. + :return: query result cursor + :raise UndefinedTable: if queried table isn't in the search_path -- this likely + indicates a UTA schema/search path config issue """ - self = cls(db_url) - await self._create_genomic_table() - await self.create_pool() - return self - - async def execute_query(self, query: str) -> Any: # noqa: ANN401 - """Execute a query and return its result. - - :param query: Query to make on database - :return: Query's result - """ - - async def _execute_query(q: str) -> Any: # noqa: ANN401 - async with ( - self._connection_pool.acquire() as connection, - connection.transaction(), - ): - return await connection.fetch(q) - - if not self._connection_pool: - await self.create_pool() try: - return await _execute_query(query) - except InvalidAuthorizationSpecificationError: - self._connection_pool = None - await self.create_pool() - return await _execute_query(query) - - async def _create_genomic_table(self) -> None: - """Create table containing genomic accession information.""" - check_table_exists = f""" - SELECT EXISTS ( - SELECT FROM information_schema.tables - WHERE table_schema = '{self.schema}' - AND table_name = 'genomic' - ); - """ # noqa: S608 - genomic_table_exists = await self.execute_query(check_table_exists) - genomic_table_exists = genomic_table_exists[0].get("exists") - if genomic_table_exists is None: - _logger.critical( - "SELECT EXISTS query in UtaDatabase._create_genomic_table " - "returned invalid response" + return await self._conn.execute(q, params) + except UndefinedTable: + search_path = await ( + await self._conn.execute("SHOW search_path;") + ).fetchone() + _logger.exception( + "Query '%s' raised UndefinedTable error. Is search_path set to include correct UTA schema? Current search_path value is '%s'", + q, + search_path, ) - msg = "SELECT EXISTS query returned invalid response" - raise ValueError(msg) - if not genomic_table_exists: - create_genomic_table = f""" - CREATE TABLE {self.schema}.genomic AS - SELECT t.hgnc, aes.alt_ac, aes.alt_aln_method, - aes.alt_strand, ae.start_i AS alt_start_i, - ae.end_i AS alt_end_i - FROM ((((({self.schema}.transcript t - JOIN {self.schema}.exon_set tes ON (((t.ac = tes.tx_ac) - AND (tes.alt_aln_method = 'transcript'::text)))) - JOIN {self.schema}.exon_set aes ON (((t.ac = aes.tx_ac) - AND (aes.alt_aln_method <> 'transcript'::text)))) - JOIN {self.schema}.exon te ON - ((tes.exon_set_id = te.exon_set_id))) - JOIN {self.schema}.exon ae ON - (((aes.exon_set_id = ae.exon_set_id) - AND (te.ord = ae.ord)))) - LEFT JOIN {self.schema}.exon_aln ea ON - (((te.exon_id = ea.tx_exon_id) AND - (ae.exon_id = ea.alt_exon_id)))); - """ # noqa: S608 - await self.execute_query(create_genomic_table) - - indexes = [ - f"""CREATE INDEX alt_pos_index ON {self.schema}.genomic (alt_ac, alt_start_i, alt_end_i);""", - f"""CREATE INDEX gene_alt_index ON {self.schema}.genomic (hgnc, alt_ac);""", - f"""CREATE INDEX alt_ac_index ON {self.schema}.genomic (alt_ac);""", - ] - for create_index in indexes: - await self.execute_query(create_index) - - @staticmethod - def _transform_list(li: list) -> list[list[Any]]: - """Transform list to only contain field values + raise - :param li: List of asyncpg.Record objects - :return: List of list of objects + async def create_genomic_table(self) -> None: + """Create the derived ``genomic`` table in the current schema if needed.""" + create_genomic_table = """ + CREATE TABLE IF NOT EXISTS genomic AS + SELECT + t.hgnc, + aes.alt_ac, + aes.alt_aln_method, + aes.alt_strand, + ae.start_i AS alt_start_i, + ae.end_i AS alt_end_i + FROM transcript t + JOIN exon_set tes + ON t.ac = tes.tx_ac + AND tes.alt_aln_method = 'transcript' + JOIN exon_set aes + ON t.ac = aes.tx_ac + AND aes.alt_aln_method <> 'transcript' + JOIN exon te + ON tes.exon_set_id = te.exon_set_id + JOIN exon ae + ON aes.exon_set_id = ae.exon_set_id + AND te.ord = ae.ord + LEFT JOIN exon_aln ea + ON te.exon_id = ea.tx_exon_id + AND ae.exon_id = ea.alt_exon_id; """ - return [list(i) for i in li] + await self.execute_query(create_genomic_table) + + indexes = [ + """ + CREATE INDEX IF NOT EXISTS alt_pos_index + ON genomic (alt_ac, alt_start_i, alt_end_i); + """, + """ + CREATE INDEX IF NOT EXISTS gene_alt_index + ON genomic (hgnc, alt_ac); + """, + """ + CREATE INDEX IF NOT EXISTS alt_ac_index + ON genomic (alt_ac); + """, + ] + for create_index in indexes: + await self.execute_query(create_index) async def get_alt_ac_start_or_end( self, tx_ac: str, tx_exon_start: int, tx_exon_end: int, gene: str | None - ) -> tuple[GenomicAlnData | None, str | None]: + ) -> GenomicAlnData: """Get genomic data for related transcript exon start or end. :param tx_ac: Transcript accession :param tx_exon_start: Transcript's exon start coordinate :param tx_exon_end: Transcript's exon end coordinate - :param gene: HGNC gene symbol - :return: Genomic alignment data and warnings if found + :param gene: HGNC gene symbol, if available + :return: Genomic alignment data if match found + :raise NoMatchingAlignmentError: if unable to find alignment matching given params + """ + query = """ + SELECT + T.hgnc, + T.alt_ac, + T.alt_start_i, + T.alt_end_i, + T.alt_strand, + T.ord + FROM _cds_exons_fp_v AS C + JOIN tx_exon_aln_mv AS T ON T.tx_ac = C.tx_ac + WHERE T.tx_ac = %(tx_ac)s + AND (%(gene)s::text IS NULL OR T.hgnc = %(gene)s::text) + AND %(tx_exon_start)s BETWEEN T.tx_start_i AND T.tx_end_i + AND %(tx_exon_end)s BETWEEN T.tx_start_i AND T.tx_end_i + AND T.alt_aln_method = 'splign' + AND T.alt_ac LIKE 'NC_00%%' + ORDER BY CAST( + SUBSTR( + T.alt_ac, + POSITION('.' IN T.alt_ac) + 1, + LENGTH(T.alt_ac) + ) AS INT + ) DESC; """ - gene_query = f"AND T.hgnc = '{gene}'" if gene else "" - query = f""" - SELECT T.hgnc, T.alt_ac, T.alt_start_i, T.alt_end_i, T.alt_strand, T.ord - FROM {self.schema}._cds_exons_fp_v as C - JOIN {self.schema}.tx_exon_aln_mv as T ON T.tx_ac = C.tx_ac - WHERE T.tx_ac = '{tx_ac}' - {gene_query} - AND {tx_exon_start} BETWEEN T.tx_start_i AND T.tx_end_i - AND {tx_exon_end} BETWEEN T.tx_start_i AND T.tx_end_i - AND T.alt_aln_method = 'splign' - AND T.alt_ac LIKE 'NC_00%' - ORDER BY CAST(SUBSTR(T.alt_ac, position('.' in T.alt_ac) + 1, - LENGTH(T.alt_ac)) AS INT) DESC; - """ # noqa: S608 - result = await self.execute_query(query) - if not result: - msg = ( - f"Unable to find a result where {tx_ac} has transcript " - f"coordinates {tx_exon_start} and {tx_exon_end} between " - f"an exon's start and end coordinates" - ) - if gene_query: - msg += f" on gene {gene}" + params = { + "tx_ac": tx_ac, + "tx_exon_start": tx_exon_start, + "tx_exon_end": tx_exon_end, + "gene": gene, + } + + cur = await self.execute_query(query, params) + row = await cur.fetchone() + + if not row: + msg = f"Unable to find a result where {tx_ac} has transcript coordinates ({tx_exon_start=}, {tx_exon_end=}) between an exon's start and end coordinates on {gene=}" _logger.warning(msg) - return None, msg - return GenomicAlnData(**result[0]), None + raise NoMatchingAlignmentError(msg) + + return GenomicAlnData( + hgnc=row[0], + alt_ac=row[1], + alt_start_i=row[2], + alt_end_i=row[3], + alt_strand=row[4], + ord=row[5], + ) async def get_cds_start_end(self, tx_ac: str) -> tuple[int, int] | None: - """Get coding start and end site + """Return CDS start/end coordinates for a transcript. + + Strips version from Ensembl accessions (``ENS*``) since UTA stores them + unversioned. :param tx_ac: Transcript accession - :return: [Coding start site, Coding end site] + :return: (cds_start_i, cds_end_i) if both exist, else None """ - if tx_ac.startswith("ENS"): - tx_ac = tx_ac.split(".")[0] - query = f""" + # As of 2026-03, Ensembl transcripts in UTA are unversioned, so we need to drop + # the version specifier + tx_ac = tx_ac.split(".", 1)[0] if tx_ac.startswith("ENS") else tx_ac + query = """ SELECT cds_start_i, cds_end_i - FROM {self.schema}.transcript - WHERE ac='{tx_ac}'; - """ # noqa: S608 - cds_start_end = await self.execute_query(query) + FROM transcript + WHERE ac=%(ac)s; + """ + cds_start_end = await ( + await self.execute_query(query, {"ac": tx_ac}) + ).fetchone() if cds_start_end: - cds_start_end = cds_start_end[0] if cds_start_end[0] is not None and cds_start_end[1] is not None: return cds_start_end[0], cds_start_end[1] else: @@ -333,33 +269,40 @@ async def get_cds_start_end(self, tx_ac: str) -> tuple[int, int] | None: return None async def get_newest_assembly_ac(self, ac: str) -> list[str]: - """Find accession associated to latest genomic assembly + """Return newest accession versions matching the given prefix + + If the accession is Ensembl (``EN`` prefix), results are ordered lexicographically. + Otherwise, RefSeq-style accessions are ordered by version number in descending order. - :param ac: Accession - :return: List of accessions associated to latest genomic assembly. Order by - desc + :param ac: Accession (versioned or unversioned) + :return: List of matching accessions, newest version first """ - # Ensembl accessions do not have versions + prefix = ac.split(".", 1)[0] + if ac.startswith("EN"): - order_by_cond = "ORDER BY ac;" + query = """ + SELECT ac + FROM _seq_anno_most_recent + WHERE ac LIKE %(ac_prefix)s + AND (descr IS NULL OR descr = '') + ORDER BY ac; + """ else: - order_by_cond = ( - "ORDER BY SUBSTR(ac, 0, position('.' in ac))," - "CAST(SUBSTR(ac, position('.' in ac) + 1, LENGTH(ac)) AS INT) DESC;" - ) - - query = f""" - SELECT ac - FROM {self.schema}._seq_anno_most_recent - WHERE ac LIKE '{ac.split(".")[0]}%' - AND ((descr IS NULL) OR (descr = '')) - {order_by_cond} - """ # noqa: S608 - results = await self.execute_query(query) - if not results: - return [] + query = """ + SELECT ac + FROM _seq_anno_most_recent + WHERE ac LIKE %(ac_prefix)s + AND (descr IS NULL OR descr = '') + ORDER BY + SUBSTR(ac, 0, POSITION('.' IN ac)), + CAST(SUBSTR(ac, POSITION('.' IN ac) + 1, LENGTH(ac)) AS INT) DESC; + """ - return [r["ac"] for r in results] + params = { + "ac_prefix": f"{prefix}%", + } + results = await (await self.execute_query(query, params)).fetchall() + return [r[0] for r in results] async def validate_genomic_ac(self, ac: str) -> bool: """Return whether or not genomic accession exists. @@ -367,15 +310,16 @@ async def validate_genomic_ac(self, ac: str) -> bool: :param ac: Genomic accession :return: ``True`` if genomic accession exists. ``False`` otherwise. """ - query = f""" + query = """ SELECT EXISTS( SELECT ac - FROM {self.schema}._seq_anno_most_recent - WHERE ac = '{ac}' + FROM _seq_anno_most_recent + WHERE ac = %(ac)s ); - """ # noqa: S608 - result = await self.execute_query(query) - return result[0][0] + """ + cursor = await self.execute_query(query, {"ac": ac}) + result = await cursor.fetchone() + return result[0] async def gene_exists(self, gene: str) -> bool: """Return whether or not a gene symbol exists in UTA gene table @@ -383,15 +327,16 @@ async def gene_exists(self, gene: str) -> bool: :param gene: Gene symbol :return ``True`` if gene symbol exists in UTA, ``False`` if not """ - query = f""" + query = """ SELECT EXISTS( SELECT hgnc - FROM {self.schema}.gene - WHERE hgnc = '{gene}' + FROM gene + WHERE hgnc = %(gene)s ); - """ # noqa: S608 - result = await self.execute_query(query) - return result[0][0] + """ + cursor = await self.execute_query(query, {"gene": gene}) + result = await cursor.fetchone() + return result[0] async def transcript_exists(self, transcript: str) -> bool: """Return whether or not a transcript exists in the UTA ``tx_exon_aln_mv`` table @@ -399,42 +344,41 @@ async def transcript_exists(self, transcript: str) -> bool: :param transcript: A transcript accession :return: ``True`` if transcript exists in UTA, ``False`` if not """ - query = f""" + query = """ SELECT EXISTS( SELECT tx_ac - FROM {self.schema}.tx_exon_aln_mv - WHERE tx_ac = '{transcript}' + FROM tx_exon_aln_mv + WHERE tx_ac = %(tx_ac)s ); - """ # noqa: S608 - result = await self.execute_query(query) - return result[0][0] + """ + cursor = await self.execute_query(query, {"tx_ac": transcript}) + result = await cursor.fetchone() + return result[0] async def get_ac_descr(self, ac: str) -> str | None: - """Return accession description. This is typically available only for accessions - from older (pre-GRCh38) builds. - - >>> import asyncio - >>> from cool_seq_tool.sources.uta_database import UtaDatabase - >>> async def describe(): - ... uta_db = await UtaDatabase.create() - ... result = await uta_db.get_ac_descr("NC_000001.10") - ... return result - >>> asyncio.run(describe()) + """Return free-text accession description + + This is typically available only for accessions from older (pre-GRCh38) builds. + + >>> async with uta.repository() as uta: + ... result = await uta.get_ac_descr("NC_000001.10") + >>> result 'Homo sapiens chromosome 1, GRCh37.p13 Primary Assembly' :param ac: chromosome accession, e.g. ``"NC_000001.10"`` - :return: Description containing assembly and chromosome + :return: Free-text description provided by source, generally containing assembly and chromosome """ - query = f""" + query = """ SELECT descr - FROM {self.schema}._seq_anno_most_recent - WHERE ac = '{ac}'; - """ # noqa: S608 - result = await self.execute_query(query) + FROM _seq_anno_most_recent + WHERE ac = %(ac)s; + """ + cursor = await self.execute_query(query, {"ac": ac}) + result = await cursor.fetchone() if not result: - _logger.warning("Accession %s does not have a description", ac) + _logger.warning("No description entry found for accession %s", ac) return None - result = result[0][0] + result = result[0] if result == "": result = None return result @@ -465,53 +409,73 @@ async def get_tx_exon_aln_data( ``False`` if tx_condition will be exact match :return: List of transcript exon alignment data """ + params: dict = {"start_pos": start_pos, "end_pos": end_pos} if tx_ac.startswith("EN"): - temp_ac = tx_ac.split(".")[0] - aln_method = f"AND alt_aln_method='genebuild'" # noqa: F541 + params["tx_ac"] = tx_ac.split(".")[0] + params["alt_aln_method"] = "genebuild" else: - temp_ac = tx_ac - aln_method = f"AND alt_aln_method='splign'" # noqa: F541 + params["tx_ac"] = tx_ac + params["alt_aln_method"] = "splign" if like_tx_ac: - tx_q = f"WHERE tx_ac LIKE '{temp_ac}%'" + params["tx_ac"] = f"{params['tx_ac']}%" + tx_q = "WHERE tx_ac LIKE %(tx_ac)s" else: - tx_q = f"WHERE tx_ac='{temp_ac}'" + tx_q = "WHERE tx_ac=%(tx_ac)s" order_by_cond = "ORDER BY CAST(SUBSTR(alt_ac, position('.' in alt_ac) + 1, LENGTH(alt_ac)) AS INT)" if alt_ac: - alt_ac_q = f"AND alt_ac = '{alt_ac}'" + alt_ac_q = "AND alt_ac = %(alt_ac)s" + params["alt_ac"] = alt_ac if alt_ac.startswith("EN"): order_by_cond = "ORDER BY alt_ac" else: - alt_ac_q = f"AND alt_ac LIKE 'NC_00%'" # noqa: F541 + alt_ac_q = "AND alt_ac LIKE 'NC_00%%'" if use_tx_pos: - pos_q = f"""tx_start_i AND tx_end_i""" # noqa: F541 + pos_q = """tx_start_i AND tx_end_i""" else: - pos_q = f"""alt_start_i AND alt_end_i""" # noqa: F541 + pos_q = """alt_start_i AND alt_end_i""" query = f""" SELECT hgnc, tx_ac, tx_start_i, tx_end_i, alt_ac, alt_start_i, alt_end_i, alt_strand, alt_aln_method, ord, tx_exon_id, alt_exon_id - FROM {self.schema}.tx_exon_aln_mv + FROM tx_exon_aln_mv {tx_q} {alt_ac_q} - {aln_method} - AND {start_pos} BETWEEN {pos_q} - AND {end_pos} BETWEEN {pos_q} + AND alt_aln_method = %(alt_aln_method)s + AND %(start_pos)s BETWEEN {pos_q} + AND %(end_pos)s BETWEEN {pos_q} {order_by_cond} """ # noqa: S608 - result = await self.execute_query(query) - if not result: + cursor = await self.execute_query(query, params) + results = await cursor.fetchall() + if not results: _logger.warning("Unable to find transcript alignment for query: %s", query) return [] - if alt_ac and not use_tx_pos and len(result) > 1: + if alt_ac and not use_tx_pos and len(results) > 1: _logger.debug( "Found more than one match for tx_ac %s and alt_ac = %s", - temp_ac, + params["tx_ac"], alt_ac, ) - return [TxExonAlnData(**r) for r in result] + return [ + TxExonAlnData( + hgnc=r[0], + tx_ac=r[1], + tx_start_i=r[2], + tx_end_i=r[3], + alt_ac=r[4], + alt_start_i=r[5], + alt_end_i=r[6], + alt_strand=r[7], + alt_aln_method=r[8], + ord=r[9], + tx_exon_id=r[10], + alt_exon_id=r[11], + ) + for r in results + ] @staticmethod def data_from_result(result: TxExonAlnData) -> GenomicTxData | None: @@ -548,19 +512,14 @@ async def get_mane_c_genomic_data( representation. This function parses queried data from the tx_exon_aln_mv table, and sorts the queried data by the most recent genomic build - >>> import asyncio - >>> from cool_seq_tool.sources import UtaDatabase - >>> async def get_braf_mane(): - ... uta_db = await UtaDatabase.create() - ... result = await uta_db.get_mane_c_genomic_data( + >>> async with uta_db.repository() as uta: + ... result = await uta.get_mane_c_genomic_data( ... "NM_004333.6", ... None, ... 140753335, ... 140753335, ... ) - ... return result - >>> braf = asyncio.run(get_braf_mane()) - >>> braf["alt_ac"] + >>> result.alt_ac 'NC_000007.14' :param ac: MANE transcript accession @@ -692,36 +651,28 @@ async def get_ac_from_gene(self, gene: str) -> list[str]: :param gene: Gene symbol :return: List of genomic accessions, sorted in desc order """ - query = f""" + query = """ SELECT DISTINCT alt_ac - FROM {self.schema}.genomic - WHERE hgnc = '{gene}' - AND alt_ac LIKE 'NC_00%' + FROM genomic + WHERE hgnc = %(gene)s + AND alt_ac LIKE 'NC_00%%' ORDER BY alt_ac; - """ # noqa: S608 - - records = await self.execute_query(query) - if not records: - return [] + """ - alt_acs = [r["alt_ac"] for r in records] + cursor = await self.execute_query(query, {"gene": gene}) + results = await cursor.fetchall() + alt_acs = [r[0] for r in results] alt_acs.sort(key=lambda x: int(x.split(".")[-1]), reverse=True) return alt_acs async def get_gene_from_ac( - self, ac: str, start_pos: int, end_pos: int + self, ac: str, start_pos: int, end_pos: int | None ) -> list[str] | None: """Get gene(s) within the provided coordinate range - >>> import asyncio - >>> from cool_seq_tool.sources import UtaDatabase - >>> async def get_gene(): - ... uta_db = await UtaDatabase.create() - ... result = await uta_db.get_gene_from_ac( - ... "NC_000017.11", 43044296, 43045802 - ... ) - ... return result - >>> asyncio.run(get_gene()) + >>> async with uta_db.repository() as uta: + ... result = await uta.get_gene_from_ac("NC_000017.11", 43044296, 43045802) + >>> result ['BRCA1'] :param ac: NC accession, e.g. ``"NC_000001.11"`` @@ -731,14 +682,17 @@ async def get_gene_from_ac( """ if end_pos is None: end_pos = start_pos - query = f""" + query = """ SELECT DISTINCT hgnc - FROM {self.schema}.genomic - WHERE alt_ac = '{ac}' - AND {start_pos} BETWEEN alt_start_i AND alt_end_i - AND {end_pos} BETWEEN alt_start_i AND alt_end_i; - """ # noqa: S608 - results = await self.execute_query(query) + FROM genomic + WHERE alt_ac = %(ac)s + AND %(start_pos)s BETWEEN alt_start_i AND alt_end_i + AND %(end_pos)s BETWEEN alt_start_i AND alt_end_i; + """ + cursor = await self.execute_query( + query, {"ac": ac, "start_pos": start_pos, "end_pos": end_pos} + ) + results = await cursor.fetchall() if not results: _logger.warning( "Unable to find gene between %s and %s on %s", start_pos, end_pos, ac @@ -787,16 +741,16 @@ async def get_transcripts( pos_cond = "" if start_pos is not None and end_pos is not None: if use_tx_pos: - pos_cond = f""" - AND {start_pos} + T.cds_start_i + pos_cond = """ + AND %(start_pos)s + T.cds_start_i BETWEEN ALIGN.tx_start_i AND ALIGN.tx_end_i - AND {end_pos} + T.cds_start_i + AND %(end_pos)s + T.cds_start_i BETWEEN ALIGN.tx_start_i AND ALIGN.tx_end_i """ else: - pos_cond = f""" - AND {start_pos} BETWEEN ALIGN.alt_start_i AND ALIGN.alt_end_i - AND {end_pos} BETWEEN ALIGN.alt_start_i AND ALIGN.alt_end_i + pos_cond = """ + AND %(start_pos)s BETWEEN ALIGN.alt_start_i AND ALIGN.alt_end_i + AND %(end_pos)s BETWEEN ALIGN.alt_start_i AND ALIGN.alt_end_i """ order_by_cond = """ @@ -806,29 +760,36 @@ async def get_transcripts( ALIGN.tx_end_i - ALIGN.tx_start_i DESC; """ if alt_ac: - alt_ac_cond = f"AND ALIGN.alt_ac = '{alt_ac}'" + alt_ac_cond = "AND ALIGN.alt_ac = %(alt_ac)s" if alt_ac.startswith("EN"): order_by_cond = "ORDER BY ALIGN.alt_ac;" else: - alt_ac_cond = "AND ALIGN.alt_ac LIKE 'NC_00%'" + alt_ac_cond = "AND ALIGN.alt_ac LIKE 'NC_00%%'" - gene_cond = f"AND T.hgnc = '{gene}'" if gene else "" + gene_cond = "AND T.hgnc = %(gene)s" if gene else "" query = f""" SELECT AA.pro_ac, AA.tx_ac, ALIGN.alt_ac, T.cds_start_i - FROM {self.schema}.associated_accessions as AA - JOIN {self.schema}.transcript as T ON T.ac = AA.tx_ac - JOIN {self.schema}.tx_exon_aln_mv as ALIGN ON T.ac = ALIGN.tx_ac + FROM associated_accessions as AA + JOIN transcript as T ON T.ac = AA.tx_ac + JOIN tx_exon_aln_mv as ALIGN ON T.ac = ALIGN.tx_ac WHERE ALIGN.alt_aln_method = 'splign' {gene_cond} {alt_ac_cond} {pos_cond} {order_by_cond} """ # noqa: S608 - results = await self.execute_query(query) - results = [ - (r["pro_ac"], r["tx_ac"], r["alt_ac"], r["cds_start_i"]) for r in results - ] + cursor = await self.execute_query( + query, + { + "start_pos": start_pos, + "end_pos": end_pos, + "gene": gene, + "alt_ac": alt_ac, + }, + ) + results = await cursor.fetchall() + results = [(r[0], r[1], r[2], r[3]) for r in results] results_df = pl.DataFrame(results, schema=schema, orient="row") if results: results_df = results_df.unique() @@ -837,10 +798,8 @@ async def get_transcripts( async def get_chr_assembly(self, ac: str) -> tuple[str, Assembly] | None: """Get chromosome and assembly for NC accession if not in GRCh38. - >>> import asyncio - >>> from cool_seq_tool.sources.uta_database import UtaDatabase - >>> uta_db = asyncio.run(UtaDatabase.create()) - >>> result = asyncio.run(uta_db.get_chr_assembly("NC_000007.13")) + >>> async with uta_db.repository() as uta: + ... result = await uta.get_chr_assembly("NC_000007.13") >>> result ('chr7', ) @@ -885,14 +844,13 @@ async def p_to_c_ac(self, p_ac: str) -> list[str]: query = f""" SELECT tx_ac - FROM {self.schema}.associated_accessions - WHERE pro_ac = '{p_ac}' + FROM associated_accessions + WHERE pro_ac = %(p_ac)s {order_by_cond} """ # noqa: S608 - result = await self.execute_query(query) - if result: - result = [r["tx_ac"] for r in result] - return result + cursor = await self.execute_query(query, {"p_ac": p_ac}) + result = await cursor.fetchall() + return [r[0] for r in result] async def get_transcripts_from_genomic_pos( self, alt_ac: str, g_pos: int @@ -903,48 +861,22 @@ async def get_transcripts_from_genomic_pos( :param g_pos: Genomic position :return: RefSeq transcripts on c. coordinate """ - query = f""" - SELECT distinct tx_ac - FROM {self.schema}.tx_exon_aln_mv - WHERE alt_ac = '{alt_ac}' - AND {g_pos} BETWEEN alt_start_i AND alt_end_i - AND tx_ac LIKE 'NM_%'; - """ # noqa: S608 - results = await self.execute_query(query) - if not results: - return [] + query = """ + SELECT distinct tx_ac + FROM tx_exon_aln_mv + WHERE alt_ac = %(alt_ac)s + AND %(g_pos)s BETWEEN alt_start_i AND alt_end_i + AND tx_ac LIKE 'NM_%%'; + """ + cursor = await self.execute_query(query, {"alt_ac": alt_ac, "g_pos": g_pos}) + results = await cursor.fetchall() return [item for sublist in results for item in sublist] - @staticmethod - def get_secret() -> str: - """Get secrets for UTA DB instances. Used for deployment on AWS. - - :raises ClientError: If unable to retrieve secret value due to decryption - decryption failure, internal service error, invalid parameter, invalid - request, or resource not found. - """ - secret_name = environ["UTA_DB_SECRET"] - region_name = "us-east-2" - - # Create a Secrets Manager client - session = boto3.session.Session() - client = session.client(service_name="secretsmanager", region_name=region_name) - - try: - get_secret_value_response = client.get_secret_value(SecretId=secret_name) - except ClientError: - # For a list of exceptions thrown, see - # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html - _logger.exception("Encountered AWS client error fetching UTA DB secret") - raise - else: - return get_secret_value_response["SecretString"] - class ParseResult(UrlLibParseResult): - """Subclass of url.ParseResult that adds database and schema methods, - and provides stringification. - Source: https://github.com/biocommons/hgvs + """Subclass of url.ParseResult that adds database and schema methods, and provides stringification. + + Inspired by: https://github.com/biocommons/hgvs """ def __new__(cls, pr): # noqa: ANN001, ANN204 @@ -987,3 +919,183 @@ def sanitized_url(self) -> str: self.fragment, ) ) + + +def _get_secret_args() -> str: + """Get secrets connection args for UTA DB instances. Used for deployment on AWS. + + This function is tightly coupled to our internal deployment policies; + it is subject to change or removal in the (distant) future. + + :return: connection URL consisting of params from secrets + :raises ClientError: If unable to retrieve secret value due to decryption + decryption failure, internal service error, invalid parameter, invalid + request, or resource not found. + """ + warnings.warn( + "Deprecated; subject to change in future releases, someday", + DeprecationWarning, + stacklevel=2, + ) + + secret_name = os.environ["UTA_DB_SECRET"] + region_name = "us-east-2" + + session = boto3.session.Session() + client = session.client(service_name="secretsmanager", region_name=region_name) + try: + get_secret_value_response = client.get_secret_value(SecretId=secret_name) + except ClientError: + # For a list of exceptions thrown, see + # https://docs.aws.amazon.com/secretsmanager/latest/apireference/API_GetSecretValue.html + _logger.exception("Encountered AWS client error fetching UTA DB secret") + raise + secret_val = get_secret_value_response["SecretString"] + secret = ast.literal_eval(secret_val) + + username, password = secret["username"], secret["password"] + port, host = secret["port"], secret["host"] + database = secret["dbname"] + schema = secret["schema"] + return f"postgresql://{username}{':' + password if password else ''}@{host}:{port}/{database}?options=-csearch_path%3D{schema},public" + + +DEFAULT_UTA_DB_URL = "postgresql://uta_admin@localhost:5432/uta?options=-csearch_path%3Duta_20241220,public" + + +async def create_uta_connection_pool( + db_url: str | None = None, initialize_genomic_table: bool = True +) -> AsyncConnectionPool: + """Create and initialize a UTA connection pool. + + Connection parameters are resolved in the following order: + + 1. If the ``UTA_DB_PROD`` environment variable is set, credentials and schema + are retrieved from a secret manager via ``_get_secret_args()``. + 2. Otherwise, if the ``db_url`` arg is defined, it's used + 3. If not provided, fall back to environment variable ``UTA_DB_URL`` + 4. If not declared, then use default value + + After opening the pool, a one-time initialization step is performed to ensure that + required genomic tables are present. + + :param db_url: PostgreSQL connection URI (e.g., ``postgresql://user@host:port/db?options=-csearch_path%3Duta_schema,public``). + If not provided, resolved from environment or defaults. + :param initialize_genomic_table: whether to attempt initialization of the ``genomic`` + table which is used/managed by coolseqtool. + :return: An open ``AsyncConnectionPool`` configured for the UTA database + """ + if "UTA_DB_PROD" in os.environ: + db_url = _get_secret_args() + elif db_url is None: + db_url = os.environ.get("UTA_DB_URL", DEFAULT_UTA_DB_URL) + _logger.info( + "Creating connection pool with db_uri '%s'", + ParseResult(urlparse(db_url)).sanitized_url, + ) + pool = AsyncConnectionPool(conninfo=db_url, open=False) + await pool.open() + if initialize_genomic_table: + try: + async with pool.connection() as conn: + await UtaRepository(conn).create_genomic_table() + # catch all exceptions -- this is probably a critical error, it's good to + # close the pool first + except: + await pool.close() + raise + return pool + + +class ClosedUtaConnectionError(Exception): + """Raise for attempts to access a UTA connection when it's been closed/deleted""" + + +class UtaDatabase: + """Provide pooled access to connection-scoped UTA repositories. + + This class owns or borrows an async psycopg connection pool and yields + ``UtaRepository`` instances bound to checked-out connections. + """ + + def __init__(self, pool: AsyncConnectionPool) -> None: + """Initialize access wrapper. + + :param pool: Existing async connection pool to use. If omitted, a default + pool is created lazily on first use. + """ + self._connection_pool = pool + + @asynccontextmanager + async def repository(self) -> AsyncIterator[UtaRepository]: + """Yield a ``UtaRepository`` backed by a pooled connection. + + If no pool has been provided yet, a default one is created on first use. + + :yield: Repository bound to an active pooled connection + :raise ClosedUtaConnectionError: if connection associated w/ this instance is closed + or nullified + """ + if self._connection_pool is None: + raise ClosedUtaConnectionError + async with self._connection_pool.connection() as conn: + yield UtaRepository(conn) + + async def close(self) -> None: + """Close the owned connection pool, if present.""" + if self._connection_pool is None: + _logger.info("Attempted to close nonexistent UTA access connection pool") + return + + await self._connection_pool.close() + self._connection_pool = None + + +class LazyUtaDatabase(UtaDatabase): + """UTA access wrapper with lazy connection pool initialization. + + This variant defers creation of the underlying connection pool until first use. + It exists primarily for backward compatibility with earlier APIs that did not + require explicit pool construction. + + Because configuration is resolved at runtime (via environment variables or + defaults), this class can introduce implicit behavior and is not recommended + for applications that require explicit control over database connections. + """ + + def __init__(self, pool: AsyncConnectionPool | None = None) -> None: + """Initialize the lazy access wrapper. + + :param pool: Optional existing async connection pool. If not provided, + a pool will be created on first use using environment variables + or default configuration. + """ + if pool is None: + _logger.info( + "LazyUtaDatabase initialized without a connection pool; " + "a pool will be created on first use from environment/default settings." + ) + self._connection_pool = pool + + async def open(self) -> None: + """Ensure that a connection pool has been initialized. + + If no pool is currently set, one is created using default configuration. + """ + if self._connection_pool is None: + _logger.debug("Creating UTA connection pool lazily on first use") + self._connection_pool = await create_uta_connection_pool() + + @asynccontextmanager + async def repository(self) -> AsyncIterator[UtaRepository]: + """Yield a repository backed by a pooled UTA connection. + + This method ensures that a connection pool exists, creating one if + necessary, and then yields a ``UtaRepository`` bound to a checked-out + connection. + + :yield: Repository bound to an active pooled connection + """ + await self.open() + async with self._connection_pool.connection() as conn: + yield UtaRepository(conn) diff --git a/tests/mappers/test_exon_genomic_coords.py b/tests/mappers/test_exon_genomic_coords.py index 821c051..13b6d91 100644 --- a/tests/mappers/test_exon_genomic_coords.py +++ b/tests/mappers/test_exon_genomic_coords.py @@ -1671,9 +1671,7 @@ async def test_invalid(test_egc_mapper, caplog): ) genomic_tx_seg_service_checks(resp, is_valid=False) assert resp.errors == [ - "Unable to find a result where NM_152263.3 has transcript coordinates" - " 0 and 234 between an exon's start and end coordinates on gene " - "NTKR1" + "Unable to find a result where NM_152263.3 has transcript coordinates (tx_exon_start=0, tx_exon_end=234) between an exon's start and end coordinates on gene='NTKR1'" ] # No exons given diff --git a/tests/sources/test_uta_database.py b/tests/sources/test_uta_database.py index 2686ca1..9c6d7a9 100644 --- a/tests/sources/test_uta_database.py +++ b/tests/sources/test_uta_database.py @@ -3,16 +3,28 @@ from urllib.parse import urlparse import pytest +import pytest_asyncio from cool_seq_tool.schemas import Strand from cool_seq_tool.sources.uta_database import ( GenomicTxData, GenomicTxMetadata, + NoMatchingAlignmentError, ParseResult, TxExonAlnData, + UtaRepository, + create_uta_connection_pool, ) +@pytest_asyncio.fixture +async def uta_repo(): + pool = await create_uta_connection_pool() + async with pool.connection() as conn: + yield UtaRepository(conn) + await pool.close() + + @pytest.fixture(scope="module") def tx_exon_aln_data(): """Create test fixture for tx_exon_aln_data test.""" @@ -48,92 +60,92 @@ def data_from_result(): @pytest.mark.asyncio -async def test_get_cds_start_end(test_db): +async def test_get_cds_start_end(uta_repo: UtaRepository): """Test that get_cds_start_end works correctly.""" expected = (61, 2362) - resp = await test_db.get_cds_start_end("NM_004333.4") + resp = await uta_repo.get_cds_start_end("NM_004333.4") assert resp == expected - resp = await test_db.get_cds_start_end("ENST00000288602.6") + resp = await uta_repo.get_cds_start_end("ENST00000288602.6") assert resp == expected - resp = await test_db.get_cds_start_end("NM_004333.999") + resp = await uta_repo.get_cds_start_end("NM_004333.999") assert resp is None @pytest.mark.asyncio -async def test_get_newest_assembly_ac(test_db): +async def test_get_newest_assembly_ac(uta_repo: UtaRepository): """Test that get_newest_assembly_ac works correctly.""" - resp = await test_db.get_newest_assembly_ac("NC_000007.13") + resp = await uta_repo.get_newest_assembly_ac("NC_000007.13") assert resp == ["NC_000007.14"] - resp = await test_db.get_newest_assembly_ac("NC_000011.9") + resp = await uta_repo.get_newest_assembly_ac("NC_000011.9") assert resp == ["NC_000011.10"] - resp = await test_db.get_newest_assembly_ac("NC_000011.10") + resp = await uta_repo.get_newest_assembly_ac("NC_000011.10") assert resp == ["NC_000011.10"] - resp = await test_db.get_newest_assembly_ac("ENST00000288602") + resp = await uta_repo.get_newest_assembly_ac("ENST00000288602") assert resp == ["ENST00000288602"] - resp = await test_db.get_newest_assembly_ac("NC_0000077.1") + resp = await uta_repo.get_newest_assembly_ac("NC_0000077.1") assert resp == [] @pytest.mark.asyncio -async def test_validate_genomic_ac(test_db): +async def test_validate_genomic_ac(uta_repo: UtaRepository): """Test that validate_genomic_ac""" - resp = await test_db.validate_genomic_ac("NC_000007.13") + resp = await uta_repo.validate_genomic_ac("NC_000007.13") assert resp is True - resp = await test_db.validate_genomic_ac("NC_000007.17") + resp = await uta_repo.validate_genomic_ac("NC_000007.17") assert resp is False @pytest.mark.asyncio -async def test_validate_gene_exists(test_db): +async def test_validate_gene_exists(uta_repo: UtaRepository): """Test validate_gene_symbol""" - resp = await test_db.gene_exists("TPM3") + resp = await uta_repo.gene_exists("TPM3") assert resp is True - resp = await test_db.gene_exists("dummy gene") + resp = await uta_repo.gene_exists("dummy gene") assert resp is False @pytest.mark.asyncio -async def test_validate_transcript_exists(test_db): +async def test_validate_transcript_exists(uta_repo: UtaRepository): """Tests validate_transcript""" - resp = await test_db.transcript_exists("NM_152263.3") + resp = await uta_repo.transcript_exists("NM_152263.3") assert resp is True - resp = await test_db.transcript_exists("NM_152263 3") + resp = await uta_repo.transcript_exists("NM_152263 3") assert resp is False @pytest.mark.asyncio -async def test_get_ac_descr(test_db): +async def test_get_ac_descr(uta_repo: UtaRepository): """Test that get_ac_descr works correctly.""" - resp = await test_db.get_ac_descr("NC_000007.13") + resp = await uta_repo.get_ac_descr("NC_000007.13") assert resp is not None - resp = await test_db.get_ac_descr("NC_000007.14") + resp = await uta_repo.get_ac_descr("NC_000007.14") assert resp is None @pytest.mark.asyncio -async def test_get_tx_exon_aln_data(test_db, tx_exon_aln_data): +async def test_get_tx_exon_aln_data(uta_repo: UtaRepository, tx_exon_aln_data): """Test that get_tx_exon_aln_data""" - resp = await test_db.get_tx_exon_aln_data( + resp = await uta_repo.get_tx_exon_aln_data( "NM_004333.4", 140453136, 140453136, alt_ac="NC_000007.13", use_tx_pos=False ) assert resp == [tx_exon_aln_data] - resp = await test_db.get_tx_exon_aln_data( + resp = await uta_repo.get_tx_exon_aln_data( "NM_004333.4", 140453136, 140453136, alt_ac=None, use_tx_pos=False ) assert resp == [tx_exon_aln_data] - resp = await test_db.get_tx_exon_aln_data( + resp = await uta_repo.get_tx_exon_aln_data( "NM_004333.4", 1860, 1860, alt_ac=None, use_tx_pos=True ) assert resp == [ @@ -169,16 +181,9 @@ async def test_get_tx_exon_aln_data(test_db, tx_exon_aln_data): @pytest.mark.asyncio -async def test_data_from_result(test_db, tx_exon_aln_data, data_from_result): - """Test that data_from_result works correctly.""" - resp = test_db.data_from_result(tx_exon_aln_data) - assert resp == data_from_result - - -@pytest.mark.asyncio -async def test_mane_c_genomic_data(test_db): +async def test_mane_c_genomic_data(uta_repo: UtaRepository): """Test that get_mane_c_genomic_data works correctly.""" - resp = await test_db.get_mane_c_genomic_data( + resp = await uta_repo.get_mane_c_genomic_data( "NM_001374258.1", None, 140753335, 140753335 ) expected_params = { @@ -199,7 +204,7 @@ async def test_mane_c_genomic_data(test_db): assert resp == GenomicTxMetadata(**expected_params) # Test example where sorting of tx_exon_aln_mv is needed - resp = await test_db.get_mane_c_genomic_data( + resp = await uta_repo.get_mane_c_genomic_data( "NM_000077.5", "NC_000009.12", 21971186, 21971187 ) expected_params = { @@ -220,17 +225,17 @@ async def test_mane_c_genomic_data(test_db): assert resp == GenomicTxMetadata(**expected_params) # Test case where chromosomal accession is not provided - resp = await test_db.get_mane_c_genomic_data( + resp = await uta_repo.get_mane_c_genomic_data( "NM_000077.5", None, 21971186, 21971187 ) assert resp == GenomicTxMetadata(**expected_params) @pytest.mark.asyncio -async def test_get_genomic_tx_data(test_db, genomic_tx_data): +async def test_get_genomic_tx_data(uta_repo: UtaRepository): """Test that get_genomic_tx_data works correctly.""" # Positive strand transcript - resp = await test_db.get_genomic_tx_data("NM_004327.3", (3595, 3596)) + resp = await uta_repo.get_genomic_tx_data("NM_004327.3", (3595, 3596)) expected_params = { "gene": "BCR", "strand": Strand.POSITIVE, @@ -247,7 +252,7 @@ async def test_get_genomic_tx_data(test_db, genomic_tx_data): assert resp == GenomicTxMetadata(**expected_params) # Negative strand transcript - resp = await test_db.get_genomic_tx_data("NM_004333.4", (2144, 2145)) + resp = await uta_repo.get_genomic_tx_data("NM_004333.4", (2144, 2145)) expected_params = { "gene": "BRAF", "strand": Strand.NEGATIVE, @@ -265,123 +270,116 @@ async def test_get_genomic_tx_data(test_db, genomic_tx_data): @pytest.mark.asyncio -async def test_get_ac_from_gene(test_db): +async def test_get_ac_from_gene(uta_repo: UtaRepository): """Test that get_ac_from_gene works correctly.""" - resp = await test_db.get_ac_from_gene("BRAF") + resp = await uta_repo.get_ac_from_gene("BRAF") assert resp == ["NC_000007.14", "NC_000007.13"] - resp = await test_db.get_ac_from_gene("HRAS") + resp = await uta_repo.get_ac_from_gene("HRAS") assert resp == ["NC_000011.10", "NC_000011.9"] - resp = await test_db.get_ac_from_gene("dummy") + resp = await uta_repo.get_ac_from_gene("dummy") assert resp == [] @pytest.mark.asyncio -async def test_get_gene_from_ac(test_db): +async def test_get_gene_from_ac(uta_repo: UtaRepository): """Tet that get_gene_from_ac works correctly.""" - resp = await test_db.get_gene_from_ac("NC_000007.13", 140453136, None) + resp = await uta_repo.get_gene_from_ac("NC_000007.13", 140453136, None) assert resp == ["BRAF"] - resp = await test_db.get_gene_from_ac("NC_000007.14", 140753336, None) + resp = await uta_repo.get_gene_from_ac("NC_000007.14", 140753336, None) assert resp == ["BRAF"] - resp = await test_db.get_gene_from_ac("NC_000007.13", 55249071, None) + resp = await uta_repo.get_gene_from_ac("NC_000007.13", 55249071, None) assert resp == ["EGFR", "EGFR-AS1"] - resp = await test_db.get_gene_from_ac("NC_0000078.1", 140453136, None) + resp = await uta_repo.get_gene_from_ac("NC_0000078.1", 140453136, None) assert resp is None @pytest.mark.asyncio -async def test_get_transcripts_from_gene(test_db): +async def test_get_transcripts_from_gene(uta_repo: UtaRepository): """Test that get_transcripts works correctly.""" - resp = await test_db.get_transcripts(start_pos=2145, end_pos=2145, gene="BRAF") + resp = await uta_repo.get_transcripts(start_pos=2145, end_pos=2145, gene="BRAF") assert len(resp) == 32 # using no start/end pos - resp = await test_db.get_transcripts(gene="BRAF") + resp = await uta_repo.get_transcripts(gene="BRAF") assert len(resp) == 32 # using 0 start/end pos - resp = await test_db.get_transcripts(gene="BRAF", start_pos=0, end_pos=0) + resp = await uta_repo.get_transcripts(gene="BRAF", start_pos=0, end_pos=0) assert len(resp) == 32 # using 0 genomic start/end pos - resp = await test_db.get_transcripts( + resp = await uta_repo.get_transcripts( gene="BRAF", start_pos=0, end_pos=0, use_tx_pos=False ) assert len(resp) == 0 # using gene with genomic pos - resp = await test_db.get_transcripts( + resp = await uta_repo.get_transcripts( gene="BRAF", start_pos=140753336, end_pos=140753336, use_tx_pos=False ) assert len(resp) == 16 - resp = await test_db.get_transcripts( + resp = await uta_repo.get_transcripts( gene="BRAF", start_pos=140453136, end_pos=140453136 ) assert len(resp) == 0 # No gene and no alt_ac - resp = await test_db.get_transcripts(start_pos=140453136, end_pos=140453136) + resp = await uta_repo.get_transcripts(start_pos=140453136, end_pos=140453136) assert len(resp) == 0 @pytest.mark.asyncio -async def test_get_chr_assembly(test_db): +async def test_get_chr_assembly(uta_repo: UtaRepository): """Test that get_chr_assembly works correctly.""" - resp = await test_db.get_chr_assembly("NC_000007.13") + resp = await uta_repo.get_chr_assembly("NC_000007.13") assert resp == ("chr7", "GRCh37") - resp = await test_db.get_chr_assembly("NC_000007.14") + resp = await uta_repo.get_chr_assembly("NC_000007.14") assert resp is None # Invalid ac - resp = await test_db.get_chr_assembly("NC_00000714") + resp = await uta_repo.get_chr_assembly("NC_00000714") assert resp is None @pytest.mark.asyncio -async def test_p_to_c_ac(test_db): +async def test_p_to_c_ac(uta_repo: UtaRepository): """Test that p_to_c_ac works correctly.""" - resp = await test_db.p_to_c_ac("NP_004324.2") + resp = await uta_repo.p_to_c_ac("NP_004324.2") assert resp == ["NM_004333.4", "NM_004333.5", "NM_004333.6"] - resp = await test_db.p_to_c_ac("NP_064502.9") + resp = await uta_repo.p_to_c_ac("NP_064502.9") assert resp == ["NM_020117.9", "NM_020117.10", "NM_020117.11"] - resp = await test_db.p_to_c_ac("NP_004324.22") + resp = await uta_repo.p_to_c_ac("NP_004324.22") assert resp == [] @pytest.mark.asyncio async def test_get_alt_ac_start_or_end( - test_db, tpm3_1_8_start_genomic, tpm3_1_8_end_genomic + uta_repo: UtaRepository, tpm3_1_8_start_genomic, tpm3_1_8_end_genomic ): """Test that get_alt_ac_start_or_end works correctly.""" - resp = await test_db.get_alt_ac_start_or_end("NM_152263.3", 117, 234, None) - assert resp[0] == tpm3_1_8_start_genomic - assert resp[1] is None - - resp = await test_db.get_alt_ac_start_or_end("NM_152263.3", 822, 892, None) - assert resp[0] == tpm3_1_8_end_genomic - assert resp[1] is None - - resp = await test_db.get_alt_ac_start_or_end("NM_152263.63", 822, 892, None) - assert resp[0] is None - assert ( - resp[1] == "Unable to find a result where NM_152263.63 has " - "transcript coordinates 822 and 892 between an exon's " - "start and end coordinates" - ) + resp = await uta_repo.get_alt_ac_start_or_end("NM_152263.3", 117, 234, None) + assert resp == tpm3_1_8_start_genomic + + resp = await uta_repo.get_alt_ac_start_or_end("NM_152263.3", 822, 892, None) + assert resp == tpm3_1_8_end_genomic + + with pytest.raises(NoMatchingAlignmentError): + await uta_repo.get_alt_ac_start_or_end("NM_152263.63", 822, 892, None) @pytest.mark.asyncio -async def test_get_mane_transcripts_from_genomic_pos(test_db): +async def test_get_mane_transcripts_from_genomic_pos(uta_repo: UtaRepository): """Test that get_mane_transcripts_from_genomic_pos works correctly""" - resp = await test_db.get_transcripts_from_genomic_pos("NC_000007.14", 140753336) + resp = await uta_repo.get_transcripts_from_genomic_pos("NC_000007.14", 140753336) assert set(resp) == { "NM_001354609.1", "NM_001354609.2", @@ -402,11 +400,11 @@ async def test_get_mane_transcripts_from_genomic_pos(test_db): } # invalid pos - resp = await test_db.get_transcripts_from_genomic_pos("NC_000007.14", 150753336) + resp = await uta_repo.get_transcripts_from_genomic_pos("NC_000007.14", 150753336) assert resp == [] # invalid ac - resp = await test_db.get_transcripts_from_genomic_pos("NC_000007.14232", 140753336) + resp = await uta_repo.get_transcripts_from_genomic_pos("NC_000007.14232", 140753336) assert resp == []