Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup
Paper
•
2101.06983
•
Published
•
1
This is a sentence-transformers model finetuned from google-bert/bert-base-multilingual-cased on the BKAI Vietnamese Legal Documents retrieval dataset. It maps sentences & paragraphs to a 768-dimensional dense vector space and can be used for semantic textual similarity, semantic search, paraphrase mining, text classification, clustering, and more. It also achieves an NDCG@10 score of 0.60389.
SentenceTransformer(
(0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel
(1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
)
pip install -U sentence-transformers
from sentence_transformers import SentenceTransformer
# Download from the 🤗 Hub & Run inference
model = SentenceTransformer("google-bert/bert-base-multilingual-cased")
sentences = [
'Tội xúc phạm danh dự?',
'Quyền lợi của người lao động?',
'Thủ tục đăng ký kết hôn?',
]
embeddings = model.encode(sentences)
print(embeddings.shape) # [3, 768]
# Get the similarity scores for the embeddings
similarities = model.similarity(embeddings, embeddings)
print(similarities.shape) # [3, 3]
You can finetune this model or its base model (google-bert/bert-base-multilingual-cased) on your own dataset.
pip install sentence-transformers datasets pandas tqdm
text_0, text_1import pandas as pd
df = pd.DataFrame([
{"text_0": "What is civil procedure?", "text_1": "Civil procedure governs how legal cases are processed."},
{"text_0": "Define contract law", "text_1": "Contract law deals with agreements between parties."},
# …
])
df.to_parquet("data/train.parquet", index=False)
MODEL_ID = "YuITC/bert-base-multilingual-cased-finetuned-VNLegalDocs"
MODEL_NAME = "bert-base-multilingual-cased"
CACHE_DIR = "./cache"
OUTPUT_DIR = "./output"
MAX_SEQ_LEN = 512
EPOCHS = 5
LR = 3e-5
BATCH_SIZE = 128
DEVICE = "cuda" # or "cpu"
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Transformer, Pooling
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
# Load base transformer + pooling
embedding_model = Transformer(MODEL_ID, max_seq_length=MAX_SEQ_LEN, cache_dir=CACHE_DIR)
pooling_model = Pooling(embedding_model.get_word_embedding_dimension(),
pooling_mode_mean_tokens=True)
model = SentenceTransformer(modules=[embedding_model, pooling_model],
device=DEVICE, cache_folder=CACHE_DIR)
# Use multiple negatives ranking loss
loss = CachedMultipleNegativesRankingLoss(model=model)
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from datasets import Dataset
import pandas as pd
# Load your training DataFrame
df_train = pd.read_parquet("data/train.parquet")
train_ds = Dataset.from_pandas(df_train)
# Training arguments
args = SentenceTransformerTrainingArguments(output_dir=OUTPUT_DIR,
num_train_epochs=EPOCHS,
per_device_train_batch_size=BATCH_SIZE,
learning_rate=LR, warmup_ratio=0.1, fp16=True, logging_steps=100)
trainer = SentenceTransformerTrainer(model=model, args=args, train_dataset=train_ds, loss=loss)
# Start fine-tuning
trainer.train()
model = SentenceTransformer(OUTPUT_DIR, device=DEVICE)
embeddings = model.encode(["Your query here"], convert_to_tensor=True)
Size: 99,580 training samples (after preprocess)
Columns: text_0 and text_1
Approximate statistics based on the first 1000 samples:
| Column | type | min tokens | mean tokens | max tokens |
|---|---|---|---|---|
| text_0 | string | 8 | 25.64 | 58 |
| text_1 | string | 13 | 278.08 | 512 |
per_device_train_batch_size: 128learning_rate: 3e-05num_train_epochs: 5warmup_ratio: 0.1fp16: Truebatch_sampler: no_duplicates| Step | 100 | 200 | 400 | 600 | 800 | 1000 | 1200 | 1400 | 1600 | 1800 | 2000 | 2200 | 2400 | 2600 | 2800 | 3000 | 3200 | 3400 | 3600 | 3800 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Training Loss | 1.8827 | 0.4428 | 0.2856 | 0.2241 | 0.1894 | 0.1432 | 0.1311 | 0.1227 | 0.1028 | 0.0850 | 0.0800 | 0.0802 | 0.0633 | 0.0612 | 0.0566 | 0.0548 | 0.0479 | 0.0440 | 0.0444 | 0.0461 |
AbsTaskRetrieval class:class BKAILegalDocRetrievalTask(AbsTaskRetrieval):
# Metadata definition used by MTEB benchmark
metadata = TaskMetadata(name='BKAILegalDocRetrieval',
description='',
reference='https://github.com/embeddings-benchmark/mteb/blob/main/docs/adding_a_dataset.md',
type='Retrieval',
category='s2p',
modalities=['text'],
eval_splits=['test'],
eval_langs=['vi'],
main_score='ndcg_at_10',
other_scores=['recall_at_10', 'precision_at_10', 'map'],
dataset={
'path' : 'data',
'revision': 'd4c5a8ba10ae71224752c727094ac4c46947fa29',
},
date=('2012-01-01', '2020-01-01'),
form='Written',
domains=['Academic', 'Non-fiction'],
task_subtypes=['Scientific Reranking'],
license='cc-by-nc-4.0',
annotations_creators='derived',
dialect=[],
text_creation='found',
bibtex_citation=''
)
data_loaded = True # Flag
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.corpus = {}
self.queries = {}
self.relevant_docs = {}
shared_corpus = {}
for _, row in data['corpus'].iterrows():
shared_corpus[f"c{row['cid']}"] = {
'text': row['text'],
'_id' : row['cid']
}
for split in ['train', 'test']:
self.corpus[split] = shared_corpus
self.queries[split] = {}
self.relevant_docs[split] = {}
for split in ['train', 'test']:
for _, row in data[split].iterrows():
qid, cids = row['qid'], row['cid']
qid_str = f'q{qid}'
cids_str = [f'c{cid}' for cid in cids]
self.queries[split][qid_str] = row['question']
if qid_str not in self.relevant_docs[split]:
self.relevant_docs[split][qid_str] = {}
for cid_str in cids_str:
self.relevant_docs[split][qid_str][cid_str] = 1
self.data_loaded = True
fine_tuned_model = SentenceTransformer(OUTPUT_DIR, device=DEVICE)
custom_task = BKAILegalDocRetrievalTask()
evaluation = MTEB(tasks=[custom_task])
evaluation.run(fine_tuned_model, batch_size=BATCH_SIZE)
@inproceedings{reimers-2019-sentence-bert,
title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
author = "Reimers, Nils and Gurevych, Iryna",
booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
month = "11",
year = "2019",
publisher = "Association for Computational Linguistics",
url = "https://arxiv.org/abs/1908.10084",
}
@misc{gao2021scaling,
title={Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup},
author={Luyu Gao and Yunyi Zhang and Jiawei Han and Jamie Callan},
year={2021},
eprint={2101.06983},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
Base model
google-bert/bert-base-multilingual-cased