-
Notifications
You must be signed in to change notification settings - Fork 60
Switch to RunInferenceCore #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
tfx_bsl/beam/run_inference.py
Outdated
| def _BatchQueries(queries: beam.pvalue.PCollection) -> beam.pvalue.PCollection: | ||
| """Groups queries into batches.""" | ||
|
|
||
| def _add_key(query: QueryType) -> Tuple[bytes, QueryType]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before streaming is enabled, the model will be the same during inference. Is this still needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've reduced _BatchQueries to the minimum required to operate on identical inference specs. #13 can introduce keyed batching and inference spec serialization
tfx_bsl/beam/run_inference.py
Outdated
| if fixed_inference_spec_type is None: | ||
| tagged = pcoll | 'Tag inference type' >> _TagUsingInProcessInference() | ||
| tagged['remote'] | 'NotImplemented' >> _NotImplementedTransform() | ||
| raw_predictions = ( | ||
| tagged['local'] | ||
| | 'Regress' >> beam.ParDo(_BatchRegressDoFn(shared.Shared()))) | ||
| else: | ||
| raise NotImplementedError | ||
| if _using_in_process_inference(fixed_inference_spec_type): | ||
| raw_predictions = ( | ||
| pcoll | ||
| | 'Regress' >> beam.ParDo(_BatchRegressDoFn(shared.Shared(), | ||
| fixed_inference_spec_type=fixed_inference_spec_type))) | ||
| else: | ||
| raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems this part is repeated several times, can this be extracted to a function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactored this section into a single operation constructor:
tfx-bsl/tfx_bsl/beam/run_inference.py
Lines 940 to 1020 in a88ecc1
| def _BuildInferenceOperation( | |
| name: str, | |
| in_process_dofn: _BaseBatchSavedModelDoFn, | |
| remote_dofn: Optional[_BaseDoFn], | |
| build_prediction_log_dofn: beam.DoFn | |
| ): | |
| """Construct an operation specific inference sub-pipeline. | |
| Args: | |
| name: name of the operation (e.g. "Classify") | |
| in_process_dofn: a _BaseBatchSavedModelDoFn class to use for in-process | |
| inference | |
| remote_dofn: an optional DoFn that is used for remote inference | |
| build_prediction_log_dofn: a DoFn that can build prediction logs from the | |
| output of `in_process_dofn` and `remote_dofn` | |
| Returns: | |
| A PTransform of the type (_QueryBatchType -> PredictionLog) | |
| """ | |
| @beam.ptransform_fn | |
| @beam.typehints.with_input_types(_QueryBatchType) | |
| @beam.typehints.with_output_types(prediction_log_pb2.PredictionLog) | |
| def _Op( | |
| pcoll: beam.pvalue.PCollection, | |
| fixed_inference_spec_type: model_spec_pb2.InferenceSpecType = None | |
| ): # pylint: disable=invalid-name | |
| raw_result = None | |
| if fixed_inference_spec_type is None: | |
| tagged = pcoll | 'TagInferenceType' >> _TagUsingInProcessInference() | |
| in_process_result = ( | |
| tagged['in_process'] | |
| | ('InProcess%s' % name) >> beam.ParDo( | |
| in_process_dofn(shared.Shared()))) | |
| if remote_dofn: | |
| remote_result = ( | |
| tagged['remote'] | |
| | ('Remote%s' % name) >> beam.ParDo( | |
| remote_dofn(pcoll.pipeline.options))) | |
| raw_result = ( | |
| [in_process_result, remote_result] | |
| | 'FlattenResult' >> beam.Flatten()) | |
| else: | |
| raw_result = in_process_result | |
| else: | |
| if _using_in_process_inference(fixed_inference_spec_type): | |
| raw_result = ( | |
| pcoll | |
| | ('InProcess%s' % name) >> beam.ParDo(in_process_dofn( | |
| shared.Shared(), | |
| fixed_inference_spec_type=fixed_inference_spec_type))) | |
| else: | |
| raw_result = ( | |
| pcoll | |
| | ('Remote%s' % name) >> beam.ParDo(remote_dofn( | |
| pcoll.pipeline.options, | |
| fixed_inference_spec_type=fixed_inference_spec_type))) | |
| return ( | |
| raw_result | |
| | ('BuildPredictionLogFor%s' % name) >> beam.ParDo( | |
| build_prediction_log_dofn())) | |
| return _Op | |
| _Classify = _BuildInferenceOperation( | |
| 'Classify', _BatchClassifyDoFn, None, | |
| _BuildPredictionLogForClassificationsDoFn) | |
| _Regress = _BuildInferenceOperation( | |
| 'Regress', _BatchRegressDoFn, None, | |
| _BuildPredictionLogForRegressionsDoFn) | |
| _Predict = _BuildInferenceOperation( | |
| 'Predict', _BatchPredictDoFn, _RemotePredictDoFn, | |
| _BuildPredictionLogForPredictionsDoFn) | |
| _MultiInference = _BuildInferenceOperation( | |
| 'MultiInference', _BatchMultiInferenceDoFn, None, | |
| _BuildMultiInferenceLogDoFn) |
tfx_bsl/beam/run_inference.py
Outdated
| if self._use_fixed_model: | ||
| self._setup_model(self._fixed_inference_spec_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this in setup instead of init ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moving this to __init__ would require the api client to be serializable (as part of the DoFn). This may be possible, but may also lead to some strange issues if multiple DoFn's sharing an api client is a problem. The original code configured this in setup so this PR is just maintaining that convention.
tfx_bsl/beam/run_inference_test.py
Outdated
| saved_model_spec=model_spec_pb2.SavedModelSpec( | ||
| model_path=model_path)) | ||
|
|
||
| def test_batch_queries_single_model(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was intended to contrast the todo test below, but I think it is redundant so I've removed it
Benchmarks showed that TagByOperation was a performance bottleneck* as it requires disc access per query batch. To mitigate this I implemented operation caching inside the DoFn. For readability, I also renamed this operation to "SplitByOperation" as that more accurately describes its purpose. On a dataset with 1m examples, TagByOperation took ~25% of the total wall time. After implementing caching, this was reduced to ~2%.
tfx_bsl/beam/run_inference.py
Outdated
| super(_BaseDoFn, self).__init__() | ||
| self._clock = None | ||
| self._metrics_collector = self._MetricsCollector(inference_spec_type) | ||
| self._metrics_collector = self._MetricsCollector(fixed_inference_spec_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be moved to setup_model() ? So that we can know the proximity and operation_type for model streaming case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My concern with this is that we might run into issues initializing Beam metrics outside of __init__. That might be a question for someone with more experience with Beam/Dataflow.
For the purposes of metrics collection, I think the end goal here is to collect metrics per operation type (e.g. classify, regress, ...). The fixed_inference_spec_type is only used here to determine the operation type so we could get rid of that and use a different solution where we initialize unique metric collectors for each operation type (which does not require knowing the inference spec). Then at runtime we choose which metric collector to use based on the inference spec (which will be available).
e.g. here we would have something like:
self._classify_metrics = self._MetricsCollector(OperationType.CLASSIFICATION, _METRICS_DESCRIPTOR_IN_PROCESS)
self._regress_metrics = self._MetricsCollector(OperationType.REGRESSION, _METRICS_DESCRIPTOR_IN_PROCESS)
...and maybe expose a new method:
class _BaseDoFn(beam.DoFn):
...
def _metrics_collector_for_inference_spec(inference_spec_type: InferenceSpecType) -> _MetricsCollector:
...internally, we could replace:
self._metrics_collector.update(...)with
self._metrics_collector_for_inference_spec(inference_spec).update(...)Does this sound reasonable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It sounds good. Also note that we need to toggle on _METRICS_DESCRIPTOR_IN_PROCESS and _METRICS_DESCRIPTOR_CLOUD_AI_PREDICTION
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found a slightly cleaner solution here: I refactored MetricsCollector to accept an operation type and proximity directly: https://github.com/hgarrereyn/tfx-bsl/blob/core/tfx_bsl/beam/run_inference.py#L299-L309
Each subclass DoFn is responsible for configuring the correct operation type and proximity (for this we don't need the inference spec).
I refactored the MetricCollector methods for readability so now there is:
update_inferencefor updating inference metricsupdate_model_load+commit_cached_metricsfor updating model loading metrics. The first function caches the metrics and the second function commits the metrics. In the fixed model case,update_model_loadis called inDoFn.setupandcommit_cached_metricsis called inDoFn.finish_bundleand I added some documentation explaining why. In the dynamic model case, both functions are called sequentially inDoFn.process.
|
Thanks Harrison! It looks good in general. Just some minor comments. |
Overview
This PR moves the internal implementation of
RunInferenceImplto_RunInferenceCore. This core component acceptsquerytuples, i.e.(InferenceSpecType, Example). This PR has no public facing changes but will allow for the implementation of a streaming model API.What was changed:
RunInferenceImplis now a wrapper around_RunInferenceCore_RunInferenceCoreand the internal PTransforms that previously tookinference_spec_typearguments now accept an optionalfixed_inference_spec_typethat can beNone.fixed_inference_spec_typeis set and the internal PTransforms can take a "fast" path that includes collapsing down to a single sequence of PTransforms and loading the model during thesetupmethod of the DoFn. (similar to the current implementation)_RunInferenceCorewill build a graph containing all possible operations and types (local/remote) and queries will be batched and routed to the correct operation at runtime. In this case, models will be loaded during theprocessmethod of a DoFn (but caching is still possible)._BaseBatchSavedModelDoFnand_RemotePredictDoFnwas restructured:_setup_model(self, inference_spec_type: model_spec_pb2.InferenceSpecType). This function includes code previously in both__init__andsetup. It will be called either insetuporprocessdepending on whether or not the inference spec is available at pipeline construction time._BatchClassifyDoFn) could implement operation-specific model validation by overloading thesetupmethod and optionally raising an error. This check occurs after the model signature is available but before it has been loaded. Since all this logic is now contained in_setup_model, there is a new_validate_model(self)method that is unimplemented in the base class and can be overloaded to perform validation logic.ExampleType,QueryType,_QueryBatchType(the first two types will be public facing after the model streaming API is implemented)beam.BatchElements; when working with queries, it is necessary to also perform a grouping operation by model spec.beam.GroupIntoBatchesis currently experimental but contains this functionality. Unfortunately, BEAM-2717 currently blocksRunInferencein GCP Dataflow v2 runner and the v1 runner does not support stateful DoFn's which is required forGroupIntoBatches. CurrentlyBatchElementsis used as a temporary replacement with the understanding that the current implementation will not use more than one model at a time._BatchQueriesand a TODO test that addresses the comment above_RunInferenceCorewith raw queriesBenchmarks:
A test set of 1,000,000 examples (chicago taxi example) was run in a small Beam pipeline on Dataflow (v1 runner). These are the total wall times for 3 separate runs of the
RunInferencecomponent:(Current)
(RunInferenceCore)