Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/workrb/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
"""

from workrb.models.base import ModelInterface
from workrb.models.bi_encoder import BiEncoderModel, JobBERTModel
from workrb.models.bi_encoder import BiEncoderModel, ConTeXTMatchModel, JobBERTModel
from workrb.models.classification_model import RndESCOClassificationModel

__all__ = [
"BiEncoderModel",
"ConTeXTMatchModel",
"JobBERTModel",
"ModelInterface",
"RndESCOClassificationModel",
Expand Down
151 changes: 150 additions & 1 deletion src/workrb/models/bi_encoder.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""BiEncoder model wrapper for WorkRB."""
"""BiEncoder model wrapper for WorkRB, along with some instances of the BiEncoder model."""

from typing import Any

import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import batch_to_device
from torch.nn.utils.rnn import pad_sequence
from tqdm.auto import tqdm

from workrb.models.base import ModelInterface
Expand All @@ -25,10 +28,12 @@ def __init__(self, model_name: str = "all-MiniLM-L6-v2", **kwargs):

@property
def name(self) -> str:
"""Return the model name."""
return f"BiEncoder-{self.base_model_name.split('/')[-1]}"

@property
def description(self) -> str:
"""Return the model description."""
return "BiEncoder model using sentence-transformers for ranking and classification tasks."

def _compute_rankings(
Expand Down Expand Up @@ -127,10 +132,12 @@ def __init__(self, model_name: str = "TechWolf/JobBERT-v2", **kwargs):

@property
def name(self) -> str:
"""Return the model name."""
return self.base_model_name.split("/")[-1]

@property
def description(self) -> str:
"""Return the model description."""
return (
"Job-Normalization BiEncoder from Techwolf: https://huggingface.co/TechWolf/JobBERT-v2"
)
Expand Down Expand Up @@ -269,3 +276,145 @@ def citation(self) -> str | None:
year = {{2021}},
}
"""


@register_model()
class ConTeXTMatchModel(ModelInterface):
"""BiEncoder model using sentence-transformers."""

_NEAR_ZERO_THRESHOLD = 1e-9

def __init__(
self,
model_name: str = "TechWolf/ConTeXT-Skill-Extraction-base",
temperature: float = 1.0,
**kwargs,
):
self.base_model_name = model_name
self.model = SentenceTransformer(model_name)
self.temperature = temperature
# Move to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(device)
self.model.eval()

@staticmethod
def _context_match_score(
token_embeddings: torch.Tensor, target_embeddings: torch.Tensor, temperature: float = 1.0
) -> torch.Tensor:
"""Compute attention-weighted similarity between token embeddings and target embeddings."""
# token_embeddings: (B1, L, D), target_embeddings: (B2, D)
dot_scores = (token_embeddings @ target_embeddings.T).transpose(1, 2) # (B1, B2, L)
dot_scores[dot_scores.abs() < ConTeXTMatchModel._NEAR_ZERO_THRESHOLD] = float("-inf")
weights = torch.softmax(dot_scores / temperature, dim=2) # (B1, B2, L)

norm_tokens = torch.nn.functional.normalize(token_embeddings, p=2, dim=2) # (B1, L, D)
norm_targets = torch.nn.functional.normalize(target_embeddings, p=2, dim=1) # (B2, D)
sim_scores = (norm_tokens @ norm_targets.T).transpose(
1, 2
) # (B1,L,B2) -> transposed to(B1, B2, L)

return (weights * sim_scores).sum(dim=2) # (B1, B2)

@property
def name(self) -> str:
"""Return the model name."""
return self.base_model_name.split("/")[-1]

@property
def description(self) -> str:
"""Return the model description."""
return "ConTeXT-Skill-Extraction-base from Techwolf: https://huggingface.co/TechWolf/ConTeXT-Skill-Extraction-base"

@staticmethod
def encode_batch(contextmatch_model, texts, mean: bool = False) -> torch.Tensor:
"""Encode tokens pof the texts ConTeXT-Skill-Extraction-base model."""
args: dict[str, Any] = {
"normalize_embeddings": False,
"convert_to_tensor": True,
}
if not mean:
args["output_value"] = "token_embeddings"
return contextmatch_model.encode(texts, **args)

@staticmethod
def encode(
contextmatch_model, texts, batch_size: int = 128, mean: bool = False
) -> torch.Tensor:
"""Encode using the branch of the ConTeXT-Skill-Extraction-base model."""
# For token embeddings, process in batches and handle variable lengths
all_token_embeddings = []
for i in tqdm(range(0, len(texts), batch_size)):
batch = texts[i : i + batch_size]
batch_token_embs = ConTeXTMatchModel.encode_batch(contextmatch_model, batch, mean=mean)
all_token_embeddings.extend(batch_token_embs)

token_embeddings = pad_sequence(all_token_embeddings, batch_first=True)
return token_embeddings

@torch.no_grad()
def _compute_rankings(
self,
queries: list[str],
targets: list[str],
query_input_type: ModelInputType,
target_input_type: ModelInputType,
) -> torch.Tensor:
"""Compute ranking scores using attention-weighted similarity."""
query_token_embeddings = ConTeXTMatchModel.encode(self.model, queries)
target_token_embeddings_mean = ConTeXTMatchModel.encode(self.model, targets, mean=True)
return self._context_match_score(
query_token_embeddings, target_token_embeddings_mean, temperature=self.temperature
)

@torch.no_grad()
def _compute_classification(
self,
texts: list[str],
targets: list[str],
input_type: ModelInputType,
target_input_type: ModelInputType | None = None,
) -> torch.Tensor:
"""Compute classification scores by ranking texts against target labels.

Args:
texts: List of input texts to classify
targets: List of target class labels (as text)
input_type: Type of input (e.g., JOB_TITLE)
target_input_type: Type of target (e.g., SKILL_NAME). If None, uses input_type.

Returns
-------
Tensor of shape (n_texts, n_classes) with similarity scores
"""
if target_input_type is None:
target_input_type = input_type

# Use ranking mechanism to compute similarity between texts and class labels
return self._compute_rankings(
queries=texts,
targets=targets,
query_input_type=input_type,
target_input_type=target_input_type,
)

@property
def classification_label_space(self) -> list[str] | None:
"""ConTeXT-Match models do not have classification heads."""
return None

@property
def citation(self) -> str | None:
"""ConTeXT-Match model citations."""
return """
@ARTICLE{contextmatch_2025,
author={Decorte, Jens-Joris and van Hautte, Jeroen and Develder, Chris and Demeester, Thomas},
journal={IEEE Access},
title={Efficient Text Encoders for Labor Market Analysis},
year={2025},
volume={13},
number={},
pages={133596-133608},
keywords={Taxonomy;Contrastive learning;Training;Annotations;Benchmark testing;Training data;Large language models;Computational efficiency;Accuracy;Terminology;Labor market analysis;text encoders;skill extraction;job title normalization},
doi={10.1109/ACCESS.2025.3589147}}
"""
167 changes: 167 additions & 0 deletions tests/test_contextmatch_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import pytest # noqa: D100
import torch

from workrb.models.bi_encoder import ConTeXTMatchModel
from workrb.tasks import TechSkillExtractRanking
from workrb.tasks.abstract.base import DatasetSplit, Language
from workrb.types import ModelInputType


class TestConTeXTMatchModelLoading:
"""Test that ConTeXTMatchModel can be correctly loaded and initialized."""

def test_model_initialization_default(self):
"""Test model initialization with default parameters."""
model = ConTeXTMatchModel()
assert model is not None
assert model.base_model_name == "TechWolf/ConTeXT-Skill-Extraction-base"
assert model.temperature == 1.0
assert model.model is not None
# Model should be in eval mode (not training)
assert not model.model.training

def test_model_initialization_custom_params(self):
"""Test model initialization with custom parameters."""
custom_model_name = "TechWolf/ConTeXT-Skill-Extraction-base"
custom_temperature = 0.5
model = ConTeXTMatchModel(model_name=custom_model_name, temperature=custom_temperature)
assert model.base_model_name == custom_model_name
assert model.temperature == custom_temperature

def test_model_properties(self):
"""Test model name and description properties."""
model = ConTeXTMatchModel()
name = model.name
description = model.description
citation = model.citation

assert isinstance(name, str)
assert len(name) > 0
assert "ConTeXT" in name or "Skill" in name

assert isinstance(description, str)
assert len(description) > 0

assert citation is not None
assert isinstance(citation, str)
assert "contextmatch" in citation.lower() or "ConTeXT" in citation

def test_model_classification_label_space(self):
"""Test that classification_label_space returns None."""
model = ConTeXTMatchModel()
assert model.classification_label_space is None


class TestConTeXTMatchModelUsage:
"""Test that ConTeXTMatchModel can be used for ranking and classification."""

def test_compute_rankings_basic(self):
"""Test basic ranking computation."""
model = ConTeXTMatchModel()
queries = ["software engineer", "data scientist"]
targets = ["Python programming", "machine learning", "statistics"]

scores = model._compute_rankings(
queries=queries,
targets=targets,
query_input_type=ModelInputType.JOB_TITLE,
target_input_type=ModelInputType.SKILL_NAME,
)

# Check output shape: (n_queries, n_targets)
assert scores.shape == (len(queries), len(targets))
assert isinstance(scores, torch.Tensor)

# Scores should be finite
assert torch.isfinite(scores).all()

def test_compute_classification_basic(self):
"""Test basic classification computation."""
model = ConTeXTMatchModel()
texts = ["software engineer", "data scientist"]
targets = ["Python programming", "machine learning", "statistics"]

scores = model._compute_classification(
texts=texts,
targets=targets,
input_type=ModelInputType.JOB_TITLE,
target_input_type=ModelInputType.SKILL_NAME,
)

# Check output shape: (n_texts, n_targets)
assert scores.shape == (len(texts), len(targets))
assert isinstance(scores, torch.Tensor)

# Scores should be finite
assert torch.isfinite(scores).all()

def test_compute_classification_default_target_type(self):
"""Test classification with default target_input_type."""
model = ConTeXTMatchModel()
texts = ["software engineer", "data scientist"]
targets = ["Python programming", "machine learning"]

scores = model._compute_classification(
texts=texts,
targets=targets,
input_type=ModelInputType.JOB_TITLE,
)

assert scores.shape == (len(texts), len(targets))
assert torch.isfinite(scores).all()


class TestConTeXTMatchModelTechSkillExtraction:
"""Test ConTeXTMatchModel performance on TECH skill extraction test set."""

def test_tech_skill_extraction_benchmark_metrics(self):
"""
Test that ConTeXTMatchModel achieves results close to paper-reported metrics.

Paper reported on TECH skill extraction test set:
- Mean Reciprocal Rank (MRR): 0.632
- R-Precision@1 (RP@1): 50.99%
- R-Precision@5 (RP@5): 63.98%
- R-Precision@10 (RP@10): 73.99%
"""
# Initialize model and task
model = ConTeXTMatchModel()
task = TechSkillExtractRanking(split=DatasetSplit.TEST, languages=[Language.EN])

# Evaluate model on the task with the metrics from the paper
metrics = ["mrr", "rp@1", "rp@5", "rp@10"]
results = task.evaluate(model=model, metrics=metrics, language=Language.EN)

# Paper-reported values (RP metrics are percentages, convert to decimals)
expected_mrr = 0.632
expected_rp1 = 50.99 / 100.0 # Convert percentage to decimal
expected_rp5 = 63.98 / 100.0
expected_rp10 = 73.99 / 100.0

# Allow a little tolerance for floating point precision
mrr_tolerance = 0.05
rp_tolerance = 0.05

# Check MRR
actual_mrr = results["mrr"]
assert actual_mrr == pytest.approx(expected_mrr, abs=mrr_tolerance), (
f"MRR: expected {expected_mrr:.3f}, got {actual_mrr:.3f}"
)

# Check RP@1
actual_rp1 = results["rp@1"]
assert actual_rp1 == pytest.approx(expected_rp1, abs=rp_tolerance), (
f"RP@1: expected {expected_rp1:.3f}, got {actual_rp1:.3f}"
)

# Check RP@5
actual_rp5 = results["rp@5"]
assert actual_rp5 == pytest.approx(expected_rp5, abs=rp_tolerance), (
f"RP@5: expected {expected_rp5:.3f}, got {actual_rp5:.3f}"
)

# Check RP@10
actual_rp10 = results["rp@10"]
assert actual_rp10 == pytest.approx(expected_rp10, abs=rp_tolerance), (
f"RP@10: expected {expected_rp10:.3f}, got {actual_rp10:.3f}"
)