diff --git a/examples/asha_example/experiment_1.yaml b/examples/asha_example/experiment_1.yaml new file mode 100644 index 0000000..7c313a2 --- /dev/null +++ b/examples/asha_example/experiment_1.yaml @@ -0,0 +1,49 @@ +seml: + executable: main.py + name: asha_import_test + output_dir: logs + project_root_dir: . + +slurm: + experiments_per_job: 1 + max_simultaneous_jobs: 10 + sbatch_options: + gres: gpu:0 + mem: 2G + cpus-per-task: 1 + time: 0-01:00 + + +fixed: + dataset: 'mnist' + base_shared_dir: "./shared/experiments" + asha_collection_name: asha_import_test + num_stages: 10 + asha: + eta: 3 + min_r: 1 + max_r: 20 + metric_increases: True #True or False + +random: + samples: 3 # Run 5 random configurations + seed: 42 + + hidden_units: + type: choice + options: + - [64] + - [128, 64] + - [256, 128, 64] + + dropout: + type: uniform + min: 0.0 + max: 0.5 + + learning_rate: + type: loguniform + min: 1e-4 + max: 1e-2 + + diff --git a/examples/asha_example/main.py b/examples/asha_example/main.py new file mode 100644 index 0000000..6dd9f88 --- /dev/null +++ b/examples/asha_example/main.py @@ -0,0 +1,156 @@ +import uuid + +import torch +from model import SimpleNN +from seml import Experiment +from seml.database import get_mongodb_config +from seml.utils import ASHA # Import asha class +from torch import nn, optim +from torch.utils.data import DataLoader, random_split +from torchvision import datasets, transforms + +# def seed_everything(job_id): +# # Combine job_id and entropy to make a unique seed +# entropy = f"{job_id}-{time.time()}-{os.urandom(8)}".encode("utf-8") +# seed = int(hashlib.sha256(entropy).hexdigest(), 16) % (2**32) +# print(f"[Seed] Using seed: {seed}") +# random.seed(seed) +# np.random.seed(seed) +# torch.manual_seed(seed) +# if torch.cuda.is_available(): +# torch.cuda.manual_seed_all(seed) +# return seed + +experiment = Experiment() + + +@experiment.config +def default_config(): + num_stages = 10 + dataset = "mnist" + hidden_units = [64] + dropout = 0.3 + learning_rate = 1e-3 + base_shared_dir = "./shared/experiments" # Parent directory shared across jobs + job_id = None # Must be unique per job + seed = 42 + asha_collection_name = "unknown_experiment" + samples = 5 + asha = {"eta": 3, "min_r": 1, "max_resource": 20, "progression": "increase"} + + +@experiment.automain +def main( + num_stages, + dataset, + hidden_units, + dropout, + learning_rate, + base_shared_dir, + job_id, + asha, + _log, + _run, +): + mongodb_configurations = get_mongodb_config() + + print( + f"job parameters, hiddenunits:{hidden_units}, dropout:{dropout}, learningrate:{learning_rate}" + ) + if job_id is None: + job_id = str(uuid.uuid4()) + # job_id = str(_run._id) + + asha_collection_name = _run.config.get("asha_collection_name", "unknown_experiment") + print("Run info:", _run.experiment_info) + + # Create model + model = SimpleNN(hidden_units=hidden_units, dropout=dropout) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + + # Prepare dataset and loaders + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ) + full_dataset = datasets.MNIST( + root="./data", train=True, download=True, transform=transform + ) + train_size = int(0.8 * len(full_dataset)) + val_size = len(full_dataset) - train_size + train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size]) + + train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False) + + # Loss and optimizer + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(model.parameters(), lr=learning_rate) + + # ASHA setup + eta = asha["eta"] + min_r = asha["min_r"] + max_r = asha["max_resource"] + metric_increases = asha["metric_increases"] + tracker = ASHA( + asha_collection_name=asha_collection_name, + eta=eta, + min_r=min_r, + max_r=max_r, + metric_increases=metric_increases, + mongodb_configurations=mongodb_configurations, + _log=_log, + ) + for stage in range(num_stages): + # Training + model.train() + for inputs, targets in train_loader: + inputs, targets = inputs.to(device), targets.to(device) + optimizer.zero_grad() + outputs = model(inputs) + loss = criterion(outputs, targets) + loss.backward() + optimizer.step() + + # Validation + model.eval() + correct = 0 + total = 0 + with torch.no_grad(): + for inputs, targets in val_loader: + inputs, targets = inputs.to(device), targets.to(device) + outputs = model(inputs) + _, predicted = torch.max(outputs, 1) + correct += (predicted == targets).sum().item() + total += targets.size(0) + + metric = correct / total + print(f"[Epoch {stage}] Validation Accuracy: {metric:.4f}") + + if stage < (num_stages - 1): + should_stop = tracker.store_stage_metric(stage, metric) + if should_stop: + print("We should end this process here") + print( + f"job parameters, hiddenunits:{hidden_units}, dropout:{dropout}, learningrate:{learning_rate}" + ) + break + else: + print("job finished") + print( + f"job parameters, hiddenunits:{hidden_units}, dropout:{dropout}, learningrate:{learning_rate}" + ) + + return { + "asha_collection_name": asha_collection_name, + "job_id": job_id, + "metric_history": tracker.metric_history, + "final_metric": tracker.metric_history[-1], + "hidden_units": hidden_units, + "dropout": dropout, + "learning_rate": learning_rate, + "num_stages": num_stages, + "dataset": dataset, + "device": str(device), + "final_stage": len(tracker.metric_history), + } diff --git a/examples/asha_example/model.py b/examples/asha_example/model.py new file mode 100644 index 0000000..eafdfac --- /dev/null +++ b/examples/asha_example/model.py @@ -0,0 +1,22 @@ +import torch +import torch.nn as nn + + +# Define your custom PyTorch model +class SimpleNN(nn.Module): + def __init__(self, hidden_units=[128, 64], dropout=0.0): + super().__init__() + layers = [] + input_size = 28 * 28 + for h in hidden_units: + layers.append(nn.Linear(input_size, h)) + layers.append(nn.ReLU()) + if dropout > 0: + layers.append(nn.Dropout(dropout)) + input_size = h + layers.append(nn.Linear(input_size, 10)) # output layer + self.net = nn.Sequential(*layers) + + def forward(self, x): + x = torch.flatten(x, 1) + return self.net(x) diff --git a/src/seml/utils/asha.py b/src/seml/utils/asha.py new file mode 100644 index 0000000..a132d3a --- /dev/null +++ b/src/seml/utils/asha.py @@ -0,0 +1,270 @@ +import math +import uuid +from logging import Logger +from typing import Any + +from pymongo import MongoClient +from pymongo.collection import Collection + + +class ASHA: + def __init__( + self, + asha_collection_name: str, + eta: int | float, + min_r: int, + max_r: int, + metric_increases: bool, + mongodb_configurations: dict[str, Any], + _log: Logger, + ) -> None: + """Doc string pretty please ^^ + + Args: + asha_collection_name (str): _description_ + eta (float): _description_ + min_r (int): _description_ + max_r (int): _description_ + metric_increases (bool): _description_ + mongodb_configurations (_type_): _description_ + """ + # ! TODO: Adding argument verification + self.asha_collection_name = asha_collection_name + self._log = _log + self.job_uuid = str(uuid.uuid4()) + # ! TODO: No printing, use _log instead, also please make sure to properly distinguish between info/warning/debug, etc. + self._log.info( + f'--------------------------JobUUID:{self.job_uuid}---------------------' + ) + self.metric_history = [] + self.others_metric_at_stage = {} + self.eta = eta + self.min_r = min_r + self.max_r = max_r + self.metric_increases = metric_increases + self.rungs = self.generate_rungs(self.min_r, self.eta, self.max_r) + self.mongodb_configurations = mongodb_configurations + self.samples = 5 # <- Not sure what this does? But it appears like it was hardcoded to 5 in the original main.py, comment: this was to make the isbest function that doesn't work yet + self.collection = self._get_mongo_collection( + self.mongodb_configurations, self.asha_collection_name + ) + + def _get_mongo_collection( + self, mongodb_configurations: dict[str, Any], experiment_name: str + ) -> Collection: + """ + Connecting to the MongoDB, credentials from SEML config + returns connection + """ + # ? TODO: Adding some retry logic if we exeperience transient connection issues? + auth_source = mongodb_configurations.get( + 'authSource', mongodb_configurations['db_name'] + ) + if mongodb_configurations.get('username') and mongodb_configurations.get( + 'password' + ): + uri = f'mongodb://{mongodb_configurations["username"]}:{mongodb_configurations["password"]}@{mongodb_configurations["host"]}:{mongodb_configurations["port"]}/?authSource={auth_source}' + else: + uri = f'mongodb://{mongodb_configurations["host"]}:{mongodb_configurations["port"]}' + + client = MongoClient(uri, serverSelectionTimeoutMS=5000) # 5 sec timeout + client.admin.command('ping') # check connection + + db = client[mongodb_configurations['db_name']] + collection = db[experiment_name] + self._log.debug( + f"Connected to MongoDB and accessed collection '{experiment_name}' successfully." + ) + return collection + + def save_metric_to_db( + self, collection: Collection, job_id: str, stage: int, metric: float + ) -> None: + """ + Insert or update metric for the given job_id and stage in the MongoDB collection. + """ + collection.update_one( + {'job_id': job_id, 'stage': stage}, + {'$set': {'metric': metric}}, + upsert=True, + ) + + def store_stage_metric(self, stage: int, metric: float) -> bool: + """ + Accuracy added and other metrics compaired, + probably should break this into different functions, + as of now this is our running function + """ + other_job_metrics = {} + + # ? TODO: Should we really submitt and pull from the database on every stage, even if it isn't a decision rung yet? + self._log.debug('Trying MongoDB access...') + # ? TODO: Keep the connection open instead of recreating it each stage? + self.collection = self._get_mongo_collection( + self.mongodb_configurations, self.asha_collection_name + ) + self.save_metric_to_db(self.collection, self.job_uuid, stage, metric) + self._log.debug('storage in mongodb') + other_job_metrics = self.get_metric_at_stage_db( + self.collection, stage, self.job_uuid + ) + + self.metric_history.append(metric) + self.others_metric_at_stage[str(stage)] = other_job_metrics + + promote = True + should_terminate = False + + # ! TODO: Asha should not be required to know about the number of stages + # if stage == self.num_stages - 1: + # self.set_status_db("Completed") + # # self.isbest() + # elif stage in self.rungs: + if stage in self.rungs: + self._log.info(f'checking stage {stage}') + self._print_stage_info(stage, metric, other_job_metrics) + promote = self._job_promotion(metric, other_job_metrics, self.eta) + if promote: + self._log.info(f'this job was promoted at {stage}') + pass + else: + self._log.info(f'this job should be terminated at {stage}') + should_terminate = True + self.set_status_db('Completed') + return should_terminate + + def metric_in_rungs(self, stage: int) -> bool: + """ + if user wants to check if their stage/resource is in a rung + """ + if stage in self.rungs: + return True + else: + return False + + def get_metric_at_stage_db( + self, collection: Collection, stage: int, current_job_id: str | None = None + ) -> dict[str, float]: + """ + Retrieve metrics of all jobs at the specified stage from the MongoDB collection. + Returns a dict: {job_id: metric}, excluding current_job_id if provided. + """ + results = collection.find({'stage': stage}) + metrics = {} + for doc in results: + job_id = doc.get('job_id') + metric = doc.get('metric', -1.0) + if job_id and job_id != current_job_id: + metrics[job_id] = metric + return metrics + + def _print_stage_info( + self, stage: int, metric: float, other_job_metrics: dict[str, float] + ) -> None: + self._log.info(f'[Epoch {stage}] Own metric: {metric}') + self._log.info(f"[Epoch {stage}] Other jobs' metrics: {other_job_metrics}") + pass + + def _job_promotion( + self, metric: float, other_job_metrics: dict[str, float], eta: int | float + ) -> bool: + """ + returns cutoff metric at which jobs should be promoted + """ + self._log.info('Checking if this job progresses') + + valid_metrics = [acc for acc in other_job_metrics.values() if acc > -1] + [ + metric + ] + sorted_vals = sorted(valid_metrics, reverse=True) + cutoff_metric = 0.0 + promotion = 'True' + + if self.metric_increases: + # ? TODO: Should this be ceil, round or floor + top_k = max(1, math.floor(len(sorted_vals) // eta)) + cutoff_metric = sorted_vals[top_k - 1] + promotion = metric >= cutoff_metric or math.isclose( + metric, cutoff_metric, rel_tol=1e-9 + ) + else: + sorted_vals = sorted(valid_metrics) # ascending: lowest first + # ? TODO: Should this be ceil, round or floor + bottom_k = max(1, math.floor(len(sorted_vals) // eta)) + cutoff_metric = sorted_vals[bottom_k - 1] # kth lowest value + promotion = metric <= cutoff_metric or math.isclose( + metric, cutoff_metric, rel_tol=1e-9 + ) + + self._log.info(f'Valid metrics (sorted): {sorted_vals}') + self._log.info(f'Cutoff metric for promotion: {cutoff_metric:.8f}') + self._log.info(f'Current job metric: {metric:.8f}') + self._log.info(f'Promotion decision: {promotion}') + self._log.info('--------------------------------------------------') + + return promotion + + def generate_rungs(self, min_r: int, eta: int | float, max_r: int) -> list[int]: + """ + generates rungs at which promotion will be checked + """ + rungs = [] + resource = min_r + if min_r > max_r: + raise ValueError('min_r must be <= max_r') + + while resource <= max_r: + # Rounding allows for eta to be a floating point value + resource = int(round(resource)) + rungs.append(resource) + resource *= eta + + self._log.info(f'Generated rungs of the following shape: {rungs}') + per_sample_avg_stages = sum( + [stages / (eta**i) for (i, stages) in enumerate(rungs)] + + [max_r / (eta ** len(rungs))] + ) + self._log.info( + f'Given this ASHA configuration, the expected average number of stages per sample is: {per_sample_avg_stages:.2f}' + ) + return rungs + + def set_status_db(self, status: str) -> None: + """ + set status in mongodb collection to mark if process is still running + """ + self.collection.update_one( + {'job_id': self.job_uuid}, {'$set': {'Status': status}}, upsert=True + ) + + # def isbest(self,metric,other_job_metrics): + def isbest(self) -> None: + """ + this function doesn't is incorrect, + working on it to use the status to see if all jobs are completed + """ + + completed_jobs = list(self.collection.find({'Status': 'Completed'})) + if len(completed_jobs) != self.samples: + return + + best_job = max(completed_jobs, key=lambda doc: doc.get('metric', -1.0)) + best_job_id = best_job.get('job_id') + + # Mark all completed jobs as not best + self.collection.update_many({'Status': 'Completed'}, {'$set': {'BEST': 'NO'}}) + + # Mark the best job + self.collection.update_one( + {'job_id': best_job_id}, {'$set': {'BEST': 'YES'}}, upsert=True + ) + + # if best_job_id == self.job_uuid: + # self._run.log_scalar("Finished Last but was the best") + + # valid_metrics = [acc for acc in other_job_metrics.values() if acc > -1] + [metric] + # if len(other_job_metrics)+1 == len(valid_metrics): + # if metric == max(valid_metrics): + # return True + # else: + # return False