Upload folder using huggingface_hub
Browse files- README.md +109 -199
- config.json +8 -8
- model.safetensors +2 -2
- modeling_dhara.py +335 -136
- tokenizer.json +2 -16
README.md
CHANGED
|
@@ -2,253 +2,163 @@
|
|
| 2 |
license: apache-2.0
|
| 3 |
language:
|
| 4 |
- en
|
|
|
|
| 5 |
tags:
|
| 6 |
-
- text-generation
|
| 7 |
- diffusion
|
| 8 |
-
- language-model
|
| 9 |
-
-
|
|
|
|
|
|
|
|
|
|
| 10 |
datasets:
|
| 11 |
-
-
|
| 12 |
-
- allenai/dolma
|
| 13 |
-
- mlfoundations/dclm-baseline-1.0
|
| 14 |
-
model-index:
|
| 15 |
-
- name: dhara-70m
|
| 16 |
-
results:
|
| 17 |
-
- task:
|
| 18 |
-
type: text-generation
|
| 19 |
-
dataset:
|
| 20 |
-
name: HellaSwag
|
| 21 |
-
type: hellaswag
|
| 22 |
-
metrics:
|
| 23 |
-
- name: Accuracy
|
| 24 |
-
type: accuracy
|
| 25 |
-
value: 25.58
|
| 26 |
-
- task:
|
| 27 |
-
type: text-generation
|
| 28 |
-
dataset:
|
| 29 |
-
name: PIQA
|
| 30 |
-
type: piqa
|
| 31 |
-
metrics:
|
| 32 |
-
- name: Accuracy
|
| 33 |
-
type: accuracy
|
| 34 |
-
value: 51.58
|
| 35 |
-
- task:
|
| 36 |
-
type: text-generation
|
| 37 |
-
dataset:
|
| 38 |
-
name: WinoGrande
|
| 39 |
-
type: winogrande
|
| 40 |
-
metrics:
|
| 41 |
-
- name: Accuracy
|
| 42 |
-
type: accuracy
|
| 43 |
-
value: 49.64
|
| 44 |
-
- task:
|
| 45 |
-
type: text-generation
|
| 46 |
-
dataset:
|
| 47 |
-
name: ARC-Challenge
|
| 48 |
-
type: arc_challenge
|
| 49 |
-
metrics:
|
| 50 |
-
- name: Accuracy
|
| 51 |
-
type: accuracy
|
| 52 |
-
value: 24.83
|
| 53 |
-
- task:
|
| 54 |
-
type: text-generation
|
| 55 |
-
dataset:
|
| 56 |
-
name: MMLU
|
| 57 |
-
type: mmlu
|
| 58 |
-
metrics:
|
| 59 |
-
- name: Accuracy
|
| 60 |
-
type: accuracy
|
| 61 |
-
value: 23.85
|
| 62 |
-
- task:
|
| 63 |
-
type: text-generation
|
| 64 |
-
dataset:
|
| 65 |
-
name: TruthfulQA
|
| 66 |
-
type: truthfulqa_mc2
|
| 67 |
-
metrics:
|
| 68 |
-
- name: Accuracy
|
| 69 |
-
type: accuracy
|
| 70 |
-
value: 47.50
|
| 71 |
---
|
| 72 |
|
| 73 |
-
# Dhara-70M
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
## Table of Contents
|
| 78 |
-
- [Model Description](#model-description)
|
| 79 |
-
- [Training Data](#training-data)
|
| 80 |
-
- [Training Details](#training-details)
|
| 81 |
-
- [Benchmark Results](#benchmark-results)
|
| 82 |
-
- [Usage](#usage)
|
| 83 |
-
- [Key Insights](#key-insights)
|
| 84 |
-
- [Limitations](#limitations)
|
| 85 |
-
- [Citation](#citation)
|
| 86 |
|
| 87 |
## Model Description
|
| 88 |
|
| 89 |
-
Dhara-
|
| 90 |
-
- **3.8x higher throughput** than autoregressive models
|
| 91 |
-
- **Best-in-class factuality** on TruthfulQA (47.50%)
|
| 92 |
-
- **10x training efficiency** via WSD (Warmup-Stable-Decay) conversion
|
| 93 |
|
| 94 |
-
###
|
| 95 |
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
| **Hidden Size** | 384 |
|
| 101 |
-
| **FF Dimension** | 1024 |
|
| 102 |
-
| **Attention Heads** | 8 |
|
| 103 |
-
| **KV Heads** | 4 (GQA) |
|
| 104 |
-
| **Context Length** | 2048 tokens |
|
| 105 |
-
| **Position Encoding** | RoPE |
|
| 106 |
-
| **Normalization** | RMSNorm |
|
| 107 |
-
| **Special Layers** | Canon (depthwise causal convolutions) |
|
| 108 |
-
| **Generation Type** | Diffusion (parallel token generation) |
|
| 109 |
-
|
| 110 |
-
## Training Data
|
| 111 |
-
|
| 112 |
-
Dhara was trained in two stages:
|
| 113 |
-
|
| 114 |
-
**Stage 1: AR Pretraining (1B tokens)**
|
| 115 |
-
- 40% FinePDFs (400M tokens)
|
| 116 |
-
- 30% DCLM Baseline (300M tokens)
|
| 117 |
-
- 30% FineWeb-Edu (300M tokens)
|
| 118 |
-
|
| 119 |
-
**Stage 2: WSD Conversion (100M tokens)**
|
| 120 |
-
- Progressive block size warmup (1→4→32→64→1024)
|
| 121 |
-
- MDLM diffusion objective
|
| 122 |
|
| 123 |
-
|
| 124 |
|
| 125 |
| Parameter | Value |
|
| 126 |
|-----------|-------|
|
| 127 |
-
|
|
| 128 |
-
|
|
| 129 |
-
|
|
| 130 |
-
|
|
| 131 |
-
|
|
| 132 |
-
|
|
| 133 |
-
|
|
| 134 |
-
|
|
| 135 |
-
|
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
|
| 143 |
-
|
|
| 144 |
-
| ARC-Challenge
|
| 145 |
-
|
|
| 146 |
-
|
|
| 147 |
-
|
|
| 148 |
-
| **Average** | **
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
| Metric | Dhara-70M | GPT-2-70M | Advantage |
|
| 153 |
-
|--------|-----------|-----------|-----------|
|
| 154 |
-
| Time to First Token | 35.5 ms | ~25 ms | 1.4x slower |
|
| 155 |
-
| Throughput | 183.5 tok/s | ~48 tok/s | **3.8x faster** |
|
| 156 |
-
| Peak Memory | 0.24 GB | 0.15 GB | 1.6x higher |
|
| 157 |
|
| 158 |
## Usage
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
```python
|
|
|
|
| 161 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 162 |
|
| 163 |
# Load model and tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
tokenizer = AutoTokenizer.from_pretrained("codelion/dhara-70m")
|
| 165 |
-
model = AutoModelForCausalLM.from_pretrained("codelion/dhara-70m", trust_remote_code=True)
|
| 166 |
|
| 167 |
-
#
|
| 168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
outputs = model.generate(
|
| 170 |
-
|
| 171 |
-
max_new_tokens=
|
| 172 |
-
num_diffusion_steps=10, # Diffusion denoising steps (higher = better quality)
|
| 173 |
-
do_sample=True,
|
| 174 |
temperature=0.8,
|
| 175 |
-
top_p=0.9
|
|
|
|
|
|
|
|
|
|
| 176 |
)
|
|
|
|
| 177 |
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
| 178 |
```
|
| 179 |
|
| 180 |
-
###
|
| 181 |
-
|
| 182 |
-
```python
|
| 183 |
-
# For batch generation, use larger batch sizes
|
| 184 |
-
prompts = [
|
| 185 |
-
"The future of AI is",
|
| 186 |
-
"In recent years, machine learning has",
|
| 187 |
-
"The most important discovery in physics was",
|
| 188 |
-
"Climate change affects our planet by"
|
| 189 |
-
]
|
| 190 |
-
|
| 191 |
-
inputs = tokenizer(prompts, return_tensors="pt", padding=True)
|
| 192 |
-
outputs = model.generate(
|
| 193 |
-
**inputs,
|
| 194 |
-
max_length=100,
|
| 195 |
-
do_sample=True,
|
| 196 |
-
temperature=0.7,
|
| 197 |
-
num_diffusion_steps=10 # Fewer steps = faster generation
|
| 198 |
-
)
|
| 199 |
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
-
##
|
| 205 |
|
| 206 |
-
|
| 207 |
|
| 208 |
-
|
| 209 |
|
| 210 |
-
|
| 211 |
|
| 212 |
-
|
| 213 |
|
| 214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
|
| 216 |
-
|
| 217 |
|
| 218 |
-
|
| 219 |
-
- Batch generation throughput matters
|
| 220 |
-
- Factual accuracy is critical
|
| 221 |
-
- You have an existing AR checkpoint to convert
|
| 222 |
|
| 223 |
-
**
|
| 224 |
-
-
|
| 225 |
-
-
|
| 226 |
-
-
|
|
|
|
| 227 |
|
| 228 |
## Limitations
|
| 229 |
|
| 230 |
-
-
|
| 231 |
-
-
|
| 232 |
-
-
|
| 233 |
-
- Best suited for batch rather than interactive use cases
|
| 234 |
|
| 235 |
## Citation
|
| 236 |
|
|
|
|
|
|
|
| 237 |
```bibtex
|
| 238 |
-
@
|
| 239 |
-
title={
|
| 240 |
-
author={
|
| 241 |
-
year={
|
| 242 |
-
|
|
|
|
| 243 |
}
|
| 244 |
```
|
| 245 |
|
| 246 |
-
##
|
| 247 |
-
|
| 248 |
-
- [The Optimal Architecture for Small Language Models](https://huggingface.co/blog/codelion/optimal-model-architecture) - Blog post describing this work
|
| 249 |
-
- [The 1 Billion Token Challenge: Optimal Dataset Mixing](https://huggingface.co/blog/codelion/optimal-dataset-mixing) - Our previous work on optimal pretraining data
|
| 250 |
-
- [GPT-2-70M](https://huggingface.co/codelion/gpt-2-70m) - Our previous model from optimal pretraining experiments
|
| 251 |
-
|
| 252 |
-
## Contact
|
| 253 |
|
| 254 |
-
|
|
|
|
| 2 |
license: apache-2.0
|
| 3 |
language:
|
| 4 |
- en
|
| 5 |
+
library_name: transformers
|
| 6 |
tags:
|
|
|
|
| 7 |
- diffusion
|
| 8 |
+
- masked-language-model
|
| 9 |
+
- text-generation
|
| 10 |
+
- pytorch
|
| 11 |
+
- transformers
|
| 12 |
+
pipeline_tag: text-generation
|
| 13 |
datasets:
|
| 14 |
+
- codelion/pre-training-dataset-samples
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
---
|
| 16 |
|
| 17 |
+
# Dhara-70M: Diffusion Language Model
|
| 18 |
|
| 19 |
+
Dhara is a 70M parameter diffusion language model that combines masked diffusion training with Canon layers for improved local context understanding.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
## Model Description
|
| 22 |
|
| 23 |
+
Dhara was created by converting a pre-trained autoregressive (AR) LLM to a diffusion model using **Warmup-Stable-Decay (WSD)** training. This approach preserves the language understanding capabilities of the original AR model while enabling bidirectional attention and parallel token generation.
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
+
### Key Features
|
| 26 |
|
| 27 |
+
- **Bidirectional Attention**: Unlike causal LLMs, Dhara uses full bidirectional attention during generation
|
| 28 |
+
- **Canon Layers**: Incorporates causal depthwise convolutions at positions A (before attention) and C (before MLP) for local context mixing
|
| 29 |
+
- **WSD Conversion**: Trained with 100M tokens to convert AR checkpoint to diffusion while preserving language capabilities
|
| 30 |
+
- **Custom Generate Method**: Includes a specialized `generate()` method for text generation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
+
### Architecture
|
| 33 |
|
| 34 |
| Parameter | Value |
|
| 35 |
|-----------|-------|
|
| 36 |
+
| Parameters | 70M |
|
| 37 |
+
| Hidden Size | 384 |
|
| 38 |
+
| Layers | 32 |
|
| 39 |
+
| Attention Heads | 8 |
|
| 40 |
+
| KV Heads | 4 (GQA) |
|
| 41 |
+
| Intermediate Size | 1024 |
|
| 42 |
+
| Vocabulary | 50,304 |
|
| 43 |
+
| Context Length | 1024 |
|
| 44 |
+
| Canon Kernel | 4 |
|
| 45 |
+
| Canon Positions | A, C |
|
| 46 |
+
|
| 47 |
+
## Evaluation Results
|
| 48 |
+
|
| 49 |
+
| Benchmark | Score |
|
| 50 |
+
|-----------|-------|
|
| 51 |
+
| HellaSwag | 29.42 |
|
| 52 |
+
| ARC-Easy | 43.35 |
|
| 53 |
+
| ARC-Challenge | 24.15 |
|
| 54 |
+
| PIQA | 61.48 |
|
| 55 |
+
| Winogrande | 50.75 |
|
| 56 |
+
| OpenBookQA | 19.75 |
|
| 57 |
+
| **Average** | **38.15** |
|
| 58 |
+
|
| 59 |
+
*Self-reported evaluation results across 6 standard benchmarks.*
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
## Usage
|
| 62 |
|
| 63 |
+
### Installation
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
pip install transformers torch
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
### Quick Start
|
| 70 |
+
|
| 71 |
```python
|
| 72 |
+
import torch
|
| 73 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 74 |
|
| 75 |
# Load model and tokenizer
|
| 76 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 77 |
+
"codelion/dhara-70m",
|
| 78 |
+
trust_remote_code=True,
|
| 79 |
+
torch_dtype=torch.bfloat16
|
| 80 |
+
)
|
| 81 |
tokenizer = AutoTokenizer.from_pretrained("codelion/dhara-70m")
|
|
|
|
| 82 |
|
| 83 |
+
# Move to GPU if available
|
| 84 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 85 |
+
model = model.to(device)
|
| 86 |
+
|
| 87 |
+
# Generate text
|
| 88 |
+
prompt = "The future of artificial intelligence"
|
| 89 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 90 |
+
|
| 91 |
outputs = model.generate(
|
| 92 |
+
inputs.input_ids,
|
| 93 |
+
max_new_tokens=50,
|
|
|
|
|
|
|
| 94 |
temperature=0.8,
|
| 95 |
+
top_p=0.9,
|
| 96 |
+
top_k=50,
|
| 97 |
+
repetition_penalty=1.2,
|
| 98 |
+
do_sample=True
|
| 99 |
)
|
| 100 |
+
|
| 101 |
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
| 102 |
```
|
| 103 |
|
| 104 |
+
### Generation Parameters
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
+
| Parameter | Default | Description |
|
| 107 |
+
|-----------|---------|-------------|
|
| 108 |
+
| `max_new_tokens` | 50 | Number of tokens to generate |
|
| 109 |
+
| `temperature` | 1.0 | Sampling temperature (higher = more random) |
|
| 110 |
+
| `top_p` | 0.9 | Nucleus sampling threshold |
|
| 111 |
+
| `top_k` | 50 | Top-k sampling threshold |
|
| 112 |
+
| `repetition_penalty` | 1.2 | Penalty for token repetition |
|
| 113 |
+
| `do_sample` | True | Whether to sample or use greedy decoding |
|
| 114 |
+
| `num_diffusion_steps` | 10 | Diffusion refinement steps (for future use) |
|
| 115 |
|
| 116 |
+
## Training Details
|
| 117 |
|
| 118 |
+
### Dataset
|
| 119 |
|
| 120 |
+
Dhara was trained on a curated 1B token sample from the [Pre-training Dataset Samples](https://huggingface.co/collections/codelion/pre-training-dataset-samples) collection.
|
| 121 |
|
| 122 |
+
### WSD (Warmup-Stable-Decay) Conversion
|
| 123 |
|
| 124 |
+
Dhara was converted from an autoregressive checkpoint using the WSD training schedule:
|
| 125 |
|
| 126 |
+
- **Base Model**: LLaMA-style AR model with Canon layers
|
| 127 |
+
- **Total Training Tokens**: 1B tokens (AR) + 100M tokens (diffusion conversion)
|
| 128 |
+
- **WSD Warmup Phase**: 20M tokens
|
| 129 |
+
- **WSD Stable Phase**: 80M tokens
|
| 130 |
+
- **Training Objective**: Masked Diffusion Modeling (MDM)
|
| 131 |
|
| 132 |
+
### Canon Layers
|
| 133 |
|
| 134 |
+
Canon layers are causal depthwise convolutions that provide local context mixing with O(n) complexity. Based on "Physics of Language Models: Part 4.1" by Zeyuan Allen-Zhu:
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
+
- **Position A**: Applied after input LayerNorm, before attention
|
| 137 |
+
- **Position C**: Applied after post-attention LayerNorm, before MLP
|
| 138 |
+
- **Kernel Size**: 4 tokens
|
| 139 |
+
- **Residual Connection**: Enabled
|
| 140 |
+
- **Activation**: None (as recommended for transformers)
|
| 141 |
|
| 142 |
## Limitations
|
| 143 |
|
| 144 |
+
- This is a research model and may generate inaccurate or inappropriate content
|
| 145 |
+
- Performance may vary on tasks requiring long-range dependencies
|
| 146 |
+
- The model was trained on a limited dataset and may have knowledge gaps
|
|
|
|
| 147 |
|
| 148 |
## Citation
|
| 149 |
|
| 150 |
+
If you use this model, please cite:
|
| 151 |
+
|
| 152 |
```bibtex
|
| 153 |
+
@misc{dhara2024,
|
| 154 |
+
title={Dhara: Diffusion Language Model with Canon Layers},
|
| 155 |
+
author={CodeLion},
|
| 156 |
+
year={2024},
|
| 157 |
+
publisher={HuggingFace},
|
| 158 |
+
url={https://huggingface.co/codelion/dhara-70m}
|
| 159 |
}
|
| 160 |
```
|
| 161 |
|
| 162 |
+
## License
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
+
Apache 2.0
|
config.json
CHANGED
|
@@ -2,12 +2,12 @@
|
|
| 2 |
"architectures": [
|
| 3 |
"DharaForMaskedDiffusion"
|
| 4 |
],
|
|
|
|
| 5 |
"auto_map": {
|
| 6 |
"AutoConfig": "modeling_dhara.DharaConfig",
|
| 7 |
"AutoModel": "modeling_dhara.DharaForMaskedDiffusion",
|
| 8 |
"AutoModelForCausalLM": "modeling_dhara.DharaForMaskedDiffusion"
|
| 9 |
},
|
| 10 |
-
"attention_dropout": 0.0,
|
| 11 |
"bos_token_id": 1,
|
| 12 |
"canon_activation": false,
|
| 13 |
"canon_bias": false,
|
|
@@ -15,26 +15,26 @@
|
|
| 15 |
"canon_residual": true,
|
| 16 |
"canon_set": "AC",
|
| 17 |
"eos_token_id": 2,
|
| 18 |
-
"head_dim":
|
| 19 |
"hidden_act": "silu",
|
| 20 |
"hidden_size": 384,
|
| 21 |
"initializer_range": 0.02,
|
| 22 |
"intermediate_size": 1024,
|
| 23 |
"mask_epsilon": 0.001,
|
| 24 |
"mask_token_id": 50256,
|
| 25 |
-
"max_position_embeddings":
|
| 26 |
"model_type": "dhara",
|
| 27 |
-
"num_attention_heads":
|
| 28 |
"num_diffusion_steps": 1000,
|
| 29 |
"num_hidden_layers": 32,
|
| 30 |
-
"num_key_value_heads":
|
| 31 |
"pad_token_id": 0,
|
| 32 |
-
"rms_norm_eps": 1e-
|
| 33 |
"rope_theta": 10000.0,
|
| 34 |
-
"torch_dtype": "
|
| 35 |
"transformers_version": "4.55.2",
|
| 36 |
"use_cache": false,
|
| 37 |
"use_flash_attention": false,
|
| 38 |
"use_xformers": false,
|
| 39 |
-
"vocab_size":
|
| 40 |
}
|
|
|
|
| 2 |
"architectures": [
|
| 3 |
"DharaForMaskedDiffusion"
|
| 4 |
],
|
| 5 |
+
"attention_dropout": 0.0,
|
| 6 |
"auto_map": {
|
| 7 |
"AutoConfig": "modeling_dhara.DharaConfig",
|
| 8 |
"AutoModel": "modeling_dhara.DharaForMaskedDiffusion",
|
| 9 |
"AutoModelForCausalLM": "modeling_dhara.DharaForMaskedDiffusion"
|
| 10 |
},
|
|
|
|
| 11 |
"bos_token_id": 1,
|
| 12 |
"canon_activation": false,
|
| 13 |
"canon_bias": false,
|
|
|
|
| 15 |
"canon_residual": true,
|
| 16 |
"canon_set": "AC",
|
| 17 |
"eos_token_id": 2,
|
| 18 |
+
"head_dim": 48,
|
| 19 |
"hidden_act": "silu",
|
| 20 |
"hidden_size": 384,
|
| 21 |
"initializer_range": 0.02,
|
| 22 |
"intermediate_size": 1024,
|
| 23 |
"mask_epsilon": 0.001,
|
| 24 |
"mask_token_id": 50256,
|
| 25 |
+
"max_position_embeddings": 1024,
|
| 26 |
"model_type": "dhara",
|
| 27 |
+
"num_attention_heads": 8,
|
| 28 |
"num_diffusion_steps": 1000,
|
| 29 |
"num_hidden_layers": 32,
|
| 30 |
+
"num_key_value_heads": 4,
|
| 31 |
"pad_token_id": 0,
|
| 32 |
+
"rms_norm_eps": 1e-06,
|
| 33 |
"rope_theta": 10000.0,
|
| 34 |
+
"torch_dtype": "bfloat16",
|
| 35 |
"transformers_version": "4.55.2",
|
| 36 |
"use_cache": false,
|
| 37 |
"use_flash_attention": false,
|
| 38 |
"use_xformers": false,
|
| 39 |
+
"vocab_size": 50304
|
| 40 |
}
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e67749b3a03df19a3a36cbf3994997df649b5c2044cbf581a22e5eb8f473975f
|
| 3 |
+
size 142728304
|
modeling_dhara.py
CHANGED
|
@@ -1,24 +1,29 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
Dhara: Diffusion
|
| 4 |
|
| 5 |
-
|
| 6 |
-
1.
|
| 7 |
-
2. Canon layers
|
| 8 |
-
3. High-throughput parallel token generation
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
"""
|
| 15 |
|
| 16 |
import math
|
| 17 |
-
|
|
|
|
| 18 |
|
| 19 |
import torch
|
| 20 |
import torch.nn as nn
|
| 21 |
import torch.nn.functional as F
|
|
|
|
| 22 |
|
| 23 |
from transformers import PreTrainedModel
|
| 24 |
from transformers.generation import GenerationMixin
|
|
@@ -36,12 +41,18 @@ try:
|
|
| 36 |
except ImportError:
|
| 37 |
FLASH_ATTN_AVAILABLE = False
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
class DharaConfig(PretrainedConfig):
|
| 41 |
"""
|
| 42 |
Configuration for Dhara model.
|
| 43 |
|
| 44 |
-
Dhara
|
| 45 |
"""
|
| 46 |
|
| 47 |
model_type = "dhara"
|
|
@@ -49,33 +60,33 @@ class DharaConfig(PretrainedConfig):
|
|
| 49 |
def __init__(
|
| 50 |
self,
|
| 51 |
# Core architecture
|
| 52 |
-
vocab_size: int =
|
| 53 |
hidden_size: int = 384,
|
| 54 |
num_hidden_layers: int = 32,
|
| 55 |
-
num_attention_heads: int =
|
| 56 |
-
num_key_value_heads: int =
|
| 57 |
intermediate_size: int = 1024,
|
| 58 |
head_dim: int = None,
|
| 59 |
max_position_embeddings: int = 2048,
|
| 60 |
|
| 61 |
# Model specifics
|
| 62 |
hidden_act: str = "silu",
|
| 63 |
-
rms_norm_eps: float = 1e-
|
| 64 |
rope_theta: float = 10000.0,
|
| 65 |
initializer_range: float = 0.02,
|
| 66 |
tie_word_embeddings: bool = True,
|
| 67 |
attention_dropout: float = 0.0,
|
| 68 |
|
| 69 |
# Canon layer parameters
|
| 70 |
-
canon_set: str = "AC",
|
| 71 |
-
canon_kernel: int = 4,
|
| 72 |
-
canon_residual: bool = True,
|
| 73 |
-
canon_activation: bool = False,
|
| 74 |
canon_bias: bool = False,
|
| 75 |
|
| 76 |
# Diffusion specific
|
| 77 |
-
mask_token_id: int =
|
| 78 |
-
mask_epsilon: float = 0.001,
|
| 79 |
num_diffusion_steps: int = 1000,
|
| 80 |
|
| 81 |
# Special tokens
|
|
@@ -85,7 +96,7 @@ class DharaConfig(PretrainedConfig):
|
|
| 85 |
|
| 86 |
# Performance flags
|
| 87 |
use_cache: bool = False,
|
| 88 |
-
use_flash_attention: bool =
|
| 89 |
use_xformers: bool = False,
|
| 90 |
|
| 91 |
**kwargs
|
|
@@ -98,6 +109,7 @@ class DharaConfig(PretrainedConfig):
|
|
| 98 |
**kwargs
|
| 99 |
)
|
| 100 |
|
|
|
|
| 101 |
self.vocab_size = vocab_size
|
| 102 |
self.hidden_size = hidden_size
|
| 103 |
self.num_hidden_layers = num_hidden_layers
|
|
@@ -107,29 +119,44 @@ class DharaConfig(PretrainedConfig):
|
|
| 107 |
self.head_dim = head_dim or (hidden_size // num_attention_heads)
|
| 108 |
self.max_position_embeddings = max_position_embeddings
|
| 109 |
|
|
|
|
| 110 |
self.hidden_act = hidden_act
|
| 111 |
self.rms_norm_eps = rms_norm_eps
|
| 112 |
self.rope_theta = rope_theta
|
| 113 |
self.initializer_range = initializer_range
|
|
|
|
| 114 |
self.attention_dropout = attention_dropout
|
| 115 |
|
|
|
|
| 116 |
self.canon_set = canon_set
|
| 117 |
self.canon_kernel = canon_kernel
|
| 118 |
self.canon_residual = canon_residual
|
| 119 |
self.canon_activation = canon_activation
|
| 120 |
self.canon_bias = canon_bias
|
| 121 |
|
| 122 |
-
|
|
|
|
| 123 |
self.mask_epsilon = mask_epsilon
|
| 124 |
self.num_diffusion_steps = num_diffusion_steps
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
self.use_cache = use_cache
|
| 127 |
self.use_flash_attention = use_flash_attention
|
| 128 |
self.use_xformers = use_xformers
|
| 129 |
|
| 130 |
|
| 131 |
class CanonLayer(nn.Module):
|
| 132 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
def __init__(
|
| 135 |
self,
|
|
@@ -145,29 +172,49 @@ class CanonLayer(nn.Module):
|
|
| 145 |
self.use_residual = use_residual
|
| 146 |
self.use_activation = use_activation
|
| 147 |
|
|
|
|
| 148 |
self.conv = nn.Conv1d(
|
| 149 |
in_channels=hidden_size,
|
| 150 |
out_channels=hidden_size,
|
| 151 |
kernel_size=kernel_size,
|
| 152 |
-
padding=kernel_size - 1,
|
| 153 |
-
groups=hidden_size,
|
| 154 |
bias=use_bias,
|
| 155 |
)
|
| 156 |
|
|
|
|
| 157 |
nn.init.normal_(self.conv.weight, mean=0.0, std=0.02)
|
| 158 |
if use_bias:
|
| 159 |
nn.init.zeros_(self.conv.bias)
|
| 160 |
|
| 161 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
batch_size, seq_len, hidden_size = hidden_states.shape
|
|
|
|
|
|
|
| 163 |
x = hidden_states.transpose(1, 2)
|
|
|
|
|
|
|
| 164 |
out = self.conv(x)
|
|
|
|
| 165 |
out = out[:, :, :seq_len]
|
|
|
|
|
|
|
| 166 |
if self.use_activation:
|
| 167 |
out = F.silu(out)
|
|
|
|
|
|
|
| 168 |
out = out.transpose(1, 2)
|
|
|
|
|
|
|
| 169 |
if self.use_residual:
|
| 170 |
out = hidden_states + out
|
|
|
|
| 171 |
return out
|
| 172 |
|
| 173 |
|
|
@@ -206,6 +253,7 @@ class RotaryEmbedding(nn.Module):
|
|
| 206 |
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 207 |
self.max_seq_len_cached = seq_len
|
| 208 |
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
|
|
|
| 209 |
freqs = torch.outer(t, self.inv_freq)
|
| 210 |
emb = torch.cat((freqs, freqs), dim=-1)
|
| 211 |
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
|
@@ -214,6 +262,7 @@ class RotaryEmbedding(nn.Module):
|
|
| 214 |
def forward(self, x, seq_len=None):
|
| 215 |
if seq_len > self.max_seq_len_cached:
|
| 216 |
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
|
|
|
| 217 |
return (
|
| 218 |
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
| 219 |
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
|
@@ -221,14 +270,17 @@ class RotaryEmbedding(nn.Module):
|
|
| 221 |
|
| 222 |
|
| 223 |
def rotate_half(x):
|
|
|
|
| 224 |
x1 = x[..., : x.shape[-1] // 2]
|
| 225 |
x2 = x[..., x.shape[-1] // 2 :]
|
| 226 |
return torch.cat((-x2, x1), dim=-1)
|
| 227 |
|
| 228 |
|
| 229 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
|
|
|
| 230 |
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
| 231 |
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
|
|
|
| 232 |
cos = cos.to(q.dtype)
|
| 233 |
sin = sin.to(q.dtype)
|
| 234 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
@@ -241,12 +293,14 @@ class DharaMLP(nn.Module):
|
|
| 241 |
|
| 242 |
def __init__(self, config):
|
| 243 |
super().__init__()
|
|
|
|
| 244 |
self.hidden_size = config.hidden_size
|
| 245 |
self.intermediate_size = config.intermediate_size
|
| 246 |
|
| 247 |
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 248 |
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 249 |
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
|
|
| 250 |
self.act_fn = nn.SiLU()
|
| 251 |
|
| 252 |
def forward(self, x):
|
|
@@ -254,6 +308,7 @@ class DharaMLP(nn.Module):
|
|
| 254 |
|
| 255 |
|
| 256 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
|
|
| 257 |
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 258 |
if n_rep == 1:
|
| 259 |
return hidden_states
|
|
@@ -262,7 +317,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
| 262 |
|
| 263 |
|
| 264 |
class DharaAttention(nn.Module):
|
| 265 |
-
"""Multi-Head Bidirectional Attention with GQA support"""
|
| 266 |
|
| 267 |
def __init__(self, config: DharaConfig, layer_idx: Optional[int] = None):
|
| 268 |
super().__init__()
|
|
@@ -277,7 +332,13 @@ class DharaAttention(nn.Module):
|
|
| 277 |
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 278 |
self.max_position_embeddings = config.max_position_embeddings
|
| 279 |
self.rope_theta = config.rope_theta
|
| 280 |
-
self.is_causal = False #
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
|
| 282 |
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 283 |
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
|
@@ -311,6 +372,12 @@ class DharaAttention(nn.Module):
|
|
| 311 |
|
| 312 |
kv_seq_len = key_states.shape[-2]
|
| 313 |
if past_key_value is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 315 |
|
| 316 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
@@ -323,6 +390,7 @@ class DharaAttention(nn.Module):
|
|
| 323 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 324 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 325 |
|
|
|
|
| 326 |
if FLASH_ATTN_AVAILABLE and self.config.use_flash_attention:
|
| 327 |
query_states = query_states.transpose(1, 2).contiguous()
|
| 328 |
key_states = key_states.transpose(1, 2).contiguous()
|
|
@@ -334,17 +402,26 @@ class DharaAttention(nn.Module):
|
|
| 334 |
value_states = value_states.to(torch.bfloat16)
|
| 335 |
|
| 336 |
attn_output = flash_attn_func(
|
| 337 |
-
query_states,
|
| 338 |
-
|
|
|
|
|
|
|
|
|
|
| 339 |
)
|
|
|
|
| 340 |
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
|
|
|
| 341 |
else:
|
|
|
|
| 342 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
|
|
| 343 |
if attention_mask is not None:
|
| 344 |
attn_weights = attn_weights + attention_mask
|
|
|
|
| 345 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 346 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
| 347 |
attn_output = torch.matmul(attn_weights, value_states)
|
|
|
|
| 348 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 349 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 350 |
|
|
@@ -357,15 +434,23 @@ class DharaAttention(nn.Module):
|
|
| 357 |
|
| 358 |
|
| 359 |
class DharaDecoderLayer(nn.Module):
|
| 360 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
|
| 362 |
def __init__(self, config: DharaConfig, layer_idx: int):
|
| 363 |
super().__init__()
|
| 364 |
self.hidden_size = config.hidden_size
|
| 365 |
self.config = config
|
| 366 |
|
|
|
|
| 367 |
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 368 |
|
|
|
|
| 369 |
self.canon_a = None
|
| 370 |
if "A" in config.canon_set:
|
| 371 |
self.canon_a = CanonLayer(
|
|
@@ -376,9 +461,13 @@ class DharaDecoderLayer(nn.Module):
|
|
| 376 |
use_bias=config.canon_bias,
|
| 377 |
)
|
| 378 |
|
|
|
|
| 379 |
self.self_attn = DharaAttention(config=config, layer_idx=layer_idx)
|
|
|
|
|
|
|
| 380 |
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 381 |
|
|
|
|
| 382 |
self.canon_c = None
|
| 383 |
if "C" in config.canon_set:
|
| 384 |
self.canon_c = CanonLayer(
|
|
@@ -389,6 +478,7 @@ class DharaDecoderLayer(nn.Module):
|
|
| 389 |
use_bias=config.canon_bias,
|
| 390 |
)
|
| 391 |
|
|
|
|
| 392 |
self.mlp = DharaMLP(config)
|
| 393 |
|
| 394 |
def forward(
|
|
@@ -401,11 +491,15 @@ class DharaDecoderLayer(nn.Module):
|
|
| 401 |
use_cache: Optional[bool] = False,
|
| 402 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 403 |
residual = hidden_states
|
|
|
|
|
|
|
| 404 |
hidden_states = self.input_layernorm(hidden_states)
|
| 405 |
|
|
|
|
| 406 |
if self.canon_a is not None:
|
| 407 |
hidden_states = self.canon_a(hidden_states)
|
| 408 |
|
|
|
|
| 409 |
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
| 410 |
hidden_states=hidden_states,
|
| 411 |
attention_mask=attention_mask,
|
|
@@ -416,9 +510,11 @@ class DharaDecoderLayer(nn.Module):
|
|
| 416 |
)
|
| 417 |
hidden_states = residual + hidden_states
|
| 418 |
|
|
|
|
| 419 |
residual = hidden_states
|
| 420 |
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 421 |
|
|
|
|
| 422 |
if self.canon_c is not None:
|
| 423 |
hidden_states = self.canon_c(hidden_states)
|
| 424 |
|
|
@@ -426,8 +522,10 @@ class DharaDecoderLayer(nn.Module):
|
|
| 426 |
hidden_states = residual + hidden_states
|
| 427 |
|
| 428 |
outputs = (hidden_states,)
|
|
|
|
| 429 |
if output_attentions:
|
| 430 |
outputs += (self_attn_weights,)
|
|
|
|
| 431 |
if use_cache:
|
| 432 |
outputs += (present_key_value,)
|
| 433 |
|
|
@@ -456,7 +554,9 @@ class DharaPreTrainedModel(PreTrainedModel):
|
|
| 456 |
|
| 457 |
|
| 458 |
class DharaModel(DharaPreTrainedModel):
|
| 459 |
-
"""
|
|
|
|
|
|
|
| 460 |
|
| 461 |
def __init__(self, config: DharaConfig):
|
| 462 |
super().__init__(config)
|
|
@@ -467,6 +567,7 @@ class DharaModel(DharaPreTrainedModel):
|
|
| 467 |
self.layers = nn.ModuleList(
|
| 468 |
[DharaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 469 |
)
|
|
|
|
| 470 |
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 471 |
self.gradient_checkpointing = False
|
| 472 |
|
|
@@ -495,12 +596,14 @@ class DharaModel(DharaPreTrainedModel):
|
|
| 495 |
return_dict: Optional[bool] = None,
|
| 496 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 497 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 498 |
-
output_hidden_states =
|
|
|
|
|
|
|
| 499 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 500 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 501 |
|
| 502 |
if input_ids is not None and inputs_embeds is not None:
|
| 503 |
-
raise ValueError("You cannot specify both input_ids and inputs_embeds")
|
| 504 |
elif input_ids is not None:
|
| 505 |
batch_size, seq_length = input_ids.shape[:2]
|
| 506 |
elif inputs_embeds is not None:
|
|
@@ -508,8 +611,12 @@ class DharaModel(DharaPreTrainedModel):
|
|
| 508 |
else:
|
| 509 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 510 |
|
| 511 |
-
if self.gradient_checkpointing and self.training
|
| 512 |
-
use_cache
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
|
| 514 |
past_key_values_length = 0
|
| 515 |
if use_cache:
|
|
@@ -531,8 +638,10 @@ class DharaModel(DharaPreTrainedModel):
|
|
| 531 |
if self._use_flash_attention_2:
|
| 532 |
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
| 533 |
else:
|
|
|
|
| 534 |
if attention_mask is not None:
|
| 535 |
if attention_mask.dim() == 2:
|
|
|
|
| 536 |
attention_mask_4d = attention_mask[:, None, None, :].expand(
|
| 537 |
batch_size, 1, seq_length, seq_length
|
| 538 |
).to(dtype=inputs_embeds.dtype)
|
|
@@ -541,8 +650,13 @@ class DharaModel(DharaPreTrainedModel):
|
|
| 541 |
torch.tensor(float('-inf'), dtype=inputs_embeds.dtype, device=attention_mask_4d.device),
|
| 542 |
torch.tensor(0.0, dtype=inputs_embeds.dtype, device=attention_mask_4d.device)
|
| 543 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 544 |
|
| 545 |
hidden_states = inputs_embeds
|
|
|
|
| 546 |
all_hidden_states = () if output_hidden_states else None
|
| 547 |
all_self_attns = () if output_attentions else None
|
| 548 |
next_decoder_cache = None
|
|
@@ -554,8 +668,12 @@ class DharaModel(DharaPreTrainedModel):
|
|
| 554 |
if self.gradient_checkpointing and self.training:
|
| 555 |
layer_outputs = self._gradient_checkpointing_func(
|
| 556 |
decoder_layer.__call__,
|
| 557 |
-
hidden_states,
|
| 558 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 559 |
)
|
| 560 |
else:
|
| 561 |
layer_outputs = decoder_layer(
|
|
@@ -571,6 +689,7 @@ class DharaModel(DharaPreTrainedModel):
|
|
| 571 |
|
| 572 |
if use_cache:
|
| 573 |
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
|
|
|
| 574 |
if output_attentions:
|
| 575 |
all_self_attns += (layer_outputs[1],)
|
| 576 |
|
|
@@ -594,23 +713,36 @@ class DharaModel(DharaPreTrainedModel):
|
|
| 594 |
)
|
| 595 |
|
| 596 |
def add_noise_to_tokens(self, input_ids: torch.LongTensor, t: torch.FloatTensor, eps: float = None):
|
| 597 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 598 |
batch_size, seq_len = input_ids.shape
|
| 599 |
device = input_ids.device
|
| 600 |
|
| 601 |
if eps is None:
|
| 602 |
eps = getattr(self.config, 'mask_epsilon', 0.001)
|
| 603 |
p_mask = (1 - eps) * t + eps
|
|
|
|
| 604 |
p_mask = p_mask.unsqueeze(-1).expand(batch_size, seq_len)
|
| 605 |
|
| 606 |
corruption_mask = torch.rand(batch_size, seq_len, device=device) < p_mask
|
| 607 |
-
|
|
|
|
|
|
|
| 608 |
|
| 609 |
return noisy_input_ids, corruption_mask, p_mask
|
| 610 |
|
| 611 |
|
| 612 |
class DharaForMaskedDiffusion(DharaPreTrainedModel, GenerationMixin):
|
| 613 |
-
"""Dhara Model with Masked Diffusion head for training
|
| 614 |
_tied_weights_keys = ["lm_head.weight"]
|
| 615 |
|
| 616 |
def __init__(self, config):
|
|
@@ -636,6 +768,9 @@ class DharaForMaskedDiffusion(DharaPreTrainedModel, GenerationMixin):
|
|
| 636 |
def set_output_embeddings(self, new_embeddings):
|
| 637 |
self.lm_head = new_embeddings
|
| 638 |
|
|
|
|
|
|
|
|
|
|
| 639 |
def get_decoder(self):
|
| 640 |
return self.model
|
| 641 |
|
|
@@ -655,7 +790,9 @@ class DharaForMaskedDiffusion(DharaPreTrainedModel, GenerationMixin):
|
|
| 655 |
p_mask: Optional[torch.Tensor] = None,
|
| 656 |
) -> Union[Tuple, MaskedLMOutput]:
|
| 657 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 658 |
-
output_hidden_states =
|
|
|
|
|
|
|
| 659 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 660 |
|
| 661 |
outputs = self.model(
|
|
@@ -693,9 +830,13 @@ class DharaForMaskedDiffusion(DharaPreTrainedModel, GenerationMixin):
|
|
| 693 |
)
|
| 694 |
|
| 695 |
def compute_diffusion_loss(self, logits, labels, corruption_mask=None, p_mask=None):
|
| 696 |
-
"""
|
|
|
|
|
|
|
| 697 |
if corruption_mask is None or p_mask is None:
|
| 698 |
-
raise ValueError(
|
|
|
|
|
|
|
| 699 |
|
| 700 |
loss = F.cross_entropy(
|
| 701 |
logits.view(-1, self.config.vocab_size),
|
|
@@ -706,6 +847,7 @@ class DharaForMaskedDiffusion(DharaPreTrainedModel, GenerationMixin):
|
|
| 706 |
|
| 707 |
masked_losses = loss[corruption_mask]
|
| 708 |
masked_p_mask = p_mask[corruption_mask]
|
|
|
|
| 709 |
weighted_losses = masked_losses / masked_p_mask
|
| 710 |
|
| 711 |
total_positions = labels.shape[0] * labels.shape[1]
|
|
@@ -728,11 +870,15 @@ class DharaForMaskedDiffusion(DharaPreTrainedModel, GenerationMixin):
|
|
| 728 |
max_cache_length = None
|
| 729 |
|
| 730 |
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
| 731 |
-
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
|
| 732 |
elif past_length < input_ids.shape[1]:
|
| 733 |
input_ids = input_ids[:, past_length:]
|
| 734 |
|
| 735 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 736 |
attention_mask = attention_mask[:, -max_cache_length:]
|
| 737 |
|
| 738 |
position_ids = kwargs.get("position_ids", None)
|
|
@@ -740,21 +886,32 @@ class DharaForMaskedDiffusion(DharaPreTrainedModel, GenerationMixin):
|
|
| 740 |
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 741 |
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 742 |
if past_key_values:
|
| 743 |
-
position_ids = position_ids[:, -input_ids.shape[1]:]
|
| 744 |
|
| 745 |
if inputs_embeds is not None and past_key_values is None:
|
| 746 |
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 747 |
else:
|
| 748 |
model_inputs = {"input_ids": input_ids}
|
| 749 |
|
| 750 |
-
model_inputs.update(
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
|
|
|
|
|
|
| 756 |
return model_inputs
|
| 757 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 758 |
@torch.no_grad()
|
| 759 |
def generate(
|
| 760 |
self,
|
|
@@ -764,27 +921,32 @@ class DharaForMaskedDiffusion(DharaPreTrainedModel, GenerationMixin):
|
|
| 764 |
num_diffusion_steps: int = 10,
|
| 765 |
temperature: float = 1.0,
|
| 766 |
top_p: float = 0.9,
|
|
|
|
| 767 |
do_sample: bool = True,
|
| 768 |
pad_token_id: Optional[int] = None,
|
| 769 |
eos_token_id: Optional[int] = None,
|
|
|
|
| 770 |
**kwargs
|
| 771 |
) -> torch.LongTensor:
|
| 772 |
"""
|
| 773 |
-
Generate text using
|
| 774 |
|
| 775 |
-
|
| 776 |
-
|
|
|
|
| 777 |
|
| 778 |
Args:
|
| 779 |
input_ids: Input prompt token IDs [batch_size, prompt_len]
|
| 780 |
max_length: Maximum total sequence length (prompt + generation)
|
| 781 |
max_new_tokens: Number of new tokens to generate (alternative to max_length)
|
| 782 |
-
num_diffusion_steps: Number of
|
| 783 |
temperature: Sampling temperature (higher = more random)
|
| 784 |
top_p: Nucleus sampling threshold
|
|
|
|
| 785 |
do_sample: Whether to sample or take argmax
|
| 786 |
pad_token_id: Token ID for padding
|
| 787 |
eos_token_id: Token ID for end of sequence
|
|
|
|
| 788 |
|
| 789 |
Returns:
|
| 790 |
Generated token IDs including the prompt
|
|
@@ -816,97 +978,134 @@ class DharaForMaskedDiffusion(DharaPreTrainedModel, GenerationMixin):
|
|
| 816 |
if eos_token_id is None:
|
| 817 |
eos_token_id = self.config.eos_token_id if hasattr(self.config, 'eos_token_id') else 2
|
| 818 |
|
| 819 |
-
#
|
| 820 |
-
|
| 821 |
-
tokens = torch.full((batch_size, total_len), mask_token_id, dtype=torch.long, device=device)
|
| 822 |
-
tokens[:, :prompt_len] = input_ids
|
| 823 |
-
|
| 824 |
-
# Track which positions are masked (need generation)
|
| 825 |
-
is_masked = torch.ones(batch_size, total_len, dtype=torch.bool, device=device)
|
| 826 |
-
is_masked[:, :prompt_len] = False # Prompt is not masked
|
| 827 |
-
|
| 828 |
-
# Number of tokens to unmask per step
|
| 829 |
-
tokens_per_step = max(1, gen_len // num_diffusion_steps)
|
| 830 |
|
| 831 |
-
#
|
| 832 |
-
|
| 833 |
-
|
| 834 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 835 |
logits = outputs.logits # [batch, seq_len, vocab]
|
| 836 |
|
| 837 |
-
#
|
| 838 |
-
|
| 839 |
|
| 840 |
-
|
| 841 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 842 |
|
| 843 |
# Apply temperature
|
| 844 |
-
if temperature != 1.0:
|
| 845 |
-
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
|
| 856 |
-
|
| 857 |
-
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
|
| 861 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 862 |
else:
|
| 863 |
-
|
| 864 |
-
|
| 865 |
-
remaining_masked
|
| 866 |
-
)
|
| 867 |
-
|
| 868 |
-
# For each batch item, unmask the highest confidence positions
|
| 869 |
-
for b in range(batch_size):
|
| 870 |
-
if num_to_unmask[b] == 0:
|
| 871 |
-
continue
|
| 872 |
-
|
| 873 |
-
# Get confidence scores for this batch item
|
| 874 |
-
conf_b = confidence[b] # [seq_len]
|
| 875 |
-
|
| 876 |
-
# Get top-k positions with highest confidence
|
| 877 |
-
k = int(num_to_unmask[b].item())
|
| 878 |
-
_, top_indices = conf_b.topk(k)
|
| 879 |
-
|
| 880 |
-
# Sample or argmax for these positions
|
| 881 |
-
for idx in top_indices:
|
| 882 |
-
pos_logits = logits[b, idx] # [vocab]
|
| 883 |
-
|
| 884 |
-
if do_sample and temperature > 0:
|
| 885 |
-
# Top-p (nucleus) sampling
|
| 886 |
-
sorted_logits, sorted_indices = torch.sort(pos_logits, descending=True)
|
| 887 |
-
sorted_probs = F.softmax(sorted_logits, dim=-1)
|
| 888 |
-
cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
|
| 889 |
-
|
| 890 |
-
# Remove tokens with cumulative probability above top_p
|
| 891 |
-
sorted_indices_to_remove = cumsum_probs > top_p
|
| 892 |
-
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
|
| 893 |
-
sorted_indices_to_remove[0] = False
|
| 894 |
|
| 895 |
-
|
| 896 |
-
|
| 897 |
|
| 898 |
-
|
| 899 |
-
|
| 900 |
-
|
| 901 |
-
else:
|
| 902 |
-
# Greedy (argmax)
|
| 903 |
-
token_id = pos_logits.argmax()
|
| 904 |
|
| 905 |
-
|
| 906 |
-
|
|
|
|
| 907 |
|
| 908 |
-
return
|
| 909 |
|
| 910 |
def save_pretrained(self, save_directory, **kwargs):
|
|
|
|
| 911 |
kwargs['safe_serialization'] = kwargs.get('safe_serialization', True)
|
| 912 |
return super().save_pretrained(save_directory, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
Dhara: Diffusion LLM with Canon Layers
|
| 4 |
|
| 5 |
+
Combines:
|
| 6 |
+
1. Dhara's masked diffusion training (bidirectional attention, high throughput)
|
| 7 |
+
2. Canon layers (local context mixing via causal depthwise convolutions)
|
|
|
|
| 8 |
|
| 9 |
+
Canon layers from "Physics of Language Models: Part 4.1" by Zeyuan Allen-Zhu:
|
| 10 |
+
- Position A: After input LayerNorm, before attention
|
| 11 |
+
- Position C: After post-attention LayerNorm, before MLP
|
| 12 |
+
- kernel_size=4, residual=True, activation=False (default)
|
| 13 |
+
|
| 14 |
+
Expected benefits:
|
| 15 |
+
- ~280-290 tok/s throughput (Dhara's parallel generation)
|
| 16 |
+
- +0.25-0.5% accuracy improvement (Canon's local context mixing)
|
| 17 |
"""
|
| 18 |
|
| 19 |
import math
|
| 20 |
+
import warnings
|
| 21 |
+
from typing import Optional, Tuple, Union, List
|
| 22 |
|
| 23 |
import torch
|
| 24 |
import torch.nn as nn
|
| 25 |
import torch.nn.functional as F
|
| 26 |
+
from torch.nn import CrossEntropyLoss
|
| 27 |
|
| 28 |
from transformers import PreTrainedModel
|
| 29 |
from transformers.generation import GenerationMixin
|
|
|
|
| 41 |
except ImportError:
|
| 42 |
FLASH_ATTN_AVAILABLE = False
|
| 43 |
|
| 44 |
+
try:
|
| 45 |
+
import xformers.ops as xops
|
| 46 |
+
XFORMERS_AVAILABLE = True
|
| 47 |
+
except ImportError:
|
| 48 |
+
XFORMERS_AVAILABLE = False
|
| 49 |
+
|
| 50 |
|
| 51 |
class DharaConfig(PretrainedConfig):
|
| 52 |
"""
|
| 53 |
Configuration for Dhara model.
|
| 54 |
|
| 55 |
+
Combines Dhara diffusion config with Canon layer parameters.
|
| 56 |
"""
|
| 57 |
|
| 58 |
model_type = "dhara"
|
|
|
|
| 60 |
def __init__(
|
| 61 |
self,
|
| 62 |
# Core architecture
|
| 63 |
+
vocab_size: int = 50304,
|
| 64 |
hidden_size: int = 384,
|
| 65 |
num_hidden_layers: int = 32,
|
| 66 |
+
num_attention_heads: int = 8,
|
| 67 |
+
num_key_value_heads: int = 4,
|
| 68 |
intermediate_size: int = 1024,
|
| 69 |
head_dim: int = None,
|
| 70 |
max_position_embeddings: int = 2048,
|
| 71 |
|
| 72 |
# Model specifics
|
| 73 |
hidden_act: str = "silu",
|
| 74 |
+
rms_norm_eps: float = 1e-6,
|
| 75 |
rope_theta: float = 10000.0,
|
| 76 |
initializer_range: float = 0.02,
|
| 77 |
tie_word_embeddings: bool = True,
|
| 78 |
attention_dropout: float = 0.0,
|
| 79 |
|
| 80 |
# Canon layer parameters
|
| 81 |
+
canon_set: str = "AC", # Positions: A (before attn), C (before MLP)
|
| 82 |
+
canon_kernel: int = 4, # Kernel size (2-4)
|
| 83 |
+
canon_residual: bool = True, # Highly recommended
|
| 84 |
+
canon_activation: bool = False, # NOT recommended for transformers
|
| 85 |
canon_bias: bool = False,
|
| 86 |
|
| 87 |
# Diffusion specific
|
| 88 |
+
mask_token_id: int = None, # Will be set from tokenizer
|
| 89 |
+
mask_epsilon: float = 0.001, # Minimum mask probability
|
| 90 |
num_diffusion_steps: int = 1000,
|
| 91 |
|
| 92 |
# Special tokens
|
|
|
|
| 96 |
|
| 97 |
# Performance flags
|
| 98 |
use_cache: bool = False,
|
| 99 |
+
use_flash_attention: bool = True,
|
| 100 |
use_xformers: bool = False,
|
| 101 |
|
| 102 |
**kwargs
|
|
|
|
| 109 |
**kwargs
|
| 110 |
)
|
| 111 |
|
| 112 |
+
# Core architecture
|
| 113 |
self.vocab_size = vocab_size
|
| 114 |
self.hidden_size = hidden_size
|
| 115 |
self.num_hidden_layers = num_hidden_layers
|
|
|
|
| 119 |
self.head_dim = head_dim or (hidden_size // num_attention_heads)
|
| 120 |
self.max_position_embeddings = max_position_embeddings
|
| 121 |
|
| 122 |
+
# Model specifics
|
| 123 |
self.hidden_act = hidden_act
|
| 124 |
self.rms_norm_eps = rms_norm_eps
|
| 125 |
self.rope_theta = rope_theta
|
| 126 |
self.initializer_range = initializer_range
|
| 127 |
+
self.tie_word_embeddings = tie_word_embeddings
|
| 128 |
self.attention_dropout = attention_dropout
|
| 129 |
|
| 130 |
+
# Canon parameters
|
| 131 |
self.canon_set = canon_set
|
| 132 |
self.canon_kernel = canon_kernel
|
| 133 |
self.canon_residual = canon_residual
|
| 134 |
self.canon_activation = canon_activation
|
| 135 |
self.canon_bias = canon_bias
|
| 136 |
|
| 137 |
+
# Diffusion specific
|
| 138 |
+
self.mask_token_id = mask_token_id if mask_token_id is not None else (vocab_size - 1)
|
| 139 |
self.mask_epsilon = mask_epsilon
|
| 140 |
self.num_diffusion_steps = num_diffusion_steps
|
| 141 |
|
| 142 |
+
# Special tokens
|
| 143 |
+
self.bos_token_id = bos_token_id
|
| 144 |
+
self.eos_token_id = eos_token_id
|
| 145 |
+
self.pad_token_id = pad_token_id
|
| 146 |
+
|
| 147 |
+
# Performance
|
| 148 |
self.use_cache = use_cache
|
| 149 |
self.use_flash_attention = use_flash_attention
|
| 150 |
self.use_xformers = use_xformers
|
| 151 |
|
| 152 |
|
| 153 |
class CanonLayer(nn.Module):
|
| 154 |
+
"""
|
| 155 |
+
Canon Layer: Causal 1D depthwise convolution for local context mixing.
|
| 156 |
+
|
| 157 |
+
From "Physics of Language Models: Part 4.1" by Zeyuan Allen-Zhu.
|
| 158 |
+
Captures local sequential dependencies with O(n) complexity.
|
| 159 |
+
"""
|
| 160 |
|
| 161 |
def __init__(
|
| 162 |
self,
|
|
|
|
| 172 |
self.use_residual = use_residual
|
| 173 |
self.use_activation = use_activation
|
| 174 |
|
| 175 |
+
# Depthwise causal convolution
|
| 176 |
self.conv = nn.Conv1d(
|
| 177 |
in_channels=hidden_size,
|
| 178 |
out_channels=hidden_size,
|
| 179 |
kernel_size=kernel_size,
|
| 180 |
+
padding=kernel_size - 1, # Causal (left-pad)
|
| 181 |
+
groups=hidden_size, # Depthwise
|
| 182 |
bias=use_bias,
|
| 183 |
)
|
| 184 |
|
| 185 |
+
# Initialize for stability
|
| 186 |
nn.init.normal_(self.conv.weight, mean=0.0, std=0.02)
|
| 187 |
if use_bias:
|
| 188 |
nn.init.zeros_(self.conv.bias)
|
| 189 |
|
| 190 |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 191 |
+
"""
|
| 192 |
+
Args:
|
| 193 |
+
hidden_states: [batch_size, seq_len, hidden_size]
|
| 194 |
+
Returns:
|
| 195 |
+
output: [batch_size, seq_len, hidden_size]
|
| 196 |
+
"""
|
| 197 |
batch_size, seq_len, hidden_size = hidden_states.shape
|
| 198 |
+
|
| 199 |
+
# Transpose for Conv1d: [B, H, L]
|
| 200 |
x = hidden_states.transpose(1, 2)
|
| 201 |
+
|
| 202 |
+
# Apply conv with causal padding
|
| 203 |
out = self.conv(x)
|
| 204 |
+
# Remove right padding to make it causal
|
| 205 |
out = out[:, :, :seq_len]
|
| 206 |
+
|
| 207 |
+
# Optional activation
|
| 208 |
if self.use_activation:
|
| 209 |
out = F.silu(out)
|
| 210 |
+
|
| 211 |
+
# Transpose back: [B, L, H]
|
| 212 |
out = out.transpose(1, 2)
|
| 213 |
+
|
| 214 |
+
# Residual connection
|
| 215 |
if self.use_residual:
|
| 216 |
out = hidden_states + out
|
| 217 |
+
|
| 218 |
return out
|
| 219 |
|
| 220 |
|
|
|
|
| 253 |
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 254 |
self.max_seq_len_cached = seq_len
|
| 255 |
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
| 256 |
+
|
| 257 |
freqs = torch.outer(t, self.inv_freq)
|
| 258 |
emb = torch.cat((freqs, freqs), dim=-1)
|
| 259 |
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
|
|
|
| 262 |
def forward(self, x, seq_len=None):
|
| 263 |
if seq_len > self.max_seq_len_cached:
|
| 264 |
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
| 265 |
+
|
| 266 |
return (
|
| 267 |
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
| 268 |
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
|
|
|
| 270 |
|
| 271 |
|
| 272 |
def rotate_half(x):
|
| 273 |
+
"""Rotates half the hidden dims of the input."""
|
| 274 |
x1 = x[..., : x.shape[-1] // 2]
|
| 275 |
x2 = x[..., x.shape[-1] // 2 :]
|
| 276 |
return torch.cat((-x2, x1), dim=-1)
|
| 277 |
|
| 278 |
|
| 279 |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
| 280 |
+
"""Applies Rotary Position Embedding to query and key tensors."""
|
| 281 |
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
| 282 |
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
| 283 |
+
# Cast to input dtype for consistency
|
| 284 |
cos = cos.to(q.dtype)
|
| 285 |
sin = sin.to(q.dtype)
|
| 286 |
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
|
|
| 293 |
|
| 294 |
def __init__(self, config):
|
| 295 |
super().__init__()
|
| 296 |
+
self.config = config
|
| 297 |
self.hidden_size = config.hidden_size
|
| 298 |
self.intermediate_size = config.intermediate_size
|
| 299 |
|
| 300 |
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 301 |
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 302 |
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 303 |
+
|
| 304 |
self.act_fn = nn.SiLU()
|
| 305 |
|
| 306 |
def forward(self, x):
|
|
|
|
| 308 |
|
| 309 |
|
| 310 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 311 |
+
"""Repeat KV heads for GQA."""
|
| 312 |
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 313 |
if n_rep == 1:
|
| 314 |
return hidden_states
|
|
|
|
| 317 |
|
| 318 |
|
| 319 |
class DharaAttention(nn.Module):
|
| 320 |
+
"""Multi-Head Bidirectional Attention with GQA support (for diffusion)"""
|
| 321 |
|
| 322 |
def __init__(self, config: DharaConfig, layer_idx: Optional[int] = None):
|
| 323 |
super().__init__()
|
|
|
|
| 332 |
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 333 |
self.max_position_embeddings = config.max_position_embeddings
|
| 334 |
self.rope_theta = config.rope_theta
|
| 335 |
+
self.is_causal = False # CRITICAL: Dhara uses bidirectional attention
|
| 336 |
+
|
| 337 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 338 |
+
raise ValueError(
|
| 339 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
| 340 |
+
f" and `num_heads`: {self.num_heads})."
|
| 341 |
+
)
|
| 342 |
|
| 343 |
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 344 |
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
|
|
|
| 372 |
|
| 373 |
kv_seq_len = key_states.shape[-2]
|
| 374 |
if past_key_value is not None:
|
| 375 |
+
if self.layer_idx is None:
|
| 376 |
+
raise ValueError(
|
| 377 |
+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
| 378 |
+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 379 |
+
"with a layer index."
|
| 380 |
+
)
|
| 381 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 382 |
|
| 383 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
|
|
| 390 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 391 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 392 |
|
| 393 |
+
# Flash Attention for bidirectional
|
| 394 |
if FLASH_ATTN_AVAILABLE and self.config.use_flash_attention:
|
| 395 |
query_states = query_states.transpose(1, 2).contiguous()
|
| 396 |
key_states = key_states.transpose(1, 2).contiguous()
|
|
|
|
| 402 |
value_states = value_states.to(torch.bfloat16)
|
| 403 |
|
| 404 |
attn_output = flash_attn_func(
|
| 405 |
+
query_states,
|
| 406 |
+
key_states,
|
| 407 |
+
value_states,
|
| 408 |
+
dropout_p=0.0,
|
| 409 |
+
causal=False, # Bidirectional for diffusion
|
| 410 |
)
|
| 411 |
+
|
| 412 |
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
| 413 |
+
|
| 414 |
else:
|
| 415 |
+
# Standard attention
|
| 416 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 417 |
+
|
| 418 |
if attention_mask is not None:
|
| 419 |
attn_weights = attn_weights + attention_mask
|
| 420 |
+
|
| 421 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 422 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
| 423 |
attn_output = torch.matmul(attn_weights, value_states)
|
| 424 |
+
|
| 425 |
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 426 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 427 |
|
|
|
|
| 434 |
|
| 435 |
|
| 436 |
class DharaDecoderLayer(nn.Module):
|
| 437 |
+
"""
|
| 438 |
+
Dhara decoder layer with Canon layers at positions A and C.
|
| 439 |
+
|
| 440 |
+
Flow:
|
| 441 |
+
x -> LayerNorm -> [CanonA] -> Attention -> + residual
|
| 442 |
+
x -> LayerNorm -> [CanonC] -> MLP -> + residual
|
| 443 |
+
"""
|
| 444 |
|
| 445 |
def __init__(self, config: DharaConfig, layer_idx: int):
|
| 446 |
super().__init__()
|
| 447 |
self.hidden_size = config.hidden_size
|
| 448 |
self.config = config
|
| 449 |
|
| 450 |
+
# Pre-attention norm
|
| 451 |
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 452 |
|
| 453 |
+
# Canon-A: before attention
|
| 454 |
self.canon_a = None
|
| 455 |
if "A" in config.canon_set:
|
| 456 |
self.canon_a = CanonLayer(
|
|
|
|
| 461 |
use_bias=config.canon_bias,
|
| 462 |
)
|
| 463 |
|
| 464 |
+
# Attention
|
| 465 |
self.self_attn = DharaAttention(config=config, layer_idx=layer_idx)
|
| 466 |
+
|
| 467 |
+
# Post-attention norm
|
| 468 |
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 469 |
|
| 470 |
+
# Canon-C: before MLP
|
| 471 |
self.canon_c = None
|
| 472 |
if "C" in config.canon_set:
|
| 473 |
self.canon_c = CanonLayer(
|
|
|
|
| 478 |
use_bias=config.canon_bias,
|
| 479 |
)
|
| 480 |
|
| 481 |
+
# MLP
|
| 482 |
self.mlp = DharaMLP(config)
|
| 483 |
|
| 484 |
def forward(
|
|
|
|
| 491 |
use_cache: Optional[bool] = False,
|
| 492 |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 493 |
residual = hidden_states
|
| 494 |
+
|
| 495 |
+
# Pre-attention layernorm
|
| 496 |
hidden_states = self.input_layernorm(hidden_states)
|
| 497 |
|
| 498 |
+
# Canon-A (before attention)
|
| 499 |
if self.canon_a is not None:
|
| 500 |
hidden_states = self.canon_a(hidden_states)
|
| 501 |
|
| 502 |
+
# Self Attention (bidirectional)
|
| 503 |
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
| 504 |
hidden_states=hidden_states,
|
| 505 |
attention_mask=attention_mask,
|
|
|
|
| 510 |
)
|
| 511 |
hidden_states = residual + hidden_states
|
| 512 |
|
| 513 |
+
# MLP block
|
| 514 |
residual = hidden_states
|
| 515 |
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 516 |
|
| 517 |
+
# Canon-C (before MLP)
|
| 518 |
if self.canon_c is not None:
|
| 519 |
hidden_states = self.canon_c(hidden_states)
|
| 520 |
|
|
|
|
| 522 |
hidden_states = residual + hidden_states
|
| 523 |
|
| 524 |
outputs = (hidden_states,)
|
| 525 |
+
|
| 526 |
if output_attentions:
|
| 527 |
outputs += (self_attn_weights,)
|
| 528 |
+
|
| 529 |
if use_cache:
|
| 530 |
outputs += (present_key_value,)
|
| 531 |
|
|
|
|
| 554 |
|
| 555 |
|
| 556 |
class DharaModel(DharaPreTrainedModel):
|
| 557 |
+
"""
|
| 558 |
+
Dhara base model with bidirectional attention and Canon layers.
|
| 559 |
+
"""
|
| 560 |
|
| 561 |
def __init__(self, config: DharaConfig):
|
| 562 |
super().__init__(config)
|
|
|
|
| 567 |
self.layers = nn.ModuleList(
|
| 568 |
[DharaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 569 |
)
|
| 570 |
+
|
| 571 |
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 572 |
self.gradient_checkpointing = False
|
| 573 |
|
|
|
|
| 596 |
return_dict: Optional[bool] = None,
|
| 597 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 598 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 599 |
+
output_hidden_states = (
|
| 600 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 601 |
+
)
|
| 602 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 603 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 604 |
|
| 605 |
if input_ids is not None and inputs_embeds is not None:
|
| 606 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 607 |
elif input_ids is not None:
|
| 608 |
batch_size, seq_length = input_ids.shape[:2]
|
| 609 |
elif inputs_embeds is not None:
|
|
|
|
| 611 |
else:
|
| 612 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 613 |
|
| 614 |
+
if self.gradient_checkpointing and self.training:
|
| 615 |
+
if use_cache:
|
| 616 |
+
logger.warning_once(
|
| 617 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 618 |
+
)
|
| 619 |
+
use_cache = False
|
| 620 |
|
| 621 |
past_key_values_length = 0
|
| 622 |
if use_cache:
|
|
|
|
| 638 |
if self._use_flash_attention_2:
|
| 639 |
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
| 640 |
else:
|
| 641 |
+
# Bidirectional attention mask (not causal)
|
| 642 |
if attention_mask is not None:
|
| 643 |
if attention_mask.dim() == 2:
|
| 644 |
+
batch_size, seq_length = attention_mask.shape
|
| 645 |
attention_mask_4d = attention_mask[:, None, None, :].expand(
|
| 646 |
batch_size, 1, seq_length, seq_length
|
| 647 |
).to(dtype=inputs_embeds.dtype)
|
|
|
|
| 650 |
torch.tensor(float('-inf'), dtype=inputs_embeds.dtype, device=attention_mask_4d.device),
|
| 651 |
torch.tensor(0.0, dtype=inputs_embeds.dtype, device=attention_mask_4d.device)
|
| 652 |
)
|
| 653 |
+
else:
|
| 654 |
+
attention_mask = attention_mask
|
| 655 |
+
else:
|
| 656 |
+
attention_mask = None
|
| 657 |
|
| 658 |
hidden_states = inputs_embeds
|
| 659 |
+
|
| 660 |
all_hidden_states = () if output_hidden_states else None
|
| 661 |
all_self_attns = () if output_attentions else None
|
| 662 |
next_decoder_cache = None
|
|
|
|
| 668 |
if self.gradient_checkpointing and self.training:
|
| 669 |
layer_outputs = self._gradient_checkpointing_func(
|
| 670 |
decoder_layer.__call__,
|
| 671 |
+
hidden_states,
|
| 672 |
+
attention_mask,
|
| 673 |
+
position_ids,
|
| 674 |
+
past_key_values,
|
| 675 |
+
output_attentions,
|
| 676 |
+
use_cache,
|
| 677 |
)
|
| 678 |
else:
|
| 679 |
layer_outputs = decoder_layer(
|
|
|
|
| 689 |
|
| 690 |
if use_cache:
|
| 691 |
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
| 692 |
+
|
| 693 |
if output_attentions:
|
| 694 |
all_self_attns += (layer_outputs[1],)
|
| 695 |
|
|
|
|
| 713 |
)
|
| 714 |
|
| 715 |
def add_noise_to_tokens(self, input_ids: torch.LongTensor, t: torch.FloatTensor, eps: float = None):
|
| 716 |
+
"""
|
| 717 |
+
MDM-style masking: Replace tokens with [MASK] based on noise level t.
|
| 718 |
+
|
| 719 |
+
Args:
|
| 720 |
+
input_ids: Input token IDs [batch_size, seq_len]
|
| 721 |
+
t: Noise level in [0, 1] [batch_size]
|
| 722 |
+
eps: Minimum mask probability
|
| 723 |
+
|
| 724 |
+
Returns:
|
| 725 |
+
Tuple of (noisy_input_ids, corruption_mask, p_mask)
|
| 726 |
+
"""
|
| 727 |
batch_size, seq_len = input_ids.shape
|
| 728 |
device = input_ids.device
|
| 729 |
|
| 730 |
if eps is None:
|
| 731 |
eps = getattr(self.config, 'mask_epsilon', 0.001)
|
| 732 |
p_mask = (1 - eps) * t + eps
|
| 733 |
+
|
| 734 |
p_mask = p_mask.unsqueeze(-1).expand(batch_size, seq_len)
|
| 735 |
|
| 736 |
corruption_mask = torch.rand(batch_size, seq_len, device=device) < p_mask
|
| 737 |
+
|
| 738 |
+
mask_token_id = self.mask_token_id
|
| 739 |
+
noisy_input_ids = torch.where(corruption_mask, mask_token_id, input_ids)
|
| 740 |
|
| 741 |
return noisy_input_ids, corruption_mask, p_mask
|
| 742 |
|
| 743 |
|
| 744 |
class DharaForMaskedDiffusion(DharaPreTrainedModel, GenerationMixin):
|
| 745 |
+
"""Dhara Model with Masked Diffusion head for training"""
|
| 746 |
_tied_weights_keys = ["lm_head.weight"]
|
| 747 |
|
| 748 |
def __init__(self, config):
|
|
|
|
| 768 |
def set_output_embeddings(self, new_embeddings):
|
| 769 |
self.lm_head = new_embeddings
|
| 770 |
|
| 771 |
+
def set_decoder(self, decoder):
|
| 772 |
+
self.model = decoder
|
| 773 |
+
|
| 774 |
def get_decoder(self):
|
| 775 |
return self.model
|
| 776 |
|
|
|
|
| 790 |
p_mask: Optional[torch.Tensor] = None,
|
| 791 |
) -> Union[Tuple, MaskedLMOutput]:
|
| 792 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 793 |
+
output_hidden_states = (
|
| 794 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 795 |
+
)
|
| 796 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 797 |
|
| 798 |
outputs = self.model(
|
|
|
|
| 830 |
)
|
| 831 |
|
| 832 |
def compute_diffusion_loss(self, logits, labels, corruption_mask=None, p_mask=None):
|
| 833 |
+
"""
|
| 834 |
+
MDM loss with p_mask importance weighting.
|
| 835 |
+
"""
|
| 836 |
if corruption_mask is None or p_mask is None:
|
| 837 |
+
raise ValueError(
|
| 838 |
+
"MDM requires both corruption_mask and p_mask for loss computation."
|
| 839 |
+
)
|
| 840 |
|
| 841 |
loss = F.cross_entropy(
|
| 842 |
logits.view(-1, self.config.vocab_size),
|
|
|
|
| 847 |
|
| 848 |
masked_losses = loss[corruption_mask]
|
| 849 |
masked_p_mask = p_mask[corruption_mask]
|
| 850 |
+
|
| 851 |
weighted_losses = masked_losses / masked_p_mask
|
| 852 |
|
| 853 |
total_positions = labels.shape[0] * labels.shape[1]
|
|
|
|
| 870 |
max_cache_length = None
|
| 871 |
|
| 872 |
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
| 873 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
| 874 |
elif past_length < input_ids.shape[1]:
|
| 875 |
input_ids = input_ids[:, past_length:]
|
| 876 |
|
| 877 |
+
if (
|
| 878 |
+
max_cache_length is not None
|
| 879 |
+
and attention_mask is not None
|
| 880 |
+
and cache_length + input_ids.shape[1] > max_cache_length
|
| 881 |
+
):
|
| 882 |
attention_mask = attention_mask[:, -max_cache_length:]
|
| 883 |
|
| 884 |
position_ids = kwargs.get("position_ids", None)
|
|
|
|
| 886 |
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 887 |
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 888 |
if past_key_values:
|
| 889 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
| 890 |
|
| 891 |
if inputs_embeds is not None and past_key_values is None:
|
| 892 |
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 893 |
else:
|
| 894 |
model_inputs = {"input_ids": input_ids}
|
| 895 |
|
| 896 |
+
model_inputs.update(
|
| 897 |
+
{
|
| 898 |
+
"position_ids": position_ids,
|
| 899 |
+
"past_key_values": past_key_values,
|
| 900 |
+
"use_cache": kwargs.get("use_cache"),
|
| 901 |
+
"attention_mask": attention_mask,
|
| 902 |
+
}
|
| 903 |
+
)
|
| 904 |
return model_inputs
|
| 905 |
|
| 906 |
+
@staticmethod
|
| 907 |
+
def _reorder_cache(past_key_values, beam_idx):
|
| 908 |
+
reordered_past = ()
|
| 909 |
+
for layer_past in past_key_values:
|
| 910 |
+
reordered_past += (
|
| 911 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
|
| 912 |
+
)
|
| 913 |
+
return reordered_past
|
| 914 |
+
|
| 915 |
@torch.no_grad()
|
| 916 |
def generate(
|
| 917 |
self,
|
|
|
|
| 921 |
num_diffusion_steps: int = 10,
|
| 922 |
temperature: float = 1.0,
|
| 923 |
top_p: float = 0.9,
|
| 924 |
+
top_k: int = 50,
|
| 925 |
do_sample: bool = True,
|
| 926 |
pad_token_id: Optional[int] = None,
|
| 927 |
eos_token_id: Optional[int] = None,
|
| 928 |
+
repetition_penalty: float = 1.2,
|
| 929 |
**kwargs
|
| 930 |
) -> torch.LongTensor:
|
| 931 |
"""
|
| 932 |
+
Generate text using autoregressive sampling with the diffusion model.
|
| 933 |
|
| 934 |
+
Since this model was converted from AR to diffusion via WSD training,
|
| 935 |
+
we generate tokens one at a time left-to-right, using the model's
|
| 936 |
+
next-token predictions at each position.
|
| 937 |
|
| 938 |
Args:
|
| 939 |
input_ids: Input prompt token IDs [batch_size, prompt_len]
|
| 940 |
max_length: Maximum total sequence length (prompt + generation)
|
| 941 |
max_new_tokens: Number of new tokens to generate (alternative to max_length)
|
| 942 |
+
num_diffusion_steps: Number of refinement iterations per token (higher = better quality)
|
| 943 |
temperature: Sampling temperature (higher = more random)
|
| 944 |
top_p: Nucleus sampling threshold
|
| 945 |
+
top_k: Top-k sampling threshold
|
| 946 |
do_sample: Whether to sample or take argmax
|
| 947 |
pad_token_id: Token ID for padding
|
| 948 |
eos_token_id: Token ID for end of sequence
|
| 949 |
+
repetition_penalty: Penalty for repeating tokens (>1 = less repetition)
|
| 950 |
|
| 951 |
Returns:
|
| 952 |
Generated token IDs including the prompt
|
|
|
|
| 978 |
if eos_token_id is None:
|
| 979 |
eos_token_id = self.config.eos_token_id if hasattr(self.config, 'eos_token_id') else 2
|
| 980 |
|
| 981 |
+
# Start with the prompt
|
| 982 |
+
generated = input_ids.clone()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 983 |
|
| 984 |
+
# Track generated tokens for repetition penalty
|
| 985 |
+
generated_set = set()
|
| 986 |
+
for i in range(prompt_len):
|
| 987 |
+
for b in range(batch_size):
|
| 988 |
+
generated_set.add(input_ids[b, i].item())
|
| 989 |
+
|
| 990 |
+
# Generate tokens one at a time (autoregressive style)
|
| 991 |
+
for pos in range(gen_len):
|
| 992 |
+
# Add a mask token at the next position
|
| 993 |
+
current_seq = torch.cat([
|
| 994 |
+
generated,
|
| 995 |
+
torch.full((batch_size, 1), mask_token_id, dtype=torch.long, device=device)
|
| 996 |
+
], dim=1)
|
| 997 |
+
|
| 998 |
+
# Get model predictions
|
| 999 |
+
outputs = self(input_ids=current_seq)
|
| 1000 |
logits = outputs.logits # [batch, seq_len, vocab]
|
| 1001 |
|
| 1002 |
+
# Get logits for the last (masked) position
|
| 1003 |
+
next_token_logits = logits[:, -1, :] # [batch, vocab]
|
| 1004 |
|
| 1005 |
+
# Apply repetition penalty
|
| 1006 |
+
if repetition_penalty != 1.0:
|
| 1007 |
+
for b in range(batch_size):
|
| 1008 |
+
for prev_token in generated_set:
|
| 1009 |
+
if prev_token < next_token_logits.shape[1]:
|
| 1010 |
+
next_token_logits[b, prev_token] /= repetition_penalty
|
| 1011 |
|
| 1012 |
# Apply temperature
|
| 1013 |
+
if temperature != 1.0 and temperature > 0:
|
| 1014 |
+
next_token_logits = next_token_logits / temperature
|
| 1015 |
+
|
| 1016 |
+
if do_sample and temperature > 0:
|
| 1017 |
+
# Apply top-k filtering
|
| 1018 |
+
if top_k > 0:
|
| 1019 |
+
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
|
| 1020 |
+
next_token_logits[indices_to_remove] = float('-inf')
|
| 1021 |
+
|
| 1022 |
+
# Apply top-p (nucleus) filtering
|
| 1023 |
+
if top_p < 1.0:
|
| 1024 |
+
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
| 1025 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 1026 |
+
|
| 1027 |
+
# Remove tokens with cumulative probability above threshold
|
| 1028 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 1029 |
+
# Shift the indices to the right to keep the first token above threshold
|
| 1030 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 1031 |
+
sorted_indices_to_remove[..., 0] = False
|
| 1032 |
+
|
| 1033 |
+
# Scatter sorted indices to original indexing
|
| 1034 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 1035 |
+
next_token_logits[indices_to_remove] = float('-inf')
|
| 1036 |
+
|
| 1037 |
+
# Sample from the filtered distribution
|
| 1038 |
+
probs = F.softmax(next_token_logits, dim=-1)
|
| 1039 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
|
| 1040 |
else:
|
| 1041 |
+
# Greedy decoding
|
| 1042 |
+
next_tokens = next_token_logits.argmax(dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1043 |
|
| 1044 |
+
# Add to generated sequence
|
| 1045 |
+
generated = torch.cat([generated, next_tokens.unsqueeze(-1)], dim=1)
|
| 1046 |
|
| 1047 |
+
# Update generated set for repetition penalty
|
| 1048 |
+
for b in range(batch_size):
|
| 1049 |
+
generated_set.add(next_tokens[b].item())
|
|
|
|
|
|
|
|
|
|
| 1050 |
|
| 1051 |
+
# Check for EOS
|
| 1052 |
+
if eos_token_id is not None and (next_tokens == eos_token_id).all():
|
| 1053 |
+
break
|
| 1054 |
|
| 1055 |
+
return generated
|
| 1056 |
|
| 1057 |
def save_pretrained(self, save_directory, **kwargs):
|
| 1058 |
+
"""Override to save in SafeTensors format by default"""
|
| 1059 |
kwargs['safe_serialization'] = kwargs.get('safe_serialization', True)
|
| 1060 |
return super().save_pretrained(save_directory, **kwargs)
|
| 1061 |
+
|
| 1062 |
+
|
| 1063 |
+
def count_parameters(model):
|
| 1064 |
+
"""Count total and Canon-specific parameters."""
|
| 1065 |
+
total = sum(p.numel() for p in model.parameters())
|
| 1066 |
+
canon = sum(p.numel() for n, p in model.named_parameters() if 'canon' in n.lower())
|
| 1067 |
+
return total, canon
|
| 1068 |
+
|
| 1069 |
+
|
| 1070 |
+
if __name__ == "__main__":
|
| 1071 |
+
# Quick test
|
| 1072 |
+
print("Testing Dhara model creation...")
|
| 1073 |
+
|
| 1074 |
+
config = DharaConfig(
|
| 1075 |
+
vocab_size=50304,
|
| 1076 |
+
hidden_size=384,
|
| 1077 |
+
num_hidden_layers=32,
|
| 1078 |
+
num_attention_heads=8,
|
| 1079 |
+
num_key_value_heads=4,
|
| 1080 |
+
intermediate_size=1024,
|
| 1081 |
+
canon_set="AC",
|
| 1082 |
+
canon_kernel=4,
|
| 1083 |
+
canon_residual=True,
|
| 1084 |
+
)
|
| 1085 |
+
|
| 1086 |
+
model = DharaForMaskedDiffusion(config)
|
| 1087 |
+
|
| 1088 |
+
total, canon = count_parameters(model)
|
| 1089 |
+
print(f"Model created successfully!")
|
| 1090 |
+
print(f"Total params: {total:,} ({total/1e6:.2f}M)")
|
| 1091 |
+
print(f"Canon params: {canon:,} ({100*canon/total:.3f}%)")
|
| 1092 |
+
print(f"Base Dhara would be: {total - canon:,}")
|
| 1093 |
+
|
| 1094 |
+
# Test forward pass
|
| 1095 |
+
batch_size, seq_len = 2, 64
|
| 1096 |
+
input_ids = torch.randint(0, 50304, (batch_size, seq_len))
|
| 1097 |
+
|
| 1098 |
+
# Test with diffusion noise
|
| 1099 |
+
t = torch.rand(batch_size)
|
| 1100 |
+
noisy_ids, corruption_mask, p_mask = model.add_noise_to_tokens(input_ids, t)
|
| 1101 |
+
|
| 1102 |
+
with torch.no_grad():
|
| 1103 |
+
outputs = model(
|
| 1104 |
+
input_ids=noisy_ids,
|
| 1105 |
+
labels=input_ids,
|
| 1106 |
+
corruption_mask=corruption_mask,
|
| 1107 |
+
p_mask=p_mask,
|
| 1108 |
+
)
|
| 1109 |
+
|
| 1110 |
+
print(f"Forward pass: loss={outputs.loss.item():.4f}")
|
| 1111 |
+
print("Ready for training!")
|
tokenizer.json
CHANGED
|
@@ -1,21 +1,7 @@
|
|
| 1 |
{
|
| 2 |
"version": "1.0",
|
| 3 |
-
"truncation":
|
| 4 |
-
|
| 5 |
-
"max_length": 2048,
|
| 6 |
-
"strategy": "LongestFirst",
|
| 7 |
-
"stride": 0
|
| 8 |
-
},
|
| 9 |
-
"padding": {
|
| 10 |
-
"strategy": {
|
| 11 |
-
"Fixed": 2048
|
| 12 |
-
},
|
| 13 |
-
"direction": "Right",
|
| 14 |
-
"pad_to_multiple_of": null,
|
| 15 |
-
"pad_id": 50256,
|
| 16 |
-
"pad_type_id": 0,
|
| 17 |
-
"pad_token": "<|endoftext|>"
|
| 18 |
-
},
|
| 19 |
"added_tokens": [
|
| 20 |
{
|
| 21 |
"id": 50256,
|
|
|
|
| 1 |
{
|
| 2 |
"version": "1.0",
|
| 3 |
+
"truncation": null,
|
| 4 |
+
"padding": null,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"added_tokens": [
|
| 6 |
{
|
| 7 |
"id": 50256,
|