Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions xla/pjrt/distributed/coordination/coordination_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ using xla::coordination::ShutdownTaskRequest;
using xla::coordination::ShutdownTaskResponse;
using xla::coordination::TryGetKeyValueRequest;
using xla::coordination::TryGetKeyValueResponse;
using xla::coordination::WatchJobStateRequest;
using xla::coordination::WatchJobStateResponse;
using xla::coordination::WatchTasksRequest;
using xla::coordination::WatchTasksResponse;

// Base class of client interface for communicating with coordination service.
// Can be implemented by a variety of transports such as gRPC.
Expand All @@ -74,10 +74,10 @@ class CoordinationClient {
ShutdownTaskResponse* response,
tsl::StatusCallback done) = 0;

virtual void WatchJobStateAsync(tsl::CallOptions* call_opts,
const WatchJobStateRequest* request,
WatchJobStateResponse* response,
tsl::StatusCallback done) = 0;
virtual void WatchTasksAsync(tsl::CallOptions* call_opts,
const WatchTasksRequest* request,
WatchTasksResponse* response,
tsl::StatusCallback done) = 0;

virtual void InsertKeyValueAsync(const InsertKeyValueRequest* request,
InsertKeyValueResponse* response,
Expand Down
8 changes: 4 additions & 4 deletions xla/pjrt/distributed/coordination/coordination_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ std::vector<TaskInfo> CoordinationService::GetJobState() {
return states_info;
}

void CoordinationService::NotifyWatchJobStateCallbacks() {
void CoordinationService::NotifyWatchTasksCallbacks() {
for (auto& callback : watch_job_state_callbacks_) {
callback(GetJobState(), cluster_state_version_number_);
}
Expand All @@ -651,11 +651,11 @@ void CoordinationService::NotifyWatchJobStateCallbacks() {

void CoordinationService::ClusterStateUpdated() {
cluster_state_version_number_++;
NotifyWatchJobStateCallbacks();
NotifyWatchTasksCallbacks();
}

void CoordinationService::WatchJobState(std::optional<int64_t> version_number,
WatchJobStateCallback callback) {
void CoordinationService::WatchTasks(std::optional<int64_t> version_number,
WatchTasksCallback callback) {
absl::MutexLock l(state_mu_);
int64_t v = version_number.value_or(-1);
CHECK_GE(cluster_state_version_number_, v);
Expand Down
13 changes: 6 additions & 7 deletions xla/pjrt/distributed/coordination/coordination_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,9 @@ class CoordinationService {
absl::Status ReportTaskError(TaskId task, const absl::Status& error);

// Watches the state and the error status of the job.
using WatchJobStateCallback = absl::AnyInvocable<void(
using WatchTasksCallback = absl::AnyInvocable<void(
std::vector<xla::coordination::TaskInfo>, int64_t)>;
void WatchJobState(std::optional<int64_t> version_number,
WatchJobStateCallback);
void WatchTasks(std::optional<int64_t> version_number, WatchTasksCallback);

// Insert a configuration key-value in the coordination service.
// For now, a key-value can only be inserted once and cannot be updated.
Expand Down Expand Up @@ -554,11 +553,11 @@ class CoordinationService {
std::vector<xla::coordination::TaskInfo> GetJobState()
ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);

// Notifies all callbacks registered via WatchJobState.
void NotifyWatchJobStateCallbacks() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);
// Notifies all callbacks registered via WatchTasks.
void NotifyWatchTasksCallbacks() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);

// This method should be called whenever the cluster state changes in a way
// such that NotifyWatchJobStateCallbacks should be called.
// such that NotifyWatchTasksCallbacks should be called.
void ClusterStateUpdated() ABSL_EXCLUSIVE_LOCKS_REQUIRED(state_mu_);

tsl::Env& env_;
Expand All @@ -573,7 +572,7 @@ class CoordinationService {
absl::flat_hash_map<TaskId, std::unique_ptr<TaskState>> cluster_state_
ABSL_GUARDED_BY(state_mu_);
int64_t cluster_state_version_number_ ABSL_GUARDED_BY(state_mu_) = 0;
std::vector<WatchJobStateCallback> watch_job_state_callbacks_
std::vector<WatchTasksCallback> watch_job_state_callbacks_
ABSL_GUARDED_BY(state_mu_);

KeyValueStore store_;
Expand Down
38 changes: 19 additions & 19 deletions xla/pjrt/distributed/coordination/coordination_service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -102,38 +102,38 @@ message ShutdownTaskResponse {}
// fails. This new job state has version 10. Another task fails, leading to job
// state 22.
//
// A WatchJobStateRequest with version_number v blocks until the latest job
// state has a version number greater than v. It then returns the latest job
// state along with its version number.
// A WatchTasksRequest with version_number v blocks until the latest job state
// has a version number greater than v. It then returns the latest job state
// along with its version number.
//
// This API can be used to (1) get the latest job state immediately or (2) watch
// the job state for changes.
//
// 1. To get the latest job state, issue a WatchJobStateRequest with a
// 1. To get the latest job state, issue a WatchTasksRequest with a
// version number of -1.
// 2. To watch the job state for changes, issue a WatchJobStateRequest with
// the version number returned by the previous WatchJobStateResponse.
// 2. To watch the job state for changes, issue a WatchTasksRequest with
// the version number returned by the previous WatchTasksResponse.
//
// Though version numbers are ints, they should be treated as opaque ids. The
// only valid version numbers are -1 and the version numbers returned in a
// WatchJobStateResponse. You should NOT issue a WatchJobStateRequest with some
// WatchTasksResponse. You should NOT issue a WatchTasksRequest with some
// arbitrarily chosen version number.
//
// Note the following two subtleties with this API:
//
// 1. This API *cannot* be used to witness *every* change in job state. If
// you issue a WatchJobStateRequest with version v, you may receive a job
// state that has changed multiple times since version v.
// 2. If you issue a WatchJobStateRequest with version number v, the job
// state you receive may not be different than the job state at version
// v. For example, a task may have failed and later recovered leading to
// an identical job state.
message WatchJobStateRequest {
reserved 1;
int64 version_number = 2;
// 1. This API *cannot* be used to witness *every* change in job state. If you
// issue a WatchTasksRequest with version v, you may receive a job state that
// has changed multiple times since version v.
//
// 2. If you issue a WatchTasksRequest with version number v, the job state you
// receive may not be different than the job state at version v. For
// example, a task may have failed and later recovered leading to an
// identical job state.
message WatchTasksRequest {
int64 version_number = 1;
}

message WatchJobStateResponse {
message WatchTasksResponse {
repeated TaskInfo task_state = 1;
int64 version_number = 2;
}
Expand Down Expand Up @@ -280,7 +280,7 @@ service CoordinationService {
}

// Watches the state of every task in a remote job.
rpc WatchJobState(WatchJobStateRequest) returns (WatchJobStateResponse);
rpc WatchTasks(WatchTasksRequest) returns (WatchTasksResponse);

// Insert configuration key-value that will be accessible to all cluster
// tasks. The key can be formatted as Unix file path with hierarchy. The
Expand Down
34 changes: 16 additions & 18 deletions xla/pjrt/distributed/coordination/coordination_service_agent.cc
Original file line number Diff line number Diff line change
Expand Up @@ -319,19 +319,18 @@ void CoordinationServiceAgent::PollForErrorAsync(tsl::StatusCallback done) {
});
}

std::shared_ptr<tsl::CallOptions> CoordinationServiceAgent::WatchJobStateAsync(
std::shared_ptr<tsl::CallOptions> CoordinationServiceAgent::WatchTasksAsync(
std::optional<int64_t> version_number,
std::function<
void(absl::StatusOr<xla::coordination::WatchJobStateResponse>)>
std::function<void(absl::StatusOr<xla::coordination::WatchTasksResponse>)>
callback) {
auto request = std::make_shared<WatchJobStateRequest>();
auto response = std::make_shared<WatchJobStateResponse>();
auto request = std::make_shared<WatchTasksRequest>();
auto response = std::make_shared<WatchTasksResponse>();
auto call_opts = std::make_shared<tsl::CallOptions>();
WatchJobStateRequest* request_ptr = request.get();
WatchJobStateResponse* response_ptr = response.get();
WatchTasksRequest* request_ptr = request.get();
WatchTasksResponse* response_ptr = response.get();
request->set_version_number(version_number.value_or(-1));

leader_client_->WatchJobStateAsync(
leader_client_->WatchTasksAsync(
call_opts.get(), request_ptr, response_ptr,
[request = std::move(request), response = std::move(response),
callback = std::move(callback)](const absl::Status& s) mutable {
Expand All @@ -344,17 +343,16 @@ std::shared_ptr<tsl::CallOptions> CoordinationServiceAgent::WatchJobStateAsync(
return call_opts;
}

absl::StatusOr<xla::coordination::WatchJobStateResponse>
CoordinationServiceAgent::WatchJobState(std::optional<int64_t> version_number) {
absl::StatusOr<xla::coordination::WatchJobStateResponse> response;
absl::StatusOr<xla::coordination::WatchTasksResponse>
CoordinationServiceAgent::WatchTasks(std::optional<int64_t> version_number) {
absl::StatusOr<xla::coordination::WatchTasksResponse> response;
absl::Notification done;
WatchJobStateAsync(
version_number,
[&response,
&done](absl::StatusOr<xla::coordination::WatchJobStateResponse> r) {
response = std::move(r);
done.Notify();
});
WatchTasksAsync(version_number,
[&response, &done](
absl::StatusOr<xla::coordination::WatchTasksResponse> r) {
response = std::move(r);
done.Notify();
});
done.WaitForNotification();
return response;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,14 @@ class CoordinationServiceAgent {
CoordinationService::TaskId task_id() const { return task_id_; }

// Watches the status of a remote job.
absl::StatusOr<xla::coordination::WatchJobStateResponse> WatchJobState(
absl::StatusOr<xla::coordination::WatchTasksResponse> WatchTasks(
std::optional<int64_t> version_number);

// Note: Cancel the underlying RPC call with `call_opts->StartCancel()` and
// `call_opts->ClearCancelCallback()`.
std::shared_ptr<tsl::CallOptions> WatchJobStateAsync(
std::shared_ptr<tsl::CallOptions> WatchTasksAsync(
std::optional<int64_t> version_number,
std::function<
void(absl::StatusOr<xla::coordination::WatchJobStateResponse>)>
std::function<void(absl::StatusOr<xla::coordination::WatchTasksResponse>)>
callback);

// Report error to coordination service. This will invoke the error callback.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ class TestCoordinationClient : public CoordinationClient {
(const GetAliveTasksRequest*, GetAliveTasksResponse*,
tsl::StatusCallback),
(override));
MOCK_METHOD(void, WatchJobStateAsync,
(tsl::CallOptions*, const WatchJobStateRequest*,
WatchJobStateResponse*, tsl::StatusCallback),
MOCK_METHOD(void, WatchTasksAsync,
(tsl::CallOptions*, const WatchTasksRequest*, WatchTasksResponse*,
tsl::StatusCallback),
(override));
MOCK_METHOD(void, HeartbeatAsync,
(tsl::CallOptions*, const HeartbeatRequest*, HeartbeatResponse*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,9 @@ void CoordinationServiceRpcHandler::ShutdownTaskAsync(
[done](absl::Status s) { done(s); });
}

void CoordinationServiceRpcHandler::WatchJobStateAsync(
const xla::coordination::WatchJobStateRequest* request,
xla::coordination::WatchJobStateResponse* response,
tsl::StatusCallback done) {
void CoordinationServiceRpcHandler::WatchTasksAsync(
const xla::coordination::WatchTasksRequest* request,
xla::coordination::WatchTasksResponse* response, tsl::StatusCallback done) {
absl::ReaderMutexLock l(mu_);
if (service_ == nullptr) {
done(MakeCoordinationError(
Expand All @@ -118,7 +117,7 @@ void CoordinationServiceRpcHandler::WatchJobStateAsync(
if (request->version_number() >= 0) {
version_number.emplace(request->version_number());
}
service_->WatchJobState(
service_->WatchTasks(
version_number,
[response, done](std::vector<xla::coordination::TaskInfo> info,
int64_t version_number) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,9 @@ class CoordinationServiceRpcHandler {
xla::coordination::ShutdownTaskResponse* response,
tsl::StatusCallback done);

void WatchJobStateAsync(
const xla::coordination::WatchJobStateRequest* request,
xla::coordination::WatchJobStateResponse* response,
tsl::StatusCallback done);
void WatchTasksAsync(const xla::coordination::WatchTasksRequest* request,
xla::coordination::WatchTasksResponse* response,
tsl::StatusCallback done);

void InsertKeyValueAsync(
const xla::coordination::InsertKeyValueRequest* request,
Expand Down
Loading
Loading