Skip to content

Mean_std Aquisition function#8

Open
VannshJani wants to merge 8 commits intosustainability-lab:mainfrom
VannshJani:main
Open

Mean_std Aquisition function#8
VannshJani wants to merge 8 commits intosustainability-lab:mainfrom
VannshJani:main

Conversation

@VannshJani
Copy link

Implementing mean_std aquisition function using ensemble and MC strategy

@coveralls
Copy link

Pull Request Test Coverage Report for Build 6698196722

  • 0 of 13 (0.0%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage decreased (-1.7%) to 58.531%

Changes Missing Coverage Covered Lines Changed/Added Lines %
astra/torch/al/acquisitions/Mean_std.py 0 13 0.0%
Totals Coverage Status
Change from base Build 6692078825: -1.7%
Covered Lines: 271
Relevant Lines: 463

💛 - Coveralls

# Mean-STD acquisition function
# (n_nets/n_mc_samples, pool_dim, n_classes) logits shape
pool_num = logits.shape[1]
assert len(logits.shape) == 3, "logits shape must be 3-Dimensional"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This lines goes first

# std = torch.std(logits, dim=0) # standard deviation over model parameters, shape (pool_dim, n_classes)
expectaion_of_squared = torch.mean(ab**2,dim=0)
expectation_squared = torch.mean(ab,dim=0)**2
std = torch.sqrt(expectation_of_squared - expectation_squared)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a direct method of calculating std in torch. You can use that. Also, logits should be converted to probs before using this method. What is ab?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had used the direct method first bhaiya but it does not produce the same result as calculating (E[x**2] - E[x]**2)**0.5. I manually verified this with an example.



# maximum mean standard deviation aquisition function
class Mean_std(EnsembleAcquisition,MCAcquisition):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Class names follow Camel Case i.e. MeanStd

expectation_squared = torch.mean(ab,dim=0)**2
std = torch.sqrt(expectation_of_squared - expectation_squared)
scores = torch.mean(std, dim=1) # mean over classes, shape (pool_dim)
assert len(scores.shape) == 1 and scores.shape[0]==pool_num, "scores shape must be 1-Dimensional and must have length equal to that of pool dataset"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not needed because, it is done by developer (You). We should have asserts to prevent users from passing invalid arguments.

@VannshJani
Copy link
Author

I have updated rest of the changes bhaiya.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants