File size: 8,017 Bytes
d051424
 
 
 
 
44b0bbc
d051424
 
44b0bbc
 
 
 
 
 
 
 
d051424
 
 
 
39f3c5d
d051424
 
 
39f3c5d
d051424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
---
license: apache-2.0
language:
- en
tags:
- information-retrieval
- LLM
- Embedding
- text-retrieval
- disaster-management

task_categories:
- text-retrieval
library_name: transformers
dataset_tags:
  - DMIR01/DMRetriever_MTT
---

This model is trained through the approach described in [DMRetriever: A Family of Models for Improved Text Retrieval in Disaster Management](https://www.arxiv.org/abs/2510.15087). 
The associated GitHub repository is available [here](https://github.com/KaiYin97/DMRETRIEVER).
This model has 7.6B parameters.

## 🧠 Model Overview

**DMRetriever-7.6B** has the following features:

- Model Type: Text Embedding
- Supported Languages: English
- Number of Paramaters: 7.6B
- Embedding Dimension: 4096

For more details, including model training, benchmark evaluation, and inference performance, please refer to our [paper](https://www.arxiv.org/abs/2510.15087), [GitHub](https://github.com/KaiYin97/DMRETRIEVER).

## 📦 DMRetriever Series Model List

| **Model** | **Description** | **Backbone** | **Backbone Type** | **Hidden Size** | **#Layers** |
|:--|:--|:--|:--|:--:|:--:|
| [DMRetriever-33M](https://huggingface.co/DMIR01/DMRetriever-33M) | Base 33M variant | MiniLM | Encoder-only | 384 | 12 |
| [DMRetriever-33M-PT](https://huggingface.co/DMIR01/DMRetriever-33M-PT) | Pre-trained version of 33M | MiniLM | Encoder-only | 384 | 12 |
| [DMRetriever-109M](https://huggingface.co/DMIR01/DMRetriever-109M) | Base 109M variant | BERT-base-uncased | Encoder-only | 768 | 12 |
| [DMRetriever-109M-PT](https://huggingface.co/DMIR01/DMRetriever-109M-PT) | Pre-trained version of 109M | BERT-base-uncased | Encoder-only | 768 | 12 |
| [DMRetriever-335M](https://huggingface.co/DMIR01/DMRetriever-335M) | Base 335M variant | BERT-large-uncased-WWM | Encoder-only | 1024 | 24 |
| [DMRetriever-335M-PT](https://huggingface.co/DMIR01/DMRetriever-335M-PT) | Pre-trained version of 335M | BERT-large-uncased-WWM | Encoder-only | 1024 | 24 |
| [DMRetriever-596M](https://huggingface.co/DMIR01/DMRetriever-596M) | Base 596M variant | Qwen3-0.6B | Decoder-only | 1024 | 28 |
| [DMRetriever-596M-PT](https://huggingface.co/DMIR01/DMRetriever-596M-PT) | Pre-trained version of 596M | Qwen3-0.6B | Decoder-only | 1024 | 28 |
| [DMRetriever-4B](https://huggingface.co/DMIR01/DMRetriever-4B) | Base 4B variant | Qwen3-4B | Decoder-only | 2560 | 36 |
| [DMRetriever-4B-PT](https://huggingface.co/DMIR01/DMRetriever-4B-PT) | Pre-trained version of 4B | Qwen3-4B | Decoder-only | 2560 | 36 |
| [DMRetriever-7.6B](https://huggingface.co/DMIR01/DMRetriever-7.6B) | Base 7.6B variant | Qwen3-8B | Decoder-only | 4096 | 36 |
| [DMRetriever-7.6B-PT](https://huggingface.co/DMIR01/DMRetriever-7.6B-PT) | Pre-trained version of 7.6B | Qwen3-8B | Decoder-only | 4096 | 36 |


## 🚀 Usage

Using HuggingFace Transformers:

```python
# pip install torch transformers
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from bidirectional_qwen3 import Qwen3BiModel  # custom bidirectional backbone

MODEL_ID = "DMIR01/DMRetriever-7.6B"

# Device & dtype
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32

# --- Tokenizer (needs remote code for custom modules) ---
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    use_fast=False,
)
# Ensure pad token and right padding (matches training)
if getattr(tokenizer, "pad_token_id", None) is None and getattr(tokenizer, "eos_token", None) is not None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# --- Bidirectional encoder (non-autoregressive; for retrieval/embedding) ---
model = Qwen3BiModel.from_pretrained(
    MODEL_ID,
    torch_dtype=dtype,
    trust_remote_code=True,
).to(device).eval()

# --- Mean pooling over valid tokens ---
def mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    mask = attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)  # [B, L, 1]
    summed = (last_hidden_state * mask).sum(dim=1)                   # [B, H]
    counts = mask.sum(dim=1).clamp(min=1e-9)                         # [B, 1]
    return summed / counts

# --- Batch encoder: returns L2-normalized embeddings ---
def encode_texts(texts, batch_size=32, max_length=512):
    vecs = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        with torch.no_grad():
            inputs = tokenizer(
                batch,
                max_length=max_length,
                truncation=True,
                padding=True,
                return_tensors="pt",
            ).to(device)
            hidden = model(**inputs).last_hidden_state
            emb = mean_pool(hidden, inputs["attention_mask"])
            emb = F.normalize(emb, p=2, dim=1)  # cosine-ready
            vecs.append(emb.cpu())
    return torch.cat(vecs, dim=0) if vecs else torch.empty(0, model.config.hidden_size)

# --- Task instructions (apply to queries only) ---
TASK2PREFIX = {
    "FactCheck": "Given the claim, retrieve most relevant document that supports or refutes the claim",
    "NLI":       "Given the premise, retrieve most relevant hypothesis that is entailed by the premise",
    "QA":        "Given the question, retrieve most relevant passage that best answers the question",
    "QAdoc":     "Given the question, retrieve the most relevant document that answers the question",
    "STS":       "Given the sentence, retrieve the sentence with the same meaning",
    "Twitter":   "Given the user query, retrieve the most relevant Twitter text that meets the request",
}

def apply_task_prefix(queries, task: str):
    """Add instruction to queries; corpus texts remain unchanged."""
    prefix = TASK2PREFIX.get(task, "")
    if prefix:
        return [f"{prefix}: {q.strip()}" for q in queries]
    return [q.strip() for q in queries]

# ========================= Usage =========================
# Queries need task instruction
task = "QA"
queries_raw = [
    "Who wrote The Little Prince?",
    "What is the capital of France?",
]
queries = apply_task_prefix(queries_raw, task)

# Corpus: no instruction
corpus_passages = [
    "The Little Prince is a novella by Antoine de Saint-Exupéry, first published in 1943.",
    "Paris is the capital and most populous city of France.",
    "Transformers are neural architectures that rely on attention mechanisms.",
]

# Encode
query_emb  = encode_texts(queries,         batch_size=32, max_length=512)  # [Q, H]
corpus_emb = encode_texts(corpus_passages, batch_size=32, max_length=512)  # [D, H]
print("Query embeddings:",  tuple(query_emb.shape))
print("Corpus embeddings:", tuple(corpus_emb.shape))

# Retrieval demo: cosine similarity via dot product (embeddings are normalized)
scores = query_emb @ corpus_emb.T  # [Q, D]
topk = scores.topk(k=min(3, corpus_emb.size(0)), dim=1)

for i, q in enumerate(queries_raw):
    print(f"\nQuery[{i}] {q}")
    for rank, (score, idx) in enumerate(zip(topk.values[i].tolist(), topk.indices[i].tolist()), start=1):
        print(f"  Top{rank}: doc#{idx} | score={score:.4f} | text={corpus_passages[idx]}")


```

## ⚠️ Notice

1. The **backbone** used in DMRetriever is **Bidirectional Qwen3**, not the standard Qwen3.  
   Please ensure that the `bidirectional_qwen3` module (included in the released model checkpoint folder) is correctly placed inside your model directory.

2. Make sure that your **transformers** library version is **> 4.51.0** to avoid the error:  
   `KeyError: 'qwen3'`.


## 🧾 Citation
If you find this repository helpful, please kindly consider citing the corresponding paper. Thanks!
```
@article{yin2025dmretriever,
  title={DMRetriever: A Family of Models for Improved Text Retrieval in Disaster Management},
  author={Yin, Kai and Dong, Xiangjue and Liu, Chengkai and Lin, Allen and Shi, Lingfeng and Mostafavi, Ali and Caverlee, James},
  journal={arXiv preprint arXiv:2510.15087},
  year={2025}
}
```