Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ If you want to contribute to `superintendent`, you will need to install the test
dependencies as well. You can do so with
`pip install superintendent[tests,examples]`


## Acknowledgements
## Acknowledgements

Much of the initial work on `superintendent` was done during my time at
[Faculty AI](https://faculty.ai/).
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ description-file = "README.md"
classifiers = [
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"License :: OSI Approved :: MIT License",
"Intended Audience :: Science/Research",
"Framework :: Jupyter",
Expand All @@ -34,7 +35,8 @@ requires = [
"psycopg2-binary>=2.8",
"flask>=1.0",
"ipyevents>=0.6.0",
"typing-extensions"
"typing-extensions",
"merge-args",
]

[tool.flit.metadata.requires-extra]
Expand Down
90 changes: 10 additions & 80 deletions src/superintendent/acquisition_functions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,92 +1,22 @@
"""
Functions to prioritise labelling data points (to drive active learning).
"""
from typing import Dict, Callable
import numpy as np
import scipy.stats

from .decorators import make_acquisition_function

__all__ = ["entropy", "margin", "certainty"]


@make_acquisition_function(handle_multioutput=None) # noqa: D002
def entropy(probabilities: np.ndarray) -> np.ndarray:
"""
Sort by the entropy of the probabilities (high to low).

Parameters
----------
probabilities : np.ndarray
An array of probabilities, with the shape n_samples,
n_classes

Other Parameters
----------------
shuffle_prop : float (default=0.1)
The proportion of data points that should be randomly shuffled. This
means the sorting retains some randomness, to avoid biasing your
new labels and catching any minority classes the algorithm currently
classifies as a different label.

"""
neg_entropy = -scipy.stats.entropy(probabilities.T)
return neg_entropy

from typing import Dict, Callable, Union, List
import numpy as np
from .functions import entropy, margin, certainty, bald, random

@make_acquisition_function(handle_multioutput="mean") # noqa: D002
def margin(probabilities: np.ndarray) -> np.ndarray:
"""
Sort by the margin between the top two predictions (low to high).

Parameters
----------
probabilities : np.ndarray
An array of probabilities, with the shape n_samples,
n_classes

Other Parameters
----------------
shuffle_prop : float
The proportion of data points that should be randomly shuffled. This
means the sorting retains some randomness, to avoid biasing your
new labels and catching any minority classes the algorithm currently
classifies as a different label.
"""
margin = (
np.sort(probabilities, axis=1)[:, -1]
- np.sort(probabilities, axis=1)[:, -2]
)
return margin


@make_acquisition_function(handle_multioutput="mean") # noqa: D002
def certainty(probabilities: np.ndarray):
"""
Sort by the certainty of the maximum prediction.

Parameters
----------
probabilities : np.ndarray
An array of probabilities, with the shape n_samples,
n_classes

Other Parameters
----------------
shuffle_prop : float
The proportion of data points that should be randomly shuffled. This
means the sorting retains some randomness, to avoid biasing your
new labels and catching any minority classes the algorithm currently
classifies as a different label.

"""
certainty = probabilities.max(axis=-1)
return certainty
__all__ = ["entropy", "margin", "certainty", "bald", "functions", "random"]

AcquisitionFunction = Callable[
[Union[np.ndarray, List[np.ndarray]]], np.ndarray
]

functions: Dict[str, Callable] = {
functions: Dict[str, AcquisitionFunction] = {
"entropy": entropy,
"margin": margin,
"certainty": certainty,
"bald": bald,
"random": random,
}
"""A dictionary of functions to prioritise data."""
92 changes: 86 additions & 6 deletions src/superintendent/acquisition_functions/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
import typing

import numpy as np
from merge_args import merge_args


def _dummy_fn(shuffle_prop: float = 0.1):
...


def _shuffle_subset(data: np.ndarray, shuffle_prop: float) -> np.ndarray:
Expand All @@ -25,6 +30,7 @@ def _get_indices(scores: np.ndarray, shuffle_prop: float) -> np.ndarray:
def _is_multioutput(
probabilities: typing.Union[np.ndarray, typing.List[np.ndarray]]
):
"""Test whether predictions are for single- or multi-output"""
if isinstance(probabilities, list) and (
isinstance(probabilities[0], np.ndarray)
and probabilities[0].ndim == 2
Expand All @@ -37,6 +43,69 @@ def _is_multioutput(
raise ValueError("Unknown probability format.")


def _is_distribution(probabilities: np.ndarray):
"""
Test whether predictions are single value per outcome, or a distribution.
"""
if _is_multioutput(probabilities):
return _is_distribution(probabilities[0])
else:
return probabilities.ndim > 2


multioutput_reduce_fns = {"mean": np.mean, "max": np.max}


def require_point_estimate(fn: typing.Callable) -> typing.Callable:
"""
Mark a function as requiring point estimate predictions.

If distributions of predictions get passed, the distribution will be
averaged first.

Parameters
----------
fn
The function to decorate.
"""

@functools.wraps(fn)
def wrapped_fn(probabilities: np.ndarray, *args, **kwargs):
if _is_distribution(probabilities):
if _is_multioutput(probabilities):
probabilities = [p.mean(axis=-1) for p in probabilities]
else:
probabilities = probabilities.mean(axis=-1)
return fn(probabilities, *args, **kwargs)

return wrapped_fn


def require_distribution(fn: typing.Callable) -> typing.Callable:
"""
Mark a function as requiring distribution output.

If non-distribution output gets passed, this function will now raise an
error.

Parameters
----------
fn
The function to decorate.
"""

@functools.wraps(fn)
def wrapped_fn(probabilities, *args, **kwargs):
if not _is_distribution(probabilities):
raise ValueError(
f"Acquisition function {fn.__name__} "
"requires distribution output."
)
return fn(probabilities, *args, **kwargs)

return wrapped_fn


def make_acquisition_function(handle_multioutput="mean"):
"""Wrap an acquisition function.

Expand All @@ -50,23 +119,34 @@ def make_acquisition_function(handle_multioutput="mean"):
comes as a list of binary classifier outputs.
"""

def decorator(fn):
if handle_multioutput == "mean": # define fn where scores are avgd
def decorator(
fn: typing.Callable[[np.ndarray], np.ndarray]
) -> typing.Callable[[np.ndarray, float], np.ndarray]:
if handle_multioutput is not None:

reduce_fn = multioutput_reduce_fns[handle_multioutput]

@merge_args(_dummy_fn)
@functools.wraps(fn)
def wrapped_fn(probabilities, shuffle_prop=0.1):
def wrapped_fn(
probabilities: np.ndarray, shuffle_prop: float = 0.1
):
if _is_multioutput(probabilities):
scores = np.stack(
tuple(fn(prob) for prob in probabilities), axis=0
).mean(axis=0)
tuple(fn(prob) for prob in probabilities), axis=0,
)
scores = reduce_fn(scores, axis=0)
else:
scores = fn(probabilities)
return _get_indices(scores, shuffle_prop)

else: # raise error if list is passed

@merge_args(_dummy_fn)
@functools.wraps(fn)
def wrapped_fn(probabilities, shuffle_prop=0.1):
def wrapped_fn(
probabilities: np.ndarray, shuffle_prop: float = 0.1
):
if _is_multioutput(probabilities):
raise ValueError(
"The input probabilities is a list of arrays, "
Expand Down
Loading