-
Notifications
You must be signed in to change notification settings - Fork 5
enh(Agent): Allow for periodic historical tracking of X state #231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
b1c6d4c
enh(Agent): Periodic historical x tracking
Simpag be73521
docs(Advanded): Add advanced developer guide
Simpag 2261f24
test(Agent): Test inplace operators and some user docs update
Simpag bb25c0c
ref(Docs): Move advanced dev guide to new PR
Simpag e758a77
fix(PR): Fix PR comments
Simpag 0e49f32
fix(Agent): PR issues
Simpag 8e27b5a
fix(Agent): Fix PR comments
Simpag 3c202e2
fix(Agent): Change history_period to state_snapshot_period
Simpag d9f5372
fix(Test/Agent): Fix initial argument in test
Simpag 92fae58
fix(Plots): Fix plotting
Simpag 14f41e0
ref(Plot): Reorder functions
Simpag c681d28
Merge branch 'team-decent:main' into x-tracking
Simpag fe4b547
fix(Networks): Add snapshot period to FED network creation
Simpag 620b038
ref(Plot): Update arg doc
Simpag a28e95c
fix(Plot): Allow matplotlib to handle layout
Simpag 55d1a65
docs(User): Update example plot
Simpag File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,13 +6,32 @@ | |
| from numpy import linalg as la | ||
| from numpy.linalg import LinAlgError | ||
| from numpy.typing import NDArray | ||
| from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeRemainingColumn | ||
|
|
||
| import decent_bench.utils.interoperability as iop | ||
| from decent_bench.agents import AgentMetricsView | ||
| from decent_bench.benchmark_problem import BenchmarkProblem | ||
| from decent_bench.utils.array import Array | ||
|
|
||
|
|
||
| class MetricProgressBar(Progress): | ||
| """ | ||
| Progress bar for metric calculations. | ||
|
|
||
| Make sure to set the field *status* in the task to show custom status messages. | ||
|
|
||
| """ | ||
|
|
||
| def __init__(self) -> None: | ||
| super().__init__( | ||
| TextColumn("[progress.description]{task.description}"), | ||
| BarColumn(), | ||
| TaskProgressColumn(), | ||
| TimeRemainingColumn(elapsed_when_finished=True), | ||
| TextColumn("{task.fields[status]}"), | ||
| ) | ||
|
|
||
|
|
||
| def single(values: Sequence[float]) -> float: | ||
| """ | ||
| Assert that *values* contain exactly one element and return it. | ||
|
|
@@ -37,7 +56,11 @@ def x_mean(agents: tuple[AgentMetricsView, ...], iteration: int = -1) -> Array: | |
| ValueError: if no agent reached *iteration* | ||
|
|
||
| """ | ||
| all_x_at_iter = [a.x_history[iteration] for a in agents if len(a.x_history) > iteration] | ||
| if iteration == -1: | ||
| all_x_at_iter = [a.x_history[max(a.x_history)] for a in agents if len(a.x_history) > 0] | ||
| else: | ||
| all_x_at_iter = [a.x_history[iteration] for a in agents if iteration in a.x_history] | ||
|
|
||
| if len(all_x_at_iter) == 0: | ||
| raise ValueError(f"No agent reached iteration {iteration}") | ||
|
|
||
|
|
@@ -56,7 +79,7 @@ def regret(agents: list[AgentMetricsView], problem: BenchmarkProblem, iteration: | |
| mean_x = x_mean(tuple(agents), iteration) | ||
| optimal_cost = sum(a.cost.function(x_opt) for a in agents) | ||
| actual_cost = sum(a.cost.function(mean_x) for a in agents) | ||
| return abs(optimal_cost - actual_cost) | ||
| return actual_cost - optimal_cost | ||
|
|
||
|
|
||
| def gradient_norm(agents: list[AgentMetricsView], iteration: int = -1) -> float: | ||
|
|
@@ -83,7 +106,7 @@ def x_error(agent: AgentMetricsView, problem: BenchmarkProblem) -> NDArray[float | |
| where :math:`\mathbf{x}_k` is the agent's local x at iteration k, | ||
| and :math:`\mathbf{x}^\star` is the optimal x defined in the *problem*. | ||
| """ | ||
| x_per_iteration = np.asarray([iop.to_numpy(x) for x in agent.x_history]) | ||
| x_per_iteration = np.asarray([iop.to_numpy(x) for _, x in sorted(agent.x_history.items())]) | ||
| opt_x = problem.x_optimal | ||
| errors: NDArray[float64] = la.norm(x_per_iteration - opt_x, axis=tuple(range(1, x_per_iteration.ndim))) | ||
| return errors | ||
|
|
@@ -131,3 +154,22 @@ def iterative_convergence_rate_and_order(agent: AgentMetricsView, problem: Bench | |
| except LinAlgError: | ||
| rate, order = np.nan, np.nan | ||
| return rate, order | ||
|
|
||
|
|
||
| def common_sorted_iterations(agents: Sequence[AgentMetricsView]) -> list[int]: | ||
| """ | ||
| Get a sorted list of all common iterations reached by agents in *agents*. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would maybe add a bit more context, like: "since the agents can sample periodically, and potentially at different times, this function can be used to find the numbers of iterations where all agents have recorded their states, which can then be used to compute the metrics" |
||
|
|
||
| Since the agents can sample their states periodically, and may sample at different iterations, | ||
| this function returns only the iterations that are common to all agents. These iterations can then be used | ||
| to compute metrics that require synchronized iterations. | ||
|
|
||
| Args: | ||
| agents: sequence of agents to get the common iterations from | ||
|
|
||
| Returns: | ||
| sorted list of iterations reached by all agents | ||
|
|
||
| """ | ||
| common_iters = set.intersection(*(set(a.x_history.keys()) for a in agents)) if agents else set() | ||
| return sorted(common_iters) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.