diff --git a/doc/output.rst b/doc/output.rst index 989baefdc..0cbfa9bf6 100644 --- a/doc/output.rst +++ b/doc/output.rst @@ -254,6 +254,11 @@ to read data using DuckDB. These include: Analysis scripts (see :ref:`analysis_scripts`) receive a ``history_sql`` and ``config_sql`` that reads data from Parquet files with filters applied when run using :py:mod:`runscripts.analysis`. +- :py:func:`~ecoli.library.parquet_emitter.quote_columns`: Enclose + raw column names in double quotes to handle special characters (e.g. spaces, + dashes, etc.) when constructing DuckDB SQL queries. +- :py:func:`~ecoli.library.parquet_emitter.list_columns`: Get a list of all + output column names, optionally filtered by glob pattern. - :py:func:`~ecoli.library.parquet_emitter.union_by_name`: Modify SQL query from :py:func:`~ecoli.library.parquet_emitter.dataset_sql` to use DuckDB's `union_by_name `_. diff --git a/ecoli/library/parquet_emitter.py b/ecoli/library/parquet_emitter.py index 0eeafa2a7..5a1551d46 100644 --- a/ecoli/library/parquet_emitter.py +++ b/ecoli/library/parquet_emitter.py @@ -1,4 +1,5 @@ import os +import fnmatch from concurrent.futures import Future, ThreadPoolExecutor from typing import Any, Callable, cast, Mapping, Optional from urllib import parse @@ -180,6 +181,43 @@ def dataset_sql(out_dir: str, experiment_ids: list[str]) -> tuple[str, str, str] return sql_queries[0], sql_queries[1], sql_queries[2] +def list_columns( + conn: duckdb.DuckDBPyConnection, history_subquery: str, pattern: str | None = None +) -> list[str]: + """ + Return list of columns in DuckDB subquery containing sim output data. + + Args: + conn: DuckDB connection + history_subquery: DuckDB query containing sim output data + pattern: Optional glob pattern to filter column names + """ + columns = ( + conn.sql(f"SELECT column_name FROM (DESCRIBE ({history_subquery}))") + .pl()["column_name"] + .to_list() + ) + if pattern is not None: + columns = fnmatch.filter(columns, pattern) + return columns + + +def quote_columns(columns: str | list[str]) -> str | list[str]: + """ + Given one or more raw column names (not DuckDB expressions), + return the same column name(s) enclosed in + double quotes to handle special characters (spaces, dashes, etc.). + + Args: + columns: One or more column names + """ + if isinstance(columns, str): + # Escape existing double quotes by doubling them + escaped = columns.replace('"', '""') + return f'"{escaped}"' + return [cast(str, quote_columns(col)) for col in columns] + + def num_cells(conn: duckdb.DuckDBPyConnection, subquery: str) -> int: """ Return cell count in DuckDB subquery containing ``experiment_id``, @@ -524,10 +562,16 @@ def read_stacked_columns( also include the ``experiment_id``, ``variant``, ``lineage_seed``, ``generation``, ``agent_id``, and ``time`` columns. - .. warning:: If the column expressions in ``columns`` are not from + .. hint:: To get a full list of columns in the output data that you can + use in your ``columns`` SQL expressions, use :py:func:`~.list_columns`. + + .. warning:: If your raw column names contain special characters and you + are not constructing your column expressions with :py:func:`~named_idx` or :py:func:`~ndidx_to_duckdb_expr`, - they may need to be enclosed in double quotes to handle - special characters (e.g. ``"col-with-hyphens"``). + the raw column names MUST be enclosed in double quotes + to handle special characters (e.g. ``'"space and-hyphens"'``, + ``"\"[brackets]\""``). Use :py:func:`~quote_columns` to quote + these columns before constructing SQL expressions with them. For example, to get the average total concentration of three bulk molecules with indices 100, 1000, and 10000 per cell:: diff --git a/ecoli/library/test_parquet_emitter.py b/ecoli/library/test_parquet_emitter.py index 670306748..19c4d2753 100644 --- a/ecoli/library/test_parquet_emitter.py +++ b/ecoli/library/test_parquet_emitter.py @@ -21,6 +21,9 @@ flatten_dict, union_pl_dtypes, ParquetEmitter, + quote_columns, + list_columns, + create_duckdb_conn, ) @@ -251,6 +254,154 @@ def test_union_pl_dtypes(self): pl.UInt32, ) == pl.List(pl.List(pl.List(pl.UInt32))) + def test_quote_columns(self): + """Test quote_columns handles special characters correctly.""" + # Test single string with special characters + assert quote_columns("simple") == '"simple"' + assert quote_columns("with spaces") == '"with spaces"' + assert quote_columns("with-hyphens") == '"with-hyphens"' + assert quote_columns("with[brackets]") == '"with[brackets]"' + assert quote_columns("with/slashes") == '"with/slashes"' + + # Test string with existing double quotes (should be escaped) + assert quote_columns('already"quoted') == '"already""quoted"' + assert quote_columns('"fully"quoted"') == '"""fully""quoted"""' + + # Test list of strings + assert quote_columns(["col1", "col2", "col3"]) == [ + '"col1"', + '"col2"', + '"col3"', + ] + assert quote_columns(["with spaces", "with-hyphens"]) == [ + '"with spaces"', + '"with-hyphens"', + ] + + # Test mixed special characters in list + assert quote_columns(["normal", "space here", "hyphen-here", 'quote"here']) == [ + '"normal"', + '"space here"', + '"hyphen-here"', + '"quote""here"', + ] + + # Test empty cases + assert quote_columns("") == '""' + assert quote_columns([]) == [] + + # Test that quoted columns actually work in DuckDB queries with weird column names + with tempfile.TemporaryDirectory() as tmp_path: + test_file = os.path.join(tmp_path, "weird_cols.parquet") + # Create test data with columns containing special characters + test_data = pl.DataFrame( + { + "simple": [1, 2, 3], + "with spaces": [4, 5, 6], + "with-hyphens": [7, 8, 9], + "with[brackets]": [10, 11, 12], + "with/slashes": [13, 14, 15], + 'has"quote': [16, 17, 18], + "dot.name": [19, 20, 21], + "colon:name": [22, 23, 24], + } + ) + test_data.write_parquet(test_file, statistics=False) + + conn = create_duckdb_conn() + + # Test selecting individual columns with special characters + for col in test_data.columns: + quoted_col = quote_columns(col) + result = conn.sql(f"SELECT {quoted_col} FROM '{test_file}'").pl() + assert result.shape == (3, 1) + assert result.columns[0] == col + expected_values = test_data[col].to_list() + assert result[col].to_list() == expected_values + + # Test selecting multiple columns at once + weird_cols = ["with spaces", "with-hyphens", "with[brackets]", 'has"quote'] + quoted_cols = ", ".join(quote_columns(weird_cols)) + result = conn.sql(f"SELECT {quoted_cols} FROM '{test_file}'").pl() + assert result.shape == (3, 4) + for col in weird_cols: + assert col in result.columns + assert result[col].to_list() == test_data[col].to_list() + + # Test that using WHERE clause works with quoted columns + quoted_space_col = quote_columns("with spaces") + result = conn.sql( + f"SELECT * FROM '{test_file}' WHERE {quoted_space_col} > 4" + ).pl() + assert result.shape == (2, 8) + assert result["with spaces"].to_list() == [5, 6] + + # Test aggregation with quoted columns + quoted_bracket_col = quote_columns("with[brackets]") + result = conn.sql( + f"SELECT AVG({quoted_bracket_col}) as avg_val FROM '{test_file}'" + ).pl() + assert result["avg_val"][0] == 11.0 + + # Test ORDER BY with quoted columns + quoted_slash_col = quote_columns("with/slashes") + result = conn.sql( + f"SELECT {quoted_slash_col} FROM '{test_file}' ORDER BY {quoted_slash_col} DESC" + ).pl() + assert result["with/slashes"].to_list() == [15, 14, 13] + + def test_list_columns(self): + """Test list_columns retrieves column names correctly.""" + with tempfile.TemporaryDirectory() as tmp_path: + # Create test Parquet file with known columns + test_file = os.path.join(tmp_path, "test.parquet") + test_data = pl.DataFrame( + { + "col_a": [1, 2, 3], + "col_b": [4.0, 5.0, 6.0], + "listeners__mass__cell_mass": [7.0, 8.0, 9.0], + "listeners__mass__dry_mass": [10.0, 11.0, 12.0], + "listeners__growth__instantaneous_growth_rate": [0.1, 0.2, 0.3], + "bulk": [[1, 2], [3, 4], [5, 6]], + } + ) + test_data.write_parquet(test_file, statistics=False) + + conn = create_duckdb_conn() + subquery = f"SELECT * FROM '{test_file}'" + + # Test getting all columns + all_cols = list_columns(conn, subquery) + assert len(all_cols) == 6 + assert "col_a" in all_cols + assert "col_b" in all_cols + assert "listeners__mass__cell_mass" in all_cols + + # Test pattern matching with glob patterns + listener_cols = list_columns(conn, subquery, "listeners__*") + assert len(listener_cols) == 3 + assert all(col.startswith("listeners__") for col in listener_cols) + + # Test pattern matching for specific listener + mass_cols = list_columns(conn, subquery, "listeners__mass__*") + assert len(mass_cols) == 2 + assert "listeners__mass__cell_mass" in mass_cols + assert "listeners__mass__dry_mass" in mass_cols + + # Test pattern that matches nothing + no_match = list_columns(conn, subquery, "nonexistent__*") + assert len(no_match) == 0 + + # Test pattern with single character wildcard + col_pattern = list_columns(conn, subquery, "col_?") + assert len(col_pattern) == 2 + assert "col_a" in col_pattern + assert "col_b" in col_pattern + + # Test exact match pattern + exact = list_columns(conn, subquery, "bulk") + assert exact == ["bulk"] + def compare_nested(a: list, b: list) -> bool: """