Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python Debugger: Current File",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal"
}
]
}
93 changes: 93 additions & 0 deletions development-code/bitbirchX/.idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

37 changes: 35 additions & 2 deletions development-code/bitbirch_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@
### Joel Nothman <joel.nothman@gmail.com>
### License: BSD 3 clause

import time
import numpy as np
from scipy import sparse
from sklearn.cluster import AgglomerativeClustering, KMeans
from sklearn.metrics import pairwise_distances_argmin

def jt_distances(X):
"""Calculates the matrix of Tanimoto distances
Expand Down Expand Up @@ -439,6 +442,9 @@ class BitBirch():
subcluster_labels_ : ndarray
Labels assigned to the centroids of the subclusters after
they are clustered globally.

labels_ : ndarray of shape (n_samples,)
Array of labels assigned to the input data.

Notes
-----
Expand All @@ -464,6 +470,8 @@ def __init__(
n_clusters=3,
compute_labels=True,
copy=True,
perform_clustering=False,
clustering_type=""
):
self.threshold = threshold
self.branching_factor = branching_factor
Expand All @@ -472,6 +480,8 @@ def __init__(
self.copy = copy
self.index_tracker = 0
self.first_call = True
self.perform_clustering = perform_clustering
self.clustering_type = clustering_type

def fit(self, X, y=None):
"""
Expand Down Expand Up @@ -557,8 +567,8 @@ def _fit(self, X, partial):
self.subcluster_centers_ = centroids
self._n_features_out = self.subcluster_centers_.shape[0]

# TODO: Incorporate global_clustering option
#self._global_clustering(X)
if(self.perform_clustering):
self._global_clustering(X)
self.first_call = False
return self

Expand All @@ -581,3 +591,26 @@ def _get_leaves(self):
def retrieveVal(self):
print()

def _global_clustering(self, X):
"""
Global clustering for the subclusters obtained after fitting
"""
clusters = self.n_clusters
centroids = self.subcluster_centers_
clustering_type = self.clustering_type
compute_labels = (X is not None) and self.compute_labels

if clustering_type == "kmeans" and isinstance(clusters, int):
clusterer = KMeans(n_clusters=clusters)
self.subcluster_labels_ = clusterer.fit_predict(centroids)
elif clustering_type == "hierarchical" and isinstance(clusters, int):
clusterer = AgglomerativeClustering(n_clusters=clusters)
self.subcluster_labels_ = clusterer.fit_predict(centroids)
else:
# n_clusters is None and/or clustering_type == "" (skip global clustering)
self.subcluster_labels_ = np.arange(len(centroids))
return

if compute_labels:
argmin = pairwise_distances_argmin(X, centroids)
self.labels_ = self.subcluster_labels_[argmin]
2 changes: 2 additions & 0 deletions jt_fit_label.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
1000 0.386942
2000 0.20582