A comprehensive system for analyzing and comparing narrative similarity across multiple dimensions. This project implements graph-based story representations, semantic similarity metrics, and Graph Neural Networks (GNNs) to solve the narrative similarity task.
- Overview
- Project Structure
- Installation
- Core Components
- Usage
- Baselines
- GNN Module
- Data Format
- Configuration
- Results
- Links
StoryNet tackles the challenging problem of measuring narrative similarity between stories. The system:
- Extracts structured information from narrative text using LLMs
- Constructs multi-dimensional story graphs representing characters, events, causality, emotions, and themes
- Compares stories using both structural graph metrics and semantic similarity
- Learns optimal weighting schemes for semantic dimensions via supervised learning
- Trains Graph Neural Networks to learn narrative similarity in an end-to-end fashion
The project implements multiple approaches:
- Structural Baseline: Graph topology comparison (node/edge overlap, degree distributions, centrality)
- Semantic Baseline: Multi-dimensional semantic similarity with learned weights
- LLM Baseline: Direct prompting of language models for similarity judgments
- GNN Models: Graph Convolutional Networks (GCN) and Graph Attention Networks (GAT) for learned representations
StoryNet/
├── main.py # Main execution pipeline
├── story.py # Story class and graph construction
├── semantic.py # Semantic similarity computation
├── semantic_tables.py # Semantic similarity evaluation tables
├── train_semantic_weights.py # Weight learning for semantic dimensions
├── weight_learning.py # Weight learner implementation
├── gemini.py # Gemini API wrapper
├── results.py # Results aggregation utility
├── evaluate.sh # Evaluation script
│
├── baselines/
│ ├── structural_baseline.py # Graph structure comparison
│ ├── llm_baseline.py # LLM-based similarity
│ └── llm_baseline_prompt.py # Prompts for LLM baseline
│
├── GNN/
│ ├── train.py # GNN training script
│ ├── trainer.py # Training loop implementation
│ ├── models.py # GCN and GAT architectures
│ ├── graph_preprocessor.py # Graph preprocessing and featurization
│ ├── feature_tokenizer.py # Feature encoding with embeddings
│ ├── triplet_dataset.py # Dataset for triplet learning
│ ├── split_manager.py # Train/val/test split management
│ ├── config.py # Configuration and hyperparameters
│ ├── validate_splits.py # Split validation utility
│ ├── regenerate_splits.py # Split regeneration utility
│ └── gnn.sh # Training job script
│
├── data/
│ ├── compiled_track_a.jsonl # Main dataset
│ ├── dev_data/ # Development data
│ ├── sample_data/ # Sample stories
│ └── synthetic/ # Synthetic test data
│
├── prompts/
│ ├── stage1.py # Character extraction prompts
│ ├── stage2.py # Event and causality prompts
│ └── stage3.py # Emotional and thematic prompts
│
├── graphs/ # Generated story graphs (idx_*)
├── results/ # Evaluation results and logs
├── weights/ # Learned semantic weights
├── submissions/ # Course submissions
└── requirements.txt # Python dependencies
- Python 3.8+
- CUDA-capable GPU (optional, for GNN training)
-
Install dependencies:
pip install -r requirements.txt
-
HuggingFace CLI login (optional, for LLAMA models):
# Create HuggingFace account and verify email # Accept LLAMA agreement: https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct huggingface-cli login # Paste token, answer Y to both questions
Purpose: Converts raw narrative text into structured graph representations.
Class: Story
- Attributes:
text: Raw story textcharacters: List of character entities with roles, attributes, agencyrelationships: Character relationships with types, evolution, power dynamicsevents: Story events with types, participants, emotional tones, narrative functionscausal_chains: Cause-effect relationships between eventsparallel_threads: Concurrent narrative threadsemotional_trajectories: Character emotional arcsthematic_elements: Abstract themes and their manifestationsgraph: NetworkX MultiDiGraph representation
Three-Stage Pipeline:
- Stage 1: Extract characters and relationships
- Stage 2: Identify events and causal chains
- Stage 3: Analyze emotional arcs and themes
Methods:
analyze(): Run all three extraction stagesbuild_graph(): Construct NetworkX graph from extracted informationvisualize_graph(): Generate interactive HTML visualizationsave(filepath): Serialize Story object to pickleload(filepath): Deserialize Story from pickle
Example:
from story import Story
from gemini import GeminiInference
llm = GeminiInference(api_key="YOUR_KEY")
story = Story(text, llm, index=1, label="anchor")
story.analyze()
story.build_graph()
story.visualize_graph()
story.save("story.pkl")Purpose: Multi-dimensional semantic comparison of story pairs.
Seven Similarity Dimensions:
- Character Semantic: Role overlap, attribute similarity, agency distribution
- Thematic: Theme alignment, abstraction level matching
- Emotional Arc: Emotional progression, state distribution, change patterns
- Causal Structure: Causal relationship types, strengths, temporal gaps
- Narrative Function: Story beats, narrative structure similarity
- Relationship Dynamics: Relationship types, evolution, power dynamics
- Event Semantic: Event types, emotional tones
Class: StorySemanticAnalyzer
- Uses
sentence-transformers(all-mpnet-base-v2) for embedding-based comparison - Each dimension returns a
SimilarityScorewith:score: Overall dimension score (0-1)details: Component-level scores for training
Function: analyze_story_pair(story1, story2, embedding_model, weights_dict)
- Returns:
(similarities_dict, overall_score) - Supports custom weights via
weights_dictwith structure:{ "thematic": 0.25, "character_semantic": 0.20, ... "intra": { "character_semantic": {"role": 0.4, "attribute": 0.4, "agency": 0.2}, ... } }
Purpose: Learn optimal weights for semantic dimensions from labeled data.
Key Features:
- Component-level learning: Learns weights for 18 fine-grained components across 7 dimensions
- Prior-aware updates: Integrates default heuristic weights in logit space
- Feature filtering: Removes low-variance and low-activity components
- Fast validation blending: Efficiently searches for optimal interpolation between default and learned weights
- Sample weighting: Focuses learning on hard examples (small default margins)
Algorithm:
- Build 18-dimensional component feature vectors (Xp, Xn) for positive/negative pairs
- Standardize features based on training distribution
- Filter components by variance and non-zero activity
- Train LogisticRegression or custom GD on margins (Xp - Xn)
- Apply delta-logit update:
w_new = softmax(logit(w_default) + γ·w_learned) - Blend search: Find optimal
αinw_blend = (1-α)·w_default + α·w_new
Command:
python train_semantic_weights.py \
--data_path data/compiled_track_a.jsonl \
--linear_backend logreg \
--C 100.0 \
--max_iter_lr 5000 \
--std_threshold 0.0 \
--nz_threshold 0.0 \
--test_frac 0.10 \
--seed 42 \
--blend_grid 201 \
--delta_logit 1 \
--delta_scale 1.0 \
--save_json weights/learned_semantic_weights.json \
--save_blended_json weights/learned_semantic_weights_blend.jsonArguments:
--data_path: Path to JSONL dataset--linear_backend: Training method (gdorlogreg)--C: Inverse L2 regularization for LogisticRegression (default: 1.0)--max_iter_lr: Max iterations for LogisticRegression (default: 1000)--std_threshold: Drop components with std below this (default: 0.05)--nz_threshold: Drop components with non-zero rate below this (default: 0.60)--test_frac: Validation set fraction (default: 0.10)--seed: Random seed (default: 42)--blend_grid: Number of alpha values to try in blending (default: 21)--delta_logit: Use delta-logit update (0 or 1, default: 1)--delta_scale: Scale of delta step in logit space (default: 1.0)--save_json: Output path for learned weights--save_blended_json: Output path for blended weights
Purpose: Core weight learning algorithms with simplex projection.
Class: WeightLearner
- Learns non-negative normalized weights (probability simplex)
- Supports triplet and pairwise learning objectives
- Euclidean projection onto simplex after each gradient step
Methods:
fit_triplets(feats_anchor_pos, feats_anchor_neg, config): Learn from triplet supervisionfit_pairs(feats_pairs, labels, config): Learn from binary labelssave(path)/load(path): Persistence
TrainConfig:
lr: Learning rate (default: 0.1)steps: Number of gradient updates (default: 2000)reg: L2 regularization (default: 1e-4)seed: Random seed (default: 42)verbose_every: Print frequency (default: 200)
Purpose: Wrapper for Google Gemini API.
Class: GeminiInference
- Model: gemini-2.5-flash-lite (default)
- System instruction: Enforces JSON-only output, filters inappropriate content
- Parameters:
temperature: 0.3 (default)max_output_tokens: 2048thinking_budget: 0 (no chain-of-thought)
Example:
from gemini import GeminiInference
llm = GeminiInference(api_key="YOUR_KEY")
response = llm.generate(prompt, temperature=0.3, max_output_tokens=2048)Purpose: End-to-end story analysis and comparison pipeline.
Modes:
- Single sample analysis:
--index N - Batch processing:
--batch [batch_num] [total_batches] - Comparison only:
--comp_only --index N - Full evaluation:
--comp_all - Test on 10% sample:
--test_only
Command Examples:
# Analyze single sample (runs LLM inference, builds graphs, compares)
python main.py \
--index 5 \
--data_path data/compiled_track_a.jsonl \
--API_KEY "YOUR_GEMINI_KEY"
# Compare existing cached graphs (no LLM calls)
python main.py \
--comp_only \
--index 100 \
--data_path data/compiled_track_a.jsonl \
--weights_path weights/learned_semantic_weights.json
# Evaluate on all cached graphs
python main.py \
--comp_all \
--data_path data/compiled_track_a.jsonl \
--weights_path weights/learned_semantic_weights.json \
--idx_offset 0
# Test on 10% random sample (fast evaluation)
python main.py \
--test_only \
--data_path data/compiled_track_a.jsonl \
--weights_path weights/learned_semantic_weights.json
# Batch processing (parallel execution)
# Run batch 3 out of 22 total batches
python main.py \
--batch 3 22 \
--data_path data/compiled_track_a.jsonl \
--API_KEY "YOUR_GEMINI_KEY"Arguments:
--index: Sample index to analyze (default: -1 for batch mode)--comp_only: Only compare existing graphs without re-analysis--comp_all: Evaluate on ALL cached indices--test_only: Evaluate on 10% random sample--batch: Batch number and total batches for parallel execution--data_path: Path to JSONL dataset (default: data/trimmed_track_a.jsonl)--API_KEY: Gemini API key (default: hardcoded key)--weights_path: Path to learned weights JSON (default: learned_semantic_weights.json)--idx_offset: Offset for cached story indices (default: 0)
Output:
- Graphs saved to:
graphs/idx_{index}/ - Results saved to:
graphs/idx_{index}/results.txt - Batch results:
results/batch_{n}_of_{total}_results.txt - Logs:
results/logs/batch_{n}_of_{total}.log
Results Format:
Structure Comparison Results:
Comparison Node Edge InDeg OutDeg Centrality Overall
--------------------------------------------------------------------------------
Text A 0.7500 0.6000 0.8200 0.7800 0.6500 0.7100
Text B 0.6000 0.5000 0.7000 0.6500 0.5500 0.6000
SUCCESS
Semantic Comparison Results:
--- Text A vs Anchor ---
THEMATIC : 0.850
CHARACTER_SEMANTIC : 0.720
EMOTIONAL_ARC : 0.680
CAUSAL_STRUCTURE : 0.750
NARRATIVE_FUNCTION : 0.810
RELATIONSHIP_DYNAMICS : 0.690
EVENT_SEMANTIC : 0.730
OVERALL SIMILARITY : 0.756
--- Text B vs Anchor ---
...
Function: compare_story_graphs(G1, G2)
Metrics:
- Node overlap: Jaccard similarity of node sets
- Edge overlap: Jaccard similarity of (u, v, type) edge triples
- In-degree similarity: Distribution comparison
- Out-degree similarity: Distribution comparison
- Centrality similarity: Betweenness centrality comparison
Weighted combination:
similarity = 0.25·node_score + 0.25·edge_score + 0.2·indegree + 0.2·outdegree + 0.1·centrality
Returns: Dict with individual scores and overall similarity
Purpose: Direct LLM-based similarity judgment.
Approach:
- Presents anchor and two candidate stories to LLM
- Asks which candidate is more similar
- Returns: "A" or "B"
Command:
python baselines/llm_baseline.py \
--index 10 \
--data_path data/compiled_track_a.jsonl \
--API_KEY "YOUR_KEY"Arguments:
--index: Sample to evaluate (default: -1 for all)--data_path: Path to dataset--API_KEY: Gemini API key (required)
Two GNN models for learning narrative similarity:
-
GCN (Graph Convolutional Network):
- Input: 768D semantic embeddings
- Hidden: 512D → 128D embedding
- Dropout: 0.3
-
GAT (Graph Attention Network):
- Input: 768D semantic embeddings
- Hidden: 128D × 8 heads → 128D embedding
- Dropout: 0.3
Class: FeatureTokenizer
Purpose: Convert graph features to 768D semantic embeddings using sentence-transformers.
Features encoded:
- Node types: character, event, theme, thread, tension
- Character attributes: functional_roles, key_attributes, character_type, agency_level
- Event attributes: event_type, narrative_function, emotional_tone, time_marker
- Relationship attributes: relationship_type, relationship_evolution, power_dynamic
- Causal attributes: relationship_type, strength, temporal_gap
- Edge types: All 10 edge types
Method: build_dynamic_vocabularies(graphs)
- Collects all unique categorical values from graphs
- Builds mapping from value → embedding
- Handles missing values with zero embeddings
Purpose: Train GNN models on triplet supervision.
Command:
python GNN/train.py \
--model both \
--loss triplet \
--epochs 50 \
--batch_size 12 \
--lr 0.002 \
--device cudaArguments:
--data: Path to JSONL dataset (default: from config)--model: Model to train (gcn,gat, orboth)--loss: Loss function (triplet,infonce,multisimilarity)--epochs: Number of training epochs (default: 50)--batch_size: Batch size (default: 12)--lr: Learning rate (default: 0.002)--device: Device (cudaorcpu)--seed: Random seed (default: 42)--show-config: Print configuration and exit
Output:
- Models:
/scratch/akr/models/{RUN_ID}/best_{model}_model.pt - Checkpoints:
/scratch/akr/models/{RUN_ID}/checkpoint_{model}_epoch_{N}.pt - Results:
/scratch/akr/results/{RUN_ID}/{model}_results.json - Logs:
/scratch/akr/logs/{RUN_ID}/{model}_training.log
Training features:
- Triplet loss with margin (default: 0.5)
- Early stopping (patience: 1000 epochs)
- Periodic checkpointing (every 50 epochs)
- Learning rate scheduling (ReduceLROnPlateau)
- 70/15/15 train/val/test split
Key settings:
TRAINING_CONFIG = {
"learning_rate": 2e-3,
"weight_decay": 1e-4,
"margin": 0.5,
"batch_size": 12,
"max_epochs": 50,
"early_stopping_patience": 1000,
}
DATA_CONFIG = {
"train_ratio": 0.7,
"val_ratio": 0.15,
"test_ratio": 0.15,
"random_seed": 42,
}Path structure (with run-specific directories):
/scratch/akr/
├── models/{RUN_ID}/
│ ├── best_gcn_model.pt
│ ├── best_gat_model.pt
│ └── checkpoint_*.pt
├── results/{RUN_ID}/
│ ├── gcn_results.json
│ ├── gat_results.json
│ ├── combined_results.json
│ └── run_metadata.json
└── logs/{RUN_ID}/
└── *.log
Split Management: GNN/split_manager.py
- Maintains consistent train/val/test splits
- Saves splits to
/home2/akr/StoryNet/data_split/ - Validates split integrity
Validation: GNN/validate_splits.py
python GNN/validate_splits.pyChecks for:
- No overlap between splits
- Correct ratios
- All indices in range
Regeneration: GNN/regenerate_splits.py
python GNN/regenerate_splits.py --seed 42Creates new splits with specified seed.
Each line contains one training triplet:
{
"anchor_text": "Story A text...",
"text_a": "Similar story text...",
"text_b": "Less similar story text...",
"text_a_is_closer": true
}After running analysis, graphs are cached:
graphs/idx_{N}/
├── anchor.pkl # Anchor story graph
├── text_a.pkl # Text A story graph
├── text_b.pkl # Text B story graph
├── anchor.html # Visualization
├── text_a.html
├── text_b.html
└── results.txt # Comparison results
Each .pkl file contains a serialized Story object with:
- All extracted information (characters, events, emotions, themes)
- NetworkX MultiDiGraph representation
- Metadata (analysis time, index, label)
{
"thematic": 0.25,
"character_semantic": 0.20,
"emotional_arc": 0.20,
"causal_structure": 0.15,
"narrative_function": 0.10,
"relationship_dynamics": 0.05,
"event_semantic": 0.05,
"intra": {
"character_semantic": {
"role": 0.4,
"attribute": 0.4,
"agency": 0.2
},
"thematic": {
"theme": 0.8,
"abstraction": 0.2
},
...
}
}- Top-level weights: Distribution across 7 dimensions (sum to 1)
- Intra weights: Distribution within each dimension (sum to 1 per dimension)
- Official Task: https://narrative-similarity-task.github.io/
- Initial Submission: https://www.overleaf.com/project/68acb675e43413f7dc707ab4
- Mid-term Submission: https://www.overleaf.com/3855277839svnnmqwsqpkr#3b3a22
- Pipeline Diagram: https://app.diagrams.net/#G1dRz4ljG-y7SvDPPAaYjm6-r0dW37t3lk%23%7B%22pageId%22%3A%22MWvHMk7ffjXTAU3lkiu7%22%7D