diff --git a/agentic_eda/jupyterlab_extension_backend/.gitignore b/agentic_eda/jupyterlab_extension_backend/.gitignore new file mode 100644 index 000000000..082bc2d6e --- /dev/null +++ b/agentic_eda/jupyterlab_extension_backend/.gitignore @@ -0,0 +1,17 @@ +# OS files +.DS_Store + +# Python cache/build artifacts +__pycache__/ +*.py[cod] +*.pyo +*.pyd + +# Secrets and local environment files +.env +*.env +config/.env +*.secret +*secret* +*.key +*.pem diff --git a/agentic_eda/jupyterlab_extension_backend/README.md b/agentic_eda/jupyterlab_extension_backend/README.md new file mode 100644 index 000000000..d3a9b1185 --- /dev/null +++ b/agentic_eda/jupyterlab_extension_backend/README.md @@ -0,0 +1,19 @@ +# JupyterLab Extension Backend + +Run the backend entrypoint from this directory: + +```bash +cd /Users/indro/src/tutorials1/agentic_eda/jupyterlab_extension_backend +python -m src.main \ + --mode integrity \ + --path /Users/indro/src/tutorials1/agentic_eda/jupyterlab_extension_backend/datasets/T1_slice.csv +``` + +If you run from a different directory, set `PYTHONPATH`: + +```bash +PYTHONPATH=/Users/indro/src/tutorials1/agentic_eda/jupyterlab_extension_backend \ +python -m src.main \ + --mode integrity \ + --path /Users/indro/src/tutorials1/agentic_eda/jupyterlab_extension_backend/datasets/T1_slice.csv +``` diff --git a/agentic_eda/jupyterlab_extension_backend/config/config.py b/agentic_eda/jupyterlab_extension_backend/config/config.py new file mode 100644 index 000000000..56a61fabe --- /dev/null +++ b/agentic_eda/jupyterlab_extension_backend/config/config.py @@ -0,0 +1,128 @@ +""" +Import as: + +import config.config as cconf +""" + +import dataclasses +import functools +import os + +import dotenv +import langchain_anthropic +import langchain_google_genai +import langchain_openai +import pydantic + +dataclass = dataclasses.dataclass +lru_cache = functools.lru_cache +ChatOpenAI = langchain_openai.ChatOpenAI +ChatAnthropic = langchain_anthropic.ChatAnthropic +ChatGoogleGenerativeAI = langchain_google_genai.ChatGoogleGenerativeAI +SecretStr = pydantic.SecretStr + +dotenv.load_dotenv() + + +@dataclass(frozen=True) +class Settings: + """ + Store model provider settings. + """ + + provider: str + model: str + temperature: float + timeout: float + max_retries: int + + +def _need(name: str) -> str: + """ + Read a required environment variable. + + :param name: environment variable name + :return: environment variable value + """ + value = os.getenv(name) + if value is None or value == "": + raise RuntimeError(f"Missing required environment variable: {name}") + return value + + +@lru_cache(maxsize=1) +def get_settings() -> Settings: + """ + Build settings from environment variables. + + :return: configured settings + """ + settings = Settings( + provider=os.getenv("LLM_PROVIDER", "openai"), + model=os.getenv("LLM_MODEL", "gpt-5-nano"), + temperature=float(os.getenv("LLM_TEMP", 0.2)), + timeout=float(os.getenv("LLM_TIMEOUT", 60)), + max_retries=int(os.getenv("LLM_MAX_RETRIES", 2)), + ) + return settings + + +@lru_cache(maxsize=1) +def get_chat_model(*, model: str | None = None) -> object: + """ + Build the configured chat model client. + + :param model: optional model override + :return: langchain chat model client + """ + settings = get_settings() + model_name = settings.model if model is None else model + provider = settings.provider + if provider == "openai": + _need("OPENAI_API_KEY") + chat_model = ChatOpenAI( + model=model_name, + temperature=settings.temperature, + timeout=settings.timeout, + max_retries=settings.max_retries, + ) + elif provider == "openai_compatible": + base_url = _need("OPENAI_COMPAT_BASE_URL") + api_key = _need("OPENAI_COMPAT_API_KEY") + chat_model = ChatOpenAI( + model=model_name, + base_url=base_url, + api_key=SecretStr(api_key), + temperature=settings.temperature, + timeout=settings.timeout, + max_retries=settings.max_retries, + ) + elif provider == "azure_openai_v1": + azure_base = _need("AZURE_OPENAI_BASE_URL") + azure_key = SecretStr(_need("AZURE_OPENAI_API_KEY")) + chat_model = ChatOpenAI( + model=model_name, + base_url=azure_base, + api_key=azure_key, + temperature=settings.temperature, + timeout=settings.timeout, + max_retries=settings.max_retries, + ) + elif provider == "anthropic": + _need("ANTHROPIC_API_KEY") + chat_model = ChatAnthropic( + model_name=model_name, + temperature=settings.temperature, + timeout=settings.timeout, + max_retries=settings.max_retries, + stop=None, + ) + elif provider in ("google", "gemini", "google_genai"): + _need("GOOGLE_API_KEY") + chat_model = ChatGoogleGenerativeAI( + model=model_name, + temperature=settings.temperature, + ) + else: + raise ValueError(f"Unsupported provider='{provider}'") + return chat_model diff --git a/agentic_eda/jupyterlab_extension_backend/datasets/T1_slice.csv b/agentic_eda/jupyterlab_extension_backend/datasets/T1_slice.csv new file mode 100644 index 000000000..fd8bb93b2 --- /dev/null +++ b/agentic_eda/jupyterlab_extension_backend/datasets/T1_slice.csv @@ -0,0 +1,101 @@ +Date/Time,LV ActivePower (kW),Wind Speed (m/s),Theoretical_Power_Curve (KWh),Wind Direction (°) +01 01 2018 00:00,380.047790527343,5.31133604049682,416.328907824861,259.994903564453 +01 01 2018 00:10,453.76919555664,5.67216682434082,519.917511061494,268.64111328125 +01 01 2018 00:20,306.376586914062,5.21603679656982,390.900015810951,272.564788818359 +01 01 2018 00:30,419.645904541015,5.65967416763305,516.127568975674,271.258087158203 +01 01 2018 00:40,380.650695800781,5.57794094085693,491.702971953588,265.674285888671 +01 01 2018 00:50,402.391998291015,5.60405206680297,499.436385024805,264.57861328125 +01 01 2018 01:00,447.605712890625,5.79300785064697,557.372363290225,266.163604736328 +01 01 2018 01:10,387.2421875,5.30604982376098,414.898178826186,257.949493408203 +01 01 2018 01:20,463.651214599609,5.58462905883789,493.677652137077,253.480697631835 +01 01 2018 01:30,439.725708007812,5.52322816848754,475.706782818068,258.72378540039 +01 01 2018 01:40,498.181701660156,5.72411584854125,535.841397042263,251.850997924804 +01 01 2018 01:50,526.816223144531,5.93419885635375,603.014076510633,265.504699707031 +01 01 2018 02:00,710.587280273437,6.54741382598876,824.662513585882,274.23291015625 +01 01 2018 02:10,655.194274902343,6.19974613189697,693.472641075637,266.733184814453 +01 01 2018 02:20,754.762512207031,6.50538301467895,808.098138482693,266.76040649414 +01 01 2018 02:30,790.173278808593,6.63411617279052,859.459020788565,270.493194580078 +01 01 2018 02:40,742.985290527343,6.37891292572021,759.434536596592,266.593292236328 +01 01 2018 02:50,748.229614257812,6.4466528892517,785.28100987646,265.571807861328 +01 01 2018 03:00,736.647827148437,6.41508293151855,773.172863451736,261.15869140625 +01 01 2018 03:10,787.246215820312,6.43753099441528,781.7712157188,257.56021118164 +01 01 2018 03:20,722.864074707031,6.22002410888671,700.764699868076,255.926498413085 +01 01 2018 03:30,935.033386230468,6.89802598953247,970.736626881787,250.012893676757 +01 01 2018 03:40,1220.60900878906,7.60971117019653,1315.04892785216,255.985702514648 +01 01 2018 03:50,1053.77197265625,7.28835582733154,1151.26574355584,255.444595336914 +01 01 2018 04:00,1493.80798339843,7.94310188293457,1497.58372354361,256.407409667968 +01 01 2018 04:10,1724.48803710937,8.37616157531738,1752.19966204818,252.41259765625 +01 01 2018 04:20,1636.93505859375,8.23695755004882,1668.47070685152,247.979400634765 +01 01 2018 04:30,1385.48803710937,7.87959098815917,1461.81579081391,238.609603881835 +01 01 2018 04:40,1098.93200683593,7.10137605667114,1062.28503444311,245.095596313476 +01 01 2018 04:50,1021.4580078125,6.95530700683593,995.995854606612,245.410202026367 +01 01 2018 05:00,1164.89294433593,7.09829807281494,1060.85971215544,235.227905273437 +01 01 2018 05:10,1073.33203125,6.95363092422485,995.250960801046,242.872695922851 +01 01 2018 05:20,1165.30798339843,7.24957799911499,1132.4168612641,244.835693359375 +01 01 2018 05:30,1177.98999023437,7.29469108581542,1154.36530469206,242.48159790039 +01 01 2018 05:40,1170.53601074218,7.37636995315551,1194.8430985043,247.97720336914 +01 01 2018 05:50,1145.53601074218,7.44855403900146,1231.43070603717,249.682998657226 +01 01 2018 06:00,1114.02697753906,7.2392520904541,1127.43320551345,248.401000976562 +01 01 2018 06:10,1153.18505859375,7.32921123504638,1171.35504358957,244.621704101562 +01 01 2018 06:20,1125.3310546875,7.13970518112182,1080.13908466205,244.631805419921 +01 01 2018 06:30,1228.73205566406,7.47422885894775,1244.63353439737,245.785995483398 +01 01 2018 06:40,1021.79302978515,7.03317403793334,1030.99268581181,248.652206420898 +01 01 2018 06:50,957.378173828125,6.88645505905151,965.683334443832,244.611694335937 +01 01 2018 07:00,909.887817382812,6.88782119750976,966.279104864065,235.84829711914 +01 01 2018 07:10,1000.95397949218,7.21643209457397,1116.4718990154,232.842697143554 +01 01 2018 07:20,1024.47802734375,7.0685977935791,1047.17023059277,229.933197021484 +01 01 2018 07:30,1009.53399658203,6.93829584121704,988.451940715539,230.13670349121 +01 01 2018 07:40,899.492980957031,6.53668785095214,820.416658585943,234.933807373046 +01 01 2018 07:50,725.110107421875,6.18062496185302,686.636942163399,232.837905883789 +01 01 2018 08:00,585.259399414062,5.81682586669921,564.927659543473,240.328796386718 +01 01 2018 08:10,443.913909912109,5.45015096664428,454.773587146918,238.12629699707 +01 01 2018 08:20,565.253784179687,5.81814908981323,565.349093224668,235.80029296875 +01 01 2018 08:30,644.037780761718,6.13027286529541,668.823569309414,224.958694458007 +01 01 2018 08:40,712.058898925781,6.34707784652709,747.460673422601,216.803894042968 +01 01 2018 08:50,737.394775390625,6.34743690490722,747.595109122642,205.785293579101 +01 01 2018 09:00,725.868103027343,6.19436883926391,691.546334303948,199.848495483398 +01 01 2018 09:10,408.997406005859,4.97719812393188,330.417630427964,207.997802734375 +01 01 2018 09:20,628.436828613281,5.95911121368408,611.283836510667,210.954895019531 +01 01 2018 09:30,716.1005859375,6.21137619018554,697.649474372052,215.69400024414 +01 01 2018 09:40,711.49560546875,6.11145305633544,662.235163012206,220.84260559082 +01 01 2018 09:50,838.151916503906,6.45632219314575,789.011422412419,237.065307617187 +01 01 2018 10:00,881.062072753906,6.66665792465209,872.739625855708,235.667495727539 +01 01 2018 10:10,663.703125,6.16287899017333,680.327891653483,229.329696655273 +01 01 2018 10:20,578.261596679687,6.01316785812377,628.442560754699,234.900604248046 +01 01 2018 10:30,465.620086669921,5.56120300292968,486.779567601972,230.422805786132 +01 01 2018 10:40,311.050903320312,4.96073198318481,326.411025380213,229.537506103515 +01 01 2018 10:50,230.05549621582,4.60387516021728,244.31624421611,231.79849243164 +01 01 2018 11:00,233.990600585937,4.55453395843505,233.632780531927,234.105606079101 +01 01 2018 11:10,175.592193603515,4.26362895965576,173.573663122312,228.776702880859 +01 01 2018 11:20,118.133102416992,3.89413905143737,108.571221110423,227.938995361328 +01 01 2018 11:30,142.202499389648,4.03876113891601,130.229989593698,224.46499633789 +01 01 2018 11:40,212.566192626953,4.50565099716186,223.196784083793,224.950500488281 +01 01 2018 11:50,222.610000610351,4.54339790344238,231.242507343633,229.12759399414 +01 01 2018 12:00,194.181198120117,4.32376098632812,185.598479588255,227.039993286132 +01 01 2018 12:10,82.6407470703125,3.63443708419799,68.5028197987886,230.31460571289 +01 01 2018 12:20,75.8952178955078,3.70551204681396,78.3961653540173,233.953292846679 +01 01 2018 12:30,41.9472389221191,3.25396800041198,29.2869556318446,233.06590270996 +01 01 2018 12:40,118.534599304199,3.77513694763183,88.8713653309387,227.753494262695 +01 01 2018 12:50,250.755905151367,4.69350099563598,264.119257409418,229.896606445312 +01 01 2018 13:00,346.86441040039,5.00293922424316,336.721998240131,235.279495239257 +01 01 2018 13:10,416.417907714843,5.36474990844726,430.92108895689,235.585296630859 +01 01 2018 13:20,331.941497802734,5.01618194580078,339.984940156412,229.942901611328 +01 01 2018 13:30,583.479919433593,5.97040796279907,615.05563084927,235.69529724121 +01 01 2018 13:40,776.552673339843,6.6555209159851,868.180844867276,241.457397460937 +01 01 2018 13:50,752.726379394531,6.60090398788452,846.029409522117,242.782104492187 +01 01 2018 14:00,589.073120117187,5.98137807846069,618.731442665699,234.984405517578 +01 01 2018 14:10,1109.12805175781,7.42459392547607,1219.19978672882,235.14729309082 +01 01 2018 14:20,1482.4599609375,8.18645191192626,1638.50890923271,238.479095458984 +01 01 2018 14:30,1523.43005371093,8.27493000030517,1691.1470390233,237.033203125 +01 01 2018 14:40,1572.17004394531,8.44920253753662,1796.76309010091,238.332397460937 +01 01 2018 14:50,1698.93994140625,8.5759744644165,1875.04719734159,235.641403198242 +01 01 2018 15:00,1616.84594726562,8.28225994110107,1695.53877696245,236.461395263671 +01 01 2018 15:10,1796.82397460937,8.73455238342285,1974.47580025242,234.354797363281 +01 01 2018 15:20,1885.86096191406,8.76410388946533,1993.17071186444,231.001602172851 +01 01 2018 15:30,2327.51196289062,9.66943168640136,2568.82712862015,227.60009765625 +01 01 2018 15:40,2499.162109375,10.1410903930664,2876.75361614448,227.73159790039 +01 01 2018 15:50,2820.51293945312,10.7724199295043,3186.02988321436,225.276397705078 +01 01 2018 16:00,2812.27905273437,10.6475200653076,3133.25922420184,224.680603027343 +01 01 2018 16:10,2530.44702148437,9.98266124725341,2781.27404078649,225.519500732421 +01 01 2018 16:20,2399.12109375,9.87438583374023,2711.49245838958,227.273803710937 +01 01 2018 16:30,2335.587890625,9.78547954559326,2651.34100928894,229.255493164062 diff --git a/agentic_eda/jupyterlab_extension_backend/src/format_datetime.py b/agentic_eda/jupyterlab_extension_backend/src/format_datetime.py new file mode 100644 index 000000000..8b538ff35 --- /dev/null +++ b/agentic_eda/jupyterlab_extension_backend/src/format_datetime.py @@ -0,0 +1,251 @@ +""" +Import as: + +import src.format_datetime as sfordat +""" + +import logging +import pathlib +from typing import TypedDict + +import langchain.agents as lagents +import langchain.tools as ltools +import langchain_core.messages as lmessages +import langgraph.graph as lgraph +import numpy as np +import pandas as pd +import pydantic + +import config.config as cconf +import src.handle_inputs as shainp +import tools.input_tools as tinptool + +_LOG = logging.getLogger(__name__) + + +def _score_parse(dt: pd.Series) -> float: + """ + Score datetime parse quality. + + :param dt: candidate datetime series + :return: score where larger means better + """ + datetime_series = pd.to_datetime(dt, errors="coerce", utc=True) + if datetime_series.isna().all(): + score = -1.0 + return score + parsed_fraction = float(datetime_series.notna().mean()) + min_timestamp = datetime_series.min() + max_timestamp = datetime_series.max() + range_score = 1.0 + min_bound = pd.Timestamp("1990-01-01", tz="UTC") + max_bound = pd.Timestamp("2035-01-01", tz="UTC") + if min_timestamp < min_bound or max_timestamp > max_bound: + range_score = 0.7 + datetime_no_na = datetime_series.dropna() + monotonic_score = 0.0 + if len(datetime_no_na) >= 3: + deltas = datetime_no_na.diff() + inversions = float((deltas < pd.Timedelta(0)).mean()) + monotonic_score = 1.0 - inversions + score = ( + parsed_fraction * 0.65 + range_score * 0.15 + monotonic_score * 0.20 + ) + return float(score) + + +class _Candidate(pydantic.BaseModel): + """ + Store one datetime parse candidate. + """ + + model_config = pydantic.ConfigDict(extra="forbid") + format: str | None + dayfirst: bool | None + yearfirst: bool | None + utc: bool + + +class _ParseWithCandidatesArgs(pydantic.BaseModel): + """ + Store tool arguments for candidate parsing. + """ + + model_config = pydantic.ConfigDict(extra="forbid") + path: str + col_name: str + candidates: list[_Candidate] + + +@ltools.tool(args_schema=_ParseWithCandidatesArgs) +def _parse_with_candidates( + path: str, + col_name: str, + candidates: list[_Candidate], +) -> dict: + """ + Parse one column with multiple datetime candidates and pick the best. + + :param path: dataset path + :param col_name: target column name + :param candidates: parse candidates + :return: best candidate summary + """ + dataset_path = pathlib.Path(path) + dataset = tinptool.load_dataset(dataset_path) + col = dataset[col_name] + best_score = -1.0 + best_candidate = None + best_parsed_fraction = 0.0 + series = col.astype(str).str.strip().replace( + { + "": np.nan, + "nan": np.nan, + "NaT": np.nan, + } + ) + for candidate in candidates: + candidate_dict = candidate.model_dump() + format_val = candidate_dict["format"] + dayfirst_val = candidate_dict["dayfirst"] + yearfirst_val = candidate_dict["yearfirst"] + utc_val = candidate_dict["utc"] + kwargs = { + key: val + for key, val in { + "format": format_val, + "dayfirst": dayfirst_val, + "yearfirst": yearfirst_val, + "utc": utc_val, + }.items() + if val is not None + } + try: + datetime_series = pd.to_datetime( + series, + errors="coerce", + **kwargs, + ) + except Exception: + continue + score = _score_parse(datetime_series) + if score > best_score: + best_score = score + best_candidate = candidate_dict + best_parsed_fraction = float(datetime_series.notna().mean()) + payload = { + "best_candidate": best_candidate, + "best_score": float(best_score), + "parsed_fraction": float(best_parsed_fraction), + } + return payload + + +class DateFormatterState(TypedDict): + """ + Store graph state for datetime formatting. + """ + + path: str + time_col: str + candidates: list[dict] + winner_formatter: dict + + +class DateFormatterOutput(pydantic.BaseModel): + """ + Store structured formatter output. + """ + + model_config = pydantic.ConfigDict(extra="forbid") + candidates: list[_Candidate] + winner_formatter: _Candidate + + +def run_formatting_agent(state: DateFormatterState) -> dict: + """ + Run LLM tool-calling to find the best datetime parser. + + :param state: formatter graph state + :return: candidate list and winner formatter + """ + system_prompt = ( + "Use tools to convert the provided time column into a correct datetime " + "format.\n" + "1. Use extract_head to inspect the temporal column and propose parse " + "candidates.\n" + "2. Call _parse_with_candidates with those candidates.\n" + "3. Return all candidates and the winning formatter." + ) + llm = cconf.get_chat_model(model="gpt-4.1") + agent = lagents.create_agent( + model=llm, + tools=[_parse_with_candidates, tinptool.extract_head], + system_prompt=system_prompt, + response_format=DateFormatterOutput, + ) + out = agent.invoke( + { + "messages": [ + lmessages.HumanMessage( + content=( + f"The dataset path is {state['path']} and the time " + f"column name is {state['time_col']}" + ) + ) + ] + } + ) + structured_response = out["structured_response"].model_dump() + payload = { + "candidates": structured_response["candidates"], + "winner_formatter": structured_response["winner_formatter"], + } + return payload + + +def call_input_handler(state: DateFormatterState) -> dict: + """ + Run input handler and pick the first temporal column. + + :param state: formatter graph state + :return: selected temporal column + """ + out = shainp.run_input_handler(state["path"]) + temporal_cols = out.get("temporal_cols") or [] + if not temporal_cols: + raise ValueError("No temporal columns found by input handler.") + payload = {"time_col": temporal_cols[0]} + return payload + + +date_formatter = lgraph.StateGraph(DateFormatterState) +date_formatter.add_node("input_handler", call_input_handler) +date_formatter.add_node("run_formatting_agent", run_formatting_agent) +date_formatter.add_edge(lgraph.START, "input_handler") +date_formatter.add_edge("input_handler", "run_formatting_agent") +date_formatter.add_edge("run_formatting_agent", lgraph.END) +graph = date_formatter.compile() + + +def run_date_formatter(path: str) -> dict: + """ + Execute datetime formatter graph and parse the selected time column. + + :param path: dataset path + :return: output including selected formatter and parsed dtype + """ + graph_in = {"path": path} + out: DateFormatterState = graph.invoke(graph_in) # type: ignore[assignment] + dataset_path = pathlib.Path(path) + dataset = tinptool.load_dataset(dataset_path) + raw_args = out["winner_formatter"] + format_args = {key: val for key, val in raw_args.items() if val is not None} + parsed_time = pd.to_datetime(dataset[out["time_col"]], **format_args) + payload = { + "time_col": out["time_col"], + "winner_formatter": out["winner_formatter"], + "parsed_dtype": str(parsed_time.dtype), + } + _LOG.info("Date formatter output: %s", payload) + return payload diff --git a/agentic_eda/jupyterlab_extension_backend/src/handle_inputs.py b/agentic_eda/jupyterlab_extension_backend/src/handle_inputs.py new file mode 100644 index 000000000..5c3e6ba68 --- /dev/null +++ b/agentic_eda/jupyterlab_extension_backend/src/handle_inputs.py @@ -0,0 +1,168 @@ +""" +Import as: + +import src.handle_inputs as shainp +""" + +from __future__ import annotations + +import argparse +import logging +import pathlib +from typing import TypedDict + +import langchain.agents as lagents +import langchain_core.messages as lmessages +import langgraph.graph as lgraph +import pydantic + +import config.config as cconf +import tools.input_tools as tinptool + +_LOG = logging.getLogger(__name__) + + +class InputState(TypedDict): + """ + Store graph state for input checks. + """ + + path: str | pathlib.Path + done: list[str] + has_header: bool + has_missing_values: bool + error: str + info: str + cols: list[str] + temporal_cols: list[str] + numeric_val_cols: list[str] + categorical_val_cols: list[str] + + +class LLMOutput(pydantic.BaseModel): + """ + Store structured output from the header classifier. + """ + + temporal_cols: list[str] + numeric_val_cols: list[str] + categorical_val_cols: list[str] + + +def header_classification_agent(state: InputState) -> dict: + """ + Classify temporal, numeric, and categorical columns. + + :param state: input graph state + :return: column classification payload + """ + llm = cconf.get_chat_model(model="gpt-4.1") + agent = lagents.create_agent( + model=llm, + tools=[tinptool.extract_head, tinptool.extract_metadata], + system_prompt=( + "You are a header classifier agent. Use tools to identify temporal " + "columns and classify the remaining value columns as numeric or " + "categorical. Output JSON with keys temporal_cols, " + "numeric_val_cols, and categorical_val_cols." + ), + response_format=LLMOutput, + ) + out = agent.invoke( + { + "messages": [ + lmessages.HumanMessage( + content=f"The dataset is in {state['path']}" + ) + ] + } + ) + result = out["structured_response"].model_dump() + return result + + +def error_node(state: InputState) -> dict: + """ + Log an error node transition. + + :param state: input graph state + :return: empty update + """ + _LOG.error("Input handler failed: %s", state["error"]) + return {} + + +def has_header(state: InputState) -> bool: + """ + Check if header validation passed. + + :param state: input graph state + :return: true when headers are valid + """ + has_header_flag = state["has_header"] + return has_header_flag + + +def run_input_handler(path: str | pathlib.Path) -> dict: + """ + Run dataset header and column classification checks. + + :param path: path to dataset + :return: final graph output + """ + graph_builder = lgraph.StateGraph(InputState) + graph_builder.add_node("header_analysis", tinptool.analyze_header) + graph_builder.add_node( + "header_classification_agent", + header_classification_agent, + ) + graph_builder.add_node("error", error_node) + graph_builder.add_edge(lgraph.START, "header_analysis") + graph_builder.add_conditional_edges( + "header_analysis", + has_header, + { + True: "header_classification_agent", + False: "error", + }, + ) + graph_builder.add_edge("error", lgraph.END) + graph_builder.add_edge("header_classification_agent", lgraph.END) + graph = graph_builder.compile() + init_state: InputState = { + "path": str(path), + "done": [], + "has_header": True, + "has_missing_values": False, + "error": "", + "info": "", + "cols": [], + "temporal_cols": [], + "numeric_val_cols": [], + "categorical_val_cols": [], + } + out = graph.invoke(init_state) + _LOG.info("Input handler output: %s", out) + return out + + +def _parse_args() -> argparse.Namespace: + """ + Parse command-line arguments. + + :return: parsed arguments + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "--path", + required=True, + help="Path to dataset file.", + ) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + args = _parse_args() + run_input_handler(args.path) diff --git a/agentic_eda/jupyterlab_extension_backend/src/integrity.py b/agentic_eda/jupyterlab_extension_backend/src/integrity.py new file mode 100644 index 000000000..c606de1ed --- /dev/null +++ b/agentic_eda/jupyterlab_extension_backend/src/integrity.py @@ -0,0 +1,334 @@ +""" +Import as: + +import src.integrity as sinteg +""" + +import logging +import pathlib +from typing import Literal +from typing import TypedDict + +import langchain.agents as lagents +import langchain_core.messages as lmessages +import langgraph.graph as lgraph +import pandas as pd +import pydantic + +import config.config as cconf +import src.format_datetime as sfordat +import src.handle_inputs as shainp +import tools.input_tools as tinptool + +_LOG = logging.getLogger(__name__) + + +class IntegrityState(TypedDict): + """ + Store graph state for integrity checks. + """ + + path: str + time_col: str | None + winner_formatter: dict + entity_col: str | None + numeric_cols: list[str] + nonnegative_cols: list[str] + jump_mult: float + report: dict + summary: str + flag: str + + +class IntegrityJudgeOutput(pydantic.BaseModel): + """ + Store structured LLM judgment. + """ + + summary: str + flag: Literal["yes", "no"] + + +def call_date_formatter(state: IntegrityState) -> dict: + """ + Run the datetime formatter graph. + + :param state: integrity graph state + :return: selected time column and formatter + """ + out: sfordat.DateFormatterState = sfordat.graph.invoke( # type: ignore + {"path": state["path"]} + ) + payload = { + "time_col": out["time_col"], + "winner_formatter": out["winner_formatter"], + } + return payload + + +def _maybe_infer_columns(state: IntegrityState) -> dict: + """ + Infer numeric columns when they are not provided. + + :param state: integrity graph state + :return: optional numeric column update + """ + if state.get("numeric_cols"): + payload = {} + else: + out = shainp.run_input_handler(state["path"]) + numeric_cols = out.get("numeric_val_cols") or [] + payload = {"numeric_cols": numeric_cols} + return payload + + +def run_integrity_checks(state: IntegrityState) -> dict: + """ + Run deterministic integrity checks on a dataset. + + :param state: integrity graph state + :return: report payload + """ + dataset_path = pathlib.Path(state["path"]) + dataset = tinptool.load_dataset(dataset_path) + issues: list[dict] = [] + summary: dict = { + "n_rows": int(dataset.shape[0]), + "n_cols": int(dataset.shape[1]), + } + if dataset.shape[0] == 0: + issues.append({"type": "empty_dataset", "msg": "Dataset has 0 rows."}) + report = {"summary": summary, "issues": issues} + payload = {"report": report} + return payload + time_col = state.get("time_col") + if time_col is None or time_col not in dataset.columns: + issues.append( + { + "type": "missing_time_col", + "msg": f"time_col missing: {time_col!r}", + } + ) + report = {"summary": summary, "issues": issues} + payload = {"report": report} + return payload + format_args = state.get("winner_formatter") or {} + format_args = { + key: val + for key, val in format_args.items() + if val is not None + } + try: + timestamp = pd.to_datetime( + dataset[time_col], + errors="coerce", + **format_args, + ) + except Exception: + timestamp = pd.to_datetime(dataset[time_col], errors="coerce") + summary["n_nat_time"] = int(timestamp.isna().sum()) + summary["min_time"] = ( + None if timestamp.dropna().empty else str(timestamp.dropna().min()) + ) + summary["max_time"] = ( + None if timestamp.dropna().empty else str(timestamp.dropna().max()) + ) + duplicate_timestamps = int(timestamp.dropna().duplicated().sum()) + summary["duplicate_timestamps"] = duplicate_timestamps + if duplicate_timestamps > 0: + issues.append( + {"type": "duplicate_timestamps", "count": duplicate_timestamps} + ) + entity_col = state.get("entity_col") + if entity_col is not None and entity_col in dataset.columns: + summary["n_entities"] = int(dataset[entity_col].nunique(dropna=True)) + tmp = dataset[[entity_col]].copy() + tmp["_ts"] = timestamp + duplicate_pairs = int( + tmp.dropna(subset=[entity_col, "_ts"]) + .duplicated(subset=[entity_col, "_ts"]) + .sum() + ) + summary["duplicate_entity_timestamp_pairs"] = duplicate_pairs + if duplicate_pairs > 0: + issues.append( + { + "type": "duplicate_entity_timestamp_pairs", + "count": duplicate_pairs, + } + ) + else: + summary["duplicate_entity_timestamp_pairs"] = None + numeric_cols = [col for col in state.get("numeric_cols") or []] + numeric_cols = [col for col in numeric_cols if col in dataset.columns] + nonnegative_cols = [col for col in state.get("nonnegative_cols") or []] + negative_report: dict = {} + for col in nonnegative_cols: + if col not in dataset.columns: + continue + series = pd.to_numeric(dataset[col], errors="coerce") + n_negative = int((series < 0).sum(skipna=True)) + if n_negative > 0: + negative_report[col] = n_negative + summary["negatives_in_nonnegative_cols"] = negative_report + if negative_report: + issues.append({"type": "negative_values", "details": negative_report}) + jump_mult = float(state.get("jump_mult") or 20.0) + jumps: dict = {} + if numeric_cols: + selected_cols = [time_col] + if entity_col is not None and entity_col in dataset.columns: + selected_cols.append(entity_col) + selected_cols.extend(numeric_cols) + tmp = dataset[selected_cols].copy() + tmp["_ts"] = timestamp + if entity_col is None or entity_col not in tmp.columns: + sort_cols = ["_ts"] + else: + sort_cols = [entity_col, "_ts"] + tmp = tmp.sort_values(sort_cols) + for col in numeric_cols: + tmp[col] = pd.to_numeric(tmp[col], errors="coerce") + if entity_col is None or entity_col not in tmp.columns: + diff = tmp[col].diff() + else: + diff = tmp.groupby(entity_col)[col].diff() + diff_abs = diff.abs() + scale = diff_abs.median() + if pd.isna(scale) or float(scale) <= 0.0: + scale = diff_abs.mean() + if pd.isna(scale) or float(scale) <= 0.0: + continue + threshold = float(scale) * jump_mult + flagged = diff_abs > threshold + n_flagged = int(flagged.sum(skipna=True)) + if n_flagged <= 0: + continue + examples: list[dict] = [] + flagged_idx = tmp.index[flagged.fillna(False)][:5] + for idx in flagged_idx: + diff_val = diff.loc[idx] + curr_val = tmp.loc[idx, col] + if pd.isna(diff_val) or pd.isna(curr_val): + prev_val = None + else: + prev_val = float(curr_val - diff_val) + example = { + "col": col, + "entity": ( + None + if entity_col is None or entity_col not in tmp.columns + else tmp.loc[idx, entity_col] + ), + "time": ( + None + if pd.isna(tmp.loc[idx, "_ts"]) + else str(tmp.loc[idx, "_ts"]) + ), + "prev": prev_val, + "curr": None if pd.isna(curr_val) else float(curr_val), + "diff": None if pd.isna(diff_val) else float(diff_val), + "threshold": float(threshold), + } + examples.append(example) + jumps[col] = { + "count": n_flagged, + "threshold": threshold, + "examples": examples, + } + issues.append( + { + "type": "impossible_jumps", + "col": col, + "count": n_flagged, + } + ) + summary["jump_mult"] = jump_mult + summary["jumps"] = jumps + report = {"summary": summary, "issues": issues} + payload = {"report": report} + return payload + + +def integrity_llm_summary(state: IntegrityState) -> dict: + """ + Summarize integrity report and provide go/no-go flag. + + :param state: integrity graph state + :return: summary and decision flag + """ + llm = cconf.get_chat_model(model="gpt-4.1") + agent = lagents.create_agent( + model=llm, + tools=[], + system_prompt=( + "You are an integrity judge. Decide if the dataset can proceed. " + "Return JSON with keys summary and flag. Set flag to yes only when " + "there are no meaningful integrity issues." + ), + response_format=IntegrityJudgeOutput, + ) + out = agent.invoke( + { + "messages": [ + lmessages.HumanMessage( + content=f"Here is the integrity report: {state['report']}" + ) + ] + } + ) + structured_response = out["structured_response"].model_dump() + payload = { + "summary": structured_response["summary"], + "flag": structured_response["flag"], + } + return payload + + +integrity = lgraph.StateGraph(IntegrityState) +integrity.add_node("date_formatter", call_date_formatter) +integrity.add_node("maybe_infer_columns", _maybe_infer_columns) +integrity.add_node("run_integrity_checks", run_integrity_checks) +integrity.add_node("integrity_llm_summary", integrity_llm_summary) +integrity.add_edge(lgraph.START, "date_formatter") +integrity.add_edge("date_formatter", "maybe_infer_columns") +integrity.add_edge("maybe_infer_columns", "run_integrity_checks") +integrity.add_edge("run_integrity_checks", "integrity_llm_summary") +integrity.add_edge("integrity_llm_summary", lgraph.END) +graph = integrity.compile() + + +def run_integrity( + path: str, + *, + time_col: str | None = None, + entity_col: str | None = None, +) -> dict: + """ + Execute integrity graph end to end. + + :param path: dataset path + :param time_col: optional time column override + :param entity_col: optional entity column + :return: integrity report with summary and flag + """ + init_state: IntegrityState = { + "path": path, + "time_col": time_col, + "winner_formatter": {}, + "entity_col": entity_col, + "numeric_cols": [], + "nonnegative_cols": [], + "jump_mult": 20.0, + "report": {}, + "summary": "", + "flag": "", + } + out = graph.invoke(init_state) + payload = { + "report": out["report"], + "summary": out["summary"], + "flag": out["flag"], + } + _LOG.info("Integrity output: %s", payload) + return payload diff --git a/agentic_eda/jupyterlab_extension_backend/src/main.py b/agentic_eda/jupyterlab_extension_backend/src/main.py new file mode 100644 index 000000000..9d60ccd46 --- /dev/null +++ b/agentic_eda/jupyterlab_extension_backend/src/main.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python +""" +Import as: + +import src.main as smain +""" + +import argparse +import json +import logging + +import src.format_datetime as sfordat +import src.handle_inputs as shainp +import src.integrity as sinteg + +_LOG = logging.getLogger(__name__) + + +def _parse_args() -> argparse.Namespace: + """ + Parse CLI arguments. + + :return: parsed arguments + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "--mode", + required=True, + choices=["input", "format", "integrity"], + help="Pipeline stage to execute.", + ) + parser.add_argument( + "--path", + required=True, + help="Path to dataset file.", + ) + parser.add_argument( + "--time_col", + default=None, + help="Optional time column override for integrity mode.", + ) + parser.add_argument( + "--entity_col", + default=None, + help="Optional entity column for integrity mode.", + ) + args = parser.parse_args() + return args + + +def _run_cli(args: argparse.Namespace) -> dict: + """ + Execute selected backend stage. + + :param args: parsed CLI args + :return: stage output payload + """ + mode = args.mode + if mode == "input": + payload = shainp.run_input_handler(args.path) + elif mode == "format": + payload = sfordat.run_date_formatter(args.path) + elif mode == "integrity": + payload = sinteg.run_integrity( + args.path, + time_col=args.time_col, + entity_col=args.entity_col, + ) + else: + raise ValueError(f"Unsupported mode='{mode}'") + return payload + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + cli_args = _parse_args() + output = _run_cli(cli_args) + _LOG.info("Pipeline output: %s", json.dumps(output, default=str, indent=2)) diff --git a/agentic_eda/jupyterlab_extension_backend/tools/input_tools.py b/agentic_eda/jupyterlab_extension_backend/tools/input_tools.py new file mode 100644 index 000000000..28d1c4c22 --- /dev/null +++ b/agentic_eda/jupyterlab_extension_backend/tools/input_tools.py @@ -0,0 +1,120 @@ +""" +Import as: + +import tools.input_tools as tinptool +""" + +import json +import pathlib +import re + +import langchain.tools as ltools +import pandas as pd + +_VALID_HEADER_START_RE = re.compile(r"^[A-Za-z_]") + + +def load_dataset(path: pathlib.Path) -> pd.DataFrame: + """ + Load a supported dataset from disk. + + :param path: path to dataset file + :return: dataset as dataframe + """ + ext = path.suffix.lower() + if ext == ".csv": + dataset = pd.read_csv(path) + else: + raise ValueError(f"Unsupported file extension='{ext}'") + return dataset + + +def analyze_header(state: dict) -> dict: + """ + Validate dataset headers. + + :param state: graph state containing dataset path + :return: updated state fields with header status + """ + path = pathlib.Path(str(state["path"])) + dataset = load_dataset(path) + cols = list(dataset.columns) + has_header = True + error = "" + if ( + all(isinstance(col, int) for col in cols) + and cols == list(range(len(cols))) + ): + has_header = False + error = "No column names." + else: + for col in cols: + if col is None: + has_header = False + error = "One or more column names missing." + break + col_name = str(col).strip() + if col_name == "": + has_header = False + error = "One or more column names missing." + break + if ( + col_name[0].isdigit() + or not _VALID_HEADER_START_RE.match(col_name) + ): + has_header = False + error = ( + "One or more column names start with invalid characters." + ) + break + if has_header: + result = {"has_header": has_header, "dataset": dataset} + else: + result = {"has_header": has_header, "error": error} + return result + + +@ltools.tool +def extract_metadata(path: str) -> dict: + """ + Return minimal dataset metadata. + + :param path: dataset path + :return: metadata with shape and per-column cardinality + """ + dataset_path = pathlib.Path(path) + dataset = load_dataset(dataset_path) + n_rows, n_cols = dataset.shape + n_unique = dataset.nunique(dropna=True) + n_unique_map = {str(col): int(n_unique[col]) for col in n_unique.index} + metadata = { + "n_rows": int(n_rows), + "n_cols": int(n_cols), + "n_unique": n_unique_map, + } + return metadata + + +@ltools.tool +def extract_head(path: str, *, n: int = 5) -> dict: + """ + Return the first rows from a dataset. + + :param path: dataset path + :param n: number of rows to return + :return: head rows serialized as JSON-compatible payload + """ + dataset_path = pathlib.Path(path) + dataset = load_dataset(dataset_path) + n_rows = int(n) + if n_rows <= 0: + n_rows = 5 + n_rows = min(n_rows, 50) + head = dataset.head(n_rows) + rows = json.loads(head.to_json(orient="records", date_format="iso")) + payload = { + "n": n_rows, + "columns": [str(col) for col in head.columns.tolist()], + "rows": rows, + } + return payload