Update README.md
Browse files
README.md
CHANGED
|
@@ -1,3 +1,176 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: apache-2.0
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
tags:
|
| 6 |
+
- Retrieval
|
| 7 |
+
- LLM
|
| 8 |
+
- Embedding
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
This model is trained through the approach described in [DMRetriever: A Family of Models for Improved Text Retrieval in Disaster Management](https://www.arxiv.org/abs/2510.15087).
|
| 12 |
+
The associated GitHub repository is available [here](https://github.com/KaiYin97/DMRETRIEVER).
|
| 13 |
+
This model has 7.6M parameters.
|
| 14 |
+
|
| 15 |
+
## 🧠 Model Overview
|
| 16 |
+
|
| 17 |
+
**DMRetriever-596M** has the following features:
|
| 18 |
+
|
| 19 |
+
- Model Type: Text Embedding
|
| 20 |
+
- Supported Languages: English
|
| 21 |
+
- Number of Paramaters: 7.6B
|
| 22 |
+
- Embedding Dimension: 4096
|
| 23 |
+
|
| 24 |
+
For more details, including model training, benchmark evaluation, and inference performance, please refer to our [paper](https://www.arxiv.org/abs/2510.15087), [GitHub](https://github.com/KaiYin97/DMRETRIEVER).
|
| 25 |
+
|
| 26 |
+
## 📦 DMRetriever Series Model List
|
| 27 |
+
|
| 28 |
+
| **Model** | **Description** | **Backbone** | **Backbone Type** | **Hidden Size** | **#Layers** |
|
| 29 |
+
|:--|:--|:--|:--|:--:|:--:|
|
| 30 |
+
| [DMRetriever-33M](https://huggingface.co/DMIR01/DMRetriever-33M) | Base 33M variant | MiniLM | Encoder-only | 384 | 12 |
|
| 31 |
+
| [DMRetriever-33M-PT](https://huggingface.co/DMIR01/DMRetriever-33M-PT) | Pre-trained version of 33M | MiniLM | Encoder-only | 384 | 12 |
|
| 32 |
+
| [DMRetriever-109M](https://huggingface.co/DMIR01/DMRetriever-109M) | Base 109M variant | BERT-base-uncased | Encoder-only | 768 | 12 |
|
| 33 |
+
| [DMRetriever-109M-PT](https://huggingface.co/DMIR01/DMRetriever-109M-PT) | Pre-trained version of 109M | BERT-base-uncased | Encoder-only | 768 | 12 |
|
| 34 |
+
| [DMRetriever-335M](https://huggingface.co/DMIR01/DMRetriever-335M) | Base 335M variant | BERT-large-uncased-WWM | Encoder-only | 1024 | 24 |
|
| 35 |
+
| [DMRetriever-335M-PT](https://huggingface.co/DMIR01/DMRetriever-335M-PT) | Pre-trained version of 335M | BERT-large-uncased-WWM | Encoder-only | 1024 | 24 |
|
| 36 |
+
| [DMRetriever-596M](https://huggingface.co/DMIR01/DMRetriever-596M) | Base 596M variant | Qwen3-0.6B | Decoder-only | 1024 | 28 |
|
| 37 |
+
| [DMRetriever-596M-PT](https://huggingface.co/DMIR01/DMRetriever-596M-PT) | Pre-trained version of 596M | Qwen3-0.6B | Decoder-only | 1024 | 28 |
|
| 38 |
+
| [DMRetriever-4B](https://huggingface.co/DMIR01/DMRetriever-4B) | Base 4B variant | Qwen3-4B | Decoder-only | 2560 | 36 |
|
| 39 |
+
| [DMRetriever-4B-PT](https://huggingface.co/DMIR01/DMRetriever-4B-PT) | Pre-trained version of 4B | Qwen3-4B | Decoder-only | 2560 | 36 |
|
| 40 |
+
| [DMRetriever-7.6B](https://huggingface.co/DMIR01/DMRetriever-7.6B) | Base 7.6B variant | Qwen3-8B | Decoder-only | 4096 | 36 |
|
| 41 |
+
| [DMRetriever-7.6B-PT](https://huggingface.co/DMIR01/DMRetriever-7.6B-PT) | Pre-trained version of 7.6B | Qwen3-8B | Decoder-only | 4096 | 36 |
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
## 🚀 Usage
|
| 45 |
+
|
| 46 |
+
Using HuggingFace Transformers:
|
| 47 |
+
|
| 48 |
+
```python
|
| 49 |
+
# pip install torch transformers
|
| 50 |
+
import torch
|
| 51 |
+
import torch.nn.functional as F
|
| 52 |
+
from transformers import AutoTokenizer
|
| 53 |
+
from bidirectional_qwen3 import Qwen3BiModel # custom bidirectional backbone
|
| 54 |
+
|
| 55 |
+
MODEL_ID = "DMIR01/DMRetriever-7.6B"
|
| 56 |
+
|
| 57 |
+
# Device & dtype
|
| 58 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 59 |
+
dtype = torch.float16 if device == "cuda" else torch.float32
|
| 60 |
+
|
| 61 |
+
# --- Tokenizer (needs remote code for custom modules) ---
|
| 62 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 63 |
+
MODEL_ID,
|
| 64 |
+
trust_remote_code=True,
|
| 65 |
+
use_fast=False,
|
| 66 |
+
)
|
| 67 |
+
# Ensure pad token and right padding (matches training)
|
| 68 |
+
if getattr(tokenizer, "pad_token_id", None) is None and getattr(tokenizer, "eos_token", None) is not None:
|
| 69 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 70 |
+
tokenizer.padding_side = "right"
|
| 71 |
+
|
| 72 |
+
# --- Bidirectional encoder (non-autoregressive; for retrieval/embedding) ---
|
| 73 |
+
model = Qwen3BiModel.from_pretrained(
|
| 74 |
+
MODEL_ID,
|
| 75 |
+
torch_dtype=dtype,
|
| 76 |
+
trust_remote_code=True,
|
| 77 |
+
).to(device).eval()
|
| 78 |
+
|
| 79 |
+
# --- Mean pooling over valid tokens ---
|
| 80 |
+
def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
| 81 |
+
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype) # [B, L, 1]
|
| 82 |
+
summed = (last_hidden_state * mask).sum(dim=1) # [B, H]
|
| 83 |
+
counts = mask.sum(dim=1).clamp(min=1e-9) # [B, 1]
|
| 84 |
+
return summed / counts
|
| 85 |
+
|
| 86 |
+
# --- Batch encoder: returns L2-normalized embeddings ---
|
| 87 |
+
def encode_texts(texts, batch_size=32, max_length=512):
|
| 88 |
+
vecs = []
|
| 89 |
+
for i in range(0, len(texts), batch_size):
|
| 90 |
+
batch = texts[i:i+batch_size]
|
| 91 |
+
with torch.no_grad():
|
| 92 |
+
inputs = tokenizer(
|
| 93 |
+
batch,
|
| 94 |
+
max_length=max_length,
|
| 95 |
+
truncation=True,
|
| 96 |
+
padding=True,
|
| 97 |
+
return_tensors="pt",
|
| 98 |
+
).to(device)
|
| 99 |
+
hidden = model(**inputs).last_hidden_state
|
| 100 |
+
emb = mean_pool(hidden, inputs["attention_mask"])
|
| 101 |
+
emb = F.normalize(emb, p=2, dim=1) # cosine-ready
|
| 102 |
+
vecs.append(emb.cpu())
|
| 103 |
+
return torch.cat(vecs, dim=0) if vecs else torch.empty(0, model.config.hidden_size)
|
| 104 |
+
|
| 105 |
+
# --- Task instructions (apply to queries only) ---
|
| 106 |
+
TASK2PREFIX = {
|
| 107 |
+
"FactCheck": "Given the claim, retrieve most relevant document that supports or refutes the claim",
|
| 108 |
+
"NLI": "Given the premise, retrieve most relevant hypothesis that is entailed by the premise",
|
| 109 |
+
"QA": "Given the question, retrieve most relevant passage that best answers the question",
|
| 110 |
+
"QAdoc": "Given the question, retrieve the most relevant document that answers the question",
|
| 111 |
+
"STS": "Given the sentence, retrieve the sentence with the same meaning",
|
| 112 |
+
"Twitter": "Given the user query, retrieve the most relevant Twitter text that meets the request",
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
def apply_task_prefix(queries, task: str):
|
| 116 |
+
"""Add instruction to queries; corpus texts remain unchanged."""
|
| 117 |
+
prefix = TASK2PREFIX.get(task, "")
|
| 118 |
+
if prefix:
|
| 119 |
+
return [f"{prefix}: {q.strip()}" for q in queries]
|
| 120 |
+
return [q.strip() for q in queries]
|
| 121 |
+
|
| 122 |
+
# ========================= Usage =========================
|
| 123 |
+
# Queries need task instruction
|
| 124 |
+
task = "QA"
|
| 125 |
+
queries_raw = [
|
| 126 |
+
"Who wrote The Little Prince?",
|
| 127 |
+
"What is the capital of France?",
|
| 128 |
+
]
|
| 129 |
+
queries = apply_task_prefix(queries_raw, task)
|
| 130 |
+
|
| 131 |
+
# Corpus: no instruction
|
| 132 |
+
corpus_passages = [
|
| 133 |
+
"The Little Prince is a novella by Antoine de Saint-Exupéry, first published in 1943.",
|
| 134 |
+
"Paris is the capital and most populous city of France.",
|
| 135 |
+
"Transformers are neural architectures that rely on attention mechanisms.",
|
| 136 |
+
]
|
| 137 |
+
|
| 138 |
+
# Encode
|
| 139 |
+
query_emb = encode_texts(queries, batch_size=32, max_length=512) # [Q, H]
|
| 140 |
+
corpus_emb = encode_texts(corpus_passages, batch_size=32, max_length=512) # [D, H]
|
| 141 |
+
print("Query embeddings:", tuple(query_emb.shape))
|
| 142 |
+
print("Corpus embeddings:", tuple(corpus_emb.shape))
|
| 143 |
+
|
| 144 |
+
# Retrieval demo: cosine similarity via dot product (embeddings are normalized)
|
| 145 |
+
scores = query_emb @ corpus_emb.T # [Q, D]
|
| 146 |
+
topk = scores.topk(k=min(3, corpus_emb.size(0)), dim=1)
|
| 147 |
+
|
| 148 |
+
for i, q in enumerate(queries_raw):
|
| 149 |
+
print(f"\nQuery[{i}] {q}")
|
| 150 |
+
for rank, (score, idx) in enumerate(zip(topk.values[i].tolist(), topk.indices[i].tolist()), start=1):
|
| 151 |
+
print(f" Top{rank}: doc#{idx} | score={score:.4f} | text={corpus_passages[idx]}")
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
## ⚠️ Notice
|
| 157 |
+
|
| 158 |
+
1. The **backbone** used in DMRetriever is **Bidirectional Qwen3**, not the standard Qwen3.
|
| 159 |
+
Please ensure that the `bidirectional_qwen3` module (included in the released model checkpoint folder) is correctly placed inside your model directory.
|
| 160 |
+
|
| 161 |
+
2. Make sure that your **transformers** library version is **> 4.51.0** to avoid the error:
|
| 162 |
+
`KeyError: 'qwen3'`.
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
## 🧾 Citation
|
| 166 |
+
If you find this repository helpful, please kindly consider citing the corresponding paper. Thanks!
|
| 167 |
+
```
|
| 168 |
+
@article{yin2025dmretriever,
|
| 169 |
+
title={DMRetriever: A Family of Models for Improved Text Retrieval in Disaster Management},
|
| 170 |
+
author={Yin, Kai and Dong, Xiangjue and Liu, Chengkai and Lin, Allen and Shi, Lingfeng and Mostafavi, Ali and Caverlee, James},
|
| 171 |
+
journal={arXiv preprint arXiv:2510.15087},
|
| 172 |
+
year={2025}
|
| 173 |
+
}
|
| 174 |
+
```
|
| 175 |
+
|
| 176 |
+
|