File size: 8,017 Bytes
d051424 44b0bbc d051424 44b0bbc d051424 39f3c5d d051424 39f3c5d d051424 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
---
license: apache-2.0
language:
- en
tags:
- information-retrieval
- LLM
- Embedding
- text-retrieval
- disaster-management
task_categories:
- text-retrieval
library_name: transformers
dataset_tags:
- DMIR01/DMRetriever_MTT
---
This model is trained through the approach described in [DMRetriever: A Family of Models for Improved Text Retrieval in Disaster Management](https://www.arxiv.org/abs/2510.15087).
The associated GitHub repository is available [here](https://github.com/KaiYin97/DMRETRIEVER).
This model has 7.6B parameters.
## 🧠 Model Overview
**DMRetriever-7.6B** has the following features:
- Model Type: Text Embedding
- Supported Languages: English
- Number of Paramaters: 7.6B
- Embedding Dimension: 4096
For more details, including model training, benchmark evaluation, and inference performance, please refer to our [paper](https://www.arxiv.org/abs/2510.15087), [GitHub](https://github.com/KaiYin97/DMRETRIEVER).
## 📦 DMRetriever Series Model List
| **Model** | **Description** | **Backbone** | **Backbone Type** | **Hidden Size** | **#Layers** |
|:--|:--|:--|:--|:--:|:--:|
| [DMRetriever-33M](https://huggingface.co/DMIR01/DMRetriever-33M) | Base 33M variant | MiniLM | Encoder-only | 384 | 12 |
| [DMRetriever-33M-PT](https://huggingface.co/DMIR01/DMRetriever-33M-PT) | Pre-trained version of 33M | MiniLM | Encoder-only | 384 | 12 |
| [DMRetriever-109M](https://huggingface.co/DMIR01/DMRetriever-109M) | Base 109M variant | BERT-base-uncased | Encoder-only | 768 | 12 |
| [DMRetriever-109M-PT](https://huggingface.co/DMIR01/DMRetriever-109M-PT) | Pre-trained version of 109M | BERT-base-uncased | Encoder-only | 768 | 12 |
| [DMRetriever-335M](https://huggingface.co/DMIR01/DMRetriever-335M) | Base 335M variant | BERT-large-uncased-WWM | Encoder-only | 1024 | 24 |
| [DMRetriever-335M-PT](https://huggingface.co/DMIR01/DMRetriever-335M-PT) | Pre-trained version of 335M | BERT-large-uncased-WWM | Encoder-only | 1024 | 24 |
| [DMRetriever-596M](https://huggingface.co/DMIR01/DMRetriever-596M) | Base 596M variant | Qwen3-0.6B | Decoder-only | 1024 | 28 |
| [DMRetriever-596M-PT](https://huggingface.co/DMIR01/DMRetriever-596M-PT) | Pre-trained version of 596M | Qwen3-0.6B | Decoder-only | 1024 | 28 |
| [DMRetriever-4B](https://huggingface.co/DMIR01/DMRetriever-4B) | Base 4B variant | Qwen3-4B | Decoder-only | 2560 | 36 |
| [DMRetriever-4B-PT](https://huggingface.co/DMIR01/DMRetriever-4B-PT) | Pre-trained version of 4B | Qwen3-4B | Decoder-only | 2560 | 36 |
| [DMRetriever-7.6B](https://huggingface.co/DMIR01/DMRetriever-7.6B) | Base 7.6B variant | Qwen3-8B | Decoder-only | 4096 | 36 |
| [DMRetriever-7.6B-PT](https://huggingface.co/DMIR01/DMRetriever-7.6B-PT) | Pre-trained version of 7.6B | Qwen3-8B | Decoder-only | 4096 | 36 |
## 🚀 Usage
Using HuggingFace Transformers:
```python
# pip install torch transformers
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from bidirectional_qwen3 import Qwen3BiModel # custom bidirectional backbone
MODEL_ID = "DMIR01/DMRetriever-7.6B"
# Device & dtype
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
# --- Tokenizer (needs remote code for custom modules) ---
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
trust_remote_code=True,
use_fast=False,
)
# Ensure pad token and right padding (matches training)
if getattr(tokenizer, "pad_token_id", None) is None and getattr(tokenizer, "eos_token", None) is not None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# --- Bidirectional encoder (non-autoregressive; for retrieval/embedding) ---
model = Qwen3BiModel.from_pretrained(
MODEL_ID,
torch_dtype=dtype,
trust_remote_code=True,
).to(device).eval()
# --- Mean pooling over valid tokens ---
def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype) # [B, L, 1]
summed = (last_hidden_state * mask).sum(dim=1) # [B, H]
counts = mask.sum(dim=1).clamp(min=1e-9) # [B, 1]
return summed / counts
# --- Batch encoder: returns L2-normalized embeddings ---
def encode_texts(texts, batch_size=32, max_length=512):
vecs = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
with torch.no_grad():
inputs = tokenizer(
batch,
max_length=max_length,
truncation=True,
padding=True,
return_tensors="pt",
).to(device)
hidden = model(**inputs).last_hidden_state
emb = mean_pool(hidden, inputs["attention_mask"])
emb = F.normalize(emb, p=2, dim=1) # cosine-ready
vecs.append(emb.cpu())
return torch.cat(vecs, dim=0) if vecs else torch.empty(0, model.config.hidden_size)
# --- Task instructions (apply to queries only) ---
TASK2PREFIX = {
"FactCheck": "Given the claim, retrieve most relevant document that supports or refutes the claim",
"NLI": "Given the premise, retrieve most relevant hypothesis that is entailed by the premise",
"QA": "Given the question, retrieve most relevant passage that best answers the question",
"QAdoc": "Given the question, retrieve the most relevant document that answers the question",
"STS": "Given the sentence, retrieve the sentence with the same meaning",
"Twitter": "Given the user query, retrieve the most relevant Twitter text that meets the request",
}
def apply_task_prefix(queries, task: str):
"""Add instruction to queries; corpus texts remain unchanged."""
prefix = TASK2PREFIX.get(task, "")
if prefix:
return [f"{prefix}: {q.strip()}" for q in queries]
return [q.strip() for q in queries]
# ========================= Usage =========================
# Queries need task instruction
task = "QA"
queries_raw = [
"Who wrote The Little Prince?",
"What is the capital of France?",
]
queries = apply_task_prefix(queries_raw, task)
# Corpus: no instruction
corpus_passages = [
"The Little Prince is a novella by Antoine de Saint-Exupéry, first published in 1943.",
"Paris is the capital and most populous city of France.",
"Transformers are neural architectures that rely on attention mechanisms.",
]
# Encode
query_emb = encode_texts(queries, batch_size=32, max_length=512) # [Q, H]
corpus_emb = encode_texts(corpus_passages, batch_size=32, max_length=512) # [D, H]
print("Query embeddings:", tuple(query_emb.shape))
print("Corpus embeddings:", tuple(corpus_emb.shape))
# Retrieval demo: cosine similarity via dot product (embeddings are normalized)
scores = query_emb @ corpus_emb.T # [Q, D]
topk = scores.topk(k=min(3, corpus_emb.size(0)), dim=1)
for i, q in enumerate(queries_raw):
print(f"\nQuery[{i}] {q}")
for rank, (score, idx) in enumerate(zip(topk.values[i].tolist(), topk.indices[i].tolist()), start=1):
print(f" Top{rank}: doc#{idx} | score={score:.4f} | text={corpus_passages[idx]}")
```
## ⚠️ Notice
1. The **backbone** used in DMRetriever is **Bidirectional Qwen3**, not the standard Qwen3.
Please ensure that the `bidirectional_qwen3` module (included in the released model checkpoint folder) is correctly placed inside your model directory.
2. Make sure that your **transformers** library version is **> 4.51.0** to avoid the error:
`KeyError: 'qwen3'`.
## 🧾 Citation
If you find this repository helpful, please kindly consider citing the corresponding paper. Thanks!
```
@article{yin2025dmretriever,
title={DMRetriever: A Family of Models for Improved Text Retrieval in Disaster Management},
author={Yin, Kai and Dong, Xiangjue and Liu, Chengkai and Lin, Allen and Shi, Lingfeng and Mostafavi, Ali and Caverlee, James},
journal={arXiv preprint arXiv:2510.15087},
year={2025}
}
```
|