diff --git a/xla/pjrt/distributed/coordination/coordination_client.h b/xla/pjrt/distributed/coordination/coordination_client.h index 989c368dddc8e..2cbdb6858ce3f 100644 --- a/xla/pjrt/distributed/coordination/coordination_client.h +++ b/xla/pjrt/distributed/coordination/coordination_client.h @@ -50,8 +50,8 @@ using xla::coordination::ShutdownTaskRequest; using xla::coordination::ShutdownTaskResponse; using xla::coordination::TryGetKeyValueRequest; using xla::coordination::TryGetKeyValueResponse; -using xla::coordination::WatchJobStateRequest; -using xla::coordination::WatchJobStateResponse; +using xla::coordination::WatchTasksRequest; +using xla::coordination::WatchTasksResponse; // Base class of client interface for communicating with coordination service. // Can be implemented by a variety of transports such as gRPC. @@ -74,10 +74,10 @@ class CoordinationClient { ShutdownTaskResponse* response, tsl::StatusCallback done) = 0; - virtual void WatchJobStateAsync(tsl::CallOptions* call_opts, - const WatchJobStateRequest* request, - WatchJobStateResponse* response, - tsl::StatusCallback done) = 0; + virtual void WatchTasksAsync(tsl::CallOptions* call_opts, + const WatchTasksRequest* request, + WatchTasksResponse* response, + tsl::StatusCallback done) = 0; virtual void InsertKeyValueAsync(const InsertKeyValueRequest* request, InsertKeyValueResponse* response, diff --git a/xla/pjrt/distributed/coordination/coordination_service.cc b/xla/pjrt/distributed/coordination/coordination_service.cc index 4abf01d4a46ef..a5180ed638a03 100644 --- a/xla/pjrt/distributed/coordination/coordination_service.cc +++ b/xla/pjrt/distributed/coordination/coordination_service.cc @@ -642,7 +642,7 @@ std::vector CoordinationService::GetJobState() { return states_info; } -void CoordinationService::NotifyWatchJobStateCallbacks() { +void CoordinationService::NotifyWatchTasksCallbacks() { for (auto& callback : watch_job_state_callbacks_) { callback(GetJobState(), cluster_state_version_number_); } @@ -651,11 +651,11 @@ void CoordinationService::NotifyWatchJobStateCallbacks() { void CoordinationService::ClusterStateUpdated() { cluster_state_version_number_++; - NotifyWatchJobStateCallbacks(); + NotifyWatchTasksCallbacks(); } -void CoordinationService::WatchJobState(std::optional version_number, - WatchJobStateCallback callback) { +void CoordinationService::WatchTasks(std::optional version_number, + WatchTasksCallback callback) { absl::MutexLock l(state_mu_); int64_t v = version_number.value_or(-1); CHECK_GE(cluster_state_version_number_, v); diff --git a/xla/pjrt/distributed/coordination/coordination_service.h b/xla/pjrt/distributed/coordination/coordination_service.h index b9aab8c80c5eb..93d360b00f09d 100644 --- a/xla/pjrt/distributed/coordination/coordination_service.h +++ b/xla/pjrt/distributed/coordination/coordination_service.h @@ -160,10 +160,9 @@ class CoordinationService { absl::Status ReportTaskError(TaskId task, const absl::Status& error); // Watches the state and the error status of the job. - using WatchJobStateCallback = absl::AnyInvocable, int64_t)>; - void WatchJobState(std::optional version_number, - WatchJobStateCallback); + void WatchTasks(std::optional version_number, WatchTasksCallback); // Insert a configuration key-value in the coordination service. // For now, a key-value can only be inserted once and cannot be updated. @@ -554,11 +553,11 @@ class CoordinationService { std::vector GetJobState() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); - // Notifies all callbacks registered via WatchJobState. - void NotifyWatchJobStateCallbacks() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); + // Notifies all callbacks registered via WatchTasks. + void NotifyWatchTasksCallbacks() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); // This method should be called whenever the cluster state changes in a way - // such that NotifyWatchJobStateCallbacks should be called. + // such that NotifyWatchTasksCallbacks should be called. void ClusterStateUpdated() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_); tsl::Env& env_; @@ -573,7 +572,7 @@ class CoordinationService { absl::flat_hash_map> cluster_state_ ABSL_GUARDED_BY(state_mu_); int64_t cluster_state_version_number_ ABSL_GUARDED_BY(state_mu_) = 0; - std::vector watch_job_state_callbacks_ + std::vector watch_job_state_callbacks_ ABSL_GUARDED_BY(state_mu_); KeyValueStore store_; diff --git a/xla/pjrt/distributed/coordination/coordination_service.proto b/xla/pjrt/distributed/coordination/coordination_service.proto index 6cd1a455ef219..cd6376f2c80e1 100644 --- a/xla/pjrt/distributed/coordination/coordination_service.proto +++ b/xla/pjrt/distributed/coordination/coordination_service.proto @@ -102,38 +102,38 @@ message ShutdownTaskResponse {} // fails. This new job state has version 10. Another task fails, leading to job // state 22. // -// A WatchJobStateRequest with version_number v blocks until the latest job -// state has a version number greater than v. It then returns the latest job -// state along with its version number. +// A WatchTasksRequest with version_number v blocks until the latest job state +// has a version number greater than v. It then returns the latest job state +// along with its version number. // // This API can be used to (1) get the latest job state immediately or (2) watch // the job state for changes. // -// 1. To get the latest job state, issue a WatchJobStateRequest with a +// 1. To get the latest job state, issue a WatchTasksRequest with a // version number of -1. -// 2. To watch the job state for changes, issue a WatchJobStateRequest with -// the version number returned by the previous WatchJobStateResponse. +// 2. To watch the job state for changes, issue a WatchTasksRequest with +// the version number returned by the previous WatchTasksResponse. // // Though version numbers are ints, they should be treated as opaque ids. The // only valid version numbers are -1 and the version numbers returned in a -// WatchJobStateResponse. You should NOT issue a WatchJobStateRequest with some +// WatchTasksResponse. You should NOT issue a WatchTasksRequest with some // arbitrarily chosen version number. // // Note the following two subtleties with this API: // -// 1. This API *cannot* be used to witness *every* change in job state. If -// you issue a WatchJobStateRequest with version v, you may receive a job -// state that has changed multiple times since version v. -// 2. If you issue a WatchJobStateRequest with version number v, the job -// state you receive may not be different than the job state at version -// v. For example, a task may have failed and later recovered leading to -// an identical job state. -message WatchJobStateRequest { - reserved 1; - int64 version_number = 2; +// 1. This API *cannot* be used to witness *every* change in job state. If you +// issue a WatchTasksRequest with version v, you may receive a job state that +// has changed multiple times since version v. +// +// 2. If you issue a WatchTasksRequest with version number v, the job state you +// receive may not be different than the job state at version v. For +// example, a task may have failed and later recovered leading to an +// identical job state. +message WatchTasksRequest { + int64 version_number = 1; } -message WatchJobStateResponse { +message WatchTasksResponse { repeated TaskInfo task_state = 1; int64 version_number = 2; } @@ -280,7 +280,7 @@ service CoordinationService { } // Watches the state of every task in a remote job. - rpc WatchJobState(WatchJobStateRequest) returns (WatchJobStateResponse); + rpc WatchTasks(WatchTasksRequest) returns (WatchTasksResponse); // Insert configuration key-value that will be accessible to all cluster // tasks. The key can be formatted as Unix file path with hierarchy. The diff --git a/xla/pjrt/distributed/coordination/coordination_service_agent.cc b/xla/pjrt/distributed/coordination/coordination_service_agent.cc index 12360e8a678f3..531cdb4daf42f 100644 --- a/xla/pjrt/distributed/coordination/coordination_service_agent.cc +++ b/xla/pjrt/distributed/coordination/coordination_service_agent.cc @@ -319,19 +319,18 @@ void CoordinationServiceAgent::PollForErrorAsync(tsl::StatusCallback done) { }); } -std::shared_ptr CoordinationServiceAgent::WatchJobStateAsync( +std::shared_ptr CoordinationServiceAgent::WatchTasksAsync( std::optional version_number, - std::function< - void(absl::StatusOr)> + std::function)> callback) { - auto request = std::make_shared(); - auto response = std::make_shared(); + auto request = std::make_shared(); + auto response = std::make_shared(); auto call_opts = std::make_shared(); - WatchJobStateRequest* request_ptr = request.get(); - WatchJobStateResponse* response_ptr = response.get(); + WatchTasksRequest* request_ptr = request.get(); + WatchTasksResponse* response_ptr = response.get(); request->set_version_number(version_number.value_or(-1)); - leader_client_->WatchJobStateAsync( + leader_client_->WatchTasksAsync( call_opts.get(), request_ptr, response_ptr, [request = std::move(request), response = std::move(response), callback = std::move(callback)](const absl::Status& s) mutable { @@ -344,17 +343,16 @@ std::shared_ptr CoordinationServiceAgent::WatchJobStateAsync( return call_opts; } -absl::StatusOr -CoordinationServiceAgent::WatchJobState(std::optional version_number) { - absl::StatusOr response; +absl::StatusOr +CoordinationServiceAgent::WatchTasks(std::optional version_number) { + absl::StatusOr response; absl::Notification done; - WatchJobStateAsync( - version_number, - [&response, - &done](absl::StatusOr r) { - response = std::move(r); - done.Notify(); - }); + WatchTasksAsync(version_number, + [&response, &done]( + absl::StatusOr r) { + response = std::move(r); + done.Notify(); + }); done.WaitForNotification(); return response; } diff --git a/xla/pjrt/distributed/coordination/coordination_service_agent.h b/xla/pjrt/distributed/coordination/coordination_service_agent.h index 31f9c3e9954e1..8a8a8c85381ac 100644 --- a/xla/pjrt/distributed/coordination/coordination_service_agent.h +++ b/xla/pjrt/distributed/coordination/coordination_service_agent.h @@ -140,15 +140,14 @@ class CoordinationServiceAgent { CoordinationService::TaskId task_id() const { return task_id_; } // Watches the status of a remote job. - absl::StatusOr WatchJobState( + absl::StatusOr WatchTasks( std::optional version_number); // Note: Cancel the underlying RPC call with `call_opts->StartCancel()` and // `call_opts->ClearCancelCallback()`. - std::shared_ptr WatchJobStateAsync( + std::shared_ptr WatchTasksAsync( std::optional version_number, - std::function< - void(absl::StatusOr)> + std::function)> callback); // Report error to coordination service. This will invoke the error callback. diff --git a/xla/pjrt/distributed/coordination/coordination_service_agent_test.cc b/xla/pjrt/distributed/coordination/coordination_service_agent_test.cc index ad6f5288dbe2f..bf601b3877c49 100644 --- a/xla/pjrt/distributed/coordination/coordination_service_agent_test.cc +++ b/xla/pjrt/distributed/coordination/coordination_service_agent_test.cc @@ -117,9 +117,9 @@ class TestCoordinationClient : public CoordinationClient { (const GetAliveTasksRequest*, GetAliveTasksResponse*, tsl::StatusCallback), (override)); - MOCK_METHOD(void, WatchJobStateAsync, - (tsl::CallOptions*, const WatchJobStateRequest*, - WatchJobStateResponse*, tsl::StatusCallback), + MOCK_METHOD(void, WatchTasksAsync, + (tsl::CallOptions*, const WatchTasksRequest*, WatchTasksResponse*, + tsl::StatusCallback), (override)); MOCK_METHOD(void, HeartbeatAsync, (tsl::CallOptions*, const HeartbeatRequest*, HeartbeatResponse*, diff --git a/xla/pjrt/distributed/coordination/coordination_service_rpc_handler.cc b/xla/pjrt/distributed/coordination/coordination_service_rpc_handler.cc index 2b17ed377ae15..3242298a7c5c8 100644 --- a/xla/pjrt/distributed/coordination/coordination_service_rpc_handler.cc +++ b/xla/pjrt/distributed/coordination/coordination_service_rpc_handler.cc @@ -103,10 +103,9 @@ void CoordinationServiceRpcHandler::ShutdownTaskAsync( [done](absl::Status s) { done(s); }); } -void CoordinationServiceRpcHandler::WatchJobStateAsync( - const xla::coordination::WatchJobStateRequest* request, - xla::coordination::WatchJobStateResponse* response, - tsl::StatusCallback done) { +void CoordinationServiceRpcHandler::WatchTasksAsync( + const xla::coordination::WatchTasksRequest* request, + xla::coordination::WatchTasksResponse* response, tsl::StatusCallback done) { absl::ReaderMutexLock l(mu_); if (service_ == nullptr) { done(MakeCoordinationError( @@ -118,7 +117,7 @@ void CoordinationServiceRpcHandler::WatchJobStateAsync( if (request->version_number() >= 0) { version_number.emplace(request->version_number()); } - service_->WatchJobState( + service_->WatchTasks( version_number, [response, done](std::vector info, int64_t version_number) { diff --git a/xla/pjrt/distributed/coordination/coordination_service_rpc_handler.h b/xla/pjrt/distributed/coordination/coordination_service_rpc_handler.h index 0501032c6c2bb..6d389f3bdcf6d 100644 --- a/xla/pjrt/distributed/coordination/coordination_service_rpc_handler.h +++ b/xla/pjrt/distributed/coordination/coordination_service_rpc_handler.h @@ -44,10 +44,9 @@ class CoordinationServiceRpcHandler { xla::coordination::ShutdownTaskResponse* response, tsl::StatusCallback done); - void WatchJobStateAsync( - const xla::coordination::WatchJobStateRequest* request, - xla::coordination::WatchJobStateResponse* response, - tsl::StatusCallback done); + void WatchTasksAsync(const xla::coordination::WatchTasksRequest* request, + xla::coordination::WatchTasksResponse* response, + tsl::StatusCallback done); void InsertKeyValueAsync( const xla::coordination::InsertKeyValueRequest* request, diff --git a/xla/pjrt/distributed/coordination/coordination_service_test.cc b/xla/pjrt/distributed/coordination/coordination_service_test.cc index 7bd682cefcc12..d69d41dd5d0bf 100644 --- a/xla/pjrt/distributed/coordination/coordination_service_test.cc +++ b/xla/pjrt/distributed/coordination/coordination_service_test.cc @@ -124,7 +124,7 @@ class TestCoordinationClient : public CoordinationClient { UNIMPLEMENTED_WITH_CALL_OPTS(Heartbeat); UNIMPLEMENTED_WITH_CALL_OPTS(ShutdownTask); UNIMPLEMENTED_WITH_CALL_OPTS(PollForError); - UNIMPLEMENTED_WITH_CALL_OPTS(WatchJobState); + UNIMPLEMENTED_WITH_CALL_OPTS(WatchTasks); #undef UNIMPLEMENTED_WITH_CALL_OPTS private: @@ -437,8 +437,8 @@ xla::coordination::TaskInfo info(CoordinationService::TaskId task, return info; } -TEST_F(CoordinateTwoTasksTest, WatchJobStateSucceeds) { - // This test calls WatchJobState on two successfully connected tasks. +TEST_F(CoordinateTwoTasksTest, WatchTasksSucceeds) { + // This test calls WatchTasks on two successfully connected tasks. // Connect the tasks. EnableCoordinationService(); @@ -447,7 +447,7 @@ TEST_F(CoordinateTwoTasksTest, WatchJobStateSucceeds) { // Watch the job state, which should return immediately. absl::Notification done; - coord_service_->WatchJobState( + coord_service_->WatchTasks( std::nullopt, [&, this](std::vector got, int64_t version_number) { using State = xla::coordination::TaskState; @@ -461,8 +461,8 @@ TEST_F(CoordinateTwoTasksTest, WatchJobStateSucceeds) { done.WaitForNotification(); } -TEST_F(CoordinateTwoTasksTest, WatchJobStateReturnsDisconnected) { - // This test calls WatchJobState on one successfully connected task and one +TEST_F(CoordinateTwoTasksTest, WatchTasksReturnsDisconnected) { + // This test calls WatchTasks on one successfully connected task and one // disconnected task. // Connect the tasks. Disconnect task 1. @@ -473,7 +473,7 @@ TEST_F(CoordinateTwoTasksTest, WatchJobStateReturnsDisconnected) { // Watch the job state, which should return immediately. absl::Notification done; - coord_service_->WatchJobState( + coord_service_->WatchTasks( std::nullopt, [&, this](std::vector got, int64_t version_number) { using State = xla::coordination::TaskState; @@ -488,8 +488,8 @@ TEST_F(CoordinateTwoTasksTest, WatchJobStateReturnsDisconnected) { done.WaitForNotification(); } -TEST_F(CoordinateTwoTasksTest, WatchJobStateReturnsNewIncarnation) { - // This test calls WatchJobState after one task has restarted with a new +TEST_F(CoordinateTwoTasksTest, WatchTasksReturnsNewIncarnation) { + // This test calls WatchTasks after one task has restarted with a new // incarnation. EnableCoordinationService(); ASSERT_OK(coord_service_->RegisterTask(0, incarnation_0_)); @@ -499,7 +499,7 @@ TEST_F(CoordinateTwoTasksTest, WatchJobStateReturnsNewIncarnation) { // Watch the job state, which should return immediately. absl::Notification done; - coord_service_->WatchJobState( + coord_service_->WatchTasks( std::nullopt, [&, this](std::vector got, int64_t version_number) { using State = xla::coordination::TaskState; @@ -514,8 +514,8 @@ TEST_F(CoordinateTwoTasksTest, WatchJobStateReturnsNewIncarnation) { done.WaitForNotification(); } -TEST_F(CoordinateTwoTasksTest, WatchJobStateBlocksUntilChange) { - // This test calls checks that WatchJobState blocks until the job state +TEST_F(CoordinateTwoTasksTest, WatchTasksBlocksUntilChange) { + // This test calls checks that WatchTasks blocks until the job state // changes. // Connect the tasks. Disconnect task 1. @@ -526,7 +526,7 @@ TEST_F(CoordinateTwoTasksTest, WatchJobStateBlocksUntilChange) { // Watch the job state, which should return immediately. absl::Notification done_1; int64_t version_number = -1; - coord_service_->WatchJobState( + coord_service_->WatchTasks( std::nullopt, [&](std::vector got, int64_t v) { EXPECT_THAT(v, Ge(0)); @@ -537,7 +537,7 @@ TEST_F(CoordinateTwoTasksTest, WatchJobStateBlocksUntilChange) { // Watch the job state again, which should block. absl::Notification done_2; - coord_service_->WatchJobState( + coord_service_->WatchTasks( version_number, [&, this](std::vector got, int64_t v) { using State = xla::coordination::TaskState; @@ -558,8 +558,8 @@ TEST_F(CoordinateTwoTasksTest, WatchJobStateBlocksUntilChange) { done_2.WaitForNotification(); } -TEST_F(CoordinateTwoTasksTest, WatchJobStateAfterTwoStateChanges) { - // This test calls WatchJobState after two state changes. +TEST_F(CoordinateTwoTasksTest, WatchTasksAfterTwoStateChanges) { + // This test calls WatchTasks after two state changes. EnableCoordinationService(); ASSERT_OK(coord_service_->RegisterTask(0, incarnation_0_)); ASSERT_OK(coord_service_->RegisterTask(1, incarnation_1_)); @@ -567,7 +567,7 @@ TEST_F(CoordinateTwoTasksTest, WatchJobStateAfterTwoStateChanges) { // Watch the job state, which should return immediately. absl::Notification done_1; int64_t version_number = -1; - coord_service_->WatchJobState( + coord_service_->WatchTasks( std::nullopt, [&, this](std::vector got, int64_t v) { using State = xla::coordination::TaskState; @@ -590,7 +590,7 @@ TEST_F(CoordinateTwoTasksTest, WatchJobStateAfterTwoStateChanges) { // Watch the job state, which should return immediately because the state has // already changed. absl::Notification done_2; - coord_service_->WatchJobState( + coord_service_->WatchTasks( version_number, [&, this](std::vector got, int64_t v) { using State = xla::coordination::TaskState; diff --git a/xla/pjrt/distributed/coordination/grpc_coordination_client.cc b/xla/pjrt/distributed/coordination/grpc_coordination_client.cc index 784f93428ed7e..8005f06fb3f78 100644 --- a/xla/pjrt/distributed/coordination/grpc_coordination_client.cc +++ b/xla/pjrt/distributed/coordination/grpc_coordination_client.cc @@ -62,8 +62,8 @@ using xla::coordination::ShutdownTaskRequest; using xla::coordination::ShutdownTaskResponse; using xla::coordination::TryGetKeyValueRequest; using xla::coordination::TryGetKeyValueResponse; -using xla::coordination::WatchJobStateRequest; -using xla::coordination::WatchJobStateResponse; +using xla::coordination::WatchTasksRequest; +using xla::coordination::WatchTasksResponse; class GrpcCoordinationClientThread { public: @@ -143,12 +143,12 @@ class GrpcCoordinationClient : public CoordinationClient { /*fail_fast=*/true, &target_); } - void WatchJobStateAsync(tsl::CallOptions* call_opts, - const WatchJobStateRequest* request, - WatchJobStateResponse* response, - tsl::StatusCallback done) override { + void WatchTasksAsync(tsl::CallOptions* call_opts, + const WatchTasksRequest* request, + WatchTasksResponse* response, + tsl::StatusCallback done) override { new tsl::RPCState( - &stub_, cq_, "/xla.coordination.CoordinationService/WatchJobState", + &stub_, cq_, "/xla.coordination.CoordinationService/WatchTasks", *request, response, std::move(done), call_opts, /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true, &target_); diff --git a/xla/pjrt/distributed/coordination/grpc_coordination_service_impl.cc b/xla/pjrt/distributed/coordination/grpc_coordination_service_impl.cc index 5fb2b4437f791..e76d7e1bd8eb2 100644 --- a/xla/pjrt/distributed/coordination/grpc_coordination_service_impl.cc +++ b/xla/pjrt/distributed/coordination/grpc_coordination_service_impl.cc @@ -50,7 +50,7 @@ void GrpcCoordinationServiceImpl::HandleRPCsLoop() { ENQUEUE_REQUEST(RegisterTask); ENQUEUE_REQUEST(ShutdownTask); ENQUEUE_REQUEST(Heartbeat); - ENQUEUE_REQUEST(WatchJobState); + ENQUEUE_REQUEST(WatchTasks); ENQUEUE_REQUEST(InsertKeyValue); ENQUEUE_REQUEST(GetKeyValue); ENQUEUE_REQUEST(TryGetKeyValue); diff --git a/xla/pjrt/distributed/coordination/grpc_coordination_service_impl.h b/xla/pjrt/distributed/coordination/grpc_coordination_service_impl.h index 612f22c2ddb95..823c2441b7837 100644 --- a/xla/pjrt/distributed/coordination/grpc_coordination_service_impl.h +++ b/xla/pjrt/distributed/coordination/grpc_coordination_service_impl.h @@ -88,7 +88,7 @@ class GrpcCoordinationServiceImpl : public tsl::AsyncServiceInterface { HANDLER(RegisterTask); HANDLER(ShutdownTask); HANDLER(Heartbeat); - HANDLER(WatchJobState); + HANDLER(WatchTasks); HANDLER(InsertKeyValue); HANDLER(GetKeyValue); HANDLER(TryGetKeyValue); diff --git a/xla/python/pjrt_ifrt/pjrt_client.cc b/xla/python/pjrt_ifrt/pjrt_client.cc index 3003ceb10bd20..45b5de303a6bf 100644 --- a/xla/python/pjrt_ifrt/pjrt_client.cc +++ b/xla/python/pjrt_ifrt/pjrt_client.cc @@ -1565,22 +1565,22 @@ absl::Status PjRtClient::WatchGlobalProcessInfo( int64_t version_number = -1; // latest job state version while (true) { - // Call WatchJobStateAsync. - VLOG(3) << "Calling WatchJobStateAsync for task " << task_id + // Call WatchTasksAsync. + VLOG(3) << "Calling WatchTasksAsync for task " << task_id << " with version number " << version_number; - absl::StatusOr response; + absl::StatusOr response; bool done = false; - std::shared_ptr call_opts = agent.WatchJobStateAsync( + std::shared_ptr call_opts = agent.WatchTasksAsync( version_number, [this, &response, - &done](absl::StatusOr r) { + &done](absl::StatusOr r) { response = std::move(r); absl::MutexLock lock(shutting_down_mu_); done = true; }); { - // Wait for the WatchJobStateAsync call to finish or for us to shut down, + // Wait for the WatchTasksAsync call to finish or for us to shut down, // whichever happens first. absl::MutexLock lock(shutting_down_mu_); auto done_or_shutting_down = [this, &done]() { @@ -1590,7 +1590,7 @@ absl::Status PjRtClient::WatchGlobalProcessInfo( shutting_down_mu_.Await(absl::Condition(&done_or_shutting_down)); if (shutting_down_) { - // Cancel the call the WatchJobStateAsync and wait for it to terminate. + // Cancel the call the WatchTasksAsync and wait for it to terminate. VLOG(3) << "WatchGlobalProcessInfo shutting down for task " << task_id; call_opts->StartCancel(); shutting_down_mu_.Await(absl::Condition(&done)); @@ -1601,7 +1601,7 @@ absl::Status PjRtClient::WatchGlobalProcessInfo( // Sleep to avoid repeatedly issuing a request that fails immediately. // // TODO: mwhittaker - Perform exponential backoff. - LOG(WARNING) << "WatchJobStateAsync failed for task " << task_id << ": " + LOG(WARNING) << "WatchTasksAsync failed for task " << task_id << ": " << response.status(); shutting_down_mu_.AwaitWithTimeout(absl::Condition(&shutting_down_), absl::Seconds(1));