diff --git a/astra/torch/al/acquisitions/max_entropy.py b/astra/torch/al/acquisitions/max_entropy.py new file mode 100644 index 0000000..7f86a8f --- /dev/null +++ b/astra/torch/al/acquisitions/max_entropy.py @@ -0,0 +1,16 @@ +import torch +from base import EnsembleAcquisition, MCAcquisition + +class MaxEntropy(EnsembleAcquisition, MCAcquisition): + def acquire_scores(self, logits: torch.Tensor): + + #calculate entropy for each pool datapoint for each model + probs=torch.softmax(logits,dim=2) + entropy=-torch.sum(probs*torch.log(probs),dim=2) + + + score=torch.sum(entropy,dim=0) + + return score + + \ No newline at end of file diff --git a/tests/torch/acquisitions/test_max_entropy.ipynb b/tests/torch/acquisitions/test_max_entropy.ipynb new file mode 100644 index 0000000..f5ae46d --- /dev/null +++ b/tests/torch/acquisitions/test_max_entropy.ipynb @@ -0,0 +1,83 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ensemble Acquisition Scores:\n", + "tensor([2.0118, 2.0402, 1.9574])\n" + ] + } + ], + "source": [ + "import torch\n", + "from astra.torch.al.acquisitions.base import MCAcquisition, EnsembleAcquisition\n", + "\n", + "\n", + "class MaxEntropyAcquisition(MCAcquisition, EnsembleAcquisition):\n", + " def acquire_scores(self, logits: torch.Tensor):\n", + " probs = torch.softmax(logits, dim=2)\n", + " entropy = -torch.sum(probs * torch.log(probs), dim=2)\n", + " score = torch.sum(entropy, dim=0)\n", + " return score\n", + "\n", + "\n", + "# Create an instance of MaxEntropyAcquisition\n", + "max_entropy_acquisition = MaxEntropyAcquisition()\n", + "\n", + "logits = torch.tensor(\n", + " [\n", + " [[0.2, 0.8], [0.7, 0.3], [0.4, 0.6]],\n", + " [[0.6, 0.4], [0.3, 0.7], [0.8, 0.2]],\n", + " [[0.3, 0.7], [0.5, 0.5], [0.9, 0.1]],\n", + " ],\n", + " dtype=torch.float32,\n", + ")\n", + "\n", + "\n", + "# Calculate acquisition scores using the ensemble context\n", + "ensemble_scores = max_entropy_acquisition.acquire_scores(logits)\n", + "\n", + "# Calculate acquisition scores using the Monte Carlo context\n", + "mc_scores = max_entropy_acquisition.acquire_scores(mc_logits)\n", + "\n", + "# Print the results\n", + "print(\"Ensemble Acquisition Scores:\")\n", + "print(ensemble_scores)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "torch_env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}