This repository contains the implementation of the Contextual Relational Graph Attention Network (ContextRGAT) model for learning contextual word embeddings by leveraging syntactic dependencies and multi-relational graph structures. The project includes data processing, training, testing, and prediction scripts.
In this work, we propose a novel approach for learning contextual word embeddings by integrating three distinct types of graphs:
- Global Co-occurrence Graph: Captures word co-occurrence statistics.
- Syntactic Graph: Represents sentence-level syntactic dependencies.
- Similarity Graph: Edge weights reflect cosine similarity between pre-trained FastText embeddings.
We construct a unified multi-relational graph that encompasses these three types of graphs. A Relational Graph Attention Network (RGAT) is then employed to perform relation-specific linear transformations and attention-based message passing. During training, a masked language modeling style loss is used, encouraging the model to reconstruct the original embeddings for masked words while integrating contextual information from diverse relational signals.
Our approach demonstrates competitive performance on the Word-in-Context (WiC) dataset, achieving accuracy comparable to ELMo despite being trained on a significantly smaller corpus. Additionally, we evaluate our embeddings through semantic clustering, showing that they effectively capture contextual distinctions between word senses. Our findings highlight the potential of multi-relational graph structures for contextual embedding learning and more efficient alternatives to transformer-based models.
.
├── utils/
│ ├── helper_functions.py
│ ├── ContextRGAT.py
├── pipelines/
│ ├── data_processing/
│ │ ├── data_processing_brown.py
│ ├── train/
│ │ ├── brown/
│ │ │ ├── training_brown.py
│ │ │ ├── resume_training_brown.py
│ ├── test/
│ │ ├── test_brown_wic.py
├── objects/
│ ├── models/
│ ├── graphs/
│ ├── dictionaries/
├── parameters/
│ ├── parameters.yaml
├── requirements.txt
├── README.md
-
Clone the repository:
git clone <repository_url> cd <repository_directory>
-
Install the required dependencies:
pip install -r requirements.txt
-
Download cc.en.300.bin from FastText and save it in the utils/fasttext
-
Preprocess the data:
python pipelines/data_processing/data_processing_brown.py
-
Train the model:
python pipelines/train/brown/training_brown.py
-
Resume training from a checkpoint (if needed):
python pipelines/train/brown/resume_training_brown.py
- Test the model and generate predictions:
python pipelines/test/test_brown_wic.py
- Ensure that the paths in the scripts are correctly set to point to the appropriate directories and files.
- The
parameters.yamlfile contains configuration parameters such as the number of epochs, batch size, learning rate, and checkpoint interval. Modify these parameters as needed. - The
objects/directory is used to store models, graphs, and other artifacts generated during the data processing, training, and testing steps.
By following these guidelines, you should be able to reproduce the results and further experiment with the model and data.