Skip to content
This repository was archived by the owner on Jan 12, 2026. It is now read-only.

Conversation

@huangrh99
Copy link
Collaborator

@huangrh99 huangrh99 commented May 10, 2021

  1. ps_strategy
  2. add ps_strategy to jax example
  3. add Typing in ps_strategy.py, allreduce_strategy.py and base_strategy.py
  4. strategy and jax operator save/load states



class ParameterServerStrategy(BaseStrategy):
"""Strategy that trains a model via collective AllReduce.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change this docstring summary ?

training_operator_cls,
operator_config=None,
initialization_hook=None,
num_workers=1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_worker


assert num_ps
self.num_ps = num_ps
self.num_workers = num_workers
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here... don't use the plural form

assert num_ps
self.num_ps = num_ps
self.num_workers = num_workers
self.num_cpus_per_server = num_cpus_per_server
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and here and following

ray.get([server.set_params.remote(this_shard_ref)])

def _start_workers(self):
"""Create worker(actor), maybe need worker group to manager these workers.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rewrite this docstring

"""
# TODO (Hao): infer the per-replica batch size here...

# so here we get two set of params that will be passed around:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can clean this comment as it is redundent with those I left in AllReduceStrategy

}

# Should we make two groups for worker and server?
self.worker_group = DataParallelGroup(**workergroup_init_args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is strange. Is this the same DataParallelGroup with the one in AllReduceStrategy?
If yes -- then fine
If not -- is there any way we can share the same class? If it is hard then we should at least use a different class name?

self.server_group.start_actors(
self.num_ps) # server at the last num_ps processes.

worker_rets = self.worker_group.test_connection()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are testing connection necessary? if not, probably move it to DEBUG mode.


def setup_operator(self):
# figure out the signature of training_operator_cls later.
self.training_operator = self.training_operator_cls(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure whether we should setup the whole operator on the server side? One drawback is that this will take a lot of GPU memory?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants