Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
--machine_type=n1-standard-2
--num_workers=20
--max_num_workers=250
--timeout_ms=600000
--disk_size_gb=50
--autoscaling_algorithm=THROUGHPUT_BASED
--staging_location=gs://temp-storage-for-perf-tests/loadtests
Expand All @@ -31,5 +32,7 @@
--device=CPU
--input_file=gs://apache-beam-ml/testing/inputs/sentences_50k.txt
--runner=DataflowRunner
--sdk_location=container
--sdk_container_image=us.gcr.io/apache-beam-testing/python-postcommit-it/tensor_rt@sha256:884d67e96d9a3c22fb21fcd412c10a012d4c82a7c723f1c1ffe41fca609b5a6a
--model_path=distilbert-base-uncased-finetuned-sst-2-english
--model_state_dict_path=gs://apache-beam-ml/models/huggingface.sentiment.distilbert-base-uncased.pth
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
--machine_type=n1-standard-2
--num_workers=20
--max_num_workers=250
--timeout_ms=600000
--disk_size_gb=50
--autoscaling_algorithm=THROUGHPUT_BASED
--staging_location=gs://temp-storage-for-perf-tests/loadtests
Expand All @@ -31,6 +32,8 @@
--device=CPU
--input_file=gs://apache-beam-ml/testing/inputs/sentences_50k.txt
--runner=DataflowRunner
--sdk_location=container
--sdk_container_image=us.gcr.io/apache-beam-testing/python-postcommit-it/tensor_rt@sha256:884d67e96d9a3c22fb21fcd412c10a012d4c82a7c723f1c1ffe41fca609b5a6a
--dataflow_service_options=worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver
--model_path=distilbert-base-uncased-finetuned-sst-2-english
--model_state_dict_path=gs://apache-beam-ml/models/huggingface.sentiment.distilbert-base-uncased.pth
80 changes: 61 additions & 19 deletions sdks/python/apache_beam/examples/inference/pytorch_sentiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@

class SentimentPostProcessor(beam.DoFn):
"""Processes PredictionResult to extract sentiment label and confidence."""
def __init__(self, tokenizer: DistilBertTokenizerFast):
self.tokenizer = tokenizer

def process(self, element: tuple[str, PredictionResult]) -> Iterable[dict]:
text, prediction_result = element
logits = prediction_result.inference['logits']
Expand All @@ -62,16 +59,34 @@ def process(self, element: tuple[str, PredictionResult]) -> Iterable[dict]:
}


def tokenize_text(text: str,
tokenizer: DistilBertTokenizerFast) -> tuple[str, dict]:
"""Tokenizes input text using the specified tokenizer."""
tokenized = tokenizer(
text,
padding='max_length',
truncation=True,
max_length=128,
return_tensors="pt")
return text, {k: torch.squeeze(v) for k, v in tokenized.items()}
class TokenizeTextDoFn(beam.DoFn):
"""Initializes tokenizer per worker and tokenizes input text."""
def __init__(self, model_path: str):
self.model_path = model_path
self.tokenizer = None

def setup(self):
self.tokenizer = DistilBertTokenizerFast.from_pretrained(self.model_path)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = '[PAD]'

def process(self, text: str) -> Iterable[tuple[str, dict]]:
tokenized = self.tokenizer(
text,
padding='max_length',
truncation=True,
max_length=128,
return_tensors="pt")
yield text, {k: torch.squeeze(v, 0) for k, v in tokenized.items()}


class DistilBertForSequenceClassificationCompat(
DistilBertForSequenceClassification):
"""Builds config in worker runtime to avoid cross-env config drift."""
def __init__(self, model_name: str, num_labels: int = 2):
config = _ensure_transformers_config_compat(
DistilBertConfig.from_pretrained(model_name, num_labels=num_labels))
super().__init__(config)


class RateLimitDoFn(beam.DoFn):
Expand All @@ -83,6 +98,31 @@ def process(self, element):
yield element


def _ensure_transformers_config_compat(
config: DistilBertConfig) -> DistilBertConfig:
"""Adds missing config attributes for cross-version transformers compatibility.

The benchmark can run with container images whose transformers version differs
from the launcher environment. Some versions assume these attributes exist.
"""
# Use a default config instance as the source of canonical attributes for the
# transformers version available on the worker. This avoids chasing one
# missing field at a time (e.g. torchscript, output_attentions).
default_config = DistilBertConfig()
for key, value in default_config.to_dict().items():
if not hasattr(config, key):
setattr(config, key, value)

# Keep non-serialized fields explicitly for older/newer transformers mixes.
if not hasattr(config, 'pruned_heads'):
config.pruned_heads = {}
if not hasattr(config, 'torchscript'):
config.torchscript = False
if not hasattr(config, 'return_dict'):
config.return_dict = True
return config


def parse_known_args(argv):
"""Parses command-line arguments for pipeline execution."""
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -235,13 +275,14 @@ def run(
pipeline_options.view_as(StandardOptions).streaming = True

model_handler = PytorchModelHandlerKeyedTensor(
model_class=DistilBertForSequenceClassification,
model_params={'config': DistilBertConfig(num_labels=2)},
model_class=DistilBertForSequenceClassificationCompat,
model_params={
'model_name': known_args.model_path,
'num_labels': 2,
},
state_dict_path=known_args.model_state_dict_path,
device='GPU')

tokenizer = DistilBertTokenizerFast.from_pretrained(known_args.model_path)

pipeline = test_pipeline or beam.Pipeline(options=pipeline_options)

# Main pipeline: read, process, write result to BigQuery output table
Expand All @@ -264,9 +305,9 @@ def run(

_ = (
input
| 'Tokenize' >> beam.Map(lambda text: tokenize_text(text, tokenizer))
| 'Tokenize' >> beam.ParDo(TokenizeTextDoFn(known_args.model_path))
| 'RunInference' >> RunInference(KeyedModelHandler(model_handler))
| 'PostProcess' >> beam.ParDo(SentimentPostProcessor(tokenizer))
| 'PostProcess' >> beam.ParDo(SentimentPostProcessor())
| 'WriteToBigQuery' >> beam.io.WriteToBigQuery(
known_args.output_table,
schema='text:STRING, sentiment:STRING, confidence:FLOAT',
Expand All @@ -277,6 +318,7 @@ def run(
result = pipeline.run()
result.wait_until_finish(duration=1800000) # 30 min
result.cancel()
result.wait_until_finish(duration=600000) # up to 10 min to settle cancel

cleanup_pubsub_resources(
project=known_args.project,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ class TableRowInferenceOptions(
@classmethod
def _add_argparse_args(cls, parser):
parser.add_argument('--mode', default='batch')
parser.add_argument('--input_subscription')
parser.add_argument('--input_file')
parser.add_argument('--input_subscription', default='')
parser.add_argument('--input_file', default='')
parser.add_argument('--output_table')
parser.add_argument('--model_path')
parser.add_argument('--feature_columns')
Expand Down
Loading