Most tutorials will tell you to use bi-encoders for semantic search. They're fast, they scale, they work. And for 80% of use cases, that's the right answer.
But here's what they won't tell you: bi-encoders make a fundamental tradeoff. They encode each sentence independently, which means they miss the subtle interactions between words across sentences. When you need to catch those interactions—when accuracy actually matters more than millisecond latency—you need cross-encoders.
This isn't about building the fastest system. It's about building the right system for problems where getting the answer wrong has consequences.
What Cross-Encoders Actually Do (And Why It Matters)
Think about how you judge if two sentences mean the same thing. You don't read sentence A, form an opinion, then read sentence B and form another opinion, then compare your two opinions. You read them together, constantly cross-referencing: "Oh, 'man' here corresponds to 'person' there. 'Playing' matches 'performing'. 'Guitar' and 'instrument'—close enough in this context."
That's what cross-encoders do. They take both sentences as a single input and let the transformer's attention mechanism draw connections between every word in sentence A and every word in sentence B.
Bi-encoders can't do this. They process each sentence separately, create an embedding, then measure geometric similarity. It's like forming an opinion about two people by looking at their passport photos side by side instead of watching them interact.
The result? Cross-encoders are significantly more accurate for tasks where nuance matters: duplicate detection, answer ranking, semantic similarity scoring.
The cost? You can't pre-compute embeddings. Every pair needs a fresh forward pass. For a million documents, that's a million forward passes per query instead of one. This is why everyone uses bi-encoders for retrieval and cross-encoders for re-ranking the top candidates.
When You Actually Need This
Before you invest time fine-tuning a cross-encoder, make sure you're solving the right problem. Here's when it makes sense:
You have a re-ranking problem. You've already narrowed down candidates using BM25 or a bi-encoder. Now you need to score the top 10-100 results with high precision. Classic use case: search engines, question-answering systems, recommendation re-rankers.
Accuracy directly impacts your outcome. If you're building a duplicate detection system for support tickets and false positives waste hours of human time, the accuracy gain is worth the compute cost. If you're just doing approximate clustering, probably not.
You have labeled training data. Fine-tuning requires pairs of sentences with ground-truth similarity scores or classification labels. If you're starting from scratch, collecting this data is your main bottleneck, not the model training.
Your domain has specific similarity patterns. Legal documents, medical records, customer support conversations—these have domain-specific ways of expressing similarity that generic models miss. Fine-tuning teaches the model your domain's equivalences.
The Real Constraints (India Context)
Let's talk about what this actually costs to run, because most tutorials skip this part.
Training Hardware:
- Fine-tuning BERT-base (110M parameters) on STS-B: needs ~8GB VRAM, runs on a single T4 GPU
- Training time: 1-2 hours for 3 epochs on STS-B (5,749 training pairs)
- Cost on Colab Pro: free on T4, or ~₹800/month for unlimited access
- Cost on AWS: p3.2xlarge at ~₹600/hour, so ₹600-1,200 for full training
- Local option: RTX 3060 (12GB) works fine, costs ~₹35,000 one-time
This is actually accessible compared to training large language models. You can fine-tune a cross-encoder on a decent gaming GPU.
Inference Reality:
- Cross-encoder processes ~100-200 sentence pairs per second on CPU (varies by sentence length)
- On GPU: ~1,000-2,000 pairs per second
- For re-ranking 100 candidates per query: 50-100ms on CPU, 5-10ms on GPU
If you're building for India where GPU access isn't a given, cross-encoders are actually reasonable. The compute requirements are modest enough that CPU inference works for many real-world loads.
The Tutorial: Fine-Tuning for Semantic Similarity
We'll fine-tune BERT on the STS Benchmark dataset, where sentence pairs are scored 0-5 on semantic similarity. This teaches the model to predict how similar two sentences are, which you can then use for ranking, duplicate detection, or similarity search.
Hardware Check
Before starting, verify you have:
- Python 3.8+
- 8GB RAM minimum (16GB recommended)
- GPU with 8GB+ VRAM (optional but 10x faster)
- 5GB disk space for model and dataset
If you're running this on CPU, it'll work but take 3-4x longer to train. For learning and small datasets, that's fine.
Install Dependencies
pip install torch transformers sentence-transformers datasets accelerate
Note: sentence-transformers is the key library here. It wraps Hugging Face transformers with convenient APIs for training and evaluating cross-encoders and bi-encoders.
Load and Inspect the Dataset
from datasets import load_dataset# STS-B: pairs of sentences with similarity scores 0-5dataset = load_dataset("stsb_multi_mt", "en")print(dataset)# Look at a few examplesfor i, example in enumerate(dataset["train"]): if i >= 3: break print(f"Sentence 1: {example['sentence1']}") print(f"Sentence 2: {example['sentence2']}") print(f"Similarity: {example['similarity_score']}/5.0") print()
The STS-B dataset has 5,749 training pairs and 1,379 test pairs. Each pair has two sentences and a human-annotated similarity score from 0 (completely different) to 5 (equivalent meaning).
Example pair:
- Sentence 1: "A man is playing a guitar."
- Sentence 2: "A person is playing a musical instrument."
- Score: 3.8/5.0
This is what we're teaching the model to predict.
Prepare Training Data
Cross-encoders need input in (sentence1, sentence2, label) format. The InputExample class from sentence-transformers handles this:
from sentence_transformers import InputExamplefrom torch.utils.data import DataLoader# Convert to InputExample formattrain_examples = [ InputExample( texts=[row["sentence1"], row["sentence2"]], label=float(row["similarity_score"]) ) for row in dataset["train"]]dev_examples = [ InputExample( texts=[row["sentence1"], row["sentence2"]], label=float(row["similarity_score"]) ) for row in dataset["test"]]# Create DataLoader with batch size 16# Smaller batches if you're memory-constrainedtrain_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
Why this format? Cross-encoders concatenate both sentences and feed them to BERT as: [CLS] sentence1 [SEP] sentence2 [SEP]. The model learns to output a single score representing their similarity.
Initialize the Model
from sentence_transformers import CrossEncoderMODEL_NAME = "bert-base-uncased" # 110M parameters, English onlyOUTPUT_DIR = "./cross-encoder-stsb"# num_labels=1 means regression (predicting continuous scores)# For classification tasks, set num_labels to number of classesmodel = CrossEncoder(MODEL_NAME, num_labels=1)
The model starts from BERT's pretrained weights but the classification head (the final layer that outputs the score) is randomly initialized. That's what we're training.
Set Up Evaluation
from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator# Evaluator computes Pearson and Spearman correlation# between predicted and actual similarity scoresevaluator = CECorrelationEvaluator.from_input_examples( dev_examples, name="sts-dev")
Correlation is the right metric here because we care about ranking pairs correctly, not predicting exact scores. A correlation of 0.85 means the model's rankings closely match human judgments.
Training
EPOCHS = 3WARMUP_STEPS = 100 # Linear warmup for learning rateEVAL_STEPS = 500 # Evaluate every 500 stepsmodel.fit( train_dataloader=train_dataloader, evaluator=evaluator, epochs=EPOCHS, evaluation_steps=EVAL_STEPS, warmup_steps=WARMUP_STEPS, output_path=OUTPUT_DIR, show_progress_bar=True)
What's happening during training:
- Model sees batches of sentence pairs with similarity scores
- Predicts scores using current weights
- Computes mean squared error between predictions and actual scores
- Updates weights to reduce that error
- Every 500 steps, evaluates on dev set and saves if it's the best model so far
On a T4 GPU, this takes about 90 minutes. On CPU, expect 4-6 hours.
Training Output (What to Expect)
Epoch 1/3{'loss': -20.15, 'learning_rate': 1.18e-05, 'epoch': 1.39}{'eval_pearson': 0.451, 'eval_spearman': 0.477, 'epoch': 1.39}Epoch 2/3{'loss': -32.75, 'learning_rate': 1.59e-06, 'epoch': 2.78}{'eval_pearson': 0.550, 'eval_spearman': 0.549, 'epoch': 2.78}Epoch 3/3{'train_loss': -27.04, 'epoch': 3.0}
Notice the Pearson correlation improving from 0.45 to 0.55. That's the model learning to match human similarity judgments better.
For reference, state-of-the-art cross-encoders on STS-B get Pearson correlation around 0.87-0.90. We're training a base model with minimal tuning, so 0.55 is reasonable. With hyperparameter optimization (batch size, learning rate, warmup), you can push this higher.
Load the Trained Model
# Reload from checkpoint for inferencemodel = CrossEncoder(OUTPUT_DIR)
The model is saved with all weights, config, and tokenizer. You can share this folder or deploy it directly.
Inference: Pairwise Similarity
# Test on new sentence pairstest_pairs = [ ("A man is playing a guitar.", "A person is playing a guitar."), ("A dog is running in the park.", "A cat is sleeping on the couch."), ("Python is a programming language.", "Python is a type of snake.")]scores = model.predict(test_pairs)for (s1, s2), score in zip(test_pairs, scores): print(f"{score:.3f} | {s1} ↔ {s2}")
Output:
1.000 | A man is playing a guitar. ↔ A person is playing a guitar.0.176 | A dog is running in the park. ↔ A cat is sleeping on the couch.0.412 | Python is a programming language. ↔ Python is a type of snake.
The model correctly identifies the first pair as near-identical, the second as unrelated, and the third as somewhat similar (same word, different contexts).
Inference: Ranking Candidates
This is the real-world use case. You have a query and multiple candidate answers. Score each candidate against the query and rank by score.
query = "What is the capital of France?"candidates = [ "Paris is the capital city of France.", "London is the capital of the UK.", "France is known for its wine and cheese.", "The capital of France is Paris, founded in the 3rd century BC."]# Create pairs: (query, candidate)pairs = [(query, cand) for cand in candidates]scores = model.predict(pairs)# Sort by score descendingranked = sorted( zip(candidates, scores), key=lambda x: x[1], reverse=True)print("\nRanked Results:")for rank, (cand, score) in enumerate(ranked, 1): print(f"{rank}. [{score:.3f}] {cand}")
Output:
Ranked Results:1. [1.000] Paris is the capital city of France.2. [0.987] The capital of France is Paris, founded in the 3rd century BC.3. [0.832] London is the capital of the UK.4. [0.543] France is known for its wine and cheese.
The model correctly ranks direct answers highest, even though the fourth candidate also mentions France.
Complete Working Code
Here's everything in one script you can run:
"""Fine-tune a BERT-based cross-encoder for semantic similarityDataset: STS Benchmark (English)Task: Predict similarity scores 0-5 for sentence pairs"""import torchfrom datasets import load_datasetfrom torch.utils.data import DataLoaderfrom sentence_transformers import CrossEncoder, InputExamplefrom sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator# ConfigurationMODEL_NAME = "bert-base-uncased"OUTPUT_DIR = "./cross-encoder-stsb"BATCH_SIZE = 16EPOCHS = 3WARMUP_STEPS = 100EVAL_STEPS = 500SEED = 42torch.manual_seed(SEED)# ===== 1. LOAD DATA =====print("Loading STS Benchmark dataset...")dataset = load_dataset("stsb_multi_mt", "en")# Show dataset structureprint(f"Train: {len(dataset['train'])} pairs")print(f"Test: {len(dataset['test'])} pairs\n")# Sample examplesample = dataset["train"][0]print("Sample:")print(f" Sentence 1: {sample['sentence1']}")print(f" Sentence 2: {sample['sentence2']}")print(f" Similarity: {sample['similarity_score']}/5.0\n")# ===== 2. PREPARE TRAINING DATA =====train_examples = [ InputExample( texts=[row["sentence1"], row["sentence2"]], label=float(row["similarity_score"]) ) for row in dataset["train"]]dev_examples = [ InputExample( texts=[row["sentence1"], row["sentence2"]], label=float(row["similarity_score"]) ) for row in dataset["test"]]train_dataloader = DataLoader( train_examples, shuffle=True, batch_size=BATCH_SIZE)# ===== 3. INITIALIZE MODEL =====print(f"Initializing CrossEncoder with {MODEL_NAME}...")model = CrossEncoder(MODEL_NAME, num_labels=1)# ===== 4. SET UP EVALUATION =====evaluator = CECorrelationEvaluator.from_input_examples( dev_examples, name="sts-dev")# ===== 5. TRAIN =====print("\nStarting training...")print(f"Epochs: {EPOCHS}, Batch size: {BATCH_SIZE}")print(f"Total training steps: {len(train_dataloader) * EPOCHS}\n")model.fit( train_dataloader=train_dataloader, evaluator=evaluator, epochs=EPOCHS, evaluation_steps=EVAL_STEPS, warmup_steps=WARMUP_STEPS, output_path=OUTPUT_DIR, show_progress_bar=True)print(f"\nModel saved to: {OUTPUT_DIR}")# ===== 6. RELOAD TRAINED MODEL =====print("\nReloading trained model for inference...")model = CrossEncoder(OUTPUT_DIR)# ===== 7. INFERENCE DEMO =====print("\n" + "="*60)print("INFERENCE DEMO")print("="*60)# Pairwise similarityprint("\n1. Pairwise Similarity Scoring:")test_pairs = [ ("A man is playing a guitar.", "A person is playing a guitar."), ("A dog is running in the park.", "A cat is sleeping on the couch."), ("Python is a programming language.", "Python is a type of snake.")]scores = model.predict(test_pairs)for (s1, s2), score in zip(test_pairs, scores): print(f" {score:.3f} | {s1} ↔ {s2}")# Rankingprint("\n2. Candidate Ranking:")query = "What is the capital of France?"candidates = [ "Paris is the capital city of France.", "London is the capital of the UK.", "France is known for its wine and cheese.", "The capital of France is Paris, founded in the 3rd century BC."]pairs = [(query, cand) for cand in candidates]scores = model.predict(pairs)ranked = sorted( zip(candidates, scores), key=lambda x: x[1], reverse=True)print(f" Query: {query}\n")for rank, (cand, score) in enumerate(ranked, 1): print(f" {rank}. [{score:.3f}] {cand}")print("\n" + "="*60)print("Training and inference complete!")print("="*60)
What Can Go Wrong (And How to Fix It)
1. Out of Memory During Training
Error: RuntimeError: CUDA out of memory
Fix:
- Reduce batch size from 16 to 8 or 4
- Use gradient accumulation:
model.fit(..., gradient_accumulation_steps=2) - Switch to a smaller model:
distilbert-base-uncasedinstead ofbert-base-uncased - If on GPU with <8GB VRAM, train on CPU (slower but works)
2. Poor Correlation Scores
If your dev set correlation is below 0.3 after training:
- Check if you loaded the right dataset (English split for English model)
- Verify labels are in the correct range (0-5 for STS-B)
- Try more epochs (3 might not be enough for convergence)
- Increase warmup steps to 200-300 for better learning rate scheduling
- Use a pretrained model already fine-tuned on similar tasks:
cross-encoder/stsb-roberta-base
3. Slow Inference on CPU
If scoring 100 pairs takes >10 seconds:
- Batch your predictions:
model.predict(pairs, batch_size=32)processes multiple pairs in parallel - Use a smaller model: DistilBERT is 40% faster with minimal accuracy loss
- Consider using a bi-encoder for initial retrieval, cross-encoder only for top-K re-ranking
- If deploying in production, use ONNX Runtime for 2-3x speedup on CPU
4. Model Predicts Same Score for Everything
This happens if training diverged or learning rate is too high:
- Check for NaN in loss values during training
- Reduce learning rate:
model.fit(..., learning_rate=1e-5)(default is 2e-5) - Inspect predictions on dev set:
model.predict(dev_examples[:10]) - If all scores are near 0 or 5, the model collapsed—restart training with lower LR
When NOT to Use Cross-Encoders
Cross-encoders are not the right tool if:
You need real-time retrieval from millions of documents. Cross-encoders can't pre-compute embeddings. For semantic search across large corpora, use bi-encoders for retrieval and cross-encoders for re-ranking the top 10-100 results.
Your task is symmetric similarity. If you just need to cluster documents or find near-duplicates where order doesn't matter, bi-encoders are faster and work fine. Cross-encoders shine when you have a query-document asymmetry.
You don't have labeled training data. Pre-trained cross-encoders exist (cross-encoder/ms-marco-MiniLM-L-6-v2 for search, cross-encoder/nli-deberta-v3-base for entailment), but if you need domain-specific fine-tuning and don't have labels, you'll need to collect or generate them first.
Your accuracy requirements are loose. If you're building a "related posts" feature where approximate matches are fine, the extra complexity isn't worth it. Use a bi-encoder, call it a day.
Production Deployment Considerations
If you're actually putting this into production (not just experimenting), here's what matters:
Batch Inference: Always batch your predictions. Scoring 100 pairs one at a time is 10x slower than scoring them in a batch of 100. Use model.predict(pairs, batch_size=64).
Model Serving: For API deployment, wrap the model in FastAPI or Flask and use ONNX Runtime for inference. This gives you 2-3x speedup and better hardware utilization.
Monitoring: Track inference latency (target: <100ms for batch of 10 pairs on CPU) and score distributions (watch for drift—if all scores start clustering, something's wrong).
Hardware Scaling: If you're processing >1000 queries/second, you need GPU inference. An AWS g4dn.xlarge (T4 GPU) handles ~2000 pairs/second at ₹2,500/month, which is cheaper than scaling CPU instances.
Versioning: Save your training config, dataset version, and model checkpoints. When you retrain, you want to be able to compare old vs new performance. Use MLflow or Weights & Biases for experiment tracking.
Where to Go From Here
This tutorial covered semantic similarity, but cross-encoders work for any pairwise classification or ranking task:
Natural Language Inference (NLI): Classify sentence pairs as entailment, contradiction, or neutral. Use the SNLI or MultiNLI datasets with num_labels=3.
Duplicate Detection: Binary classification—are these two questions/tickets/documents duplicates? Use Quora Question Pairs dataset with num_labels=2.
Answer Ranking: Given a question and candidate answers, rank answers by relevance. Use MS MARCO or Natural Questions datasets.
Zero-Shot Classification: Frame classification as: "Does this text entail this label?" Use cross-encoder for few-shot learning when you have minimal training data.
The code structure is the same. Change the dataset, adjust num_labels, and retrain.
The Bottom Line
Cross-encoders are your accuracy tool when bi-encoders aren't good enough. They're slower, they don't scale to millions of documents, but they catch the subtle interactions that matter for high-stakes tasks.
Fine-tuning one takes 1-2 hours and a modest GPU. The resulting model is surprisingly robust and works well across related domains. If you're building search, Q&A, or duplicate detection where precision matters, this is worth learning.
The sentence-transformers library makes it almost trivial. The hard parts are getting good training data and understanding when to use cross-encoders vs simpler approaches. But once you've got those figured out, you have a powerful tool for pairwise scoring that actually works.
Related Reading: