From afb3ffe4b97e744aea87c955736758642d6afb39 Mon Sep 17 00:00:00 2001 From: aIbrahiim Date: Sat, 11 Apr 2026 22:24:34 +0200 Subject: [PATCH] Fix PyTorch sentiment Dataflow inference device and CLI args Parse --device and --input_file in the example, stop injecting --device into Beam pipeline args, and default inference to CPU. Drop T4 worker_accelerator from the streaming load-test options so CPU jobs do not request GPUs. Made-with: Cursor --- ...entiment_Streaming_DistilBert_Base_Uncased.txt | 1 - .../examples/inference/pytorch_sentiment.py | 15 +++++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Pytorch_Sentiment_Streaming_DistilBert_Base_Uncased.txt b/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Pytorch_Sentiment_Streaming_DistilBert_Base_Uncased.txt index d10b9bb2dfcb..f3beb528f881 100644 --- a/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Pytorch_Sentiment_Streaming_DistilBert_Base_Uncased.txt +++ b/.github/workflows/load-tests-pipeline-options/beam_Inference_Python_Benchmarks_Dataflow_Pytorch_Sentiment_Streaming_DistilBert_Base_Uncased.txt @@ -31,6 +31,5 @@ --device=CPU --input_file=gs://apache-beam-ml/testing/inputs/sentences_50k.txt --runner=DataflowRunner ---dataflow_service_options=worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver --model_path=distilbert-base-uncased-finetuned-sst-2-english --model_state_dict_path=gs://apache-beam-ml/models/huggingface.sentiment.distilbert-base-uncased.pth diff --git a/sdks/python/apache_beam/examples/inference/pytorch_sentiment.py b/sdks/python/apache_beam/examples/inference/pytorch_sentiment.py index 3bb36930a045..fe876f1c85b9 100644 --- a/sdks/python/apache_beam/examples/inference/pytorch_sentiment.py +++ b/sdks/python/apache_beam/examples/inference/pytorch_sentiment.py @@ -98,7 +98,11 @@ def parse_known_args(argv): required=True, help="Path to the model's state_dict.") parser.add_argument( - '--input', required=True, help='Path to input file on GCS') + '--input', + '--input_file', + dest='input', + required=True, + help='Path to input file on GCS (load tests pass --input_file).') parser.add_argument( '--pubsub_topic', default='projects/apache-beam-testing/topics/test_sentiment_topic', @@ -117,6 +121,11 @@ def parse_known_args(argv): type=float, default=None, help='Elements per second to send to Pub/Sub') + parser.add_argument( + '--device', + default='CPU', + choices=['CPU', 'GPU'], + help='Device to use for inference. Choices are CPU or GPU.') return parser.parse_known_args(argv) @@ -183,8 +192,6 @@ def override_or_add(args, flag, value): def run_load_pipeline(known_args, pipeline_args): """Load data pipeline: read lines from GCS file and send to Pub/Sub.""" - - override_or_add(pipeline_args, '--device', 'CPU') override_or_add(pipeline_args, '--num_workers', '5') override_or_add(pipeline_args, '--max_num_workers', '10') override_or_add( @@ -238,7 +245,7 @@ def run( model_class=DistilBertForSequenceClassification, model_params={'config': DistilBertConfig(num_labels=2)}, state_dict_path=known_args.model_state_dict_path, - device='GPU') + device=known_args.device) tokenizer = DistilBertTokenizerFast.from_pretrained(known_args.model_path)