diff --git a/scripts/embedding_utils.py b/scripts/embedding_utils.py index 1d363321..daa323d4 100644 --- a/scripts/embedding_utils.py +++ b/scripts/embedding_utils.py @@ -57,9 +57,21 @@ def save_embeddings_from_structures( for mlip_name in mlip_names: graph_embeddings = [] node_embeddings = [] + if f"e_above_hull_{mlip_name}" in structures[0].properties.keys(): + stability = [] for structure in structures: # Extract graph embeddings + try: + if structure.properties[f"e_above_hull_{mlip_name}"] == 0.0: + stability.append([0,0,1]) + if 0.0 < structure.properties[f"e_above_hull_{mlip_name}"] <= 0.1: + stability.append([0,1,0]) + else: + stability.append([1,0,0]) + except: + pass + graph_emb_key = f"graph_embedding_{mlip_name}" if ( graph_emb_key in structure.properties @@ -94,7 +106,13 @@ def save_embeddings_from_structures( logger.info( f"Extracted {len(node_embeddings)} node embeddings for {mlip_name}" ) - + if stability: + embeddings[f"{mlip_name}_stability"] = np.array(stability) + if logger: + logger.info( + f"Extracted {len(stability)} stability for {mlip_name}" + ) + if embeddings: # Create filename with timestamp timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") @@ -184,60 +202,71 @@ def generate_embedding_plots( methods = ["pca", "tsne", "umap"] for emb_name, emb_array in embeddings.items(): - if logger: - logger.info(f"📊 Creating plots for {emb_name} ({emb_array.shape})") - - for method in methods: - try: - # Perform dimensionality reduction - if method == "pca": - reducer = PCA(n_components=2, random_state=42) - elif method == "tsne": - reducer = TSNE( - n_components=2, - random_state=42, - perplexity=min(30, len(emb_array) // 4), - ) - elif method == "umap": - # Reasonable defaults; n_neighbors scales mildly with dataset size - n_neighbors = max(5, min(30, len(emb_array) // 20)) - reducer = UMAP( - n_components=2, - random_state=42, - n_neighbors=n_neighbors, - min_dist=0.1, - metric="euclidean", - ) - else: - raise ValueError( - f"Unknown dimensionality reduction method: {method}" + if "stability" not in emb_name: + if logger: + logger.info(f"📊 Creating plots for {emb_name} ({emb_array.shape})") + + for method in methods: + try: + # Perform dimensionality reduction + if method == "pca": + reducer = PCA(n_components=2, random_state=42) + elif method == "tsne": + reducer = TSNE( + n_components=2, + random_state=42, + perplexity=min(30, len(emb_array) // 4), + ) + elif method == "umap": + # Reasonable defaults; n_neighbors scales mildly with dataset size + n_neighbors = max(5, min(30, len(emb_array) // 20)) + reducer = UMAP( + n_components=2, + random_state=42, + n_neighbors=n_neighbors, + min_dist=0.1, + metric="euclidean", + ) + else: + raise ValueError( + f"Unknown dimensionality reduction method: {method}" + ) + reduced = reducer.fit_transform(emb_array) + mlip_name = emb_name.split("_")[0] + + # Create plot + labels = ["Stable", "Metastable", "Unstable"] + colors = [[0,0,1], [0,1,0], [1,0,0]] + + handles = [plt.Line2D([], [], color=c, marker='o', linestyle='None', label=l) + for c, l in zip(colors, labels)] + + + color_map = embeddings[f"{mlip_name}_stability"] + + plt.figure(figsize=(10, 8)) + plt.scatter(reduced[:, 0], reduced[:, 1], alpha=0.6, s=30, c=color_map) + plt.legend(handles=handles, title="Stability") + plt.xlabel(f"{method.upper()} Component 1") + plt.ylabel(f"{method.upper()} Component 2") + plt.title( + f"{emb_name} Embeddings ({method.upper()})\n{len(emb_array)} structures" ) + plt.grid(True, alpha=0.3) - reduced = reducer.fit_transform(emb_array) - - # Create plot - plt.figure(figsize=(10, 8)) - plt.scatter(reduced[:, 0], reduced[:, 1], alpha=0.6, s=30) - plt.xlabel(f"{method.upper()} Component 1") - plt.ylabel(f"{method.upper()} Component 2") - plt.title( - f"{emb_name} Embeddings ({method.upper()})\n{len(emb_array)} structures" - ) - plt.grid(True, alpha=0.3) + # Save plot + plot_file = output_dir / f"{method}_{emb_name}_plot.png" + plt.savefig(plot_file, dpi=300, bbox_inches="tight") + plt.close() - # Save plot - plot_file = output_dir / f"{method}_{emb_name}_plot.png" - plt.savefig(plot_file, dpi=300, bbox_inches="tight") - plt.close() - - if logger: - logger.info(f"📊 Saved plot: {plot_file}") + if logger: + logger.info(f"📊 Saved plot: {plot_file}") - except Exception as e: - if logger: - logger.warning( - f"📊 Failed to create {method} plot for {emb_name}: {e}" - ) + except Exception as e: + if logger: + logger.warning( + f"📊 Failed to create {method} plot for {emb_name}: {e}" + ) - if logger: - logger.info(f"📊 All plots saved to: {output_dir}") + if logger: + logger.info(f"📊 All plots saved to: {output_dir}")