Multi-Tensor Input in Servo-Beam#10
Conversation
…mn opption and tests
Integrate Arrow as internal processing container
tfx_bsl/beam/run_inference.py
Outdated
| self._api_client = discovery.build('ml', 'v1') | ||
|
|
||
| def _extract_from_recordBatch(self, elements: pa.RecordBatch): | ||
| serialized_examples = bsl_util.ExtractSerializedExampleFromRecordBatch(elements) |
There was a problem hiding this comment.
Seems this is the same in Batch and Remote DoFn. Maybe extract this out to Base, and only get model_input in _extract_from_recordBatch?
tfx_bsl/beam/run_inference.py
Outdated
| ) -> Mapping[Text, np.ndarray]: | ||
| self._check_elements(elements) | ||
| outputs = self._run_tf_operations(elements) | ||
| self, tensors: Mapping[Any, Any]) -> Mapping[Text, np.ndarray]: |
There was a problem hiding this comment.
Add comment on what's expected in tensors. And is the Mapping key a Text?
| self, elements: Mapping[Any, Any], | ||
| outputs: Mapping[Text, np.ndarray] | ||
| ) -> Iterable[Tuple[Union[str, bytes], classification_pb2.Classifications]]: | ||
| serialized_examples, = elements.values() |
There was a problem hiding this comment.
It won't give the right answer
| self, elements: Mapping[Any, Any], | ||
| outputs: Mapping[Text, np.ndarray] | ||
| ) -> Iterable[Tuple[Union[str, bytes], classification_pb2.Classifications]]: | ||
| serialized_examples, = elements.values() |
There was a problem hiding this comment.
Is element.values serialized examples?
| raise ValueError('Expected to have one name and one alias per tensor') | ||
|
|
||
| include_request = True | ||
| if len(input_tensor_names) == 1: |
There was a problem hiding this comment.
Can we make the determination of single input string tensor in a internal utility function inside of BaseDoFn?
There was a problem hiding this comment.
The input tensor names is not in baseDoFn
|
|
||
| include_request = True | ||
| if len(input_tensor_names) == 1: | ||
| serialized_examples, = elements.values() |
There was a problem hiding this comment.
Shall we also check the type of elements.values is string/bytes?
There was a problem hiding this comment.
It's checked in extract form record batch
tfx_bsl/beam/run_inference.py
Outdated
| else: | ||
| input_tensor_proto.tensor_shape.dim.add().size = len(elements[tensor_name][0]) |
There was a problem hiding this comment.
Why the dim size is len(elements[tensor_name][0]) instead of:
for s in elements[tensor_name][0].shape:
input_tensor_proto.tensor_shape.dim.add().size = s
There was a problem hiding this comment.
we have an nd.array, I dont think we will have shape parameter
tfx_bsl/beam/run_inference.py
Outdated
| for alias, tensor_name in zip(input_tensor_alias, input_tensor_names): | ||
| input_tensor_proto = predict_log_tmpl.request.inputs[alias] | ||
| input_tensor_proto.dtype = tf.as_dtype(input_tensor_types[alias]).as_datatype_enum | ||
| if len(input_tensor_alias) == 1: |
There was a problem hiding this comment.
Could the single input case be handled separately?
tfx_bsl/beam/run_inference.py
Outdated
| alias = input_tensor_alias[0] | ||
| predict_log.request.inputs[alias].string_val.append(process_elements[i]) | ||
| else: | ||
| for alias, tensor_name in zip(input_tensor_alias, input_tensor_names): |
There was a problem hiding this comment.
Is this correct given it's already in the loop of alias, tensor_name
| ) -> Iterable[Tuple[tf.train.Example, inference_pb2.MultiInferenceResponse]]: | ||
| self, elements: Mapping[Any, Any], | ||
| outputs: Mapping[Text, np.ndarray] | ||
| ) -> Iterable[Tuple[Union[str, bytes], inference_pb2.MultiInferenceResponse]]: |
There was a problem hiding this comment.
Can this just be bytes instead of Union[str, bytes] ?
str is the same as 'bytes' in py2.
There was a problem hiding this comment.
Just wanted to make sure it's compatible with py2
|
|
||
| model_input = None | ||
| if (len(self._io_tensor_spec.input_tensor_names) == 1): | ||
| model_input = {self._io_tensor_spec.input_tensor_names[0]: serialized_examples} |
There was a problem hiding this comment.
Can we just leave this in _BaseBatchsavedModelDoFn and move the rest to _BatchPredictDoFn?
tfx_bsl/public/beam/run_inference.py
Outdated
| Args: | ||
| examples: A PCollection containing examples. | ||
| inference_spec_type: Model inference endpoint. | ||
| Schema [optional]: required for models that requires |
There was a problem hiding this comment.
Mention this is only available for Predict method.
tfx_bsl/beam/bsl_util.py
Outdated
|
|
||
| _KERAS_INPUT_SUFFIX = '_input' | ||
|
|
||
| def ExtractSerializedExampleFromRecordBatch(elements: pa.RecordBatch) -> List[Text]: |
There was a problem hiding this comment.
ExtractSerializedExamplesFromRecordBatch
| def ExtractSerializedExampleFromRecordBatch(elements: pa.RecordBatch) -> List[Text]: | ||
| serialized_examples = None | ||
| for column_name, column_array in zip(elements.schema.names, elements.columns): | ||
| if column_name == _RECORDBATCH_COLUMN: |
There was a problem hiding this comment.
Should _RECORDBATCH_COLUMN be passed an an argument to the API?
If we use a constant here, it would mean users would have to use this same constant when creating the TFXIO.
| tf.train.SequenceExample]) | ||
| @beam.typehints.with_input_types(tf.train.Example) | ||
| @beam.typehints.with_output_types(prediction_log_pb2.PredictionLog) | ||
| def RunInference( # pylint: disable=invalid-name |
There was a problem hiding this comment.
Is the long term plan to deprecate the tf.example API? And only have a record batch API?
If so, mention it in a comment
| if prepare_instances_serialized: | ||
| return [{'b64': base64.b64encode(value).decode()} for value in df[_RECORDBATCH_COLUMN]] | ||
| else: | ||
| as_binary = df.columns.str.endswith("_bytes") |
There was a problem hiding this comment.
Why does the name end with "_bytes"?
There was a problem hiding this comment.
User specified byte columns, it's consistent with the original implementation
There was a problem hiding this comment.
This is required by cloud ai platform to indicate the bytes feature with '_bytes' suffix.
| @beam.typehints.with_input_types(tf.train.Example) | ||
| @beam.typehints.with_output_types(prediction_log_pb2.PredictionLog) | ||
| def RunInferenceImpl( # pylint: disable=invalid-name | ||
| def RunInferenceOnExamples( # pylint: disable=invalid-name |
There was a problem hiding this comment.
Let's use the first option of public API here to have a polymorphic RunInference and RunInferenceImpl.
|
Thanks for your pull request. It looks like this may be your first contribution to a Google open source project (if not, look below for help). Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). 📝 Please visit https://cla.developers.google.com/ to sign. Once you've signed (or fixed any issues), please reply here with What to do if you already signed the CLAIndividual signers
Corporate signers
ℹ️ Googlers: Go here for more info. |
|
|
CLAs look good, thanks! ℹ️ Googlers: Go here for more info. |
Internally uses Arrow RecordBatch for processing, supports multi-tensor input