This project implements a Continuous Bag of Words (CBOW) model using JAX. The primary goal is to learn word embeddings by predicting a center word based on its surrounding context words.
This implementation features:
- Input: Bag-of-Words (BoW) averaged one-hot vectors.
- Network: A feed-forward neural network with one hidden layer and ReLU activation.
- Output: A full Softmax output layer.
- Optimization: Manual Stochastic Gradient Descent (SGD) with Cross-Entropy loss.
- Differentiation: Automatic differentiation using
jax.grad.
Note: This model is inspired by the original Word2Vec CBOW but is not identical. Real Word2Vec uses embedding lookups (indexing) rather than BoW vectors and often uses linear hidden layers. This version is simplified to illustrate core concepts in JAX.
├── dataprocessing.py # Preprocessing and dataset builder
├── cbow.py # JAX model and training loop
├── data.txt # Input corpus
├── README.md # Documentation
The preprocessing script prepares the raw text for the neural network through the following steps:
- Load & Clean: Reads text data and removes non-alphanumeric characters.
- Tokenize: Converts text to lowercase and splits into tokens.
- Vocabulary: Builds a sorted list of unique words.
- Sliding Window: Generates
(context, center)pairs by moving a window over the text. - Vectorization: Converts context and center words into one-hot vectors.
- Feature Engineering: Averages the context one-hot vectors to create a single input vector.
Resulting Data:
x_train: Averaged context vectors.y_train: One-hot encoded center word vectors.
Assuming a sample vocabulary size
| Variable | Shape | Meaning |
|---|---|---|
| x_train | BoW context vectors (Inputs) | |
| y_train | One-hot center word labels (Targets) | |
| V | Vocabulary size |
The model is a simple feed-forward neural network structured as follows:
Input vector x Linear Projection ReLU Linear Projection Softmax
The hidden layer size is set to
| Parameter | Shape | Description |
|---|---|---|
| Input to hidden layer weights (The Embeddings) | ||
| Hidden layer bias | ||
| Hidden to output layer weights | ||
| Output layer bias |
The forward propagation calculates the probability distribution over the vocabulary:
The output probs represents the likelihood of each word in the vocabulary being the center word.
We use Cross-Entropy Loss to measure the difference between the predicted probability distribution and the actual one-hot vector:
Where
The training loop iterates through epochs, performing:
- Forward Pass: Compute predictions.
- Loss Calculation: Evaluate error.
- Gradient Computation: Use
jax.gradto find derivatives w.r.t parameters. - Update: Adjust parameters using basic SGD.
Example Output:
Epoch 1, Loss = 5.6204
Epoch 2, Loss = 5.6199
...
Note: A starting loss around $\log(V) = \log(276) \approx 5.62$ is expected, indicating uniform probability distribution at initialization.
Ensure the following packages are installed:
pip install "jax[cpu]"
pip install nltkEnsure data.txt contains your corpus, then run:
python cbow.pyThe model learns three key things:
-
Word Embeddings: Stored in matrix
$W_1$ . -
Semantic Similarity: Words with similar contexts will have similar vectors in
$W_1$ . - Prediction: How to predict a missing word given its neighbors.
To access a specific word's embedding after training:
embedding = params["W1"][Word2Idx[word]]| Feature | This Implementation | Original Word2Vec CBOW |
|---|---|---|
| Input Representation | Averaged one-hot vectors | Embedding Lookup Table |
| Hidden Activation | ReLU | Linear (Identity) |
| Output Layer | Full Softmax | Hierarchical Softmax or Negative Sampling |
| Performance | Educational / Slower | Highly Optimized / Fast |
To improve performance or capability, consider implementing:
- Optimizers: Replace manual SGD with Adam or RMSProp (via Optax).
- Batching: Implement mini-batch training for stability.
- JIT Compilation: Use
@jax.jitto speed up the training step. - Visualization: Use t-SNE or PCA to visualize the learned embeddings in 2D space.