From 4667e7d66ff7e430a27073914886c563009e7a36 Mon Sep 17 00:00:00 2001 From: Mikhail Salnikov <2613180+MihailSalnikov@users.noreply.github.com> Date: Sat, 29 Nov 2025 17:02:37 +0300 Subject: [PATCH 01/10] dded seq2seq support for new fixed versions of transformers; Fixed sequence experiments for reqranking --- Dockerfile | 31 +- experiments/subgraphs_reranking/ranking.ipynb | 1110 +++++++++++------ .../subgraphs_reranking/ranking_data_utils.py | 14 +- .../sequence/train_sequence_ranker.py | 33 +- kbqa/seq2seq/train.py | 52 +- kbqa/seq2seq/utils.py | 57 + mintaka_evaluate.py | 28 +- mkqa_test.json | 42 + mkqa_train.json | 370 ++++++ requirements.txt | 79 +- seq2seq.py | 40 +- 11 files changed, 1378 insertions(+), 478 deletions(-) create mode 100644 mkqa_test.json create mode 100644 mkqa_train.json diff --git a/Dockerfile b/Dockerfile index 5a317b8..a3d441a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,25 +1,18 @@ -FROM huggingface/transformers-pytorch-gpu:4.29.2 +FROM pytorch/pytorch:2.7.1-cuda11.8-cudnn9-runtime -RUN apt update && \ - apt install -y git htop g++ && \ - update-alternatives --install /usr/bin/gcc gcc /usr/bin/g++ 10 +RUN apt-get update && \ + apt-get install -y git htop g++ build-essential && \ + rm -rf /var/lib/apt/lists/* -COPY ./requirements.txt / -RUN pip3 install --upgrade pip && \ - pip3 install -r /requirements.txt - -RUN git clone --branch fixing_prefix_allowed_tokens_fn https://github.com/MihailSalnikov/fairseq && \ - cd /fairseq && \ - pip3 install --editable ./ && \ - cd / && \ - echo "export PYTHONPATH=/fairseq/" >> ~/.bashrc - -RUN git clone https://github.com/facebookresearch/KILT.git && \ - pip3 install ./KILT +ENV PYTHONUNBUFFERED=1 +ENV PIP_DISABLE_PIP_VERSION_CHECK=1 -RUN git clone https://github.com/MihailSalnikov/GENRE.git && \ - pip3 install ./GENRE +COPY ./requirements.txt / +RUN pip install --upgrade pip && \ + pip install -r /requirements.txt COPY ./ /workspace/kbqa -RUN pip3 install -e /workspace/kbqa +RUN pip install -e /workspace/kbqa + +WORKDIR /workspace/kbqa diff --git a/experiments/subgraphs_reranking/ranking.ipynb b/experiments/subgraphs_reranking/ranking.ipynb index d1f0265..2521a27 100644 --- a/experiments/subgraphs_reranking/ranking.ipynb +++ b/experiments/subgraphs_reranking/ranking.ipynb @@ -14,20 +14,7 @@ "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2024-05-06 09:42:09.512299: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n", - "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", - "2024-05-06 09:42:09.778598: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", - "2024-05-06 09:42:10.536839: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n", - "2024-05-06 09:42:10.536975: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64\n", - "2024-05-06 09:42:10.536984: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n" - ] - } - ], + "outputs": [], "source": [ "import pandas as pd\n", "from datasets import load_dataset, Dataset\n", @@ -55,7 +42,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -70,24 +57,22 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/usr/local/lib/python3.8/dist-packages/datasets/load.py:1486: FutureWarning: The repository for AmazonScience/mintaka contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/AmazonScience/mintaka\n", - "You can avoid this message in future by passing the argument `trust_remote_code=True`.\n", - "Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.\n", - " warnings.warn(\n" + "Using the latest cached version of the dataset since AmazonScience/mintaka couldn't be found on the Hugging Face Hub\n", + "Found the latest cached dataset configuration 'default' at /home/jovyan/.cache/huggingface/datasets/AmazonScience___mintaka/default/0.0.0/fe3f1235e31b01dc9cce913086f0cb6ed0d9b82e (last modified on Fri Oct 24 13:02:21 2025).\n" ] } ], "source": [ "# os.environ['HF_DATASETS_CACHE'] = '/workspace/storage/misc/huggingface'\n", "\n", - "ds_type = \"t5xlssm\" # 't5largessm', 't5xlssm', 'mistral' or 'mixtral'\n", + "ds_type = \"t5largessm\" # 't5largessm', 't5xlssm', 'mistral' or 'mixtral'\n", "\n", "hf_cache_dir = \"/workspace/storage/misc/huggingface\"\n", "mintaka_dataset_path = \"AmazonScience/mintaka\"\n", @@ -105,7 +90,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -131,6 +116,7 @@ " \n", " target\n", " target_out_of_vocab\n", + " id_y\n", " answerEntity\n", " questionEntity\n", " groundTruthAnswerEntity\n", @@ -138,7 +124,6 @@ " graph\n", " correct\n", " t5_sequence\n", - " gap_sequence\n", " ...\n", " gap_sequence_embedding\n", " t5_sequence_embedding\n", @@ -155,8 +140,8 @@ " \n", " \n", " count\n", - " 4000.000000\n", - " 4000.000000\n", + " 4000.00000\n", + " 4000.00000\n", " 4000.000000\n", " 4000.000000\n", " 4000.000000\n", @@ -175,60 +160,60 @@ " 4000.000000\n", " 4000.000000\n", " 4000.000000\n", - " 4000.000000\n", + " 4000.00000\n", " \n", " \n", " mean\n", - " 4.286750\n", - " 4.286750\n", - " 3.894250\n", - " 3.894250\n", - " 3.894250\n", - " 3.894250\n", - " 3.894250\n", - " 3.894250\n", - " 3.894250\n", - " 3.894250\n", + " 9.06900\n", + " 9.06900\n", + " 8.283500\n", + " 8.283500\n", + " 8.283500\n", + " 8.283500\n", + " 8.283500\n", + " 8.283500\n", + " 8.283500\n", + " 8.283500\n", " ...\n", - " 3.894250\n", - " 3.894250\n", - " 3.894250\n", - " 3.894250\n", - " 3.894250\n", - " 3.894250\n", - " 3.894250\n", - " 3.894250\n", - " 3.894250\n", - " 4.286750\n", + " 8.283500\n", + " 8.283500\n", + " 8.283500\n", + " 8.283500\n", + " 8.283500\n", + " 8.283500\n", + " 8.283500\n", + " 8.283500\n", + " 8.283500\n", + " 9.06900\n", " \n", " \n", " std\n", - " 3.132487\n", - " 3.132487\n", - " 3.554114\n", - " 3.554114\n", - " 3.554114\n", - " 3.554114\n", - " 3.554114\n", - " 3.554114\n", - " 3.554114\n", - " 3.554114\n", + " 6.60001\n", + " 6.60001\n", + " 7.458052\n", + " 7.458052\n", + " 7.458052\n", + " 7.458052\n", + " 7.458052\n", + " 7.458052\n", + " 7.458052\n", + " 7.458052\n", " ...\n", - " 3.554114\n", - " 3.554114\n", - " 3.554114\n", - " 3.554114\n", - " 3.554114\n", - " 3.554114\n", - " 3.554114\n", - " 3.554114\n", - " 3.554114\n", - " 3.132487\n", + " 7.458052\n", + " 7.458052\n", + " 7.458052\n", + " 7.458052\n", + " 7.458052\n", + " 7.458052\n", + " 7.458052\n", + " 7.458052\n", + " 7.458052\n", + " 6.60001\n", " \n", " \n", " min\n", - " 1.000000\n", - " 1.000000\n", + " 2.00000\n", + " 2.00000\n", " 0.000000\n", " 0.000000\n", " 0.000000\n", @@ -247,12 +232,12 @@ " 0.000000\n", " 0.000000\n", " 0.000000\n", - " 1.000000\n", + " 2.00000\n", " \n", " \n", " 25%\n", - " 1.000000\n", - " 1.000000\n", + " 2.00000\n", + " 2.00000\n", " 0.000000\n", " 0.000000\n", " 0.000000\n", @@ -271,160 +256,160 @@ " 0.000000\n", " 0.000000\n", " 0.000000\n", - " 1.000000\n", + " 2.00000\n", " \n", " \n", " 50%\n", - " 4.000000\n", - " 4.000000\n", - " 4.000000\n", - " 4.000000\n", - " 4.000000\n", - " 4.000000\n", - " 4.000000\n", - " 4.000000\n", - " 4.000000\n", - " 4.000000\n", + " 10.00000\n", + " 10.00000\n", + " 10.000000\n", + " 10.000000\n", + " 10.000000\n", + " 10.000000\n", + " 10.000000\n", + " 10.000000\n", + " 10.000000\n", + " 10.000000\n", " ...\n", - " 4.000000\n", - " 4.000000\n", - " 4.000000\n", - " 4.000000\n", - " 4.000000\n", - " 4.000000\n", - " 4.000000\n", - " 4.000000\n", - " 4.000000\n", - " 4.000000\n", + " 10.000000\n", + " 10.000000\n", + " 10.000000\n", + " 10.000000\n", + " 10.000000\n", + " 10.000000\n", + " 10.000000\n", + " 10.000000\n", + " 10.000000\n", + " 10.00000\n", " \n", " \n", " 75%\n", - " 7.000000\n", - " 7.000000\n", - " 7.000000\n", - " 7.000000\n", - " 7.000000\n", - " 7.000000\n", - " 7.000000\n", - " 7.000000\n", - " 7.000000\n", - " 7.000000\n", + " 14.00000\n", + " 14.00000\n", + " 14.000000\n", + " 14.000000\n", + " 14.000000\n", + " 14.000000\n", + " 14.000000\n", + " 14.000000\n", + " 14.000000\n", + " 14.000000\n", " ...\n", - " 7.000000\n", - " 7.000000\n", - " 7.000000\n", - " 7.000000\n", - " 7.000000\n", - " 7.000000\n", - " 7.000000\n", - " 7.000000\n", - " 7.000000\n", - " 7.000000\n", + " 14.000000\n", + " 14.000000\n", + " 14.000000\n", + " 14.000000\n", + " 14.000000\n", + " 14.000000\n", + " 14.000000\n", + " 14.000000\n", + " 14.000000\n", + " 14.00000\n", " \n", " \n", " max\n", - " 19.000000\n", - " 19.000000\n", - " 19.000000\n", - " 19.000000\n", - " 19.000000\n", - " 19.000000\n", - " 19.000000\n", - " 19.000000\n", - " 19.000000\n", - " 19.000000\n", + " 38.00000\n", + " 38.00000\n", + " 38.000000\n", + " 38.000000\n", + " 38.000000\n", + " 38.000000\n", + " 38.000000\n", + " 38.000000\n", + " 38.000000\n", + " 38.000000\n", " ...\n", - " 19.000000\n", - " 19.000000\n", - " 19.000000\n", - " 19.000000\n", - " 19.000000\n", - " 19.000000\n", - " 19.000000\n", - " 19.000000\n", - " 19.000000\n", - " 19.000000\n", + " 38.000000\n", + " 38.000000\n", + " 38.000000\n", + " 38.000000\n", + " 38.000000\n", + " 38.000000\n", + " 38.000000\n", + " 38.000000\n", + " 38.000000\n", + " 38.00000\n", " \n", " \n", "\n", - "

8 rows × 31 columns

\n", + "

8 rows × 32 columns

\n", "" ], "text/plain": [ - " target target_out_of_vocab answerEntity questionEntity \\\n", - "count 4000.000000 4000.000000 4000.000000 4000.000000 \n", - "mean 4.286750 4.286750 3.894250 3.894250 \n", - "std 3.132487 3.132487 3.554114 3.554114 \n", - "min 1.000000 1.000000 0.000000 0.000000 \n", - "25% 1.000000 1.000000 0.000000 0.000000 \n", - "50% 4.000000 4.000000 4.000000 4.000000 \n", - "75% 7.000000 7.000000 7.000000 7.000000 \n", - "max 19.000000 19.000000 19.000000 19.000000 \n", + " target target_out_of_vocab id_y answerEntity \\\n", + "count 4000.00000 4000.00000 4000.000000 4000.000000 \n", + "mean 9.06900 9.06900 8.283500 8.283500 \n", + "std 6.60001 6.60001 7.458052 7.458052 \n", + "min 2.00000 2.00000 0.000000 0.000000 \n", + "25% 2.00000 2.00000 0.000000 0.000000 \n", + "50% 10.00000 10.00000 10.000000 10.000000 \n", + "75% 14.00000 14.00000 14.000000 14.000000 \n", + "max 38.00000 38.00000 38.000000 38.000000 \n", "\n", - " groundTruthAnswerEntity complexityType graph correct \\\n", - "count 4000.000000 4000.000000 4000.000000 4000.000000 \n", - "mean 3.894250 3.894250 3.894250 3.894250 \n", - "std 3.554114 3.554114 3.554114 3.554114 \n", - "min 0.000000 0.000000 0.000000 0.000000 \n", - "25% 0.000000 0.000000 0.000000 0.000000 \n", - "50% 4.000000 4.000000 4.000000 4.000000 \n", - "75% 7.000000 7.000000 7.000000 7.000000 \n", - "max 19.000000 19.000000 19.000000 19.000000 \n", + " questionEntity groundTruthAnswerEntity complexityType graph \\\n", + "count 4000.000000 4000.000000 4000.000000 4000.000000 \n", + "mean 8.283500 8.283500 8.283500 8.283500 \n", + "std 7.458052 7.458052 7.458052 7.458052 \n", + "min 0.000000 0.000000 0.000000 0.000000 \n", + "25% 0.000000 0.000000 0.000000 0.000000 \n", + "50% 10.000000 10.000000 10.000000 10.000000 \n", + "75% 14.000000 14.000000 14.000000 14.000000 \n", + "max 38.000000 38.000000 38.000000 38.000000 \n", "\n", - " t5_sequence gap_sequence ... gap_sequence_embedding \\\n", - "count 4000.000000 4000.000000 ... 4000.000000 \n", - "mean 3.894250 3.894250 ... 3.894250 \n", - "std 3.554114 3.554114 ... 3.554114 \n", - "min 0.000000 0.000000 ... 0.000000 \n", - "25% 0.000000 0.000000 ... 0.000000 \n", - "50% 4.000000 4.000000 ... 4.000000 \n", - "75% 7.000000 7.000000 ... 7.000000 \n", - "max 19.000000 19.000000 ... 19.000000 \n", + " correct t5_sequence ... gap_sequence_embedding \\\n", + "count 4000.000000 4000.000000 ... 4000.000000 \n", + "mean 8.283500 8.283500 ... 8.283500 \n", + "std 7.458052 7.458052 ... 7.458052 \n", + "min 0.000000 0.000000 ... 0.000000 \n", + "25% 0.000000 0.000000 ... 0.000000 \n", + "50% 10.000000 10.000000 ... 10.000000 \n", + "75% 14.000000 14.000000 ... 14.000000 \n", + "max 38.000000 38.000000 ... 38.000000 \n", "\n", " t5_sequence_embedding question_answer_embedding \\\n", "count 4000.000000 4000.000000 \n", - "mean 3.894250 3.894250 \n", - "std 3.554114 3.554114 \n", + "mean 8.283500 8.283500 \n", + "std 7.458052 7.458052 \n", "min 0.000000 0.000000 \n", "25% 0.000000 0.000000 \n", - "50% 4.000000 4.000000 \n", - "75% 7.000000 7.000000 \n", - "max 19.000000 19.000000 \n", + "50% 10.000000 10.000000 \n", + "75% 14.000000 14.000000 \n", + "max 38.000000 38.000000 \n", "\n", " highlighted_determ_sequence no_highlighted_determ_sequence \\\n", "count 4000.000000 4000.000000 \n", - "mean 3.894250 3.894250 \n", - "std 3.554114 3.554114 \n", + "mean 8.283500 8.283500 \n", + "std 7.458052 7.458052 \n", "min 0.000000 0.000000 \n", "25% 0.000000 0.000000 \n", - "50% 4.000000 4.000000 \n", - "75% 7.000000 7.000000 \n", - "max 19.000000 19.000000 \n", + "50% 10.000000 10.000000 \n", + "75% 14.000000 14.000000 \n", + "max 38.000000 38.000000 \n", "\n", " highlighted_t5_sequence no_highlighted_t5_sequence \\\n", "count 4000.000000 4000.000000 \n", - "mean 3.894250 3.894250 \n", - "std 3.554114 3.554114 \n", + "mean 8.283500 8.283500 \n", + "std 7.458052 7.458052 \n", "min 0.000000 0.000000 \n", "25% 0.000000 0.000000 \n", - "50% 4.000000 4.000000 \n", - "75% 7.000000 7.000000 \n", - "max 19.000000 19.000000 \n", + "50% 10.000000 10.000000 \n", + "75% 14.000000 14.000000 \n", + "max 38.000000 38.000000 \n", "\n", " highlighted_gap_sequence no_highlighted_gap_sequence model_answers \n", - "count 4000.000000 4000.000000 4000.000000 \n", - "mean 3.894250 3.894250 4.286750 \n", - "std 3.554114 3.554114 3.132487 \n", - "min 0.000000 0.000000 1.000000 \n", - "25% 0.000000 0.000000 1.000000 \n", - "50% 4.000000 4.000000 4.000000 \n", - "75% 7.000000 7.000000 7.000000 \n", - "max 19.000000 19.000000 19.000000 \n", + "count 4000.000000 4000.000000 4000.00000 \n", + "mean 8.283500 8.283500 9.06900 \n", + "std 7.458052 7.458052 6.60001 \n", + "min 0.000000 0.000000 2.00000 \n", + "25% 0.000000 0.000000 2.00000 \n", + "50% 10.000000 10.000000 10.00000 \n", + "75% 14.000000 14.000000 14.00000 \n", + "max 38.000000 38.000000 38.00000 \n", "\n", - "[8 rows x 31 columns]" + "[8 rows x 32 columns]" ] }, - "execution_count": 13, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -440,7 +425,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -464,12 +449,12 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "results_path = Path(\n", - " f\"/workspace/storage/misc/subgraphs_reranking_runs/reranking_model_results/{ds_type}/\"\n", + " f\"/home/jovyan/kbqa/reranking_model_results/{ds_type}/\"\n", ")\n", "results_path.mkdir(parents=True, exist_ok=True)\n", "\n", @@ -484,7 +469,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -503,142 +488,419 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 10, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RUN 1 logreg_text_reranking_seq2seq completed\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/jovyan/kbqa/experiments/subgraphs_reranking/ranking_model.py:174: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " train_df[self.graph_features] = scaler.fit_transform(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RUN 1 logreg_graph_reranking_seq2seq completed\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/jovyan/kbqa/experiments/subgraphs_reranking/ranking_model.py:174: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " train_df[self.graph_features] = scaler.fit_transform(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RUN 1 logreg_text_graph_reranking_seq2seq completed\n", + "RUN 1 logreg_g2t_determ_reranking_seq2seq completed\n", + "RUN 1 logreg_g2t_t5_reranking_seq2seq completed\n", + "RUN 1 logreg_g2t_gap_reranking_seq2seq completed\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/jovyan/kbqa/experiments/subgraphs_reranking/ranking_model.py:174: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " train_df[self.graph_features] = scaler.fit_transform(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RUN 1 logreg_text_g2t_determ_graph_reranking_seq2seq completed\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/jovyan/kbqa/experiments/subgraphs_reranking/ranking_model.py:174: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " train_df[self.graph_features] = scaler.fit_transform(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RUN 1 logreg_text_g2t_t5_graph_reranking_seq2seq completed\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/jovyan/kbqa/experiments/subgraphs_reranking/ranking_model.py:174: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " train_df[self.graph_features] = scaler.fit_transform(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RUN 1 logreg_text_g2t_gap_graph_reranking_seq2seq completed\n", + "RUN 2 logreg_text_reranking_seq2seq completed\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/jovyan/kbqa/experiments/subgraphs_reranking/ranking_model.py:174: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " train_df[self.graph_features] = scaler.fit_transform(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RUN 2 logreg_graph_reranking_seq2seq completed\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/jovyan/kbqa/experiments/subgraphs_reranking/ranking_model.py:174: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " train_df[self.graph_features] = scaler.fit_transform(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RUN 2 logreg_text_graph_reranking_seq2seq completed\n", + "RUN 2 logreg_g2t_determ_reranking_seq2seq completed\n", + "RUN 2 logreg_g2t_t5_reranking_seq2seq completed\n", + "RUN 2 logreg_g2t_gap_reranking_seq2seq completed\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/jovyan/kbqa/experiments/subgraphs_reranking/ranking_model.py:174: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " train_df[self.graph_features] = scaler.fit_transform(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RUN 2 logreg_text_g2t_determ_graph_reranking_seq2seq completed\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/jovyan/kbqa/experiments/subgraphs_reranking/ranking_model.py:174: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " train_df[self.graph_features] = scaler.fit_transform(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RUN 2 logreg_text_g2t_t5_graph_reranking_seq2seq completed\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/jovyan/kbqa/experiments/subgraphs_reranking/ranking_model.py:174: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " train_df[self.graph_features] = scaler.fit_transform(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RUN 2 logreg_text_g2t_gap_graph_reranking_seq2seq completed\n", + "RUN 3 logreg_text_reranking_seq2seq completed\n" + ] + }, { "name": "stderr", "output_type": "stream", "text": [ - "/workspace/kbqa/experiments/subgraphs_reranking/ranking_model.py:169: SettingWithCopyWarning: \n", + "/home/jovyan/kbqa/experiments/subgraphs_reranking/ranking_model.py:174: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", - " train_df[self.graph_features] = scaler.fit_transform(\n", - "/workspace/kbqa/experiments/subgraphs_reranking/ranking_model.py:169: SettingWithCopyWarning: \n", + " train_df[self.graph_features] = scaler.fit_transform(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RUN 3 logreg_graph_reranking_seq2seq completed\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/jovyan/kbqa/experiments/subgraphs_reranking/ranking_model.py:174: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", - " train_df[self.graph_features] = scaler.fit_transform(\n", - "/workspace/kbqa/experiments/subgraphs_reranking/ranking_model.py:169: SettingWithCopyWarning: \n", + " train_df[self.graph_features] = scaler.fit_transform(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RUN 3 logreg_text_graph_reranking_seq2seq completed\n", + "RUN 3 logreg_g2t_determ_reranking_seq2seq completed\n", + "RUN 3 logreg_g2t_t5_reranking_seq2seq completed\n", + "RUN 3 logreg_g2t_gap_reranking_seq2seq completed\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/jovyan/kbqa/experiments/subgraphs_reranking/ranking_model.py:174: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", - " train_df[self.graph_features] = scaler.fit_transform(\n", - "/workspace/kbqa/experiments/subgraphs_reranking/ranking_model.py:169: SettingWithCopyWarning: \n", + " train_df[self.graph_features] = scaler.fit_transform(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RUN 3 logreg_text_g2t_determ_graph_reranking_seq2seq completed\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/jovyan/kbqa/experiments/subgraphs_reranking/ranking_model.py:174: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", - " train_df[self.graph_features] = scaler.fit_transform(\n", - "/workspace/kbqa/experiments/subgraphs_reranking/ranking_model.py:169: SettingWithCopyWarning: \n", + " train_df[self.graph_features] = scaler.fit_transform(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RUN 3 logreg_text_g2t_t5_graph_reranking_seq2seq completed\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/jovyan/kbqa/experiments/subgraphs_reranking/ranking_model.py:174: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", " train_df[self.graph_features] = scaler.fit_transform(\n" ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RUN 3 logreg_text_g2t_gap_graph_reranking_seq2seq completed\n" + ] } ], "source": [ - "logreg_ranker = LogisticRegressionRanker(sequence_features=features_map[\"text\"])\n", - "logreg_ranker.fit(train_df, n_jobs=8)\n", - "with open(\n", - " results_path / f\"logreg_text_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", - ") as f:\n", - " for result in logreg_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")\n", + "for run in [1, 2, 3]:\n", + " random.seed(42 + run)\n", + " np.random.seed(42 + run)\n", + " torch.manual_seed(42 + run)\n", + " torch.cuda.manual_seed_all(42 + run)\n", + " set_seed(42 + run)\n", + " \n", + " logreg_ranker = LogisticRegressionRanker(sequence_features=features_map[\"text\"])\n", + " logreg_ranker.fit(train_df, n_jobs=16)\n", + " with open(\n", + " results_path / f\"logreg_text_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl\", \"w\"\n", + " ) as f:\n", + " for result in logreg_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")\n", + " print(f\"RUN {run} logreg_text_reranking_seq2seq completed\")\n", "\n", - "scaler = preprocessing.MinMaxScaler()\n", - "logreg_ranker = LogisticRegressionRanker(graph_features=features_map[\"graph\"])\n", - "logreg_ranker.fit(train_df, n_jobs=8, scaler=scaler)\n", - "with open(\n", - " results_path / f\"logreg_graph_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", - ") as f:\n", - " for result in logreg_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")\n", + " scaler = preprocessing.MinMaxScaler()\n", + " logreg_ranker = LogisticRegressionRanker(graph_features=features_map[\"graph\"])\n", + " logreg_ranker.fit(train_df, n_jobs=16, scaler=scaler)\n", + " with open(\n", + " results_path / f\"logreg_graph_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl\", \"w\"\n", + " ) as f:\n", + " for result in logreg_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")\n", + " print(f\"RUN {run} logreg_graph_reranking_seq2seq completed\")\n", "\n", + " logreg_ranker = LogisticRegressionRanker(\n", + " sequence_features=features_map[\"text\"], graph_features=features_map[\"graph\"]\n", + " )\n", + " logreg_ranker.fit(train_df, n_jobs=16, scaler=scaler)\n", + " with open(\n", + " results_path / f\"logreg_text_graph_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl\", \"w\"\n", + " ) as f:\n", + " for result in logreg_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")\n", + " print(f\"RUN {run} logreg_text_graph_reranking_seq2seq completed\")\n", "\n", - "logreg_ranker = LogisticRegressionRanker(\n", - " sequence_features=features_map[\"text\"], graph_features=features_map[\"graph\"]\n", - ")\n", - "logreg_ranker.fit(train_df, n_jobs=8, scaler=scaler)\n", - "with open(\n", - " results_path / f\"logreg_text_graph_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", - ") as f:\n", - " for result in logreg_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")\n", - "\n", - "\n", - "logreg_ranker = LogisticRegressionRanker(sequence_features=features_map[\"g2t_determ\"])\n", - "logreg_ranker.fit(train_df, n_jobs=8)\n", - "with open(\n", - " results_path / f\"logreg_g2t_determ_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", - ") as f:\n", - " for result in logreg_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")\n", + " logreg_ranker = LogisticRegressionRanker(sequence_features=features_map[\"g2t_determ\"])\n", + " logreg_ranker.fit(train_df, n_jobs=16)\n", + " with open(\n", + " results_path / f\"logreg_g2t_determ_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl\", \"w\"\n", + " ) as f:\n", + " for result in logreg_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")\n", + " print(f\"RUN {run} logreg_g2t_determ_reranking_seq2seq completed\")\n", "\n", + " logreg_ranker = LogisticRegressionRanker(sequence_features=features_map[\"g2t_t5\"])\n", + " logreg_ranker.fit(train_df, n_jobs=16)\n", + " with open(\n", + " results_path / f\"logreg_g2t_t5_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl\", \"w\"\n", + " ) as f:\n", + " for result in logreg_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")\n", + " print(f\"RUN {run} logreg_g2t_t5_reranking_seq2seq completed\")\n", "\n", - "logreg_ranker = LogisticRegressionRanker(sequence_features=features_map[\"g2t_t5\"])\n", - "logreg_ranker.fit(train_df, n_jobs=8)\n", - "with open(\n", - " results_path / f\"logreg_g2t_t5_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", - ") as f:\n", - " for result in logreg_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")\n", - "\n", - "\n", - "logreg_ranker = LogisticRegressionRanker(sequence_features=features_map[\"g2t_gap\"])\n", - "logreg_ranker.fit(train_df, n_jobs=8)\n", - "with open(\n", - " results_path / f\"logreg_g2t_gap_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", - ") as f:\n", - " for result in logreg_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")\n", - "\n", - "\n", - "logreg_ranker = LogisticRegressionRanker(\n", - " sequence_features=features_map[\"text\"] + features_map[\"g2t_determ\"],\n", - " graph_features=features_map[\"graph\"],\n", - ")\n", - "logreg_ranker.fit(train_df, n_jobs=8, scaler=scaler)\n", - "with open(\n", - " results_path\n", - " / f\"logreg_text_g2t_determ_graph_reranking_seq2seq_{ds_type}_results.jsonl\",\n", - " \"w\",\n", - ") as f:\n", - " for result in logreg_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")\n", + " logreg_ranker = LogisticRegressionRanker(sequence_features=features_map[\"g2t_gap\"])\n", + " logreg_ranker.fit(train_df, n_jobs=16)\n", + " with open(\n", + " results_path / f\"logreg_g2t_gap_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl\", \"w\"\n", + " ) as f:\n", + " for result in logreg_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")\n", + " print(f\"RUN {run} logreg_g2t_gap_reranking_seq2seq completed\")\n", "\n", + " logreg_ranker = LogisticRegressionRanker(\n", + " sequence_features=features_map[\"text\"] + features_map[\"g2t_determ\"],\n", + " graph_features=features_map[\"graph\"],\n", + " )\n", + " logreg_ranker.fit(train_df, n_jobs=16, scaler=scaler)\n", + " with open(\n", + " results_path\n", + " / f\"logreg_text_g2t_determ_graph_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl\",\n", + " \"w\",\n", + " ) as f:\n", + " for result in logreg_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")\n", + " print(f\"RUN {run} logreg_text_g2t_determ_graph_reranking_seq2seq completed\")\n", "\n", - "logreg_ranker = LogisticRegressionRanker(\n", - " sequence_features=features_map[\"text\"] + features_map[\"g2t_t5\"],\n", - " graph_features=features_map[\"graph\"],\n", - ")\n", - "logreg_ranker.fit(train_df, n_jobs=8, scaler=scaler)\n", - "with open(\n", - " results_path\n", - " / f\"logreg_text_g2t_t5_graph_reranking_seq2seq_{ds_type}_results.jsonl\",\n", - " \"w\",\n", - ") as f:\n", - " for result in logreg_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")\n", + " logreg_ranker = LogisticRegressionRanker(\n", + " sequence_features=features_map[\"text\"] + features_map[\"g2t_t5\"],\n", + " graph_features=features_map[\"graph\"],\n", + " )\n", + " logreg_ranker.fit(train_df, n_jobs=16, scaler=scaler)\n", + " with open(\n", + " results_path\n", + " / f\"logreg_text_g2t_t5_graph_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl\",\n", + " \"w\",\n", + " ) as f:\n", + " for result in logreg_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")\n", + " print(f\"RUN {run} logreg_text_g2t_t5_graph_reranking_seq2seq completed\")\n", "\n", - "logreg_ranker = LogisticRegressionRanker(\n", - " sequence_features=features_map[\"text\"] + features_map[\"g2t_gap\"],\n", - " graph_features=features_map[\"graph\"],\n", - ")\n", - "logreg_ranker.fit(train_df, n_jobs=8, scaler=scaler)\n", - "with open(\n", - " results_path\n", - " / f\"logreg_text_g2t_gap_graph_reranking_seq2seq_{ds_type}_results.jsonl\",\n", - " \"w\",\n", - ") as f:\n", - " for result in logreg_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")" + " logreg_ranker = LogisticRegressionRanker(\n", + " sequence_features=features_map[\"text\"] + features_map[\"g2t_gap\"],\n", + " graph_features=features_map[\"graph\"],\n", + " )\n", + " logreg_ranker.fit(train_df, n_jobs=16, scaler=scaler)\n", + " with open(\n", + " results_path\n", + " / f\"logreg_text_g2t_gap_graph_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl\",\n", + " \"w\",\n", + " ) as f:\n", + " for result in logreg_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")\n", + " print(f\"RUN {run} logreg_text_g2t_gap_graph_reranking_seq2seq completed\")" ] }, { @@ -650,106 +912,113 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ - "logreg_ranker = LinearRegressionRanker(sequence_features=features_map[\"text\"])\n", - "logreg_ranker.fit(train_df, n_jobs=8)\n", - "with open(\n", - " results_path / f\"linreg_text_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", - ") as f:\n", - " for result in logreg_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")\n", + "for run in [1, 2, 3]:\n", + " random.seed(42 + run)\n", + " np.random.seed(42 + run)\n", + " torch.manual_seed(42 + run)\n", + " torch.cuda.manual_seed_all(42 + run)\n", + " set_seed(42 + run)\n", + " \n", + " logreg_ranker = LinearRegressionRanker(sequence_features=features_map[\"text\"])\n", + " logreg_ranker.fit(train_df, n_jobs=8)\n", + " with open(\n", + " results_path / f\"linreg_text_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl\", \"w\"\n", + " ) as f:\n", + " for result in logreg_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")\n", "\n", "\n", - "logreg_ranker = LinearRegressionRanker(graph_features=features_map[\"graph\"])\n", - "logreg_ranker.fit(train_df, n_jobs=8)\n", - "with open(\n", - " results_path / f\"linreg_graph_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", - ") as f:\n", - " for result in logreg_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")\n", + " logreg_ranker = LinearRegressionRanker(graph_features=features_map[\"graph\"])\n", + " logreg_ranker.fit(train_df, n_jobs=8)\n", + " with open(\n", + " results_path / f\"linreg_graph_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl\", \"w\"\n", + " ) as f:\n", + " for result in logreg_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")\n", "\n", "\n", - "logreg_ranker = LinearRegressionRanker(\n", - " sequence_features=features_map[\"text\"], graph_features=features_map[\"graph\"]\n", - ")\n", - "logreg_ranker.fit(train_df, n_jobs=8)\n", - "with open(\n", - " results_path / f\"linreg_text_graph_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", - ") as f:\n", - " for result in logreg_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")\n", + " logreg_ranker = LinearRegressionRanker(\n", + " sequence_features=features_map[\"text\"], graph_features=features_map[\"graph\"]\n", + " )\n", + " logreg_ranker.fit(train_df, n_jobs=8)\n", + " with open(\n", + " results_path / f\"linreg_text_graph_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl\", \"w\"\n", + " ) as f:\n", + " for result in logreg_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")\n", "\n", "\n", - "logreg_ranker = LinearRegressionRanker(sequence_features=features_map[\"g2t_determ\"])\n", - "logreg_ranker.fit(train_df, n_jobs=8)\n", - "with open(\n", - " results_path / f\"linreg_g2t_determ_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", - ") as f:\n", - " for result in logreg_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")\n", + " logreg_ranker = LinearRegressionRanker(sequence_features=features_map[\"g2t_determ\"])\n", + " logreg_ranker.fit(train_df, n_jobs=8)\n", + " with open(\n", + " results_path / f\"linreg_g2t_determ_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl\", \"w\"\n", + " ) as f:\n", + " for result in logreg_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")\n", "\n", "\n", - "logreg_ranker = LinearRegressionRanker(sequence_features=features_map[\"g2t_t5\"])\n", - "logreg_ranker.fit(train_df, n_jobs=8)\n", - "with open(\n", - " results_path / f\"linreg_g2t_t5_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", - ") as f:\n", - " for result in logreg_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")\n", + " logreg_ranker = LinearRegressionRanker(sequence_features=features_map[\"g2t_t5\"])\n", + " logreg_ranker.fit(train_df, n_jobs=8)\n", + " with open(\n", + " results_path / f\"linreg_g2t_t5_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl\", \"w\"\n", + " ) as f:\n", + " for result in logreg_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")\n", "\n", "\n", - "logreg_ranker = LinearRegressionRanker(sequence_features=features_map[\"g2t_gap\"])\n", - "logreg_ranker.fit(train_df, n_jobs=8)\n", - "with open(\n", - " results_path / f\"linreg_g2t_gap_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", - ") as f:\n", - " for result in logreg_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")\n", + " logreg_ranker = LinearRegressionRanker(sequence_features=features_map[\"g2t_gap\"])\n", + " logreg_ranker.fit(train_df, n_jobs=8)\n", + " with open(\n", + " results_path / f\"linreg_g2t_gap_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl\", \"w\"\n", + " ) as f:\n", + " for result in logreg_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")\n", "\n", "\n", - "logreg_ranker = LinearRegressionRanker(\n", - " sequence_features=features_map[\"text\"] + features_map[\"g2t_determ\"],\n", - " graph_features=features_map[\"graph\"],\n", - ")\n", - "logreg_ranker.fit(train_df, n_jobs=8)\n", - "with open(\n", - " results_path\n", - " / f\"linreg_text_g2t_determ_graph_reranking_seq2seq_{ds_type}_results.jsonl\",\n", - " \"w\",\n", - ") as f:\n", - " for result in logreg_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")\n", + " logreg_ranker = LinearRegressionRanker(\n", + " sequence_features=features_map[\"text\"] + features_map[\"g2t_determ\"],\n", + " graph_features=features_map[\"graph\"],\n", + " )\n", + " logreg_ranker.fit(train_df, n_jobs=8)\n", + " with open(\n", + " results_path\n", + " / f\"linreg_text_g2t_determ_graph_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl\",\n", + " \"w\",\n", + " ) as f:\n", + " for result in logreg_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")\n", "\n", "\n", - "logreg_ranker = LinearRegressionRanker(\n", - " sequence_features=features_map[\"text\"] + features_map[\"g2t_t5\"],\n", - " graph_features=features_map[\"graph\"],\n", - ")\n", - "logreg_ranker.fit(train_df, n_jobs=8)\n", - "with open(\n", - " results_path\n", - " / f\"linreg_text_g2t_t5_graph_reranking_seq2seq_{ds_type}_results.jsonl\",\n", - " \"w\",\n", - ") as f:\n", - " for result in logreg_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")\n", + " logreg_ranker = LinearRegressionRanker(\n", + " sequence_features=features_map[\"text\"] + features_map[\"g2t_t5\"],\n", + " graph_features=features_map[\"graph\"],\n", + " )\n", + " logreg_ranker.fit(train_df, n_jobs=8)\n", + " with open(\n", + " results_path\n", + " / f\"linreg_text_g2t_t5_graph_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl\",\n", + " \"w\",\n", + " ) as f:\n", + " for result in logreg_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")\n", "\n", "\n", - "logreg_ranker = LinearRegressionRanker(\n", - " sequence_features=features_map[\"text\"] + features_map[\"g2t_gap\"],\n", - " graph_features=features_map[\"graph\"],\n", - ")\n", - "logreg_ranker.fit(train_df, n_jobs=8)\n", - "with open(\n", - " results_path\n", - " / f\"linreg_text_g2t_gap_graph_reranking_seq2seq_{ds_type}_results.jsonl\",\n", - " \"w\",\n", - ") as f:\n", - " for result in logreg_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")" + " logreg_ranker = LinearRegressionRanker(\n", + " sequence_features=features_map[\"text\"] + features_map[\"g2t_gap\"],\n", + " graph_features=features_map[\"graph\"],\n", + " )\n", + " logreg_ranker.fit(train_df, n_jobs=8)\n", + " with open(\n", + " results_path\n", + " / f\"linreg_text_g2t_gap_graph_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl\",\n", + " \"w\",\n", + " ) as f:\n", + " for result in logreg_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")" ] }, { @@ -833,7 +1102,7 @@ "source": [ "device = torch.device(\"cuda\")\n", "\n", - "model_path = \"/workspace/storage/misc/subgraphs_reranking_results/question_answer/T5-xl-ssm/question_answer_nocherries_fixed_train/outputs/checkpoint-best\"\n", + "model_path = f\"/home/jovyan/kbqa/subgraphs_reranking_runs/sequence/question_answer/KGQASubgraphsRanking/run_1/{ds_type}/outputs/checkpoint-best\"\n", "mpnet_ranker = MPNetRanker(\"question_answer\", model_path, device)\n", "with open(\n", " results_path / f\"mpnet_text_only_determ_reranking_seq2seq_{ds_type}_results.jsonl\",\n", @@ -842,7 +1111,7 @@ " for result in mpnet_ranker.rerank(test_df):\n", " f.write(json.dumps(result) + \"\\n\")\n", "\n", - "model_path = f\"/mnt/storage/QA_System_Project/subgraphs_reranking_runs/determ/T5-{ds_type}-ssm/no_cherries_hl_false_determ/outputs/checkpoint-best\"\n", + "model_path = f\"/home/jovyan/kbqa/subgraphs_reranking_runs/sequence/no_highlighted_determ_sequence/KGQASubgraphsRanking/run_1/{ds_type}/outputs/checkpoint-best\"\n", "mpnet_ranker = MPNetRanker(\"no_highlighted_determ_sequence\", model_path, device)\n", "with open(\n", " results_path / f\"mpnet_no_hl_g2t_determ_reranking_seq2seq_{ds_type}_results.jsonl\",\n", @@ -851,7 +1120,7 @@ " for result in mpnet_ranker.rerank(test_df):\n", " f.write(json.dumps(result) + \"\\n\")\n", "\n", - "model_path = f\"/mnt/storage/QA_System_Project/subgraphs_reranking_runs/determ/T5-{ds_type}-ssm/no_cherries_hl_true_determ/outputs/checkpoint-best/\"\n", + "model_path = f\"/home/jovyan/kbqa/subgraphs_reranking_runs/sequence/highlighted_determ_sequence/KGQASubgraphsRanking/run_1/{ds_type}/outputs/checkpoint-best\"\n", "mpnet_ranker = MPNetRanker(\"highlighted_determ_sequence\", model_path, device)\n", "with open(\n", " results_path / f\"mpnet_hl_g2t_determ_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", @@ -860,7 +1129,7 @@ " f.write(json.dumps(result) + \"\\n\")\n", "\n", "\n", - "model_path = f\"/mnt/storage/QA_System_Project/subgraphs_reranking_runs/g2t/T5-{ds_type}-ssm/no_cherries_fixed_train_g2t_hl_true_large/outputs/checkpoint-best\"\n", + "model_path = f\"/home/jovyan/kbqa/subgraphs_reranking_runs/sequence/highlighted_t5_sequence/KGQASubgraphsRanking/run_1/{ds_type}/outputs/checkpoint-best\"\n", "mpnet_ranker = MPNetRanker(\"highlighted_t5_sequence\", model_path, device)\n", "with open(\n", " results_path / f\"mpnet_hl_g2t_t5_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", @@ -868,7 +1137,7 @@ " for result in mpnet_ranker.rerank(test_df):\n", " f.write(json.dumps(result) + \"\\n\")\n", "\n", - "model_path = f\"/mnt/storage/QA_System_Project/subgraphs_reranking_runs/g2t/T5-{ds_type}-ssm/no_cherries_fixed_train_g2t_hl_false_large/outputs/checkpoint-best\"\n", + "model_path = f\"/home/jovyan/kbqa/subgraphs_reranking_runs/sequence/no_highlighted_t5_sequence/KGQASubgraphsRanking/run_1/{ds_type}/outputs/checkpoint-best\"\n", "mpnet_ranker = MPNetRanker(\"no_highlighted_t5_sequence\", model_path, device)\n", "with open(\n", " results_path / f\"mpnet_no_hl_g2t_t5_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", @@ -877,7 +1146,7 @@ " f.write(json.dumps(result) + \"\\n\")\n", "\n", "\n", - "model_path = f\"/mnt/storage/QA_System_Project/subgraphs_reranking_runs/gap/T5-{ds_type}-ssm/no_cherries_fixed_train_hl_true_gap_large/outputs/checkpoint-best\"\n", + "model_path = f\"/home/jovyan/kbqa/subgraphs_reranking_runs/sequence/highlighted_gap_sequence/KGQASubgraphsRanking/run_1/{ds_type}/outputs/checkpoint-best\"\n", "mpnet_ranker = MPNetRanker(\"highlighted_gap_sequence\", model_path, device)\n", "with open(\n", " results_path / f\"mpnet_hl_g2t_gap_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", @@ -885,7 +1154,7 @@ " for result in mpnet_ranker.rerank(test_df):\n", " f.write(json.dumps(result) + \"\\n\")\n", "\n", - "model_path = f\"/mnt/storage/QA_System_Project/subgraphs_reranking_runs/gap/T5-{ds_type}-ssm/no_cherries_fixed_train_hl_false_gap_large/outputs/checkpoint-best\"\n", + "model_path = f\"/home/jovyan/kbqa/subgraphs_reranking_runs/sequence/no_highlighted_gap_sequence/KGQASubgraphsRanking/run_1/{ds_type}/outputs/checkpoint-best\"\n", "mpnet_ranker = MPNetRanker(\"no_highlighted_gap_sequence\", model_path, device)\n", "with open(\n", " results_path / f\"mpnet_no_hl_g2t_gap_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", @@ -1273,6 +1542,95 @@ " f.write(json.dumps(result) + \"\\n\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Semantic Ranking\n" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4125dc4b5b634492ab2783e892694def", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/4000 [00:00 pd.DataFrame: """merge mintaka, vanilla LLM outputs and subgraph datasets""" + mintaka_df =mintaka_ds.to_pandas() outputs_df = pd.merge( - mintaka_ds.to_pandas(), + mintaka_df[mintaka_df['lang'] == 'en'], outputs_ds.to_pandas(), on="question", how="left", ) + if "id_x" in outputs_df.columns: + outputs_df.rename(columns={"id_x": "id"}, inplace=True) + merged_df = pd.merge( outputs_df[["id"] + list(outputs_ds.features.keys())], features_ds.to_pandas(), - on=["id", "question"], + on=["question"], how="left", ) return merged_df @@ -37,8 +41,14 @@ def prepare_data( mintaka_ds: Dataset, outputs_ds: Dataset, features_ds: Dataset ) -> pd.DataFrame: """merge mintaka, vanilla LLM outputs and subgraph datasets""" + dataframe = merge_datasets(mintaka_ds, outputs_ds, features_ds) dataframe = compile_seq2seq_outputs_to_model_answers_column(dataframe) + + if "id_x" in dataframe.columns: + dataframe.rename(columns={"id_x": "id"}, inplace=True) + dataframe = dataframe.loc[:, ~dataframe.columns.duplicated()] + return dataframe diff --git a/experiments/subgraphs_reranking/sequence/train_sequence_ranker.py b/experiments/subgraphs_reranking/sequence/train_sequence_ranker.py index e4fffe9..1763c1d 100644 --- a/experiments/subgraphs_reranking/sequence/train_sequence_ranker.py +++ b/experiments/subgraphs_reranking/sequence/train_sequence_ranker.py @@ -48,11 +48,18 @@ default="s-nlp/KGQASubgraphsRanking", help="Path to train sequence data file (HF)", ) +parse.add_argument( + "--ds_type", + type=str, + default="t5largessm", + choices=["t5largessm", "t5xlssm", "mistral", "mixtral"], + help="Dataset type to use", +) parse.add_argument( "--output_path", type=str, - default="/workspace/storage/misc/subgraphs_reranking_results", + default="./subgraphs_reranking_runs/sequence/", ) parse.add_argument( @@ -71,7 +78,7 @@ parse.add_argument( "--wandb_on", - default=True, + default=False, type=lambda x: (str(x).lower() == "true"), help="Using WanDB or not (True/False)", ) @@ -79,19 +86,19 @@ parse.add_argument( "--per_device_train_batch_size", type=int, - default=32, + default=16, ) parse.add_argument( "--per_device_eval_batch_size", type=int, - default=32, + default=64, ) parse.add_argument( "--num_train_epochs", type=int, - default=6, + default=2, ) parse.add_argument( "--do_highlighting", @@ -131,7 +138,7 @@ def get_labels(self): labels.append(int(i["labels"].cpu().detach().numpy())) return labels - def _get_train_sampler(self) -> torch.utils.data.Sampler: + def _get_train_sampler(self, dataset) -> torch.utils.data.Sampler: """create our custom sampler""" labels = self.get_labels() return self.create_sampler(labels) @@ -190,11 +197,7 @@ def __len__(self): if args.wandb_on: os.environ["WANDB_NAME"] = args.run_name - model_folder = args.data_path.split("_")[-1] # either large or xl - output_path = f"{args.output_path}/{args.sequence_type}/{model_folder}" - Path(output_path).mkdir(parents=True, exist_ok=True) - - subgraphs_dataset = load_dataset(args.data_path) + subgraphs_dataset = load_dataset(args.data_path, data_dir=f"{args.ds_type}_subgraphs") train_df = subgraphs_dataset["train"].to_pandas() val_df = subgraphs_dataset["validation"].to_pandas() @@ -217,6 +220,10 @@ def __len__(self): else: SEQ_TYPE = f"{HL_TYPE}_{args.sequence_type}_sequence" + model_folder = args.data_path.split("_")[-1] # either large or xl + output_path = Path(args.output_path) / SEQ_TYPE / model_folder + output_path.mkdir(parents=True, exist_ok=True) + train_dataset = SequenceDataset(train_df, tokenizer, SEQ_TYPE) val_dataset = SequenceDataset(val_df, tokenizer, SEQ_TYPE) @@ -234,7 +241,7 @@ def __len__(self): greater_is_better=True, logging_steps=500, save_steps=500, - evaluation_strategy="steps", + eval_strategy="steps", report_to="wandb" if args.wandb_on else "tensorboard", ) @@ -248,7 +255,7 @@ def __len__(self): trainer.train() checkpoint_best_path = ( - Path(output_path) / args.run_name / "outputs" / "checkpoint-best" + output_path / args.run_name / f"{args.ds_type}" / "outputs" / "checkpoint-best" ) model.save_pretrained(checkpoint_best_path) tokenizer.save_pretrained(checkpoint_best_path) diff --git a/kbqa/seq2seq/train.py b/kbqa/seq2seq/train.py index 3981864..186e4af 100644 --- a/kbqa/seq2seq/train.py +++ b/kbqa/seq2seq/train.py @@ -1,3 +1,4 @@ +import os import datasets from .redirect_trainer import Seq2SeqWikidataRedirectsTrainer from ..wikidata.wikidata_redirects import WikidataRedirectsCache @@ -27,7 +28,7 @@ def train( per_device_eval_batch_size: int = 1, warmup_steps: int = 500, weight_decay: float = 0.01, - evaluation_strategy: str = "steps", + eval_strategy: str = "steps", eval_steps: int = 500, logging_steps: int = 500, gradient_accumulation_steps: int = 8, @@ -57,13 +58,13 @@ def train( per_device_eval_batch_size (int, optional): eval batch size per device. Defaults to 1. warmup_steps (int, optional): warmup steps for traning. Defaults to 500. weight_decay (float, optional): weight decay for traning. Defaults to 0.01. - evaluation_strategy (str, optional): + eval_strategy (str, optional): "no": No evaluation is done during training; "steps": Evaluation is done (and logged) every eval_steps; "epoch": Evaluation is done at the end of each epoch; Defaults to 'steps'. eval_steps (int, optional): - Number of update steps between two evaluations if evaluation_strategy="steps". + Number of update steps between two evaluations if eval_strategy="steps". Will default to the same value as logging_steps if not set. Defaults to 500. logging_steps (int, optional): @@ -80,26 +81,31 @@ def train( Returns: Seq2SeqTrainer: Trained after traning and validation """ - training_args = Seq2SeqTrainingArguments( - run_name=run_name, - report_to=report_to, - output_dir=output_dir, - num_train_epochs=num_train_epochs, - max_steps=max_steps, - per_device_train_batch_size=per_device_train_batch_size, - per_device_eval_batch_size=per_device_eval_batch_size, - warmup_steps=warmup_steps, - weight_decay=weight_decay, - logging_dir=logging_dir, - evaluation_strategy=evaluation_strategy, - eval_steps=eval_steps, - save_steps=eval_steps, - save_strategy="steps", - save_total_limit=save_total_limit, - logging_steps=logging_steps, - load_best_model_at_end=True, - gradient_accumulation_steps=gradient_accumulation_steps, - ) + training_args_dict = { + "run_name": run_name, + "report_to": report_to, + "output_dir": output_dir, + "num_train_epochs": num_train_epochs, + "max_steps": max_steps, + "per_device_train_batch_size": per_device_train_batch_size, + "per_device_eval_batch_size": per_device_eval_batch_size, + "warmup_steps": warmup_steps, + "weight_decay": weight_decay, + "logging_dir": logging_dir, + "eval_strategy": eval_strategy, + "eval_steps": eval_steps, + "save_steps": eval_steps, + "save_strategy": "steps", + "save_total_limit": save_total_limit, + "logging_steps": logging_steps, + "load_best_model_at_end": True, + "gradient_accumulation_steps": gradient_accumulation_steps, + } + + if "LOCAL_RANK" not in os.environ and "RANK" not in os.environ: + training_args_dict["local_rank"] = -1 + + training_args = Seq2SeqTrainingArguments(**training_args_dict) if trainer_mode == "default": trainer = Seq2SeqTrainer( diff --git a/kbqa/seq2seq/utils.py b/kbqa/seq2seq/utils.py index bc41726..dce01b3 100644 --- a/kbqa/seq2seq/utils.py +++ b/kbqa/seq2seq/utils.py @@ -284,6 +284,63 @@ def load_mintaka_seq2seq_dataset( return dataset +def load_mkqa_seq2seq_dataset( + train_json_path: str, + test_json_path: str, + tokenizer: PreTrainedTokenizer, + split: str = None, + use_convert_to_features: bool = True, +): + """load_mkqa_seq2seq_dataset - helper for loading MKQA dataset for seq2seq + + Args: + train_json_path (str): Path to mkqa_train.json file + test_json_path (str): Path to mkqa_test.json file + tokenizer (PreTrainedTokenizer): Tokenizer of seq2seq model + split (str, optional): Load only train/test split if passed, else load all. Defaults to None. + use_convert_to_features (bool, optional): Converting dataset to features for seq2seq training/evaluation pipeline. Defaults to True. + + Returns: + datasets.arrow_dataset.Dataset or datasets.DatasetDict: Prepared dataset for seq2seq + """ + data_files = {"train": train_json_path, "test": test_json_path} + + if split is None: + dataset = datasets.load_dataset("json", data_files=data_files) + else: + split_map = {"train": train_json_path, "test": test_json_path} + if split not in split_map: + raise ValueError(f"split must be 'train' or 'test', got {split}") + dataset = datasets.load_dataset("json", data_files={split: split_map[split]}, split=split) + + if use_convert_to_features is True: + if isinstance(dataset, datasets.DatasetDict): + dataset = dataset.map( + lambda batch: convert_to_features( + batch, tokenizer, label_feature_name="answerText" + ), + batched=True, + ) + for split_name in dataset: + dataset[split_name].set_format( + type="torch", + columns=["input_ids", "labels", "attention_mask"], + ) + else: + dataset = dataset.map( + lambda batch: convert_to_features( + batch, tokenizer, label_feature_name="answerText" + ), + batched=True, + ) + dataset.set_format( + type="torch", + columns=["input_ids", "labels", "attention_mask"], + ) + + return dataset + + def hf_model_name_mormolize(model_name: str) -> str: """hf_model_name_mormolize - return normolized model name for storing to directory Example: facebook/bart-large -> facebook_bart-large diff --git a/mintaka_evaluate.py b/mintaka_evaluate.py index 1c946a9..6f2ec53 100644 --- a/mintaka_evaluate.py +++ b/mintaka_evaluate.py @@ -94,7 +94,7 @@ class EvalMintaka: """EvalMintaka Evaluation class for Mintaka ranked predictions""" def __init__(self): - mintaka_ds = load_dataset("AmazonScience/mintaka") + mintaka_ds = load_dataset("AmazonScience/mintaka", revision="refs/convert/parquet", data_dir=f"en") self.dataset = { "train": mintaka_ds["train"].to_pandas(), "validation": mintaka_ds["validation"].to_pandas(), @@ -158,17 +158,27 @@ def evaluate(self, predictions, split: str = "test", top_n: int = 10): """ _df = self.dataset[split] - is_correct = [] - for prediction in tqdm(predictions, desc="Process predictions.."): + import concurrent.futures + + def process_prediction(prediction): question_idx = prediction["QuestionID"] mintaka_record = _df[_df["id"] == question_idx].iloc[0] - is_answer_correct_results = [] - for _, answer in enumerate(prediction["RankedAnswers"]): - is_answer_correct_results.append( - self.is_answer_correct(mintaka_record, answer) - ) + is_answer_correct_results = [ + self.is_answer_correct(mintaka_record, answer) + for answer in prediction["RankedAnswers"] + ] + return is_answer_correct_results - is_correct.append(is_answer_correct_results) + is_correct = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: + results = list( + tqdm( + executor.map(process_prediction, predictions), + total=len(predictions), + desc="Process predictions.." + ) + ) + is_correct.extend(results) is_correct_df = pd.DataFrame(is_correct) is_correct_df["id"] = [p["QuestionID"] for p in predictions] diff --git a/mkqa_test.json b/mkqa_test.json new file mode 100644 index 0000000..1ecc45f --- /dev/null +++ b/mkqa_test.json @@ -0,0 +1,42 @@ +[ + { + "id": 6, + "lang": "en", + "question": "who sings i hear you knocking but you can't come in", + "answerText": "Dave Edmunds", + "answerEntity": [ + { + "name": "Q545186", + "label": "Dave Edmunds" + } + ], + "questionEntity": [ + { + "name": "Q3147103", + "entityType": "entity", + "label": "I Hear You Knocking", + "mention": "i hear you knocking" + } + ] + }, + { + "id": 11, + "lang": "en", + "question": "who sang the song oh what a night", + "answerText": "The Four Seasons", + "answerEntity": [ + { + "name": "Q687785", + "label": "The Four Seasons" + } + ], + "questionEntity": [ + { + "name": "Q5248905", + "entityType": "entity", + "label": "December, 1963 (Oh, What a Night)", + "mention": "oh what a night" + } + ] + } +] \ No newline at end of file diff --git a/mkqa_train.json b/mkqa_train.json new file mode 100644 index 0000000..ded02bf --- /dev/null +++ b/mkqa_train.json @@ -0,0 +1,370 @@ +[ + { + "id": 34, + "lang": "en", + "question": "who sings i'll make a man out of you", + "answerText": "Donny Osmond", + "answerEntity": [ + { + "name": "Q386053", + "label": "Donny Osmond" + }, + { + "name": "Q5296743", + "label": "Donny Osmond" + } + ], + "questionEntity": [ + { + "name": "Q5965889", + "entityType": "entity", + "label": "I'll Make a Man Out of You", + "mention": "i'll make a man out of you" + }, + { + "name": "Q386053", + "entityType": "entity", + "label": "Donny Osmond", + "mention": "i'll make a man out of you" + } + ] + }, + { + "id": 36, + "lang": "en", + "question": "where is the world cup going to be held", + "answerText": "Qatar", + "answerEntity": [ + { + "name": "Q846", + "label": "Qatar" + } + ], + "questionEntity": [ + { + "name": "Q19317", + "entityType": "entity", + "label": "FIFA World Cup", + "mention": "world cup" + } + ] + }, + { + "id": 37, + "lang": "en", + "question": "what are the names of the tweenies characters", + "answerText": "Judy", + "answerEntity": [ + { + "name": "Q55615419", + "label": "Judy" + } + ], + "questionEntity": [ + { + "name": "Q1569073", + "entityType": "entity", + "label": "Tweenies", + "mention": "tweenies characters" + } + ] + }, + { + "id": 38, + "lang": "en", + "question": "who plays black panther in the movie black panther", + "answerText": "Chadwick Boseman", + "answerEntity": [ + { + "name": "Q5066520", + "label": "Chadwick Boseman" + } + ], + "questionEntity": [ + { + "name": "Q23780734", + "entityType": "entity", + "label": "Black Panther", + "mention": "the movie black panther" + } + ] + }, + { + "id": 43, + "lang": "en", + "question": "what is the active ingredient in all fda approved hand sanitizers", + "answerText": "ethanol", + "answerEntity": [ + { + "name": "Q153", + "label": "ethanol" + } + ], + "questionEntity": [ + { + "name": "Q204711", + "entityType": "entity", + "label": "Food and Drug Administration", + "mention": "fda" + }, + { + "name": "Q520181", + "entityType": "entity", + "label": "hand sanitizer", + "mention": "hand sanitizers" + } + ] + }, + { + "id": 46, + "lang": "en", + "question": "who starred in the movie bridge over the river kwai", + "answerText": "William Holden", + "answerEntity": [ + { + "name": "Q95002", + "label": "William Holden" + }, + { + "name": "Q103894", + "label": "Alec Guinness" + }, + { + "name": "Q26118", + "label": "Jack Hawkins" + } + ], + "questionEntity": [ + { + "name": "Q188718", + "entityType": "entity", + "label": "The Bridge on the River Kwai", + "mention": "bridge over the river kwai" + } + ] + }, + { + "id": 47, + "lang": "en", + "question": "what were the names of the knights of the round table", + "answerText": "Gareth", + "answerEntity": [ + { + "name": "Q1413003", + "label": "Gareth" + }, + { + "name": "Q831462", + "label": "Galahad" + }, + { + "name": "Q728510", + "label": "Percival" + }, + { + "name": "Q215681", + "label": "Lancelot" + }, + { + "name": "Q831685", + "label": "Gawain" + } + ], + "questionEntity": [ + { + "name": "Q1644266", + "entityType": "entity", + "label": "Knights of the Round Table", + "mention": "knights of the round table" + } + ] + }, + { + "id": 48, + "lang": "en", + "question": "who is the host of america has talent", + "answerText": "Simon Cowell", + "answerEntity": [ + { + "name": "Q162629", + "label": "Simon Cowell" + }, + { + "name": "Q232646", + "label": "Julianne Hough" + }, + { + "name": "Q1190974", + "label": "Howie Mandel" + }, + { + "name": "Q231648", + "label": "Gabrielle Union" + }, + { + "name": "Q271464", + "label": "Terry Crews" + } + ], + "questionEntity": [ + { + "name": "Q947873", + "entityType": "entity", + "label": "television presenter", + "mention": "host" + }, + { + "name": "Q467561", + "entityType": "entity", + "label": "America's Got Talent", + "mention": "america has talent" + } + ] + }, + { + "id": 56, + "lang": "en", + "question": "who played snow white in once upon a time", + "answerText": "Ginnifer Goodwin", + "answerEntity": [ + { + "name": "Q109522", + "label": "Ginnifer Goodwin" + } + ], + "questionEntity": [ + { + "name": "Q11831", + "entityType": "entity", + "label": "Snow White", + "mention": "snow white" + } + ] + }, + { + "id": 57, + "lang": "en", + "question": "who was the oldest man who ever lived", + "answerText": "Jeanne Calment", + "answerEntity": [ + { + "name": "Q182260", + "label": "Jeanne Calment" + }, + { + "name": "Q550149", + "label": "Jiroemon Kimura" + } + ], + "questionEntity": [ + { + "name": "Q550149", + "entityType": "entity", + "label": "Jiroemon Kimura", + "mention": "the oldest man who ever lived" + }, + { + "name": "Q6581097", + "entityType": "entity", + "label": "male", + "mention": "man" + } + ] + }, + { + "id": 64, + "lang": "en", + "question": "actors in the movie gone in 60 seconds", + "answerText": "Angelina Jolie", + "answerEntity": [ + { + "name": "Q13909", + "label": "Angelina Jolie" + }, + { + "name": "Q1320560", + "label": "William Lee Scott" + }, + { + "name": "Q314673", + "label": "Scott Caan" + }, + { + "name": "Q224081", + "label": "Giovanni Ribisi" + }, + { + "name": "Q42869", + "label": "Nicolas Cage" + }, + { + "name": "Q171736", + "label": "Robert Duvall" + } + ], + "questionEntity": [ + { + "name": "Q33999", + "entityType": "entity", + "label": "actor", + "mention": "actors" + } + ] + }, + { + "id": 71, + "lang": "en", + "question": "who is the captain of west indies in cricket", + "answerText": "Jason Holder", + "answerEntity": [ + { + "name": "Q6162716", + "label": "Jason Holder" + }, + { + "name": "Q18684975", + "label": "Carlos Brathwaite" + } + ], + "questionEntity": [ + { + "name": "Q912881", + "entityType": "entity", + "label": "West Indies cricket team", + "mention": "west indies" + }, + { + "name": "Q5375", + "entityType": "entity", + "label": "cricket", + "mention": "cricket" + } + ] + }, + { + "id": 72, + "lang": "en", + "question": "what is the largest lake in new zealand", + "answerText": "Lake Taupo", + "answerEntity": [ + { + "name": "Q199903", + "label": "Lake Taupo" + } + ], + "questionEntity": [ + { + "name": "Q199903", + "entityType": "entity", + "label": "Lake Taupō", + "mention": "largest lake" + }, + { + "name": "Q664", + "entityType": "entity", + "label": "New Zealand", + "mention": "new zealand" + } + ] + } +] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f4941d8..6658a46 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,40 +1,53 @@ -torch==2.0.1 -transformers==4.30.2 -peft @ git+https://github.com/huggingface/peft -bitsandbytes -datasets -pandas==1.4.3 -pylint==2.14.4 +torch>=2.7.1 +transformers==4.55.0 +peft>=0.6.0 +pandas>=1.5.0 +pylint>=2.15.0 SPARQLWrapper==2.0.0 -sacrebleu -sentencepiece -pywikibot +sacrebleu>=2.3.0 +sentencepiece>=0.1.99 +pywikibot>=8.0.0 mwparserfromhell>=0.5.0 -evaluate==0.2.2 +evaluate>=0.4.0 pywikidata -networkx -matplotlib -unidecode -requests -bs4 -marisa_trie -jupyterlab -scikit-learn -tensorflow<=2.10 -spacy -mlflow -gradio -python-telegram-bot -pydash -jsonlines -ujson -sentence_transformers +networkx>=3.0 +matplotlib>=3.6.0 +unidecode>=1.3.0 +requests>=2.28.0 +beautifulsoup4>=4.11.0 +marisa-trie>=0.7.8 +jupyterlab>=4.0.0 +scikit-learn>=1.2.0 +spacy<3.8.0,>=3.4.0 +mlflow>=2.8.0 +gradio>=4.0.0 +python-telegram-bot>=20.0 +pydash>=5.1.0 +jsonlines>=3.1.0 +ujson>=5.8.0 +sentence-transformers>=2.2.0 captum==0.6.0 torch-model-archiver==0.7.0 torchserve==0.7.0 black==22.3.0 -xgboost -seaborn -gudhi -giotto-tda -wandb +xgboost>=2.0.0 +seaborn>=0.12.0 +gudhi>=3.7.0 +giotto-tda>=0.6.0 +wandb>=0.15.0 +datasets==3.1.0 +tqdm>=4.65.0 +graphviz>=0.20.0 +nltk>=3.8.0 +python-igraph>=0.10.0 +catboost>=1.2.0 +hydra-core>=1.3.0 +omegaconf>=2.3.0 +joblib>=1.3.0 +dvc>=3.0.0 +altair>=5.0.0 +numpy>=1.24.0 +packaging>=21.0 +filelock>=3.9.0 +huggingface-hub>=0.16.0 +tensorboardx diff --git a/seq2seq.py b/seq2seq.py index 0d38a85..0c7b390 100644 --- a/seq2seq.py +++ b/seq2seq.py @@ -15,6 +15,7 @@ load_kbqa_seq2seq_dataset, load_mintaka_seq2seq_dataset, load_lcquad2_seq2seq_dataset, + load_mkqa_seq2seq_dataset, load_model_and_tokenizer_by_name, ) from kbqa.utils.train_eval import get_best_checkpoint_path @@ -110,13 +111,13 @@ ) parser.add_argument( "--num_beams", - default=200, + default=30, help="Numbers of beams for Beam search (only for eval mode)", type=int, ) parser.add_argument( "--num_return_sequences", - default=200, + default=30, help=( "Numbers of return sequencese from Beam search (only for eval mode)." " Must be less or equal to num_beams" @@ -125,7 +126,7 @@ ) parser.add_argument( "--num_beam_groups", - default=20, + default=3, help=( "Number of groups to divide num_beams into in order to ensure diversity " "among different groups of beams (only for eval mode). " @@ -190,6 +191,20 @@ def train(args, model_dir, logging_dir): split="test", ) + elif args.dataset_name == "mkqa": + train_json_path = Path("mkqa_train.json") + test_json_path = Path("mkqa_test.json") + if not train_json_path.exists() and args.dataset_cache_dir: + train_json_path = Path(args.dataset_cache_dir) / "mkqa_train.json" + if not test_json_path.exists() and args.dataset_cache_dir: + test_json_path = Path(args.dataset_cache_dir) / "mkqa_test.json" + dataset = load_mkqa_seq2seq_dataset( + str(train_json_path), + str(test_json_path), + tokenizer, + ) + dataset["validation"] = dataset["test"] + else: dataset = load_kbqa_seq2seq_dataset( args.dataset_name, @@ -271,6 +286,25 @@ def evaluate(args, model_dir, normolized_model_name): f"Lcquad2.0 Eval: Dataset loaded, label_feature_name={label_feature_name}" ) + elif args.dataset_name == "mkqa": + train_json_path = Path("mkqa_train.json") + test_json_path = Path("mkqa_test.json") + if not train_json_path.exists() and args.dataset_cache_dir: + train_json_path = Path(args.dataset_cache_dir) / "mkqa_train.json" + if not test_json_path.exists() and args.dataset_cache_dir: + test_json_path = Path(args.dataset_cache_dir) / "mkqa_test.json" + split = args.dataset_evaluation_split if args.dataset_evaluation_split else "test" + dataset = load_mkqa_seq2seq_dataset( + str(train_json_path), + str(test_json_path), + tokenizer, + split=split, + ) + label_feature_name = "answerText" + logger.info( + f"Eval: MKQA Dataset loaded, label_feature_name={label_feature_name}" + ) + else: dataset = load_kbqa_seq2seq_dataset( args.dataset_name, From 759917e6d0d30ccaeb178a1638babb5d1904dc8c Mon Sep 17 00:00:00 2001 From: "d.yarosh" Date: Sun, 30 Nov 2025 10:55:05 +0300 Subject: [PATCH 02/10] Added hf mkqa dataset --- kbqa/utils/train_eval.py | 6 +++++- seq2seq.py | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/kbqa/utils/train_eval.py b/kbqa/utils/train_eval.py index 708a394..e441ca8 100644 --- a/kbqa/utils/train_eval.py +++ b/kbqa/utils/train_eval.py @@ -25,6 +25,10 @@ def get_best_checkpoint_path(path_to_checkpoints: str) -> str: last_checkpint_path = pathes[-1] with open(last_checkpint_path / "trainer_state.json", "r") as file_handler: train_state = json.load(file_handler) - best_model_checkpoint = Path(train_state["best_model_checkpoint"]).name + if train_state["best_model_checkpoint"] is None: + best_model_checkpoint = last_checkpint_path.name + else: + best_model_checkpoint = Path(train_state["best_model_checkpoint"]).name + print("Used checkpoint: ", best_model_checkpoint) return path_to_checkpoints / best_model_checkpoint diff --git a/seq2seq.py b/seq2seq.py index 0c7b390..cd6ae6d 100644 --- a/seq2seq.py +++ b/seq2seq.py @@ -191,6 +191,13 @@ def train(args, model_dir, logging_dir): split="test", ) + elif args.dataset_name == "mkqa-hf": + dataset = load_mintaka_seq2seq_dataset( + 'Dms12/mkqa_mintaka_format_with_question_entities', + args.dataset_config_name, + tokenizer, + ) + elif args.dataset_name == "mkqa": train_json_path = Path("mkqa_train.json") test_json_path = Path("mkqa_test.json") @@ -286,6 +293,18 @@ def evaluate(args, model_dir, normolized_model_name): f"Lcquad2.0 Eval: Dataset loaded, label_feature_name={label_feature_name}" ) + elif args.dataset_name == "mkqa-hf": + dataset = load_mintaka_seq2seq_dataset( + 'Dms12/mkqa_mintaka_format_with_question_entities', + args.dataset_config_name, + tokenizer, + split=args.dataset_evaluation_split, + ) + label_feature_name = "answerText" + logger.info( + f"Eval: MKQA Dataset loaded, label_feature_name={label_feature_name}" + ) + elif args.dataset_name == "mkqa": train_json_path = Path("mkqa_train.json") test_json_path = Path("mkqa_test.json") From 674a3c737ab7291c04d4fab1740f23c2804d292c Mon Sep 17 00:00:00 2001 From: "d.yarosh" Date: Sun, 30 Nov 2025 20:57:08 +0300 Subject: [PATCH 03/10] Added triples extraction for mkqa dataset --- .../mkqa_subgraphs_prepairing.py | 119 ++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 experiments/subgraphs_datasets_prepare_input_data/mkqa_subgraphs_prepairing.py diff --git a/experiments/subgraphs_datasets_prepare_input_data/mkqa_subgraphs_prepairing.py b/experiments/subgraphs_datasets_prepare_input_data/mkqa_subgraphs_prepairing.py new file mode 100644 index 0000000..b52fdca --- /dev/null +++ b/experiments/subgraphs_datasets_prepare_input_data/mkqa_subgraphs_prepairing.py @@ -0,0 +1,119 @@ +import pandas as pd +from pywikidata import Entity +from tqdm.auto import tqdm +import ujson +import datasets +from wd_api import get_wd_search_results +from multiprocessing import Pool, cpu_count + +model_name = 't5-large-ssm' +predictions_path = f'../../{model_name}-res/google_{model_name}/evaluation/version_0/results.csv' + + +def label_to_entity(label: str, top_k: int = 3) -> list: + """label_to_entity method to linking label to WikiData entity ID + by using elasticsearch Wikimedia public API + Supported only English language (en) + + Parameters + ---------- + label : str + label of entity to search + top_k : int, optional + top K results from WikiData, by default 3 + + Returns + ------- + list[str] | None + list of entity IDs or None if not found + """ + try: + elastic_results = get_wd_search_results(label, top_k, language='en')[:top_k] + except: + elastic_results = [] + + try: + elastic_results.extend( + get_wd_search_results(label.replace("\"", "").replace("\'", "").strip(), top_k, language='en')[:top_k] + ) + except: + return None + + return list(dict.fromkeys(elastic_results).keys())[:top_k] + + +def data_to_subgraphs(df): + for _, row in tqdm(df.iterrows(), total=df.index.size): + # if row['complexityType'] not in ['count', 'yesno']: + question_entity_ids = [e['name'] for e in row['questionEntity'] if e['entityType'] == 'entity'] + + for candidate_label in dict.fromkeys(row['model_answers']).keys(): + for candidate_entity_id in label_to_entity(candidate_label): + candidate_entity = Entity(candidate_entity_id) + yield { + 'id': row['id'], + 'question': row['question'], + 'generatedAnswer': [candidate_label], + 'answerEntity': [candidate_entity.idx], + 'answerEntityLabel': [candidate_entity.label], + 'questionEntity': question_entity_ids, + 'groundTruthAnswerEntity': [e['name'] for e in row['answerEntity']] + } + + +def process_row(row): + results = [] + print(f'Start: {row['id']}') + # print("HERE!") + question_entity_ids = [e['name'] for e in row['questionEntity'] if e['entityType'] == 'entity'] + for candidate_label in dict.fromkeys(row['model_answers']).keys(): + for candidate_entity_id in label_to_entity(candidate_label): + candidate_entity = Entity(candidate_entity_id) + results.append({ + 'id': row['id'], + 'question': row['question'], + 'generatedAnswer': [candidate_label], + 'answerEntity': [candidate_entity.idx], + 'answerEntityLabel': [candidate_entity.label], + 'questionEntity': question_entity_ids, + 'groundTruthAnswerEntity': [e['name'] for e in row['answerEntity']] + }) + + print(f'End: {row['id']}') + return results + + +def eval_df(df): + num_processes = cpu_count() + # Convert DataFrame to list of dictionaries for processing + rows = df.to_dict('records') + # print(rows) + # rows = rows[:num_processes] + + # Create pool and process rows + with Pool(processes=num_processes) as pool: + results = pool.map(process_row, rows) + + # Convert results back to DataFrame + results = [item for sublist in results for item in sublist] + return results + + +if __name__ == '__main__': + test_predictions = pd.read_csv(predictions_path) + ds = datasets.load_dataset("Dms12/mkqa_mintaka_format_with_question_entities") + + answer_columns = [col for col in test_predictions.columns if col.startswith('answer_')] + test_predictions['model_answers'] = test_predictions[answer_columns].values.tolist() + test_predictions = test_predictions.drop(columns=answer_columns) + + test_df = pd.merge( + test_predictions, + ds['test'].to_pandas(), + on=['question'], + ) + + results = eval_df(test_df) + with open(f'../../{model_name}_test.jsonl', 'w') as f: + for data_line in results: + f.write(ujson.dumps(data_line) + '\n') From 0f3256d11b842cbc1f46ddb43a6da7d75057eb51 Mon Sep 17 00:00:00 2001 From: "d.yarosh" Date: Tue, 2 Dec 2025 11:45:49 +0300 Subject: [PATCH 04/10] Fixed scripts for graphs prepairing --- .../mkqa_subgraphs_prepairing.py | 96 +++++++++++++++---- .../mining_subgraphs_dataset_processes.py | 87 ++++++++++------- 2 files changed, 134 insertions(+), 49 deletions(-) diff --git a/experiments/subgraphs_datasets_prepare_input_data/mkqa_subgraphs_prepairing.py b/experiments/subgraphs_datasets_prepare_input_data/mkqa_subgraphs_prepairing.py index b52fdca..6df2837 100644 --- a/experiments/subgraphs_datasets_prepare_input_data/mkqa_subgraphs_prepairing.py +++ b/experiments/subgraphs_datasets_prepare_input_data/mkqa_subgraphs_prepairing.py @@ -1,3 +1,9 @@ +import time +from collections import deque +from functools import wraps +from threading import Lock +from time import sleep + import pandas as pd from pywikidata import Entity from tqdm.auto import tqdm @@ -6,10 +12,47 @@ from wd_api import get_wd_search_results from multiprocessing import Pool, cpu_count -model_name = 't5-large-ssm' -predictions_path = f'../../{model_name}-res/google_{model_name}/evaluation/version_0/results.csv' +model_name = 't5-xl-ssm-nq' +type = 'train' +predictions_path = f'../../{model_name}-res/google_{model_name}/evaluation/version_0_{type}/results.csv' + + +def rate_limit(max_calls=15, period=60): + def decorator(func): + # Store state in closure variables + calls = deque() # Store timestamps of recent calls + lock = Lock() # Thread safety lock + + @wraps(func) + def wrapper(*args, **kwargs): + with lock: + current_time = time.time() + + # Remove timestamps older than the period + while calls and calls[0] <= current_time - period: + calls.popleft() + + # Check if we've exceeded the rate limit + if len(calls) >= max_calls: + oldest = calls[0] + wait_time = oldest + period - current_time + if wait_time > 0: + print(f"Rate limit, sleep {wait_time}") + time.sleep(wait_time) + # After sleeping, update current_time and clean old calls again + current_time = time.time() + while calls and calls[0] <= current_time - period: + calls.popleft() + + # Record this call and execute the function + calls.append(current_time) + + return func(*args, **kwargs) + return wrapper + return decorator +@rate_limit(max_calls=30, period=60) def label_to_entity(label: str, top_k: int = 3) -> list: """label_to_entity method to linking label to WikiData entity ID by using elasticsearch Wikimedia public API @@ -27,17 +70,35 @@ def label_to_entity(label: str, top_k: int = 3) -> list: list[str] | None list of entity IDs or None if not found """ - try: - elastic_results = get_wd_search_results(label, top_k, language='en')[:top_k] - except: - elastic_results = [] - - try: - elastic_results.extend( - get_wd_search_results(label.replace("\"", "").replace("\'", "").strip(), top_k, language='en')[:top_k] - ) - except: - return None + retry = True + while retry: + try: + elastic_results = get_wd_search_results(label, top_k, language='en')[:top_k] + except Exception as e: + print(f"First e: {e}") + if '429' in str(e): + # print(f"Retry first for: {e}") + sleep(1001) + else: + retry = False + elastic_results = [] + else: + retry = False + + retry = True + while retry: + try: + elastic_results.extend( + get_wd_search_results(label.replace("\"", "").replace("\'", "").strip(), top_k, language='en')[:top_k] + ) + except Exception as e: + print(f"Second e: {e}") + if '429' in str(e): + sleep(1001) + else: + retry = False + else: + retry = False return list(dict.fromkeys(elastic_results).keys())[:top_k] @@ -63,7 +124,7 @@ def data_to_subgraphs(df): def process_row(row): results = [] - print(f'Start: {row['id']}') + print(f"Start: {row['id']}") # print("HERE!") question_entity_ids = [e['name'] for e in row['questionEntity'] if e['entityType'] == 'entity'] for candidate_label in dict.fromkeys(row['model_answers']).keys(): @@ -79,12 +140,13 @@ def process_row(row): 'groundTruthAnswerEntity': [e['name'] for e in row['answerEntity']] }) - print(f'End: {row['id']}') + print(f"End: {row['id']}") return results def eval_df(df): num_processes = cpu_count() + print("Run with processes:", num_processes) # Convert DataFrame to list of dictionaries for processing rows = df.to_dict('records') # print(rows) @@ -109,11 +171,11 @@ def eval_df(df): test_df = pd.merge( test_predictions, - ds['test'].to_pandas(), + ds[f'{type}'].to_pandas(), on=['question'], ) results = eval_df(test_df) - with open(f'../../{model_name}_test.jsonl', 'w') as f: + with open(f'../../{model_name}_{type}.jsonl', 'w') as f: for data_line in results: f.write(ujson.dumps(data_line) + '\n') diff --git a/subgraphs_dataset_creation/mining_subgraphs_dataset_processes.py b/subgraphs_dataset_creation/mining_subgraphs_dataset_processes.py index 5136806..4ff8334 100644 --- a/subgraphs_dataset_creation/mining_subgraphs_dataset_processes.py +++ b/subgraphs_dataset_creation/mining_subgraphs_dataset_processes.py @@ -153,7 +153,7 @@ def igraph_to_nx(subgraph: ig.Graph): return nx_subgraph -def write_from_queue(save_jsonl_path: str, results_q: JoinableQueue): +def write_from_queue(save_jsonl_path: str, results_q: JoinableQueue, n_jobs: int): """given a queue, write the queue to the save_jsonl_path file Args: @@ -161,15 +161,21 @@ def write_from_queue(save_jsonl_path: str, results_q: JoinableQueue): results_q (JoinableQueue): result queue (to write our results from) """ with open(save_jsonl_path, "a+", encoding="utf-8") as file_handler: - while True: + finished_workers = 0 + while finished_workers < n_jobs: try: json_obj = results_q.get() - file_handler.write(json_obj + "\n") + if json_obj == "END!": + finished_workers += 1 + else: + file_handler.write(json_obj + "\n") except QueueEmpty: continue else: results_q.task_done() + print("Finished writing queue") + def read_wd_graph(wd_graph_path: str) -> ig.Graph: """given the path, parse the triples and build @@ -209,39 +215,50 @@ def find_subgraph_and_transform_to_json( f"[{now()}]{proc_worker_header}[{os.getpid()}] Current process memory (Gb)", psutil.Process(os.getpid()).memory_info().rss / (1024.0**3), ) - while True: + is_working = True + while is_working: try: task_line = task_q.get() - start_time = time.time() - data = ujson.loads(task_line) - try: - subgraph = extract_subgraph( - wd_graph, data["answerEntity"], data["questionEntity"] - ) - except ValueError as value_err: - with open("ErrorsLog.jsonl", "a+", encoding="utf-8") as file: - data["error"] = str(value_err) - file.write(ujson.dumps(data) + "\n") - continue - except Exception as general_exception: # pylint: disable=broad-except - print(str(general_exception)) - time.sleep(60) - subgraph = extract_subgraph( - wd_graph, data["answerEntity"], data["questionEntity"] - ) - - nx_subgraph = igraph_to_nx(subgraph) - data["graph"] = nx.node_link_data(nx_subgraph) - - results_q.put(ujson.dumps(data)) + if task_line == "END!": + results_q.put(task_line) + is_working = False + else: + start_time = time.time() + data = ujson.loads(task_line) + try: + subgraph = extract_subgraph( + wd_graph, data["answerEntity"], data["questionEntity"] + ) + except ValueError as value_err: + with open("ErrorsLog.jsonl", "a+", encoding="utf-8") as file: + data["error"] = str(value_err) + file.write(ujson.dumps(data) + "\n") + continue + except Exception as general_exception: # pylint: disable=broad-except + print(str(general_exception)) + time.sleep(60) + subgraph = extract_subgraph( + wd_graph, data["answerEntity"], data["questionEntity"] + ) + + nx_subgraph = igraph_to_nx(subgraph) + data["graph"] = nx.node_link_data(nx_subgraph) + + results_q.put(ujson.dumps(data)) except QueueEmpty: continue else: task_queue.task_done() - print( - f"[{now()}]{proc_worker_header}[{os.getpid()}] \ - SSP task completed ({time.time() - start_time}s)" - ) + if task_line == "END!": + print( + f"[{now()}]{proc_worker_header}[{os.getpid()}] \ + Received End of tasks. Send the same to writer!" + ) + else: + print( + f"[{now()}]{proc_worker_header}[{os.getpid()}] \ + SSP task completed ({time.time() - start_time}s)" + ) if __name__ == "__main__": @@ -253,6 +270,8 @@ def find_subgraph_and_transform_to_json( proc_worker_header = f"{BColors.OKGREEN}[Process Worker]{BColors.ENDC}" print(f"[{now()}]] Start loading WD Graph") parsed_wd_graph = read_wd_graph(args.igraph_wikidata_path) + # parsed_wd_graph = None + print( f"[{now()}]]{BColors.OKGREEN} \ WD Graph loaded{BColors.ENDC}" @@ -263,7 +282,7 @@ def find_subgraph_and_transform_to_json( task_queue = JoinableQueue(maxsize=queue_max_size) writing_thread = Process( target=write_from_queue, - args=[args.save_jsonl_path, results_queue], + args=[args.save_jsonl_path, results_queue, args.n_jobs], daemon=True, ) writing_thread.start() @@ -276,7 +295,8 @@ def find_subgraph_and_transform_to_json( daemon=True, ) p.start() - time.sleep(180) + time.sleep(30) + # time.sleep(1) with open( args.subgraphs_dataset_prepared_entities_jsonl_path, "r", encoding="utf-8" @@ -296,6 +316,9 @@ def find_subgraph_and_transform_to_json( {results_queue.qsize():4d}; task_queue size: {task_queue.qsize():4d}" ) + for _ in range(args.n_jobs): + task_queue.put("END!") + print(f"[{now()}]{BColors.HEADER}[Main Thread]{BColors.ENDC} All tasks sent") task_queue.join() results_queue.join() From a39eab7b775acda8162f28d90371910dde59f15b Mon Sep 17 00:00:00 2001 From: Mikhail Salnikov <2613180+MihailSalnikov@users.noreply.github.com> Date: Tue, 2 Dec 2025 19:30:32 +0300 Subject: [PATCH 05/10] Update processing features based on subgraphs --- .../graph_features_preparation.py | 338 ++++++++++++++---- mistral_mixtral.py | 58 ++- 2 files changed, 316 insertions(+), 80 deletions(-) diff --git a/experiments/subgraphs_reranking/graph_features_preparation.py b/experiments/subgraphs_reranking/graph_features_preparation.py index 414a128..ba4c033 100644 --- a/experiments/subgraphs_reranking/graph_features_preparation.py +++ b/experiments/subgraphs_reranking/graph_features_preparation.py @@ -1,5 +1,7 @@ """ prepare the graph features dataset from scratch""" import argparse +import json +import os from ast import literal_eval import yaml from datasets import load_dataset, Dataset, DatasetDict @@ -16,51 +18,72 @@ parse = argparse.ArgumentParser() parse.add_argument( "--subgraphs_dataset_path", - default="s-nlp/Mintaka_Subgraphs_T5_xl_ssm", + default=None, type=str, - help="Path for the subgraphs dataset (HF)", + help="Path for the subgraphs dataset (HF). Required if not using JSON files.", +) + +parse.add_argument( + "--subgraphs_train_path", + type=str, + default=None, + help="Path to train JSON/JSONL file (optional)", +) + +parse.add_argument( + "--subgraphs_val_path", + type=str, + default=None, + help="Path to validation JSON/JSONL file (optional)", +) + +parse.add_argument( + "--subgraphs_test_path", + type=str, + default=None, + help="Path to test JSON/JSONL file (optional)", ) parse.add_argument( "--g2t_t5_train_path", type=str, - default="/workspace/storage/misc/train_results_mintaka_T5XL.yaml", + default=None, help="Path to g2t train yaml file", ) parse.add_argument( "--g2t_t5_test_path", type=str, - default="/workspace/storage/misc/test_results_mintaka_T5Large.yaml", + default=None, help="Path to g2t test yaml file", ) parse.add_argument( "--g2t_t5_val_path", type=str, - default="/workspace/storage/misc/val_results_t5_xl.yaml", - help="Path to g2t test yaml file", + default=None, + help="Path to g2t val yaml file", ) parse.add_argument( "--g2t_gap_train_path", type=str, - default="/workspace/storage/misc/gap_train_mintaka_large_predictions.txt", - help="Path to g2t train yaml file", + default=None, + help="Path to g2t gap train txt file", ) parse.add_argument( "--g2t_gap_test_path", type=str, - default="/workspace/storage/misc/gap_test_mintaka_large_predictions.txt", - help="Path to g2t test yaml file", + default=None, + help="Path to g2t gap test txt file", ) parse.add_argument( "--g2t_gap_val_path", type=str, - default="/workspace/storage/misc/gap_val_mintaka_t5_xl_filtered_predictions.txt", - help="Path to g2t test yaml file", + default=None, + help="Path to g2t gap val txt file", ) parse.add_argument( @@ -77,6 +100,59 @@ help="path to upload to HuggingFace", ) +parse.add_argument( + "--g2t_types", + type=str, + nargs="+", + default=["determ", "t5", "gap"], + choices=["determ", "t5", "gap"], + help="G2T types to process: determ (G2T Deterministic), t5 (G2T T5), gap (G2T GAP). " + "When any G2T type is selected, G2T Deterministic is always included.", +) + +parse.add_argument( + "--subset_name", + type=str, + default="mkqa_t5large", + help="Name for the subset when pushing to HuggingFace. Subset will be named 'name_subgraphs'.", +) + + +def load_json_dataset(json_path): + """load dataset from JSON/JSONL file""" + data = [] + with open(json_path, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + data.append(json.loads(line)) + + df = pd.DataFrame(data) + + # Compute 'correct' field if missing + if "correct" not in df.columns: + if "groundTruthAnswerEntity" in df.columns and "answerEntity" in df.columns: + def compute_correct(row): + answer_entity = row["answerEntity"] + if isinstance(answer_entity, str): + answer_entity = [answer_entity] + elif not isinstance(answer_entity, list): + answer_entity = [str(answer_entity)] + ground_truth = row["groundTruthAnswerEntity"] + if isinstance(ground_truth, str): + ground_truth = [ground_truth] + elif not isinstance(ground_truth, list): + ground_truth = [str(ground_truth)] + return 1.0 if any(str(ans) in ground_truth for ans in answer_entity) else 0.0 + df["correct"] = df.apply(compute_correct, axis=1) + else: + df["correct"] = 0.0 + + # Convert graph to string if it's a dict + if "graph" in df.columns and df["graph"].dtype == "object": + df["graph"] = df["graph"].apply(lambda x: json.dumps(x) if isinstance(x, dict) else x) + + return df + def get_g2t_seqs(g2t_path): """proccess the g2t yaml file and return list of g2t seqs""" @@ -98,13 +174,18 @@ def get_gap_seqs(gap_path): return gap_seqs -def add_new_seqs(g2t_path, gap_path, dataframe): - """get the new g2t and gap seqs and add to df""" - g2t_seqs = get_g2t_seqs(g2t_path) - gap_seqs = get_gap_seqs(gap_path) - - dataframe["g2t_sequence"] = g2t_seqs - dataframe["gap_sequence"] = gap_seqs +def add_new_seqs(g2t_t5_path, g2t_gap_path, dataframe, g2t_types): + """get the new g2t and gap seqs and add to df based on selected g2t_types""" + if "t5" in g2t_types: + if g2t_t5_path is None: + raise ValueError("G2T T5 path is required when processing G2T T5 sequences") + g2t_seqs = get_g2t_seqs(g2t_t5_path) + dataframe["g2t_sequence"] = g2t_seqs + if "gap" in g2t_types: + if g2t_gap_path is None: + raise ValueError("G2T GAP path is required when processing G2T GAP sequences") + gap_seqs = get_gap_seqs(g2t_gap_path) + dataframe["gap_sequence"] = gap_seqs return dataframe @@ -196,12 +277,19 @@ def arr_to_str(arr): return ",".join(str(a) for a in arr) -def try_literal_eval(strng): - """str representation to object""" - try: - return literal_eval(strng) - except ValueError: - return strng +def parse_graph(graph_data): + """parse graph data from dict, JSON string, or Python literal string""" + if isinstance(graph_data, dict): + return graph_data + if isinstance(graph_data, str): + try: + return json.loads(graph_data) + except (json.JSONDecodeError, ValueError): + try: + return literal_eval(graph_data) + except (ValueError, SyntaxError): + return graph_data + return graph_data def find_candidate_note(graph): @@ -213,23 +301,30 @@ def find_candidate_note(graph): raise ValueError("Cannot find answer candidate entity") -def get_features(dataframe, model, device): - """get the graph features for the df""" +def get_features(dataframe, model, device, g2t_types): + """get the graph features for the df based on selected g2t_types""" + # Identify fields to preserve from original dataframe + fields_to_preserve = ["answerEntity", "groundTruthAnswerEntity", "questionEntity", "graph", "correct"] + available_fields = [f for f in fields_to_preserve if f in dataframe.columns] + dict_list = [] - for _, row in tqdm(dataframe.iterrows()): + for _, row in tqdm(dataframe.iterrows(), total=len(dataframe), desc="Processing rows"): # convert from json dict to networkx graph - graph_obj = json_graph.node_link_graph(try_literal_eval(row["graph"])) - graph_node_names = get_node_names(graph_obj) - g2t_seq = row["g2t_sequence"] - gap_seq = row["gap_sequence"] + graph_data = parse_graph(row["graph"]) + graph_obj = json_graph.node_link_graph(graph_data, edges="links") + graph_node_names_no_highlight = get_node_names(graph_obj, highlight=False) + graph_node_names_highlight = get_node_names(graph_obj, highlight=True) # skip if we have no answer candidates in our graph try: ans_cand_id = find_candidate_note(graph_obj) + # Get answerEntity - handle both list and single value + answer_entity = row.get("answerEntity", "") + if isinstance(answer_entity, list) and len(answer_entity) > 0: + answer_entity = answer_entity[0] ques_ans = ( - f"{row['question']} ; {find_label(graph_obj, row['answerEntity'])}" + f"{row['question']} ; {find_label(graph_obj, answer_entity)}" ) - determ_seq = graph_to_sequence(graph_obj, graph_node_names) # build the features curr_dict = { @@ -247,26 +342,66 @@ def get_features(dataframe, model, device): "katz_centrality": nx.katz_centrality(graph_obj)[ans_cand_id], "page_rank": nx.pagerank(graph_obj)[ans_cand_id], "avg_ssp_length": get_distance_ans_cand(graph_obj, ans_cand_id), - "determ_sequence": determ_seq, - "gap_sequence": gap_seq, - "g2t_sequence": g2t_seq, - # embedding data - "determ_sequence_embedding": arr_to_str( - model.encode(determ_seq, device=device, convert_to_numpy=True) - ), - "gap_sequence_embedding": arr_to_str( - model.encode(gap_seq, device=device, convert_to_numpy=True) - ), - "g2t_sequence_embedding": arr_to_str( - model.encode(g2t_seq, device=device, convert_to_numpy=True) - ), - "question_answer_embedding": arr_to_str( - model.encode(ques_ans, device=device, convert_to_numpy=True) - ), - "tfidf_vector": np.array([]), - # label - "correct": float(row["correct"]), } + + # Always preserve original fields from JSON (answerEntity, groundTruthAnswerEntity, questionEntity, graph, correct) + for field in available_fields: + if field in row: + if field == "graph": + # Preserve graph in original format (should be string after load_json_dataset) + graph_original = row["graph"] + if isinstance(graph_original, dict): + # If still a dict, convert to JSON string + curr_dict["graph"] = json.dumps(graph_original) + elif isinstance(graph_original, str): + # Already a string, preserve as-is + curr_dict["graph"] = graph_original + else: + # Fallback: convert to string + curr_dict["graph"] = str(graph_original) + elif field == "correct": + curr_dict["correct"] = float(row["correct"]) + else: + # Preserve answerEntity, groundTruthAnswerEntity, questionEntity as-is + curr_dict[field] = row[field] + + # Process G2T Deterministic - generate both highlighted and no_highlighted versions + if "determ" in g2t_types: + # No highlighted version + no_highlighted_determ_seq = graph_to_sequence(graph_obj, graph_node_names_no_highlight) + curr_dict["no_highlighted_determ_sequence"] = no_highlighted_determ_seq + curr_dict["no_highlighted_determ_sequence_embedding"] = arr_to_str( + model.encode(no_highlighted_determ_seq, device=device, convert_to_numpy=True) + ) + + # Highlighted version + highlighted_determ_seq = graph_to_sequence(graph_obj, graph_node_names_highlight) + curr_dict["highlighted_determ_sequence"] = highlighted_determ_seq + curr_dict["highlighted_determ_sequence_embedding"] = arr_to_str( + model.encode(highlighted_determ_seq, device=device, convert_to_numpy=True) + ) + + # Process G2T T5 + if "t5" in g2t_types: + g2t_seq = row["g2t_sequence"] + curr_dict["g2t_sequence"] = g2t_seq + curr_dict["g2t_sequence_embedding"] = arr_to_str( + model.encode(g2t_seq, device=device, convert_to_numpy=True) + ) + + # Process G2T GAP + if "gap" in g2t_types: + gap_seq = row["gap_sequence"] + curr_dict["gap_sequence"] = gap_seq + curr_dict["gap_sequence_embedding"] = arr_to_str( + model.encode(gap_seq, device=device, convert_to_numpy=True) + ) + + # Always include question_answer embedding + curr_dict["question_answer_embedding"] = arr_to_str( + model.encode(ques_ans, device=device, convert_to_numpy=True) + ) + except: # pylint: disable=bare-except continue dict_list.append(curr_dict) @@ -277,30 +412,93 @@ def get_features(dataframe, model, device): if __name__ == "__main__": args = parse.parse_args() - - subgraphs_dataset = load_dataset( - args.subgraphs_dataset_path, cache_dir="/workspace/storage/misc/huggingface" - ) - train_df = subgraphs_dataset["train"].to_pandas() - val_df = subgraphs_dataset["validation"].to_pandas() - test_df = subgraphs_dataset["test"].to_pandas() + g2t_types = args.g2t_types + + # Always include determ when any G2T type is selected + if "determ" not in g2t_types and len(g2t_types) > 0: + g2t_types = ["determ"] + g2t_types + + # Load dataset from JSON files or HuggingFace + train_df = None + val_df = None + test_df = None + + if args.subgraphs_train_path or args.subgraphs_val_path or args.subgraphs_test_path: + # Load from JSON files + if args.subgraphs_train_path: + if not os.path.exists(args.subgraphs_train_path): + raise ValueError(f"Train JSON file not found: {args.subgraphs_train_path}") + train_df = load_json_dataset(args.subgraphs_train_path) + + if args.subgraphs_val_path: + if not os.path.exists(args.subgraphs_val_path): + raise ValueError(f"Validation JSON file not found: {args.subgraphs_val_path}") + val_df = load_json_dataset(args.subgraphs_val_path) + + if args.subgraphs_test_path: + if not os.path.exists(args.subgraphs_test_path): + raise ValueError(f"Test JSON file not found: {args.subgraphs_test_path}") + test_df = load_json_dataset(args.subgraphs_test_path) + elif args.subgraphs_dataset_path: + # Load from HuggingFace + subgraphs_dataset = load_dataset( + args.subgraphs_dataset_path, cache_dir="/workspace/storage/misc/huggingface" + ) + train_df = subgraphs_dataset["train"].to_pandas() + val_df = subgraphs_dataset["validation"].to_pandas() + test_df = subgraphs_dataset["test"].to_pandas() + else: + raise ValueError( + "Either --subgraphs_dataset_path (HF) or at least one of " + "--subgraphs_train_path/--subgraphs_val_path/--subgraphs_test_path (JSON) must be provided" + ) # adding the new g2t sequences to subgraph dataset - train_df = add_new_seqs(args.g2t_t5_train_path, args.g2t_gap_train_path, train_df) - test_df = add_new_seqs(args.g2t_t5_test_path, args.g2t_gap_test_path, test_df) - val_df = add_new_seqs(args.g2t_t5_val_path, args.g2t_gap_val_path, val_df) + if train_df is not None: + train_df = add_new_seqs( + args.g2t_t5_train_path, args.g2t_gap_train_path, train_df, g2t_types + ) + if test_df is not None: + test_df = add_new_seqs( + args.g2t_t5_test_path, args.g2t_gap_test_path, test_df, g2t_types + ) + if val_df is not None: + val_df = add_new_seqs( + args.g2t_t5_val_path, args.g2t_gap_val_path, val_df, g2t_types + ) # get all features and add to df smodel = SentenceTransformer("all-mpnet-base-v2") curr_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - processed_train_df = get_features(train_df, smodel, curr_device) - processed_test_df = get_features(test_df, smodel, curr_device) - processed_val_df = get_features(val_df, smodel, curr_device) + + processed_train_df = None + processed_test_df = None + processed_val_df = None + + if train_df is not None: + processed_train_df = get_features(train_df, smodel, curr_device, g2t_types) + if test_df is not None: + processed_test_df = get_features(test_df, smodel, curr_device, g2t_types) + if val_df is not None: + processed_val_df = get_features(val_df, smodel, curr_device, g2t_types) # upload to HF if args.upload_dataset: ds = DatasetDict() - ds["train"] = Dataset.from_pandas(processed_train_df) - ds["validation"] = Dataset.from_pandas(processed_val_df) - ds["test"] = Dataset.from_pandas(processed_test_df) - ds.push_to_hub(args.hf_path) + if processed_train_df is not None: + ds["train"] = Dataset.from_pandas(processed_train_df) + if processed_val_df is not None: + ds["validation"] = Dataset.from_pandas(processed_val_df) + if processed_test_df is not None: + ds["test"] = Dataset.from_pandas(processed_test_df) + if len(ds) > 0: + subset_name = f"{args.subset_name}_subgraphs" + # Push dataset with config_name to create a subset/configuration + # Can be loaded later with: load_dataset(hf_path, subset_name) + try: + ds.push_to_hub(args.hf_path, config_name=subset_name) + except (TypeError, ValueError): + # If config_name is not supported with DatasetDict, push normally + # Note: subset organization may need to be handled differently + ds.push_to_hub(args.hf_path) + print(f"Note: Pushed dataset without config_name. Subset name '{subset_name}' is for reference only.") diff --git a/mistral_mixtral.py b/mistral_mixtral.py index 0fa9edc..dba6c6d 100644 --- a/mistral_mixtral.py +++ b/mistral_mixtral.py @@ -1,6 +1,7 @@ """script for mistral and mixtral""" from pathlib import Path import pickle +import json from tqdm import tqdm from argparse import ArgumentParser import random @@ -39,6 +40,11 @@ help="trained model path for mixtral or mistral", default="mistralai/Mixtral-8x7B-Instruct-v0.1", ) +parser.add_argument( + "--model_checkpoint_path", + default=None, + help="Direct path to LoRA adapter checkpoint directory (for eval mode). If provided, will use this instead of constructing path from output_dir", +) parser.add_argument( "--mode", default="train", @@ -57,6 +63,14 @@ help="file path for the generated answer candidates", ) parser.add_argument("--evaluation_split", default="test") +parser.add_argument("--dataset_name", default="AmazonScience/mintaka") +parser.add_argument("--dataset_config_name", default="en") +parser.add_argument( + "--num_beams", + default=30, + type=int, + help="Numbers of beams for Beam search (only for eval mode)", +) # prompt to feed mistral/mixtral # pylint: disable=line-too-long @@ -180,7 +194,7 @@ def train(args, dataset): training_args = TrainingArguments( num_train_epochs=3, output_dir=args.output_dir, - evaluation_strategy="steps", + eval_strategy="steps", eval_steps=10, save_steps=10, save_total_limit=3, @@ -227,8 +241,17 @@ def evaluate(args, dataset): trust_remote_code=True, ) + if args.model_checkpoint_path: + checkpoint_path = args.model_checkpoint_path + if Path(checkpoint_path).is_dir(): + checkpoint_path = get_best_checkpoint_path(checkpoint_path) or checkpoint_path + output_dir = Path(checkpoint_path).parent + else: + checkpoint_path = get_best_checkpoint_path(args.output_dir) + output_dir = Path(args.output_dir) + model = PeftModel.from_pretrained( - model, get_best_checkpoint_path(args.output_dir), torch_dtype=torch.float16 + model, checkpoint_path, torch_dtype=torch.float16 ) model.eval() tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True) @@ -236,11 +259,15 @@ def evaluate(args, dataset): eval_split = args.evaluation_split prompts = create_prompts(dataset, eval_split) - with open( - Path(args.output_dir) / f"{args.model_name}_{eval_split}_answer_candidates", - "rb", - ) as file: - mistral_answer = pickle.load(file) + + answer_candidates_path = output_dir / f"{Path(args.model_name).name}_{eval_split}_answer_candidates.pkl" + answer_candidates_path.parent.mkdir(parents=True, exist_ok=True) + + if answer_candidates_path.exists(): + with open(answer_candidates_path, "rb") as file: + mistral_answer = pickle.load(file) + else: + mistral_answer = [] # filling the pkl file with the generated answers for index in tqdm(range(len(mistral_answer), len(dataset[eval_split]))): @@ -255,15 +282,26 @@ def evaluate(args, dataset): } ] - with open(args.file_name, "wb") as file: + with open(answer_candidates_path, "wb") as file: pickle.dump(mistral_answer, file) + + # Save final results as JSON + answer_candidates_json_path = output_dir / f"{Path(args.model_name).name}_{eval_split}_answer_candidates.json" + with open(answer_candidates_json_path, "w", encoding="utf-8") as file: + json.dump(mistral_answer, file, indent=2, ensure_ascii=False) if __name__ == "__main__": - huggingface_hub.login() + # huggingface_hub.login() args = parser.parse_args() - ds = datasets.load_dataset("AmazonScience/mintaka") + if args.dataset_name == "mkqa-hf": + ds = datasets.load_dataset( + 'Dms12/mkqa_mintaka_format_with_question_entities', + args.dataset_config_name, + ) + else: + ds = datasets.load_dataset(args.dataset_name, args.dataset_config_name) if args.mode == "train": train(args, ds) From 70efd75fb8fcefb64e2dc8d13a765556fa398086 Mon Sep 17 00:00:00 2001 From: Mikhail Salnikov <2613180+MihailSalnikov@users.noreply.github.com> Date: Fri, 5 Dec 2025 00:04:57 +0300 Subject: [PATCH 06/10] Update Catboost training loop for reranking for MKQA-HF data ready --- .../subgraphs_reranking/ranking_data_utils.py | 25 +++++ .../subgraphs_reranking/ranking_model.py | 103 ++++++++++++++++-- 2 files changed, 121 insertions(+), 7 deletions(-) diff --git a/experiments/subgraphs_reranking/ranking_data_utils.py b/experiments/subgraphs_reranking/ranking_data_utils.py index 3cfb648..1afa612 100644 --- a/experiments/subgraphs_reranking/ranking_data_utils.py +++ b/experiments/subgraphs_reranking/ranking_data_utils.py @@ -52,6 +52,31 @@ def prepare_data( return dataframe +def parse_embedding_string(embedding_str): + """Parse comma-separated embedding string to numpy array, replacing NaN/Inf with 0.0""" + if isinstance(embedding_str, (list, np.ndarray)): + arr = np.array(embedding_str, dtype=np.float32) + elif isinstance(embedding_str, str): + try: + arr = np.array([float(x) for x in embedding_str.split(",")], dtype=np.float32) + except (ValueError, AttributeError): + arr = np.array([0.0], dtype=np.float32) + else: + arr = np.array([0.0], dtype=np.float32) + + arr = np.nan_to_num(arr, nan=0.0, posinf=0.0, neginf=0.0) + return arr + + +def convert_embedding_columns_to_arrays(dataframe: pd.DataFrame, embedding_columns: list) -> pd.DataFrame: + """Convert embedding string columns to numpy arrays, handling NaN/Inf values""" + dataframe = dataframe.copy() + for col in embedding_columns: + if col in dataframe.columns: + dataframe[col] = dataframe[col].apply(parse_embedding_string) + return dataframe + + def df_to_features_array(dataframe: pd.DataFrame) -> np.ndarray: """convert from df to arr representation""" features_array = [] diff --git a/experiments/subgraphs_reranking/ranking_model.py b/experiments/subgraphs_reranking/ranking_model.py index 9206954..00601c9 100644 --- a/experiments/subgraphs_reranking/ranking_model.py +++ b/experiments/subgraphs_reranking/ranking_model.py @@ -13,8 +13,9 @@ AutoModelForSequenceClassification, AutoTokenizer, ) -from catboost import CatBoostRegressor -from ranking_data_utils import df_to_features_array +from catboost import CatBoostRegressor, Pool +from sklearn import preprocessing, utils +from ranking_data_utils import df_to_features_array, convert_embedding_columns_to_arrays class RankedAnswer(TypedDict): @@ -362,7 +363,7 @@ class CatboostRanker(RankerBase): def __init__( self, - model_path, + model_path: Optional[str] = None, sequence_features: Optional[list] = None, graph_features: Optional[list] = None, scaler_path: Optional[str] = None, @@ -391,10 +392,92 @@ def __init__( except Exception as exception: # pylint: disable=broad-except print(f"Failed to load fitted scaler: {exception}") - def fit(self, train_df: DataFrame, **kwargs) -> None: - raise NotImplementedError( - "No fit function for CatBoost. Model should be trained already." + def fit( + self, + train_df: DataFrame, + val_df: Optional[DataFrame] = None, + model_save_path: Optional[str] = None, + scaler_save_path: Optional[str] = None, + early_stopping_rounds: int = 300, + **kwargs, + ) -> None: + """fit CatBoost model on train_df""" + train_df = train_df.dropna(subset=["graph"]).copy() + train_df = train_df.sample(frac=0.999).reset_index(drop=True) + if val_df is not None: + val_df = val_df.dropna(subset=["graph"]).copy() + if len(val_df) == 0: + val_df = None + + embedding_features = [] + if self.sequence_features: + embedding_features = self.sequence_features.copy() + + if self.graph_features: + scaler = preprocessing.MinMaxScaler() + train_df[self.graph_features] = scaler.fit_transform( + train_df[self.graph_features] + ) + self.fitted_scaler = scaler + if scaler_save_path: + joblib.dump(scaler, scaler_save_path) + if val_df is not None and len(val_df) > 0: + val_df[self.graph_features] = scaler.transform( + val_df[self.graph_features] + ) + + train_df = convert_embedding_columns_to_arrays(train_df, embedding_features) + if val_df is not None and len(val_df) > 0: + val_df = convert_embedding_columns_to_arrays(val_df, embedding_features) + + X_train = train_df[self.features_to_use] + y_train = train_df["correct"].astype(float).tolist() + + train_classes = np.unique(y_train) + train_weights = utils.compute_class_weight( + class_weight="balanced", classes=train_classes, y=y_train ) + train_class_weights = np.array(y_train) + train_class_weights[train_class_weights == 0] = train_weights[0] + train_class_weights[train_class_weights == 1] = train_weights[1] + + + + learn_pool = Pool( + X_train, + y_train, + feature_names=list(X_train), + embedding_features=embedding_features if embedding_features else None, + weight=train_class_weights, + ) + + val_pool = None + if val_df is not None and len(val_df) > 0: + X_val = val_df[self.features_to_use] + y_val = val_df["correct"].astype(float).tolist() + val_pool = Pool( + X_val, + y_val, + feature_names=list(X_val), + embedding_features=embedding_features if embedding_features else None, + ) + + # params = { + # "learning_rate": list(np.linspace(0.03, 0.3, 5)), + # "depth": [4, 8, 10], + # } + model = CatBoostRegressor() + # grid_search_result = model.grid_search(params, learn_pool) + + self.model = CatBoostRegressor( + depth=4, + early_stopping_rounds=early_stopping_rounds, + eval_metric="RMSE", + ) + self.model.fit(learn_pool, eval_set=val_pool, verbose=200) + + if model_save_path: + self.model.save_model(model_save_path) def rerank(self, test_df: DataFrame) -> List[RankedAnswersDict]: """given test_df, rerank using the trained model and output the @@ -402,10 +485,16 @@ def rerank(self, test_df: DataFrame) -> List[RankedAnswersDict]: if self.model is None: raise NotFittedError("This ranker model is not fitted yet.") - if self.fitted_scaler: # fit graph features if we have a scaler + test_df = test_df.copy() + if self.fitted_scaler: # scale graph features if we have a scaler test_df[self.graph_features] = self.fitted_scaler.transform( test_df[self.graph_features] ) + + embedding_features = [] + if self.sequence_features: + embedding_features = self.sequence_features.copy() + test_df = convert_embedding_columns_to_arrays(test_df, embedding_features) results = [] groups = test_df.groupby("id") From 97bcd9a3145755a3b0fb0518d9b9f1b8463251fc Mon Sep 17 00:00:00 2001 From: Mikhail Salnikov <2613180+MihailSalnikov@users.noreply.github.com> Date: Fri, 5 Dec 2025 00:07:00 +0300 Subject: [PATCH 07/10] Bugfix --- experiments/subgraphs_reranking/ranking_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/experiments/subgraphs_reranking/ranking_model.py b/experiments/subgraphs_reranking/ranking_model.py index 00601c9..99f906d 100644 --- a/experiments/subgraphs_reranking/ranking_model.py +++ b/experiments/subgraphs_reranking/ranking_model.py @@ -127,7 +127,9 @@ def _sort_answers_group_by_scores( for score, answer_entity_id in zip(sorted_scores, sorted_ranked_answers): ranked_answers.append( RankedAnswer( - AnswerEntityID=answer_entity_id, AnswerString=None, Score=score + AnswerEntityID=str(answer_entity_id) if answer_entity_id is not None else None, + AnswerString=None, + Score=float(score) ) ) return ranked_answers From ef5c35d4c6bc89ba099d0de1c6bdee11f15d0676 Mon Sep 17 00:00:00 2001 From: Mikhail Salnikov <2613180+MihailSalnikov@users.noreply.github.com> Date: Wed, 10 Dec 2025 12:06:04 +0300 Subject: [PATCH 08/10] Update evaluators for MKQA --- kbqa/seq2seq/utils.py | 42 ++++--- mintaka_evaluate.py | 50 +++++--- mkqa_evaluate.py | 266 ++++++++++++++++++++++++++++++++++++++++++ seq2seq.py | 27 ++++- 4 files changed, 352 insertions(+), 33 deletions(-) create mode 100644 mkqa_evaluate.py diff --git a/kbqa/seq2seq/utils.py b/kbqa/seq2seq/utils.py index dce01b3..7cf6702 100644 --- a/kbqa/seq2seq/utils.py +++ b/kbqa/seq2seq/utils.py @@ -369,26 +369,40 @@ def get_model_logging_dirs(save_dir, model_name, run_name=None): def dump_eval( - results_df: pd.DataFrame, report: dict, args: Namespace, normolized_model_name: str + results_df: pd.DataFrame, + report: dict, + args: Namespace, + normolized_model_name: str, + output_dir: Path = None, + split_suffix: str = None, ): - eval_report_dir = Path(args.save_dir) - if args.run_name is not None: - eval_report_dir = eval_report_dir / args.run_name - eval_report_dir = ( - eval_report_dir - / ( - normolized_model_name.name - if isinstance(normolized_model_name, Path) - else str(normolized_model_name) + if output_dir is not None: + eval_report_dir = Path(output_dir) / "evaluation" + else: + eval_report_dir = Path(args.save_dir) + if args.run_name is not None: + eval_report_dir = eval_report_dir / args.run_name + eval_report_dir = ( + eval_report_dir + / ( + normolized_model_name.name + if isinstance(normolized_model_name, Path) + else str(normolized_model_name) + ) + / "evaluation" ) - / "evaluation" - ) number_of_versions = len(list(eval_report_dir.glob("version_*"))) - eval_report_dir = eval_report_dir / f"version_{number_of_versions}" + version_name = f"version_{number_of_versions}" + if split_suffix: + version_name = f"{version_name}_{split_suffix}" + eval_report_dir = eval_report_dir / version_name eval_report_dir.mkdir(parents=True, exist_ok=True) - results_df.to_csv(eval_report_dir / "results.csv", index=False) + results_filename = "results.csv" + if split_suffix: + results_filename = f"results_{split_suffix}.csv" + results_df.to_csv(eval_report_dir / results_filename, index=False) with open(eval_report_dir / "report.json", "w", encoding=None) as file_handler: json.dump(report, file_handler) with open(eval_report_dir / "args.json", "w", encoding=None) as file_handler: diff --git a/mintaka_evaluate.py b/mintaka_evaluate.py index 6f2ec53..fd0abf5 100644 --- a/mintaka_evaluate.py +++ b/mintaka_evaluate.py @@ -1,6 +1,7 @@ """ Parsing the jsonl reranking prediction file to gather reranking results (top@n)""" from argparse import ArgumentParser, RawTextHelpFormatter import json +import os from tqdm.auto import tqdm import pandas as pd from datasets import load_dataset @@ -52,6 +53,12 @@ help="Mintaka dataset split.\ntest by default", ) +parser.add_argument( + "--force", + action="store_true", + help="Force evaluation even if output file already exists", +) + def label_to_entity(label: str, top_k: int = 1) -> list: """label_to_entity method to linking label to WikiData entity ID @@ -210,30 +217,39 @@ def process_prediction(prediction): def _calculate_hits(self, is_correct_df: pd.DataFrame, top_n: int = 10) -> dict: hits = {} + numeric_cols = [col for col in is_correct_df.columns if isinstance(col, int)] for top in range(1, top_n + 1): - hits[f"Hit@{top}"] = ( - is_correct_df[list(range(top))].apply(any, axis=1).mean() - ) + cols_to_use = [col for col in range(top) if col in numeric_cols] + if not cols_to_use: + hits[f"Hit@{top}"] = 0.0 + else: + hits[f"Hit@{top}"] = ( + is_correct_df[cols_to_use].apply(any, axis=1).mean() + ) return hits if __name__ == "__main__": args = parser.parse_args() - with open(args.predictions_path, "r", encoding="utf-8") as f: - reranking_predictions = [json.loads(line) for line in f.readlines()] - - eval_mintaka = EvalMintaka() - reranking_results = eval_mintaka.evaluate(reranking_predictions, args.split, 5) - - # save reranking results OUTPUT_DIR = "/".join(args.predictions_path.split("/")[:-1]) run_name = args.predictions_path.split("/")[-1] output_path = f"{OUTPUT_DIR}/reranking_result_{run_name}.txt" - with open(output_path, "w+", encoding="utf-8") as file_output: - file_output.write("Hit scores: \n") - for key, val in reranking_results.items(): - file_output.write(f"{key}") - for hitkey, hitval in val.items(): - file_output.write(f"\t{hitkey:6} = {hitval:.6f}") - file_output.write("\n") + + if os.path.exists(output_path) and not args.force: + print(f"Output file already exists: {output_path}") + print("Skipping evaluation. Use --force to re-evaluate.") + else: + with open(args.predictions_path, "r", encoding="utf-8") as f: + reranking_predictions = [json.loads(line) for line in f.readlines()] + + eval_mintaka = EvalMintaka() + reranking_results = eval_mintaka.evaluate(reranking_predictions, args.split, 5) + + with open(output_path, "w+", encoding="utf-8") as file_output: + file_output.write("Hit scores: \n") + for key, val in reranking_results.items(): + file_output.write(f"{key}") + for hitkey, hitval in val.items(): + file_output.write(f"\t{hitkey:6} = {hitval:.6f}") + file_output.write("\n") diff --git a/mkqa_evaluate.py b/mkqa_evaluate.py new file mode 100644 index 0000000..27c2790 --- /dev/null +++ b/mkqa_evaluate.py @@ -0,0 +1,266 @@ +""" Parsing the jsonl reranking prediction file to gather reranking results (top@n)""" +from argparse import ArgumentParser, RawTextHelpFormatter +import ast +import json +import os +from tqdm.auto import tqdm +import pandas as pd +from datasets import load_dataset +from pywikidata.utils import get_wd_search_results + + +DESCRIPTION = """Evaluation script for MKQA ranked predictions + +Evaluate ranked predictions. If AnswerEntityID not provided, +try to link AnswerString to Entity and compare with GT. +""" + +EXAMPLE_OF_DATA_FORMAT = """ +Example of data format in predictions_path: + { + "QuestionID": "ID1", + "RankedAnswers": [ + { + "AnswerEntityID": null, + "AnswerString": "String of prediction", + "Score": null + }, + { + "AnswerEntityID": "Q90", + "AnswerString": "Paris", + "Score": 0.99 + }, + ... + ] + }, +""" + +parser = ArgumentParser( + description=DESCRIPTION, + formatter_class=RawTextHelpFormatter, +) + +# pylint: disable=line-too-long +parser.add_argument( + "--predictions_path", + help="Path to JSONL file with predictions" + EXAMPLE_OF_DATA_FORMAT, + default="/workspace/storage/misc/subgraphs_reranking_runs/reranking_model_results/t5_large_ssm/mpnet_highlighted_t5_sequence_reranking_seq2seq_large_2_results.jsonl", +) + +parser.add_argument( + "--split", + default="test", + type=str, + help="MKQA dataset split.\ntest by default", +) + +parser.add_argument( + "--force", + action="store_true", + help="Force evaluation even if output file already exists", +) + + +def label_to_entity(label: str, top_k: int = 1) -> list: + """label_to_entity method to linking label to WikiData entity ID + by using elasticsearch Wikimedia public API + Supported only English language (en) + + Parameters + ---------- + label : str + label of entity to search + top_k : int, optional + top K results from WikiData, by default 1 + + Returns + ------- + list[str] | None + list of entity IDs or None if not found + """ + try: + elastic_results = get_wd_search_results(label, top_k, language="en")[:top_k] + except: # pylint: disable=bare-except + elastic_results = [] + + try: + elastic_results.extend( + get_wd_search_results( + label.replace('"', "").replace("'", "").strip(), top_k, language="en" + )[:top_k] + ) + except: # pylint: disable=bare-except + return [None] + + if len(elastic_results) == 0: + return [None] + + return list(dict.fromkeys(elastic_results).keys())[:top_k] + + +class EvalMKQA: + """EvalMKQA Evaluation class for MKQA ranked predictions""" + + def __init__(self): + mkqa_ds = load_dataset("Dms12/mkqa_mintaka_format_with_question_entities") + self.dataset = { + "train": mkqa_ds["train"].to_pandas(), + "validation": mkqa_ds["validation"].to_pandas(), + "test": mkqa_ds["test"].to_pandas(), + } + + # Extract Entities Names (Ids) from dataset records + for _, df in self.dataset.items(): + df["answerEntityNames"] = df["answerEntity"].apply( + self._get_list_of_entity_ids + ) + + def _get_list_of_entity_ids(self, answer_entities): + return [e["name"] for e in answer_entities] + + def is_answer_correct(self, mkqa_record: pd.Series, answer: dict) -> bool: + """to check whether an answer is correct or not + + Args: + mkqa_record (pd.Series): row in the MKQA dataset + answer (dict): answer dict; comprising of the answer entity and/or answer str + + Returns: + bool: correct or not + """ + answer_entity_id = answer.get("AnswerEntityID") + + # Parse AnswerEntityID if it's a string representation of a list + if answer_entity_id is not None and isinstance(answer_entity_id, str): + if answer_entity_id.startswith("[") and answer_entity_id.endswith("]"): + try: + parsed = ast.literal_eval(answer_entity_id) + if isinstance(parsed, list) and len(parsed) > 0: + answer_entity_id = parsed[0] + else: + answer_entity_id = None + except (ValueError, SyntaxError): + answer_entity_id = None + + if answer_entity_id is None: + if answer.get("AnswerString") is not None: + answer_entity_id = label_to_entity(answer["AnswerString"])[0] + else: + answer_entity_id = None + + if ( + answer_entity_id is None + and mkqa_record["answerText"] is not None + and answer.get("AnswerString") is not None + ): + return answer["AnswerString"] == mkqa_record["answerText"] + + if answer_entity_id is None: + return False + + return answer_entity_id in mkqa_record["answerEntityNames"] + + def evaluate(self, predictions, split: str = "test", top_n: int = 10): + """evaluate _summary_ + + Parameters + ---------- + predictions : List[Dict] + Predictions in the following format: + [ + { + "QuestionID": "ID1", + "RankedAnswers": [ + { + "AnswerEntityID": None, + "AnswerString": "String of prediction", + "Score": None + }, + ... + ] + }, + ... + ] + """ + _df = self.dataset[split] + + import concurrent.futures + + def process_prediction(prediction): + question_idx = int(prediction["QuestionID"]) + matching_records = _df[_df["id"] == question_idx] + if len(matching_records) == 0: + raise ValueError(f"QuestionID {question_idx} not found in dataset") + mkqa_record = matching_records.iloc[0] + is_answer_correct_results = [ + self.is_answer_correct(mkqa_record, answer) + for answer in prediction["RankedAnswers"] + ] + return is_answer_correct_results + + is_correct = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: + results = list( + tqdm( + executor.map(process_prediction, predictions), + total=len(predictions), + desc="Process predictions.." + ) + ) + is_correct.extend(results) + + is_correct_df = pd.DataFrame(is_correct) + is_correct_df["id"] = [int(p["QuestionID"]) for p in predictions] + is_correct_df = _df.merge(is_correct_df, on="id") + + if len(set(is_correct_df["id"]).symmetric_difference(_df["id"])) != 0: + print( + "WARNING: Not all questions have predictions, " + "the results will be calculated only for the provided predictions " + "without taking into account the unworthy ones." + ) + + # Format metrics based on is_correct matrix + results = { + "FULL Dataset": self._calculate_hits(is_correct_df, top_n), + } + return results + + def _calculate_hits(self, is_correct_df: pd.DataFrame, top_n: int = 10) -> dict: + hits = {} + numeric_cols = [col for col in is_correct_df.columns if isinstance(col, int)] + for top in range(1, top_n + 1): + cols_to_use = [col for col in range(top) if col in numeric_cols] + if not cols_to_use: + hits[f"Hit@{top}"] = 0.0 + else: + hits[f"Hit@{top}"] = ( + is_correct_df[cols_to_use].apply(any, axis=1).mean() + ) + return hits + + +if __name__ == "__main__": + args = parser.parse_args() + + OUTPUT_DIR = "/".join(args.predictions_path.split("/")[:-1]) + run_name = args.predictions_path.split("/")[-1] + output_path = f"{OUTPUT_DIR}/reranking_result_{run_name}.txt" + + if os.path.exists(output_path) and not args.force: + print(f"Output file already exists: {output_path}") + print("Skipping evaluation. Use --force to re-evaluate.") + else: + with open(args.predictions_path, "r", encoding="utf-8") as f: + reranking_predictions = [json.loads(line) for line in f.readlines()] + + eval_mkqa = EvalMKQA() + reranking_results = eval_mkqa.evaluate(reranking_predictions, args.split, 5) + + with open(output_path, "w+", encoding="utf-8") as file_output: + file_output.write("Hit scores:\n") + for key, val in reranking_results.items(): + file_output.write(f"{key}") + for hitkey, hitval in sorted(val.items(), key=lambda x: int(x[0].split("@")[1])): + file_output.write(f"\t{hitkey:6} = {hitval:.6f}") + file_output.write("\n") diff --git a/seq2seq.py b/seq2seq.py index cd6ae6d..8d5230d 100644 --- a/seq2seq.py +++ b/seq2seq.py @@ -164,6 +164,11 @@ help="Using Wikidata redirects for augmenting train dataset. Do not use with Seq2SeqWikidataRedirectsTrainer", type=lambda x: (str(x).lower() == "true"), ) +parser.add_argument( + "--model_checkpoint_path", + default=None, + help="Direct path to model checkpoint directory (for eval mode). If provided, will use this instead of constructing path from save_dir/model_name/run_name", +) def train(args, model_dir, logging_dir): @@ -264,11 +269,22 @@ def train(args, model_dir, logging_dir): def evaluate(args, model_dir, normolized_model_name): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + output_dir = None + if args.model_checkpoint_path: + checkpoint_path = args.model_checkpoint_path + if Path(checkpoint_path).is_dir(): + checkpoint_path = get_best_checkpoint_path(checkpoint_path) or checkpoint_path + output_dir = Path(checkpoint_path).parent + else: + checkpoint_path = get_best_checkpoint_path(model_dir) + model, tokenizer = load_model_and_tokenizer_by_name( - args.model_name, get_best_checkpoint_path(model_dir) + args.model_name, checkpoint_path ) model = model.to(device) + split_suffix = None if args.dataset_name == "AmazonScience/mintaka": dataset = load_mintaka_seq2seq_dataset( args.dataset_name, @@ -277,6 +293,7 @@ def evaluate(args, model_dir, normolized_model_name): split=args.dataset_evaluation_split, ) label_feature_name = "answerText" + split_suffix = args.dataset_evaluation_split logger.info( f"Eval: MINTAKA Dataset loaded, label_feature_name={label_feature_name}" ) @@ -289,6 +306,7 @@ def evaluate(args, model_dir, normolized_model_name): split="test", ) label_feature_name = "Label" + split_suffix = "test" logger.info( f"Lcquad2.0 Eval: Dataset loaded, label_feature_name={label_feature_name}" ) @@ -301,6 +319,7 @@ def evaluate(args, model_dir, normolized_model_name): split=args.dataset_evaluation_split, ) label_feature_name = "answerText" + split_suffix = args.dataset_evaluation_split logger.info( f"Eval: MKQA Dataset loaded, label_feature_name={label_feature_name}" ) @@ -320,6 +339,7 @@ def evaluate(args, model_dir, normolized_model_name): split=split, ) label_feature_name = "answerText" + split_suffix = split logger.info( f"Eval: MKQA Dataset loaded, label_feature_name={label_feature_name}" ) @@ -334,6 +354,7 @@ def evaluate(args, model_dir, normolized_model_name): apply_redirects_augmentation=args.apply_redirects_augmentation, ) label_feature_name = "object" + split_suffix = args.dataset_evaluation_split logger.info(f"Eval: Dataset loaded, label_feature_name={label_feature_name}") results_df, report = make_report( @@ -350,7 +371,9 @@ def evaluate(args, model_dir, normolized_model_name): label_feature_name=label_feature_name, ) - eval_report_dir = dump_eval(results_df, report, args, normolized_model_name) + eval_report_dir = dump_eval( + results_df, report, args, normolized_model_name, output_dir=output_dir, split_suffix=split_suffix + ) if args.mlflow_experiment_name is not None: mlflow.log_metrics(report) mlflow.log_artifacts(eval_report_dir, "report") From 34945542b6a5fe5518b9bf493a05f4a49919d7b5 Mon Sep 17 00:00:00 2001 From: Mikhail Salnikov <2613180+MihailSalnikov@users.noreply.github.com> Date: Wed, 10 Dec 2025 12:08:31 +0300 Subject: [PATCH 09/10] Ranking CatBoost, Baselines, MPNet updates --- .../mkqa_subgraphs_prepairing.py | 96 +- .../subgraphs_reranking/aggregate_results.py | 356 ++++ experiments/subgraphs_reranking/plot_main.py | 451 +++++ .../subgraphs_reranking/rankgpt/__init__.py | 12 + .../rankgpt/aggregate_results.py | 205 +++ .../subgraphs_reranking/rankgpt/data_utils.py | 126 ++ .../subgraphs_reranking/rankgpt/predict.py | 156 ++ .../rankgpt/prompt_builder.py | 162 ++ .../rankgpt/rankgpt_ranker.py | 262 +++ experiments/subgraphs_reranking/ranking.ipynb | 1613 ++++++++++++----- .../subgraphs_reranking/ranking_baselines.py | 499 +++++ .../subgraphs_reranking/ranking_catboost.py | 289 +++ .../subgraphs_reranking/ranking_data_utils.py | 24 +- .../subgraphs_reranking/ranking_model.py | 31 +- .../subgraphs_reranking/ranking_mpnet.py | 145 ++ .../sequence/train_sequence_ranker.py | 49 +- .../subgraphs_reranking/upload_outputs.py | 88 + .../mining_subgraphs_dataset_processes.py | 87 +- 18 files changed, 4037 insertions(+), 614 deletions(-) create mode 100644 experiments/subgraphs_reranking/aggregate_results.py create mode 100644 experiments/subgraphs_reranking/plot_main.py create mode 100644 experiments/subgraphs_reranking/rankgpt/__init__.py create mode 100644 experiments/subgraphs_reranking/rankgpt/aggregate_results.py create mode 100644 experiments/subgraphs_reranking/rankgpt/data_utils.py create mode 100755 experiments/subgraphs_reranking/rankgpt/predict.py create mode 100644 experiments/subgraphs_reranking/rankgpt/prompt_builder.py create mode 100644 experiments/subgraphs_reranking/rankgpt/rankgpt_ranker.py create mode 100644 experiments/subgraphs_reranking/ranking_baselines.py create mode 100644 experiments/subgraphs_reranking/ranking_catboost.py create mode 100644 experiments/subgraphs_reranking/ranking_mpnet.py create mode 100644 experiments/subgraphs_reranking/upload_outputs.py diff --git a/experiments/subgraphs_datasets_prepare_input_data/mkqa_subgraphs_prepairing.py b/experiments/subgraphs_datasets_prepare_input_data/mkqa_subgraphs_prepairing.py index 6df2837..b52fdca 100644 --- a/experiments/subgraphs_datasets_prepare_input_data/mkqa_subgraphs_prepairing.py +++ b/experiments/subgraphs_datasets_prepare_input_data/mkqa_subgraphs_prepairing.py @@ -1,9 +1,3 @@ -import time -from collections import deque -from functools import wraps -from threading import Lock -from time import sleep - import pandas as pd from pywikidata import Entity from tqdm.auto import tqdm @@ -12,47 +6,10 @@ from wd_api import get_wd_search_results from multiprocessing import Pool, cpu_count -model_name = 't5-xl-ssm-nq' -type = 'train' -predictions_path = f'../../{model_name}-res/google_{model_name}/evaluation/version_0_{type}/results.csv' - - -def rate_limit(max_calls=15, period=60): - def decorator(func): - # Store state in closure variables - calls = deque() # Store timestamps of recent calls - lock = Lock() # Thread safety lock - - @wraps(func) - def wrapper(*args, **kwargs): - with lock: - current_time = time.time() - - # Remove timestamps older than the period - while calls and calls[0] <= current_time - period: - calls.popleft() - - # Check if we've exceeded the rate limit - if len(calls) >= max_calls: - oldest = calls[0] - wait_time = oldest + period - current_time - if wait_time > 0: - print(f"Rate limit, sleep {wait_time}") - time.sleep(wait_time) - # After sleeping, update current_time and clean old calls again - current_time = time.time() - while calls and calls[0] <= current_time - period: - calls.popleft() - - # Record this call and execute the function - calls.append(current_time) - - return func(*args, **kwargs) - return wrapper - return decorator +model_name = 't5-large-ssm' +predictions_path = f'../../{model_name}-res/google_{model_name}/evaluation/version_0/results.csv' -@rate_limit(max_calls=30, period=60) def label_to_entity(label: str, top_k: int = 3) -> list: """label_to_entity method to linking label to WikiData entity ID by using elasticsearch Wikimedia public API @@ -70,35 +27,17 @@ def label_to_entity(label: str, top_k: int = 3) -> list: list[str] | None list of entity IDs or None if not found """ - retry = True - while retry: - try: - elastic_results = get_wd_search_results(label, top_k, language='en')[:top_k] - except Exception as e: - print(f"First e: {e}") - if '429' in str(e): - # print(f"Retry first for: {e}") - sleep(1001) - else: - retry = False - elastic_results = [] - else: - retry = False - - retry = True - while retry: - try: - elastic_results.extend( - get_wd_search_results(label.replace("\"", "").replace("\'", "").strip(), top_k, language='en')[:top_k] - ) - except Exception as e: - print(f"Second e: {e}") - if '429' in str(e): - sleep(1001) - else: - retry = False - else: - retry = False + try: + elastic_results = get_wd_search_results(label, top_k, language='en')[:top_k] + except: + elastic_results = [] + + try: + elastic_results.extend( + get_wd_search_results(label.replace("\"", "").replace("\'", "").strip(), top_k, language='en')[:top_k] + ) + except: + return None return list(dict.fromkeys(elastic_results).keys())[:top_k] @@ -124,7 +63,7 @@ def data_to_subgraphs(df): def process_row(row): results = [] - print(f"Start: {row['id']}") + print(f'Start: {row['id']}') # print("HERE!") question_entity_ids = [e['name'] for e in row['questionEntity'] if e['entityType'] == 'entity'] for candidate_label in dict.fromkeys(row['model_answers']).keys(): @@ -140,13 +79,12 @@ def process_row(row): 'groundTruthAnswerEntity': [e['name'] for e in row['answerEntity']] }) - print(f"End: {row['id']}") + print(f'End: {row['id']}') return results def eval_df(df): num_processes = cpu_count() - print("Run with processes:", num_processes) # Convert DataFrame to list of dictionaries for processing rows = df.to_dict('records') # print(rows) @@ -171,11 +109,11 @@ def eval_df(df): test_df = pd.merge( test_predictions, - ds[f'{type}'].to_pandas(), + ds['test'].to_pandas(), on=['question'], ) results = eval_df(test_df) - with open(f'../../{model_name}_{type}.jsonl', 'w') as f: + with open(f'../../{model_name}_test.jsonl', 'w') as f: for data_line in results: f.write(ujson.dumps(data_line) + '\n') diff --git a/experiments/subgraphs_reranking/aggregate_results.py b/experiments/subgraphs_reranking/aggregate_results.py new file mode 100644 index 0000000..c25d3ce --- /dev/null +++ b/experiments/subgraphs_reranking/aggregate_results.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python3 +"""Aggregate evaluation results across multiple runs and calculate mean/std statistics""" + +import json +import re +from pathlib import Path +from collections import defaultdict +import numpy as np +import subprocess +import sys +from argparse import ArgumentParser + +parser = ArgumentParser(description="Aggregate evaluation results across multiple runs") +parser.add_argument( + "--ds_type", + type=str, + default=None, + help="Filter results by dataset type (e.g., t5largessm, t5xlssm, mistral, mixtral). If not specified, processes all dataset types." +) +parser.add_argument( + "--dataset", + type=str, + default="mintaka", + choices=["mintaka", "mkqa-hf"], + help="Dataset to use: mintaka or mkqa-hf." +) + +MINTAKA_EVAL_SCRIPT = Path(__file__).parent.parent.parent / "mintaka_evaluate.py" +MKQA_EVAL_SCRIPT = Path(__file__).parent.parent.parent / "mkqa_evaluate.py" + + +def get_results_base_dir(dataset: str) -> Path: + """Get the results base directory based on dataset type""" + if dataset == "mkqa-hf": + return Path("./reranking_model_results/mkqa-hf") + return Path("./reranking_model_results/mintaka") + + +def parse_evaluation_file(eval_file: Path) -> dict: + """Parse evaluation result file and extract metrics""" + results = {} + with open(eval_file, "r", encoding="utf-8") as f: + lines = f.readlines() + + for line in lines[1:]: # Skip "Hit scores:" header + line = line.strip() + if not line: + continue + + parts = line.split("\t") + if not parts: + continue + + category = parts[0].strip() + metrics = {} + + for part in parts[1:]: + match = re.match(r"Hit@(\d+)\s*=\s*([\d.]+)", part) + if match: + n = int(match.group(1)) + value = float(match.group(2)) + metrics[f"Hit@{n}"] = value + + if metrics: + results[category] = metrics + + return results + + +def run_evaluation(result_file: Path, dataset: str) -> Path: + """Run evaluation script on a result file if evaluation doesn't exist""" + eval_file = result_file.parent / f"reranking_result_{result_file.name}.txt" + + if not eval_file.exists(): + print(f"Running evaluation for {result_file.name}...") + eval_script = MKQA_EVAL_SCRIPT if dataset == "mkqa-hf" else MINTAKA_EVAL_SCRIPT + cmd = [ + sys.executable, + str(eval_script), + "--predictions_path", + str(result_file), + "--split", + "test" + ] + subprocess.run(cmd, check=True) + + return eval_file + + +def extract_config(result_file: Path) -> tuple: + """Extract (model_type, feature_combo, ds_type, run) from result filename + + Handles two patterns: + 1. With run: {model_type}_{feature_combo}_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl + 2. Without run (baselines): {model_type}_{feature_combo}_reranking_seq2seq_{ds_type}_results.jsonl + """ + name = result_file.name + # Pattern 1: With run number (multi-run experiments) + match = re.match( + r"([^_]+)_(.+)_reranking_seq2seq_run_(\d+)_(.+?)_results\.jsonl", + name + ) + if match: + model_type = match.group(1) + feature_combo = match.group(2) + run = int(match.group(3)) + ds_type = match.group(4) + return (model_type, feature_combo, ds_type, run) + + # Pattern 2: Without run number (baseline methods, single run) + # Handle cases like: NO_reranking_seq2seq_{ds_type}_results.jsonl + # or: full_random_reranking_seq2seq_{ds_type}_results.jsonl + # or: logreg_text_reranking_seq2seq_{ds_type}_results.jsonl + match = re.match( + r"(.+?)_reranking_seq2seq_(.+?)_results\.jsonl", + name + ) + if match: + prefix = match.group(1) # Everything before "_reranking_seq2seq_" + ds_type = match.group(2) + + # Known model types that have feature combos (split on first underscore) + models_with_features = {"logreg", "linreg", "mpnet", "catboost"} + + # Try to split into model_type and feature_combo + parts = prefix.split("_", 1) + if len(parts) == 2 and parts[0] in models_with_features: + # Has feature combo: e.g., "logreg_text" -> model_type="logreg", feature_combo="text" + model_type = parts[0] + feature_combo = parts[1] + else: + # No feature combo or unknown model: e.g., "NO", "full_random", "semantic_mpnet" + # The whole prefix is the model_type + model_type = prefix + feature_combo = "" + + # Use run=0 for baselines to distinguish from multi-run experiments + return (model_type, feature_combo, ds_type, 0) + + return (None, None, None, None) + + +def aggregate_results(ds_type_filter: str = None, dataset: str = "mintaka"): + """Main function to aggregate all evaluation results + + Parameters + ---------- + ds_type_filter : str, optional + If provided, only process results for this dataset type + dataset : str + Dataset type: "mintaka" or "mkqa-hf" + """ + results_base_dir = get_results_base_dir(dataset) + + if not results_base_dir.exists(): + print(f"Warning: Results directory {results_base_dir} does not exist") + return {} + + result_files = [] + for ds_dir in results_base_dir.iterdir(): + if ds_dir.is_dir(): + if ds_type_filter is not None and ds_dir.name != ds_type_filter: + continue + # Find files with run numbers (multi-run experiments) + result_files.extend( + ds_dir.glob("*_reranking_seq2seq_run_*_results.jsonl") + ) + # Find files without run numbers (baseline methods) + # Exclude files that already matched the pattern above + baseline_files = [ + f for f in ds_dir.glob("*_reranking_seq2seq_*_results.jsonl") + if "_run_" not in f.name + ] + result_files.extend(baseline_files) + + result_files = sorted(result_files) + + config_results = defaultdict(lambda: defaultdict(list)) + + print(f"Found {len(result_files)} result files" + (f" for ds_type={ds_type_filter}" if ds_type_filter else "")) + + for result_file in result_files: + model_type, feature_combo, ds_type, run = extract_config(result_file) + if model_type is None: + print(f"Warning: Could not parse config from {result_file.name}") + continue + + if ds_type_filter is not None and ds_type != ds_type_filter: + continue + + eval_file = run_evaluation(result_file, dataset) + + try: + results = parse_evaluation_file(eval_file) + config_key = (model_type, feature_combo, ds_type) + config_results[config_key]["runs"].append((run, results)) + except Exception as e: + print(f"Error processing {result_file.name}: {e}") + continue + + aggregated = {} + for (model_type, feature_combo, ds_type), data in config_results.items(): + runs_data = sorted(data["runs"], key=lambda x: x[0]) + runs = [r[1] for r in runs_data] + num_runs = len(runs) + + # Check if this is a baseline (run=0) or multi-run experiment + is_baseline = any(r[0] == 0 for r in runs_data) + + if not is_baseline and num_runs != 3: + print(f"Warning: Expected 3 runs for {model_type}/{feature_combo}/{ds_type}, found {num_runs}") + + config_key = f"{model_type}_{feature_combo}_{ds_type}" + aggregated[config_key] = { + "metadata": { + "model_type": model_type, + "feature_combo": feature_combo, + "ds_type": ds_type, + "num_runs": num_runs, + "is_baseline": is_baseline + }, + "results": {} + } + + categories = set() + for run in runs: + categories.update(run.keys()) + + for category in sorted(categories): + hit_metrics = defaultdict(list) + for run in runs: + if category in run: + for hit_key, value in run[category].items(): + hit_metrics[hit_key].append(value) + + category_stats = {} + for hit_key in sorted(hit_metrics.keys(), key=lambda x: int(x.split("@")[1])): + values = hit_metrics[hit_key] + if values: + if is_baseline and num_runs == 1: + # For baselines (single run), mean = value, std = 0.0 + category_stats[f"{hit_key}_mean"] = values[0] + category_stats[f"{hit_key}_std"] = 0.0 + else: + # For multi-run experiments, calculate mean and std + category_stats[f"{hit_key}_mean"] = np.mean(values) + category_stats[f"{hit_key}_std"] = np.std(values) + + if category_stats: + aggregated[config_key]["results"][category] = category_stats + + return aggregated + + +def print_results(aggregated: dict, ds_type_filter: str = None, dataset: str = "mintaka"): + """Print aggregated results in a readable format + + Parameters + ---------- + aggregated : dict + Aggregated results dictionary + ds_type_filter : str, optional + Dataset type filter used, for output filename + dataset : str + Dataset type: "mintaka" or "mkqa-hf" + """ + results_base_dir = get_results_base_dir(dataset) + + if ds_type_filter: + output_file = results_base_dir / f"aggregated_results_mean_std_{ds_type_filter}.txt" + else: + output_file = results_base_dir / "aggregated_results_mean_std.txt" + + with open(output_file, "w", encoding="utf-8") as f: + f.write("Aggregated Results (Mean ± Std across runs)\n") + f.write("=" * 80 + "\n\n") + + for config_key in sorted(aggregated.keys()): + config_data = aggregated[config_key] + metadata = config_data["metadata"] + results = config_data["results"] + + model_type = metadata["model_type"] + feature_combo = metadata["feature_combo"] + ds_type = metadata["ds_type"] + num_runs = metadata["num_runs"] + is_baseline = metadata.get("is_baseline", False) + + run_info = f"runs={num_runs}" if not is_baseline else "baseline (single run)" + f.write(f"\nConfiguration: model={model_type}, features={feature_combo}, ds_type={ds_type}, {run_info}\n") + f.write("-" * 80 + "\n") + + for category in sorted(results.keys()): + f.write(f"\n{category}:\n") + metrics = results[category] + + hit_n_values = defaultdict(dict) + for key, value in metrics.items(): + if "_mean" in key: + hit_n = key.replace("_mean", "") + hit_n_values[hit_n]["mean"] = value + elif "_std" in key: + hit_n = key.replace("_std", "") + hit_n_values[hit_n]["std"] = value + + for hit_n in sorted(hit_n_values.keys(), key=lambda x: int(x.split("@")[1])): + mean = hit_n_values[hit_n]["mean"] + std = hit_n_values[hit_n]["std"] + f.write(f" {hit_n:8} = {mean:.4f} ± {std:.4f}\n") + + f.write("\n") + + # Calculate and print upper bound (maximum Hit@K values across all configurations) + f.write("\n" + "=" * 80 + "\n") + f.write("Upper Bound (Maximum Hit@K across all configurations)\n") + f.write("=" * 80 + "\n\n") + + # Collect all categories and Hit@K metrics (Hit@1 to Hit@30) + all_categories = set() + all_hit_metrics = defaultdict(lambda: defaultdict(list)) + + for config_key, config_data in aggregated.items(): + results = config_data["results"] + for category, metrics in results.items(): + all_categories.add(category) + for key, value in metrics.items(): + if "_mean" in key: + hit_n = key.replace("_mean", "") + # Only include Hit@1 to Hit@30 + hit_k = int(hit_n.split("@")[1]) + if 1 <= hit_k <= 30: + all_hit_metrics[category][hit_n].append(value) + + # Print maximum values for each category and Hit@K + for category in sorted(all_categories): + f.write(f"{category}:\n") + hit_metrics = all_hit_metrics[category] + + for hit_n in sorted(hit_metrics.keys(), key=lambda x: int(x.split("@")[1])): + max_value = max(hit_metrics[hit_n]) + f.write(f" {hit_n:8} = {max_value:.4f}\n") + + f.write("\n") + + print(f"\nResults saved to {output_file}") + + with open(output_file, "r", encoding="utf-8") as f: + print(f.read()) + + +if __name__ == "__main__": + args = parser.parse_args() + aggregated = aggregate_results(args.ds_type, args.dataset) + print(f"Aggregated results: {aggregated}") + print_results(aggregated, args.ds_type, args.dataset) + diff --git a/experiments/subgraphs_reranking/plot_main.py b/experiments/subgraphs_reranking/plot_main.py new file mode 100644 index 0000000..9c4e6d1 --- /dev/null +++ b/experiments/subgraphs_reranking/plot_main.py @@ -0,0 +1,451 @@ +""" +Plot Hit@1 results for different reranking approaches across multiple candidate sources. +""" + +import re +import numpy as np +import matplotlib.pyplot as plt +from matplotlib import rcParams + +# Configure matplotlib for publication quality +rcParams['font.family'] = 'serif' +rcParams['font.size'] = 9 +rcParams['axes.labelsize'] = 10 +rcParams['axes.titlesize'] = 10 +rcParams['xtick.labelsize'] = 9 +rcParams['ytick.labelsize'] = 9 +rcParams['legend.fontsize'] = 8 +rcParams['figure.titlesize'] = 11 + +# File paths +files = { + 'T5-Large-SSM': 'tabs/t5largessm_all_results_hit_at_n.tex', + 'T5-XL-SSM': 'tabs/t5xlssm_all_results_hit_at_n.tex', + 'Mixtral': 'tabs/mixtral_all_results_hit_at_n.tex', + 'Mistral': 'tabs/mistral_all_results_hit_at_n.tex' +} + +def parse_value(value_str): + """Parse values like '0.2213' or '$0.2442\\pm0.0004$'""" + value_str = value_str.strip().replace('$', '') + + if '\\pm' in value_str: + parts = value_str.split('\\pm') + mean = float(parts[0]) + error = float(parts[1]) + return mean, error + else: + try: + return float(value_str), None + except: + return None, None + +def parse_latex_table(filename): + """Parse LaTeX table and extract Hit@1 values with errors.""" + with open(filename, 'r') as f: + content = f.read() + + results = {} + lines = content.split('\n') + current_model = None + + for line in lines: + # Skip control lines + if any(x in line for x in ['\\toprule', '\\midrule', '\\bottomrule', '\\cmidrule', + 'textbf{Reranking Model}', '\\caption', '\\label', + '\\begin{', '\\end{', '\\setlength', '\\fontsize']): + continue + + # Check if it's a data line + if '&' in line and '\\\\' in line: + parts = [p.strip() for p in line.split('&')] + + if len(parts) >= 3: + model = parts[0] + features = parts[1] if len(parts) > 1 else '' + hit1 = parts[2] if len(parts) > 2 else '' + + # Handle multirow + if model and '\\multirow' not in model: + if model != '': + current_model = model + + if '\\multirow' in model: + match = re.search(r'\\multirow\{[^}]+\}\{[^}]+\}\{([^}]+)\}', model) + if match: + current_model = match.group(1) + model = '' + + if model == '' and current_model: + model = current_model + + # Parse Hit@1 value + mean, error = parse_value(hit1) + + if mean is not None: + if features: + # Normalize "Det.Lin." to "Det. Lin." for consistency + features = features.replace('Det.Lin.', 'Det. Lin.') + key = f"{model} - {features}" + else: + key = model + + results[key] = {'mean': mean, 'error': error} + + return results + +# Parse all data +all_data = {} +for model_name, filename in files.items(): + all_data[model_name] = parse_latex_table(filename) + +# Define comprehensive method selection - prioritize most important ones +# Note: Some methods may not exist for all models (will show as missing/0) +selected_approaches = [ + 'Without reranking', + 'Majority Vote', + # 'Semantic reranking - Text', + 'RankGPT (Mistral) - Text', + 'RankGPT (Mistral) - Text + G2T (Det. Lin.)', + 'RankGPT (Qwen2.5 7B) - Text', + 'RankGPT (Qwen2.5 7B) - Text + G2T (Det. Lin.)', + 'Linear Regression - Text', + 'Linear Regression - Text + Graph', + 'Linear Regression - Text + Graph + G2T (Det. Lin.)', + 'Linear Regression - Text + Graph + G2T (T5)', + 'Linear Regression - Text + Graph + G2T (GAP)', + 'Linear Regression - Text + G2T (Det. Lin.)', + 'Linear Regression - Text + G2T (T5)', + 'Linear Regression - Text + G2T (GAP)', + 'Logistic Regression - Text', + 'Logistic Regression - Text + Graph', + 'Logistic Regression - Text + Graph + G2T (Det. Lin.)', + 'Logistic Regression - Text + Graph + G2T (T5)', + 'Logistic Regression - Text + Graph + G2T (GAP)', + 'Logistic Regression - Text + G2T (Det. Lin.)', + 'Logistic Regression - Text + G2T (T5)', + 'Logistic Regression - Text + G2T (GAP)', + 'CatBoost - Text', + 'CatBoost - Text + Graph', + 'CatBoost - Text + Graph + G2T (Det. Lin.)', + 'CatBoost - Text + Graph + G2T (T5)', + 'CatBoost - Text + Graph + G2T (GAP)', + 'CatBoost - Text + G2T (Det. Lin.)', + 'CatBoost - Text + G2T (T5)', + 'CatBoost - Text + G2T (GAP)', + 'MPNet - Text', + 'MPNet - Text + G2T (Det. Lin.)', + 'MPNet - Text + G2T (T5)', + 'MPNet - Text + G2T (GAP)', +] + +def shorten_method_name(name): + """Shorten method names for better display.""" + replacements = { + 'Linear Regression': 'Lin. Reg.', + 'Logistic Regression': 'Log. Reg.', + 'RankGPT (Mistral)': 'RankGPT (Mistral)', + 'RankGPT (Qwen2.5 7B)': 'RankGPT (Qwen)', + } + result = name + for full, short in replacements.items(): + result = result.replace(full, short) + return result + +def get_method_group(method_name): + """Determine which group a method belongs to for spacing.""" + if method_name in ['Without reranking', 'Majority Vote', 'Semantic reranking - Text']: + return 'baseline' + elif method_name.startswith('RankGPT'): + return 'rankgpt' + elif method_name.startswith('Linear Regression'): + return 'linear' + elif method_name.startswith('Logistic Regression'): + return 'logistic' + elif method_name.startswith('CatBoost'): + return 'catboost' + elif method_name.startswith('MPNet'): + return 'mpnet' + else: + return 'other' + +# High contrast earthy colors palette +colors = { + 'Without reranking': '#6B5344', + 'Majority Vote': '#E74C3C', + 'Semantic reranking - Text': '#A0522D', + 'RankGPT (Mistral) - Text': '#8B4513', + 'RankGPT (Mistral) - Text + G2T (Det. Lin.)': '#7A3513', + 'RankGPT (Qwen2.5 7B) - Text': '#9B5523', + 'RankGPT (Qwen2.5 7B) - Text + G2T (Det. Lin.)': '#8A4523', + 'Linear Regression - Text': '#BC8F8F', + 'Linear Regression - Text + Graph': '#A0826D', + 'Linear Regression - Text + Graph + G2T (Det. Lin.)': '#8B7355', + 'Linear Regression - Text + Graph + G2T (T5)': '#7A6345', + 'Linear Regression - Text + Graph + G2T (GAP)': '#6B5535', + 'Linear Regression - Text + G2T (Det. Lin.)': '#9A8365', + 'Linear Regression - Text + G2T (T5)': '#897355', + 'Linear Regression - Text + G2T (GAP)': '#786345', + 'Logistic Regression - Text': '#DEB887', + 'Logistic Regression - Text + Graph': '#D2B48C', + 'Logistic Regression - Text + Graph + G2T (Det. Lin.)': '#C8A882', + 'Logistic Regression - Text + Graph + G2T (T5)': '#BE9C78', + 'Logistic Regression - Text + Graph + G2T (GAP)': '#B4906E', + 'Logistic Regression - Text + G2T (Det. Lin.)': '#D8B892', + 'Logistic Regression - Text + G2T (T5)': '#CEA882', + 'Logistic Regression - Text + G2T (GAP)': '#C49872', + 'CatBoost - Text': '#DAA520', + 'CatBoost - Text + Graph': '#B8860B', + 'CatBoost - Text + Graph + G2T (Det. Lin.)': '#9B7708', + 'CatBoost - Text + Graph + G2T (T5)': '#8B6806', + 'CatBoost - Text + Graph + G2T (GAP)': '#7B5905', + 'CatBoost - Text + G2T (Det. Lin.)': '#856A07', + 'CatBoost - Text + G2T (T5)': '#755A06', + 'CatBoost - Text + G2T (GAP)': '#654A05', + 'MPNet - Text': '#CD853F', + 'MPNet - Text + G2T (Det. Lin.)': '#A0522D', + 'MPNet - Text + G2T (T5)': '#8F451D', + 'MPNet - Text + G2T (GAP)': '#7E380D', +} + +def plot_results(files_dict, selected_approaches_list, colors_dict, y_axis_limits_dict, + title, output_prefix): + """Plot Hit@1 results for given files and methods.""" + # Parse all data + all_data = {} + for model_name, filename in files_dict.items(): + all_data[model_name] = parse_latex_table(filename) + + # Extract data for plotting + models = list(files_dict.keys()) + + # Prepare data arrays for each model + data_by_model = {} + for j, model in enumerate(models): + means = [] + errors = [] + labels = [] + colors_list = [] + + for approach in selected_approaches_list: + if approach in all_data[model]: + means.append(all_data[model][approach]['mean']) + err = all_data[model][approach]['error'] + errors.append(err if err is not None else 0.0) + labels.append(shorten_method_name(approach)) + colors_list.append(colors_dict.get(approach, '#808080')) + + data_by_model[model] = { + 'means': np.array(means), + 'errors': np.array(errors), + 'labels': labels, + 'colors': colors_list + } + + # Create subplot figure - 1 column, n_models rows + n_models = len(models) + n_cols = 1 + n_rows = n_models + fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, 3.5 * n_rows)) + if n_models == 1: + axes = [axes] + elif not isinstance(axes, np.ndarray): + axes = [axes] + else: + axes = axes.flatten() + + # Plot each model in a separate subplot + for idx, model in enumerate(models): + ax = axes[idx] + + means = data_by_model[model]['means'] + errors = data_by_model[model]['errors'] + labels = data_by_model[model]['labels'] + colors_list = data_by_model[model]['colors'] + + n_methods = len(means) + + # Group methods and calculate positions with spacing + # First, collect which approaches are present and their groups + present_approaches = [] + for approach in selected_approaches_list: + if approach in all_data[model]: + present_approaches.append(approach) + + # Group consecutive methods of the same category + groups = [] + current_group = None + group_start = 0 + + for i, approach in enumerate(present_approaches): + group = get_method_group(approach) + if current_group is None: + current_group = group + elif group != current_group: + groups.append((current_group, group_start, i - group_start)) + current_group = group + group_start = i + if current_group is not None: + groups.append((current_group, group_start, len(present_approaches) - group_start)) + + # Calculate x positions with spacing between groups + x = [] + group_spacing = 0.8 + bar_spacing_within_group = 0.85 # 50% less than 1.0 + current_pos = 0 + + for group_name, start_idx, group_size in groups: + for i in range(group_size): + x.append(current_pos) + current_pos += bar_spacing_within_group + current_pos += group_spacing + + x = np.array(x) + width = 0.65 + + # Plot bars + bars = ax.bar(x, means, width, + color=colors_list, + edgecolor='black', + linewidth=0.8, + alpha=0.9) + + # Add error bars + ax.errorbar(x, means, yerr=errors, + fmt='none', + ecolor='black', + elinewidth=1.2, + capsize=3.5, + capthick=1.2, + alpha=0.7, + zorder=2) + + # Add labels below bars at 35 degree angle + ax.set_xticks(x) + ax.set_xticklabels(labels, + rotation=35, ha='right', fontsize=8) + + # Focused Y-axis scaling for each model + if len(means) > 0: + if model in y_axis_limits_dict and y_axis_limits_dict[model] is not None: + y_min, y_max = y_axis_limits_dict[model] + else: + data_range = means.max() - means.min() + y_min = means.min() - data_range * 0.05 + y_max = means.max() + data_range * 0.05 + y_min = max(0, y_min) + ax.set_ylim(y_min, y_max) + + # Customize subplot + ax.set_ylabel('Hit@1 Score', fontweight='bold', fontsize=9) + ax.set_title(f'{model}', fontweight='bold', fontsize=11) + ax.grid(axis='y', alpha=0.4, linestyle='--', linewidth=0.7, color='gray') + ax.set_axisbelow(True) + + # Hide unused subplots + for idx in range(n_models, len(axes)): + axes[idx].set_visible(False) + + # Add overall title + # fig.suptitle(title, fontsize=13, fontweight='bold', y=0.995) + + # Adjust layout + plt.tight_layout(rect=[0, 0, 1, 1]) + + # Save figure + plt.savefig(f'{output_prefix}.pdf', format='pdf', dpi=300, bbox_inches='tight') + plt.savefig(f'{output_prefix}.png', format='png', dpi=300, bbox_inches='tight') + print(f"✓ Figures saved: {output_prefix}.pdf and {output_prefix}.png") + + # Report what was plotted + for model in models: + print(f" {model}: {len(data_by_model[model]['means'])} methods") + + return data_by_model + +# Manual Y-axis limits for focused scaling (optional, None for auto) +y_axis_limits = { + 'T5-Large-SSM': None, + 'T5-XL-SSM': None, + 'Mixtral': None, + 'Mistral': None, +} + +# Plot Mintaka dataset +data_by_model = plot_results( + files, + selected_approaches, + colors, + y_axis_limits, + 'Hit@1 Performance: Reranking Approaches Across Candidate Sources on Mintaka Dataset', + 'hit_at_1_comparison' +) + +for model in list(files.keys()): + means = data_by_model[model]['means'] + errors = data_by_model[model]['errors'] + labels = data_by_model[model]['labels'] + print(f"Statistics for model: {model}") + for method, mean, std in zip(labels, means, errors): + print(f" {method}: mean={mean:.4f}, std={std:.4f}") + +# MKQA dataset files +mkqa_files = { + 'T5-Large-SSM': 'tabs/mkqa_t5largessm_all_results_hit_at_n.tex', + 'T5-XL-SSM': 'tabs/mkqa_t5xlssm_all_results_hit_at_n.tex', +} + +# MKQA selected approaches (subset of methods available in MKQA) +mkqa_selected_approaches = [ + 'Without reranking', + 'Majority Vote', + # 'Semantic reranking - Text', + 'RankGPT (Mistral) - Text', + 'RankGPT (Mistral) - Text + G2T (Det. Lin.)', + 'RankGPT (Qwen2.5 7B) - Text', + 'RankGPT (Qwen2.5 7B) - Text + G2T (Det. Lin.)', + 'Linear Regression - Text', + 'Linear Regression - Graph', + 'Linear Regression - Text + Graph', + 'Linear Regression - G2T (Det. Lin.)', + 'Linear Regression - Text + G2T (Det. Lin.)', + 'Logistic Regression - Text', + 'Logistic Regression - Graph', + 'Logistic Regression - Text + Graph', + 'Logistic Regression - G2T (Det. Lin.)', + 'Logistic Regression - Text + G2T (Det. Lin.)', + 'CatBoost - Text', + 'CatBoost - Graph', + 'CatBoost - Text + Graph', + 'CatBoost - G2T (Det. Lin.)', + 'CatBoost - Text + G2T (Det. Lin.)', + 'MPNet - Text', + 'MPNet - Text + G2T (Det. Lin.)', +] + +# Add colors for MKQA-specific methods +mkqa_colors = colors.copy() +mkqa_colors.update({ + 'Linear Regression - Graph': '#B5A08D', + 'Logistic Regression - Graph': '#E8C8A2', + 'CatBoost - Graph': '#C9A00B', +}) + +# MKQA Y-axis limits +mkqa_y_axis_limits = { + 'T5-Large-SSM': None, + 'T5-XL-SSM': None, +} + +# Plot MKQA dataset +print("\n" + "="*60) +print("Plotting MKQA dataset results...") +print("="*60) +mkqa_data_by_model = plot_results( + mkqa_files, + mkqa_selected_approaches, + mkqa_colors, + mkqa_y_axis_limits, + 'Hit@1 Performance: Reranking Approaches Across Candidate Sources on MKQA Dataset', + 'hit_at_1_comparison_mkqa' +) \ No newline at end of file diff --git a/experiments/subgraphs_reranking/rankgpt/__init__.py b/experiments/subgraphs_reranking/rankgpt/__init__.py new file mode 100644 index 0000000..57e9430 --- /dev/null +++ b/experiments/subgraphs_reranking/rankgpt/__init__.py @@ -0,0 +1,12 @@ +"""RankGPT implementation for KGQA subgraph ranking""" + +from .rankgpt_ranker import RankGPTRanker +from .data_utils import prepare_rankgpt_data +from .prompt_builder import create_ranking_prompt, parse_ranking_output + +__all__ = [ + "RankGPTRanker", + "prepare_rankgpt_data", + "create_ranking_prompt", + "parse_ranking_output" +] diff --git a/experiments/subgraphs_reranking/rankgpt/aggregate_results.py b/experiments/subgraphs_reranking/rankgpt/aggregate_results.py new file mode 100644 index 0000000..15d23a3 --- /dev/null +++ b/experiments/subgraphs_reranking/rankgpt/aggregate_results.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python3 +"""Aggregate evaluation results across multiple runs and calculate mean/std statistics""" + +import json +import re +from pathlib import Path +from collections import defaultdict +import numpy as np +import subprocess +import sys + +OUTPUTS_DIR = Path(__file__).parent / "outputs" +MINTAKA_EVAL_SCRIPT = Path(__file__).parent.parent.parent.parent / "mintaka_evaluate.py" + + +def parse_evaluation_file(eval_file: Path) -> dict: + """Parse evaluation result file and extract metrics""" + results = {} + with open(eval_file, "r", encoding="utf-8") as f: + lines = f.readlines() + + current_category = None + for line in lines[1:]: # Skip "Hit scores:" header + line = line.strip() + if not line: + continue + + # Extract category name (everything before first tab) + parts = line.split("\t") + if not parts: + continue + + category = parts[0].strip() + metrics = {} + + # Extract Hit@N metrics + for part in parts[1:]: + match = re.match(r"Hit@(\d+)\s*=\s*([\d.]+)", part) + if match: + n = int(match.group(1)) + value = float(match.group(2)) + metrics[f"Hit@{n}"] = value + + if metrics: + results[category] = metrics + + return results + + +def run_evaluation(result_file: Path) -> Path: + """Run mintaka_evaluate.py on a result file if evaluation doesn't exist""" + eval_file = OUTPUTS_DIR / f"reranking_result_{result_file.name}.txt" + + if not eval_file.exists(): + print(f"Running evaluation for {result_file.name}...") + cmd = [ + sys.executable, + str(MINTAKA_EVAL_SCRIPT), + "--predictions_path", + str(result_file), + "--split", + "test" + ] + subprocess.run(cmd, check=True) + + return eval_file + + +def extract_config(result_file: Path) -> tuple: + """Extract (ds_type, variant) from result filename""" + name = result_file.name + # Format: results_r{run}_{ds_type}_rankedby_gptoss20b{_variant}.json + # ds_type can be: mistral, mixtral, t5largessm, t5xlssm + match = re.match(r"results_r\d+_([^_]+)_rankedby_gptoss20b(?:_(.+))?\.json", name) + if match: + ds_type = match.group(1) + variant = match.group(2) if match.group(2) else "default" + return (ds_type, variant) + return (None, None) + + +def aggregate_results(): + """Main function to aggregate all evaluation results""" + result_files = sorted(OUTPUTS_DIR.glob("results_r*_*_rankedby_gptoss20b*.json")) + + # Group results by configuration + config_results = defaultdict(lambda: defaultdict(list)) + + print(f"Found {len(result_files)} result files") + + for result_file in result_files: + ds_type, variant = extract_config(result_file) + if ds_type is None: + print(f"Warning: Could not parse config from {result_file.name}") + continue + + # Run evaluation if needed + eval_file = run_evaluation(result_file) + + # Parse evaluation results + try: + results = parse_evaluation_file(eval_file) + config_results[(ds_type, variant)]["runs"].append(results) + except Exception as e: + print(f"Error processing {result_file.name}: {e}") + continue + + # Calculate statistics + aggregated = {} + for (ds_type, variant), data in config_results.items(): + runs = data["runs"] + num_runs = len(runs) + if num_runs != 3: + print(f"Warning: Expected 3 runs for {ds_type}/{variant}, found {num_runs}") + + # Aggregate metrics across runs + config_key = f"{ds_type}_{variant}" + aggregated[config_key] = { + "metadata": { + "ds_type": ds_type, + "variant": variant, + "num_runs": num_runs + }, + "results": {} + } + + # Get all categories from first run + categories = set() + for run in runs: + categories.update(run.keys()) + + for category in sorted(categories): + # Collect values for each Hit@N across runs + hit_metrics = defaultdict(list) + for run in runs: + if category in run: + for hit_key, value in run[category].items(): + hit_metrics[hit_key].append(value) + + # Calculate mean and std + category_stats = {} + for hit_key in sorted(hit_metrics.keys(), key=lambda x: int(x.split("@")[1])): + values = hit_metrics[hit_key] + if values: + category_stats[f"{hit_key}_mean"] = np.mean(values) + category_stats[f"{hit_key}_std"] = np.std(values) + + if category_stats: + aggregated[config_key]["results"][category] = category_stats + + return aggregated + + +def print_results(aggregated: dict): + """Print aggregated results in a readable format""" + output_file = OUTPUTS_DIR / "aggregated_results_rankedby_gptoss20b_mean_std.txt" + + with open(output_file, "w", encoding="utf-8") as f: + f.write("Aggregated Results (Mean ± Std across runs)\n") + f.write("=" * 80 + "\n\n") + + for config_key in sorted(aggregated.keys()): + config_data = aggregated[config_key] + metadata = config_data["metadata"] + results = config_data["results"] + + ds_type = metadata["ds_type"] + variant = metadata["variant"] + num_runs = metadata["num_runs"] + + f.write(f"\nConfiguration: ds_type={ds_type}, type={variant}, runs={num_runs}\n") + f.write("-" * 80 + "\n") + + for category in sorted(results.keys()): + f.write(f"\n{category}:\n") + metrics = results[category] + + # Group by Hit@N + hit_n_values = defaultdict(dict) + for key, value in metrics.items(): + if "_mean" in key: + hit_n = key.replace("_mean", "") + hit_n_values[hit_n]["mean"] = value + elif "_std" in key: + hit_n = key.replace("_std", "") + hit_n_values[hit_n]["std"] = value + + for hit_n in sorted(hit_n_values.keys(), key=lambda x: int(x.split("@")[1])): + mean = hit_n_values[hit_n]["mean"] + std = hit_n_values[hit_n]["std"] + f.write(f" {hit_n:8} = {mean:.4f} ± {std:.4f}\n") + + f.write("\n") + + print(f"\nResults saved to {output_file}") + + # Also print to console + with open(output_file, "r", encoding="utf-8") as f: + print(f.read()) + + +if __name__ == "__main__": + aggregated = aggregate_results() + print_results(aggregated) + diff --git a/experiments/subgraphs_reranking/rankgpt/data_utils.py b/experiments/subgraphs_reranking/rankgpt/data_utils.py new file mode 100644 index 0000000..290f01e --- /dev/null +++ b/experiments/subgraphs_reranking/rankgpt/data_utils.py @@ -0,0 +1,126 @@ +"""Data preparation utilities for RankGPT ranking""" + +import pandas as pd +from typing import List, Dict, Any, Optional + + +def prepare_rankgpt_data(test_df: pd.DataFrame) -> List[Dict[str, Any]]: + """ + Convert test DataFrame to RankGPT input format. + + Args: + test_df: DataFrame with columns ['id', 'question', 'model_answers', 'answerEntity'] + + Returns: + List of dictionaries with question_id, question, and unique_answers + """ + rankgpt_data = [] + + for question_id, group in test_df.groupby("id"): + question = group["question"].iloc[0] + + # Extract unique answers - deduplicate based on answer string or entity ID + unique_answers = [] + seen_answers = set() + + # Get model answers (list of answer strings) + model_answers = group["model_answers"].iloc[0] + + # Get answer entities if available + answer_entities = group["answerEntity"].tolist() if "answerEntity" in group.columns else [None] * len(model_answers) + + for i, (answer_str, entity_id) in enumerate(zip(model_answers, answer_entities)): + # Create unique identifier for deduplication + if entity_id and entity_id != "None": + unique_key = f"entity:{entity_id}" + else: + unique_key = f"string:{answer_str}" + + if unique_key not in seen_answers: + seen_answers.add(unique_key) + unique_answers.append({ + "index": i, + "answer_string": answer_str, + "entity_id": entity_id if entity_id and entity_id != "None" else None + }) + + rankgpt_data.append({ + "question_id": question_id, + "question": question, + "unique_answers": unique_answers + }) + + return rankgpt_data + + +def extract_unique_answers_from_group(group: pd.DataFrame, graph_sequence_feature: Optional[str] = None) -> List[Dict[str, Any]]: + """ + Extract unique answers from a single question group. + + Args: + group: DataFrame group for a single question + graph_sequence_feature: Optional feature name for graph sequence (e.g., 'highlighted_determ_sequence') + + Returns: + List of unique answer dictionaries with optional graph sequences + """ + unique_answers = [] + seen_answers = set() + + # Get model answers (list of answer strings) - this is the source of truth for what to rank + model_answers = group["model_answers"].iloc[0] + + # Iterate through group rows - each row represents one question/answer candidate pair + for row_idx, (idx, row) in enumerate(group.iterrows()): + # Get answer entity from this row + entity_id = row.get("answerEntity") if "answerEntity" in row else None + if entity_id and pd.notna(entity_id): + entity_id = str(entity_id) + if entity_id == "None": + entity_id = None + else: + entity_id = None + + # Get answer string from this row + # Try question_answer column first, then fallback to model_answers by index + answer_str = None + if "question_answer" in row and pd.notna(row.get("question_answer")): + qa = str(row["question_answer"]) + if ";" in qa: + answer_str = qa.split(";")[-1].strip() + elif row_idx < len(model_answers): + answer_str = model_answers[row_idx] + + if not answer_str: + continue + + # Create unique identifier for deduplication + if entity_id and entity_id != "None": + unique_key = f"entity:{entity_id}" + else: + unique_key = f"string:{answer_str}" + + if unique_key not in seen_answers: + seen_answers.add(unique_key) + + # Get graph sequence directly from this row if feature is specified + graph_sequence = None + if graph_sequence_feature and graph_sequence_feature in row: + graph_sequence = row.get(graph_sequence_feature) + if graph_sequence and pd.notna(graph_sequence): + graph_sequence = str(graph_sequence) + else: + graph_sequence = None + + answer_dict = { + "index": row_idx, + "answer_string": answer_str, + "entity_id": entity_id + } + + if graph_sequence: + answer_dict["graph_sequence"] = graph_sequence + + unique_answers.append(answer_dict) + + return unique_answers diff --git a/experiments/subgraphs_reranking/rankgpt/predict.py b/experiments/subgraphs_reranking/rankgpt/predict.py new file mode 100755 index 0000000..4b65a5b --- /dev/null +++ b/experiments/subgraphs_reranking/rankgpt/predict.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 +"""CLI script for running RankGPT predictions on KGQA data""" + +import argparse +import json +import os +import sys +from pathlib import Path + +# Add parent directory to path to import ranking_data_utils +sys.path.append(str(Path(__file__).parent.parent)) + +from datasets import load_dataset +from ranking_data_utils import prepare_data +from rankgpt_ranker import RankGPTRanker + + +def main(): + parser = argparse.ArgumentParser(description="Run RankGPT ranking on KGQA data") + + # Required arguments + parser.add_argument("--model_name", required=True, help="Model name for ranking (e.g., meta-llama/Llama-2-7b-chat-hf)") + parser.add_argument("--ds_type", required=True, choices=["t5largessm", "t5xlssm", "mistral", "mixtral"], + help="Dataset type to use") + parser.add_argument("--dataset", type=str, default="mintaka", choices=["mintaka", "mkqa-hf"], + help="Dataset to use: mintaka or mkqa-hf.") + parser.add_argument("--output_path", required=True, help="Path to save results JSONL file") + + # Optional arguments + parser.add_argument("--window_size", type=int, default=20, help="Window size for sliding window (default: 20)") + parser.add_argument("--step_size", type=int, default=10, help="Step size for sliding window (default: 10)") + parser.add_argument("--api_base", help="API base URL (overrides OPENAI_BASE_URL env var)") + parser.add_argument("--api_key", help="API key (overrides OPENAI_API_KEY env var)") + parser.add_argument("--max_retries", type=int, default=3, help="Maximum retries for API calls (default: 3)") + parser.add_argument("--retry_delay", type=float, default=1.0, help="Delay between retries in seconds (default: 1.0)") + parser.add_argument("--split", default="test", choices=["train", "validation", "test"], + help="Dataset split to use (default: test)") + parser.add_argument("--graph_sequence_feature", choices=["highlighted_determ_sequence", "no_highlighted_determ_sequence"], + help="Optional graph sequence feature to include in prompts (default: None)") + + args = parser.parse_args() + + # Validate API configuration + api_base = args.api_base or os.getenv("OPENAI_BASE_URL") + api_key = args.api_key or os.getenv("OPENAI_API_KEY") + + if not api_base: + print("Error: API base URL must be provided via --api_base or OPENAI_BASE_URL environment variable") + sys.exit(1) + if not api_key: + print("Error: API key must be provided via --api_key or OPENAI_API_KEY environment variable") + sys.exit(1) + + print(f"Using API base: {api_base}") + print(f"Using model: {args.model_name}") + print(f"Using dataset: {args.dataset}") + print(f"Using dataset type: {args.ds_type}") + print(f"Using split: {args.split}") + print(f"Window size: {args.window_size}") + print(f"Step size: {args.step_size}") + print(f"Graph sequence feature: {args.graph_sequence_feature or 'None'}") + + # Load datasets + print("Loading datasets...") + try: + if args.dataset == "mintaka": + base_dataset_path = "AmazonScience/mintaka" + kgqa_ds_path = "s-nlp/KGQASubgraphsRanking" + features_data_dir = f"{args.ds_type}_subgraphs" + outputs_data_dir = f"{args.ds_type}_outputs" + elif args.dataset == "mkqa-hf": + base_dataset_path = "Dms12/mkqa_mintaka_format_with_question_entities" + kgqa_ds_path = "s-nlp/MKQASubgraphsRanking" + features_data_dir = f"mkqa_{args.ds_type}_subgraphs" + outputs_data_dir = f"mkqa_{args.ds_type}_outputs" + + # Load subgraph features + features_ds = load_dataset( + kgqa_ds_path, + data_dir=features_data_dir, + ) + + # Load model outputs + outputs_ds = load_dataset( + kgqa_ds_path, + data_dir=outputs_data_dir, + ) + + # Load base dataset + if args.dataset == "mintaka": + base_ds = load_dataset(base_dataset_path, revision="refs/convert/parquet", data_dir=f"en") + else: + base_ds = load_dataset(base_dataset_path) + print("Datasets loaded successfully") + + except Exception as e: + print(f"Error loading datasets: {e}") + sys.exit(1) + + # Prepare data + print("Preparing data...") + test_df = prepare_data( + base_ds[args.split], + outputs_ds[args.split], + features_ds[args.split] + ) + print(f"Prepared {len(test_df)} rows of data") + print(test_df.head()) + + # Initialize RankGPT ranker + print("Initializing RankGPT ranker...") + try: + ranker = RankGPTRanker( + api_base=api_base, + api_key=api_key, + model_name=args.model_name, + window_size=args.window_size, + step_size=args.step_size, + max_retries=args.max_retries, + retry_delay=args.retry_delay, + graph_sequence_feature=args.graph_sequence_feature + ) + print("RankGPT ranker initialized successfully") + + except Exception as e: + print(f"Error initializing ranker: {e}") + sys.exit(1) + + # Run ranking + print("Running ranking...") + results = ranker.rerank(test_df) + print(f"Ranking completed for {len(results)} questions") + + + # Save results + print(f"Saving results to {args.output_path}...") + try: + # Create output directory if it doesn't exist + output_path = Path(args.output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, "w", encoding="utf-8") as f: + for result in results: + f.write(json.dumps(result) + "\n") + + print(f"Results saved successfully to {args.output_path}") + + except Exception as e: + print(f"Error saving results: {e}") + sys.exit(1) + + print("RankGPT ranking completed successfully!") + + +if __name__ == "__main__": + main() diff --git a/experiments/subgraphs_reranking/rankgpt/prompt_builder.py b/experiments/subgraphs_reranking/rankgpt/prompt_builder.py new file mode 100644 index 0000000..dc4a17b --- /dev/null +++ b/experiments/subgraphs_reranking/rankgpt/prompt_builder.py @@ -0,0 +1,162 @@ +"""Prompt generation and parsing utilities for RankGPT""" + +import re +from typing import List, Dict, Any, Optional + + +def create_ranking_prompt(question: str, answers: List[Dict[str, Any]], window_start: int = 1) -> str: + """ + Generate the instructional permutation prompt for RankGPT. + + Args: + question: The question to rank answers for + answers: List of answer dictionaries with 'answer_string' and optionally 'entity_id' and 'graph_sequence' + window_start: Starting number for answer indexing (default 1) + + Returns: + Formatted prompt string + """ + num_answers = len(answers) + + # Check if any answer has a graph sequence + has_graph_sequences = any("graph_sequence" in answer and answer.get("graph_sequence") for answer in answers) + + # Create the main prompt + if has_graph_sequences: + prompt = f"""I will provide you with a question and {num_answers} candidate answers. Each answer is associated with a knowledge graph subgraph sequence that shows the reasoning path. Rank the answers by their relevance to the question, considering both the answer text and the graph structure that supports it. The most relevant answer should be ranked first, and the least relevant answer should be ranked last. + +Question: {question} + +Candidate Answers with Graph Context: +""" + else: + prompt = f"""I will provide you with a question and {num_answers} candidate answers. Rank the answers by their relevance to the question, from most relevant to least relevant. + +Question: {question} + +Candidate Answers: +""" + + # Add numbered answer candidates with graph sequences if available + for i, answer in enumerate(answers, start=window_start): + answer_text = answer["answer_string"] + graph_sequence = answer.get("graph_sequence") + + if graph_sequence and has_graph_sequences: + prompt += f"[{i}] Answer: {answer_text}\n" + prompt += f" Graph Sequence: {graph_sequence}\n" + else: + prompt += f"[{i}] {answer_text}\n" + + prompt += f""" +Please rank the {num_answers} answers above. The most relevant answer should be ranked first, and the least relevant answer should be ranked last. Please give the ranking results in the format [x] > [y] > [z] > ... where x, y, z are the numbers of the answers in order of relevance. Don't include any other text in your response, your response will be parsed automatically. + +Ranking:""" + + return prompt + + +def parse_ranking_output(llm_response: str, num_answers: int) -> Optional[List[int]]: + """ + Parse LLM response to extract ranking. + + Args: + llm_response: Raw response from the LLM + num_answers: Number of answers that were ranked + + Returns: + List of indices in ranked order, or None if parsing failed + """ + if not llm_response: + return None + + # Clean the response + response = llm_response.strip() + + # Try to find ranking pattern like [1] > [2] > [3] or [3] > [1] > [2] + ranking_patterns = [ + r'\[(\d+)\]\s*>\s*\[(\d+)\](?:\s*>\s*\[(\d+)\])*', # [1] > [2] > [3] format + r'(\d+)\s*>\s*(\d+)(?:\s*>\s*(\d+))*', # 1 > 2 > 3 format + ] + + for pattern in ranking_patterns: + matches = re.findall(pattern, response) + if matches: + # Extract all numbers from the first match + numbers = [] + for match in matches[0]: + if match: + numbers.append(int(match)) + + # Validate that we have the right number of answers + if len(numbers) == num_answers and all(1 <= num <= num_answers for num in numbers): + return numbers + + # Fallback: try to extract any sequence of numbers + number_pattern = r'\b(\d+)\b' + numbers = re.findall(number_pattern, response) + if numbers: + numbers = [int(n) for n in numbers] + # Filter to valid answer indices + valid_numbers = [n for n in numbers if 1 <= n <= num_answers] + if len(valid_numbers) == num_answers: + return valid_numbers + + return None + + +def create_sliding_window_prompt(question: str, answers: List[Dict[str, Any]], + window_start: int, window_end: int) -> str: + """ + Create prompt for a sliding window of answers. + + Args: + question: The question to rank answers for + answers: List of answer dictionaries + window_start: Starting index (1-based) + window_end: Ending index (1-based, inclusive) + + Returns: + Formatted prompt string for the window + """ + window_answers = answers[window_start-1:window_end] + return create_ranking_prompt(question, window_answers, window_start) + + +def extract_ranking_from_response(response: str, expected_count: int) -> Optional[List[int]]: + """ + Extract ranking from LLM response with better error handling. + + Args: + response: LLM response text + expected_count: Expected number of ranked items + + Returns: + List of ranked indices or None if parsing failed + """ + if not response: + return None + + # Try the main parsing function first + ranking = parse_ranking_output(response, expected_count) + if ranking: + return ranking + + # Fallback strategies + fallback_patterns = [ + r'(\d+)(?:\s*,\s*(\d+))*', # 1, 2, 3 format + r'(\d+)(?:\s+(\d+))*', # 1 2 3 format + ] + + for pattern in fallback_patterns: + matches = re.findall(pattern, response) + if matches: + numbers = [] + for match in matches[0]: + if match: + numbers.append(int(match)) + + if len(numbers) == expected_count and all(1 <= num <= expected_count for num in numbers): + return numbers + + return None diff --git a/experiments/subgraphs_reranking/rankgpt/rankgpt_ranker.py b/experiments/subgraphs_reranking/rankgpt/rankgpt_ranker.py new file mode 100644 index 0000000..af9dd40 --- /dev/null +++ b/experiments/subgraphs_reranking/rankgpt/rankgpt_ranker.py @@ -0,0 +1,262 @@ +"""RankGPT ranker implementation for KGQA subgraph ranking""" + +import os +import time +import numpy as np +from typing import List, Dict, Any, Optional +import pandas as pd +from openai import OpenAI +from tqdm.auto import tqdm +from concurrent.futures import ThreadPoolExecutor, as_completed + +import sys +from pathlib import Path + +# Add parent directory to path for imports +sys.path.append(str(Path(__file__).parent.parent)) + +from ranking_model import Ranker, RankedAnswer, RankedAnswersDict +from data_utils import extract_unique_answers_from_group +from prompt_builder import create_ranking_prompt, parse_ranking_output + + +class RankGPTRanker(Ranker): + """RankGPT ranker using OpenAI-compatible API for question-answer ranking""" + + def __init__(self, api_base: Optional[str] = None, api_key: Optional[str] = None, + model_name: str = "gpt-3.5-turbo", window_size: int = 20, + step_size: int = 10, max_retries: int = 3, retry_delay: float = 1.0, + max_workers: int = 8, graph_sequence_feature: Optional[str] = None): + """ + Initialize RankGPT ranker. + + Args: + api_base: API base URL (defaults to OPENAI_BASE_URL env var) + api_key: API key (defaults to OPENAI_API_KEY env var) + model_name: Model name to use for ranking + window_size: Size of sliding window for ranking + step_size: Step size for sliding window + max_retries: Maximum number of retries for API calls + retry_delay: Delay between retries in seconds + max_workers: Maximum number of worker threads + graph_sequence_feature: Optional feature name for graph sequence + ('highlighted_determ_sequence' or 'no_highlighted_determ_sequence') + """ + self.api_base = api_base or os.getenv("OPENAI_BASE_URL") + self.api_key = api_key or os.getenv("OPENAI_API_KEY", "None") + self.model_name = model_name + self.window_size = window_size + self.step_size = step_size + self.max_retries = max_retries + self.retry_delay = retry_delay + self.max_workers = max_workers + self.graph_sequence_feature = graph_sequence_feature + + if not self.api_base: + raise ValueError("API base URL must be provided via api_base parameter or OPENAI_BASE_URL environment variable") + if not self.api_key: + raise ValueError("API key must be provided via api_key parameter or OPENAI_API_KEY environment variable") + + if graph_sequence_feature and graph_sequence_feature not in ["highlighted_determ_sequence", "no_highlighted_determ_sequence"]: + raise ValueError(f"graph_sequence_feature must be one of: 'highlighted_determ_sequence', 'no_highlighted_determ_sequence', or None") + + # Initialize OpenAI client + self.client = OpenAI( + api_key=self.api_key, + base_url=self.api_base + ) + + def fit(self, train_df: pd.DataFrame) -> None: + """No-op for zero-shot ranking (no training required)""" + pass + + def _process_group(self, question_id, group) -> RankedAnswersDict: + # Extract unique answers for this question + unique_answers = extract_unique_answers_from_group(group, self.graph_sequence_feature) + + if not unique_answers: + # No answers to rank, use original order + answers = list(dict.fromkeys(group["model_answers"].iloc[0]).keys()) + ranked_answers = self._model_answers_to_ranked_answers(answers) + elif len(unique_answers) <= self.window_size: + # Single-pass ranking + ranked_answers = self._rank_single_pass(group["question"].iloc[0], unique_answers) + else: + # Sliding window ranking + ranked_answers = self._sliding_window_rank(group["question"].iloc[0], unique_answers) + + return RankedAnswersDict( + QuestionID=question_id, + RankedAnswers=ranked_answers + ) + + def rerank(self, test_df: pd.DataFrame) -> List[RankedAnswersDict]: + """ + Rerank answers using RankGPT approach. + + Args: + test_df: DataFrame with test data + + Returns: + List of ranked answers in the required format + """ + results = [] + + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + futures = [executor.submit(self._process_group, question_id, group) for question_id, group in test_df.groupby("id")] + for future in tqdm(as_completed(futures), total=len(futures), desc="Ranking questions"): + results.append(future.result()) + + return results + + def _rank_single_pass(self, question: str, answers: List[Dict[str, Any]]) -> List[RankedAnswer]: + """ + Rank all answers in a single API call. + + Args: + question: The question to rank answers for + answers: List of unique answer dictionaries + + Returns: + List of ranked answers + """ + prompt = create_ranking_prompt(question, answers) + + try: + response = self._call_openai_api(prompt) + ranking = parse_ranking_output(response, len(answers)) + + if ranking is None: + # Fallback to original order + return self._answers_to_ranked_answers(answers) + + # Apply ranking + ranked_answers = [] + for rank_idx in ranking: + if 1 <= rank_idx <= len(answers): + answer = answers[rank_idx - 1] # Convert to 0-based index + ranked_answers.append( + RankedAnswer( + AnswerEntityID=answer.get("entity_id"), + AnswerString=answer["answer_string"], + Score=None + ) + ) + + return ranked_answers + + except Exception as e: + print(f"Error in single-pass ranking: {e}") + return self._answers_to_ranked_answers(answers) + + def _sliding_window_rank(self, question: str, all_answers: List[Dict[str, Any]]) -> List[RankedAnswer]: + """ + Rank answers using sliding window approach. + + Args: + question: The question to rank answers for + all_answers: List of all unique answer dictionaries + + Returns: + List of ranked answers + """ + num_answers = len(all_answers) + scores = np.zeros(num_answers) + + # Create sliding windows + for start_idx in range(0, num_answers, self.step_size): + end_idx = min(start_idx + self.window_size, num_answers) + window_answers = all_answers[start_idx:end_idx] + + if len(window_answers) < 2: + continue + + prompt = create_ranking_prompt(question, window_answers, start_idx + 1) + + try: + response = self._call_openai_api(prompt) + ranking = parse_ranking_output(response, len(window_answers)) + + + if ranking is not None: + # Apply position-based scoring (position i gets score 1/i) + for rank_pos, rank_idx in enumerate(ranking): + if 1 <= rank_idx <= len(window_answers): + global_idx = start_idx + rank_idx - 1 + if 0 <= global_idx < num_answers: + scores[global_idx] += 1.0 / (rank_pos + 1) + + except Exception as e: + print(f"Error in sliding window ranking for window {start_idx}-{end_idx}: {e}") + continue + + # Sort by scores (higher is better) + sorted_indices = np.argsort(scores)[::-1] + + # Create ranked answers + ranked_answers = [] + for idx in sorted_indices: + answer = all_answers[idx] + ranked_answers.append( + RankedAnswer( + AnswerEntityID=answer.get("entity_id"), + AnswerString=answer["answer_string"], + Score=float(scores[idx]) if scores[idx] > 0 else None + ) + ) + + return ranked_answers + + def _call_openai_api(self, prompt: str) -> str: + """ + Make API call to OpenAI-compatible endpoint. + + Args: + prompt: The prompt to send + + Returns: + Response text from the API + """ + for attempt in range(self.max_retries): + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=[ + {"role": "user", "content": prompt} + ], + temperature=0.0, # Deterministic output + max_tokens=1000 + ) + + return response.choices[0].message.content.strip() + + except Exception as e: + if attempt < self.max_retries - 1: + print(f"API call failed (attempt {attempt + 1}/{self.max_retries}): {e}") + time.sleep(self.retry_delay * (2 ** attempt)) # Exponential backoff + else: + raise e + + raise Exception(f"API call failed after {self.max_retries} attempts") + + def _answers_to_ranked_answers(self, answers: List[Dict[str, Any]]) -> List[RankedAnswer]: + """Convert answer dictionaries to RankedAnswer format""" + return [ + RankedAnswer( + AnswerEntityID=answer.get("entity_id"), + AnswerString=answer["answer_string"], + Score=None + ) + for answer in answers + ] + + def _model_answers_to_ranked_answers(self, model_answers: List[str]) -> List[RankedAnswer]: + """Convert model answers to RankedAnswer format (fallback)""" + return [ + RankedAnswer( + AnswerEntityID=None, + AnswerString=answer, + Score=None + ) + for answer in model_answers + ] diff --git a/experiments/subgraphs_reranking/ranking.ipynb b/experiments/subgraphs_reranking/ranking.ipynb index 2521a27..6b6e540 100644 --- a/experiments/subgraphs_reranking/ranking.ipynb +++ b/experiments/subgraphs_reranking/ranking.ipynb @@ -15,6 +15,16 @@ "execution_count": 2, "metadata": {}, "outputs": [], + "source": [ + "import os\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], "source": [ "import pandas as pd\n", "from datasets import load_dataset, Dataset\n", @@ -42,7 +52,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -57,18 +67,9 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using the latest cached version of the dataset since AmazonScience/mintaka couldn't be found on the Hugging Face Hub\n", - "Found the latest cached dataset configuration 'default' at /home/jovyan/.cache/huggingface/datasets/AmazonScience___mintaka/default/0.0.0/fe3f1235e31b01dc9cce913086f0cb6ed0d9b82e (last modified on Fri Oct 24 13:02:21 2025).\n" - ] - } - ], + "outputs": [], "source": [ "# os.environ['HF_DATASETS_CACHE'] = '/workspace/storage/misc/huggingface'\n", "\n", @@ -90,7 +91,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -140,8 +141,8 @@ " \n", " \n", " count\n", - " 4000.00000\n", - " 4000.00000\n", + " 4000.000000\n", + " 4000.000000\n", " 4000.000000\n", " 4000.000000\n", " 4000.000000\n", @@ -160,60 +161,60 @@ " 4000.000000\n", " 4000.000000\n", " 4000.000000\n", - " 4000.00000\n", + " 4000.000000\n", " \n", " \n", " mean\n", - " 9.06900\n", - " 9.06900\n", - " 8.283500\n", - " 8.283500\n", - " 8.283500\n", - " 8.283500\n", - " 8.283500\n", - " 8.283500\n", - " 8.283500\n", - " 8.283500\n", + " 4.534500\n", + " 4.534500\n", + " 4.141750\n", + " 4.141750\n", + " 4.141750\n", + " 4.141750\n", + " 4.141750\n", + " 4.141750\n", + " 4.141750\n", + " 4.141750\n", " ...\n", - " 8.283500\n", - " 8.283500\n", - " 8.283500\n", - " 8.283500\n", - " 8.283500\n", - " 8.283500\n", - " 8.283500\n", - " 8.283500\n", - " 8.283500\n", - " 9.06900\n", + " 4.534500\n", + " 4.534500\n", + " 4.534500\n", + " 4.141750\n", + " 4.141750\n", + " 4.141750\n", + " 4.141750\n", + " 4.141750\n", + " 4.141750\n", + " 4.534500\n", " \n", " \n", " std\n", - " 6.60001\n", - " 6.60001\n", - " 7.458052\n", - " 7.458052\n", - " 7.458052\n", - " 7.458052\n", - " 7.458052\n", - " 7.458052\n", - " 7.458052\n", - " 7.458052\n", + " 3.300005\n", + " 3.300005\n", + " 3.729026\n", + " 3.729026\n", + " 3.729026\n", + " 3.729026\n", + " 3.729026\n", + " 3.729026\n", + " 3.729026\n", + " 3.729026\n", " ...\n", - " 7.458052\n", - " 7.458052\n", - " 7.458052\n", - " 7.458052\n", - " 7.458052\n", - " 7.458052\n", - " 7.458052\n", - " 7.458052\n", - " 7.458052\n", - " 6.60001\n", + " 3.300005\n", + " 3.300005\n", + " 3.300005\n", + " 3.729026\n", + " 3.729026\n", + " 3.729026\n", + " 3.729026\n", + " 3.729026\n", + " 3.729026\n", + " 3.300005\n", " \n", " \n", " min\n", - " 2.00000\n", - " 2.00000\n", + " 1.000000\n", + " 1.000000\n", " 0.000000\n", " 0.000000\n", " 0.000000\n", @@ -223,21 +224,21 @@ " 0.000000\n", " 0.000000\n", " ...\n", + " 1.000000\n", + " 1.000000\n", + " 1.000000\n", " 0.000000\n", " 0.000000\n", " 0.000000\n", " 0.000000\n", " 0.000000\n", " 0.000000\n", - " 0.000000\n", - " 0.000000\n", - " 0.000000\n", - " 2.00000\n", + " 1.000000\n", " \n", " \n", " 25%\n", - " 2.00000\n", - " 2.00000\n", + " 1.000000\n", + " 1.000000\n", " 0.000000\n", " 0.000000\n", " 0.000000\n", @@ -247,88 +248,88 @@ " 0.000000\n", " 0.000000\n", " ...\n", + " 1.000000\n", + " 1.000000\n", + " 1.000000\n", " 0.000000\n", " 0.000000\n", " 0.000000\n", " 0.000000\n", " 0.000000\n", " 0.000000\n", - " 0.000000\n", - " 0.000000\n", - " 0.000000\n", - " 2.00000\n", + " 1.000000\n", " \n", " \n", " 50%\n", - " 10.00000\n", - " 10.00000\n", - " 10.000000\n", - " 10.000000\n", - " 10.000000\n", - " 10.000000\n", - " 10.000000\n", - " 10.000000\n", - " 10.000000\n", - " 10.000000\n", + " 5.000000\n", + " 5.000000\n", + " 5.000000\n", + " 5.000000\n", + " 5.000000\n", + " 5.000000\n", + " 5.000000\n", + " 5.000000\n", + " 5.000000\n", + " 5.000000\n", " ...\n", - " 10.000000\n", - " 10.000000\n", - " 10.000000\n", - " 10.000000\n", - " 10.000000\n", - " 10.000000\n", - " 10.000000\n", - " 10.000000\n", - " 10.000000\n", - " 10.00000\n", + " 5.000000\n", + " 5.000000\n", + " 5.000000\n", + " 5.000000\n", + " 5.000000\n", + " 5.000000\n", + " 5.000000\n", + " 5.000000\n", + " 5.000000\n", + " 5.000000\n", " \n", " \n", " 75%\n", - " 14.00000\n", - " 14.00000\n", - " 14.000000\n", - " 14.000000\n", - " 14.000000\n", - " 14.000000\n", - " 14.000000\n", - " 14.000000\n", - " 14.000000\n", - " 14.000000\n", + " 7.000000\n", + " 7.000000\n", + " 7.000000\n", + " 7.000000\n", + " 7.000000\n", + " 7.000000\n", + " 7.000000\n", + " 7.000000\n", + " 7.000000\n", + " 7.000000\n", " ...\n", - " 14.000000\n", - " 14.000000\n", - " 14.000000\n", - " 14.000000\n", - " 14.000000\n", - " 14.000000\n", - " 14.000000\n", - " 14.000000\n", - " 14.000000\n", - " 14.00000\n", + " 7.000000\n", + " 7.000000\n", + " 7.000000\n", + " 7.000000\n", + " 7.000000\n", + " 7.000000\n", + " 7.000000\n", + " 7.000000\n", + " 7.000000\n", + " 7.000000\n", " \n", " \n", " max\n", - " 38.00000\n", - " 38.00000\n", - " 38.000000\n", - " 38.000000\n", - " 38.000000\n", - " 38.000000\n", - " 38.000000\n", - " 38.000000\n", - " 38.000000\n", - " 38.000000\n", + " 19.000000\n", + " 19.000000\n", + " 19.000000\n", + " 19.000000\n", + " 19.000000\n", + " 19.000000\n", + " 19.000000\n", + " 19.000000\n", + " 19.000000\n", + " 19.000000\n", " ...\n", - " 38.000000\n", - " 38.000000\n", - " 38.000000\n", - " 38.000000\n", - " 38.000000\n", - " 38.000000\n", - " 38.000000\n", - " 38.000000\n", - " 38.000000\n", - " 38.00000\n", + " 19.000000\n", + " 19.000000\n", + " 19.000000\n", + " 19.000000\n", + " 19.000000\n", + " 19.000000\n", + " 19.000000\n", + " 19.000000\n", + " 19.000000\n", + " 19.000000\n", " \n", " \n", "\n", @@ -336,80 +337,80 @@ "" ], "text/plain": [ - " target target_out_of_vocab id_y answerEntity \\\n", - "count 4000.00000 4000.00000 4000.000000 4000.000000 \n", - "mean 9.06900 9.06900 8.283500 8.283500 \n", - "std 6.60001 6.60001 7.458052 7.458052 \n", - "min 2.00000 2.00000 0.000000 0.000000 \n", - "25% 2.00000 2.00000 0.000000 0.000000 \n", - "50% 10.00000 10.00000 10.000000 10.000000 \n", - "75% 14.00000 14.00000 14.000000 14.000000 \n", - "max 38.00000 38.00000 38.000000 38.000000 \n", + " target target_out_of_vocab id_y answerEntity \\\n", + "count 4000.000000 4000.000000 4000.000000 4000.000000 \n", + "mean 4.534500 4.534500 4.141750 4.141750 \n", + "std 3.300005 3.300005 3.729026 3.729026 \n", + "min 1.000000 1.000000 0.000000 0.000000 \n", + "25% 1.000000 1.000000 0.000000 0.000000 \n", + "50% 5.000000 5.000000 5.000000 5.000000 \n", + "75% 7.000000 7.000000 7.000000 7.000000 \n", + "max 19.000000 19.000000 19.000000 19.000000 \n", "\n", " questionEntity groundTruthAnswerEntity complexityType graph \\\n", "count 4000.000000 4000.000000 4000.000000 4000.000000 \n", - "mean 8.283500 8.283500 8.283500 8.283500 \n", - "std 7.458052 7.458052 7.458052 7.458052 \n", + "mean 4.141750 4.141750 4.141750 4.141750 \n", + "std 3.729026 3.729026 3.729026 3.729026 \n", "min 0.000000 0.000000 0.000000 0.000000 \n", "25% 0.000000 0.000000 0.000000 0.000000 \n", - "50% 10.000000 10.000000 10.000000 10.000000 \n", - "75% 14.000000 14.000000 14.000000 14.000000 \n", - "max 38.000000 38.000000 38.000000 38.000000 \n", + "50% 5.000000 5.000000 5.000000 5.000000 \n", + "75% 7.000000 7.000000 7.000000 7.000000 \n", + "max 19.000000 19.000000 19.000000 19.000000 \n", "\n", " correct t5_sequence ... gap_sequence_embedding \\\n", "count 4000.000000 4000.000000 ... 4000.000000 \n", - "mean 8.283500 8.283500 ... 8.283500 \n", - "std 7.458052 7.458052 ... 7.458052 \n", - "min 0.000000 0.000000 ... 0.000000 \n", - "25% 0.000000 0.000000 ... 0.000000 \n", - "50% 10.000000 10.000000 ... 10.000000 \n", - "75% 14.000000 14.000000 ... 14.000000 \n", - "max 38.000000 38.000000 ... 38.000000 \n", + "mean 4.141750 4.141750 ... 4.534500 \n", + "std 3.729026 3.729026 ... 3.300005 \n", + "min 0.000000 0.000000 ... 1.000000 \n", + "25% 0.000000 0.000000 ... 1.000000 \n", + "50% 5.000000 5.000000 ... 5.000000 \n", + "75% 7.000000 7.000000 ... 7.000000 \n", + "max 19.000000 19.000000 ... 19.000000 \n", "\n", " t5_sequence_embedding question_answer_embedding \\\n", "count 4000.000000 4000.000000 \n", - "mean 8.283500 8.283500 \n", - "std 7.458052 7.458052 \n", - "min 0.000000 0.000000 \n", - "25% 0.000000 0.000000 \n", - "50% 10.000000 10.000000 \n", - "75% 14.000000 14.000000 \n", - "max 38.000000 38.000000 \n", + "mean 4.534500 4.534500 \n", + "std 3.300005 3.300005 \n", + "min 1.000000 1.000000 \n", + "25% 1.000000 1.000000 \n", + "50% 5.000000 5.000000 \n", + "75% 7.000000 7.000000 \n", + "max 19.000000 19.000000 \n", "\n", " highlighted_determ_sequence no_highlighted_determ_sequence \\\n", "count 4000.000000 4000.000000 \n", - "mean 8.283500 8.283500 \n", - "std 7.458052 7.458052 \n", + "mean 4.141750 4.141750 \n", + "std 3.729026 3.729026 \n", "min 0.000000 0.000000 \n", "25% 0.000000 0.000000 \n", - "50% 10.000000 10.000000 \n", - "75% 14.000000 14.000000 \n", - "max 38.000000 38.000000 \n", + "50% 5.000000 5.000000 \n", + "75% 7.000000 7.000000 \n", + "max 19.000000 19.000000 \n", "\n", " highlighted_t5_sequence no_highlighted_t5_sequence \\\n", "count 4000.000000 4000.000000 \n", - "mean 8.283500 8.283500 \n", - "std 7.458052 7.458052 \n", + "mean 4.141750 4.141750 \n", + "std 3.729026 3.729026 \n", "min 0.000000 0.000000 \n", "25% 0.000000 0.000000 \n", - "50% 10.000000 10.000000 \n", - "75% 14.000000 14.000000 \n", - "max 38.000000 38.000000 \n", + "50% 5.000000 5.000000 \n", + "75% 7.000000 7.000000 \n", + "max 19.000000 19.000000 \n", "\n", " highlighted_gap_sequence no_highlighted_gap_sequence model_answers \n", - "count 4000.000000 4000.000000 4000.00000 \n", - "mean 8.283500 8.283500 9.06900 \n", - "std 7.458052 7.458052 6.60001 \n", - "min 0.000000 0.000000 2.00000 \n", - "25% 0.000000 0.000000 2.00000 \n", - "50% 10.000000 10.000000 10.00000 \n", - "75% 14.000000 14.000000 14.00000 \n", - "max 38.000000 38.000000 38.00000 \n", + "count 4000.000000 4000.000000 4000.000000 \n", + "mean 4.141750 4.141750 4.534500 \n", + "std 3.729026 3.729026 3.300005 \n", + "min 0.000000 0.000000 1.000000 \n", + "25% 0.000000 0.000000 1.000000 \n", + "50% 5.000000 5.000000 5.000000 \n", + "75% 7.000000 7.000000 7.000000 \n", + "max 19.000000 19.000000 19.000000 \n", "\n", "[8 rows x 32 columns]" ] }, - "execution_count": 5, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -425,7 +426,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -449,12 +450,12 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "results_path = Path(\n", - " f\"/home/jovyan/kbqa/reranking_model_results/{ds_type}/\"\n", + " f\"./reranking_model_results/{ds_type}/\"\n", ")\n", "results_path.mkdir(parents=True, exist_ok=True)\n", "\n", @@ -469,7 +470,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -1030,7 +1031,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -1051,21 +1052,21 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Try load model...\n", + "Trying to load the model...\n", "Model Loaded.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "579f2caacc9d4f04894e484d9cd40161", + "model_id": "914bfd27cd96498bb85954ede47dda5f", "version_major": 2, "version_minor": 0 }, @@ -1080,14 +1081,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "Try load model...\n", + "Trying to load the model...\n", "Model Loaded.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "d93143f2fad34243a88cbc9bd3cf2f86", + "model_id": "59a327fd193146eba06daf3ebb0b7dd1", "version_major": 2, "version_minor": 0 }, @@ -1097,103 +1098,7 @@ }, "metadata": {}, "output_type": "display_data" - } - ], - "source": [ - "device = torch.device(\"cuda\")\n", - "\n", - "model_path = f\"/home/jovyan/kbqa/subgraphs_reranking_runs/sequence/question_answer/KGQASubgraphsRanking/run_1/{ds_type}/outputs/checkpoint-best\"\n", - "mpnet_ranker = MPNetRanker(\"question_answer\", model_path, device)\n", - "with open(\n", - " results_path / f\"mpnet_text_only_determ_reranking_seq2seq_{ds_type}_results.jsonl\",\n", - " \"w+\",\n", - ") as f:\n", - " for result in mpnet_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")\n", - "\n", - "model_path = f\"/home/jovyan/kbqa/subgraphs_reranking_runs/sequence/no_highlighted_determ_sequence/KGQASubgraphsRanking/run_1/{ds_type}/outputs/checkpoint-best\"\n", - "mpnet_ranker = MPNetRanker(\"no_highlighted_determ_sequence\", model_path, device)\n", - "with open(\n", - " results_path / f\"mpnet_no_hl_g2t_determ_reranking_seq2seq_{ds_type}_results.jsonl\",\n", - " \"w\",\n", - ") as f:\n", - " for result in mpnet_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")\n", - "\n", - "model_path = f\"/home/jovyan/kbqa/subgraphs_reranking_runs/sequence/highlighted_determ_sequence/KGQASubgraphsRanking/run_1/{ds_type}/outputs/checkpoint-best\"\n", - "mpnet_ranker = MPNetRanker(\"highlighted_determ_sequence\", model_path, device)\n", - "with open(\n", - " results_path / f\"mpnet_hl_g2t_determ_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", - ") as f:\n", - " for result in mpnet_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")\n", - "\n", - "\n", - "model_path = f\"/home/jovyan/kbqa/subgraphs_reranking_runs/sequence/highlighted_t5_sequence/KGQASubgraphsRanking/run_1/{ds_type}/outputs/checkpoint-best\"\n", - "mpnet_ranker = MPNetRanker(\"highlighted_t5_sequence\", model_path, device)\n", - "with open(\n", - " results_path / f\"mpnet_hl_g2t_t5_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", - ") as f:\n", - " for result in mpnet_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")\n", - "\n", - "model_path = f\"/home/jovyan/kbqa/subgraphs_reranking_runs/sequence/no_highlighted_t5_sequence/KGQASubgraphsRanking/run_1/{ds_type}/outputs/checkpoint-best\"\n", - "mpnet_ranker = MPNetRanker(\"no_highlighted_t5_sequence\", model_path, device)\n", - "with open(\n", - " results_path / f\"mpnet_no_hl_g2t_t5_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", - ") as f:\n", - " for result in mpnet_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")\n", - "\n", - "\n", - "model_path = f\"/home/jovyan/kbqa/subgraphs_reranking_runs/sequence/highlighted_gap_sequence/KGQASubgraphsRanking/run_1/{ds_type}/outputs/checkpoint-best\"\n", - "mpnet_ranker = MPNetRanker(\"highlighted_gap_sequence\", model_path, device)\n", - "with open(\n", - " results_path / f\"mpnet_hl_g2t_gap_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", - ") as f:\n", - " for result in mpnet_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")\n", - "\n", - "model_path = f\"/home/jovyan/kbqa/subgraphs_reranking_runs/sequence/no_highlighted_gap_sequence/KGQASubgraphsRanking/run_1/{ds_type}/outputs/checkpoint-best\"\n", - "mpnet_ranker = MPNetRanker(\"no_highlighted_gap_sequence\", model_path, device)\n", - "with open(\n", - " results_path / f\"mpnet_no_hl_g2t_gap_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", - ") as f:\n", - " for result in mpnet_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Catboost" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "catboost_dir = (\n", - " f\"/workspace/storage/misc/features_reranking/catboost/unified_reranking/{ds_type}\"\n", - ")\n", - "model_weights = f\"{catboost_dir}/text/best_model\"\n", - "catboost_ranker = CatboostRanker(model_weights, graph_features=features_map[\"text\"])\n", - "\n", - "with open(\n", - " results_path / f\"catboost_graph_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", - ") as f:\n", - " for result in catboost_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ + }, { "name": "stdout", "output_type": "stream", @@ -1205,7 +1110,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "387842ac132b49e7aac60bb326e69e57", + "model_id": "d51f6aa6615446b295f7b210edd2d5e0", "version_major": 2, "version_minor": 0 }, @@ -1215,30 +1120,7 @@ }, "metadata": {}, "output_type": "display_data" - } - ], - "source": [ - "catboost_dir = (\n", - " f\"/workspace/storage/misc/features_reranking/catboost/unified_reranking/{ds_type}\"\n", - ")\n", - "model_weights = f\"{catboost_dir}/graph/best_model\"\n", - "fitted_scaler_path = f\"{catboost_dir}/graph/fitted_scaler.bz2\"\n", - "catboost_ranker = CatboostRanker(\n", - " model_weights, graph_features=features_map[\"graph\"], scaler_path=fitted_scaler_path\n", - ")\n", - "\n", - "with open(\n", - " results_path / f\"catboost_graph_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", - ") as f:\n", - " for result in catboost_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "metadata": {}, - "outputs": [ + }, { "name": "stdout", "output_type": "stream", @@ -1250,7 +1132,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "177b5f610de44f18a349afb57dae8211", + "model_id": "7abf017c18b54638bdf6dacb7df19d44", "version_major": 2, "version_minor": 0 }, @@ -1260,30 +1142,7 @@ }, "metadata": {}, "output_type": "display_data" - } - ], - "source": [ - "model_weights = f\"{catboost_dir}/text_graph/best_model\"\n", - "fitted_scaler_path = f\"{catboost_dir}/text_graph/fitted_scaler.bz2\"\n", - "catboost_ranker = CatboostRanker(\n", - " model_weights,\n", - " sequence_features=features_map[\"text\"],\n", - " graph_features=features_map[\"graph\"],\n", - " scaler_path=fitted_scaler_path,\n", - ")\n", - "\n", - "with open(\n", - " results_path / f\"catboost_text_graph_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", - ") as f:\n", - " for result in catboost_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ + }, { "name": "stdout", "output_type": "stream", @@ -1295,7 +1154,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "54f8eb0a67344258ba0cb2db752d88c3", + "model_id": "ceefb3c7217a4686aaf4d3f491ff4cc7", "version_major": 2, "version_minor": 0 }, @@ -1305,26 +1164,7 @@ }, "metadata": {}, "output_type": "display_data" - } - ], - "source": [ - "model_weights = f\"{catboost_dir}/g2t_determ/best_model\"\n", - "catboost_ranker = CatboostRanker(\n", - " model_weights, sequence_features=features_map[\"g2t_determ\"]\n", - ")\n", - "\n", - "with open(\n", - " results_path / f\"catboost_g2t_determ_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", - ") as f:\n", - " for result in catboost_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ + }, { "name": "stdout", "output_type": "stream", @@ -1336,7 +1176,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "2b6b1baa38794664ab53885e9b68283a", + "model_id": "9ee75de72e194953a6158978c4624b06", "version_major": 2, "version_minor": 0 }, @@ -1346,26 +1186,7 @@ }, "metadata": {}, "output_type": "display_data" - } - ], - "source": [ - "model_weights = f\"{catboost_dir}/g2t_t5/best_model\"\n", - "catboost_ranker = CatboostRanker(\n", - " model_weights, sequence_features=features_map[\"g2t_t5\"]\n", - ")\n", - "\n", - "with open(\n", - " results_path / f\"catboost_g2t_t5_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", - ") as f:\n", - " for result in catboost_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ + }, { "name": "stdout", "output_type": "stream", @@ -1377,7 +1198,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "d83ca03af9ce4861baa615cc04c64ff0", + "model_id": "8d6d9a720cb5445291a53af5641702ef", "version_major": 2, "version_minor": 0 }, @@ -1387,26 +1208,7 @@ }, "metadata": {}, "output_type": "display_data" - } - ], - "source": [ - "model_weights = f\"{catboost_dir}/g2t_gap/best_model\"\n", - "catboost_ranker = CatboostRanker(\n", - " model_weights, sequence_features=features_map[\"g2t_gap\"]\n", - ")\n", - "\n", - "with open(\n", - " results_path / f\"catboost_g2t_gap_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", - ") as f:\n", - " for result in catboost_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [ + }, { "name": "stdout", "output_type": "stream", @@ -1418,7 +1220,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "141c985e38f940269b5657720a0eaa70", + "model_id": "1afcfc0f4e484713b8a5056ca9a13bcc", "version_major": 2, "version_minor": 0 }, @@ -1428,31 +1230,7 @@ }, "metadata": {}, "output_type": "display_data" - } - ], - "source": [ - "model_weights = f\"{catboost_dir}/text_graph_g2t_determ/best_model\"\n", - "fitted_scaler_path = f\"{catboost_dir}/text_graph_g2t_determ/fitted_scaler.bz2\"\n", - "catboost_ranker = CatboostRanker(\n", - " model_weights,\n", - " sequence_features=features_map[\"text\"] + features_map[\"g2t_determ\"],\n", - " graph_features=features_map[\"graph\"],\n", - " scaler_path=fitted_scaler_path,\n", - ")\n", - "with open(\n", - " results_path\n", - " / f\"catboost_text_graph_g2t_determ_reranking_seq2seq_{ds_type}_results.jsonl\",\n", - " \"w\",\n", - ") as f:\n", - " for result in catboost_ranker.rerank(test_df):\n", - " f.write(json.dumps(result) + \"\\n\")" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": {}, - "outputs": [ + }, { "name": "stdout", "output_type": "stream", @@ -1464,7 +1242,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "db3a397f28564a6fa6ae54a6bc3c92b0", + "model_id": "7ab18abf7f7a403fb09ced238ab10d17", "version_major": 2, "version_minor": 0 }, @@ -1474,21 +1252,872 @@ }, "metadata": {}, "output_type": "display_data" - } - ], - "source": [ - "model_weights = f\"{catboost_dir}/text_graph_g2t_t5/best_model\"\n", - "fitted_scaler_path = f\"{catboost_dir}/text_graph_g2t_t5/fitted_scaler.bz2\"\n", - "catboost_ranker = CatboostRanker(\n", - " model_weights,\n", - " sequence_features=features_map[\"text\"] + features_map[\"g2t_t5\"],\n", - " graph_features=features_map[\"graph\"],\n", - " scaler_path=fitted_scaler_path,\n", - ")\n", - "\n", - "with open(\n", - " results_path\n", - " / f\"catboost_text_graph_g2t_t5_reranking_seq2seq_{ds_type}_results.jsonl\",\n", + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Trying to load the model...\n", + "Model Loaded.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0faa3a30cb2e4bf3bc2385448dad6e23", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/4000 [00:00 \u001b[39m\u001b[32m15\u001b[39m \u001b[43mcatboost_ranker\u001b[49m\u001b[43m.\u001b[49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_df\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_df\u001b[49m\u001b[43m=\u001b[49m\u001b[43mvalid_df\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel_save_path\u001b[49m\u001b[43m=\u001b[49m\u001b[43mmodel_weights\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43muse_gpu\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevices\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43m0\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 16\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(\n\u001b[32m 17\u001b[39m results_path / \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mcatboost_text_reranking_seq2seq_run_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mrun\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m_\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mds_type\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m_results.jsonl\u001b[39m\u001b[33m\"\u001b[39m, \u001b[33m\"\u001b[39m\u001b[33mw\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 18\u001b[39m ) \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[32m 19\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m result \u001b[38;5;129;01min\u001b[39;00m catboost_ranker.rerank(test_df):\n", + "\u001b[36mFile \u001b[39m\u001b[32m/workspace/kbqa/experiments/subgraphs_reranking/ranking_model.py:496\u001b[39m, in \u001b[36mCatboostRanker.fit\u001b[39m\u001b[34m(self, train_df, val_df, model_save_path, scaler_save_path, early_stopping_rounds, use_gpu, devices, **kwargs)\u001b[39m\n\u001b[32m 489\u001b[39m grid_search_params = {\n\u001b[32m 490\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mlearning_rate\u001b[39m\u001b[33m\"\u001b[39m: \u001b[38;5;28mlist\u001b[39m(np.linspace(\u001b[32m0.03\u001b[39m, \u001b[32m0.3\u001b[39m, \u001b[32m5\u001b[39m)),\n\u001b[32m 491\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mdepth\u001b[39m\u001b[33m\"\u001b[39m: [\u001b[32m4\u001b[39m, \u001b[32m6\u001b[39m, \u001b[32m8\u001b[39m, \u001b[32m10\u001b[39m],\n\u001b[32m 492\u001b[39m \u001b[33m\"\u001b[39m\u001b[33miterations\u001b[39m\u001b[33m\"\u001b[39m: [\u001b[32m2000\u001b[39m, \u001b[32m3000\u001b[39m, \u001b[32m4000\u001b[39m],\n\u001b[32m 493\u001b[39m }\n\u001b[32m 495\u001b[39m model = CatBoostRegressor(**base_params)\n\u001b[32m--> \u001b[39m\u001b[32m496\u001b[39m grid_search_result = \u001b[43mmodel\u001b[49m\u001b[43m.\u001b[49m\u001b[43mgrid_search\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgrid_search_params\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlearn_pool\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 498\u001b[39m final_params = {\n\u001b[32m 499\u001b[39m \u001b[33m\"\u001b[39m\u001b[33miterations\u001b[39m\u001b[33m\"\u001b[39m: grid_search_result[\u001b[33m\"\u001b[39m\u001b[33mparams\u001b[39m\u001b[33m\"\u001b[39m][\u001b[33m\"\u001b[39m\u001b[33miterations\u001b[39m\u001b[33m\"\u001b[39m],\n\u001b[32m 500\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mlearning_rate\u001b[39m\u001b[33m\"\u001b[39m: grid_search_result[\u001b[33m\"\u001b[39m\u001b[33mparams\u001b[39m\u001b[33m\"\u001b[39m][\u001b[33m\"\u001b[39m\u001b[33mlearning_rate\u001b[39m\u001b[33m\"\u001b[39m],\n\u001b[32m 501\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mdepth\u001b[39m\u001b[33m\"\u001b[39m: grid_search_result[\u001b[33m\"\u001b[39m\u001b[33mparams\u001b[39m\u001b[33m\"\u001b[39m][\u001b[33m\"\u001b[39m\u001b[33mdepth\u001b[39m\u001b[33m\"\u001b[39m],\n\u001b[32m 502\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mearly_stopping_rounds\u001b[39m\u001b[33m\"\u001b[39m: early_stopping_rounds,\n\u001b[32m 503\u001b[39m }\n\u001b[32m 504\u001b[39m final_params.update(base_params)\n", + "\u001b[36mFile \u001b[39m\u001b[32m/opt/conda/lib/python3.11/site-packages/catboost/core.py:4211\u001b[39m, in \u001b[36mCatBoost.grid_search\u001b[39m\u001b[34m(self, param_grid, X, y, cv, partition_random_seed, calc_cv_statistics, search_by_train_test_split, refit, shuffle, stratified, train_size, verbose, plot, plot_file, log_cout, log_cerr)\u001b[39m\n\u001b[32m 4208\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(grid[key], Iterable):\n\u001b[32m 4209\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[33m'\u001b[39m\u001b[33mParameter grid value is not iterable (key=\u001b[39m\u001b[38;5;132;01m{!r}\u001b[39;00m\u001b[33m, value=\u001b[39m\u001b[38;5;132;01m{!r}\u001b[39;00m\u001b[33m)\u001b[39m\u001b[33m'\u001b[39m.format(key, grid[key]))\n\u001b[32m-> \u001b[39m\u001b[32m4211\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_tune_hyperparams\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 4212\u001b[39m \u001b[43m \u001b[49m\u001b[43mparam_grid\u001b[49m\u001b[43m=\u001b[49m\u001b[43mparam_grid\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mX\u001b[49m\u001b[43m=\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m=\u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcv\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_iter\u001b[49m\u001b[43m=\u001b[49m\u001b[43m-\u001b[49m\u001b[32;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 4213\u001b[39m \u001b[43m \u001b[49m\u001b[43mpartition_random_seed\u001b[49m\u001b[43m=\u001b[49m\u001b[43mpartition_random_seed\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcalc_cv_statistics\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcalc_cv_statistics\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 4214\u001b[39m \u001b[43m \u001b[49m\u001b[43msearch_by_train_test_split\u001b[49m\u001b[43m=\u001b[49m\u001b[43msearch_by_train_test_split\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrefit\u001b[49m\u001b[43m=\u001b[49m\u001b[43mrefit\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mshuffle\u001b[49m\u001b[43m=\u001b[49m\u001b[43mshuffle\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 4215\u001b[39m \u001b[43m \u001b[49m\u001b[43mstratified\u001b[49m\u001b[43m=\u001b[49m\u001b[43mstratified\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_size\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtrain_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[43m=\u001b[49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mplot\u001b[49m\u001b[43m=\u001b[49m\u001b[43mplot\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mplot_file\u001b[49m\u001b[43m=\u001b[49m\u001b[43mplot_file\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 4216\u001b[39m \u001b[43m \u001b[49m\u001b[43mlog_cout\u001b[49m\u001b[43m=\u001b[49m\u001b[43mlog_cout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlog_cerr\u001b[49m\u001b[43m=\u001b[49m\u001b[43mlog_cerr\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 4217\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m/opt/conda/lib/python3.11/site-packages/catboost/core.py:4100\u001b[39m, in \u001b[36mCatBoost._tune_hyperparams\u001b[39m\u001b[34m(self, param_grid, X, y, cv, n_iter, partition_random_seed, calc_cv_statistics, search_by_train_test_split, refit, shuffle, stratified, train_size, verbose, plot, plot_file, log_cout, log_cerr)\u001b[39m\n\u001b[32m 4097\u001b[39m stratified = \u001b[38;5;28misinstance\u001b[39m(loss_function, STRING_TYPES) \u001b[38;5;129;01mand\u001b[39;00m is_cv_stratified_objective(loss_function)\n\u001b[32m 4099\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m plot_wrapper(plot, plot_file, \u001b[33m'\u001b[39m\u001b[33mHyperparameters search plot\u001b[39m\u001b[33m'\u001b[39m, [_get_train_dir(params)]):\n\u001b[32m-> \u001b[39m\u001b[32m4100\u001b[39m cv_result = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_object\u001b[49m\u001b[43m.\u001b[49m\u001b[43m_tune_hyperparams\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 4101\u001b[39m \u001b[43m \u001b[49m\u001b[43mparam_grid\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_params\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mtrain_pool\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_iter\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 4102\u001b[39m \u001b[43m \u001b[49m\u001b[43mfold_count\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpartition_random_seed\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mshuffle\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstratified\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 4103\u001b[39m \u001b[43m \u001b[49m\u001b[43msearch_by_train_test_split\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcalc_cv_statistics\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcustom_folds\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\n\u001b[32m 4104\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 4106\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m refit:\n\u001b[32m 4107\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m.is_fitted()\n", + "\u001b[36mFile \u001b[39m\u001b[32m_catboost.pyx:5524\u001b[39m, in \u001b[36m_catboost._CatBoost._tune_hyperparams\u001b[39m\u001b[34m()\u001b[39m\n", + "\u001b[36mFile \u001b[39m\u001b[32m_catboost.pyx:5562\u001b[39m, in \u001b[36m_catboost._CatBoost._tune_hyperparams\u001b[39m\u001b[34m()\u001b[39m\n", + "\u001b[31mCatBoostError\u001b[39m: catboost/cuda/methods/oblivious_tree_structure_searcher.cpp:261: Error: something went wrong, best split is NaN with scoreinf" + ] + } + ], + "source": [ + "catboost_dir = (\n", + " f\"/mnt/storage/QA_System_Project/kbqa_reranking_experiments_runs/catboost/{ds_type}\"\n", + ")\n", + "\n", + "for run in [1, 2, 3]:\n", + " random.seed(42 + run)\n", + " np.random.seed(42 + run)\n", + " torch.manual_seed(42 + run)\n", + " torch.cuda.manual_seed_all(42 + run)\n", + " set_seed(42 + run)\n", + " \n", + " Path(f\"{catboost_dir}/text/run_{run}\").mkdir(parents=True, exist_ok=True)\n", + " model_weights = f\"{catboost_dir}/text/run_{run}/best_model\"\n", + " catboost_ranker = CatboostRanker(sequence_features=features_map[\"text\"])\n", + " catboost_ranker.fit(train_df, val_df=valid_df, model_save_path=model_weights)\n", + " with open(\n", + " results_path / f\"catboost_text_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl\", \"w\"\n", + " ) as f:\n", + " for result in catboost_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")\n", + " print(f\"RUN {run} catboost_text_reranking_seq2seq completed\")\n", + "\n", + " Path(f\"{catboost_dir}/graph/run_{run}\").mkdir(parents=True, exist_ok=True)\n", + " model_weights = f\"{catboost_dir}/graph/run_{run}/best_model\"\n", + " fitted_scaler_path = f\"{catboost_dir}/graph/run_{run}/fitted_scaler.bz2\"\n", + " catboost_ranker = CatboostRanker(graph_features=features_map[\"graph\"])\n", + " catboost_ranker.fit(\n", + " train_df,\n", + " val_df=valid_df,\n", + " model_save_path=model_weights,\n", + " scaler_save_path=fitted_scaler_path,\n", + " )\n", + " with open(\n", + " results_path / f\"catboost_graph_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl\", \"w\"\n", + " ) as f:\n", + " for result in catboost_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")\n", + " print(f\"RUN {run} catboost_graph_reranking_seq2seq completed\")\n", + "\n", + " Path(f\"{catboost_dir}/text_graph/run_{run}\").mkdir(parents=True, exist_ok=True)\n", + " model_weights = f\"{catboost_dir}/text_graph/run_{run}/best_model\"\n", + " fitted_scaler_path = f\"{catboost_dir}/text_graph/run_{run}/fitted_scaler.bz2\"\n", + " catboost_ranker = CatboostRanker(\n", + " sequence_features=features_map[\"text\"],\n", + " graph_features=features_map[\"graph\"],\n", + " )\n", + " catboost_ranker.fit(\n", + " train_df,\n", + " val_df=valid_df,\n", + " model_save_path=model_weights,\n", + " scaler_save_path=fitted_scaler_path,\n", + " )\n", + " with open(\n", + " results_path / f\"catboost_text_graph_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl\", \"w\"\n", + " ) as f:\n", + " for result in catboost_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")\n", + " print(f\"RUN {run} catboost_text_graph_reranking_seq2seq completed\")\n", + "\n", + " Path(f\"{catboost_dir}/g2t_determ/run_{run}\").mkdir(parents=True, exist_ok=True)\n", + " model_weights = f\"{catboost_dir}/g2t_determ/run_{run}/best_model\"\n", + " catboost_ranker = CatboostRanker(sequence_features=features_map[\"g2t_determ\"])\n", + " catboost_ranker.fit(train_df, val_df=valid_df, model_save_path=model_weights)\n", + " with open(\n", + " results_path / f\"catboost_g2t_determ_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl\", \"w\"\n", + " ) as f:\n", + " for result in catboost_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")\n", + " print(f\"RUN {run} catboost_g2t_determ_reranking_seq2seq completed\")\n", + "\n", + " Path(f\"{catboost_dir}/g2t_t5/run_{run}\").mkdir(parents=True, exist_ok=True)\n", + " model_weights = f\"{catboost_dir}/g2t_t5/run_{run}/best_model\"\n", + " catboost_ranker = CatboostRanker(sequence_features=features_map[\"g2t_t5\"])\n", + " catboost_ranker.fit(train_df, val_df=valid_df, model_save_path=model_weights)\n", + " with open(\n", + " results_path / f\"catboost_g2t_t5_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl\", \"w\"\n", + " ) as f:\n", + " for result in catboost_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")\n", + " print(f\"RUN {run} catboost_g2t_t5_reranking_seq2seq completed\")\n", + "\n", + " Path(f\"{catboost_dir}/g2t_gap/run_{run}\").mkdir(parents=True, exist_ok=True)\n", + " model_weights = f\"{catboost_dir}/g2t_gap/run_{run}/best_model\"\n", + " catboost_ranker = CatboostRanker(sequence_features=features_map[\"g2t_gap\"])\n", + " catboost_ranker.fit(train_df, val_df=valid_df, model_save_path=model_weights)\n", + " with open(\n", + " results_path / f\"catboost_g2t_gap_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl\", \"w\"\n", + " ) as f:\n", + " for result in catboost_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")\n", + " print(f\"RUN {run} catboost_g2t_gap_reranking_seq2seq completed\")\n", + "\n", + " Path(f\"{catboost_dir}/text_graph_g2t_determ/run_{run}\").mkdir(parents=True, exist_ok=True)\n", + " model_weights = f\"{catboost_dir}/text_graph_g2t_determ/run_{run}/best_model\"\n", + " fitted_scaler_path = f\"{catboost_dir}/text_graph_g2t_determ/run_{run}/fitted_scaler.bz2\"\n", + " catboost_ranker = CatboostRanker(\n", + " sequence_features=features_map[\"text\"] + features_map[\"g2t_determ\"],\n", + " graph_features=features_map[\"graph\"],\n", + " )\n", + " catboost_ranker.fit(\n", + " train_df,\n", + " val_df=valid_df,\n", + " model_save_path=model_weights,\n", + " scaler_save_path=fitted_scaler_path,\n", + " )\n", + " with open(\n", + " results_path\n", + " / f\"catboost_text_graph_g2t_determ_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl\",\n", + " \"w\",\n", + " ) as f:\n", + " for result in catboost_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")\n", + " print(f\"RUN {run} catboost_text_graph_g2t_determ_reranking_seq2seq completed\")\n", + "\n", + " Path(f\"{catboost_dir}/text_graph_g2t_t5/run_{run}\").mkdir(parents=True, exist_ok=True)\n", + " model_weights = f\"{catboost_dir}/text_graph_g2t_t5/run_{run}/best_model\"\n", + " fitted_scaler_path = f\"{catboost_dir}/text_graph_g2t_t5/run_{run}/fitted_scaler.bz2\"\n", + " catboost_ranker = CatboostRanker(\n", + " sequence_features=features_map[\"text\"] + features_map[\"g2t_t5\"],\n", + " graph_features=features_map[\"graph\"],\n", + " )\n", + " catboost_ranker.fit(\n", + " train_df,\n", + " val_df=valid_df,\n", + " model_save_path=model_weights,\n", + " scaler_save_path=fitted_scaler_path,\n", + " )\n", + " with open(\n", + " results_path\n", + " / f\"catboost_text_graph_g2t_t5_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl\",\n", + " \"w\",\n", + " ) as f:\n", + " for result in catboost_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")\n", + " print(f\"RUN {run} catboost_text_graph_g2t_t5_reranking_seq2seq completed\")\n", + "\n", + " Path(f\"{catboost_dir}/text_graph_g2t_gap/run_{run}\").mkdir(parents=True, exist_ok=True)\n", + " model_weights = f\"{catboost_dir}/text_graph_g2t_gap/run_{run}/best_model\"\n", + " fitted_scaler_path = f\"{catboost_dir}/text_graph_g2t_gap/run_{run}/fitted_scaler.bz2\"\n", + " catboost_ranker = CatboostRanker(\n", + " sequence_features=features_map[\"text\"] + features_map[\"g2t_gap\"],\n", + " graph_features=features_map[\"graph\"],\n", + " )\n", + " catboost_ranker.fit(\n", + " train_df,\n", + " val_df=valid_df,\n", + " model_save_path=model_weights,\n", + " scaler_save_path=fitted_scaler_path,\n", + " )\n", + " with open(\n", + " results_path\n", + " / f\"catboost_text_graph_g2t_gap_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl\",\n", + " \"w\",\n", + " ) as f:\n", + " for result in catboost_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")\n", + " print(f\"RUN {run} catboost_text_graph_g2t_gap_reranking_seq2seq completed\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "catboost_dir = (\n", + " f\"/workspace/storage/misc/features_reranking/catboost/unified_reranking/{ds_type}\"\n", + ")\n", + "Path(catboost_dir).mkdir(parents=True, exist_ok=True)\n", + "model_weights = f\"{catboost_dir}/text/best_model\"\n", + "catboost_ranker = CatboostRanker(sequence_features=features_map[\"text\"])\n", + "catboost_ranker.fit(train_df, val_df=valid_df, model_save_path=model_weights)\n", + "\n", + "with open(\n", + " results_path / f\"catboost_text_reranking_seq2seq_{ds_type}_results.jsonl\", \"w\"\n", + ") as f:\n", + " for result in catboost_ranker.rerank(test_df):\n", + " f.write(json.dumps(result) + \"\\n\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Trying to load the model...\n", + "Model Loaded.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "387842ac132b49e7aac60bb326e69e57", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/4000 [00:00 pd.DataFrame: """Convert embedding string columns to numpy arrays, handling NaN/Inf values""" + if dataframe.empty: + return dataframe dataframe = dataframe.copy() for col in embedding_columns: if col in dataframe.columns: diff --git a/experiments/subgraphs_reranking/ranking_model.py b/experiments/subgraphs_reranking/ranking_model.py index 99f906d..107469b 100644 --- a/experiments/subgraphs_reranking/ranking_model.py +++ b/experiments/subgraphs_reranking/ranking_model.py @@ -15,7 +15,7 @@ ) from catboost import CatBoostRegressor, Pool from sklearn import preprocessing, utils -from ranking_data_utils import df_to_features_array, convert_embedding_columns_to_arrays +from ranking_data_utils import df_to_features_array class RankedAnswer(TypedDict): @@ -69,7 +69,7 @@ def rerank(self, test_df: DataFrame) -> List[RankedAnswersDict]: ] results.append( RankedAnswersDict( - QuestionID=row["id"], + QuestionID=str(row["id"]), RankedAnswers=ranked_answers, ) ) @@ -100,7 +100,7 @@ def rerank(self, test_df: DataFrame) -> List[RankedAnswersDict]: ] results.append( RankedAnswersDict( - QuestionID=row["id"], + QuestionID=str(row["id"]), RankedAnswers=ranked_answers, ) ) @@ -127,9 +127,7 @@ def _sort_answers_group_by_scores( for score, answer_entity_id in zip(sorted_scores, sorted_ranked_answers): ranked_answers.append( RankedAnswer( - AnswerEntityID=str(answer_entity_id) if answer_entity_id is not None else None, - AnswerString=None, - Score=float(score) + AnswerEntityID=str(answer_entity_id), AnswerString=None, Score=float(score) ) ) return ranked_answers @@ -214,7 +212,7 @@ def rerank(self, test_df: DataFrame) -> List[RankedAnswersDict]: results.append( RankedAnswersDict( - QuestionID=question_id, + QuestionID=str(question_id), RankedAnswers=ranked_answers, ) ) @@ -284,7 +282,7 @@ def rerank(self, test_df: DataFrame) -> List[RankedAnswersDict]: results.append( RankedAnswersDict( - QuestionID=question_id, + QuestionID=str(question_id), RankedAnswers=ranked_answers, ) ) @@ -353,7 +351,7 @@ def rerank(self, test_df: DataFrame) -> List[RankedAnswersDict]: results.append( RankedAnswersDict( - QuestionID=question_id, + QuestionID=str(question_id), RankedAnswers=ranked_answers, ) ) @@ -405,7 +403,7 @@ def fit( ) -> None: """fit CatBoost model on train_df""" train_df = train_df.dropna(subset=["graph"]).copy() - train_df = train_df.sample(frac=0.999).reset_index(drop=True) + train_df = train_df.sample(frac=0.9999).reset_index(drop=True) if val_df is not None: val_df = val_df.dropna(subset=["graph"]).copy() if len(val_df) == 0: @@ -428,10 +426,6 @@ def fit( val_df[self.graph_features] ) - train_df = convert_embedding_columns_to_arrays(train_df, embedding_features) - if val_df is not None and len(val_df) > 0: - val_df = convert_embedding_columns_to_arrays(val_df, embedding_features) - X_train = train_df[self.features_to_use] y_train = train_df["correct"].astype(float).tolist() @@ -443,8 +437,6 @@ def fit( train_class_weights[train_class_weights == 0] = train_weights[0] train_class_weights[train_class_weights == 1] = train_weights[1] - - learn_pool = Pool( X_train, y_train, @@ -492,11 +484,6 @@ def rerank(self, test_df: DataFrame) -> List[RankedAnswersDict]: test_df[self.graph_features] = self.fitted_scaler.transform( test_df[self.graph_features] ) - - embedding_features = [] - if self.sequence_features: - embedding_features = self.sequence_features.copy() - test_df = convert_embedding_columns_to_arrays(test_df, embedding_features) results = [] groups = test_df.groupby("id") @@ -512,7 +499,7 @@ def rerank(self, test_df: DataFrame) -> List[RankedAnswersDict]: results.append( RankedAnswersDict( - QuestionID=question_id, + QuestionID=str(question_id), RankedAnswers=ranked_answers, ) ) diff --git a/experiments/subgraphs_reranking/ranking_mpnet.py b/experiments/subgraphs_reranking/ranking_mpnet.py new file mode 100644 index 0000000..93add1b --- /dev/null +++ b/experiments/subgraphs_reranking/ranking_mpnet.py @@ -0,0 +1,145 @@ +import random +import json + +import numpy as np +from transformers import set_seed +import torch + +from pathlib import Path +from datasets import load_dataset + +from ranking_model import MPNetRanker +from ranking_data_utils import prepare_data + +random.seed(42) +np.random.seed(42) +torch.manual_seed(42) +torch.cuda.manual_seed_all(42) +set_seed(42) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + +# os.environ['HF_DATASETS_CACHE'] = '/workspace/storage/misc/huggingface' + +import argparse + +hf_cache_dir = "/workspace/storage/misc/huggingface" + +parser = argparse.ArgumentParser(description="Subgraphs Reranking") +parser.add_argument( + "--ds_type", + type=str, + default="t5xlssm", + choices=["t5largessm", "t5xlssm", "mistral", "mixtral"], + help="Type of dataset/features to use." +) +parser.add_argument( + "--dataset", + type=str, + default="mintaka", + choices=["mintaka", "mkqa-hf"], + help="Dataset to use: mintaka or mkqa-hf." +) +parser.add_argument( + "--force", + action="store_true", + default=False, + help="Force reranking even if result files exist." +) +args = parser.parse_args() + +ds_type = args.ds_type # 't5largessm', 't5xlssm', 'mistral' or 'mixtral' +dataset_name = args.dataset # 'mintaka' or 'mkqa-hf' +force = args.force + +if dataset_name == "mintaka": + base_dataset_path = "AmazonScience/mintaka" + kgqa_ds_path = "s-nlp/KGQASubgraphsRanking" + features_data_dir = f"{ds_type}_subgraphs" + outputs_data_dir = f"{ds_type}_outputs" +elif dataset_name == "mkqa-hf": + base_dataset_path = "Dms12/mkqa_mintaka_format_with_question_entities" + kgqa_ds_path = "s-nlp/MKQASubgraphsRanking" + features_data_dir = f"mkqa_{ds_type}_subgraphs" + outputs_data_dir = f"mkqa_{ds_type}_outputs" + +features_ds = load_dataset( + kgqa_ds_path, data_dir=features_data_dir, cache_dir=hf_cache_dir +) +outputs_ds = load_dataset( + kgqa_ds_path, data_dir=outputs_data_dir, cache_dir=hf_cache_dir +) +base_ds = load_dataset(base_dataset_path) + +train_df = prepare_data(base_ds["train"], outputs_ds["train"], features_ds["train"]) +# valid_df = prepare_data( +# base_ds["validation"], outputs_ds["validation"], features_ds["validation"] +# ) +valid_df = None +test_df = prepare_data(base_ds["test"], outputs_ds["test"], features_ds["test"]) +test_df.groupby(["id", "question"]).count().describe() + +results_path = Path( + f"./reranking_model_results/{dataset_name}/{ds_type}/" +) +results_path.mkdir(parents=True, exist_ok=True) + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +base_path = "/mnt/storage/QA_System_Project/kbqa_reranking_experiments_runs/sequence/" + +seq_type_configs = { + "question_answer": "text_only_determ", + "no_highlighted_determ_sequence": "no_hl_g2t_determ", + "highlighted_determ_sequence": "hl_g2t_determ", + "no_highlighted_t5_sequence": "no_hl_g2t_t5", + "highlighted_t5_sequence": "hl_g2t_t5", + "no_highlighted_gap_sequence": "no_hl_g2t_gap", + "highlighted_gap_sequence": "hl_g2t_gap", +} + +if dataset_name == "mintaka": + sequence_types = [ + "question_answer", + "no_highlighted_determ_sequence", + "highlighted_determ_sequence", + "no_highlighted_t5_sequence", + "highlighted_t5_sequence", + "no_highlighted_gap_sequence", + "highlighted_gap_sequence", + ] +elif dataset_name == "mkqa-hf": + sequence_types = [ + "question_answer", + "no_highlighted_determ_sequence", + ] + +for run in [1, 2, 3]: + random.seed(42 + run) + np.random.seed(42 + run) + torch.manual_seed(42 + run) + torch.cuda.manual_seed_all(42 + run) + set_seed(42 + run) + + for seq_type in sequence_types: + result_suffix = seq_type_configs[seq_type] + result_file = results_path / f"mpnet_{result_suffix}_reranking_seq2seq_run_{run}_{ds_type}_results.jsonl" + + if not force and result_file.exists(): + print(f"RUN {run} mpnet_{result_suffix}_reranking_seq2seq skipped (file exists)") + continue + + model_path = Path(base_path) / dataset_name / seq_type / ds_type / f"run_{run}" / "outputs" / "checkpoint-best" + + if not model_path.exists(): + print(f"RUN {run} mpnet_{result_suffix}_reranking_seq2seq skipped (model not found: {model_path})") + continue + + try: + mpnet_ranker = MPNetRanker(seq_type, str(model_path), device) + with open(result_file, "w") as f: + for result in mpnet_ranker.rerank(test_df): + f.write(json.dumps(result) + "\n") + print(f"RUN {run} mpnet_{result_suffix}_reranking_seq2seq completed") + except Exception as e: + print(f"RUN {run} mpnet_{result_suffix}_reranking_seq2seq failed: {e}") + continue diff --git a/experiments/subgraphs_reranking/sequence/train_sequence_ranker.py b/experiments/subgraphs_reranking/sequence/train_sequence_ranker.py index 1763c1d..e1056a4 100644 --- a/experiments/subgraphs_reranking/sequence/train_sequence_ranker.py +++ b/experiments/subgraphs_reranking/sequence/train_sequence_ranker.py @@ -19,10 +19,6 @@ from datasets import load_dataset -torch.manual_seed(8) -random.seed(8) -np.random.seed(8) - METRIC_CLASSIFIER = evaluate.combine( [ @@ -48,6 +44,13 @@ default="s-nlp/KGQASubgraphsRanking", help="Path to train sequence data file (HF)", ) +parse.add_argument( + "--dataset", + type=str, + default="mintaka", + choices=["mintaka", "mkqa-hf"], + help="Dataset to use: mintaka or mkqa-hf.", +) parse.add_argument( "--ds_type", type=str, @@ -59,7 +62,7 @@ parse.add_argument( "--output_path", type=str, - default="./subgraphs_reranking_runs/sequence/", + default="/mnt/storage/QA_System_Project/kbqa_reranking_experiments_runs/sequence/", ) parse.add_argument( @@ -197,9 +200,23 @@ def __len__(self): if args.wandb_on: os.environ["WANDB_NAME"] = args.run_name - subgraphs_dataset = load_dataset(args.data_path, data_dir=f"{args.ds_type}_subgraphs") + hf_cache_dir = "/workspace/storage/misc/huggingface" + + if args.dataset == "mintaka": + kgqa_ds_path = "s-nlp/KGQASubgraphsRanking" + features_data_dir = f"{args.ds_type}_subgraphs" + elif args.dataset == "mkqa-hf": + kgqa_ds_path = "s-nlp/MKQASubgraphsRanking" + features_data_dir = f"mkqa_{args.ds_type}_subgraphs" + + subgraphs_dataset = load_dataset(kgqa_ds_path, data_dir=features_data_dir, cache_dir=hf_cache_dir) train_df = subgraphs_dataset["train"].to_pandas() - val_df = subgraphs_dataset["validation"].to_pandas() + + if args.dataset == "mkqa-hf": + val_df = train_df.head(100) + train_df = train_df.iloc[100:].reset_index(drop=True) + else: + val_df = subgraphs_dataset["validation"].to_pandas() tokenizer = AutoTokenizer.from_pretrained(args.model_name) model = AutoModelForSequenceClassification.from_pretrained( @@ -220,22 +237,26 @@ def __len__(self): else: SEQ_TYPE = f"{HL_TYPE}_{args.sequence_type}_sequence" - model_folder = args.data_path.split("_")[-1] # either large or xl - output_path = Path(args.output_path) / SEQ_TYPE / model_folder + output_path = Path(args.output_path) / args.dataset / SEQ_TYPE / args.ds_type output_path.mkdir(parents=True, exist_ok=True) train_dataset = SequenceDataset(train_df, tokenizer, SEQ_TYPE) val_dataset = SequenceDataset(val_df, tokenizer, SEQ_TYPE) + output_dir = output_path / args.run_name / "outputs" + output_dir.mkdir(parents=True, exist_ok=True) + logging_dir = output_path / args.run_name / "logs" + logging_dir.mkdir(parents=True, exist_ok=True) + training_args = TrainingArguments( - output_dir=Path(output_path) / args.run_name / "outputs", - save_total_limit=1, + output_dir=str(output_dir), + save_total_limit=2, num_train_epochs=args.num_train_epochs, per_device_train_batch_size=args.per_device_train_batch_size, per_device_eval_batch_size=args.per_device_eval_batch_size, warmup_steps=500, weight_decay=0.01, - logging_dir=Path(output_path) / args.run_name / "logs", + logging_dir=str(logging_dir), load_best_model_at_end=True, metric_for_best_model="balanced_accuracy", greater_is_better=True, @@ -252,10 +273,10 @@ def __len__(self): eval_dataset=val_dataset, compute_metrics=lambda x: compute_metrics(x, args.classification_threshold), ) - trainer.train() + trainer.train(resume_from_checkpoint=None) checkpoint_best_path = ( - output_path / args.run_name / f"{args.ds_type}" / "outputs" / "checkpoint-best" + output_path / args.run_name / "outputs" / "checkpoint-best" ) model.save_pretrained(checkpoint_best_path) tokenizer.save_pretrained(checkpoint_best_path) diff --git a/experiments/subgraphs_reranking/upload_outputs.py b/experiments/subgraphs_reranking/upload_outputs.py new file mode 100644 index 0000000..126cf24 --- /dev/null +++ b/experiments/subgraphs_reranking/upload_outputs.py @@ -0,0 +1,88 @@ +"""Upload CSV results to HuggingFace as a subset""" +import argparse +import os +import pandas as pd +from datasets import Dataset, DatasetDict + +parse = argparse.ArgumentParser() +parse.add_argument( + "--outputs_train_path", + type=str, + default=None, + help="Path to train CSV file", +) + +parse.add_argument( + "--outputs_val_path", + type=str, + default=None, + help="Path to validation CSV file", +) + +parse.add_argument( + "--outputs_test_path", + type=str, + default=None, + help="Path to test CSV file", +) + +parse.add_argument( + "--hf_path", + type=str, + default="s-nlp/MKQASubgraphsRanking", + help="Path to upload to HuggingFace", +) + +parse.add_argument( + "--subset_name", + type=str, + default="mkqa_t5largessm", + help="Name for the subset when pushing to HuggingFace. Subset will be named 'name_outputs'.", +) + + +if __name__ == "__main__": + args = parse.parse_args() + + # Load CSV files + train_df = None + val_df = None + test_df = None + + if args.outputs_train_path: + if not os.path.exists(args.outputs_train_path): + raise ValueError(f"Train CSV file not found: {args.outputs_train_path}") + train_df = pd.read_csv(args.outputs_train_path) + + if args.outputs_val_path: + if not os.path.exists(args.outputs_val_path): + raise ValueError(f"Validation CSV file not found: {args.outputs_val_path}") + val_df = pd.read_csv(args.outputs_val_path) + + if args.outputs_test_path: + if not os.path.exists(args.outputs_test_path): + raise ValueError(f"Test CSV file not found: {args.outputs_test_path}") + test_df = pd.read_csv(args.outputs_test_path) + + if train_df is None and val_df is None and test_df is None: + raise ValueError( + "At least one of --outputs_train_path/--outputs_val_path/--outputs_test_path must be provided" + ) + + # Upload to HF + ds = DatasetDict() + if train_df is not None: + ds["train"] = Dataset.from_pandas(train_df) + if val_df is not None: + ds["validation"] = Dataset.from_pandas(val_df) + if test_df is not None: + ds["test"] = Dataset.from_pandas(test_df) + + if len(ds) > 0: + subset_name = f"{args.subset_name}_outputs" + try: + ds.push_to_hub(args.hf_path, config_name=subset_name) + except (TypeError, ValueError): + ds.push_to_hub(args.hf_path) + print(f"Note: Pushed dataset without config_name. Subset name '{subset_name}' is for reference only.") + diff --git a/subgraphs_dataset_creation/mining_subgraphs_dataset_processes.py b/subgraphs_dataset_creation/mining_subgraphs_dataset_processes.py index 4ff8334..5136806 100644 --- a/subgraphs_dataset_creation/mining_subgraphs_dataset_processes.py +++ b/subgraphs_dataset_creation/mining_subgraphs_dataset_processes.py @@ -153,7 +153,7 @@ def igraph_to_nx(subgraph: ig.Graph): return nx_subgraph -def write_from_queue(save_jsonl_path: str, results_q: JoinableQueue, n_jobs: int): +def write_from_queue(save_jsonl_path: str, results_q: JoinableQueue): """given a queue, write the queue to the save_jsonl_path file Args: @@ -161,21 +161,15 @@ def write_from_queue(save_jsonl_path: str, results_q: JoinableQueue, n_jobs: int results_q (JoinableQueue): result queue (to write our results from) """ with open(save_jsonl_path, "a+", encoding="utf-8") as file_handler: - finished_workers = 0 - while finished_workers < n_jobs: + while True: try: json_obj = results_q.get() - if json_obj == "END!": - finished_workers += 1 - else: - file_handler.write(json_obj + "\n") + file_handler.write(json_obj + "\n") except QueueEmpty: continue else: results_q.task_done() - print("Finished writing queue") - def read_wd_graph(wd_graph_path: str) -> ig.Graph: """given the path, parse the triples and build @@ -215,50 +209,39 @@ def find_subgraph_and_transform_to_json( f"[{now()}]{proc_worker_header}[{os.getpid()}] Current process memory (Gb)", psutil.Process(os.getpid()).memory_info().rss / (1024.0**3), ) - is_working = True - while is_working: + while True: try: task_line = task_q.get() - if task_line == "END!": - results_q.put(task_line) - is_working = False - else: - start_time = time.time() - data = ujson.loads(task_line) - try: - subgraph = extract_subgraph( - wd_graph, data["answerEntity"], data["questionEntity"] - ) - except ValueError as value_err: - with open("ErrorsLog.jsonl", "a+", encoding="utf-8") as file: - data["error"] = str(value_err) - file.write(ujson.dumps(data) + "\n") - continue - except Exception as general_exception: # pylint: disable=broad-except - print(str(general_exception)) - time.sleep(60) - subgraph = extract_subgraph( - wd_graph, data["answerEntity"], data["questionEntity"] - ) - - nx_subgraph = igraph_to_nx(subgraph) - data["graph"] = nx.node_link_data(nx_subgraph) - - results_q.put(ujson.dumps(data)) + start_time = time.time() + data = ujson.loads(task_line) + try: + subgraph = extract_subgraph( + wd_graph, data["answerEntity"], data["questionEntity"] + ) + except ValueError as value_err: + with open("ErrorsLog.jsonl", "a+", encoding="utf-8") as file: + data["error"] = str(value_err) + file.write(ujson.dumps(data) + "\n") + continue + except Exception as general_exception: # pylint: disable=broad-except + print(str(general_exception)) + time.sleep(60) + subgraph = extract_subgraph( + wd_graph, data["answerEntity"], data["questionEntity"] + ) + + nx_subgraph = igraph_to_nx(subgraph) + data["graph"] = nx.node_link_data(nx_subgraph) + + results_q.put(ujson.dumps(data)) except QueueEmpty: continue else: task_queue.task_done() - if task_line == "END!": - print( - f"[{now()}]{proc_worker_header}[{os.getpid()}] \ - Received End of tasks. Send the same to writer!" - ) - else: - print( - f"[{now()}]{proc_worker_header}[{os.getpid()}] \ - SSP task completed ({time.time() - start_time}s)" - ) + print( + f"[{now()}]{proc_worker_header}[{os.getpid()}] \ + SSP task completed ({time.time() - start_time}s)" + ) if __name__ == "__main__": @@ -270,8 +253,6 @@ def find_subgraph_and_transform_to_json( proc_worker_header = f"{BColors.OKGREEN}[Process Worker]{BColors.ENDC}" print(f"[{now()}]] Start loading WD Graph") parsed_wd_graph = read_wd_graph(args.igraph_wikidata_path) - # parsed_wd_graph = None - print( f"[{now()}]]{BColors.OKGREEN} \ WD Graph loaded{BColors.ENDC}" @@ -282,7 +263,7 @@ def find_subgraph_and_transform_to_json( task_queue = JoinableQueue(maxsize=queue_max_size) writing_thread = Process( target=write_from_queue, - args=[args.save_jsonl_path, results_queue, args.n_jobs], + args=[args.save_jsonl_path, results_queue], daemon=True, ) writing_thread.start() @@ -295,8 +276,7 @@ def find_subgraph_and_transform_to_json( daemon=True, ) p.start() - time.sleep(30) - # time.sleep(1) + time.sleep(180) with open( args.subgraphs_dataset_prepared_entities_jsonl_path, "r", encoding="utf-8" @@ -316,9 +296,6 @@ def find_subgraph_and_transform_to_json( {results_queue.qsize():4d}; task_queue size: {task_queue.qsize():4d}" ) - for _ in range(args.n_jobs): - task_queue.put("END!") - print(f"[{now()}]{BColors.HEADER}[Main Thread]{BColors.ENDC} All tasks sent") task_queue.join() results_queue.join() From 309f500c5af4e962901014d67f41aef7c852ab0c Mon Sep 17 00:00:00 2001 From: Mikhail Salnikov <2613180+MihailSalnikov@users.noreply.github.com> Date: Wed, 10 Dec 2025 12:20:07 +0300 Subject: [PATCH 10/10] Update README --- README.md | 107 +++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 94 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 2123608..08326de 100644 --- a/README.md +++ b/README.md @@ -22,11 +22,12 @@ Our KGQA pipeline is a novel framework which enhances Large Language Models' per ![big_pipline](assets/big_pipe.png) -Our KGQA pipelines includes generating answer candidates, entity linking for question entities, subgraphs generation, feature extractors for subgraphs, and various ranking models. All experiments for these papers were based on [Mintaka](https://www.google.com/search?q=mintaka+amazon) - a complex factoid question answering dataset. +Our KGQA pipeline includes generating answer candidates, entity linking for question entities, subgraphs generation, feature extractors for subgraphs, and various ranking models. The pipeline leverages Wikidata as the Knowledge Graph and extracts subgraphs by calculating shortest paths between entities. Experiments were conducted on [Mintaka](https://huggingface.co/datasets/AmazonScience/mintaka) and [MKQA](https://github.com/google-research-datasets/mkqa) - complex factoid question answering datasets. ### 📝 Quick Links - [📄 Knowledge Graph Question Answering - KGQA📑](#knowledge-graph-question-answering) - [🛣KGQA Overview](#kgqa-overview) + - [💻Hardware Requirements](#hardware-requirements) - [🔨Answer Candidates Generation](#answer-candidates-generation) - [🔧Entity Linking](#entity-linking) - [🛠️Subgraphs Extraction](#subgraphs-extraction) @@ -50,7 +51,20 @@ python3 kbqa/mistral_mixtral.py --mode train_eval --model_name mistralai/Mistral ``` The generated candidates for the T5-like models and Mixtral/Mistral will be `.csv` and `.json` format, respectively. -In both `seq2seq.py` and `mistral_mixtral.py`, you can other useful arguments, which includes tracking, training parameters, finetuning parameters, path for checkpoints etc. These arguments are detailed within the files themselves. Lastly, if you prefer to use our prepared finetuned models and generated candidates, we have uplodaed them to [HuggingFace](https://huggingface.co/datasets/s-nlp/KGQASubgraphsRanking). +In both `seq2seq.py` and `mistral_mixtral.py`, you can other useful arguments, which includes tracking, training parameters, finetuning parameters, path for checkpoints etc. These arguments are detailed within the files themselves. + +**Supported Datasets:** The `seq2seq.py` script supports multiple datasets including: +- `AmazonScience/mintaka` (default): The Mintaka dataset +- `mkqa-hf`: MKQA dataset in Mintaka format from `Dms12/mkqa_mintaka_format_with_question_entities` +- `mkqa`: Local MKQA dataset files (`mkqa_train.json` and `mkqa_test.json`) +- `s-nlp/lc_quad2`: LC-QuAD 2.0 dataset + +To use a specific dataset, set the `--dataset_name` argument accordingly. For example: +```bash +python3 seq2seq.py --mode train_eval --dataset_name mkqa-hf --model_name t5-large +``` + +Lastly, if you prefer to use our prepared finetuned models and generated candidates, we have uplodaed them to [HuggingFace](https://huggingface.co/datasets/s-nlp/KGQASubgraphsRanking). ### Entity Linking ![entity_linking](assets/entity_linking.png) In both of our papers, we decided to use the golden question entities provided by the Mintaka dataset. The scope of our research were solely on the novelty of the subgraphs and the efficacy of different ranking methods. @@ -58,9 +72,9 @@ In both of our papers, we decided to use the golden question entities provided b ### Subgraphs Extraction ![subgraphs_pipe](assets/subgraphs_pipe.png) -You can either 1) use our prepared dataset at [HuggingFace](https://huggingface.co/datasets/s-nlp/KGQASubgraphsRanking) or 2) fetch your own dataset. **Please do note, if you'd like to fetch your own subgraphs dataset, the task is very computationally expensive on the CPU**. The extraction protocal can be divided into 2 steps. - - parsing the Wikidata dump to build our Wikidata graph via iGraph. - - load our Igraph representation of Wikidata and generate the subgraph dataset. +The subgraph extraction process extracts subgraphs related to entity candidates from question-and-answer sets by calculating shortest paths between entities in the Wikidata Knowledge Graph. You can either 1) use our prepared dataset at [HuggingFace](https://huggingface.co/datasets/s-nlp/KGQASubgraphsRanking) for Mintaka or 2) extract your own dataset. **⚠️ WARNING: Subgraph extraction is very computationally expensive and memory-intensive (requires 60-80GB RAM per parallel process)**. The extraction protocol can be divided into 2 steps: + - Parsing the Wikidata dump to build our Wikidata graph via iGraph. + - Loading our iGraph representation of Wikidata and generating the subgraph dataset. All subgraphs extraction codes can be found in `kbqa/subgraphs_dataset_creation/`. #### Parsing Wikidata Dump @@ -144,14 +158,17 @@ python3 kbqa/experiments/subgraphs_reranking/graph_features_preparation.py --sub The output file will be a `.csv` file of the same format as the published [finalised HuggingFace dataset](https://huggingface.co/datasets/s-nlp/KGQASubgraphsRanking). **Please pay attention that one would need to repeat the "[Building the Subgraphs](#building-the-subgraphs)" and "[Subgraphs Feature Extraction](#subgraphs-feature-extraction)" sections for train, val, test for T5-large-ssm, T5-xl-ssm, Mistral, and Mixtral**. The [finalised HuggingFace dataset](https://huggingface.co/datasets/s-nlp/KGQASubgraphsRanking) already combined all data splits and LLMs into one total-packaged dataset. ### Ranking Answer Candidates Using Subgraphs -Using the [finalised dataset](https://huggingface.co/datasets/s-nlp/KGQASubgraphsRanking), we devised the following rankers: -- **Graph Transformer**: leveraging the raw subgraphs by itselves. -- **Regression-based**: Logistic and Linear Regression with graph features and MPNet embeddings of text and G2T features. -- **Gradient Boosting**: Catboost with graph features and MPNet embeddings of text and G2T features. -- **Sequence Ranker**: MPNet with G2T features. - -After training/fitting, all tuned rankers will generate the list of re-ranked answer candidates with the same skeleton, outlined in `/kbqa/experiments/subgraphs_reranking/ranking_model.py` (**beside Graphormer**). This list of re-ranked answer candidates (in `jsonl` format) is then evaluated with Hits@N metrics with `kbqa/mintaka_evaluate.py` -Evaluating +Using the [finalised dataset](https://huggingface.co/datasets/s-nlp/KGQASubgraphsRanking), we devised the following reranking methods to select the most probable answers from candidate lists: + +- **Regression-based**: Logistic and Linear Regression models using graph features and MPNet embeddings of text and G2T features. +- **Gradient Boosting (CatBoost)**: Gradient boosting models with graph features and MPNet embeddings of text and G2T features. +- **Sequence Ranker (MPNet)**: Semantic similarity-based ranking using MPNet embeddings of G2T features. +- **RankGPT**: Zero-shot LLM-based reranking using instructional permutation generation (supports both Mintaka and MKQA-hf datasets). + +These methods utilize various features extracted from the mined subgraphs, including graph structural features, text embeddings, and graph-to-text (G2T) sequence embeddings. + +After training/fitting, all tuned rankers will generate the list of re-ranked answer candidates with the same skeleton, outlined in `/kbqa/experiments/subgraphs_reranking/ranking_model.py` (**beside Graphormer**). This list of re-ranked answer candidates (in `jsonl` format) is then evaluated with Hits@N metrics with `kbqa/mintaka_evaluate.py`. + #### Training \& Generating the Re-ranked Answers **Graphormer:** As Graphormer was introduced in the original paper, it is the only ranker that was **not updated** to work with `kbqa/experiments/subgraphs_reranking/ranking_model.py` and `kbqa/mintaka_evaluate.py`. We are still working to refractor the code to the unified ranking pipeline, introduced in the extended paper. With that in mind, you can train the Graphormer model with: ```bash @@ -184,6 +201,44 @@ For the sequence ranker code, there are several available arguments, which can b After training sequence ranker on the desired answer candidate LLM subgraph dataset and sequence, please load the path of the tuned model in `ranking.ipynb` to evaluate. Please pay attention to the parameters of `MPNetRanker()` in `ranking.ipynb` (the different sequence used must be pass in accordingly). It is important to note that the tuned model will generate and rank answer candidates to produce a ranking `.jsonl` file. +**RankGPT:** RankGPT is a zero-shot LLM-based reranking approach that uses instructional permutation generation. It requires no training and works with OpenAI-compatible APIs (including vLLM). To use RankGPT: + +```bash +cd experiments/subgraphs_reranking/rankgpt +python3 predict.py \ + --model_name meta-llama/Llama-2-7b-chat-hf \ + --dataset mintaka \ + --ds_type t5xlssm \ + --output_path /path/to/output.jsonl \ + --window_size 20 \ + --step_size 10 +``` + +For MKQA-hf dataset: +```bash +python3 predict.py \ + --model_name meta-llama/Llama-2-7b-chat-hf \ + --dataset mkqa-hf \ + --ds_type t5xlssm \ + --output_path /path/to/output.jsonl +``` + +Key arguments: +- `--model_name`: LLM model name for ranking (e.g., `meta-llama/Llama-2-7b-chat-hf`, `mistralai/Mistral-7B-Instruct-v0.2`) +- `--dataset`: Dataset to use (`mintaka` or `mkqa-hf`) +- `--ds_type`: Answer candidate LLM type (`t5largessm`, `t5xlssm`, `mistral`, `mixtral`) +- `--window_size`: Window size for sliding window ranking (default: 20) +- `--step_size`: Step size for sliding window (default: 10) +- `--graph_sequence_feature`: Optional graph sequence feature (`highlighted_determ_sequence` or `no_highlighted_determ_sequence`) + +The ranker requires API configuration via environment variables: +```bash +export OPENAI_BASE_URL="http://localhost:8000/v1" # Your vLLM or OpenAI endpoint +export OPENAI_API_KEY="your-api-key" # Can be "EMPTY" for local vLLM +``` + +RankGPT automatically handles answer deduplication and uses sliding window strategy for large answer sets. The output format is compatible with `mintaka_evaluate.py`. + #### Hits@N Evaluation After producing the new list of re-ranked answer candidates, you can evaluate this `.jsonl` file by running: ```bash @@ -191,6 +246,32 @@ python3 kbqa/mintaka_evaluate.py --predictions_path path_to_jsonl_prediction_fil ``` Running the above code will produce the final evaluation of our ranker. The evaluation includes Hits@1-5 for the entire Mintaka dataset and each of the question type (intersection, comparative, generic, etc.). +### Hardware Requirements + +The hardware requirements vary significantly depending on which components of the pipeline you plan to use: + +#### Minimum Requirements (Using Pre-computed Datasets) +- **CPU**: Multi-core processor (4+ cores recommended) +- **RAM**: 32GB minimum, 120GB recommended +- **GPU**: Optional, but recommended for training and inference + - For T5-base/large: 12GB VRAM + - For T5-3B/XL: 24GB VRAM + - For Mistral/Mixtral: 80GB VRAM +- **Storage**: 100GB+ free space for datasets and models + +#### Subgraph Extraction Requirements +**⚠️ WARNING: Subgraph extraction is computationally expensive and memory-intensive.** + +- **CPU**: High-performance multi-core processor (32+ cores recommended for parallel processing) +- **RAM**: **60-80GB per parallel process** (critical requirement) + - The `--n_jobs` parameter controls parallelism + - Example: With `--n_jobs=4`, you need 240-320GB total RAM + - Consider using fewer jobs if RAM is limited +- **Storage**: 2000GB+ free space for Wikidata dumps and parsed graph representations +- **Time**: Parsing Wikidata dump can take several days depending on hardware + +**Recommendation**: Use our pre-computed datasets from [HuggingFace](https://huggingface.co/datasets/s-nlp/KGQASubgraphsRanking) instead of extracting subgraphs yourself unless you have access to high-memory compute infrastructure. + ### Miscallaneous #### Build \& Run KGQA Docker Environment We have prepared a Docker environment for all experiments outlined above. Please run: