Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions common/src/apis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ pub struct Application {
pub command: Option<String>,
pub arguments: Vec<String>,
pub environments: HashMap<String, String>,
pub working_directory: String,
pub working_directory: Option<String>,
pub max_instances: u32,
pub delay_release: Duration,
pub schema: Option<ApplicationSchema>,
Expand All @@ -111,7 +111,7 @@ pub struct ApplicationAttributes {
pub command: Option<String>,
pub arguments: Vec<String>,
pub environments: HashMap<String, String>,
pub working_directory: String,
pub working_directory: Option<String>,
pub max_instances: u32,
pub delay_release: Duration,
pub schema: Option<ApplicationSchema>,
Expand All @@ -128,7 +128,7 @@ impl Default for ApplicationAttributes {
command: None,
arguments: vec![],
environments: HashMap::new(),
working_directory: "/tmp".to_string(),
working_directory: None,
max_instances: DEFAULT_MAX_INSTANCES,
delay_release: DEFAULT_DELAY_RELEASE,
schema: Some(ApplicationSchema::default()),
Expand Down Expand Up @@ -868,7 +868,7 @@ impl TryFrom<&rpc::Application> for Application {
.into_iter()
.map(|e| (e.name, e.value))
.collect(),
working_directory: spec.working_directory.unwrap_or(String::default()),
working_directory: spec.working_directory,
max_instances: spec.max_instances.unwrap_or(DEFAULT_MAX_INSTANCES),
delay_release: spec
.delay_release
Expand Down Expand Up @@ -901,7 +901,7 @@ impl From<&Application> for rpc::Application {
.into_iter()
.map(|(k, v)| rpc::Environment { name: k, value: v })
.collect(),
working_directory: Some(app.working_directory.clone()),
working_directory: app.working_directory.clone(),
max_instances: Some(app.max_instances),
delay_release: Some(app.delay_release.num_seconds()),
schema: app.schema.clone().map(rpc::ApplicationSchema::from),
Expand Down Expand Up @@ -939,7 +939,7 @@ impl From<rpc::ApplicationSpec> for ApplicationAttributes {
.into_iter()
.map(|e| (e.name, e.value))
.collect(),
working_directory: spec.working_directory.clone().unwrap_or_default(),
working_directory: spec.working_directory.clone(),
max_instances: spec.max_instances.unwrap_or(DEFAULT_MAX_INSTANCES),
delay_release: spec
.delay_release
Expand Down
3 changes: 1 addition & 2 deletions common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,7 @@ pub fn default_applications() -> HashMap<String, ApplicationAttributes> {
"The Flame Runner application for executing customized Python applications."
.to_string(),
),
working_directory: "/tmp".to_string(),
command: Some("/usr/bin/uv".to_string()),
command: Some("/bin/uv".to_string()),
arguments: vec![
"run".to_string(),
"--with".to_string(),
Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile.console
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ COPY --from=builder /usr/local/cargo/bin/flmping /usr/local/bin/flmping
COPY --from=builder /usr/local/cargo/bin/flmctl /usr/local/bin/flmctl
COPY --from=builder /usr/local/cargo/bin/flmexec /usr/local/bin/flmexec

COPY --from=ghcr.io/astral-sh/uv:0.9.18 /uv /uvx /bin/
COPY --from=ghcr.io/astral-sh/uv:0.9.26 /uv /uvx /bin/

RUN chmod +x /usr/local/bin/*

Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile.fem
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ RUN mkdir -p /usr/local/flame/bin /usr/local/flame/work /usr/local/flame/sdk

WORKDIR /usr/local/flame/work

COPY --from=ghcr.io/astral-sh/uv:0.9.18 /uv /uvx /bin/
COPY --from=ghcr.io/astral-sh/uv:0.9.26 /uv /uvx /bin/

COPY --from=builder /usr/local/cargo/bin/flame-executor-manager /usr/local/flame/bin/flame-executor-manager
COPY --from=builder /usr/local/cargo/bin/flmping-service /usr/local/flame/bin/flmping-service
Expand Down
1 change: 1 addition & 0 deletions examples/ps/.python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.12
Empty file added examples/ps/README.md
Empty file.
Binary file added examples/ps/dist/ps-example.tar.gz
Binary file not shown.
26 changes: 26 additions & 0 deletions examples/ps/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from ps import ConvNet, get_data_loader, ParameterServer, DataWorker, evaluate
from flamepy.rl import Runner


if __name__ == "__main__":
model = ConvNet()
test_loader = get_data_loader()[1]
print("Running synchronous parameter server training.")

with Runner("ps-example") as rr:
ps_svc = rr.service(ParameterServer(1e-2))
workers_svc = [rr.service(DataWorker) for _ in range(4)]

current_weights = ps_svc.get_weights().get()
for i in range(20):
gradients = [worker.compute_gradients(current_weights) for worker in workers_svc]
# Calculate update after all gradients are available.
current_weights = ps_svc.apply_gradients(*gradients).get()

if i % 10 == 0:
# Evaluate the current model.
model.set_weights(current_weights)
accuracy = evaluate(model, test_loader)
print("Iter {}: \taccuracy is {:.1f}".format(i, accuracy))

print("Final accuracy is {:.1f}.".format(accuracy))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The accuracy variable is only defined inside the if i % 10 == 0: block within the loop. If the loop range were smaller (e.g., range(0)), this line would raise an UnboundLocalError. It would be safer to initialize accuracy to a default value (e.g., 0.0) before the loop.

119 changes: 119 additions & 0 deletions examples/ps/ps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from filelock import FileLock
import numpy as np


def get_data_loader():
"""Safely downloads data. Returns training/validation set dataloader."""
mnist_transforms = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

# We add FileLock here because multiple workers will want to
# download data, and this may cause overwrites since
# DataLoader is not threadsafe.
with FileLock(os.path.expanduser("~/data.lock")):
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"~/data", train=True, download=True, transform=mnist_transforms
),
batch_size=128,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST("~/data", train=False, transform=mnist_transforms),
batch_size=128,
shuffle=True,
)
Comment on lines +19 to +31
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The path ~/data is not guaranteed to be expanded by torchvision.datasets.MNIST. It's better to use os.path.expanduser explicitly to ensure the path is resolved correctly. Also, it's good practice to place the lock file within the directory it's protecting to keep things organized. I've refactored this part to define data_dir once and use it for both the dataset and the lock file.

Suggested change
with FileLock(os.path.expanduser("~/data.lock")):
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"~/data", train=True, download=True, transform=mnist_transforms
),
batch_size=128,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST("~/data", train=False, transform=mnist_transforms),
batch_size=128,
shuffle=True,
)
data_dir = os.path.expanduser("~/data")
# We add FileLock here because multiple workers will want to
# download data, and this may cause overwrites since
# DataLoader is not threadsafe.
with FileLock(os.path.join(data_dir, ".lock")):
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
data_dir, train=True, download=True, transform=mnist_transforms
),
batch_size=128,
shuffle=True,
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(data_dir, train=False, transform=mnist_transforms),
batch_size=128,
shuffle=True,
)

return train_loader, test_loader


def evaluate(model, test_loader):
"""Evaluates the accuracy of the model on a validation dataset."""
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(test_loader):
# This is only set to finish evaluation faster.
if batch_idx * len(data) > 1024:
break
outputs = model(data)
_, predicted = torch.max(outputs.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
return 100.0 * correct / total


class ConvNet(nn.Module):
"""Small ConvNet for MNIST."""

def __init__(self):
super(ConvNet, self).__init__()
self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
self.fc = nn.Linear(192, 10)

def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 3))
x = x.view(-1, 192)
x = self.fc(x)
return F.log_softmax(x, dim=1)

def get_weights(self):
return {k: v.cpu() for k, v in self.state_dict().items()}

def set_weights(self, weights):
self.load_state_dict(weights)

def get_gradients(self):
grads = []
for p in self.parameters():
grad = None if p.grad is None else p.grad.data.cpu().numpy()
grads.append(grad)
return grads

def set_gradients(self, gradients):
for g, p in zip(gradients, self.parameters()):
if g is not None:
p.grad = torch.from_numpy(g)


class ParameterServer(object):
def __init__(self, lr):
self.model = ConvNet()
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=lr)

def apply_gradients(self, *gradients):
summed_gradients = [
np.stack(gradient_zip).sum(axis=0) for gradient_zip in zip(*gradients)
]
self.optimizer.zero_grad()
self.model.set_gradients(summed_gradients)
self.optimizer.step()
return self.model.get_weights()

def get_weights(self):
return self.model.get_weights()


class DataWorker(object):
def __init__(self):
self.model = ConvNet()
self.data_iterator = iter(get_data_loader()[0])

def compute_gradients(self, weights):
self.model.set_weights(weights)
try:
data, target = next(self.data_iterator)
except StopIteration: # When the epoch ends, start a new epoch.
self.data_iterator = iter(get_data_loader()[0])
data, target = next(self.data_iterator)
self.model.zero_grad()
output = self.model(data)
loss = F.nll_loss(output, target)
loss.backward()
return self.model.get_gradients()
22 changes: 22 additions & 0 deletions examples/ps/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[project]
name = "ps-example"
version = "0.1.0"
description = "Parameter Server by flamepy.Runner"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"torch",
"torchvision",
"numpy",
"filelock"
]

[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"

[tool.setuptools]
py-modules = ["ps", "main"]

[tool.uv.sources]
flamepy = { path = "/usr/local/flame/sdk/python" }
12 changes: 8 additions & 4 deletions executor_manager/src/shims/host_shim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use std::env;
use std::fs::{self, create_dir_all, File, OpenOptions};
use std::future::Future;
use std::os::unix::process::CommandExt;
use std::path::Path;
use std::pin::Pin;
use std::process::{self, Command, Stdio};
use std::sync::Arc;
Expand Down Expand Up @@ -99,10 +100,13 @@ impl HostShim {
// Spawn child process
let mut cmd = tokio::process::Command::new(&command);

let cur_dir = app
.working_directory
.clone()
.unwrap_or(FLAME_WORKING_DIRECTORY.to_string());
// If application doesn't specify working_directory, use executor-specific directory
let cur_dir = app.working_directory.clone().unwrap_or_else(|| {
let executor_working_directory = env::current_dir()
.unwrap_or(Path::new(FLAME_WORKING_DIRECTORY).to_path_buf())
.join(executor.id.as_str());
executor_working_directory.to_string_lossy().to_string()
});

tracing::debug!("Current directory of application instance: {cur_dir}");

Expand Down
43 changes: 36 additions & 7 deletions sdk/python/src/flamepy/rl/runpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import inspect
import logging
import os
import shutil
import site
import subprocess
import sys
Expand Down Expand Up @@ -120,7 +121,12 @@ def _extract_archive(self, archive_path: str, extract_to: str) -> str:
logger.info(f"Extracting archive: {archive_path} to {extract_to}")

try:
# Create extraction directory if it doesn't exist
# Remove old extracted directory if it exists to ensure clean extraction
if os.path.exists(extract_to):
logger.info(f"Removing existing extracted directory: {extract_to}")
shutil.rmtree(extract_to)

# Create extraction directory
os.makedirs(extract_to, exist_ok=True)

# Determine archive type and extract
Expand Down Expand Up @@ -193,27 +199,50 @@ def _install_package_from_url(self, url: str) -> None:
install_path = extracted_dir
logger.info(f"Will install from extracted directory: {install_path}")

# Debug: List contents of extracted directory
try:
contents = os.listdir(install_path)
logger.debug(f"Extracted directory contents: {contents}")

# Check for pyproject.toml or setup.py
if "pyproject.toml" in contents:
pyproject_path = os.path.join(install_path, "pyproject.toml")
with open(pyproject_path, "r") as f:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's a good practice to explicitly specify the file encoding when opening text files. This avoids potential UnicodeDecodeError if the file contains non-ASCII characters and the system's default encoding is not UTF-8. Please add encoding="utf-8".

Suggested change
with open(pyproject_path, "r") as f:
with open(pyproject_path, "r", encoding="utf-8") as f:

pyproject_content = f.read()
logger.debug(f"pyproject.toml content:\n{pyproject_content}")
if "setup.py" in contents:
logger.debug("Found setup.py in extracted directory")
except Exception as e:
logger.warning(f"Failed to list extracted directory contents: {e}")

# Use sys.executable -m pip to install into the current virtual environment
# pip install will upgrade the package if it's already installed
logger.info(f"Installing package: {install_path}")
install_args = [sys.executable, "-m", "pip", "install", install_path]
logger.debug(f"Python executable: {sys.executable}")
logger.debug(f"Current working directory: {os.getcwd()}")
install_args = [sys.executable, "-m", "pip", "install", "--upgrade", install_path]
logger.debug(f"Install command: {' '.join(install_args)}")

try:
result = subprocess.run(install_args, capture_output=True, text=True, check=True)
logger.info(f"Package installation output: {result.stdout}")
logger.info("Package installation succeeded")
logger.debug(f"Package installation stdout:\n{result.stdout}")
if result.stderr:
logger.warning(f"Package installation stderr: {result.stderr}")
logger.debug(f"Package installation stderr:\n{result.stderr}")
logger.info(f"Successfully installed package from: {install_path}")

# Reload site packages to make the newly installed package available
# This is necessary because the Python interpreter has already started
logger.info("Reloading site packages to pick up newly installed package")
importlib.reload(site)
logger.info(f"Updated sys.path: {sys.path}")
logger.debug(f"Updated sys.path: {sys.path}")

except subprocess.CalledProcessError as e:
logger.error(f"Failed to install package: {e}")
logger.error(f"stdout: {e.stdout}")
logger.error(f"stderr: {e.stderr}")
logger.error(f"Return code: {e.returncode}")
logger.error(f"Install command was: {' '.join(install_args)}")
logger.error(f"Package installation stdout:\n{e.stdout}")
logger.error(f"Package installation stderr:\n{e.stderr}")
raise RuntimeError(f"Package installation failed: {e}")
finally:
# Clean up extracted directory if it was created
Expand Down
2 changes: 1 addition & 1 deletion session_manager/src/scheduler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ mod tests {
description: None,
labels: Vec::new(),
arguments: Vec::new(),
working_directory: "/tmp".to_string(),
working_directory: Some("/tmp".to_string()),
environments: HashMap::new(),
shim: Shim::Host,
max_instances: 10,
Expand Down
Loading
Loading