Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 48 additions & 18 deletions src/tower/_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass

from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.exceptions import CommitFailedException

TTable = TypeVar("TTable", bound="Table")

Expand All @@ -25,6 +26,9 @@
namespace_or_default,
)

import time
import random


@dataclass
class RowsAffectedInformation:
Expand Down Expand Up @@ -178,24 +182,38 @@ def insert(self, data: pa.Table) -> TTable:
self._stats.inserts += data.num_rows
return self

def upsert(self, data: pa.Table, join_cols: Optional[list[str]] = None) -> TTable:
def upsert(
self,
data: pa.Table,
join_cols: Optional[list[str]] = None,
max_retries: int = 5,
retry_delay_seconds: float = 0.5,
) -> TTable:
"""
Performs an upsert operation (update or insert) on the Iceberg table.
Performs an upsert operation (update or insert) on the Iceberg table. In case of commit conflicts, reloads the metadata and retries.

This method will:
- Update existing rows if they match the join columns
- Insert new rows if no match is found
- Retry for max_retries if commits fail
All operations are case-sensitive by default.

Args:
data (pa.Table): The data to upsert into the table. The schema of this table
must match the schema of the target table.
join_cols (Optional[list[str]]): The columns that form the key to match rows on.
If not provided, all columns will be used for matching.
max_retries (int): Maximum number of retry attempts on commit conflicts.
Defaults to 5.
retry_delay_seconds (float): Wait time in seconds between retries.
Defaults to 0.5 seconds.

Returns:
TTable: The table instance with the upserted rows, allowing for method chaining.

Raises:
CommitFailedException: If all retry attempts are exhausted.

Note:
- The operation is always case-sensitive
- When a match is found, all columns are updated
Expand All @@ -217,22 +235,34 @@ def upsert(self, data: pa.Table, join_cols: Optional[list[str]] = None) -> TTabl
>>> print(f"Updated {stats.updates} rows")
>>> print(f"Inserted {stats.inserts} rows")
"""
res = self._table.upsert(
data,
join_cols=join_cols,
# All upserts will always be case sensitive. Perhaps we'll add this
# as a parameter in the future?
case_sensitive=True,
# These are the defaults, but we're including them to be complete.
when_matched_update_all=True,
when_not_matched_insert_all=True,
)

# Update the stats with the results of the relevant upsert.
self._stats.updates += res.rows_updated
self._stats.inserts += res.rows_inserted

return self
last_exception = None

for attempt in range(max_retries + 1):
try:
if attempt > 0:
self._table.refresh()

res = self._table.upsert(
data,
join_cols=join_cols,
# All upserts will always be case sensitive. Perhaps we'll add this
# as a parameter in the future?
case_sensitive=True,
# These are the defaults, but we're including them to be complete.
when_matched_update_all=True,
when_not_matched_insert_all=True,
)

self._stats.updates += res.rows_updated
self._stats.inserts += res.rows_inserted
return self

except CommitFailedException as e:
last_exception = e
if attempt < max_retries:
time.sleep(retry_delay_seconds)

raise last_exception

def delete(self, filters: Union[str, List[pc.Expression]]) -> TTable:
"""
Expand Down
155 changes: 155 additions & 0 deletions tests/tower/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
import pathlib
from urllib.parse import urljoin
from urllib.request import pathname2url
import threading

# We import all the things we need from Tower.
import tower.polars as pl
import pyarrow as pa
from pyiceberg.catalog.memory import InMemoryCatalog
from pyiceberg.catalog.sql import SqlCatalog

import concurrent.futures

# Imports the library under test
import tower
Expand Down Expand Up @@ -42,6 +46,28 @@ def in_memory_catalog():
pass


@pytest.fixture
def sql_catalog():
temp_dir = tempfile.mkdtemp() # ← Returns string path, no auto-cleanup
abs_path = pathlib.Path(temp_dir).absolute()
file_url = urljoin("file:", pathname2url(str(abs_path)))

catalog = SqlCatalog(
"test.sql.catalog",
**{
"uri": f"sqlite:///{abs_path}/catalog.db?check_same_thread=False",
"warehouse": file_url,
},
)

yield catalog

try:
shutil.rmtree(abs_path)
except FileNotFoundError:
pass


def test_reading_and_writing_to_tables(in_memory_catalog):
schema = pa.schema(
[
Expand Down Expand Up @@ -166,6 +192,135 @@ def test_upsert_to_tables(in_memory_catalog):
assert res["age"].item() == 26


def test_upsert_concurrent_writes_with_retry(sql_catalog):
"""Test that concurrent upserts succeed with retry logic handling conflicts."""
schema = pa.schema(
[
pa.field("ticker", pa.string()),
pa.field("date", pa.string()),
pa.field("price", pa.float64()),
]
)

ref = tower.tables("concurrent_test", catalog=sql_catalog)
table = ref.create_if_not_exists(schema)

initial_data = pa.Table.from_pylist(
[
{"ticker": "AAPL", "date": "2024-01-01", "price": 100.0},
{"ticker": "GOOGL", "date": "2024-01-01", "price": 200.0},
{"ticker": "MSFT", "date": "2024-01-01", "price": 300.0},
],
schema=schema,
)
table.insert(initial_data)

retry_count = {"value": 0}
retry_lock = threading.Lock()

def upsert_ticker(ticker: str, new_price: float):
t = tower.tables("concurrent_test", catalog=sql_catalog).load()

original_refresh = t._table.refresh

def tracked_refresh():
with retry_lock:
retry_count["value"] += 1
return original_refresh()

t._table.refresh = tracked_refresh

data = pa.Table.from_pylist(
[{"ticker": ticker, "date": "2024-01-01", "price": new_price}],
schema=schema,
)
t.upsert(data, join_cols=["ticker", "date"])
return ticker

with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
futures = [
executor.submit(upsert_ticker, "AAPL", 150.0),
executor.submit(upsert_ticker, "GOOGL", 250.0),
executor.submit(upsert_ticker, "MSFT", 350.0),
]
results = [f.result() for f in concurrent.futures.as_completed(futures)]

assert len(results) == 3
assert (
retry_count["value"] > 0
), "Expected at least one retry due to concurrent conflicts"

final_table = tower.tables("concurrent_test", catalog=sql_catalog).load()
df = final_table.read()

assert len(df) == 3

ticker_prices = {row["ticker"]: row["price"] for row in df.iter_rows(named=True)}

assert ticker_prices["AAPL"] == 150.0
assert ticker_prices["GOOGL"] == 250.0
assert ticker_prices["MSFT"] == 350.0


def test_upsert_concurrent_writes_same_row(sql_catalog):
"""Test concurrent upserts to the SAME row - last write wins."""
schema = pa.schema(
[
pa.field("id", pa.int64()),
pa.field("counter", pa.int64()),
]
)

ref = tower.tables("concurrent_same_row_test", catalog=sql_catalog)
table = ref.create_if_not_exists(schema)

initial_data = pa.Table.from_pylist(
[{"id": 1, "counter": 0}],
schema=schema,
)
table.insert(initial_data)

retry_count = {"value": 0}
retry_lock = threading.Lock()

def upsert_counter(value: int):
t = tower.tables("concurrent_same_row_test", catalog=sql_catalog).load()

original_refresh = t._table.refresh

def tracked_refresh():
with retry_lock:
retry_count["value"] += 1
return original_refresh()

t._table.refresh = tracked_refresh

data = pa.Table.from_pylist(
[{"id": 1, "counter": value}],
schema=schema,
)
t.upsert(data, join_cols=["id"])
return value

with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
futures = [executor.submit(upsert_counter, i) for i in range(1, 6)]
results = [f.result() for f in concurrent.futures.as_completed(futures)]

assert len(results) == 5

assert (
retry_count["value"] > 0
), "Expected at least one retry due to concurrent conflicts"

final_table = tower.tables("concurrent_same_row_test", catalog=sql_catalog).load()
df = final_table.read()

assert len(df) == 1

final_counter = df.select("counter").item()
assert final_counter in [1, 2, 3, 4, 5]


def test_delete_from_tables(in_memory_catalog):
schema = pa.schema(
[
Expand Down