Skip to content

Aatman09/cbow

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 

Repository files navigation

CBOW Word Embedding Model in JAX

Overview

This project implements a Continuous Bag of Words (CBOW) model using JAX. The primary goal is to learn word embeddings by predicting a center word based on its surrounding context words.

This implementation features:

  • Input: Bag-of-Words (BoW) averaged one-hot vectors.
  • Network: A feed-forward neural network with one hidden layer and ReLU activation.
  • Output: A full Softmax output layer.
  • Optimization: Manual Stochastic Gradient Descent (SGD) with Cross-Entropy loss.
  • Differentiation: Automatic differentiation using jax.grad.

Note: This model is inspired by the original Word2Vec CBOW but is not identical. Real Word2Vec uses embedding lookups (indexing) rather than BoW vectors and often uses linear hidden layers. This version is simplified to illustrate core concepts in JAX.


Project Structure

├── dataprocessing.py     # Preprocessing and dataset builder
├── cbow.py               # JAX model and training loop
├── data.txt              # Input corpus
├── README.md             # Documentation

1. Preprocessing Pipeline

The preprocessing script prepares the raw text for the neural network through the following steps:

  1. Load & Clean: Reads text data and removes non-alphanumeric characters.
  2. Tokenize: Converts text to lowercase and splits into tokens.
  3. Vocabulary: Builds a sorted list of unique words.
  4. Sliding Window: Generates (context, center) pairs by moving a window over the text.
  5. Vectorization: Converts context and center words into one-hot vectors.
  6. Feature Engineering: Averages the context one-hot vectors to create a single input vector.

Resulting Data:

  • x_train: Averaged context vectors.
  • y_train: One-hot encoded center word vectors.

2. Dataset Shapes

Assuming a sample vocabulary size $V=276$ and $N=524$ training examples:

Variable Shape Meaning
x_train $(524, 276)$ BoW context vectors (Inputs)
y_train $(524, 276)$ One-hot center word labels (Targets)
V $276$ Vocabulary size

3. Model Architecture

The model is a simple feed-forward neural network structured as follows:

Input vector x $\to$ Linear Projection $\to$ ReLU $\to$ Linear Projection $\to$ Softmax

Parameter Shapes

The hidden layer size is set to $50$.

Parameter Shape Description
$W_1$ $(276, 50)$ Input to hidden layer weights (The Embeddings)
$b_1$ $(50,)$ Hidden layer bias
$W_2$ $(50, 276)$ Hidden to output layer weights
$b_2$ $(276,)$ Output layer bias

4. Forward Pass

The forward propagation calculates the probability distribution over the vocabulary:

$$ \begin{aligned} Z_1 &= x \cdot W_1 + b_1 \\ A_1 &= \text{ReLU}(Z_1) \\ Z_2 &= A_1 \cdot W_2 + b_2 \\ \text{probs} &= \text{Softmax}(Z_2) \end{aligned} $$

The output probs represents the likelihood of each word in the vocabulary being the center word.


5. Loss Function

We use Cross-Entropy Loss to measure the difference between the predicted probability distribution and the actual one-hot vector:

$$ L = -\sum_{i=1}^{V} y_i \cdot \log(\hat{y}_i) $$

Where $y$ is the true one-hot vector and $\hat{y}$ (probs) is the predicted distribution.


6. Training Loop

The training loop iterates through epochs, performing:

  1. Forward Pass: Compute predictions.
  2. Loss Calculation: Evaluate error.
  3. Gradient Computation: Use jax.grad to find derivatives w.r.t parameters.
  4. Update: Adjust parameters using basic SGD.

Example Output:

Epoch 1, Loss = 5.6204
Epoch 2, Loss = 5.6199
...

Note: A starting loss around $\log(V) = \log(276) \approx 5.62$ is expected, indicating uniform probability distribution at initialization.


7. Running the Model

Prerequisites

Ensure the following packages are installed:

pip install "jax[cpu]"
pip install nltk

Execution

Ensure data.txt contains your corpus, then run:

python cbow.py

8. What the Model Learns

The model learns three key things:

  1. Word Embeddings: Stored in matrix $W_1$.
  2. Semantic Similarity: Words with similar contexts will have similar vectors in $W_1$.
  3. Prediction: How to predict a missing word given its neighbors.

To access a specific word's embedding after training:

embedding = params["W1"][Word2Idx[word]]

9. Differences from Real Word2Vec CBOW

Feature This Implementation Original Word2Vec CBOW
Input Representation Averaged one-hot vectors Embedding Lookup Table
Hidden Activation ReLU Linear (Identity)
Output Layer Full Softmax Hierarchical Softmax or Negative Sampling
Performance Educational / Slower Highly Optimized / Fast

10. Possible Extensions

To improve performance or capability, consider implementing:

  • Optimizers: Replace manual SGD with Adam or RMSProp (via Optax).
  • Batching: Implement mini-batch training for stability.
  • JIT Compilation: Use @jax.jit to speed up the training step.
  • Visualization: Use t-SNE or PCA to visualize the learned embeddings in 2D space.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages