diff --git a/tfbpshiny/misc/binding_perturbation_heatmap_module.py b/tfbpshiny/misc/binding_perturbation_heatmap_module.py
new file mode 100644
index 0000000..1a89b36
--- /dev/null
+++ b/tfbpshiny/misc/binding_perturbation_heatmap_module.py
@@ -0,0 +1,346 @@
+from logging import Logger
+from typing import Literal
+
+import numpy as np
+import pandas as pd
+import plotly.express as px
+import plotly.graph_objects as go
+from shiny import Inputs, Outputs, Session, module, reactive, render, ui
+from shinywidgets import output_widget, render_widget
+
+from ..utils.source_name_lookup import BindingSource, PerturbationSource
+
+
+@module.ui
+def heatmap_comparison_ui():
+ return ui.div(
+ ui.row(
+ ui.column(
+ 8,
+ ui.card(
+ ui.card_header(
+ ui.div(
+ ui.output_text("heatmap_title"),
+ style="font-weight: bold; font-size: 1.1em;",
+ )
+ ),
+ ui.card_body(
+ output_widget("comparison_heatmap"), style="height: 500px;"
+ ),
+ ui.card_footer(
+ ui.div(
+ ui.input_radio_buttons(
+ "comparison_type",
+ label="Comparison Type:",
+ choices={
+ "regulators": "Number of Common Regulators",
+ "correlation": "Median Rank Correlation",
+ },
+ selected="regulators",
+ inline=True,
+ ),
+ ui.p(
+ "Click on any cell to see detailed information. "
+ "Diagonal cells show single dataset statistics.",
+ style="margin: 5px 0 0 0; font-size: 0.9em; color:"
+ "#666;",
+ ),
+ )
+ ),
+ ),
+ ),
+ ui.column(
+ 4,
+ ui.card(
+ ui.card_header("Selection Details"),
+ ui.card_body(
+ ui.output_ui("selection_details"),
+ style="max-height: 500px; overflow-y: auto;",
+ ),
+ ),
+ ),
+ )
+ )
+
+
+@module.server
+def heatmap_comparison_server(
+ input: Inputs,
+ output: Outputs,
+ session: Session,
+ *,
+ metadata_result: reactive.ExtendedTask,
+ source_name_dict: dict[str, str],
+ data_type: Literal["binding", "perturbation"],
+ correlation_data: pd.DataFrame | None = None,
+ logger: Logger,
+) -> reactive.Value:
+ """
+ Server for interactive heatmap comparisons between datasets.
+
+ :param metadata_result: Reactive extended task with metadata
+ :param source_name_dict: Dictionary mapping source names to display names
+ :param data_type: Type of data ("binding" or "perturbation_response")
+ :param correlation_data: DataFrame with correlation data for correlation comparisons
+ :param logger: Logger object
+ :return: Reactive value with selected cell information
+
+ """
+
+ # Store selected cell information
+ selected_cell = reactive.Value(None)
+
+ @reactive.calc
+ def processed_metadata():
+ """Process metadata and organize by source."""
+ df = metadata_result.result()
+ if df.empty:
+ return {}
+
+ # Filter for sources in our dict
+ df_filtered = df[df["source_name"].isin(source_name_dict.keys())]
+
+ # Group regulators by source
+ source_regulators = {}
+ for source, display_name in source_name_dict.items():
+ source_data = df_filtered[df_filtered["source_name"] == source]
+ regulators = set(source_data["regulator_symbol"].unique())
+ source_regulators[display_name] = regulators
+
+ return source_regulators
+
+ @reactive.calc
+ def comparison_matrix():
+ """Create comparison matrix based on selected type."""
+ source_regulators = processed_metadata()
+ comparison_type = input.comparison_type()
+
+ if not source_regulators:
+ return pd.DataFrame()
+
+ # Define explicit order for sources dynamically from enums
+ # This controls both row order (top to bottom) and column order (left to right)
+ if data_type == "binding":
+ desired_order = [source.value for source in BindingSource]
+ elif data_type == "perturbation_response":
+ desired_order = [source.value for source in PerturbationSource]
+ else:
+ desired_order = [] # Fallback to alphabetical if unknown data type
+
+ # Filter to only include sources that exist in the data and match desired order
+ sources = [source for source in desired_order if source in source_regulators]
+
+ # Add any remaining sources not in the desired order (as fallback)
+ remaining_sources = [
+ source for source in source_regulators.keys() if source not in sources
+ ]
+ sources.extend(sorted(remaining_sources))
+
+ n_sources = len(sources)
+
+ # Initialize matrix
+ matrix = np.zeros((n_sources, n_sources))
+
+ # Fill the matrix based on whether the comparison is by regulators
+ # or correlation
+ for i, source1 in enumerate(sources):
+ for j, source2 in enumerate(sources):
+ if i == j:
+ # Diagonal: total regulators in single dataset
+ matrix[i, j] = (
+ len(source_regulators[source1])
+ if comparison_type == "regulators"
+ else 1.0
+ )
+ elif i > j:
+ # Off-diagonal: common regulators (lower triangular)
+ matrix[i, j] = (
+ len(source_regulators[source1] & source_regulators[source2])
+ if comparison_type == "regulators"
+ else calculate_median_correlation(
+ source1, source2, source_regulators, correlation_data
+ )
+ )
+ else:
+ matrix[i, j] = np.nan
+
+ # Convert to DataFrame with explicit ordering
+ # index controls row order (top to bottom)
+ # columns controls column order (left to right)
+ df = pd.DataFrame(matrix, index=sources, columns=sources)
+
+ return df
+
+ def calculate_median_correlation(
+ source1: str, source2: str, source_regulators: dict, corr_data: pd.DataFrame
+ ) -> float:
+ """Calculate median correlation between two sources."""
+ common_regulators = source_regulators[source1] & source_regulators[source2]
+
+ if len(common_regulators) == 0 or corr_data is None:
+ return 0.0
+
+ # This is a placeholder - you'll need to implement based on your correlation
+ # data structure
+ # The correlation_data should contain pairwise correlations between
+ # regulators across sources
+ correlations: list[float] = []
+ for regulator in common_regulators:
+ # Extract correlation for this regulator between the two sources
+ # This depends on how your correlation data is structured
+ pass
+
+ return np.median(correlations) if correlations else 0.0
+
+ @render.text
+ def heatmap_title():
+ comp_type = input.comparison_type()
+ data_name = data_type.replace("_", " ").title()
+ if comp_type == "regulators":
+ return f"{data_name} Dataset Overlap: Number of Common Regulators"
+ else:
+ return f"{data_name} Dataset Correlation: Median Rank Correlations"
+
+ @render_widget
+ def comparison_heatmap():
+ matrix_df = comparison_matrix()
+ comparison_type = input.comparison_type()
+
+ if matrix_df.empty:
+ return px.scatter(title="No data available")
+
+ # Create mask for lower triangular + diagonal
+ mask = np.triu(np.ones_like(matrix_df.values, dtype=bool), k=1)
+ matrix_masked = matrix_df.values.copy()
+ matrix_masked[mask] = np.nan
+
+ # Convert to Python list and replace np.nan with None for JSON compliance
+ matrix_list: list[list[float | None]] = []
+ text_list: list[list[str]] = []
+ for row in matrix_masked:
+ matrix_row: list[float | None] = []
+ text_row = []
+ for val in row:
+ if np.isnan(val):
+ matrix_row.append(None)
+ text_row.append("")
+ else:
+ matrix_row.append(val)
+ text_row.append(
+ f"{int(val)}"
+ if comparison_type == "regulators"
+ else f"{val:.2f}"
+ )
+ matrix_list.append(matrix_row)
+ text_list.append(text_row)
+
+ # Create the heatmap
+ heatmap = go.Heatmap(
+ z=matrix_list,
+ x=matrix_df.columns.tolist(),
+ y=matrix_df.index.tolist(),
+ colorscale="Blues",
+ showscale=True,
+ hoverongaps=False,
+ text=text_list,
+ texttemplate="%{text}",
+ textfont={"size": 12},
+ hovertemplate="%{y} vs %{x}
Value: %{z}",
+ connectgaps=False,
+ )
+
+ # Create figure with just the heatmap
+ fig = go.Figure(data=[heatmap])
+
+ # Update layout
+ fig.update_layout(
+ xaxis_title="Dataset",
+ yaxis_title="Dataset",
+ height=450,
+ margin=dict(l=100, r=50, t=50, b=100),
+ xaxis=dict(
+ tickangle=45,
+ side="bottom",
+ autorange=True,
+ ),
+ yaxis=dict(
+ autorange="reversed", # This ensures rows go top to bottom correctly
+ ),
+ plot_bgcolor="rgba(0,0,0,0)",
+ paper_bgcolor="white",
+ )
+
+ return fig
+
+ @render.ui
+ def selection_details():
+ """Generate details panel content based on selection."""
+ cell_info = selected_cell.get()
+ logger.info(f"Selection details: {cell_info}")
+ source_regulators = processed_metadata()
+
+ if not cell_info:
+ # Show all regulators when no selection
+ all_regulators = set()
+ for regs in source_regulators.values():
+ all_regulators.update(regs)
+
+ return ui.div(
+ ui.h5("All Regulators"),
+ ui.p(f"Total: {len(all_regulators)} regulators"),
+ ui.div(
+ create_searchable_list(sorted(all_regulators)),
+ style="max-height: 300px; overflow-y: auto;",
+ ),
+ )
+
+ x_source = cell_info["x"]
+ y_source = cell_info["y"]
+ value = cell_info["value"]
+ comp_type = cell_info["comparison_type"]
+
+ if x_source == y_source:
+ # Diagonal cell - single dataset
+ regulators = sorted(source_regulators[x_source])
+ return ui.div(
+ ui.h5(f"{x_source} Dataset"),
+ ui.p(f"Total regulators: {len(regulators)}"),
+ create_searchable_list(regulators),
+ )
+ else:
+ # Off-diagonal cell - comparison
+ common_regs = sorted(
+ source_regulators[x_source] & source_regulators[y_source]
+ )
+
+ if comp_type == "regulators":
+ return ui.div(
+ ui.h5(f"{y_source} ∩ {x_source}"),
+ ui.p(f"Common regulators: {len(common_regs)}"),
+ create_searchable_list(common_regs),
+ )
+ else:
+ # For correlation, show histogram (placeholder)
+ return ui.div(
+ ui.h5(f"Correlation: {y_source} vs {x_source}"),
+ ui.p(f"Median correlation: {value:.3f}"),
+ ui.p("Distribution histogram would go here"),
+ create_searchable_list(common_regs),
+ )
+
+ def create_searchable_list(items: list[str]) -> ui.Tag:
+ """Create a searchable list of items."""
+ if not items:
+ return ui.p("No items to display")
+
+ # For now, create a simple scrollable list
+ # Could be enhanced with actual search functionality
+ list_items = [ui.tags.li(item) for item in items]
+
+ return ui.div(
+ ui.tags.ul(*list_items, style="list-style-type: none; padding-left: 0;"),
+ style="border: 1px solid #ddd; border-radius: 4px; padding: 10px; "
+ "background-color: #f9f9f9;",
+ )
+
+ return selected_cell
diff --git a/tfbpshiny/tabs/binding_module.py b/tfbpshiny/tabs/binding_module.py
index d3fcd9a..5144732 100644
--- a/tfbpshiny/tabs/binding_module.py
+++ b/tfbpshiny/tabs/binding_module.py
@@ -3,7 +3,10 @@
import pandas as pd
from shiny import Inputs, Outputs, Session, module, reactive, ui
-from ..misc.binding_perturbation_upset_module import upset_plot_server, upset_plot_ui
+from ..misc.binding_perturbation_heatmap_module import (
+ heatmap_comparison_server,
+ heatmap_comparison_ui,
+)
from ..misc.correlation_plot_module import (
correlation_matrix_server,
correlation_matrix_ui,
@@ -17,7 +20,7 @@ def binding_ui():
# First row: Description
ui.div(
ui.p(
- "This page displays the UpSet plot and correlation matrix for "
+ "This page displays pairwise comparisons and correlation matrix for "
"TF the binding datasets. The current binding datasets are: "
),
ui.div(
@@ -107,46 +110,43 @@ def binding_ui():
),
id="binding-description",
),
- # Second row: Plot area container
+ # Second row: Heatmap comparison
ui.div(
- # Left: UpSet plot
- ui.div(
- ui.card(
- ui.card_header("Binding UpSet Plot"),
- ui.card_body(upset_plot_ui("binding_upset")),
- ui.card_footer(
- ui.p(
- "Click any one of the sets to show what proportion of "
- "the regulators in the selected set are also present "
- "in the other sets.",
- class_="text-muted",
- ),
+ ui.card(
+ ui.card_header("Binding Dataset Comparisons"),
+ ui.card_body(
+ heatmap_comparison_ui("binding_heatmap"),
+ ),
+ ui.card_footer(
+ ui.p(
+ "Interactive heatmap showing pairwise comparisons between "
+ "binding datasets. Click cells to explore common regulators "
+ "or correlation distributions.",
+ class_="text-muted",
),
),
- id="binding-upset-container",
),
- # Right: Correlation matrix
- ui.div(
- ui.card(
- ui.card_header("Binding Correlation Matrix"),
- ui.card_body(
- ui.div(
- correlation_matrix_ui("binding_corr_matrix"),
- id="binding-corr-plot-wrapper",
- ),
- id="binding-corr-body",
+ style="margin-bottom: 2rem;",
+ ),
+ # Third row: Correlation matrix
+ ui.div(
+ ui.card(
+ ui.card_header("Binding Correlation Matrix"),
+ ui.card_body(
+ ui.div(
+ correlation_matrix_ui("binding_corr_matrix"),
+ style="display: flex; justify-content: center; "
+ "align-items: center; height: 500px;",
),
- ui.card_footer(
- ui.p(
- "Click and drag to zoom in on a specific region of the "
- "correlation matrix. Double click to reset the zoom.",
- class_="text-muted",
- ),
+ ),
+ ui.card_footer(
+ ui.p(
+ "Click and drag to zoom in on a specific region of the "
+ "correlation matrix. Double click to reset the zoom.",
+ class_="text-muted",
),
),
- id="binding-corr-container",
),
- id="binding-plot-row",
),
# Add styles at the bottom
ui.tags.style(
@@ -164,15 +164,6 @@ def binding_ui():
gap: 2rem;
}
- #binding-upset-container {
- flex: 1.2;
- min-width: 0;
- min-height: 500px;
- display: flex;
- flex-direction: column;
- height: 100%;
- }
-
#binding-corr-container {
flex: 0 0 500px;
display: flex;
@@ -222,33 +213,38 @@ def binding_server(
*,
binding_metadata_task: reactive.ExtendedTask,
logger: Logger,
-) -> reactive.calc:
+) -> reactive.Value:
"""
This function produces the reactive/render functions necessary to producing the
- binding upset plot and correlation matrix.
+ binding heatmap comparison and correlation matrix.
:param binding_metadata_task: This is the result from a reactive.extended_task.
Result can be retrieved with .result()
:param logger: A logger object
- :return: A reactive.calc with the metadata filtered for the selected upset plot sets
- (note that this is not currently working b/c of something to do with the upset
- plot server)
+ :return: A reactive.Value with the selected cell information from the heatmap
"""
- selected_binding_sets = upset_plot_server(
- "binding_upset",
- metadata_result=binding_metadata_task,
- source_name_dict=get_source_name_dict("binding"),
- logger=logger,
- )
# TODO: retrieving the predictors should be from the db as a reactive.extended_task
tf_binding_df = pd.read_csv("tmp/shiny_data/cc_predictors_normalized.csv")
tf_binding_df.set_index("target_symbol", inplace=True)
+
+ # Set up correlation matrix
correlation_matrix_server(
"binding_corr_matrix",
tf_binding_df=tf_binding_df,
logger=logger,
)
- return selected_binding_sets
+ # Set up heatmap comparison
+ selected_cell = heatmap_comparison_server(
+ "binding_heatmap",
+ metadata_result=binding_metadata_task,
+ source_name_dict=get_source_name_dict("binding"),
+ data_type="binding",
+ correlation_data=tf_binding_df, # Pass correlation data for correlation
+ # comparisons
+ logger=logger,
+ )
+
+ return selected_cell
diff --git a/tfbpshiny/tabs/perturbation_response_module.py b/tfbpshiny/tabs/perturbation_response_module.py
index 416c014..18df43f 100644
--- a/tfbpshiny/tabs/perturbation_response_module.py
+++ b/tfbpshiny/tabs/perturbation_response_module.py
@@ -3,7 +3,10 @@
import pandas as pd
from shiny import Inputs, Outputs, Session, module, reactive, ui
-from ..misc.binding_perturbation_upset_module import upset_plot_server, upset_plot_ui
+from ..misc.binding_perturbation_heatmap_module import (
+ heatmap_comparison_server,
+ heatmap_comparison_ui,
+)
from ..misc.correlation_plot_module import (
correlation_matrix_server,
correlation_matrix_ui,
@@ -17,7 +20,7 @@ def perturbation_response_ui():
# First row: Description
ui.div(
ui.p(
- "This page displays the UpSet plot and correlation matrix for "
+ "This page displays pairwise comparisons and correlation matrix for "
"TF perturbation response datasets. The current datasets include "
"data derived from gene deletions and overexpression methods."
),
@@ -121,46 +124,43 @@ def perturbation_response_ui():
),
id="perturbation-description",
),
- # Second row: Plot area container
+ # Second row: Heatmap comparison
ui.div(
- # Left: UpSet plot
- ui.div(
- ui.card(
- ui.card_header("Perturbation Response UpSet Plot"),
- ui.card_body(upset_plot_ui("perturbation_response_upset")),
- ui.card_footer(
- ui.p(
- "Click any one of the sets to show what proportion of "
- "the regulators in the selected set are also present "
- "in the other sets.",
- class_="text-muted",
- ),
+ ui.card(
+ ui.card_header("Perturbation Response Dataset Comparisons"),
+ ui.card_body(
+ heatmap_comparison_ui("perturbation_heatmap"),
+ ),
+ ui.card_footer(
+ ui.p(
+ "Interactive heatmap showing pairwise comparisons between "
+ "perturbation response datasets. Click cells to explore "
+ "common regulators or correlation distributions.",
+ class_="text-muted",
),
),
- id="perturbation-upset-container",
),
- # Right: Correlation matrix
- ui.div(
- ui.card(
- ui.card_header("Perturbation Response Correlation Matrix"),
- ui.card_body(
- ui.div(
- correlation_matrix_ui("perturbation_corr_matrix"),
- id="perturbation-corr-plot-wrapper",
- ),
- id="perturbation-corr-body",
+ style="margin-bottom: 2rem;",
+ ),
+ # Third row: Correlation matrix
+ ui.div(
+ ui.card(
+ ui.card_header("Perturbation Response Correlation Matrix"),
+ ui.card_body(
+ ui.div(
+ correlation_matrix_ui("perturbation_corr_matrix"),
+ style="display: flex; justify-content: center; "
+ "align-items: center; height: 500px;",
),
- ui.card_footer(
- ui.p(
- "Click and drag to zoom in on a specific region of the "
- "correlation matrix. Double click to reset the zoom.",
- class_="text-muted",
- ),
+ ),
+ ui.card_footer(
+ ui.p(
+ "Click and drag to zoom in on a specific region of the "
+ "correlation matrix. Double click to reset the zoom.",
+ class_="text-muted",
),
),
- id="perturbation-corr-container",
),
- id="perturbation-plot-row",
),
# Add styles at the bottom
ui.tags.style(
@@ -178,15 +178,6 @@ def perturbation_response_ui():
gap: 2rem;
}
- #perturbation-upset-container {
- flex: 1.2;
- min-width: 0;
- min-height: 500px;
- display: flex;
- flex-direction: column;
- height: 100%;
- }
-
#perturbation-corr-container {
flex: 0 0 500px;
display: flex;
@@ -236,39 +227,42 @@ def perturbation_response_server(
*,
pr_metadata_task: reactive.ExtendedTask,
logger: Logger,
-) -> reactive.calc:
+) -> reactive.Value:
"""
This function produces the reactive/render functions necessary to producing the
- perturbation response upset plot and correlation matrix for the perturbation
+ perturbation response heatmap comparison and correlation matrix for the perturbation
response data.
:param pr_metadata_task: This is the result from a reactive.extended_task. Result
can be retrieved with .result()
:param logger: A logger object
- :return: A reactive.calc with the metadata filtered for the selected upset plot sets
- (note that this is not currently working b/c of something to do with the upset
- plot server)
+ :return: A reactive.Value with the selected cell information from the heatmap
"""
# TODO: retrieving the response should be from the db as a reactive.extended_task
tf_pr_df = pd.read_csv("tmp/shiny_data/response_data.csv")
tf_pr_df.set_index("target_symbol", inplace=True)
+
+ # Set up correlation matrix
correlation_matrix_server(
"perturbation_corr_matrix",
tf_binding_df=tf_pr_df,
logger=logger,
)
- selected_pr_sets = upset_plot_server(
- "perturbation_response_upset",
+ # Set up heatmap comparison
+ selected_cell = heatmap_comparison_server(
+ "perturbation_heatmap",
metadata_result=pr_metadata_task,
source_name_dict=get_source_name_dict("perturbation_response"),
+ data_type="perturbation_response",
+ correlation_data=tf_pr_df, # Pass correlation data for correlation comparisons
logger=logger,
)
@reactive.effect
def _():
- logger.info(f"Selected perturbation response sets: {selected_pr_sets()}")
+ logger.info(f"Selected perturbation response cell: {selected_cell.get()}")
- return selected_pr_sets
+ return selected_cell