Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pipelines:
- type: RunInference
config:
model_handler:
type: "HuggingFacePipeline"
type: "HuggingFacePipelineModelHandler"
config:
task: "text-classification"
inference_fn:
Expand Down
52 changes: 48 additions & 4 deletions sdks/python/apache_beam/yaml/yaml_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,8 @@ def inference_output_type(self):
('model_id', Optional[str])])


@ModelHandlerProvider.register_handler_type('HuggingFacePipeline')
class HuggingFacePipelineProvider(ModelHandlerProvider):
@ModelHandlerProvider.register_handler_type('HuggingFacePipelineModelHandler')
class HuggingFacePipelineModelHandlerProvider(ModelHandlerProvider):
def __init__(
self,
task: Optional[str] = None,
Expand All @@ -294,6 +294,49 @@ def __init__(
inference_fn: Optional[dict[str, str]] = None,
load_pipeline_args: Optional[dict[str, Any]] = None,
**kwargs):
"""
ModelHandler for Hugging Face Pipelines.

This Model Handler can be used with RunInference to load a model using
Hugging Face pipelines. Hugging Face pipelines provide a simple way to
perform inference on various tasks (e.g. text classification, token
classification, text generation).

This Model Handler requires either a `task` or `model` to be specified.
Preprocessing and Postprocessing are described in more detail in the
RunInference docs:
https://beam.apache.org/releases/yamldoc/current/#runinference

For example: ::

- type: RunInference
config:
model_handler:
type: HuggingFacePipelineModelHandler
config:
task: text-classification
model: distilbert-base-uncased-finetuned-sst-2-english
preprocess:
callable: 'lambda x: x.text'

Args:
task: The task for the pipeline. See Hugging Face documentation for
a list of supported tasks.
model: The model name on Hugging Face hub or a path to a local directory.
If the model already defines the task, no need to specify the task.
preprocess: A python callable, defined either inline, or using a file,
that is invoked on the input row before sending to the model to be
loaded by this ModelHandler.
postprocess: A python callable, defined either inline, or using a file,
that is invoked on the PredictionResult output by the ModelHandler
before parsing into the output Beam Row.
device: The device to run the pipeline on (e.g., 'cpu', 'cuda', 'cuda:0').
Defaults to CPU.
inference_fn: The custom inference function to use.
load_pipeline_args: Extra arguments to pass to the Hugging Face pipeline
loader (e.g. `transformers.pipeline`).
**kwargs: Extra arguments to pass to the model handler.
"""
try:
from apache_beam.ml.inference.huggingface_inference import HuggingFacePipelineModelHandler
except ImportError:
Expand Down Expand Up @@ -324,7 +367,7 @@ def __init__(
def validate(config):
if not config or (not config.get('task') and not config.get('model')):
raise ValueError(
"HuggingFacePipeline requires either 'task' or "
"HuggingFacePipelineModelHandler requires either 'task' or "
"'model' to be specified.")

def inference_output_type(self):
Expand Down Expand Up @@ -488,10 +531,11 @@ def fn(x: PredictionResult):

Args:
model_handler: Specifies the parameters for the respective
enrichment_handler in a YAML/JSON format. To see the full set of
model_handler in a YAML/JSON format. To see the full set of
handler_config parameters, see their corresponding doc pages:

- [VertexAIModelHandlerJSON](https://beam.apache.org/releases/pydoc/current/apache_beam.yaml.yaml_ml.VertexAIModelHandlerJSONProvider) # pylint: disable=line-too-long
- [HuggingFacePipelineModelHandler](https://beam.apache.org/releases/pydoc/current/apache_beam.yaml.yaml_ml.HuggingFacePipelineModelHandlerProvider) # pylint: disable=line-too-long
inference_tag: The tag to use for the returned inference. Default is
'inference'.
inference_args: Extra arguments for models whose inference call requires
Expand Down
Loading