From 69d751ce2d3d87ce9b70009d69ab416f39c351f6 Mon Sep 17 00:00:00 2001 From: rishabh-mondal Date: Sat, 28 Oct 2023 23:20:29 +0530 Subject: [PATCH 1/3] Max entropy acquasition for ensemble mc dropout --- astra/torch/al/acquisitions/max_entropy.py | 28 ++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 astra/torch/al/acquisitions/max_entropy.py diff --git a/astra/torch/al/acquisitions/max_entropy.py b/astra/torch/al/acquisitions/max_entropy.py new file mode 100644 index 0000000..5423fe5 --- /dev/null +++ b/astra/torch/al/acquisitions/max_entropy.py @@ -0,0 +1,28 @@ +import torch +from base import EnsembleAcquisition, MCAcquisition + +class MaxEntropy(EnsembleAcquisition, MCAcquisition): + def acquire_scores(self, logits: torch.Tensor): + + #this is ensemble strategy + + + if isinstance(self,EnsembleAcquisition): + #calculate entropy for each pool datapoint for each model + entropy=-torch.sum(logits*torch.log(logits),dim=2) + + score=torch.sum(entropy,dim=0) + + return score + + elif isinstance(self,MCAcquisition): + #calculate entropy for each pool datapoint accross all Monte Carlo samples + entropy=-torch.sum(logits*torch.log(logits),dim=2) + + score=torch.mean(entropy,dim=0) + + return score + + else: + raise NotImplementedError("Unknown acquisition strategy") + \ No newline at end of file From 790d60f3c9758941298824d59674d4a0a572523b Mon Sep 17 00:00:00 2001 From: rishabh-mondal Date: Sat, 28 Oct 2023 23:33:17 +0530 Subject: [PATCH 2/3] Max entropy acquasition for ensemble and mc drop --- astra/torch/al/acquisitions/max_entropy.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/astra/torch/al/acquisitions/max_entropy.py b/astra/torch/al/acquisitions/max_entropy.py index 5423fe5..bfa0247 100644 --- a/astra/torch/al/acquisitions/max_entropy.py +++ b/astra/torch/al/acquisitions/max_entropy.py @@ -9,7 +9,9 @@ def acquire_scores(self, logits: torch.Tensor): if isinstance(self,EnsembleAcquisition): #calculate entropy for each pool datapoint for each model - entropy=-torch.sum(logits*torch.log(logits),dim=2) + probs=torch.softmax(logits,dim=2) + entropy=-torch.sum(probs*torch.log(probs),dim=2) + score=torch.sum(entropy,dim=0) @@ -17,7 +19,9 @@ def acquire_scores(self, logits: torch.Tensor): elif isinstance(self,MCAcquisition): #calculate entropy for each pool datapoint accross all Monte Carlo samples - entropy=-torch.sum(logits*torch.log(logits),dim=2) + probs = torch.softmax(logits, dim=2) + + entropy=-torch.sum(probs*torch.log(probs),dim=2) score=torch.mean(entropy,dim=0) From 519fc6dc884dcb03c4fa4ed8c0859a807966afdd Mon Sep 17 00:00:00 2001 From: rishabh-mondal Date: Sun, 29 Oct 2023 00:36:25 +0530 Subject: [PATCH 3/3] max entropy acquasition and test case --- astra/torch/al/acquisitions/max_entropy.py | 26 ++---- .../torch/acquisitions/test_max_entropy.ipynb | 83 +++++++++++++++++++ 2 files changed, 88 insertions(+), 21 deletions(-) create mode 100644 tests/torch/acquisitions/test_max_entropy.ipynb diff --git a/astra/torch/al/acquisitions/max_entropy.py b/astra/torch/al/acquisitions/max_entropy.py index bfa0247..7f86a8f 100644 --- a/astra/torch/al/acquisitions/max_entropy.py +++ b/astra/torch/al/acquisitions/max_entropy.py @@ -4,29 +4,13 @@ class MaxEntropy(EnsembleAcquisition, MCAcquisition): def acquire_scores(self, logits: torch.Tensor): - #this is ensemble strategy + #calculate entropy for each pool datapoint for each model + probs=torch.softmax(logits,dim=2) + entropy=-torch.sum(probs*torch.log(probs),dim=2) - if isinstance(self,EnsembleAcquisition): - #calculate entropy for each pool datapoint for each model - probs=torch.softmax(logits,dim=2) - entropy=-torch.sum(probs*torch.log(probs),dim=2) + score=torch.sum(entropy,dim=0) + return score - score=torch.sum(entropy,dim=0) - - return score - - elif isinstance(self,MCAcquisition): - #calculate entropy for each pool datapoint accross all Monte Carlo samples - probs = torch.softmax(logits, dim=2) - - entropy=-torch.sum(probs*torch.log(probs),dim=2) - - score=torch.mean(entropy,dim=0) - - return score - - else: - raise NotImplementedError("Unknown acquisition strategy") \ No newline at end of file diff --git a/tests/torch/acquisitions/test_max_entropy.ipynb b/tests/torch/acquisitions/test_max_entropy.ipynb new file mode 100644 index 0000000..f5ae46d --- /dev/null +++ b/tests/torch/acquisitions/test_max_entropy.ipynb @@ -0,0 +1,83 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ensemble Acquisition Scores:\n", + "tensor([2.0118, 2.0402, 1.9574])\n" + ] + } + ], + "source": [ + "import torch\n", + "from astra.torch.al.acquisitions.base import MCAcquisition, EnsembleAcquisition\n", + "\n", + "\n", + "class MaxEntropyAcquisition(MCAcquisition, EnsembleAcquisition):\n", + " def acquire_scores(self, logits: torch.Tensor):\n", + " probs = torch.softmax(logits, dim=2)\n", + " entropy = -torch.sum(probs * torch.log(probs), dim=2)\n", + " score = torch.sum(entropy, dim=0)\n", + " return score\n", + "\n", + "\n", + "# Create an instance of MaxEntropyAcquisition\n", + "max_entropy_acquisition = MaxEntropyAcquisition()\n", + "\n", + "logits = torch.tensor(\n", + " [\n", + " [[0.2, 0.8], [0.7, 0.3], [0.4, 0.6]],\n", + " [[0.6, 0.4], [0.3, 0.7], [0.8, 0.2]],\n", + " [[0.3, 0.7], [0.5, 0.5], [0.9, 0.1]],\n", + " ],\n", + " dtype=torch.float32,\n", + ")\n", + "\n", + "\n", + "# Calculate acquisition scores using the ensemble context\n", + "ensemble_scores = max_entropy_acquisition.acquire_scores(logits)\n", + "\n", + "# Calculate acquisition scores using the Monte Carlo context\n", + "mc_scores = max_entropy_acquisition.acquire_scores(mc_logits)\n", + "\n", + "# Print the results\n", + "print(\"Ensemble Acquisition Scores:\")\n", + "print(ensemble_scores)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "torch_env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}