Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tensorflow/core/data/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ cc_library(
],
deps = [
":cache_utils",
":local_decision_utils",
":common",
":common_proto_cc",
":credentials_factory",
Expand Down Expand Up @@ -938,6 +939,21 @@ cc_library(
],
)

cc_library(
name = "local_decision_utils",
srcs = ["easl/local_decision_utils.cc"],
hdrs = [
"easl/local_decision_utils.h",
],
deps = [
":common_proto_cc",
":dispatcher_state",
":metadata_store",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
],
)

cc_library(
name = "cache_model",
srcs = ["easl/cache_model.cc"],
Expand Down
9 changes: 7 additions & 2 deletions tensorflow/core/data/service/dispatcher.proto
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ message JobKey {
int64 job_name_index = 2;
}

// Next tag: 10
// Next tag: 11
message GetOrCreateJobRequest {
reserved 2, 3, 4;
// The id of the dataset to create a job for.
Expand All @@ -136,6 +136,9 @@ message GetOrCreateJobRequest {
}
// Specifies which workers the client of this job reads from.
TargetWorkers target_workers = 9;

// MUYU's changes
repeated string local_workers = 10;
}

// Next tag: 2
Expand Down Expand Up @@ -186,7 +189,7 @@ message ClientHeartbeatRequest {
double avg_inter_arrival_time = 6;
}

// Next tag: 4
// Next tag: 5
message ClientHeartbeatResponse {
// A list of all tasks that the client should read from.
repeated TaskInfo task_info = 1;
Expand All @@ -196,6 +199,8 @@ message ClientHeartbeatResponse {
}
// Whether the job has finished.
bool job_finished = 2;
// EASL: to check whether we should use local workers (based on last epoch's metrics)
bool num_worker_local_target = 4;
}

// Next tag: 3
Expand Down
12 changes: 11 additions & 1 deletion tensorflow/core/data/service/dispatcher_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ Status DataServiceDispatcherClient::GetOrCreateJob(
int64_t dataset_id, const ProcessingModeDef& processing_mode,
const absl::optional<JobKey>& job_key,
absl::optional<int64_t> num_consumers, TargetWorkers target_workers,
int64_t& job_client_id) {
int64_t& job_client_id,
std::vector<std::string> local_workers
) {
TF_RETURN_IF_ERROR(EnsureInitialized());
GetOrCreateJobRequest req;
req.set_dataset_id(dataset_id);
Expand All @@ -149,6 +151,13 @@ Status DataServiceDispatcherClient::GetOrCreateJob(
req.set_num_consumers(num_consumers.value());
}
req.set_target_workers(target_workers);

*req.mutable_local_workers() = {local_workers.begin(), local_workers.end()};

for (auto worker: local_workers) {
VLOG(1) << "EASL-MUYU (Client: GetOrCreateJob) local_workers: " << worker;
}

GetOrCreateJobResponse resp;
grpc::ClientContext client_ctx;
grpc::Status status = stub_->GetOrCreateJob(&client_ctx, req, &resp);
Expand All @@ -159,6 +168,7 @@ Status DataServiceDispatcherClient::GetOrCreateJob(
status);
}
job_client_id = resp.job_client_id();

return Status::OK();
}

Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/data/service/dispatcher_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class DataServiceDispatcherClient : public DataServiceClientBase {
const ProcessingModeDef& processing_mode,
const absl::optional<JobKey>& job_key,
absl::optional<int64_t> num_consumers,
TargetWorkers target_workers, int64_t& job_client_id);
TargetWorkers target_workers, int64_t& job_client_id,
std::vector<std::string> local_workers);

// Releases a job client id, indicating that the id will no longer be used to
// read from the job.
Expand Down
73 changes: 63 additions & 10 deletions tensorflow/core/data/service/dispatcher_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ limitations under the License.
#include "tensorflow/core/data/service/worker.grpc.pb.h"
#include "tensorflow/core/data/service/easl/cache_utils.h"
#include "tensorflow/core/data/service/easl/metadata_store.h"
#include "tensorflow/core/data/service/easl/local_decision_utils.h"
#include "tensorflow/core/data/standalone.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/graph.pb.h"
Expand Down Expand Up @@ -880,7 +881,6 @@ Status DataServiceDispatcherImpl::CreateJob(
int64 dataset_id = request.dataset_id();

// EASL - Caching decision: should the job compute, write or read from cache?
int64 worker_count;
std::string job_type;
std::shared_ptr<const Dataset> dataset;
TF_RETURN_IF_ERROR(state_.DatasetFromId(dataset_id, dataset));
Expand All @@ -890,16 +890,53 @@ Status DataServiceDispatcherImpl::CreateJob(
service::easl::cache_utils::DetermineJobType(
config_, cache_state_, metadata_store_, dataset_fingerprint,
compute_dataset_key, job_id, job_type);
VLOG(0) << "EASL - Caching decision for dataset_key "
VLOG(0) << "EASL - Caching decision for dataset_key "
<< compute_dataset_key << ": " << job_type;

// Infer the worker count for this job and job type
// MUYU, firstly check the local_workers from the client
absl::flat_hash_set<std::string> local_workers;
local_workers.insert(request.local_workers().cbegin(),
request.local_workers().cend());

int64 num_worker_remote_target, num_worker_local_target;
int64 total_workers = state_.ListWorkers().size();
TF_RETURN_IF_ERROR(service::easl::cache_utils::DetermineElasticity(job_type,
config_, metadata_store_, compute_dataset_key, total_workers, worker_count));
VLOG(0) << "EASL - Scalability decision for dataset_key "
if(config_.scaling_policy() == 1) { // Old autoscaling prior to paper
// Infer the worker count for this job and job type
int64 worker_count;
TF_RETURN_IF_ERROR(service::easl::cache_utils::DetermineElasticity(job_type,
config_, metadata_store_, compute_dataset_key, total_workers, worker_count));
VLOG(0) << "EASL - Scalability decision for dataset_key "
<< compute_dataset_key << ": " << worker_count;

bool want_to_use_local_workers; // Do we have enough throughput to decide to use local workers to save network bandwidth?
TF_RETURN_IF_ERROR(service::easl::local_decision::DecideIfLocal(
config_, metadata_store_, compute_dataset_key, want_to_use_local_workers
));

if(want_to_use_local_workers && local_workers.size() >= 1) {
num_worker_remote_target = worker_count - 1;
num_worker_local_target = 1;
} else {
num_worker_remote_target = worker_count;
num_worker_local_target = 0;
}
} else if(config_.scaling_policy() == 2) { // Use all available workers
num_worker_remote_target = total_workers - local_workers.size();
num_worker_local_target = local_workers.size();
} else if(config_.scaling_policy() == 3) { // Grid search over local and remote workers
TF_RETURN_IF_ERROR(service::easl::local_decision::DecideTargetWorkersGridSearch(
config_, metadata_store_, compute_dataset_key,
total_workers - local_workers.size(), local_workers.size(),
num_worker_remote_target, num_worker_local_target
));
} else { // New paper autoscaling
TF_RETURN_IF_ERROR(service::easl::local_decision::DecideTargetWorkersAutoscaling(
config_, metadata_store_, compute_dataset_key,
total_workers - local_workers.size(), local_workers.size(),
num_worker_remote_target, num_worker_local_target
));
}

// EASL add job entry to metadata store
std::string dataset_key = service::easl::cache_utils::DatasetKey(
dataset->dataset_id, dataset->fingerprint, job_type);
Expand All @@ -919,8 +956,16 @@ Status DataServiceDispatcherImpl::CreateJob(
create_job->set_dataset_id(request.dataset_id());
*create_job->mutable_processing_mode_def() = request.processing_mode_def();
create_job->set_job_type(job_type);
create_job->set_worker_count(worker_count);
create_job->set_num_worker_remote_target(num_worker_remote_target);
create_job->set_num_split_providers(num_split_providers);
create_job->set_num_worker_local_target(num_worker_local_target);
*create_job->mutable_local_workers() =
{local_workers.begin(), local_workers.end()};

for (auto worker: local_workers) {
VLOG(2) << "EASL-MUYU (CreateJob) local_workers: " << worker;
}

if (request.has_job_key()) {
NamedJobKeyDef* key = create_job->mutable_named_job_key();
key->set_name(request.job_key().job_name());
Expand All @@ -933,6 +978,11 @@ Status DataServiceDispatcherImpl::CreateJob(
create_job->set_target_workers(request.target_workers());
TF_RETURN_IF_ERROR(Apply(update));
TF_RETURN_IF_ERROR(state_.JobFromId(job_id, job));

for (auto worker: job->local_workers) {
VLOG(2) << "EASL-MUYU (CreateJob-after) local_workers: " << worker;
}

return Status::OK();
}

Expand Down Expand Up @@ -983,11 +1033,12 @@ Status DataServiceDispatcherImpl::CreateTasksForJob(
std::vector<std::shared_ptr<const Task>>& tasks)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
std::vector<std::shared_ptr<Worker>> workers = state_.ReserveWorkers(
job->job_id, job->worker_count);
if (workers.size() < job->worker_count){
job->job_id, job->num_worker_remote_target, job->num_worker_local_target, job->local_workers);
if (workers.size() < job->num_worker_remote_target + job->num_worker_local_target){
VLOG(0)
<< "EASL - Not enough workers for job. Elasticity policy requires "
<< job->worker_count << " but got " << workers.size();
<< job->num_worker_remote_target << " remote and " << job->num_worker_local_target
<< " local but got " << workers.size();
}
tasks.clear();
tasks.reserve(workers.size());
Expand Down Expand Up @@ -1217,6 +1268,8 @@ Status DataServiceDispatcherImpl::ClientHeartbeat(
task_info->set_starting_round(task->starting_round);
}
response->set_job_finished(job->finished);
response->set_num_worker_local_target(job->num_worker_local_target);

VLOG(4) << "Found " << response->task_info_size()
<< " tasks for job client id " << request->job_client_id();
return Status::OK();
Expand Down
66 changes: 49 additions & 17 deletions tensorflow/core/data/service/dispatcher_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,13 @@ void DispatcherState::CreateJob(const CreateJobUpdate& create_job) {
create_job.processing_mode_def(),
create_job.num_split_providers(),
named_job_key, num_consumers, create_job.target_workers(),
create_job.job_type(), create_job.worker_count());
create_job.job_type(), create_job.num_worker_remote_target(),
create_job.num_worker_local_target());

for (auto worker: create_job.local_workers()) {
VLOG(1) << "EASL-MUYU (DispatcherState::CreateJob): worker " << worker;
job->local_workers.insert(worker);
}

DCHECK(!jobs_.contains(job_id));
jobs_[job_id] = job;
Expand Down Expand Up @@ -368,30 +374,53 @@ DispatcherState::ListAvailableWorkers() const {
return workers;
}

// Reserves a number of available workers for a particular job. If num_workers
// is lower than or equal to 0, then the reserved number of workers is equal
// to all the available workers.
std::vector<std::shared_ptr<DispatcherState::Worker>>
DispatcherState::ReserveWorkers(
int64 job_id, int64 target_num_workers) {
// DCHECK(num_workers <= avail_workers_.size());

// If the number of required workers is below those available, we just assign
// as many as there are available at this epoch's scheduling time.
int64 num_workers = target_num_workers <= 0
|| target_num_workers > avail_workers_.size() ? avail_workers_.size()
: target_num_workers;
int64 job_id, int64 num_worker_remote_target,
int64 num_worker_local_target,
const absl::flat_hash_set<std::string> local_workers) {
int64 num_worker_target = num_worker_remote_target + num_worker_local_target;
if(num_worker_target <= 0 || num_worker_target > avail_workers_.size()) {
num_worker_remote_target = avail_workers_.size();
num_worker_local_target = avail_workers_.size();
}

std::vector<std::shared_ptr<Worker>> workers;
workers.reserve(num_workers);
VLOG(0) << "(ReserveWorkers) User got " << num_workers << " workers from "
<< "target " << target_num_workers << " workers";
workers.reserve(avail_workers_.size());
VLOG(0) << "DSL (ReserveWorkers) Available remote: " << avail_workers_.size() << "\n"
<< "Available local: " << local_workers.size() << "\n"
<< "Target remote: " << num_worker_remote_target << "\n"
<< "Target local: " << num_worker_local_target << "\n";

for (auto it = avail_workers_.begin(); it != avail_workers_.end(); ) {
num_workers--;
//bool is_local = std::count(it->second->tags.begin(), it->second->tags.end(), "COLOCATED"); // Tag based
bool is_local = local_workers.count(it->first);
if (is_local) {
VLOG(0) << "EASL-DSL (ReserveWorkers) Worker_L: " << it->first;
if (num_worker_local_target <= 0) { // No additional local workers needed
it++;
continue;
} else {
num_worker_local_target--;
}
} else {
VLOG(0) << "EASL-DSL (ReserveWorkers) Worker_R: " << it->first;
if (num_worker_remote_target <= 0) { // No additional remote workers needed
it++;
continue;
} else {
num_worker_remote_target--;
}
}
workers.push_back(it->second);
VLOG(0) << "(ReserveWorkers) Assigning worker at address "
<< it->second->address << " to job " << job_id;
workers_by_job_[job_id].push_back(it->second);
jobs_by_worker_[it->second->address][job_id] = jobs_[job_id];
avail_workers_.erase(it++);
if (num_workers == 0)
break;
}
VLOG(0) << "(ReserveWorkers) Number of workers for job " << job_id << " is: "
<< workers_by_job_[job_id].size();
Expand All @@ -411,7 +440,8 @@ void DispatcherState::ReassignFreeWorkers() {
// Get a job in need of workers
std::shared_ptr<Job> job = job_iter->second;
int64 num_assigned_workers = workers_by_job_[job->job_id].size();
while (job->finished || num_assigned_workers == job->worker_count){
while (job->finished || num_assigned_workers ==
job->num_worker_remote_target + job->num_worker_local_target){
job_iter++;
if(job_iter == jobs_.end()){
// Went through all jobs, can return
Expand All @@ -423,10 +453,12 @@ void DispatcherState::ReassignFreeWorkers() {
// Assign one worker to the job
workers_by_job_[job->job_id].push_back(it->second);
jobs_by_worker_[it->second->address][job->job_id] = jobs_[job->job_id];
avail_workers_.erase(it);


VLOG(0) << "EASL - (ReassignFreeWorkers) Reassigned worker "
<< it->second->address << " to job " << job->job_id;

avail_workers_.erase(it++);
}
}

Expand Down
23 changes: 18 additions & 5 deletions tensorflow/core/data/service/dispatcher_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,15 +150,21 @@ class DispatcherState {
absl::optional<int64> num_consumers,
TargetWorkers target_workers,
const std::string& job_type,
int64 worker_count)
int64 num_worker_remote_target,
int64 num_worker_local_target,
absl::flat_hash_set<std::string> local_workers = {}
)
: job_id(job_id),
dataset_id(dataset_id),
processing_mode(processing_mode),
named_job_key(named_job_key),
num_consumers(num_consumers),
job_type(job_type),
worker_count(worker_count),
target_workers(target_workers) {
num_worker_remote_target(num_worker_remote_target),
target_workers(target_workers),
num_worker_local_target(num_worker_local_target),
local_workers(local_workers)
{
if (IsDynamicShard(processing_mode)) {
distributed_epoch_state = DistributedEpochState(num_split_providers);
}
Expand Down Expand Up @@ -189,7 +195,11 @@ class DispatcherState {
bool garbage_collected = false;
// EASL
const std::string job_type;
const int64 worker_count;
const int64 num_worker_remote_target;
// EASL: indicate whether this job should be processed locally
const int64 num_worker_local_target;
// EASL: list of local workers in the client
absl::flat_hash_set<std::string> local_workers;
};

struct Task {
Expand Down Expand Up @@ -243,7 +253,10 @@ class DispatcherState {
// is lower than or equal to 0, then the reserved number of workers is equal
// to all the available workers.
std::vector<std::shared_ptr<Worker>> ReserveWorkers(int64 job_id,
int64 num_workers = 0);
int64 num_worker_remote_target = 0,
int64 num_worker_local_target = 0,
const absl::flat_hash_set<std::string> local_workers = {}
);

// Returns the next available job id.
int64_t NextAvailableJobId() const;
Expand Down
Loading