Skip to content

feat(Empirical): Empirical risk cost#240

Merged
nicola-bastianello merged 10 commits intoteam-decent:mainfrom
Simpag:empirical-cost
Jan 27, 2026
Merged

feat(Empirical): Empirical risk cost#240
nicola-bastianello merged 10 commits intoteam-decent:mainfrom
Simpag:empirical-cost

Conversation

@Simpag
Copy link
Contributor

@Simpag Simpag commented Jan 20, 2026

This PR adds an empirical risk cost function and minor refactoring of the monolithic costs.py file.

Empirical risk cost function implements support for stochastic gradients and batch sampling. Linear- and logistic regression costs have been converted to empirical risk cost functions. Setting batch-size to None will yield the same results as previously.

closes #141

Copilot AI review requested due to automatic review settings January 20, 2026 16:40
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the monolithic costs.py file into a modular structure and introduces empirical risk cost functions with support for stochastic gradients and batch sampling. The key changes enable mini-batch training by allowing indices to be specified in gradient/function evaluations.

Changes:

  • Introduced EmpiricalRiskCost base class with batch sampling capabilities for stochastic gradient methods
  • Refactored cost functions into a modular structure with _base/ and _empirical_risk/ subdirectories
  • Converted LinearRegressionCost and LogisticRegressionCost to inherit from EmpiricalRiskCost with optional batch size parameter

Reviewed changes

Copilot reviewed 14 out of 14 changed files in this pull request and generated 11 comments.

Show a summary per file
File Description
decent_bench/costs/_base/_cost.py Base Cost class extracted from monolithic costs.py
decent_bench/costs/_base/_sum_cost.py SumCost class for composing multiple costs
decent_bench/costs/_base/_quadratic_cost.py QuadraticCost class extracted and refactored
decent_bench/costs/_empirical_risk/_empirical_risk_cost.py New base class for empirical risk with batch sampling support
decent_bench/costs/_empirical_risk/_linear_regression_cost.py Refactored LinearRegressionCost with batch support
decent_bench/costs/_empirical_risk/_logistic_regression_cost.py Refactored LogisticRegressionCost with batch support
decent_bench/costs/_empirical_risk/init.py Module exports for empirical risk costs
decent_bench/costs/_base/init.py Module exports for base costs
decent_bench/costs/init.py Updated top-level cost module exports
decent_bench/costs.py Removed monolithic file
decent_bench/agents.py Updated call counting methods to forward *args/**kwargs
decent_bench/benchmark_problem.py Updated imports for new module structure
decent_bench/distributed_algorithms.py Minor formatting adjustment in docstring
docs/source/api/decent_bench.costs.rst Added private members to documentation

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

"""
Proximal at x solved using an iterative method.

Note:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm now debating whether proximal should never have support for indices, not even in linear regression (which is essentially the only empirical risk cost which has a closed form proximal). the idea could be: 1) in linear regression we always return the full proximal, ignoring batch size; 2) in any other empirical risk cost, we put a note saying that if there is no closed form solution, the proximal will be approximated using stochastic gradients. what do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been thinking about the proximal for a while. I was wondering if there's any point in implementing an approximation for the proximal at all? If there is no closed form solution to the proximal then the approximation strategy for solving this will affect the performance of an algorithm using the proximal, thus makes it somewhat ambiguous to assess performance. I personally think that it should be up to the algorithm to implement this approximation so that the performance is based on how well the algorithm performs, not based on our approximation of the proximal.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mainly, the approximation of the proximal was implemented to run distributed ADMM; while it is basically the only algorithm using proximal, it is an important one, so do like to have a default implementation. But I do agree that this can lead to "weird" results, where ADMM converges inexactly when it is expected to converge exactly if the proximal is exact. This could create confusion

So my proposal is the following: we make proximal no longer an abstractmethod, and instead in the base cost definition we raise NotImplementedError(), maybe with error message "see centralized_algorithms.proximal_solver for an implementation of the approximate proximal computation". We could also choose to move proximal_solver to a subfolder "utils" of costs. Then, we implement proximal in the costs that have a closed form solution (quadratic and linear regression), so that we still have examples of this in the docs. Finally, for linear regression I would go for the computation that uses the full dataset, adding a note that this is the case

What do you think? requires a bit more changes, but maybe it's neater this way

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can think about moving the proximal_solver at a later date as cost functions might represent their data in different ways. I can do that when I change how the dataset is represented and make the proximal solver assume the representation of Dataset. I do agree that we should have NotImplemented error. For the case of linear regression, because the proximal solver creates a SumCost we cannot pass self since it might be batched, what I'll do is that I'll create a deepcopy of the cost and just set batch_size = n_samples.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also assume that when you wrote "Finally, for linear regression I would go for the computation that uses the full dataset, adding a note that this is the case" you mean logistic regression

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I didn't explain too well; to clarify:

  • linear regression has a closed form proximal, so for that we can just use it, using the full dataset
  • logistic regression doesn't have the closed form solution, so for that we can use the proximal solver (like you say with a deepcopy that has batch_size = n_samples)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you want to remove the indices option from all proximal computations and always use the full dataset or just in logistic regression case?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think it's best to remove indices from all proximal computations; there are stochastic versions of proximal evaluations, but they are not very common and I think it might just be confusing. So for logistic regression I would show the use of proximal_solver, making a note in the docstring that it will use stochastic gradients, resulting in inaccurate proximal computations. And we could show how to implement a version that uses full dataset, with the deepcopy you mentioned maybe. It's just a way to have an example of using proximal_solver in the docs, I expect it will be rarely (or never) used

@nicola-bastianello
Copy link
Member

in addition to my comments, I think the docstring of EmpiricalRisk should clarify very well how to use indices. If I'm not mistaken, the current behavior is: 1) if batch_size=None in init, full dataset is always used; 2) if indices=int or list, the data corresponding to the indices are used; 3) if indices=None, a random batch of size batch_size is used. So we should mention all of this

Questions:

  • the only ways to use the full dataset are either to set batch_size=None at init, or indices=range(n_samples). Would it make sense to allow for a simplified way to do indices=range(n_samples) without using range? Something like indices="all"
  • Should we implement a specific "setter" method for batch_size, so that users can change it but when they do we can check if the proposed batch_size is in the allowed range (<= n_samples)?
  • the current behavior is to set batch_size=None to signify that the whole dataset should be used; could it make sense to set batch_size=n_samples instead? this way batch_size is always a number, never None; users could still set cost.batch_size=None, but the setter then converts to cost.batch_size=n_samples. Basically, we only need to care about interpreting batch_size=None in the setter method and nowhere else, and we can use batch_size everywhere else (e.g. if we account for batch size in computational cost) knowing it's a number

@Simpag
Copy link
Contributor Author

Simpag commented Jan 22, 2026

You understood the behavior correctly and I agree that this could be better documented. Currently this is only documented in the base class properties.

Regarding your questions, we can add the option for indices to be a string literal but this will complicate things a bit. I have a solution to this which might make the base class a bit messier but will make it easier for users to create new costs.

I cant really see a reason why someone would want to change the batch-size. Usually batch size is decided based on available resources so having that being scheduled in some way seems unlikely. But we can always add it if you think it would be of use

The reason why I set the batch_size to allow for None in init is because if we dont allow it to be None we cant have batch_size as a default argument since we cannot know the size of the dataset. We can of course set batch_size = n_samples during runtime in the init. I had that originally but I changed it to None to make the _sample_batch_indices method slightly faster but the improvement is miniscule so it shouldnt matter much.

@nicola-bastianello
Copy link
Member

You understood the behavior correctly and I agree that this could be better documented. Currently this is only documented in the base class properties.

great, thank you

Regarding your questions, we can add the option for indices to be a string literal but this will complicate things a bit. I have a solution to this which might make the base class a bit messier but will make it easier for users to create new costs.

ok, then let's see the solution you propose; if we think it's too messy we can also avoid having the string literal option

I cant really see a reason why someone would want to change the batch-size. Usually batch size is decided based on available resources so having that being scheduled in some way seems unlikely. But we can always add it if you think it would be of use

good point; then I agree it's not needed. in any case if an algorithm needs to change the batch size, it can do so by passing the indices directly which ignores the batch_size property (right?)

The reason why I set the batch_size to allow for None in init is because if we dont allow it to be None we cant have batch_size as a default argument since we cannot know the size of the dataset. We can of course set batch_size = n_samples during runtime in the init. I had that originally but I changed it to None to make the _sample_batch_indices method slightly faster but the improvement is miniscule so it shouldnt matter much.

ok, then if the performance improvement is small, I would prefer to set batch_size = n_samples at runtime. Feels like the solution that will be easier to work with in the future (and if a user accesses the property, I think it's more expected to see an int instead of None)

@Simpag
Copy link
Contributor Author

Simpag commented Jan 22, 2026

You understood the behavior correctly and I agree that this could be better documented. Currently this is only documented in the base class properties.

great, thank you

Regarding your questions, we can add the option for indices to be a string literal but this will complicate things a bit. I have a solution to this which might make the base class a bit messier but will make it easier for users to create new costs.

ok, then let's see the solution you propose; if we think it's too messy we can also avoid having the string literal option

I cant really see a reason why someone would want to change the batch-size. Usually batch size is decided based on available resources so having that being scheduled in some way seems unlikely. But we can always add it if you think it would be of use

good point; then I agree it's not needed. in any case if an algorithm needs to change the batch size, it can do so by passing the indices directly which ignores the batch_size property (right?)

Yes, the indices argument does not care about batch_size

The reason why I set the batch_size to allow for None in init is because if we dont allow it to be None we cant have batch_size as a default argument since we cannot know the size of the dataset. We can of course set batch_size = n_samples during runtime in the init. I had that originally but I changed it to None to make the _sample_batch_indices method slightly faster but the improvement is miniscule so it shouldnt matter much.

ok, then if the performance improvement is small, I would prefer to set batch_size = n_samples at runtime. Feels like the solution that will be easier to work with in the future (and if a user accesses the property, I think it's more expected to see an int instead of None)

I agree with you, I have made some changes that I will push shortly. I think the new way indices work we wont need to document the behavior too extensively.

@Simpag
Copy link
Contributor Author

Simpag commented Jan 22, 2026

Sphinx seems to fail because it randomly decided to use version v8.2.3 instead of v9.0.4 which it used earlier and we use locally? Maybe we can wait a bit and rerun the test. Also noticed that we run mypy and sphinx on macos, we might want to switch it to linux as it is most likely the most stable and up-to-date version.

Edit:
If I re-install my sphinx environment pip installs sphinx 8.2.3 for some reason even if sphinx 9.1.0 is the latest version.

It seems like the theme is the issue forcing an install of an older version, their build is failing (updated 3h ago so that lines up): https://github.com/pydata/pydata-sphinx-theme?tab=readme-ov-file

@Simpag
Copy link
Contributor Author

Simpag commented Jan 22, 2026

It seems like we are always building the theme from their latest main branch and not their stable releases. Let me update the config file to use the stable releases and hope that the checks use the one in the PR.

@nicola-bastianello
Copy link
Member

nicola-bastianello commented Jan 23, 2026

Also noticed that we run mypy and sphinx on macos, we might want to switch it to linux as it is most likely the most stable and up-to-date version.

could you try fixing this? I guess only ci.yaml needs to change. we can also do it later

Comment on lines +120 to +134
def _call_counting_function(self, x: Array, *args: Any, **kwargs: Any) -> float: # noqa: ANN401
self._n_function_calls += 1
return self._cost.__class__.function(self.cost, x)
return self._cost.__class__.function(self.cost, x, *args, **kwargs)

def _call_counting_gradient(self, x: Array) -> Array:
def _call_counting_gradient(self, x: Array, *args: Any, **kwargs: Any) -> Array: # noqa: ANN401
self._n_gradient_calls += 1
return self._cost.__class__.gradient(self.cost, x)
return self._cost.__class__.gradient(self.cost, x, *args, **kwargs)

def _call_counting_hessian(self, x: Array) -> Array:
def _call_counting_hessian(self, x: Array, *args: Any, **kwargs: Any) -> Array: # noqa: ANN401
self._n_hessian_calls += 1
return self._cost.__class__.hessian(self.cost, x)
return self._cost.__class__.hessian(self.cost, x, *args, **kwargs)

def _call_counting_proximal(self, x: Array, rho: float) -> Array:
def _call_counting_proximal(self, x: Array, rho: float, *args: Any, **kwargs: Any) -> Array: # noqa: ANN401
self._n_proximal_calls += 1
return self._cost.__class__.proximal(self.cost, x, rho)
return self._cost.__class__.proximal(self.cost, x, rho, *args, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider def _call_counting_x[**P](self, x: Array, *args: P.args, **kwargs: P.kwargs) instead of Any and noqa.

Copy link
Contributor Author

@Simpag Simpag Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, thats better. However, mypy does complain that it is unbound and afaik there is no bound for paramspec?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not found a solution to this that makes mypy happy, maybe you can find one since you have more experience with mypy?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

otherwise it's ok to use noqa

@@ -1 +1 @@
git+https://github.com/pydata/pydata-sphinx-theme.git@main
pydata-sphinx-theme
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll lose search-as-you-type in the docs' search functionality with this change. Is it really needed?

Copy link
Contributor Author

@Simpag Simpag Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Always building from their git source breaks sometimes. Yesterday they pushed an update which forced sphinx to use an older version which broke the docs and the github checks. It also solved some weird interactions causing some elements to appear and disappear in the local preview of the docs which is really annoying when trying to read the docs as the entire page moves around. I think it is worth the sacrifice of search-as-you-type for better stability and not having to worry about our theme breaking the docs randomly

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@Simpag Simpag Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found that the bugginess of the docs stopped once the switch to the pip package was made, I personally don’t think the search as you type feature is that important and I’d rather have a more stable experience but lets see what @nicola-bastianello has to say

It might be some issue on my end if neither of you have experienced it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I prefer to prioritize stability in the docs as much as possible, so I think the current solution is good

Copy link
Member

@elramen elramen Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Simpag do you remember when it started to bug for you? If yes, just pick a commit from slightly before then and do git+https://github.com/pydata/pydata-sphinx-theme.git@236d4af4b40edd607021e17342f4bf5870d64f0d instead of git+https://github.com/pydata/pydata-sphinx-theme.git@main. Here is the commit list: https://github.com/pydata/pydata-sphinx-theme/commits/main/

This way, we get stability without sacrificing search-as-you-type which imho is super nice when searching the docs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I’m not misstaken its been buggy for me the entire time. Have either of you experienced this? I am running ubuntu in wsl so that might be why if I’m the only one experiencing this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I noticed the same bugs, but if in even one setup they show up, it's better to go for the solution that solves them. it also looks like the current solution is more hands-free. if we select a specific commit from their github repo, once a new stable commit comes around, we need to change the signature to reference it. this is for sure something that will get forgotten

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might be able to find another theme with this built in. This feature seems to have been added around mid 2025. From my minor research it seems like the read-the-docs theme has this feature but I haven’t had time to verify or look for a more similar theme

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nicola has a good point about forgetting to update the version when the next release comes out. Let's keep the change.

Comment on lines +8 to +10
:private-members:
_sample_batch_indices,
_get_batch_data,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want user docs for these methods/properties, perhaps they should be public

Copy link
Contributor Author

@Simpag Simpag Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are not meant to be used outside of the cost function class but _get_batch_data is required to implement which should make use of _get_batch_indices. We thought that it would be nice to have those documented so users can see that they are required

@@ -0,0 +1,137 @@
from __future__ import annotations

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe out of scope for now but something to keep in mind:

Ruff doesn't enforce code doc for private modules (e.g. D101-D103). This is one of the reasons why the current structure is very flat and only has public modules (except for the new interoperability package). I think it's better to keep modules public until Ruff fixes this (see astral-sh/ruff#9946 and astral-sh/ruff#9561). An alternative is to configure pydocstyle (which supprts doc enforcement in private modules) as a complement to ruff but that feels out of scope for this PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, thanks for letting us know of this. let's keep this in mind for future PRs

@Simpag Simpag changed the title feat(Empirical): Add empirical risk cost feat(Empirical): Empirical risk cost Jan 26, 2026
@nicola-bastianello nicola-bastianello merged commit a918f5d into team-decent:main Jan 27, 2026
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Batching in cost functions

3 participants