diff --git a/tfbpshiny/misc/binding_perturbation_heatmap_module.py b/tfbpshiny/misc/binding_perturbation_heatmap_module.py new file mode 100644 index 0000000..1a89b36 --- /dev/null +++ b/tfbpshiny/misc/binding_perturbation_heatmap_module.py @@ -0,0 +1,346 @@ +from logging import Logger +from typing import Literal + +import numpy as np +import pandas as pd +import plotly.express as px +import plotly.graph_objects as go +from shiny import Inputs, Outputs, Session, module, reactive, render, ui +from shinywidgets import output_widget, render_widget + +from ..utils.source_name_lookup import BindingSource, PerturbationSource + + +@module.ui +def heatmap_comparison_ui(): + return ui.div( + ui.row( + ui.column( + 8, + ui.card( + ui.card_header( + ui.div( + ui.output_text("heatmap_title"), + style="font-weight: bold; font-size: 1.1em;", + ) + ), + ui.card_body( + output_widget("comparison_heatmap"), style="height: 500px;" + ), + ui.card_footer( + ui.div( + ui.input_radio_buttons( + "comparison_type", + label="Comparison Type:", + choices={ + "regulators": "Number of Common Regulators", + "correlation": "Median Rank Correlation", + }, + selected="regulators", + inline=True, + ), + ui.p( + "Click on any cell to see detailed information. " + "Diagonal cells show single dataset statistics.", + style="margin: 5px 0 0 0; font-size: 0.9em; color:" + "#666;", + ), + ) + ), + ), + ), + ui.column( + 4, + ui.card( + ui.card_header("Selection Details"), + ui.card_body( + ui.output_ui("selection_details"), + style="max-height: 500px; overflow-y: auto;", + ), + ), + ), + ) + ) + + +@module.server +def heatmap_comparison_server( + input: Inputs, + output: Outputs, + session: Session, + *, + metadata_result: reactive.ExtendedTask, + source_name_dict: dict[str, str], + data_type: Literal["binding", "perturbation"], + correlation_data: pd.DataFrame | None = None, + logger: Logger, +) -> reactive.Value: + """ + Server for interactive heatmap comparisons between datasets. + + :param metadata_result: Reactive extended task with metadata + :param source_name_dict: Dictionary mapping source names to display names + :param data_type: Type of data ("binding" or "perturbation_response") + :param correlation_data: DataFrame with correlation data for correlation comparisons + :param logger: Logger object + :return: Reactive value with selected cell information + + """ + + # Store selected cell information + selected_cell = reactive.Value(None) + + @reactive.calc + def processed_metadata(): + """Process metadata and organize by source.""" + df = metadata_result.result() + if df.empty: + return {} + + # Filter for sources in our dict + df_filtered = df[df["source_name"].isin(source_name_dict.keys())] + + # Group regulators by source + source_regulators = {} + for source, display_name in source_name_dict.items(): + source_data = df_filtered[df_filtered["source_name"] == source] + regulators = set(source_data["regulator_symbol"].unique()) + source_regulators[display_name] = regulators + + return source_regulators + + @reactive.calc + def comparison_matrix(): + """Create comparison matrix based on selected type.""" + source_regulators = processed_metadata() + comparison_type = input.comparison_type() + + if not source_regulators: + return pd.DataFrame() + + # Define explicit order for sources dynamically from enums + # This controls both row order (top to bottom) and column order (left to right) + if data_type == "binding": + desired_order = [source.value for source in BindingSource] + elif data_type == "perturbation_response": + desired_order = [source.value for source in PerturbationSource] + else: + desired_order = [] # Fallback to alphabetical if unknown data type + + # Filter to only include sources that exist in the data and match desired order + sources = [source for source in desired_order if source in source_regulators] + + # Add any remaining sources not in the desired order (as fallback) + remaining_sources = [ + source for source in source_regulators.keys() if source not in sources + ] + sources.extend(sorted(remaining_sources)) + + n_sources = len(sources) + + # Initialize matrix + matrix = np.zeros((n_sources, n_sources)) + + # Fill the matrix based on whether the comparison is by regulators + # or correlation + for i, source1 in enumerate(sources): + for j, source2 in enumerate(sources): + if i == j: + # Diagonal: total regulators in single dataset + matrix[i, j] = ( + len(source_regulators[source1]) + if comparison_type == "regulators" + else 1.0 + ) + elif i > j: + # Off-diagonal: common regulators (lower triangular) + matrix[i, j] = ( + len(source_regulators[source1] & source_regulators[source2]) + if comparison_type == "regulators" + else calculate_median_correlation( + source1, source2, source_regulators, correlation_data + ) + ) + else: + matrix[i, j] = np.nan + + # Convert to DataFrame with explicit ordering + # index controls row order (top to bottom) + # columns controls column order (left to right) + df = pd.DataFrame(matrix, index=sources, columns=sources) + + return df + + def calculate_median_correlation( + source1: str, source2: str, source_regulators: dict, corr_data: pd.DataFrame + ) -> float: + """Calculate median correlation between two sources.""" + common_regulators = source_regulators[source1] & source_regulators[source2] + + if len(common_regulators) == 0 or corr_data is None: + return 0.0 + + # This is a placeholder - you'll need to implement based on your correlation + # data structure + # The correlation_data should contain pairwise correlations between + # regulators across sources + correlations: list[float] = [] + for regulator in common_regulators: + # Extract correlation for this regulator between the two sources + # This depends on how your correlation data is structured + pass + + return np.median(correlations) if correlations else 0.0 + + @render.text + def heatmap_title(): + comp_type = input.comparison_type() + data_name = data_type.replace("_", " ").title() + if comp_type == "regulators": + return f"{data_name} Dataset Overlap: Number of Common Regulators" + else: + return f"{data_name} Dataset Correlation: Median Rank Correlations" + + @render_widget + def comparison_heatmap(): + matrix_df = comparison_matrix() + comparison_type = input.comparison_type() + + if matrix_df.empty: + return px.scatter(title="No data available") + + # Create mask for lower triangular + diagonal + mask = np.triu(np.ones_like(matrix_df.values, dtype=bool), k=1) + matrix_masked = matrix_df.values.copy() + matrix_masked[mask] = np.nan + + # Convert to Python list and replace np.nan with None for JSON compliance + matrix_list: list[list[float | None]] = [] + text_list: list[list[str]] = [] + for row in matrix_masked: + matrix_row: list[float | None] = [] + text_row = [] + for val in row: + if np.isnan(val): + matrix_row.append(None) + text_row.append("") + else: + matrix_row.append(val) + text_row.append( + f"{int(val)}" + if comparison_type == "regulators" + else f"{val:.2f}" + ) + matrix_list.append(matrix_row) + text_list.append(text_row) + + # Create the heatmap + heatmap = go.Heatmap( + z=matrix_list, + x=matrix_df.columns.tolist(), + y=matrix_df.index.tolist(), + colorscale="Blues", + showscale=True, + hoverongaps=False, + text=text_list, + texttemplate="%{text}", + textfont={"size": 12}, + hovertemplate="%{y} vs %{x}
Value: %{z}", + connectgaps=False, + ) + + # Create figure with just the heatmap + fig = go.Figure(data=[heatmap]) + + # Update layout + fig.update_layout( + xaxis_title="Dataset", + yaxis_title="Dataset", + height=450, + margin=dict(l=100, r=50, t=50, b=100), + xaxis=dict( + tickangle=45, + side="bottom", + autorange=True, + ), + yaxis=dict( + autorange="reversed", # This ensures rows go top to bottom correctly + ), + plot_bgcolor="rgba(0,0,0,0)", + paper_bgcolor="white", + ) + + return fig + + @render.ui + def selection_details(): + """Generate details panel content based on selection.""" + cell_info = selected_cell.get() + logger.info(f"Selection details: {cell_info}") + source_regulators = processed_metadata() + + if not cell_info: + # Show all regulators when no selection + all_regulators = set() + for regs in source_regulators.values(): + all_regulators.update(regs) + + return ui.div( + ui.h5("All Regulators"), + ui.p(f"Total: {len(all_regulators)} regulators"), + ui.div( + create_searchable_list(sorted(all_regulators)), + style="max-height: 300px; overflow-y: auto;", + ), + ) + + x_source = cell_info["x"] + y_source = cell_info["y"] + value = cell_info["value"] + comp_type = cell_info["comparison_type"] + + if x_source == y_source: + # Diagonal cell - single dataset + regulators = sorted(source_regulators[x_source]) + return ui.div( + ui.h5(f"{x_source} Dataset"), + ui.p(f"Total regulators: {len(regulators)}"), + create_searchable_list(regulators), + ) + else: + # Off-diagonal cell - comparison + common_regs = sorted( + source_regulators[x_source] & source_regulators[y_source] + ) + + if comp_type == "regulators": + return ui.div( + ui.h5(f"{y_source} ∩ {x_source}"), + ui.p(f"Common regulators: {len(common_regs)}"), + create_searchable_list(common_regs), + ) + else: + # For correlation, show histogram (placeholder) + return ui.div( + ui.h5(f"Correlation: {y_source} vs {x_source}"), + ui.p(f"Median correlation: {value:.3f}"), + ui.p("Distribution histogram would go here"), + create_searchable_list(common_regs), + ) + + def create_searchable_list(items: list[str]) -> ui.Tag: + """Create a searchable list of items.""" + if not items: + return ui.p("No items to display") + + # For now, create a simple scrollable list + # Could be enhanced with actual search functionality + list_items = [ui.tags.li(item) for item in items] + + return ui.div( + ui.tags.ul(*list_items, style="list-style-type: none; padding-left: 0;"), + style="border: 1px solid #ddd; border-radius: 4px; padding: 10px; " + "background-color: #f9f9f9;", + ) + + return selected_cell diff --git a/tfbpshiny/tabs/binding_module.py b/tfbpshiny/tabs/binding_module.py index d3fcd9a..5144732 100644 --- a/tfbpshiny/tabs/binding_module.py +++ b/tfbpshiny/tabs/binding_module.py @@ -3,7 +3,10 @@ import pandas as pd from shiny import Inputs, Outputs, Session, module, reactive, ui -from ..misc.binding_perturbation_upset_module import upset_plot_server, upset_plot_ui +from ..misc.binding_perturbation_heatmap_module import ( + heatmap_comparison_server, + heatmap_comparison_ui, +) from ..misc.correlation_plot_module import ( correlation_matrix_server, correlation_matrix_ui, @@ -17,7 +20,7 @@ def binding_ui(): # First row: Description ui.div( ui.p( - "This page displays the UpSet plot and correlation matrix for " + "This page displays pairwise comparisons and correlation matrix for " "TF the binding datasets. The current binding datasets are: " ), ui.div( @@ -107,46 +110,43 @@ def binding_ui(): ), id="binding-description", ), - # Second row: Plot area container + # Second row: Heatmap comparison ui.div( - # Left: UpSet plot - ui.div( - ui.card( - ui.card_header("Binding UpSet Plot"), - ui.card_body(upset_plot_ui("binding_upset")), - ui.card_footer( - ui.p( - "Click any one of the sets to show what proportion of " - "the regulators in the selected set are also present " - "in the other sets.", - class_="text-muted", - ), + ui.card( + ui.card_header("Binding Dataset Comparisons"), + ui.card_body( + heatmap_comparison_ui("binding_heatmap"), + ), + ui.card_footer( + ui.p( + "Interactive heatmap showing pairwise comparisons between " + "binding datasets. Click cells to explore common regulators " + "or correlation distributions.", + class_="text-muted", ), ), - id="binding-upset-container", ), - # Right: Correlation matrix - ui.div( - ui.card( - ui.card_header("Binding Correlation Matrix"), - ui.card_body( - ui.div( - correlation_matrix_ui("binding_corr_matrix"), - id="binding-corr-plot-wrapper", - ), - id="binding-corr-body", + style="margin-bottom: 2rem;", + ), + # Third row: Correlation matrix + ui.div( + ui.card( + ui.card_header("Binding Correlation Matrix"), + ui.card_body( + ui.div( + correlation_matrix_ui("binding_corr_matrix"), + style="display: flex; justify-content: center; " + "align-items: center; height: 500px;", ), - ui.card_footer( - ui.p( - "Click and drag to zoom in on a specific region of the " - "correlation matrix. Double click to reset the zoom.", - class_="text-muted", - ), + ), + ui.card_footer( + ui.p( + "Click and drag to zoom in on a specific region of the " + "correlation matrix. Double click to reset the zoom.", + class_="text-muted", ), ), - id="binding-corr-container", ), - id="binding-plot-row", ), # Add styles at the bottom ui.tags.style( @@ -164,15 +164,6 @@ def binding_ui(): gap: 2rem; } - #binding-upset-container { - flex: 1.2; - min-width: 0; - min-height: 500px; - display: flex; - flex-direction: column; - height: 100%; - } - #binding-corr-container { flex: 0 0 500px; display: flex; @@ -222,33 +213,38 @@ def binding_server( *, binding_metadata_task: reactive.ExtendedTask, logger: Logger, -) -> reactive.calc: +) -> reactive.Value: """ This function produces the reactive/render functions necessary to producing the - binding upset plot and correlation matrix. + binding heatmap comparison and correlation matrix. :param binding_metadata_task: This is the result from a reactive.extended_task. Result can be retrieved with .result() :param logger: A logger object - :return: A reactive.calc with the metadata filtered for the selected upset plot sets - (note that this is not currently working b/c of something to do with the upset - plot server) + :return: A reactive.Value with the selected cell information from the heatmap """ - selected_binding_sets = upset_plot_server( - "binding_upset", - metadata_result=binding_metadata_task, - source_name_dict=get_source_name_dict("binding"), - logger=logger, - ) # TODO: retrieving the predictors should be from the db as a reactive.extended_task tf_binding_df = pd.read_csv("tmp/shiny_data/cc_predictors_normalized.csv") tf_binding_df.set_index("target_symbol", inplace=True) + + # Set up correlation matrix correlation_matrix_server( "binding_corr_matrix", tf_binding_df=tf_binding_df, logger=logger, ) - return selected_binding_sets + # Set up heatmap comparison + selected_cell = heatmap_comparison_server( + "binding_heatmap", + metadata_result=binding_metadata_task, + source_name_dict=get_source_name_dict("binding"), + data_type="binding", + correlation_data=tf_binding_df, # Pass correlation data for correlation + # comparisons + logger=logger, + ) + + return selected_cell diff --git a/tfbpshiny/tabs/perturbation_response_module.py b/tfbpshiny/tabs/perturbation_response_module.py index 416c014..18df43f 100644 --- a/tfbpshiny/tabs/perturbation_response_module.py +++ b/tfbpshiny/tabs/perturbation_response_module.py @@ -3,7 +3,10 @@ import pandas as pd from shiny import Inputs, Outputs, Session, module, reactive, ui -from ..misc.binding_perturbation_upset_module import upset_plot_server, upset_plot_ui +from ..misc.binding_perturbation_heatmap_module import ( + heatmap_comparison_server, + heatmap_comparison_ui, +) from ..misc.correlation_plot_module import ( correlation_matrix_server, correlation_matrix_ui, @@ -17,7 +20,7 @@ def perturbation_response_ui(): # First row: Description ui.div( ui.p( - "This page displays the UpSet plot and correlation matrix for " + "This page displays pairwise comparisons and correlation matrix for " "TF perturbation response datasets. The current datasets include " "data derived from gene deletions and overexpression methods." ), @@ -121,46 +124,43 @@ def perturbation_response_ui(): ), id="perturbation-description", ), - # Second row: Plot area container + # Second row: Heatmap comparison ui.div( - # Left: UpSet plot - ui.div( - ui.card( - ui.card_header("Perturbation Response UpSet Plot"), - ui.card_body(upset_plot_ui("perturbation_response_upset")), - ui.card_footer( - ui.p( - "Click any one of the sets to show what proportion of " - "the regulators in the selected set are also present " - "in the other sets.", - class_="text-muted", - ), + ui.card( + ui.card_header("Perturbation Response Dataset Comparisons"), + ui.card_body( + heatmap_comparison_ui("perturbation_heatmap"), + ), + ui.card_footer( + ui.p( + "Interactive heatmap showing pairwise comparisons between " + "perturbation response datasets. Click cells to explore " + "common regulators or correlation distributions.", + class_="text-muted", ), ), - id="perturbation-upset-container", ), - # Right: Correlation matrix - ui.div( - ui.card( - ui.card_header("Perturbation Response Correlation Matrix"), - ui.card_body( - ui.div( - correlation_matrix_ui("perturbation_corr_matrix"), - id="perturbation-corr-plot-wrapper", - ), - id="perturbation-corr-body", + style="margin-bottom: 2rem;", + ), + # Third row: Correlation matrix + ui.div( + ui.card( + ui.card_header("Perturbation Response Correlation Matrix"), + ui.card_body( + ui.div( + correlation_matrix_ui("perturbation_corr_matrix"), + style="display: flex; justify-content: center; " + "align-items: center; height: 500px;", ), - ui.card_footer( - ui.p( - "Click and drag to zoom in on a specific region of the " - "correlation matrix. Double click to reset the zoom.", - class_="text-muted", - ), + ), + ui.card_footer( + ui.p( + "Click and drag to zoom in on a specific region of the " + "correlation matrix. Double click to reset the zoom.", + class_="text-muted", ), ), - id="perturbation-corr-container", ), - id="perturbation-plot-row", ), # Add styles at the bottom ui.tags.style( @@ -178,15 +178,6 @@ def perturbation_response_ui(): gap: 2rem; } - #perturbation-upset-container { - flex: 1.2; - min-width: 0; - min-height: 500px; - display: flex; - flex-direction: column; - height: 100%; - } - #perturbation-corr-container { flex: 0 0 500px; display: flex; @@ -236,39 +227,42 @@ def perturbation_response_server( *, pr_metadata_task: reactive.ExtendedTask, logger: Logger, -) -> reactive.calc: +) -> reactive.Value: """ This function produces the reactive/render functions necessary to producing the - perturbation response upset plot and correlation matrix for the perturbation + perturbation response heatmap comparison and correlation matrix for the perturbation response data. :param pr_metadata_task: This is the result from a reactive.extended_task. Result can be retrieved with .result() :param logger: A logger object - :return: A reactive.calc with the metadata filtered for the selected upset plot sets - (note that this is not currently working b/c of something to do with the upset - plot server) + :return: A reactive.Value with the selected cell information from the heatmap """ # TODO: retrieving the response should be from the db as a reactive.extended_task tf_pr_df = pd.read_csv("tmp/shiny_data/response_data.csv") tf_pr_df.set_index("target_symbol", inplace=True) + + # Set up correlation matrix correlation_matrix_server( "perturbation_corr_matrix", tf_binding_df=tf_pr_df, logger=logger, ) - selected_pr_sets = upset_plot_server( - "perturbation_response_upset", + # Set up heatmap comparison + selected_cell = heatmap_comparison_server( + "perturbation_heatmap", metadata_result=pr_metadata_task, source_name_dict=get_source_name_dict("perturbation_response"), + data_type="perturbation_response", + correlation_data=tf_pr_df, # Pass correlation data for correlation comparisons logger=logger, ) @reactive.effect def _(): - logger.info(f"Selected perturbation response sets: {selected_pr_sets()}") + logger.info(f"Selected perturbation response cell: {selected_cell.get()}") - return selected_pr_sets + return selected_cell