Skip to content

Diversity Acquisition Functions#12

Open
jaiswalsuraj487 wants to merge 20 commits intosustainability-lab:mainfrom
jaiswalsuraj487:feature_diversity
Open

Diversity Acquisition Functions#12
jaiswalsuraj487 wants to merge 20 commits intosustainability-lab:mainfrom
jaiswalsuraj487:feature_diversity

Conversation

@jaiswalsuraj487
Copy link

Implemented Furthest Acquisition and Centroid Acquisition on commit a6e59ff.

Files Added:

  1. astra/torch/al/acquisitions/furthest.py: contain implementation of Furthest acquisition
  2. astra/torch/al/acquisitions/centroid.py: contain implementation of Centroid acquisition
  3. astra/torch/al/strategies/diversity.py: modified this file as per the need
  4. tests/torch/acquisitions/test_furthest.py: contains test for furthest acquisition function.
  5. tests/torch/acquisitions/test_centroid.py: contains test for centroid acquisition function.

Passes all test cases, including those already existing(commit: a6e59ff).

Explanation:

  1. furthest.py: For the furthest acquisition function, we use the furthest_first method of Class distil.active_learning_strategies.core_set.CoreSet link where we pass dummy object strategy as an argument along with labeled_embeddings, unlabeled_embeddings and n. This returns list of indices of n data points that are furthest from all.
  2. centroid.py: For the centroid acquisition function: For the Centroid Acquisition function, we pass labeled_embeddings, unlabeled_embeddings , and n as input.

Below lines initializes min_dist as tensor with all values infinity of size [len(n_pool)] when our n_train is 0.

    if labeled_embeddings.shape[0] == 0:
        min_dist = torch.full((unlabeled_embeddings.shape[0],), float("inf"))

Else we find centroid of train data and then pairwise distance between centroid and all pool data.

    else:
        centroid_embedding = torch.mean(labeled_embeddings, dim=0).unsqueeze(0)
        dist_ctr = torch.cdist(unlabeled_embeddings, centroid_embedding, p=2)
        min_dist = torch.min(dist_ctr, dim=1)[0]

We find index of n points from pool data, which has max distance.

    idxs = []
    for i in range(n):
        idx = torch.argmax(min_dist)
        idxs.append(idx.item())
        dist_new_ctr = torch.cdist(unlabeled_embeddings, unlabeled_embeddings[[idx], :])
        min_dist = torch.minimum(min_dist, dist_new_ctr[:, 0])
    return idxs
  1. diversity.py: Since the acquisition function implemented in link takes (unlabeled_embeddings, labeled_embeddings, n) as parameters, I did same and modified diversity.py instead of using (features, pool_indices, context_indices) suggested in diversity.py of sustainability-lab/ASTRA
  2. and 5. test_furthest.py and test_centroid: Used CIFAR10 to test. Here we want to pass features extractor of model instead of forward pass of model, so I implemented feature extractor as below:
# Define the model
net = CNN(32, 3, 3, [4, 8], [2, 3], 10).to(device)

def extract_features(net):
    def feature_extractor(input_tensor):
        # Initialize features with the input tensor
        features = input_tensor

        # Apply each layer, activation, and max-pooling
        for layer in net.feature_extractor:
            features = layer(features)
            features = net.activation(features)
            features = net.max_pool(features)
        features = net.flatten(features)
        return features

    return feature_extractor

# Create a feature extractor callable from the network
feature_extractor = extract_features(net)

This feature_extractor gives us features ie. embedding of input.

# example: this snippet is not included in code
# input shape: (data_dim, height, width, channels)
input = input.permute(0, 3, 1, 2) #input shape: (data_dim, channels, height, width)
features = feature_extractor(input) # shape (data_dim, feature_dim)

We then pass this feature_extractor in strategy.query() which gives best_indices based on furthest or centroid acquisition provided.

# Query the strategy
best_indices = strategy.query(
    feature_extractor, pool_indices, train_indices, n_query_samples=n_query_samples
)

@patel-zeel
Copy link
Member

@jaiswalsuraj487 Now that our plan is broadened, let's not use distil library. Use your own implementation. Can you visually show if your acquisition is picking the correct points?

@jaiswalsuraj487
Copy link
Author

@patel-zeel I have made the required changes as per the current version of sustainability-lab:main and added sandbox/diveristy_acquisition_demo.ipynb to show a visual of selected data points using corresponding acquisition functions on dummy data.

@jaiswalsuraj487
Copy link
Author

@patel-zeel Added AL notebook for diversity acquisitions notebooks/al/diversity_acq_AL.ipynb

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants