diff --git a/pai/estimator.py b/pai/estimator.py index 2ad604e..13ad8dd 100644 --- a/pai/estimator.py +++ b/pai/estimator.py @@ -724,6 +724,7 @@ def __init__( experiment_config=experiment_config, resource_id=resource_id, session=session, + **kwargs, ) def training_image_uri(self) -> str: diff --git a/pai/model/_model.py b/pai/model/_model.py index c519153..d3a408b 100644 --- a/pai/model/_model.py +++ b/pai/model/_model.py @@ -2124,8 +2124,8 @@ def get_estimator( ts.scheduler.max_running_time_in_seconds if ts.scheduler else None ) - resource_id = kwargs.get("resource_id") - instance_spec = kwargs.get("instance_spec") + resource_id = kwargs.pop("resource_id", None) + instance_spec = kwargs.pop("instance_spec", None) compute_resource = ts.compute_resource if resource_id: if instance_type: @@ -2146,6 +2146,7 @@ def get_estimator( ) else: if instance_spec: + instance_spec = None logger.warning( "The instance spec is ignored when resource_id is not provided." ) @@ -2182,6 +2183,7 @@ def get_estimator( instance_type=instance_type, instance_count=instance_count, instance_spec=instance_spec, + resource_id=resource_id, output_path=output_path, labels=labels, **kwargs,