-
Notifications
You must be signed in to change notification settings - Fork 11
chore: ps example. #327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
chore: ps example. #327
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| 3.12 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| from ps import ConvNet, get_data_loader, ParameterServer, DataWorker, evaluate | ||
| from flamepy.rl import Runner | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| model = ConvNet() | ||
| test_loader = get_data_loader()[1] | ||
| print("Running synchronous parameter server training.") | ||
|
|
||
| with Runner("ps-example") as rr: | ||
| ps_svc = rr.service(ParameterServer(1e-2)) | ||
| workers_svc = [rr.service(DataWorker) for _ in range(4)] | ||
|
|
||
| current_weights = ps_svc.get_weights().get() | ||
| for i in range(20): | ||
| gradients = [worker.compute_gradients(current_weights) for worker in workers_svc] | ||
| # Calculate update after all gradients are available. | ||
| current_weights = ps_svc.apply_gradients(*gradients).get() | ||
|
|
||
| if i % 10 == 0: | ||
| # Evaluate the current model. | ||
| model.set_weights(current_weights) | ||
| accuracy = evaluate(model, test_loader) | ||
| print("Iter {}: \taccuracy is {:.1f}".format(i, accuracy)) | ||
|
|
||
| print("Final accuracy is {:.1f}.".format(accuracy)) | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,119 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch.nn as nn | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch.nn.functional as F | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from torchvision import datasets, transforms | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from filelock import FileLock | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_data_loader(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Safely downloads data. Returns training/validation set dataloader.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mnist_transforms = transforms.Compose( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # We add FileLock here because multiple workers will want to | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # download data, and this may cause overwrites since | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # DataLoader is not threadsafe. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with FileLock(os.path.expanduser("~/data.lock")): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| train_loader = torch.utils.data.DataLoader( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| datasets.MNIST( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "~/data", train=True, download=True, transform=mnist_transforms | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| batch_size=128, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| shuffle=True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| test_loader = torch.utils.data.DataLoader( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| datasets.MNIST("~/data", train=False, transform=mnist_transforms), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| batch_size=128, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| shuffle=True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+19
to
+31
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The path
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return train_loader, test_loader | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def evaluate(model, test_loader): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Evaluates the accuracy of the model on a validation dataset.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| model.eval() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| correct = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| total = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with torch.no_grad(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for batch_idx, (data, target) in enumerate(test_loader): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # This is only set to finish evaluation faster. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if batch_idx * len(data) > 1024: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| break | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| outputs = model(data) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _, predicted = torch.max(outputs.data, 1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| total += target.size(0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| correct += (predicted == target).sum().item() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return 100.0 * correct / total | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class ConvNet(nn.Module): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Small ConvNet for MNIST.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| super(ConvNet, self).__init__() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.conv1 = nn.Conv2d(1, 3, kernel_size=3) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.fc = nn.Linear(192, 10) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def forward(self, x): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x = F.relu(F.max_pool2d(self.conv1(x), 3)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x = x.view(-1, 192) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| x = self.fc(x) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return F.log_softmax(x, dim=1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_weights(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return {k: v.cpu() for k, v in self.state_dict().items()} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def set_weights(self, weights): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.load_state_dict(weights) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_gradients(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| grads = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for p in self.parameters(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| grad = None if p.grad is None else p.grad.data.cpu().numpy() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| grads.append(grad) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return grads | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def set_gradients(self, gradients): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for g, p in zip(gradients, self.parameters()): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if g is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| p.grad = torch.from_numpy(g) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class ParameterServer(object): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, lr): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.model = ConvNet() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.optimizer = torch.optim.SGD(self.model.parameters(), lr=lr) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def apply_gradients(self, *gradients): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| summed_gradients = [ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| np.stack(gradient_zip).sum(axis=0) for gradient_zip in zip(*gradients) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.optimizer.zero_grad() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.model.set_gradients(summed_gradients) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.optimizer.step() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return self.model.get_weights() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def get_weights(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return self.model.get_weights() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class DataWorker(object): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.model = ConvNet() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.data_iterator = iter(get_data_loader()[0]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def compute_gradients(self, weights): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.model.set_weights(weights) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| data, target = next(self.data_iterator) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except StopIteration: # When the epoch ends, start a new epoch. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.data_iterator = iter(get_data_loader()[0]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| data, target = next(self.data_iterator) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.model.zero_grad() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| output = self.model(data) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| loss = F.nll_loss(output, target) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| loss.backward() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return self.model.get_gradients() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| [project] | ||
| name = "ps-example" | ||
| version = "0.1.0" | ||
| description = "Parameter Server by flamepy.Runner" | ||
| readme = "README.md" | ||
| requires-python = ">=3.12" | ||
| dependencies = [ | ||
| "torch", | ||
| "torchvision", | ||
| "numpy", | ||
| "filelock" | ||
| ] | ||
|
|
||
| [build-system] | ||
| requires = ["setuptools>=61.0", "wheel"] | ||
| build-backend = "setuptools.build_meta" | ||
|
|
||
| [tool.setuptools] | ||
| py-modules = ["ps", "main"] | ||
|
|
||
| [tool.uv.sources] | ||
| flamepy = { path = "/usr/local/flame/sdk/python" } |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |||||
| import inspect | ||||||
| import logging | ||||||
| import os | ||||||
| import shutil | ||||||
| import site | ||||||
| import subprocess | ||||||
| import sys | ||||||
|
|
@@ -120,7 +121,12 @@ def _extract_archive(self, archive_path: str, extract_to: str) -> str: | |||||
| logger.info(f"Extracting archive: {archive_path} to {extract_to}") | ||||||
|
|
||||||
| try: | ||||||
| # Create extraction directory if it doesn't exist | ||||||
| # Remove old extracted directory if it exists to ensure clean extraction | ||||||
| if os.path.exists(extract_to): | ||||||
| logger.info(f"Removing existing extracted directory: {extract_to}") | ||||||
| shutil.rmtree(extract_to) | ||||||
|
|
||||||
| # Create extraction directory | ||||||
| os.makedirs(extract_to, exist_ok=True) | ||||||
|
|
||||||
| # Determine archive type and extract | ||||||
|
|
@@ -193,27 +199,50 @@ def _install_package_from_url(self, url: str) -> None: | |||||
| install_path = extracted_dir | ||||||
| logger.info(f"Will install from extracted directory: {install_path}") | ||||||
|
|
||||||
| # Debug: List contents of extracted directory | ||||||
| try: | ||||||
| contents = os.listdir(install_path) | ||||||
| logger.debug(f"Extracted directory contents: {contents}") | ||||||
|
|
||||||
| # Check for pyproject.toml or setup.py | ||||||
| if "pyproject.toml" in contents: | ||||||
| pyproject_path = os.path.join(install_path, "pyproject.toml") | ||||||
| with open(pyproject_path, "r") as f: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a good practice to explicitly specify the file encoding when opening text files. This avoids potential
Suggested change
|
||||||
| pyproject_content = f.read() | ||||||
| logger.debug(f"pyproject.toml content:\n{pyproject_content}") | ||||||
| if "setup.py" in contents: | ||||||
| logger.debug("Found setup.py in extracted directory") | ||||||
| except Exception as e: | ||||||
| logger.warning(f"Failed to list extracted directory contents: {e}") | ||||||
|
|
||||||
| # Use sys.executable -m pip to install into the current virtual environment | ||||||
| # pip install will upgrade the package if it's already installed | ||||||
| logger.info(f"Installing package: {install_path}") | ||||||
| install_args = [sys.executable, "-m", "pip", "install", install_path] | ||||||
| logger.debug(f"Python executable: {sys.executable}") | ||||||
| logger.debug(f"Current working directory: {os.getcwd()}") | ||||||
| install_args = [sys.executable, "-m", "pip", "install", "--upgrade", install_path] | ||||||
| logger.debug(f"Install command: {' '.join(install_args)}") | ||||||
|
|
||||||
| try: | ||||||
| result = subprocess.run(install_args, capture_output=True, text=True, check=True) | ||||||
| logger.info(f"Package installation output: {result.stdout}") | ||||||
| logger.info("Package installation succeeded") | ||||||
| logger.debug(f"Package installation stdout:\n{result.stdout}") | ||||||
| if result.stderr: | ||||||
| logger.warning(f"Package installation stderr: {result.stderr}") | ||||||
| logger.debug(f"Package installation stderr:\n{result.stderr}") | ||||||
| logger.info(f"Successfully installed package from: {install_path}") | ||||||
|
|
||||||
| # Reload site packages to make the newly installed package available | ||||||
| # This is necessary because the Python interpreter has already started | ||||||
| logger.info("Reloading site packages to pick up newly installed package") | ||||||
| importlib.reload(site) | ||||||
| logger.info(f"Updated sys.path: {sys.path}") | ||||||
| logger.debug(f"Updated sys.path: {sys.path}") | ||||||
|
|
||||||
| except subprocess.CalledProcessError as e: | ||||||
| logger.error(f"Failed to install package: {e}") | ||||||
| logger.error(f"stdout: {e.stdout}") | ||||||
| logger.error(f"stderr: {e.stderr}") | ||||||
| logger.error(f"Return code: {e.returncode}") | ||||||
| logger.error(f"Install command was: {' '.join(install_args)}") | ||||||
| logger.error(f"Package installation stdout:\n{e.stdout}") | ||||||
| logger.error(f"Package installation stderr:\n{e.stderr}") | ||||||
| raise RuntimeError(f"Package installation failed: {e}") | ||||||
| finally: | ||||||
| # Clean up extracted directory if it was created | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
accuracyvariable is only defined inside theif i % 10 == 0:block within the loop. If the loop range were smaller (e.g.,range(0)), this line would raise anUnboundLocalError. It would be safer to initializeaccuracyto a default value (e.g.,0.0) before the loop.