feat(Empirical): Empirical risk cost#240
feat(Empirical): Empirical risk cost#240nicola-bastianello merged 10 commits intoteam-decent:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR refactors the monolithic costs.py file into a modular structure and introduces empirical risk cost functions with support for stochastic gradients and batch sampling. The key changes enable mini-batch training by allowing indices to be specified in gradient/function evaluations.
Changes:
- Introduced
EmpiricalRiskCostbase class with batch sampling capabilities for stochastic gradient methods - Refactored cost functions into a modular structure with
_base/and_empirical_risk/subdirectories - Converted
LinearRegressionCostandLogisticRegressionCostto inherit fromEmpiricalRiskCostwith optional batch size parameter
Reviewed changes
Copilot reviewed 14 out of 14 changed files in this pull request and generated 11 comments.
Show a summary per file
| File | Description |
|---|---|
| decent_bench/costs/_base/_cost.py | Base Cost class extracted from monolithic costs.py |
| decent_bench/costs/_base/_sum_cost.py | SumCost class for composing multiple costs |
| decent_bench/costs/_base/_quadratic_cost.py | QuadraticCost class extracted and refactored |
| decent_bench/costs/_empirical_risk/_empirical_risk_cost.py | New base class for empirical risk with batch sampling support |
| decent_bench/costs/_empirical_risk/_linear_regression_cost.py | Refactored LinearRegressionCost with batch support |
| decent_bench/costs/_empirical_risk/_logistic_regression_cost.py | Refactored LogisticRegressionCost with batch support |
| decent_bench/costs/_empirical_risk/init.py | Module exports for empirical risk costs |
| decent_bench/costs/_base/init.py | Module exports for base costs |
| decent_bench/costs/init.py | Updated top-level cost module exports |
| decent_bench/costs.py | Removed monolithic file |
| decent_bench/agents.py | Updated call counting methods to forward *args/**kwargs |
| decent_bench/benchmark_problem.py | Updated imports for new module structure |
| decent_bench/distributed_algorithms.py | Minor formatting adjustment in docstring |
| docs/source/api/decent_bench.costs.rst | Added private members to documentation |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
decent_bench/costs/_empirical_risk/_logistic_regression_cost.py
Outdated
Show resolved
Hide resolved
| """ | ||
| Proximal at x solved using an iterative method. | ||
|
|
||
| Note: |
There was a problem hiding this comment.
I'm now debating whether proximal should never have support for indices, not even in linear regression (which is essentially the only empirical risk cost which has a closed form proximal). the idea could be: 1) in linear regression we always return the full proximal, ignoring batch size; 2) in any other empirical risk cost, we put a note saying that if there is no closed form solution, the proximal will be approximated using stochastic gradients. what do you think?
There was a problem hiding this comment.
I've been thinking about the proximal for a while. I was wondering if there's any point in implementing an approximation for the proximal at all? If there is no closed form solution to the proximal then the approximation strategy for solving this will affect the performance of an algorithm using the proximal, thus makes it somewhat ambiguous to assess performance. I personally think that it should be up to the algorithm to implement this approximation so that the performance is based on how well the algorithm performs, not based on our approximation of the proximal.
There was a problem hiding this comment.
mainly, the approximation of the proximal was implemented to run distributed ADMM; while it is basically the only algorithm using proximal, it is an important one, so do like to have a default implementation. But I do agree that this can lead to "weird" results, where ADMM converges inexactly when it is expected to converge exactly if the proximal is exact. This could create confusion
So my proposal is the following: we make proximal no longer an abstractmethod, and instead in the base cost definition we raise NotImplementedError(), maybe with error message "see centralized_algorithms.proximal_solver for an implementation of the approximate proximal computation". We could also choose to move proximal_solver to a subfolder "utils" of costs. Then, we implement proximal in the costs that have a closed form solution (quadratic and linear regression), so that we still have examples of this in the docs. Finally, for linear regression I would go for the computation that uses the full dataset, adding a note that this is the case
What do you think? requires a bit more changes, but maybe it's neater this way
There was a problem hiding this comment.
We can think about moving the proximal_solver at a later date as cost functions might represent their data in different ways. I can do that when I change how the dataset is represented and make the proximal solver assume the representation of Dataset. I do agree that we should have NotImplemented error. For the case of linear regression, because the proximal solver creates a SumCost we cannot pass self since it might be batched, what I'll do is that I'll create a deepcopy of the cost and just set batch_size = n_samples.
There was a problem hiding this comment.
I also assume that when you wrote "Finally, for linear regression I would go for the computation that uses the full dataset, adding a note that this is the case" you mean logistic regression
There was a problem hiding this comment.
Sorry I didn't explain too well; to clarify:
- linear regression has a closed form proximal, so for that we can just use it, using the full dataset
- logistic regression doesn't have the closed form solution, so for that we can use the proximal solver (like you say with a deepcopy that has batch_size = n_samples)
There was a problem hiding this comment.
So you want to remove the indices option from all proximal computations and always use the full dataset or just in logistic regression case?
There was a problem hiding this comment.
Yes, I think it's best to remove indices from all proximal computations; there are stochastic versions of proximal evaluations, but they are not very common and I think it might just be confusing. So for logistic regression I would show the use of proximal_solver, making a note in the docstring that it will use stochastic gradients, resulting in inaccurate proximal computations. And we could show how to implement a version that uses full dataset, with the deepcopy you mentioned maybe. It's just a way to have an example of using proximal_solver in the docs, I expect it will be rarely (or never) used
|
in addition to my comments, I think the docstring of EmpiricalRisk should clarify very well how to use Questions:
|
|
You understood the behavior correctly and I agree that this could be better documented. Currently this is only documented in the base class properties. Regarding your questions, we can add the option for indices to be a string literal but this will complicate things a bit. I have a solution to this which might make the base class a bit messier but will make it easier for users to create new costs. I cant really see a reason why someone would want to change the batch-size. Usually batch size is decided based on available resources so having that being scheduled in some way seems unlikely. But we can always add it if you think it would be of use The reason why I set the batch_size to allow for None in init is because if we dont allow it to be None we cant have batch_size as a default argument since we cannot know the size of the dataset. We can of course set batch_size = n_samples during runtime in the init. I had that originally but I changed it to None to make the |
great, thank you
ok, then let's see the solution you propose; if we think it's too messy we can also avoid having the string literal option
good point; then I agree it's not needed. in any case if an algorithm needs to change the batch size, it can do so by passing the
ok, then if the performance improvement is small, I would prefer to set batch_size = n_samples at runtime. Feels like the solution that will be easier to work with in the future (and if a user accesses the property, I think it's more expected to see an int instead of None) |
Yes, the indices argument does not care about batch_size
I agree with you, I have made some changes that I will push shortly. I think the new way indices work we wont need to document the behavior too extensively. |
|
Sphinx seems to fail because it randomly decided to use version Edit: It seems like the theme is the issue forcing an install of an older version, their build is failing (updated 3h ago so that lines up): https://github.com/pydata/pydata-sphinx-theme?tab=readme-ov-file |
|
It seems like we are always building the theme from their latest main branch and not their stable releases. Let me update the config file to use the stable releases and hope that the checks use the one in the PR. |
could you try fixing this? I guess only ci.yaml needs to change. we can also do it later |
| def _call_counting_function(self, x: Array, *args: Any, **kwargs: Any) -> float: # noqa: ANN401 | ||
| self._n_function_calls += 1 | ||
| return self._cost.__class__.function(self.cost, x) | ||
| return self._cost.__class__.function(self.cost, x, *args, **kwargs) | ||
|
|
||
| def _call_counting_gradient(self, x: Array) -> Array: | ||
| def _call_counting_gradient(self, x: Array, *args: Any, **kwargs: Any) -> Array: # noqa: ANN401 | ||
| self._n_gradient_calls += 1 | ||
| return self._cost.__class__.gradient(self.cost, x) | ||
| return self._cost.__class__.gradient(self.cost, x, *args, **kwargs) | ||
|
|
||
| def _call_counting_hessian(self, x: Array) -> Array: | ||
| def _call_counting_hessian(self, x: Array, *args: Any, **kwargs: Any) -> Array: # noqa: ANN401 | ||
| self._n_hessian_calls += 1 | ||
| return self._cost.__class__.hessian(self.cost, x) | ||
| return self._cost.__class__.hessian(self.cost, x, *args, **kwargs) | ||
|
|
||
| def _call_counting_proximal(self, x: Array, rho: float) -> Array: | ||
| def _call_counting_proximal(self, x: Array, rho: float, *args: Any, **kwargs: Any) -> Array: # noqa: ANN401 | ||
| self._n_proximal_calls += 1 | ||
| return self._cost.__class__.proximal(self.cost, x, rho) | ||
| return self._cost.__class__.proximal(self.cost, x, rho, *args, **kwargs) |
There was a problem hiding this comment.
Consider def _call_counting_x[**P](self, x: Array, *args: P.args, **kwargs: P.kwargs) instead of Any and noqa.
There was a problem hiding this comment.
I agree, thats better. However, mypy does complain that it is unbound and afaik there is no bound for paramspec?
There was a problem hiding this comment.
I have not found a solution to this that makes mypy happy, maybe you can find one since you have more experience with mypy?
There was a problem hiding this comment.
otherwise it's ok to use noqa
| @@ -1 +1 @@ | |||
| git+https://github.com/pydata/pydata-sphinx-theme.git@main | |||
| pydata-sphinx-theme | |||
There was a problem hiding this comment.
We'll lose search-as-you-type in the docs' search functionality with this change. Is it really needed?
There was a problem hiding this comment.
Always building from their git source breaks sometimes. Yesterday they pushed an update which forced sphinx to use an older version which broke the docs and the github checks. It also solved some weird interactions causing some elements to appear and disappear in the local preview of the docs which is really annoying when trying to read the docs as the entire page moves around. I think it is worth the sacrifice of search-as-you-type for better stability and not having to worry about our theme breaking the docs randomly
There was a problem hiding this comment.
I see, let's use the latest working commit: git+https://github.com/pydata/pydata-sphinx-theme.git@236d4af4b40edd607021e17342f4bf5870d64f0d
There was a problem hiding this comment.
I found that the bugginess of the docs stopped once the switch to the pip package was made, I personally don’t think the search as you type feature is that important and I’d rather have a more stable experience but lets see what @nicola-bastianello has to say
It might be some issue on my end if neither of you have experienced it
There was a problem hiding this comment.
I think I prefer to prioritize stability in the docs as much as possible, so I think the current solution is good
There was a problem hiding this comment.
@Simpag do you remember when it started to bug for you? If yes, just pick a commit from slightly before then and do git+https://github.com/pydata/pydata-sphinx-theme.git@236d4af4b40edd607021e17342f4bf5870d64f0d instead of git+https://github.com/pydata/pydata-sphinx-theme.git@main. Here is the commit list: https://github.com/pydata/pydata-sphinx-theme/commits/main/
This way, we get stability without sacrificing search-as-you-type which imho is super nice when searching the docs.
There was a problem hiding this comment.
If I’m not misstaken its been buggy for me the entire time. Have either of you experienced this? I am running ubuntu in wsl so that might be why if I’m the only one experiencing this.
There was a problem hiding this comment.
I'm not sure I noticed the same bugs, but if in even one setup they show up, it's better to go for the solution that solves them. it also looks like the current solution is more hands-free. if we select a specific commit from their github repo, once a new stable commit comes around, we need to change the signature to reference it. this is for sure something that will get forgotten
There was a problem hiding this comment.
We might be able to find another theme with this built in. This feature seems to have been added around mid 2025. From my minor research it seems like the read-the-docs theme has this feature but I haven’t had time to verify or look for a more similar theme
There was a problem hiding this comment.
Nicola has a good point about forgetting to update the version when the next release comes out. Let's keep the change.
| :private-members: | ||
| _sample_batch_indices, | ||
| _get_batch_data, |
There was a problem hiding this comment.
If we want user docs for these methods/properties, perhaps they should be public
There was a problem hiding this comment.
They are not meant to be used outside of the cost function class but _get_batch_data is required to implement which should make use of _get_batch_indices. We thought that it would be nice to have those documented so users can see that they are required
| @@ -0,0 +1,137 @@ | |||
| from __future__ import annotations | |||
|
|
|||
There was a problem hiding this comment.
Maybe out of scope for now but something to keep in mind:
Ruff doesn't enforce code doc for private modules (e.g. D101-D103). This is one of the reasons why the current structure is very flat and only has public modules (except for the new interoperability package). I think it's better to keep modules public until Ruff fixes this (see astral-sh/ruff#9946 and astral-sh/ruff#9561). An alternative is to configure pydocstyle (which supprts doc enforcement in private modules) as a complement to ruff but that feels out of scope for this PR.
There was a problem hiding this comment.
ok, thanks for letting us know of this. let's keep this in mind for future PRs
This PR adds an empirical risk cost function and minor refactoring of the monolithic costs.py file.
Empirical risk cost function implements support for stochastic gradients and batch sampling. Linear- and logistic regression costs have been converted to empirical risk cost functions. Setting batch-size to None will yield the same results as previously.
closes #141