diff --git a/backend/app/core/batch/gemini.py b/backend/app/core/batch/gemini.py index 3d317f1e8..f8a61fc6b 100644 --- a/backend/app/core/batch/gemini.py +++ b/backend/app/core/batch/gemini.py @@ -2,6 +2,7 @@ import json import logging +import mimetypes import os import tempfile import time @@ -11,8 +12,6 @@ from google import genai from google.genai import types -from app.core.storage_utils import get_mime_from_url - from .base import BATCH_KEY, BatchProvider logger = logging.getLogger(__name__) @@ -292,6 +291,46 @@ def _extract_text_from_response_dict(response: dict[str, Any]) -> str: """Extract text content from a Gemini response dictionary.""" return extract_text_from_response_dict(response) + def _upload_to_gemini( + self, + content: str | bytes, + suffix: str, + mime_type: str, + display_name: str, + ) -> types.File: + """Write content to a temp file, upload to Gemini, and clean up. + + Args: + content: File content (text or binary) + suffix: Temp file suffix (e.g., ".jsonl", ".mp3") + mime_type: MIME type for the upload + display_name: Display name in Gemini + + Returns: + Gemini File object + """ + # "w" for text (JSONL batch) or "wb" for binary (audio files for STT), + # since this method accepts content: str | bytes. + mode = "w" if isinstance(content, str) else "wb" + kwargs: dict[str, Any] = {"suffix": suffix, "delete": False, "mode": mode} + if mode == "w": + kwargs["encoding"] = "utf-8" + + with tempfile.NamedTemporaryFile(**kwargs) as tmp_file: + tmp_file.write(content) + tmp_path = tmp_file.name + + try: + return self._client.files.upload( + file=tmp_path, + config=types.UploadFileConfig( + display_name=display_name, + mime_type=mime_type, + ), + ) + finally: + os.unlink(tmp_path) + def upload_file(self, content: str, purpose: str = "batch") -> str: """Upload a JSONL file to Gemini Files API. @@ -305,35 +344,98 @@ def upload_file(self, content: str, purpose: str = "batch") -> str: logger.info(f"[upload_file] Uploading file to Gemini | bytes={len(content)}") try: - with tempfile.NamedTemporaryFile( - suffix=".jsonl", delete=False, mode="w", encoding="utf-8" - ) as tmp_file: - tmp_file.write(content) - tmp_path = tmp_file.name + uploaded_file = self._upload_to_gemini( + content=content, + suffix=".jsonl", + mime_type="jsonl", + display_name=f"batch-input-{int(time.time())}", + ) - try: - uploaded_file = self._client.files.upload( - file=tmp_path, - config=types.UploadFileConfig( - display_name=f"batch-input-{int(time.time())}", - mime_type="jsonl", - ), - ) + logger.info( + f"[upload_file] Uploaded file to Gemini | " + f"file_name={uploaded_file.name}" + ) - logger.info( - f"[upload_file] Uploaded file to Gemini | " - f"file_name={uploaded_file.name}" - ) + return uploaded_file.name + + except Exception as e: + logger.error(f"[upload_file] Failed to upload file to Gemini | {e}") + raise - return uploaded_file.name + def upload_audio_file( + self, + content: bytes, + mime_type: str, + display_name: str | None = None, + ) -> tuple[str, str]: + """Upload an audio file to Gemini File API. - finally: - os.unlink(tmp_path) + Args: + content: Raw audio file bytes + mime_type: MIME type of the audio (e.g., 'audio/mpeg') + display_name: Optional display name for the file + + Returns: + Tuple of (file_name, file_uri): + - file_name: Short name for API calls (e.g., "files/xxx") + - file_uri: Full URI for use in batch requests + (e.g., "https://generativelanguage.googleapis.com/v1beta/files/xxx") + """ + display_name = display_name or f"stt-audio-{int(time.time())}" + logger.info( + f"[upload_audio_file] Uploading audio to Gemini | " + f"bytes={len(content)} | mime_type={mime_type} | display_name={display_name}" + ) + + try: + uploaded_file = self._upload_to_gemini( + content=content, + suffix=mimetypes.guess_extension(mime_type) or ".bin", + mime_type=mime_type, + display_name=display_name, + ) + + logger.info( + f"[upload_audio_file] Uploaded audio to Gemini | " + f"file_name={uploaded_file.name} | file_uri={uploaded_file.uri}" + ) + + return uploaded_file.name, uploaded_file.uri except Exception as e: - logger.error(f"[upload_file] Failed to upload file to Gemini | {e}") + logger.error(f"[upload_audio_file] Failed to upload audio to Gemini | {e}") raise + def delete_files(self, file_names: list[str]) -> tuple[int, int]: + """Delete files from Gemini File API. + + Args: + file_names: List of Gemini file names to delete (e.g., ["files/xxx", ...]) + + Returns: + Tuple of (success_count, failure_count) + """ + success_count = 0 + failure_count = 0 + + for name in file_names: + try: + self._client.files.delete(name=name) + success_count += 1 + except Exception as e: + failure_count += 1 + logger.warning( + f"[delete_files] Failed to delete Gemini file | " + f"file_name={name} | error={e}" + ) + + logger.info( + f"[delete_files] Gemini file cleanup complete | " + f"deleted={success_count}, failed={failure_count}" + ) + + return success_count, failure_count + def download_file(self, file_id: str) -> str: """Download a file from Gemini Files API. @@ -387,18 +489,20 @@ def _extract_text_from_response(response: Any) -> str: def create_stt_batch_requests( - signed_urls: list[str], + file_uris: list[str], + mime_types: list[str], prompt: str, keys: list[str] | None = None, ) -> list[dict[str, Any]]: """ - Create batch API requests for Gemini STT using signed URLs. + Create batch API requests for Gemini STT using Gemini File API URIs. This function generates request payloads in Gemini's JSONL batch format - using signed URLs directly. MIME types are automatically detected from the URL path. + using file URIs from the Gemini File API. Args: - signed_urls: List of signed URLs pointing to audio files + file_uris: List of Gemini file URIs (e.g., "files/abc123") + mime_types: List of MIME types corresponding to each file URI prompt: Transcription prompt/instructions for the model keys: Optional list of custom IDs for tracking results. If not provided, uses 0-indexed integers as strings. @@ -408,25 +512,25 @@ def create_stt_batch_requests( {"key": "sample-1", "request": {"contents": [...]}} Example: - >>> urls = ["https://bucket.s3.amazonaws.com/audio.mp3?..."] + >>> uris = ["files/abc123"] + >>> mime_types = ["audio/mpeg"] >>> prompt = "Transcribe this audio file." - >>> requests = create_stt_batch_requests(urls, prompt, keys=["sample-1"]) + >>> requests = create_stt_batch_requests(uris, mime_types, prompt, keys=["sample-1"]) >>> provider.create_batch(requests, {"display_name": "stt-batch"}) """ - if keys is not None and len(keys) != len(signed_urls): + if len(file_uris) != len(mime_types): + raise ValueError( + f"Length of file_uris ({len(file_uris)}) must match mime_types ({len(mime_types)})" + ) + + if keys is not None and len(keys) != len(file_uris): raise ValueError( - f"Length of keys ({len(keys)}) must match signed_urls ({len(signed_urls)})" + f"Length of keys ({len(keys)}) must match file_uris ({len(file_uris)})" ) requests = [] - for i, url in enumerate(signed_urls): - mime_type = get_mime_from_url(url) - if mime_type is None: - logger.warning( - f"[create_stt_batch_requests] Could not determine MIME type for URL | " - f"index={i} | defaulting to audio/mpeg" - ) - mime_type = "audio/mpeg" + for i, uri in enumerate(file_uris): + mime_type = mime_types[i] # Use provided key or generate from index key = keys[i] if keys is not None else str(i) @@ -439,7 +543,7 @@ def create_stt_batch_requests( { "parts": [ {"text": prompt}, - {"file_data": {"mime_type": mime_type, "file_uri": url}}, + {"file_data": {"mime_type": mime_type, "file_uri": uri}}, ], "role": "user", } diff --git a/backend/app/crud/stt_evaluations/batch.py b/backend/app/crud/stt_evaluations/batch.py index 6275f8f16..a68d35d7d 100644 --- a/backend/app/crud/stt_evaluations/batch.py +++ b/backend/app/crud/stt_evaluations/batch.py @@ -2,7 +2,7 @@ import logging from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Any +from typing import Any, NamedTuple from sqlmodel import Session @@ -12,15 +12,25 @@ create_stt_batch_requests, start_batch_job, ) -from app.models.batch_job import BatchJobType from app.core.cloud.storage import get_cloud_storage +from app.core.storage_utils import get_mime_from_url from app.crud.file import get_files_by_ids from app.crud.stt_evaluations.run import update_stt_run from app.models import EvaluationRun +from app.models.batch_job import BatchJobType from app.models.stt_evaluation import STTSample logger = logging.getLogger(__name__) + +class _UploadResult(NamedTuple): + sample: STTSample + file_uri: str | None + file_name: str | None + mime_type: str | None + error: str | None + + DEFAULT_TRANSCRIPTION_PROMPT = ( "Generate a verbatim transcript of the speech in this audio file. " "Return only the transcription text without any formatting, timestamps, or metadata." @@ -36,12 +46,12 @@ def start_stt_evaluation_batch( samples: list[STTSample], org_id: int, project_id: int, - signed_url_expires_in: int = 86400, ) -> dict[str, Any]: - """Generate signed URLs and submit Gemini batch jobs for STT evaluation. + """Upload audio files to Gemini and submit batch jobs for STT evaluation. - Submits one batch job per model. Each batch job is tracked via - its config containing evaluation_run_id and stt_provider. + Downloads audio from S3 and uploads to Gemini File API, then submits + one batch job per model. Each batch job is tracked via its config + containing evaluation_run_id and stt_provider. Args: session: Database session @@ -49,7 +59,6 @@ def start_stt_evaluation_batch( samples: List of STT samples to process org_id: Organization ID project_id: Project ID - signed_url_expires_in: Signed URL expiry in seconds (default: 24 hours) Returns: dict: Result with batch job information per model @@ -85,56 +94,91 @@ def start_stt_evaluation_batch( ) file_map = {f.id: f for f in file_records} - # Generate signed URLs for audio files concurrently (shared across all models) - signed_urls: list[str] = [] + # Upload audio files to Gemini File API concurrently (shared across all models) + upload_provider = GeminiBatchProvider(client=gemini_client.client) + + file_uris: list[str] = [] + mime_types: list[str] = [] sample_keys: list[str] = [] + gemini_file_names: list[str] = [] failed_samples: list[tuple[STTSample, str]] = [] - def _generate_signed_url( - sample: STTSample, - ) -> tuple[STTSample, str | None, str | None]: - """Generate a signed URL for a single sample. Thread-safe.""" + def _upload_to_gemini(sample: STTSample) -> _UploadResult: + """Download from S3 and upload to Gemini File API. Thread-safe.""" file_record = file_map.get(sample.file_id) if not file_record: - return sample, None, f"File record not found for file_id: {sample.file_id}" + return _UploadResult( + sample=sample, + file_uri=None, + file_name=None, + mime_type=None, + error=f"File record not found for file_id: {sample.file_id}", + ) try: - url = storage.get_signed_url( - file_record.object_store_url, expires_in=signed_url_expires_in + # Detect MIME type from S3 URL path + mime_type = get_mime_from_url(file_record.object_store_url) + if mime_type is None: + mime_type = file_record.content_type or "audio/mpeg" + + # Download audio from S3 + body = storage.stream(file_record.object_store_url) + audio_bytes = body.read() + + # Upload to Gemini File API + file_name, file_uri = upload_provider.upload_audio_file( + content=audio_bytes, + mime_type=mime_type, + display_name=f"stt-eval-{run.id}-sample-{sample.id}", + ) + return _UploadResult( + sample=sample, + file_uri=file_uri, + file_name=file_name, + mime_type=mime_type, + error=None, ) - return sample, url, None except Exception as e: - return sample, None, str(e) + return _UploadResult( + sample=sample, + file_uri=None, + file_name=None, + mime_type=None, + error=str(e), + ) - with ThreadPoolExecutor(max_workers=10) as executor: - sign_url_tasks = { - executor.submit(_generate_signed_url, sample): sample for sample in samples + with ThreadPoolExecutor(max_workers=5) as executor: + upload_tasks = { + executor.submit(_upload_to_gemini, sample): sample for sample in samples } - for completed_task in as_completed(sign_url_tasks): - sample, url, error = completed_task.result() - if url: - signed_urls.append(url) - sample_keys.append(str(sample.id)) + for completed_task in as_completed(upload_tasks): + result = completed_task.result() + if result.file_uri: + file_uris.append(result.file_uri) + mime_types.append(result.mime_type) + sample_keys.append(str(result.sample.id)) + gemini_file_names.append(result.file_name) else: - failed_samples.append((sample, error)) + failed_samples.append((result.sample, result.error)) logger.error( - f"[start_stt_evaluation_batch] Failed to generate signed URL | " - f"sample_id: {sample.id}, error: {error}" + f"[start_stt_evaluation_batch] Failed to upload to Gemini | " + f"sample_id: {result.sample.id}, error: {result.error}" ) if failed_samples: logger.warning( - f"[start_stt_evaluation_batch] Signed URL failures | " + f"[start_stt_evaluation_batch] Gemini upload failures | " f"run_id: {run.id}, failed_count: {len(failed_samples)}, " - f"succeeded_count: {len(signed_urls)}" + f"succeeded_count: {len(file_uris)}" ) - if not signed_urls: - raise Exception("Failed to generate signed URLs for any audio files") + if not file_uris: + raise Exception("Failed to upload audio files to Gemini for any samples") # Create JSONL batch requests (shared across all models) jsonl_data = create_stt_batch_requests( - signed_urls=signed_urls, + file_uris=file_uris, + mime_types=mime_types, prompt=DEFAULT_TRANSCRIPTION_PROMPT, keys=sample_keys, ) @@ -162,6 +206,7 @@ def _generate_signed_url( "model": model, "stt_provider": model, "evaluation_run_id": run.id, + "gemini_audio_files": gemini_file_names, }, ) @@ -199,12 +244,12 @@ def _generate_signed_url( logger.info( f"[start_stt_evaluation_batch] Batch submission complete | " f"run_id: {run.id}, models_submitted: {list(batch_jobs.keys())}, " - f"sample_count: {len(signed_urls)}" + f"sample_count: {len(file_uris)}" ) return { "success": True, "run_id": run.id, "batch_jobs": batch_jobs, - "sample_count": len(signed_urls), + "sample_count": len(file_uris), } diff --git a/backend/app/crud/stt_evaluations/cron.py b/backend/app/crud/stt_evaluations/cron.py index 995685eb8..0284639e1 100644 --- a/backend/app/crud/stt_evaluations/cron.py +++ b/backend/app/crud/stt_evaluations/cron.py @@ -7,6 +7,7 @@ processing runs, grouped by project_id for credential management. """ +import asyncio import logging from typing import Any @@ -128,7 +129,31 @@ async def _on_batch_succeeded(batch_job: BatchJob, provider_name: str) -> bool: action="no_change", ) - # All batch jobs are done - finalize the run + # All batch jobs are done - clean up Gemini audio files once + gemini_file_names: list[str] = [] + for bj in batch_jobs: + names = bj.config.get("gemini_audio_files", []) + if names: + gemini_file_names = names + break + + if gemini_file_names: + try: + deleted, failed = await asyncio.to_thread( + batch_provider.delete_files, gemini_file_names + ) + logger.info( + f"[poll_stt_run] Gemini file cleanup | " + f"run_id={run.id}, deleted={deleted}, failed={failed}" + ) + except Exception as e: + # Non-critical; Gemini files auto-expire after 48h + logger.warning( + f"[poll_stt_run] Gemini file cleanup failed | " + f"run_id={run.id}, error={str(e)}" + ) + + # Finalize the run status_counts = count_results_by_status(session=session, run_id=run.id) failed_count = status_counts.get(JobStatus.FAILED.value, 0) diff --git a/backend/app/tests/core/batch/test_gemini.py b/backend/app/tests/core/batch/test_gemini.py index c9281cfb2..59e46bdee 100644 --- a/backend/app/tests/core/batch/test_gemini.py +++ b/backend/app/tests/core/batch/test_gemini.py @@ -12,18 +12,20 @@ ) -class TestGeminiBatchProvider: - """Test cases for GeminiBatchProvider.""" +@pytest.fixture +def mock_genai_client(): + """Create a mock Gemini client.""" + return MagicMock() - @pytest.fixture - def mock_genai_client(self): - """Create a mock Gemini client.""" - return MagicMock() - @pytest.fixture - def provider(self, mock_genai_client): - """Create a GeminiBatchProvider instance with mock client.""" - return GeminiBatchProvider(client=mock_genai_client) +@pytest.fixture +def provider(mock_genai_client): + """Create a GeminiBatchProvider instance with mock client.""" + return GeminiBatchProvider(client=mock_genai_client) + + +class TestGeminiBatchProvider: + """Test cases for GeminiBatchProvider.""" @pytest.fixture def provider_with_model(self, mock_genai_client): @@ -390,15 +392,15 @@ def test_download_file_success(self, provider, mock_genai_client): def test_download_file_unicode_content(self, provider, mock_genai_client): """Test downloading file with unicode content.""" file_id = "files/abc123" - expected_content = '{"text":"Hello δΈ–η•Œ 🌍"}' + expected_content = '{"text":"Hello \u4e16\u754c \U0001f30d"}' mock_genai_client.files.download.return_value = expected_content.encode("utf-8") content = provider.download_file(file_id) assert content == expected_content - assert "δΈ–η•Œ" in content - assert "🌍" in content + assert "\u4e16\u754c" in content + assert "\U0001f30d" in content def test_download_file_error(self, provider, mock_genai_client): """Test handling of error during file download.""" @@ -412,6 +414,118 @@ def test_download_file_error(self, provider, mock_genai_client): assert "File not found" in str(exc_info.value) +class TestUploadAudioFile: + """Test cases for GeminiBatchProvider.upload_audio_file.""" + + def test_upload_audio_file_success(self, provider, mock_genai_client): + """Test successful audio file upload to Gemini.""" + audio_bytes = b"\xff\xfb\x90\x00" * 100 # fake MP3 bytes + + mock_uploaded_file = MagicMock() + mock_uploaded_file.name = "files/audio-abc123" + mock_uploaded_file.uri = ( + "https://generativelanguage.googleapis.com/v1beta/files/audio-abc123" + ) + mock_genai_client.files.upload.return_value = mock_uploaded_file + + with patch("tempfile.NamedTemporaryFile") as mock_temp: + mock_temp_file = MagicMock() + mock_temp_file.name = "/tmp/test.mp3" + mock_temp.return_value.__enter__.return_value = mock_temp_file + + with patch("os.unlink"): + file_name, file_uri = provider.upload_audio_file( + content=audio_bytes, + mime_type="audio/mpeg", + display_name="stt-eval-1-sample-42", + ) + + assert file_name == "files/audio-abc123" + assert ( + file_uri + == "https://generativelanguage.googleapis.com/v1beta/files/audio-abc123" + ) + mock_genai_client.files.upload.assert_called_once() + + def test_upload_audio_file_error(self, provider, mock_genai_client): + """Test error handling during audio file upload.""" + mock_genai_client.files.upload.side_effect = Exception("Quota exceeded") + + with patch("tempfile.NamedTemporaryFile") as mock_temp: + mock_temp_file = MagicMock() + mock_temp_file.name = "/tmp/test.mp3" + mock_temp.return_value.__enter__.return_value = mock_temp_file + + with patch("os.unlink"): + with pytest.raises(Exception) as exc_info: + provider.upload_audio_file( + content=b"audio-data", + mime_type="audio/mpeg", + ) + + assert "Quota exceeded" in str(exc_info.value) + + def test_upload_audio_file_default_display_name(self, provider, mock_genai_client): + """Test that a default display name is generated when not provided.""" + mock_uploaded_file = MagicMock() + mock_uploaded_file.name = "files/audio-xyz" + mock_uploaded_file.uri = ( + "https://generativelanguage.googleapis.com/v1beta/files/audio-xyz" + ) + mock_genai_client.files.upload.return_value = mock_uploaded_file + + with patch("tempfile.NamedTemporaryFile") as mock_temp: + mock_temp_file = MagicMock() + mock_temp_file.name = "/tmp/test.wav" + mock_temp.return_value.__enter__.return_value = mock_temp_file + + with patch("os.unlink"): + file_name, file_uri = provider.upload_audio_file( + content=b"audio-data", + mime_type="audio/x-wav", + ) + + assert file_name == "files/audio-xyz" + assert "generativelanguage.googleapis.com" in file_uri + + +class TestDeleteFiles: + """Test cases for GeminiBatchProvider.delete_files.""" + + def test_delete_files_all_success(self, provider, mock_genai_client): + """Test successful deletion of all files.""" + file_names = ["files/a", "files/b", "files/c"] + + success, failed = provider.delete_files(file_names) + + assert success == 3 + assert failed == 0 + assert mock_genai_client.files.delete.call_count == 3 + + def test_delete_files_partial_failure(self, provider, mock_genai_client): + """Test that partial failures are handled gracefully.""" + file_names = ["files/a", "files/b", "files/c"] + + mock_genai_client.files.delete.side_effect = [ + None, + Exception("Not found"), + None, + ] + + success, failed = provider.delete_files(file_names) + + assert success == 2 + assert failed == 1 + + def test_delete_files_empty_list(self, provider, mock_genai_client): + """Test deletion with empty file list.""" + success, failed = provider.delete_files([]) + + assert success == 0 + assert failed == 0 + mock_genai_client.files.delete.assert_not_called() + + class TestBatchJobState: """Test cases for BatchJobState enum.""" @@ -430,14 +544,12 @@ class TestCreateSTTBatchRequests: def test_create_requests_with_keys(self): """Test creating batch requests with custom keys.""" - signed_urls = [ - "https://bucket.s3.amazonaws.com/audio1.mp3?signature=abc", - "https://bucket.s3.amazonaws.com/audio2.wav?signature=def", - ] + file_uris = ["files/audio1", "files/audio2"] + mime_types = ["audio/mpeg", "audio/x-wav"] prompt = "Transcribe this audio file." keys = ["sample-1", "sample-2"] - requests = create_stt_batch_requests(signed_urls, prompt, keys=keys) + requests = create_stt_batch_requests(file_uris, mime_types, prompt, keys=keys) assert len(requests) == 2 assert requests[0]["key"] == "sample-1" @@ -455,26 +567,22 @@ def test_create_requests_with_keys(self): def test_create_requests_without_keys(self): """Test creating batch requests without keys (auto-generated).""" - signed_urls = [ - "https://bucket.s3.amazonaws.com/audio.mp3?signature=xyz", - ] + file_uris = ["files/audio1"] + mime_types = ["audio/mpeg"] prompt = "Transcribe." - requests = create_stt_batch_requests(signed_urls, prompt) + requests = create_stt_batch_requests(file_uris, mime_types, prompt) assert len(requests) == 1 assert requests[0]["key"] == "0" - def test_create_requests_mime_type_detection(self): - """Test that MIME types are correctly detected from URLs.""" - signed_urls = [ - "https://bucket.s3.amazonaws.com/audio.mp3?sig=1", - "https://bucket.s3.amazonaws.com/audio.wav?sig=2", - "https://bucket.s3.amazonaws.com/audio.m4a?sig=3", - ] + def test_create_requests_mime_types_passed_through(self): + """Test that provided MIME types are used in requests.""" + file_uris = ["files/audio1", "files/audio2", "files/audio3"] + mime_types = ["audio/mpeg", "audio/x-wav", "audio/mp4"] prompt = "Transcribe." - requests = create_stt_batch_requests(signed_urls, prompt) + requests = create_stt_batch_requests(file_uris, mime_types, prompt) assert ( requests[0]["request"]["contents"][0]["parts"][1]["file_data"]["mime_type"] @@ -484,38 +592,46 @@ def test_create_requests_mime_type_detection(self): requests[1]["request"]["contents"][0]["parts"][1]["file_data"]["mime_type"] == "audio/x-wav" ) - # .m4a can return different MIME types depending on the system - m4a_mime = requests[2]["request"]["contents"][0]["parts"][1]["file_data"][ - "mime_type" - ] - assert m4a_mime in ("audio/mp4", "audio/mp4a-latm", "audio/x-m4a") + assert ( + requests[2]["request"]["contents"][0]["parts"][1]["file_data"]["mime_type"] + == "audio/mp4" + ) def test_create_requests_key_length_mismatch(self): - """Test that mismatched keys and URLs raise error.""" - signed_urls = [ - "https://example.com/audio1.mp3", - "https://example.com/audio2.mp3", - ] + """Test that mismatched keys and URIs raise error.""" + file_uris = ["files/audio1", "files/audio2"] + mime_types = ["audio/mpeg", "audio/mpeg"] keys = ["only-one-key"] prompt = "Transcribe." with pytest.raises(ValueError) as exc_info: - create_stt_batch_requests(signed_urls, prompt, keys=keys) + create_stt_batch_requests(file_uris, mime_types, prompt, keys=keys) assert "Length of keys" in str(exc_info.value) + def test_create_requests_mime_types_length_mismatch(self): + """Test that mismatched file_uris and mime_types raise error.""" + file_uris = ["files/audio1", "files/audio2"] + mime_types = ["audio/mpeg"] + prompt = "Transcribe." + + with pytest.raises(ValueError) as exc_info: + create_stt_batch_requests(file_uris, mime_types, prompt) + + assert "Length of file_uris" in str(exc_info.value) + def test_create_requests_file_uri_preserved(self): - """Test that signed URLs are preserved in file_uri.""" - signed_url = "https://bucket.s3.amazonaws.com/audio.mp3?X-Amz-Signature=abc123&X-Amz-Expires=3600" + """Test that Gemini file URIs are preserved in file_data.""" + file_uri = "files/abc123xyz" + mime_types = ["audio/mpeg"] prompt = "Transcribe." - requests = create_stt_batch_requests([signed_url], prompt) + requests = create_stt_batch_requests([file_uri], mime_types, prompt) - file_uri = requests[0]["request"]["contents"][0]["parts"][1]["file_data"][ + result_uri = requests[0]["request"]["contents"][0]["parts"][1]["file_data"][ "file_uri" ] - assert file_uri == signed_url - assert "X-Amz-Signature" in file_uri + assert result_uri == file_uri class TestExtractTextFromResponseDict: @@ -603,22 +719,3 @@ def test_extract_empty_response_no_text_no_candidates(self): text = GeminiBatchProvider._extract_text_from_response(response) assert text == "" - - -class TestCreateSttBatchRequestsMimeTypeFallback: - """Test cases for create_stt_batch_requests MIME type fallback.""" - - def test_unknown_mime_type_defaults_to_audio_mpeg(self): - """Test that unknown file extensions default to audio/mpeg.""" - # URL with no recognizable audio extension - signed_urls = ["https://bucket.s3.amazonaws.com/audio.unknown?signature=xyz"] - prompt = "Transcribe this audio." - - with patch("app.core.batch.gemini.get_mime_from_url", return_value=None): - requests = create_stt_batch_requests(signed_urls, prompt) - - assert len(requests) == 1 - # Check that the request was created with default mime type - # parts[0] is text prompt, parts[1] is file_data - file_data = requests[0]["request"]["contents"][0]["parts"][1]["file_data"] - assert file_data["mime_type"] == "audio/mpeg"