Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 82 additions & 53 deletions scripts/embedding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,21 @@ def save_embeddings_from_structures(
for mlip_name in mlip_names:
graph_embeddings = []
node_embeddings = []
if f"e_above_hull_{mlip_name}" in structures[0].properties.keys():
stability = []

for structure in structures:
# Extract graph embeddings
try:
if structure.properties[f"e_above_hull_{mlip_name}"] == 0.0:
stability.append([0,0,1])
if 0.0 < structure.properties[f"e_above_hull_{mlip_name}"] <= 0.1:
stability.append([0,1,0])
else:
stability.append([1,0,0])
except:
pass

graph_emb_key = f"graph_embedding_{mlip_name}"
if (
graph_emb_key in structure.properties
Expand Down Expand Up @@ -94,7 +106,13 @@ def save_embeddings_from_structures(
logger.info(
f"Extracted {len(node_embeddings)} node embeddings for {mlip_name}"
)

if stability:
embeddings[f"{mlip_name}_stability"] = np.array(stability)
if logger:
logger.info(
f"Extracted {len(stability)} stability for {mlip_name}"
)

if embeddings:
# Create filename with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
Expand Down Expand Up @@ -184,60 +202,71 @@ def generate_embedding_plots(
methods = ["pca", "tsne", "umap"]

for emb_name, emb_array in embeddings.items():
if logger:
logger.info(f"📊 Creating plots for {emb_name} ({emb_array.shape})")

for method in methods:
try:
# Perform dimensionality reduction
if method == "pca":
reducer = PCA(n_components=2, random_state=42)
elif method == "tsne":
reducer = TSNE(
n_components=2,
random_state=42,
perplexity=min(30, len(emb_array) // 4),
)
elif method == "umap":
# Reasonable defaults; n_neighbors scales mildly with dataset size
n_neighbors = max(5, min(30, len(emb_array) // 20))
reducer = UMAP(
n_components=2,
random_state=42,
n_neighbors=n_neighbors,
min_dist=0.1,
metric="euclidean",
)
else:
raise ValueError(
f"Unknown dimensionality reduction method: {method}"
if "stability" not in emb_name:
if logger:
logger.info(f"📊 Creating plots for {emb_name} ({emb_array.shape})")

for method in methods:
try:
# Perform dimensionality reduction
if method == "pca":
reducer = PCA(n_components=2, random_state=42)
elif method == "tsne":
reducer = TSNE(
n_components=2,
random_state=42,
perplexity=min(30, len(emb_array) // 4),
)
elif method == "umap":
# Reasonable defaults; n_neighbors scales mildly with dataset size
n_neighbors = max(5, min(30, len(emb_array) // 20))
reducer = UMAP(
n_components=2,
random_state=42,
n_neighbors=n_neighbors,
min_dist=0.1,
metric="euclidean",
)
else:
raise ValueError(
f"Unknown dimensionality reduction method: {method}"
)
reduced = reducer.fit_transform(emb_array)
mlip_name = emb_name.split("_")[0]

# Create plot
labels = ["Stable", "Metastable", "Unstable"]
colors = [[0,0,1], [0,1,0], [1,0,0]]

handles = [plt.Line2D([], [], color=c, marker='o', linestyle='None', label=l)
for c, l in zip(colors, labels)]


color_map = embeddings[f"{mlip_name}_stability"]

plt.figure(figsize=(10, 8))
plt.scatter(reduced[:, 0], reduced[:, 1], alpha=0.6, s=30, c=color_map)
plt.legend(handles=handles, title="Stability")
plt.xlabel(f"{method.upper()} Component 1")
plt.ylabel(f"{method.upper()} Component 2")
plt.title(
f"{emb_name} Embeddings ({method.upper()})\n{len(emb_array)} structures"
)
plt.grid(True, alpha=0.3)

reduced = reducer.fit_transform(emb_array)

# Create plot
plt.figure(figsize=(10, 8))
plt.scatter(reduced[:, 0], reduced[:, 1], alpha=0.6, s=30)
plt.xlabel(f"{method.upper()} Component 1")
plt.ylabel(f"{method.upper()} Component 2")
plt.title(
f"{emb_name} Embeddings ({method.upper()})\n{len(emb_array)} structures"
)
plt.grid(True, alpha=0.3)
# Save plot
plot_file = output_dir / f"{method}_{emb_name}_plot.png"
plt.savefig(plot_file, dpi=300, bbox_inches="tight")
plt.close()

# Save plot
plot_file = output_dir / f"{method}_{emb_name}_plot.png"
plt.savefig(plot_file, dpi=300, bbox_inches="tight")
plt.close()

if logger:
logger.info(f"📊 Saved plot: {plot_file}")
if logger:
logger.info(f"📊 Saved plot: {plot_file}")

except Exception as e:
if logger:
logger.warning(
f"📊 Failed to create {method} plot for {emb_name}: {e}"
)
except Exception as e:
if logger:
logger.warning(
f"📊 Failed to create {method} plot for {emb_name}: {e}"
)

if logger:
logger.info(f"📊 All plots saved to: {output_dir}")
if logger:
logger.info(f"📊 All plots saved to: {output_dir}")
Loading