Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 144 additions & 40 deletions backend/app/core/batch/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import logging
import mimetypes
import os
import tempfile
import time
Expand All @@ -11,8 +12,6 @@
from google import genai
from google.genai import types

from app.core.storage_utils import get_mime_from_url

from .base import BATCH_KEY, BatchProvider

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -292,6 +291,46 @@ def _extract_text_from_response_dict(response: dict[str, Any]) -> str:
"""Extract text content from a Gemini response dictionary."""
return extract_text_from_response_dict(response)

def _upload_to_gemini(
self,
content: str | bytes,
suffix: str,
mime_type: str,
display_name: str,
) -> types.File:
"""Write content to a temp file, upload to Gemini, and clean up.

Args:
content: File content (text or binary)
suffix: Temp file suffix (e.g., ".jsonl", ".mp3")
mime_type: MIME type for the upload
display_name: Display name in Gemini

Returns:
Gemini File object
"""
# "w" for text (JSONL batch) or "wb" for binary (audio files for STT),
# since this method accepts content: str | bytes.
mode = "w" if isinstance(content, str) else "wb"
kwargs: dict[str, Any] = {"suffix": suffix, "delete": False, "mode": mode}
if mode == "w":
kwargs["encoding"] = "utf-8"

with tempfile.NamedTemporaryFile(**kwargs) as tmp_file:
tmp_file.write(content)
tmp_path = tmp_file.name

try:
return self._client.files.upload(
file=tmp_path,
config=types.UploadFileConfig(
display_name=display_name,
mime_type=mime_type,
),
)
finally:
os.unlink(tmp_path)

def upload_file(self, content: str, purpose: str = "batch") -> str:
"""Upload a JSONL file to Gemini Files API.

Expand All @@ -305,35 +344,98 @@ def upload_file(self, content: str, purpose: str = "batch") -> str:
logger.info(f"[upload_file] Uploading file to Gemini | bytes={len(content)}")

try:
with tempfile.NamedTemporaryFile(
suffix=".jsonl", delete=False, mode="w", encoding="utf-8"
) as tmp_file:
tmp_file.write(content)
tmp_path = tmp_file.name
uploaded_file = self._upload_to_gemini(
content=content,
suffix=".jsonl",
mime_type="jsonl",
display_name=f"batch-input-{int(time.time())}",
)

try:
uploaded_file = self._client.files.upload(
file=tmp_path,
config=types.UploadFileConfig(
display_name=f"batch-input-{int(time.time())}",
mime_type="jsonl",
),
)
logger.info(
f"[upload_file] Uploaded file to Gemini | "
f"file_name={uploaded_file.name}"
)

logger.info(
f"[upload_file] Uploaded file to Gemini | "
f"file_name={uploaded_file.name}"
)
return uploaded_file.name

except Exception as e:
logger.error(f"[upload_file] Failed to upload file to Gemini | {e}")
raise

return uploaded_file.name
def upload_audio_file(
self,
content: bytes,
mime_type: str,
display_name: str | None = None,
) -> tuple[str, str]:
"""Upload an audio file to Gemini File API.

finally:
os.unlink(tmp_path)
Args:
content: Raw audio file bytes
mime_type: MIME type of the audio (e.g., 'audio/mpeg')
display_name: Optional display name for the file

Returns:
Tuple of (file_name, file_uri):
- file_name: Short name for API calls (e.g., "files/xxx")
- file_uri: Full URI for use in batch requests
(e.g., "https://generativelanguage.googleapis.com/v1beta/files/xxx")
"""
display_name = display_name or f"stt-audio-{int(time.time())}"
logger.info(
f"[upload_audio_file] Uploading audio to Gemini | "
f"bytes={len(content)} | mime_type={mime_type} | display_name={display_name}"
)

try:
uploaded_file = self._upload_to_gemini(
content=content,
suffix=mimetypes.guess_extension(mime_type) or ".bin",
mime_type=mime_type,
display_name=display_name,
)

logger.info(
f"[upload_audio_file] Uploaded audio to Gemini | "
f"file_name={uploaded_file.name} | file_uri={uploaded_file.uri}"
)

return uploaded_file.name, uploaded_file.uri

except Exception as e:
logger.error(f"[upload_file] Failed to upload file to Gemini | {e}")
logger.error(f"[upload_audio_file] Failed to upload audio to Gemini | {e}")
raise

def delete_files(self, file_names: list[str]) -> tuple[int, int]:
"""Delete files from Gemini File API.

Args:
file_names: List of Gemini file names to delete (e.g., ["files/xxx", ...])

Returns:
Tuple of (success_count, failure_count)
"""
success_count = 0
failure_count = 0

for name in file_names:
try:
self._client.files.delete(name=name)
success_count += 1
except Exception as e:
failure_count += 1
logger.warning(
f"[delete_files] Failed to delete Gemini file | "
f"file_name={name} | error={e}"
)

logger.info(
f"[delete_files] Gemini file cleanup complete | "
f"deleted={success_count}, failed={failure_count}"
)

return success_count, failure_count

def download_file(self, file_id: str) -> str:
"""Download a file from Gemini Files API.

Expand Down Expand Up @@ -387,18 +489,20 @@ def _extract_text_from_response(response: Any) -> str:


def create_stt_batch_requests(
signed_urls: list[str],
file_uris: list[str],
mime_types: list[str],
prompt: str,
keys: list[str] | None = None,
) -> list[dict[str, Any]]:
"""
Create batch API requests for Gemini STT using signed URLs.
Create batch API requests for Gemini STT using Gemini File API URIs.

This function generates request payloads in Gemini's JSONL batch format
using signed URLs directly. MIME types are automatically detected from the URL path.
using file URIs from the Gemini File API.

Args:
signed_urls: List of signed URLs pointing to audio files
file_uris: List of Gemini file URIs (e.g., "files/abc123")
mime_types: List of MIME types corresponding to each file URI
prompt: Transcription prompt/instructions for the model
keys: Optional list of custom IDs for tracking results. If not provided,
uses 0-indexed integers as strings.
Expand All @@ -408,25 +512,25 @@ def create_stt_batch_requests(
{"key": "sample-1", "request": {"contents": [...]}}

Example:
>>> urls = ["https://bucket.s3.amazonaws.com/audio.mp3?..."]
>>> uris = ["files/abc123"]
>>> mime_types = ["audio/mpeg"]
>>> prompt = "Transcribe this audio file."
>>> requests = create_stt_batch_requests(urls, prompt, keys=["sample-1"])
>>> requests = create_stt_batch_requests(uris, mime_types, prompt, keys=["sample-1"])
>>> provider.create_batch(requests, {"display_name": "stt-batch"})
"""
if keys is not None and len(keys) != len(signed_urls):
if len(file_uris) != len(mime_types):
raise ValueError(
f"Length of file_uris ({len(file_uris)}) must match mime_types ({len(mime_types)})"
)

if keys is not None and len(keys) != len(file_uris):
raise ValueError(
f"Length of keys ({len(keys)}) must match signed_urls ({len(signed_urls)})"
f"Length of keys ({len(keys)}) must match file_uris ({len(file_uris)})"
)

requests = []
for i, url in enumerate(signed_urls):
mime_type = get_mime_from_url(url)
if mime_type is None:
logger.warning(
f"[create_stt_batch_requests] Could not determine MIME type for URL | "
f"index={i} | defaulting to audio/mpeg"
)
mime_type = "audio/mpeg"
for i, uri in enumerate(file_uris):
mime_type = mime_types[i]

# Use provided key or generate from index
key = keys[i] if keys is not None else str(i)
Expand All @@ -439,7 +543,7 @@ def create_stt_batch_requests(
{
"parts": [
{"text": prompt},
{"file_data": {"mime_type": mime_type, "file_uri": url}},
{"file_data": {"mime_type": mime_type, "file_uri": uri}},
],
"role": "user",
}
Expand Down
Loading
Loading