diff --git a/docs/huggingface_datacard.md b/docs/huggingface_datacard.md index c41b062..3049ee5 100644 --- a/docs/huggingface_datacard.md +++ b/docs/huggingface_datacard.md @@ -157,6 +157,173 @@ configs: # ... rest of config ``` +### Automatic Metadata Joins + +**NEW**: When metadata is stored in separate files (using `applies_to`), the system automatically infers join keys from common columns and enables automatic metadata joins in queries. + +```yaml +configs: +# Data config +- config_name: binding_data + dataset_type: annotated_features + data_files: + - split: train + path: binding_scores.parquet + dataset_info: + features: + - name: sample_id + dtype: string + description: Sample identifier + - name: gene_id + dtype: string + description: Gene identifier + - name: binding_score + dtype: float64 + description: Binding score value + +# Metadata config - join keys inferred from common columns +- config_name: experiment_metadata + dataset_type: metadata + applies_to: ["binding_data"] + data_files: + - split: train + path: metadata.parquet + dataset_info: + features: + - name: sample_id # Common with binding_data - used for JOIN + dtype: string + description: Sample identifier + - name: cell_type + dtype: string + description: Cell type used in experiment + - name: treatment + dtype: string + description: Treatment condition +``` + +With this configuration, you can query metadata fields directly without manually writing JOINs: + +```python +from tfbpapi import HfQueryAPI + +api = HfQueryAPI("username/dataset-repo") + +# Query metadata field directly - automatic JOIN is performed +df = api.query( + "SELECT * FROM binding_data WHERE cell_type = 'K562'", + "binding_data" +) +# Behind the scenes, the system automatically: +# 1. Detects that 'cell_type' is not in binding_data +# 2. Finds that 'cell_type' is in experiment_metadata +# 3. Identifies 'sample_id' as the common column for joining +# 4. Loads the metadata view +# 5. Rewrites the SQL to: SELECT * FROM binding_data +# LEFT JOIN experiment_metadata ON binding_data.sample_id = experiment_metadata.sample_id +# WHERE cell_type = 'K562' +``` + +#### Composite Join Keys + +When multiple columns are common between data and metadata configs, they are all used as join keys: + +```yaml +- config_name: sample_level_metadata + dataset_type: metadata + applies_to: ["binding_data"] + data_files: + - split: train + path: sample_metadata.parquet + dataset_info: + features: + - name: sample_id # Common column 1 + dtype: string + - name: gene_id # Common column 2 + dtype: string + - name: replicate + dtype: int64 + description: Biological replicate number +``` + +The system will automatically join on BOTH `sample_id` AND `gene_id`. + +#### Multiple Metadata Configs + +A data config can have multiple metadata configs applied to it, each inferred from their respective common columns: + +```yaml +configs: +- config_name: binding_data + dataset_type: annotated_features + dataset_info: + features: + - name: sample_id # For joining with experiment_metadata + - name: gene_id # For joining with gene_annotations + - name: binding_score + +- config_name: experiment_metadata + dataset_type: metadata + applies_to: ["binding_data"] + dataset_info: + features: + - name: sample_id # Common with binding_data + - name: cell_type + - name: treatment + +- config_name: gene_annotations + dataset_type: metadata + applies_to: ["binding_data"] + dataset_info: + features: + - name: gene_id # Common with binding_data + - name: gene_name + - name: gene_biotype +``` + +Queries can reference columns from multiple metadata sources: + +```python +# Automatically joins BOTH metadata configs +df = api.query( + "SELECT * FROM binding_data WHERE cell_type = 'K562' AND gene_biotype = 'protein_coding'", + "binding_data" +) +``` + +#### Disabling Automatic Joins + +If you prefer to write JOINs manually, you can disable automatic metadata joining: + +```python +df = api.query( + "SELECT * FROM binding_data", + "binding_data", + auto_join_metadata=False # Disable automatic joins +) +``` + +#### How Join Keys Are Inferred + +Join keys are automatically determined by finding the **intersection of column names** between: +- The data config's features +- The metadata config's features + +For example: +- If `binding_data` has columns: `[sample_id, gene_id, binding_score]` +- And `experiment_metadata` has columns: `[sample_id, cell_type, treatment]` +- The join key will be: `[sample_id]` (the only common column) + +**Important**: Make sure your common columns have the same name in both configs. The system uses exact name matching. + +#### Composite Join Keys + +When multiple columns are common between configs, **all** common columns are used as join keys. For example: +- If `annotated_features` has: `[id, batch, regulator_symbol, expression_value]` +- And `sample_metadata` has: `[id, batch, cell_type, data_usable]` +- The join keys will be: `[batch, id]` (both common columns) + +The system uses SQL `USING` clause for joins, which automatically deduplicates the join key columns in the result. This means you won't see duplicate columns like `id` and `id_1` in your results. + ### Embedded Metadata with `metadata_fields` When no explicit metadata config exists, you can extract metadata directly from the dataset's own files using the `metadata_fields` field. This specifies which fields should be treated as metadata. diff --git a/tfbpapi/HfQueryAPI.py b/tfbpapi/HfQueryAPI.py index 02ae684..bf00d1b 100644 --- a/tfbpapi/HfQueryAPI.py +++ b/tfbpapi/HfQueryAPI.py @@ -88,26 +88,40 @@ def _validate_metadata_fields( self, config_name: str, field_names: list[str] ) -> None: """ - Validate that field names exist in the config's metadata columns. + Validate that field names exist in the config's columns or joinable metadata. + + Checks both: + 1. The config's own columns + 2. Columns from metadata configs that have join_keys defined :param config_name: Configuration name to validate against :param field_names: List of field names to validate - :raises InvalidFilterFieldError: If any fields don't exist in metadata + :raises InvalidFilterFieldError: If any fields don't exist in available columns """ if not field_names: return try: - metadata_df = self.get_metadata(config_name) - if metadata_df.empty: - raise InvalidFilterFieldError( - config_name=config_name, - invalid_fields=field_names, - available_fields=[], - ) + # Get columns from the base config + base_columns = self._get_columns_from_config(config_name) + available_fields = set(base_columns) + + # Add columns from any metadata configs with join_keys + relationships = self.get_metadata_relationships() + data_relationships = [ + r for r in relationships if r.data_config == config_name + ] + + for rel in data_relationships: + if rel.relationship_type == "explicit" and rel.join_keys: + # This metadata can be auto-joined, include its columns + metadata_columns = self._get_columns_from_config( + rel.metadata_config + ) + available_fields.update(metadata_columns) - available_fields = list(metadata_df.columns) + # Check for invalid fields invalid_fields = [ field for field in field_names if field not in available_fields ] @@ -116,7 +130,7 @@ def _validate_metadata_fields( raise InvalidFilterFieldError( config_name=config_name, invalid_fields=invalid_fields, - available_fields=available_fields, + available_fields=list(available_fields), ) except Exception as e: if isinstance(e, InvalidFilterFieldError): @@ -177,7 +191,8 @@ def _extract_fields_from_sql(self, sql_where: str) -> list[str]: i += 1 continue - # Handle quoted strings - could be identifiers or values depending on context + # Handle quoted strings - could be identifiers or + # values depending on context if token.startswith(("'", '"')): # Extract the content inside quotes quoted_content = token[1:-1] @@ -195,7 +210,8 @@ def _extract_fields_from_sql(self, sql_where: str) -> list[str]: # Check what comes after this quoted string if next_significant_token: - # If followed by comparison operators or SQL keywords, it's a field name + # If followed by comparison operators or SQL keywords, + # it's a field name if ( next_significant_token in ["=", "!=", "<>", "<", ">", "<=", ">="] @@ -218,7 +234,8 @@ def _extract_fields_from_sql(self, sql_where: str) -> list[str]: break # If preceded by a comparison operator, could be a field name - # But we need to be very careful not to treat string literals as field names + # But we need to be very careful not to treat string + # literals as field names if prev_significant_token and prev_significant_token in [ "=", "!=", @@ -228,7 +245,8 @@ def _extract_fields_from_sql(self, sql_where: str) -> list[str]: "<=", ">=", ]: - # Only treat as field name if it looks like a database identifier + # Only treat as field name if it looks like a + # database identifier # AND doesn't look like a typical string value if self._looks_like_identifier( quoted_content @@ -264,7 +282,8 @@ def _extract_fields_from_sql(self, sql_where: str) -> list[str]: # Check if this looks like an identifier (field name) if re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", token): - # Check the context - if the next non-whitespace token is a comparison operator, + # Check the context - if the next non-whitespace token is a + # comparison operator, # then this is likely a field name next_significant_token = None for j in range(i + 1, len(tokens)): @@ -273,7 +292,8 @@ def _extract_fields_from_sql(self, sql_where: str) -> list[str]: next_significant_token = next_token break - # Check if followed by a comparison operator or SQL keyword that indicates a field + # Check if followed by a comparison operator or SQL keyword that + # indicates a field is_field = False if next_significant_token: @@ -378,15 +398,18 @@ def _looks_like_identifier(self, content: str) -> bool: if not content: return False - # Basic identifier pattern: starts with letter/underscore, contains only alphanumeric/underscore + # Basic identifier pattern: starts with letter/underscore, contains only + # alphanumeric/underscore if re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", content): return True - # Extended identifier pattern: could contain spaces if it's a column name like "quoted field" + # Extended identifier pattern: could contain spaces if it's a column name + # like "quoted field" # but not if it contains many special characters or looks like natural language if " " in content: # If it contains spaces, it should still look identifier-like - # Allow simple cases like "quoted field" but not "this is a long string value" + # Allow simple cases like "quoted field" but not "this is a long string + # value" words = content.split() if len(words) <= 3 and all( re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", word) for word in words @@ -453,10 +476,13 @@ def get_metadata( Supports three types of metadata retrieval: 1. Direct metadata configs: config_name is itself a metadata config 2. Embedded metadata: config_name has metadata_fields defined - 3. Applied metadata: config_name appears in another metadata config's applies_to list + 3. Applied metadata: config_name appears in another metadata config's + applies_to list - For explicit metadata configs (types 1 & 3), returns all rows from metadata table. - For embedded metadata (type 2), returns distinct combinations of metadata fields. + For explicit metadata configs (types 1 & 3), returns all rows from metadata + table. + For embedded metadata (type 2), returns distinct combinations of metadata + fields. :param config_name: Specific config name to retrieve metadata for :param refresh_cache: If True, force refresh from remote instead of using cache @@ -470,13 +496,15 @@ def get_metadata( relevant_relationships = None - # First priority: data_config matches (config_name is a data config with metadata) + # First priority: data_config matches (config_name is a data config + # with metadata) data_config_matches = [r for r in relationships if r.data_config == config_name] if data_config_matches: relevant_relationships = data_config_matches else: - # Second priority: metadata_config matches (config_name is itself a metadata config) + # Second priority: metadata_config matches (config_name is itself a + # metadata config) metadata_config_matches = [ r for r in relationships if r.metadata_config == config_name ] @@ -600,7 +628,8 @@ def set_sql_filter( Example: api.set_sql_filter("hackett_2020", "time IN (15, 30) AND mechanism = 'ZEV'") # To skip validation for complex SQL: - api.set_sql_filter("hackett_2020", "complex_expression(...)", validate_fields=False) + api.set_sql_filter("hackett_2020", "complex_expression(...)", + validate_fields=False) """ if not sql_where.strip(): @@ -637,17 +666,25 @@ def get_current_filter(self, config_name: str) -> str | None: return self._table_filters.get(config_name) def query( - self, sql: str, config_name: str, refresh_cache: bool = False + self, + sql: str, + config_name: str, + refresh_cache: bool = False, + auto_join_metadata: bool = True, ) -> pd.DataFrame: """ - Execute SQL query with automatic filter application. + Execute SQL query with automatic filter application and metadata joins. Loads the specified configuration, applies any stored filters, - and executes the query. + and executes the query. If auto_join_metadata is True (default), + automatically detects metadata columns in the query and joins + the appropriate metadata tables. :param sql: SQL query to execute :param config_name: Configuration name to query (table will be loaded if needed) :param refresh_cache: If True, force refresh from remote instead of using cache + :param auto_join_metadata: If True, automatically join metadata tables + when needed :return: DataFrame with query results :raises ValueError: If config_name not found or query fails @@ -657,6 +694,12 @@ def query( FROM hackett_2020", "hackett_2020") # Automatically applies: WHERE time = 15 AND mechanism = 'ZEV' + Example with metadata: + # If cell_type is in experiment_metadata that applies_to binding_data: + df = api.query("SELECT * FROM binding_data WHERE cell_type = 'K562'", + "binding_data") + # Automatically joins experiment_metadata and filters by cell_type + """ # Validate config exists if config_name not in [c.config_name for c in self.configs]: @@ -687,6 +730,49 @@ def query( # Replace config name with actual table name in SQL for user convenience sql_with_table = sql.replace(config_name, table_name) + # Handle automatic metadata joins if enabled + if auto_join_metadata: + # Extract column references from the query + referenced_columns = self._extract_column_references(sql_with_table) + + # Also check for columns in stored filters + if config_name in self._table_filters: + filter_sql = self._table_filters[config_name] + filter_columns = self._extract_column_references(filter_sql) + referenced_columns.update(filter_columns) + + # Get columns from the base config + base_columns = self._get_columns_from_config(config_name) + + # Find columns that aren't in the base config + missing_columns = referenced_columns - base_columns + + if missing_columns: + # Find metadata configs that might have these columns + metadata_matches = self._find_metadata_for_columns( + config_name, missing_columns + ) + + if metadata_matches: + # Load metadata views and build JOIN clauses + metadata_joins = [] + for metadata_config, join_keys in metadata_matches: + metadata_table = self._load_metadata_view( + metadata_config, refresh_cache=refresh_cache + ) + metadata_joins.append( + (metadata_config, metadata_table, join_keys) + ) + self.logger.info( + f"Auto-joining metadata '{metadata_config}' " + f"on keys: {join_keys}" + ) + + # Rewrite SQL to include JOINs + sql_with_table = self._build_join_sql( + sql_with_table, table_name, metadata_joins + ) + # Apply stored filters final_sql = self._apply_filter_to_sql(sql_with_table, config_name) @@ -737,3 +823,250 @@ def _apply_filter_to_sql(self, sql: str, config_name: str) -> str: f"{sql[:insert_position].rstrip()} " f"WHERE {filter_clause} {sql[insert_position:]}" ) + + def _get_columns_from_config(self, config_name: str) -> set[str]: + """ + Get all column names from a config's schema. + + :param config_name: Configuration name + :return: Set of column names + + """ + config = self.get_config(config_name) + if not config: + return set() + return {feature.name for feature in config.dataset_info.features} + + def _extract_column_references(self, sql: str) -> set[str]: + """ + Extract column references from SQL query. + + Simple regex-based extraction that looks for identifiers in common SQL contexts. + Not a full SQL parser, but good enough for most queries. + + :param sql: SQL query string + :return: Set of potential column names + + """ + # Remove string literals to avoid false positives + sql_no_strings = re.sub(r"'[^']*'", "", sql) + sql_no_strings = re.sub(r'"[^"]*"', "", sql_no_strings) + + # Extract identifiers that appear in typical column contexts: + # - After SELECT, WHERE, GROUP BY, ORDER BY, HAVING + # - In comparisons (=, !=, <, >, etc.) + # - After AS keyword + column_patterns = [ + r"\b(?:SELECT|WHERE|AND|OR|ON|GROUP BY|ORDER BY|HAVING)\s+[\w.]+", + r"[\w.]+\s*(?:=|!=|<>|<|>|<=|>=|LIKE|IN|IS)", + r"AS\s+([\w.]+)", + ] + + columns = set() + for pattern in column_patterns: + matches = re.finditer(pattern, sql_no_strings, re.IGNORECASE) + for match in matches: + # Extract the identifier part + text = match.group(0) + # Remove SQL keywords and operators + for keyword in [ + "SELECT", + "WHERE", + "AND", + "OR", + "ON", + "GROUP BY", + "ORDER BY", + "HAVING", + "AS", + "=", + "!=", + "<>", + "<", + ">", + "<=", + ">=", + "LIKE", + "IN", + "IS", + ]: + text = re.sub( + r"\b" + keyword + r"\b", "", text, flags=re.IGNORECASE + ) + # Extract remaining identifiers + identifiers = re.findall(r"\b[\w.]+\b", text) + for ident in identifiers: + # Remove table prefixes (e.g., "table.column" -> "column") + if "." in ident: + columns.add(ident.split(".")[-1]) + else: + columns.add(ident) + + # Filter out common SQL keywords and functions + sql_keywords = { + "SELECT", + "FROM", + "WHERE", + "AND", + "OR", + "NOT", + "IN", + "IS", + "NULL", + "AS", + "ON", + "JOIN", + "LEFT", + "RIGHT", + "INNER", + "OUTER", + "GROUP", + "BY", + "ORDER", + "HAVING", + "LIMIT", + "OFFSET", + "DISTINCT", + "COUNT", + "SUM", + "AVG", + "MIN", + "MAX", + "CASE", + "WHEN", + "THEN", + "ELSE", + "END", + "CAST", + "TRUE", + "FALSE", + } + columns = {c for c in columns if c.upper() not in sql_keywords} + + return columns + + def _find_metadata_for_columns( + self, config_name: str, columns: set[str] + ) -> list[tuple[str, list[str]]]: + """ + Find metadata configs that contain the specified columns. + + :param config_name: Data config name being queried + :param columns: Set of column names to search for + :return: List of tuples (metadata_config_name, join_keys) + + """ + relationships = self.get_metadata_relationships() + data_relationships = [r for r in relationships if r.data_config == config_name] + + metadata_matches = [] + for rel in data_relationships: + if rel.relationship_type == "embedded": + # Skip embedded metadata - columns are already in the data table + continue + + # Get metadata config schema + metadata_columns = self._get_columns_from_config(rel.metadata_config) + + # Check if any of the queried columns are in this metadata + if columns & metadata_columns: + if rel.join_keys: + metadata_matches.append((rel.metadata_config, rel.join_keys)) + else: + # Log warning if columns match but no join keys defined + self.logger.warning( + f"Columns {columns & metadata_columns} found in metadata " + f"config '{rel.metadata_config}' but no join_keys defined. " + f"Cannot automatically join. Please add join_keys to datacard." + ) + + return metadata_matches + + def _load_metadata_view( + self, metadata_config_name: str, refresh_cache: bool = False + ) -> str: + """ + Load metadata config into DuckDB and return the table name. + + :param metadata_config_name: Metadata config to load + :param refresh_cache: Whether to refresh cache + :return: Table name in DuckDB + + """ + config = self.get_config(metadata_config_name) + if not config: + raise ValueError(f"Metadata config '{metadata_config_name}' not found") + + config_result = self._get_metadata_for_config( + config, force_refresh=refresh_cache + ) + if not config_result.get("success", False): + raise ValueError( + f"Failed to load metadata '{metadata_config_name}': " + f"{config_result.get('message')}" + ) + + # TODO: fix this type ignore + return config_result.get("table_name") # type: ignore + + def _build_join_sql( + self, + base_sql: str, + base_table: str, + metadata_joins: list[tuple[str, str, list[str]]], + ) -> str: + """ + Rewrite SQL to include metadata JOINs. + + :param base_sql: Original SQL query + :param base_table: Base table name + :param metadata_joins: List of (metadata_config, metadata_table, join_keys) + :return: Rewritten SQL with JOINs + + """ + if not metadata_joins: + return base_sql + + # Extract the FROM clause position + from_pattern = r"\bFROM\s+" + re.escape(base_table) + match = re.search(from_pattern, base_sql, re.IGNORECASE) + if not match: + # Can't find FROM clause, return original + self.logger.warning("Could not find FROM clause for automatic join") + return base_sql + + from_end = match.end() + + # Build JOIN clauses + join_clauses = [] + for metadata_config, metadata_table, join_keys in metadata_joins: + # Use USING clause to avoid duplicate join columns in result + # USING automatically deduplicates the join keys + join_keys_str = ", ".join(join_keys) + join_clause = f"\nLEFT JOIN {metadata_table} USING ({join_keys_str})" + join_clauses.append(join_clause) + + # Insert JOINs after FROM clause + sql_before = base_sql[:from_end] + sql_after = base_sql[from_end:] + + # Check if there's already a WHERE/GROUP BY/etc after FROM + # We need to insert JOINs before those + insert_keywords = ["WHERE", "GROUP BY", "ORDER BY", "HAVING", "LIMIT"] + insert_position = len(sql_after) + + for keyword in insert_keywords: + match = re.search(r"\b" + keyword + r"\b", sql_after, re.IGNORECASE) + if match and match.start() < insert_position: + insert_position = match.start() + + final_sql = ( + sql_before + + "".join(join_clauses) + + " " + + sql_after[:insert_position].strip() + + " " + + sql_after[insert_position:] + ) + + return final_sql.strip() diff --git a/tfbpapi/datainfo/datacard.py b/tfbpapi/datainfo/datacard.py index 5bcbd42..2d8373e 100644 --- a/tfbpapi/datainfo/datacard.py +++ b/tfbpapi/datainfo/datacard.py @@ -1,7 +1,7 @@ """DataCard class for easy exploration of HuggingFace dataset metadata.""" import logging -from typing import Any, Dict, List, Optional, Set, Union +from typing import Any from pydantic import ValidationError @@ -82,8 +82,10 @@ def _load_and_validate_card(self) -> None: if "dtype" in field_path and error_type == "string_type": error_details.append( - f"Field '{field_path}': Expected a simple data type string (like 'string', 'int64', 'float64') " - f"but got a complex structure. This might be a categorical field with class labels. " + f"Field '{field_path}': Expected a simple data type " + "string (like 'string', 'int64', 'float64') " + f"but got a complex structure. This might be a " + "categorical field with class labels. " f"Actual value: {input_value}" ) else: @@ -184,11 +186,13 @@ def _extract_field_values(self, config: DatasetConfig, field_name: str) -> set[s values.update(partition_values) # For embedded metadata fields, we would need to query the actual data - # This is a placeholder - in practice, you might use the HF datasets server API + # This is a placeholder - in practice, you might use the HF + # datasets server API if config.metadata_fields and field_name in config.metadata_fields: # Placeholder for actual data extraction self.logger.debug( - f"Would extract embedded metadata for {field_name} in {config.config_name}" + "Would extract embedded metadata for " + f"{field_name} in {config.config_name}" ) except Exception as e: @@ -244,21 +248,36 @@ def get_metadata_relationships( meta_config.applies_to and data_config.config_name in meta_config.applies_to ): + # Infer join keys from column intersection + data_columns = { + feature.name for feature in data_config.dataset_info.features + } + meta_columns = { + feature.name for feature in meta_config.dataset_info.features + } + common_columns = data_columns & meta_columns + + # Use common columns as join keys (sorted for consistency) + join_keys = sorted(list(common_columns)) if common_columns else None + relationships.append( MetadataRelationship( data_config=data_config.config_name, metadata_config=meta_config.config_name, relationship_type="explicit", + join_keys=join_keys, ) ) - # Check for embedded metadata (always runs regardless of explicit relationships) + # Check for embedded metadata (always runs regardless of + # explicit relationships) if data_config.metadata_fields: relationships.append( MetadataRelationship( data_config=data_config.config_name, metadata_config=f"{data_config.config_name}_embedded", relationship_type="embedded", + join_keys=None, # Embedded metadata doesn't need joins ) ) diff --git a/tfbpapi/datainfo/models.py b/tfbpapi/datainfo/models.py index b152248..84209f8 100644 --- a/tfbpapi/datainfo/models.py +++ b/tfbpapi/datainfo/models.py @@ -1,7 +1,6 @@ """Pydantic models for dataset card validation.""" from enum import Enum -from typing import Any, Dict, List, Optional, Set, Union from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -27,12 +26,14 @@ class FeatureInfo(BaseModel): name: str = Field(..., description="Column name in the data") dtype: str | dict[str, ClassLabelType] = Field( ..., - description="Data type (string, int64, float64, etc.) or categorical class labels", + description="Data type (string, int64, float64, etc.)" + " or categorical class labels", ) description: str = Field(..., description="Detailed description of the field") role: str | None = Field( default=None, - description="Semantic role of the feature (e.g., 'target_identifier', 'regulator_identifier', 'quantitative_measure')", + description="Semantic role of the feature (e.g., 'target_identifier'," + "'regulator_identifier', 'quantitative_measure')", ) @field_validator("dtype", mode="before") @@ -50,15 +51,18 @@ def validate_dtype(cls, v): return {"class_label": ClassLabelType(**class_label_data)} else: raise ValueError( - f"Invalid class_label structure: expected dict with 'names' key, got {class_label_data}" + "Invalid class_label structure: expected dict with " + f"'names' key, got {class_label_data}" ) else: raise ValueError( - f"Unknown dtype structure: expected 'class_label' key in dict, got keys: {list(v.keys())}" + "Unknown dtype structure: expected 'class_label' " + f"key in dict, got keys: {list(v.keys())}" ) else: raise ValueError( - f"dtype must be a string or dict with class_label info, got {type(v)}: {v}" + "dtype must be a string or dict with " + f"class_label info, got {type(v)}: {v}" ) def get_dtype_summary(self) -> str: @@ -252,3 +256,7 @@ class MetadataRelationship(BaseModel): relationship_type: str = Field( ..., description="Type of relationship (explicit, embedded)" ) + join_keys: list[str] | None = Field( + default=None, + description="Column names to join on (from data config to metadata config)", + ) diff --git a/tfbpapi/tests/test_HfQueryAPI.py b/tfbpapi/tests/test_HfQueryAPI.py index 02ad783..ea61203 100644 --- a/tfbpapi/tests/test_HfQueryAPI.py +++ b/tfbpapi/tests/test_HfQueryAPI.py @@ -1,7 +1,7 @@ """Comprehensive tests for HfQueryAPI class.""" import logging -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import Mock, patch import duckdb import pandas as pd @@ -182,31 +182,33 @@ def test_validate_metadata_fields_success(self, mock_api): def test_validate_metadata_fields_invalid(self, mock_api): """Test _validate_metadata_fields with invalid fields.""" - metadata_df = pd.DataFrame({"time": [15, 30], "mechanism": ["ZEV", "ZREV"]}) - - with patch.object(mock_api, "get_metadata", return_value=metadata_df): - with pytest.raises(InvalidFilterFieldError) as exc_info: - mock_api._validate_metadata_fields( - "test_config", ["invalid_field", "time"] - ) - - error = exc_info.value - assert "invalid_field" in error.invalid_fields - assert "time" not in error.invalid_fields - assert "time" in error.available_fields - assert error.config_name == "test_config" + # Mock the new schema-based approach + with patch.object( + mock_api, "_get_columns_from_config", return_value={"time", "mechanism"} + ): + with patch.object(mock_api, "get_metadata_relationships", return_value=[]): + with pytest.raises(InvalidFilterFieldError) as exc_info: + mock_api._validate_metadata_fields( + "test_config", ["invalid_field", "time"] + ) + + error = exc_info.value + assert "invalid_field" in error.invalid_fields + assert "time" not in error.invalid_fields + assert "time" in error.available_fields + assert error.config_name == "test_config" def test_validate_metadata_fields_empty_metadata(self, mock_api): """Test _validate_metadata_fields with empty metadata.""" - empty_df = pd.DataFrame() - - with patch.object(mock_api, "get_metadata", return_value=empty_df): - with pytest.raises(InvalidFilterFieldError) as exc_info: - mock_api._validate_metadata_fields("test_config", ["any_field"]) - - error = exc_info.value - assert error.invalid_fields == ["any_field"] - assert error.available_fields == [] + # Mock the new schema-based approach with no columns + with patch.object(mock_api, "_get_columns_from_config", return_value=set()): + with patch.object(mock_api, "get_metadata_relationships", return_value=[]): + with pytest.raises(InvalidFilterFieldError) as exc_info: + mock_api._validate_metadata_fields("test_config", ["any_field"]) + + error = exc_info.value + assert error.invalid_fields == ["any_field"] + assert error.available_fields == [] def test_validate_metadata_fields_empty_list(self, mock_api): """Test _validate_metadata_fields with empty field list.""" diff --git a/tfbpapi/tests/test_filter_with_metadata_joins.py b/tfbpapi/tests/test_filter_with_metadata_joins.py new file mode 100644 index 0000000..2d66b5d --- /dev/null +++ b/tfbpapi/tests/test_filter_with_metadata_joins.py @@ -0,0 +1,268 @@ +"""Tests for filters with automatic metadata joins.""" + +from unittest.mock import MagicMock, Mock + +import pytest + +from tfbpapi.datainfo.models import ( + DataFileInfo, + DatasetConfig, + DatasetInfo, + DatasetType, + FeatureInfo, + MetadataRelationship, +) +from tfbpapi.errors import InvalidFilterFieldError +from tfbpapi.HfQueryAPI import HfQueryAPI + + +@pytest.fixture +def mock_data_config(): + """Create a mock data configuration.""" + return DatasetConfig( + config_name="annotated_features", + description="Test binding data", + dataset_type=DatasetType.ANNOTATED_FEATURES, + data_files=[DataFileInfo(path="data.parquet")], + dataset_info=DatasetInfo( + features=[ + FeatureInfo( + name="sample_id", dtype="string", description="Sample identifier" + ), + FeatureInfo( + name="gene_id", dtype="string", description="Gene identifier" + ), + FeatureInfo( + name="expression_value", + dtype="float64", + description="Expression value", + ), + ] + ), + ) + + +@pytest.fixture +def mock_metadata_config(): + """Create a mock metadata configuration.""" + return DatasetConfig( + config_name="sample_metadata", + description="Test sample metadata", + dataset_type=DatasetType.METADATA, + applies_to=["annotated_features"], + data_files=[DataFileInfo(path="metadata.parquet")], + dataset_info=DatasetInfo( + features=[ + FeatureInfo( + name="sample_id", dtype="string", description="Sample identifier" + ), + FeatureInfo( + name="data_usable", dtype="string", description="Data quality flag" + ), + FeatureInfo(name="cell_type", dtype="string", description="Cell type"), + ] + ), + ) + + +class TestFilterValidationWithMetadata: + """Test that filter validation includes metadata columns.""" + + def test_filter_on_metadata_field_validates( + self, mock_data_config, mock_metadata_config + ): + """Test that filters on metadata fields pass validation.""" + api = HfQueryAPI.__new__(HfQueryAPI) + api.logger = MagicMock() + api._table_filters = {} + + # Mock get_config + def get_config_side_effect(config_name): + if config_name == "annotated_features": + return mock_data_config + elif config_name == "sample_metadata": + return mock_metadata_config + return None + + api.get_config = Mock(side_effect=get_config_side_effect) # type: ignore + + # Mock get_metadata_relationships with inferred join key + api.get_metadata_relationships = Mock( # type: ignore + return_value=[ + MetadataRelationship( + data_config="annotated_features", + metadata_config="sample_metadata", + relationship_type="explicit", + join_keys=["sample_id"], # Inferred from column intersection + ) + ] + ) + + # This should NOT raise - data_usable is in the metadata + api._validate_metadata_fields("annotated_features", ["data_usable"]) + + def test_filter_on_base_field_validates( + self, mock_data_config, mock_metadata_config + ): + """Test that filters on base config fields still work.""" + api = HfQueryAPI.__new__(HfQueryAPI) + api.logger = MagicMock() + api._table_filters = {} + + # Mock get_config + def get_config_side_effect(config_name): + if config_name == "annotated_features": + return mock_data_config + elif config_name == "sample_metadata": + return mock_metadata_config + return None + + api.get_config = Mock(side_effect=get_config_side_effect) # type: ignore + + # Mock get_metadata_relationships with inferred join key + api.get_metadata_relationships = Mock( # type: ignore + return_value=[ + MetadataRelationship( + data_config="annotated_features", + metadata_config="sample_metadata", + relationship_type="explicit", + join_keys=["sample_id"], # Inferred from column intersection + ) + ] + ) + + # This should NOT raise - gene_id is in the base config + api._validate_metadata_fields("annotated_features", ["gene_id"]) + + def test_filter_on_invalid_field_fails( + self, mock_data_config, mock_metadata_config + ): + """Test that filters on non-existent fields still fail.""" + api = HfQueryAPI.__new__(HfQueryAPI) + api.logger = MagicMock() + api._table_filters = {} + + # Mock get_config + def get_config_side_effect(config_name): + if config_name == "annotated_features": + return mock_data_config + elif config_name == "sample_metadata": + return mock_metadata_config + return None + + api.get_config = Mock(side_effect=get_config_side_effect) # type: ignore + + # Mock get_metadata_relationships with inferred join key + api.get_metadata_relationships = Mock( # type: ignore + return_value=[ + MetadataRelationship( + data_config="annotated_features", + metadata_config="sample_metadata", + relationship_type="explicit", + join_keys=["sample_id"], # Inferred from column intersection + ) + ] + ) + + # This SHOULD raise - nonexistent_field is nowhere + with pytest.raises(InvalidFilterFieldError) as exc_info: + api._validate_metadata_fields("annotated_features", ["nonexistent_field"]) + + assert "nonexistent_field" in str(exc_info.value) + + def test_filter_validation_includes_both_sources( + self, mock_data_config, mock_metadata_config + ): + """Test that validation includes fields from both base and metadata.""" + api = HfQueryAPI.__new__(HfQueryAPI) + api.logger = MagicMock() + api._table_filters = {} + + # Mock get_config + def get_config_side_effect(config_name): + if config_name == "annotated_features": + return mock_data_config + elif config_name == "sample_metadata": + return mock_metadata_config + return None + + api.get_config = Mock(side_effect=get_config_side_effect) # type: ignore + + # Mock get_metadata_relationships with inferred join key + api.get_metadata_relationships = Mock( # type: ignore + return_value=[ + MetadataRelationship( + data_config="annotated_features", + metadata_config="sample_metadata", + relationship_type="explicit", + join_keys=["sample_id"], # Inferred from column intersection + ) + ] + ) + + # Mix of base and metadata fields should all validate + api._validate_metadata_fields( + "annotated_features", ["gene_id", "data_usable", "cell_type"] + ) + + +class TestFilterAutoJoinTrigger: + """Test that filters trigger automatic metadata joins.""" + + def test_stored_filter_triggers_join(self, mock_data_config, mock_metadata_config): + """Test that stored filters are analyzed for metadata columns.""" + api = HfQueryAPI.__new__(HfQueryAPI) + api.logger = MagicMock() + api._table_filters = {"annotated_features": "data_usable = 'pass'"} + + # Mock get_config + def get_config_side_effect(config_name): + if config_name == "annotated_features": + return mock_data_config + elif config_name == "sample_metadata": + return mock_metadata_config + return None + + api.get_config = Mock(side_effect=get_config_side_effect) # type: ignore + + # Mock get_metadata_relationships with inferred join key + api.get_metadata_relationships = Mock( # type: ignore + return_value=[ + MetadataRelationship( + data_config="annotated_features", + metadata_config="sample_metadata", + relationship_type="explicit", + join_keys=["sample_id"], # Inferred from column intersection + ) + ] + ) + + # When extracting columns from a simple query, the filter should also be checked + # Simulating the query flow + sql = "SELECT * FROM annotated_features" + + # Extract from query + referenced_columns = api._extract_column_references(sql) + + # Extract from filter + if "annotated_features" in api._table_filters: + filter_sql = api._table_filters["annotated_features"] + filter_columns = api._extract_column_references(filter_sql) + referenced_columns.update(filter_columns) + + # Should include data_usable from the filter + assert "data_usable" in referenced_columns + + # Check that metadata would be found + base_columns = api._get_columns_from_config("annotated_features") + missing_columns = referenced_columns - base_columns + + assert "data_usable" in missing_columns + + # Verify it would find the metadata + metadata_matches = api._find_metadata_for_columns( + "annotated_features", missing_columns + ) + + assert len(metadata_matches) == 1 + assert metadata_matches[0][0] == "sample_metadata" diff --git a/tfbpapi/tests/test_metadata_joins.py b/tfbpapi/tests/test_metadata_joins.py new file mode 100644 index 0000000..47269ec --- /dev/null +++ b/tfbpapi/tests/test_metadata_joins.py @@ -0,0 +1,346 @@ +"""Tests for automatic metadata join functionality.""" + +from unittest.mock import MagicMock, Mock + +import pytest + +from tfbpapi.datainfo.models import ( + DataFileInfo, + DatasetConfig, + DatasetInfo, + DatasetType, + FeatureInfo, + MetadataRelationship, +) +from tfbpapi.HfQueryAPI import HfQueryAPI + + +@pytest.fixture +def mock_data_config(): + """Create a mock data configuration.""" + return DatasetConfig( + config_name="binding_data", + description="Test binding data", + dataset_type=DatasetType.ANNOTATED_FEATURES, + data_files=[DataFileInfo(path="data.parquet")], + dataset_info=DatasetInfo( + features=[ + FeatureInfo( + name="sample_id", dtype="string", description="Sample identifier" + ), + FeatureInfo( + name="gene_id", dtype="string", description="Gene identifier" + ), + FeatureInfo( + name="binding_score", + dtype="float64", + description="Binding score value", + ), + ] + ), + ) + + +@pytest.fixture +def mock_metadata_config(): + """Create a mock metadata configuration.""" + return DatasetConfig( + config_name="experiment_metadata", + description="Test experiment metadata", + dataset_type=DatasetType.METADATA, + applies_to=["binding_data"], + data_files=[DataFileInfo(path="metadata.parquet")], + dataset_info=DatasetInfo( + features=[ + FeatureInfo( + name="sample_id", dtype="string", description="Sample identifier" + ), + FeatureInfo(name="cell_type", dtype="string", description="Cell type"), + FeatureInfo( + name="treatment", dtype="string", description="Treatment condition" + ), + ] + ), + ) + + +@pytest.fixture +def mock_metadata_config_composite_key(): + """Create a mock metadata configuration with multiple common columns.""" + return DatasetConfig( + config_name="sample_metadata", + description="Test sample metadata with composite key", + dataset_type=DatasetType.METADATA, + applies_to=["binding_data"], + data_files=[DataFileInfo(path="sample_metadata.parquet")], + dataset_info=DatasetInfo( + features=[ + FeatureInfo( + name="sample_id", dtype="string", description="Sample identifier" + ), + FeatureInfo( + name="gene_id", dtype="string", description="Gene identifier" + ), + FeatureInfo( + name="replicate", dtype="int64", description="Replicate number" + ), + ] + ), + ) + + +class TestMetadataRelationshipsWithInferredJoinKeys: + """Test that join keys are automatically inferred from column intersection.""" + + @pytest.fixture(autouse=True) + def patch_load(self): + """Patch the _load_and_validate_card method for all tests in this class.""" + from unittest.mock import patch + + with patch("tfbpapi.datainfo.datacard.DataCard._load_and_validate_card"): + yield + + def test_relationship_infers_single_join_key( + self, mock_data_config, mock_metadata_config + ): + """Test that single common column is inferred as join key.""" + from tfbpapi.datainfo.datacard import DataCard + from tfbpapi.datainfo.models import DatasetCard + + # Mock dataset card with both configs + mock_card = DatasetCard(configs=[mock_data_config, mock_metadata_config]) + + datacard = DataCard("test/repo") + datacard._dataset_card = mock_card + relationships = datacard.get_metadata_relationships() + + # Should have one explicit relationship + explicit_rels = [r for r in relationships if r.relationship_type == "explicit"] + assert len(explicit_rels) == 1 + + rel = explicit_rels[0] + assert rel.data_config == "binding_data" + assert rel.metadata_config == "experiment_metadata" + # sample_id is the only common column + assert rel.join_keys == ["sample_id"] + + def test_relationship_infers_composite_keys( + self, mock_data_config, mock_metadata_config_composite_key + ): + """Test that multiple common columns are inferred as composite join keys.""" + from tfbpapi.datainfo.datacard import DataCard + from tfbpapi.datainfo.models import DatasetCard + + mock_card = DatasetCard( + configs=[mock_data_config, mock_metadata_config_composite_key] + ) + + datacard = DataCard("test/repo") + datacard._dataset_card = mock_card + relationships = datacard.get_metadata_relationships() + + explicit_rels = [r for r in relationships if r.relationship_type == "explicit"] + assert len(explicit_rels) == 1 + # Both sample_id and gene_id are common + assert set(explicit_rels[0].join_keys) == { # type: ignore + "gene_id", + "sample_id", + } + + def test_relationship_no_common_columns(self, mock_data_config): + """Test that no join keys are inferred when there are no common columns.""" + from tfbpapi.datainfo.datacard import DataCard + from tfbpapi.datainfo.models import DatasetCard + + # Create metadata with no common columns + metadata_no_overlap = DatasetConfig( + config_name="unrelated_metadata", + description="Metadata with no common columns", + dataset_type=DatasetType.METADATA, + applies_to=["binding_data"], + data_files=[DataFileInfo(path="metadata.parquet")], + dataset_info=DatasetInfo( + features=[ + FeatureInfo( + name="unrelated_id", dtype="string", description="Unrelated ID" + ), + FeatureInfo( + name="some_value", dtype="float64", description="Some value" + ), + ] + ), + ) + + mock_card = DatasetCard(configs=[mock_data_config, metadata_no_overlap]) + + datacard = DataCard("test/repo") + datacard._dataset_card = mock_card + relationships = datacard.get_metadata_relationships() + + explicit_rels = [r for r in relationships if r.relationship_type == "explicit"] + assert len(explicit_rels) == 1 + # No common columns, so no join keys + assert explicit_rels[0].join_keys is None + + +class TestColumnExtraction: + """Test SQL column extraction functionality.""" + + def test_extract_simple_select(self): + """Test extracting columns from simple SELECT query.""" + api = HfQueryAPI.__new__(HfQueryAPI) + api.logger = MagicMock() + sql = "SELECT sample_id, cell_type FROM table WHERE cell_type = 'K562'" + columns = api._extract_column_references(sql) + assert "sample_id" in columns + assert "cell_type" in columns + + def test_extract_with_where_clause(self): + """Test extracting columns from WHERE clauses.""" + api = HfQueryAPI.__new__(HfQueryAPI) + api.logger = MagicMock() + sql = "SELECT * FROM table WHERE cell_type = 'K562' AND treatment = 'drug'" + columns = api._extract_column_references(sql) + assert "cell_type" in columns + assert "treatment" in columns + + def test_extract_filters_sql_keywords(self): + """Test that SQL keywords are filtered out.""" + api = HfQueryAPI.__new__(HfQueryAPI) + api.logger = MagicMock() + sql = "SELECT * FROM table WHERE col1 = 'value' AND col2 IS NOT NULL" + columns = api._extract_column_references(sql) + assert "SELECT" not in columns + assert "FROM" not in columns + assert "WHERE" not in columns + assert "AND" not in columns + assert "IS" not in columns + assert "NOT" not in columns + assert "NULL" not in columns + + def test_extract_ignores_string_literals(self): + """Test that string literals are ignored.""" + api = HfQueryAPI.__new__(HfQueryAPI) + api.logger = MagicMock() + sql = "SELECT col FROM table WHERE status = 'active_user'" + columns = api._extract_column_references(sql) + assert "col" in columns + assert "status" in columns + # 'active_user' should not be extracted as a column + assert "active_user" not in columns + + +class TestAutomaticMetadataJoins: + """Test automatic metadata joining in queries.""" + + def test_find_metadata_for_columns(self, mock_data_config, mock_metadata_config): + """Test finding metadata configs that contain specific columns.""" + api = HfQueryAPI.__new__(HfQueryAPI) + api.logger = MagicMock() + + # Mock get_config + def get_config_side_effect(config_name): + if config_name == "binding_data": + return mock_data_config + elif config_name == "experiment_metadata": + return mock_metadata_config + return None + + api.get_config = Mock(side_effect=get_config_side_effect) # type: ignore + + # Mock get_metadata_relationships with inferred join keys + api.get_metadata_relationships = Mock( # type: ignore + return_value=[ + MetadataRelationship( + data_config="binding_data", + metadata_config="experiment_metadata", + relationship_type="explicit", + join_keys=["sample_id"], # Inferred from column intersection + ) + ] + ) + + # Test finding metadata for cell_type column + columns = {"cell_type"} + results = api._find_metadata_for_columns("binding_data", columns) + + assert len(results) == 1 + assert results[0][0] == "experiment_metadata" + assert results[0][1] == ["sample_id"] + + def test_build_join_sql_single_key(self): + """Test SQL rewriting with single join key.""" + api = HfQueryAPI.__new__(HfQueryAPI) + api.logger = MagicMock() + + base_sql = "SELECT * FROM metadata_binding_data WHERE cell_type = 'K562'" + metadata_joins = [ + ( + "experiment_metadata", + "metadata_experiment_metadata", + ["sample_id"], + ) + ] + + result = api._build_join_sql(base_sql, "metadata_binding_data", metadata_joins) + + assert "LEFT JOIN metadata_experiment_metadata" in result + assert "USING (sample_id)" in result + + def test_build_join_sql_composite_key(self): + """Test SQL rewriting with composite join keys.""" + api = HfQueryAPI.__new__(HfQueryAPI) + api.logger = MagicMock() + + base_sql = "SELECT * FROM metadata_binding_data WHERE replicate = 1" + metadata_joins = [ + ( + "sample_metadata", + "metadata_sample_metadata", + ["gene_id", "sample_id"], # Alphabetically sorted + ) + ] + + result = api._build_join_sql(base_sql, "metadata_binding_data", metadata_joins) + + assert "LEFT JOIN metadata_sample_metadata" in result + assert "USING (gene_id, sample_id)" in result + + def test_auto_join_disabled(self): + """Test that auto_join_metadata=False disables automatic joins.""" + api = HfQueryAPI.__new__(HfQueryAPI) + api.logger = MagicMock() + api._extract_column_references = MagicMock() # type: ignore + + # Verify the parameter exists + assert hasattr(HfQueryAPI.query, "__code__") + params = HfQueryAPI.query.__code__.co_varnames + assert "auto_join_metadata" in params + + +class TestGetColumnsFromConfig: + """Test _get_columns_from_config helper method.""" + + def test_get_columns_from_data_config(self, mock_data_config): + """Test extracting columns from data config.""" + api = HfQueryAPI.__new__(HfQueryAPI) + api.get_config = Mock(return_value=mock_data_config) # type: ignore + + columns = api._get_columns_from_config("binding_data") + assert columns == {"sample_id", "gene_id", "binding_score"} + + def test_get_columns_from_metadata_config(self, mock_metadata_config): + """Test extracting columns from metadata config.""" + api = HfQueryAPI.__new__(HfQueryAPI) + api.get_config = Mock(return_value=mock_metadata_config) # type: ignore + + columns = api._get_columns_from_config("experiment_metadata") + assert columns == {"sample_id", "cell_type", "treatment"} + + def test_get_columns_nonexistent_config(self): + """Test getting columns from non-existent config returns empty set.""" + api = HfQueryAPI.__new__(HfQueryAPI) + api.get_config = Mock(return_value=None) # type: ignore + + columns = api._get_columns_from_config("nonexistent") + assert columns == set()