diff --git a/src/tower/_tables.py b/src/tower/_tables.py index 451ebed6..87b34f14 100644 --- a/src/tower/_tables.py +++ b/src/tower/_tables.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from pyiceberg.exceptions import NoSuchTableError +from pyiceberg.exceptions import CommitFailedException TTable = TypeVar("TTable", bound="Table") @@ -25,6 +26,9 @@ namespace_or_default, ) +import time +import random + @dataclass class RowsAffectedInformation: @@ -178,13 +182,20 @@ def insert(self, data: pa.Table) -> TTable: self._stats.inserts += data.num_rows return self - def upsert(self, data: pa.Table, join_cols: Optional[list[str]] = None) -> TTable: + def upsert( + self, + data: pa.Table, + join_cols: Optional[list[str]] = None, + max_retries: int = 5, + retry_delay_seconds: float = 0.5, + ) -> TTable: """ - Performs an upsert operation (update or insert) on the Iceberg table. + Performs an upsert operation (update or insert) on the Iceberg table. In case of commit conflicts, reloads the metadata and retries. This method will: - Update existing rows if they match the join columns - Insert new rows if no match is found + - Retry for max_retries if commits fail All operations are case-sensitive by default. Args: @@ -192,10 +203,17 @@ def upsert(self, data: pa.Table, join_cols: Optional[list[str]] = None) -> TTabl must match the schema of the target table. join_cols (Optional[list[str]]): The columns that form the key to match rows on. If not provided, all columns will be used for matching. + max_retries (int): Maximum number of retry attempts on commit conflicts. + Defaults to 5. + retry_delay_seconds (float): Wait time in seconds between retries. + Defaults to 0.5 seconds. Returns: TTable: The table instance with the upserted rows, allowing for method chaining. + Raises: + CommitFailedException: If all retry attempts are exhausted. + Note: - The operation is always case-sensitive - When a match is found, all columns are updated @@ -217,22 +235,34 @@ def upsert(self, data: pa.Table, join_cols: Optional[list[str]] = None) -> TTabl >>> print(f"Updated {stats.updates} rows") >>> print(f"Inserted {stats.inserts} rows") """ - res = self._table.upsert( - data, - join_cols=join_cols, - # All upserts will always be case sensitive. Perhaps we'll add this - # as a parameter in the future? - case_sensitive=True, - # These are the defaults, but we're including them to be complete. - when_matched_update_all=True, - when_not_matched_insert_all=True, - ) - - # Update the stats with the results of the relevant upsert. - self._stats.updates += res.rows_updated - self._stats.inserts += res.rows_inserted - - return self + last_exception = None + + for attempt in range(max_retries + 1): + try: + if attempt > 0: + self._table.refresh() + + res = self._table.upsert( + data, + join_cols=join_cols, + # All upserts will always be case sensitive. Perhaps we'll add this + # as a parameter in the future? + case_sensitive=True, + # These are the defaults, but we're including them to be complete. + when_matched_update_all=True, + when_not_matched_insert_all=True, + ) + + self._stats.updates += res.rows_updated + self._stats.inserts += res.rows_inserted + return self + + except CommitFailedException as e: + last_exception = e + if attempt < max_retries: + time.sleep(retry_delay_seconds) + + raise last_exception def delete(self, filters: Union[str, List[pc.Expression]]) -> TTable: """ diff --git a/tests/tower/test_tables.py b/tests/tower/test_tables.py index 6aa36f2c..1203b669 100644 --- a/tests/tower/test_tables.py +++ b/tests/tower/test_tables.py @@ -5,11 +5,15 @@ import pathlib from urllib.parse import urljoin from urllib.request import pathname2url +import threading # We import all the things we need from Tower. import tower.polars as pl import pyarrow as pa from pyiceberg.catalog.memory import InMemoryCatalog +from pyiceberg.catalog.sql import SqlCatalog + +import concurrent.futures # Imports the library under test import tower @@ -42,6 +46,28 @@ def in_memory_catalog(): pass +@pytest.fixture +def sql_catalog(): + temp_dir = tempfile.mkdtemp() # ← Returns string path, no auto-cleanup + abs_path = pathlib.Path(temp_dir).absolute() + file_url = urljoin("file:", pathname2url(str(abs_path))) + + catalog = SqlCatalog( + "test.sql.catalog", + **{ + "uri": f"sqlite:///{abs_path}/catalog.db?check_same_thread=False", + "warehouse": file_url, + }, + ) + + yield catalog + + try: + shutil.rmtree(abs_path) + except FileNotFoundError: + pass + + def test_reading_and_writing_to_tables(in_memory_catalog): schema = pa.schema( [ @@ -166,6 +192,135 @@ def test_upsert_to_tables(in_memory_catalog): assert res["age"].item() == 26 +def test_upsert_concurrent_writes_with_retry(sql_catalog): + """Test that concurrent upserts succeed with retry logic handling conflicts.""" + schema = pa.schema( + [ + pa.field("ticker", pa.string()), + pa.field("date", pa.string()), + pa.field("price", pa.float64()), + ] + ) + + ref = tower.tables("concurrent_test", catalog=sql_catalog) + table = ref.create_if_not_exists(schema) + + initial_data = pa.Table.from_pylist( + [ + {"ticker": "AAPL", "date": "2024-01-01", "price": 100.0}, + {"ticker": "GOOGL", "date": "2024-01-01", "price": 200.0}, + {"ticker": "MSFT", "date": "2024-01-01", "price": 300.0}, + ], + schema=schema, + ) + table.insert(initial_data) + + retry_count = {"value": 0} + retry_lock = threading.Lock() + + def upsert_ticker(ticker: str, new_price: float): + t = tower.tables("concurrent_test", catalog=sql_catalog).load() + + original_refresh = t._table.refresh + + def tracked_refresh(): + with retry_lock: + retry_count["value"] += 1 + return original_refresh() + + t._table.refresh = tracked_refresh + + data = pa.Table.from_pylist( + [{"ticker": ticker, "date": "2024-01-01", "price": new_price}], + schema=schema, + ) + t.upsert(data, join_cols=["ticker", "date"]) + return ticker + + with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: + futures = [ + executor.submit(upsert_ticker, "AAPL", 150.0), + executor.submit(upsert_ticker, "GOOGL", 250.0), + executor.submit(upsert_ticker, "MSFT", 350.0), + ] + results = [f.result() for f in concurrent.futures.as_completed(futures)] + + assert len(results) == 3 + assert ( + retry_count["value"] > 0 + ), "Expected at least one retry due to concurrent conflicts" + + final_table = tower.tables("concurrent_test", catalog=sql_catalog).load() + df = final_table.read() + + assert len(df) == 3 + + ticker_prices = {row["ticker"]: row["price"] for row in df.iter_rows(named=True)} + + assert ticker_prices["AAPL"] == 150.0 + assert ticker_prices["GOOGL"] == 250.0 + assert ticker_prices["MSFT"] == 350.0 + + +def test_upsert_concurrent_writes_same_row(sql_catalog): + """Test concurrent upserts to the SAME row - last write wins.""" + schema = pa.schema( + [ + pa.field("id", pa.int64()), + pa.field("counter", pa.int64()), + ] + ) + + ref = tower.tables("concurrent_same_row_test", catalog=sql_catalog) + table = ref.create_if_not_exists(schema) + + initial_data = pa.Table.from_pylist( + [{"id": 1, "counter": 0}], + schema=schema, + ) + table.insert(initial_data) + + retry_count = {"value": 0} + retry_lock = threading.Lock() + + def upsert_counter(value: int): + t = tower.tables("concurrent_same_row_test", catalog=sql_catalog).load() + + original_refresh = t._table.refresh + + def tracked_refresh(): + with retry_lock: + retry_count["value"] += 1 + return original_refresh() + + t._table.refresh = tracked_refresh + + data = pa.Table.from_pylist( + [{"id": 1, "counter": value}], + schema=schema, + ) + t.upsert(data, join_cols=["id"]) + return value + + with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(upsert_counter, i) for i in range(1, 6)] + results = [f.result() for f in concurrent.futures.as_completed(futures)] + + assert len(results) == 5 + + assert ( + retry_count["value"] > 0 + ), "Expected at least one retry due to concurrent conflicts" + + final_table = tower.tables("concurrent_same_row_test", catalog=sql_catalog).load() + df = final_table.read() + + assert len(df) == 1 + + final_counter = df.select("counter").item() + assert final_counter in [1, 2, 3, 4, 5] + + def test_delete_from_tables(in_memory_catalog): schema = pa.schema( [