Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paimon-python/dev/lint-python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ function collect_checks() {
function get_all_supported_checks() {
_OLD_IFS=$IFS
IFS=$'\n'
SUPPORT_CHECKS=("flake8_check" "pytest_torch_check" "pytest_check" "mixed_check") # control the calling sequence
SUPPORT_CHECKS=("flake8_check" "pytest_check" "pytest_torch_check" "mixed_check") # control the calling sequence
for fun in $(declare -F); do
if [[ `regexp_match "$fun" "_check$"` = true ]]; then
check_name="${fun:11}"
Expand Down
5 changes: 0 additions & 5 deletions paimon-python/pypaimon/read/reader/field_bunch.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,6 @@ def add(self, file: DataFileMeta) -> None:
"Blob file with overlapping row id should have decreasing sequence number."
)
return
elif first_row_id > self.expected_next_first_row_id:
raise ValueError(
f"Blob file first row id should be continuous, expect "
f"{self.expected_next_first_row_id} but got {first_row_id}"
)

if file.schema_id != self._files[0].schema_id:
raise ValueError(
Expand Down
99 changes: 99 additions & 0 deletions paimon-python/pypaimon/read/reader/sample_batch_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Optional

from pyarrow import RecordBatch

from pypaimon.read.reader.format_blob_reader import FormatBlobReader
from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader


class SampleBatchReader(RecordBatchReader):
"""
A reader that reads a subset of rows from a data file based on specified sample positions.

This reader wraps another RecordBatchReader and only returns rows at the specified
sample positions, enabling efficient random sampling of data without reading all rows.

The reader supports two modes:
1. For blob readers: Directly reads specific rows by index
2. For other readers: Reads batches sequentially and extracts only the sampled rows

Attributes:
reader: The underlying RecordBatchReader to read data from
sample_positions: A sorted list of row indices to sample (0-based)
sample_idx: Current index in the sample_positions list
current_pos: Current absolute row position in the data file
"""

def __init__(self, reader, sample_positions):
"""
Initialize the SampleBatchReader.

Args:
reader: The underlying RecordBatchReader to read data from
sample_positions: A bitmap of row indices to sample (0-based).
Must be sorted in ascending order for correct behavior.
"""
self.reader = reader
self.sample_positions = sample_positions
self.sample_idx = 0
self.current_pos = 0

def read_arrow_batch(self) -> Optional[RecordBatch]:
"""
Read the next batch containing sampled rows.

This method reads data from the underlying reader and returns only the rows
at the specified sample positions. The behavior differs based on reader type:

- For FormatBlobReader: Directly reads individual rows by index
- For other readers: Reads batches sequentially and extracts sampled rows
using PyArrow's take() method
"""
if self.sample_idx >= len(self.sample_positions):
return None
if isinstance(self.reader.format_reader, FormatBlobReader):
# For blob reader, pass begin_idx and end_idx parameters
batch = self.reader.read_arrow_batch(start_idx=self.sample_positions[self.sample_idx],
end_idx=self.sample_positions[self.sample_idx] + 1)
self.sample_idx += 1
return batch
else:
while True:
batch = self.reader.read_arrow_batch()
if batch is None:
return None

batch_begin = self.current_pos
self.current_pos += batch.num_rows
take_idxes = []

sample_pos = self.sample_positions[self.sample_idx]
while batch_begin <= sample_pos < self.current_pos:
take_idxes.append(sample_pos - batch_begin)
self.sample_idx += 1
if self.sample_idx >= len(self.sample_positions):
break
sample_pos = self.sample_positions[self.sample_idx]

if take_idxes:
return batch.take(take_idxes)
# batch is outside the desired range, continue to next batch

def close(self):
self.reader.close()
103 changes: 103 additions & 0 deletions paimon-python/pypaimon/read/sampled_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Dict, List

from pyroaring import BitMap

from pypaimon.read.split import Split


class SampledSplit(Split):
"""
A Split wrapper that contains sampled row indexes for each file.

This class wraps a data split and maintains a mapping from file names to
lists of sampled row indexes. It is used for random sampling scenarios where
only specific rows from each file need to be read.

Attributes:
_data_split: The underlying data split being wrapped.
_sampled_file_idx_map: A dictionary mapping file names to lists of
sampled row indexes within each file.
"""

def __init__(
self,
data_split: 'Split',
sampled_file_idx_map: Dict[str, BitMap]
):
self._data_split = data_split
self._sampled_file_idx_map = sampled_file_idx_map

def data_split(self) -> 'Split':
return self._data_split

def sampled_file_idx_map(self) -> Dict[str, BitMap]:
return self._sampled_file_idx_map

@property
def files(self) -> List['DataFileMeta']:
return self._data_split.files

@property
def partition(self) -> 'GenericRow':
return self._data_split.partition

@property
def bucket(self) -> int:
return self._data_split.bucket

@property
def row_count(self) -> int:
if not self._sampled_file_idx_map:
return self._data_split.row_count

total_rows = 0
for file in self._data_split.files:
positions = self._sampled_file_idx_map[file.file_name]
total_rows += len(positions)

return total_rows

@property
def file_paths(self):
return self._data_split.file_paths

@property
def file_size(self):
return self._data_split.file_size

@property
def raw_convertible(self):
return self._data_split.raw_convertible

@property
def data_deletion_files(self):
return self._data_split.data_deletion_files

def __eq__(self, other):
if not isinstance(other, SampledSplit):
return False
return (self._data_split == other._data_split and
self._sampled_file_idx_map == other._sampled_file_idx_map)

def __hash__(self):
return hash((id(self._data_split), tuple(sorted(self._sampled_file_idx_map.items()))))

def __repr__(self):
return (f"SampledSplit(data_split={self._data_split}, "
f"sampled_file_idx_map={self._sampled_file_idx_map})")
73 changes: 58 additions & 15 deletions paimon-python/pypaimon/read/scanner/append_table_split_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
import random
from collections import defaultdict
from typing import List, Dict, Tuple

from pyroaring import BitMap

from pypaimon.manifest.schema.data_file_meta import DataFileMeta
from pypaimon.manifest.schema.manifest_entry import ManifestEntry
from pypaimon.read.sampled_split import SampledSplit
from pypaimon.read.scanner.split_generator import AbstractSplitGenerator
from pypaimon.read.split import Split
from pypaimon.read.sliced_split import SlicedSplit
Expand All @@ -41,13 +45,15 @@ def create_splits(self, file_entries: List[ManifestEntry]) -> List[Split]:
if self.start_pos_of_this_subtask is not None:
# shard data range: [plan_start_pos, plan_end_pos)
partitioned_files, plan_start_pos, plan_end_pos = \
self.__filter_by_slice(
self._filter_by_slice(
partitioned_files,
self.start_pos_of_this_subtask,
self.end_pos_of_this_subtask
)
elif self.idx_of_this_subtask is not None:
partitioned_files, plan_start_pos, plan_end_pos = self._filter_by_shard(partitioned_files)
elif self.sample_num_rows is not None:
partitioned_files, file_positions = self._filter_by_sample(partitioned_files)

def weight_func(f: DataFileMeta) -> int:
return max(f.file_size, self.open_file_cost)
Expand All @@ -68,6 +74,8 @@ def weight_func(f: DataFileMeta) -> int:

if self.start_pos_of_this_subtask is not None or self.idx_of_this_subtask is not None:
splits = self._wrap_to_sliced_splits(splits, plan_start_pos, plan_end_pos)
elif self.sample_num_rows is not None:
splits = self._wrap_to_sampled_splits(splits, file_positions)

return splits

Expand All @@ -76,12 +84,12 @@ def _wrap_to_sliced_splits(self, splits: List[Split], plan_start_pos: int, plan_
file_end_pos = 0 # end row position of current file in all splits data

for split in splits:
shard_file_idx_map = self.__compute_split_file_idx_map(
shard_file_idx_map = self._compute_split_shard_file_idx_map(
plan_start_pos, plan_end_pos, split, file_end_pos
)
file_end_pos = shard_file_idx_map[self.NEXT_POS_KEY]
del shard_file_idx_map[self.NEXT_POS_KEY]

if shard_file_idx_map:
sliced_splits.append(SlicedSplit(split, shard_file_idx_map))
else:
Expand All @@ -90,10 +98,21 @@ def _wrap_to_sliced_splits(self, splits: List[Split], plan_start_pos: int, plan_
return sliced_splits

@staticmethod
def __filter_by_slice(
partitioned_files: defaultdict,
start_pos: int,
end_pos: int
def _wrap_to_sampled_splits(splits: List[Split], file_positions: Dict[str, BitMap]) -> List[Split]:
# Set sample file positions for each split
sampled_splits = []
for split in splits:
sampled_file_idx_map = {}
for file in split.files:
sampled_file_idx_map[file.file_name] = file_positions[file.file_name]
sampled_splits.append(SampledSplit(split, sampled_file_idx_map))
return sampled_splits

@staticmethod
def _filter_by_slice(
partitioned_files: defaultdict,
start_pos: int,
end_pos: int
) -> tuple:
plan_start_pos = 0
plan_end_pos = 0
Expand Down Expand Up @@ -142,21 +161,45 @@ def _filter_by_shard(self, partitioned_files: defaultdict) -> tuple:
# Calculate shard range using shared helper
start_pos, end_pos = self._compute_shard_range(total_row)

return self.__filter_by_slice(partitioned_files, start_pos, end_pos)
return self._filter_by_slice(partitioned_files, start_pos, end_pos)

def _filter_by_sample(self, partitioned_files) -> (defaultdict, Dict[str, List[int]]):
"""
Randomly sample num_rows data from partitioned_files:
1. First use random to generate num_rows indexes
2. Iterate through partitioned_files, find the file entries where corresponding indexes are located,
add them to filtered_partitioned_files, and for each entry, add indexes to the list
"""
# Calculate total number of rows
total_rows = 0
for key, file_entries in partitioned_files.items():
for entry in file_entries:
total_rows += entry.file.row_count

# Generate random sample indexes
sample_indexes = sorted(random.sample(range(total_rows), self.sample_num_rows))

# Map each sample index to its corresponding file and local index
filtered_partitioned_files = defaultdict(list)
file_positions = {} # {file_name: BitMap of local_indexes}
self._compute_file_sample_idx_map(partitioned_files, filtered_partitioned_files,
file_positions,
sample_indexes, is_blob=False)
return filtered_partitioned_files, file_positions

@staticmethod
def __compute_split_file_idx_map(
plan_start_pos: int,
plan_end_pos: int,
split: Split,
file_end_pos: int
def _compute_split_shard_file_idx_map(
plan_start_pos: int,
plan_end_pos: int,
split: Split,
file_end_pos: int
) -> Dict[str, Tuple[int, int]]:
"""
Compute file index map for a split, determining which rows to read from each file.

"""
shard_file_idx_map = {}

for file in split.files:
file_begin_pos = file_end_pos # Starting row position of current file in all data
file_end_pos += file.row_count # Update to row position after current file
Expand All @@ -165,7 +208,7 @@ def __compute_split_file_idx_map(
file_range = AppendTableSplitGenerator._compute_file_range(
plan_start_pos, plan_end_pos, file_begin_pos, file.row_count
)

if file_range is not None:
shard_file_idx_map[file.file_name] = file_range

Expand Down
Loading