Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions squeeze/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,8 @@ def spearman_distance_correlation(
The embedded data.

sample_size : int, optional
If provided, sample this many point pairs to compute correlation.
Useful for large datasets to speed up computation.
If provided, subsample this many points (rows) before computing the
correlation. Useful for large datasets to speed up computation.

Returns
-------
Expand All @@ -327,6 +327,17 @@ def spearman_distance_correlation(
"""
from scipy.stats import spearmanr

X = np.asarray(X)
X_embedded = np.asarray(X_embedded)

# Subsample points (not pairs) for large datasets. This makes the estimate
# significantly more stable than sampling arbitrary point-pairs.
n_samples = X.shape[0]
if sample_size is not None and n_samples > sample_size:
indices = np.random.choice(n_samples, size=sample_size, replace=False)
X = X[indices]
X_embedded = X_embedded[indices]

# Compute pairwise distances
D_original = pairwise_distances(X, metric="euclidean")
D_embedded = pairwise_distances(X_embedded, metric="euclidean")
Expand All @@ -338,14 +349,6 @@ def spearman_distance_correlation(
original_distances = D_original[indices]
embedded_distances = D_embedded[indices]

# If dataset is large, sample pairs
if sample_size is not None and len(original_distances) > sample_size:
sample_indices = np.random.choice(
len(original_distances), size=sample_size, replace=False
)
original_distances = original_distances[sample_indices]
embedded_distances = embedded_distances[sample_indices]

# Compute Spearman correlation
correlation, _ = spearmanr(original_distances, embedded_distances)

Expand Down
Loading