codelion commited on
Commit
575face
·
verified ·
1 Parent(s): 40643a7

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. README.md +109 -199
  2. config.json +8 -8
  3. model.safetensors +2 -2
  4. modeling_dhara.py +335 -136
  5. tokenizer.json +2 -16
README.md CHANGED
@@ -2,253 +2,163 @@
2
  license: apache-2.0
3
  language:
4
  - en
 
5
  tags:
6
- - text-generation
7
  - diffusion
8
- - language-model
9
- - causal-lm
 
 
 
10
  datasets:
11
- - HuggingFaceFW/fineweb-edu
12
- - allenai/dolma
13
- - mlfoundations/dclm-baseline-1.0
14
- model-index:
15
- - name: dhara-70m
16
- results:
17
- - task:
18
- type: text-generation
19
- dataset:
20
- name: HellaSwag
21
- type: hellaswag
22
- metrics:
23
- - name: Accuracy
24
- type: accuracy
25
- value: 25.58
26
- - task:
27
- type: text-generation
28
- dataset:
29
- name: PIQA
30
- type: piqa
31
- metrics:
32
- - name: Accuracy
33
- type: accuracy
34
- value: 51.58
35
- - task:
36
- type: text-generation
37
- dataset:
38
- name: WinoGrande
39
- type: winogrande
40
- metrics:
41
- - name: Accuracy
42
- type: accuracy
43
- value: 49.64
44
- - task:
45
- type: text-generation
46
- dataset:
47
- name: ARC-Challenge
48
- type: arc_challenge
49
- metrics:
50
- - name: Accuracy
51
- type: accuracy
52
- value: 24.83
53
- - task:
54
- type: text-generation
55
- dataset:
56
- name: MMLU
57
- type: mmlu
58
- metrics:
59
- - name: Accuracy
60
- type: accuracy
61
- value: 23.85
62
- - task:
63
- type: text-generation
64
- dataset:
65
- name: TruthfulQA
66
- type: truthfulqa_mc2
67
- metrics:
68
- - name: Accuracy
69
- type: accuracy
70
- value: 47.50
71
  ---
72
 
73
- # Dhara-70M
74
 
75
- A 70M parameter diffusion language model optimized for high-throughput text generation with superior factuality.
76
-
77
- ## Table of Contents
78
- - [Model Description](#model-description)
79
- - [Training Data](#training-data)
80
- - [Training Details](#training-details)
81
- - [Benchmark Results](#benchmark-results)
82
- - [Usage](#usage)
83
- - [Key Insights](#key-insights)
84
- - [Limitations](#limitations)
85
- - [Citation](#citation)
86
 
87
  ## Model Description
88
 
89
- Dhara-70M is a novel diffusion language model that achieves:
90
- - **3.8x higher throughput** than autoregressive models
91
- - **Best-in-class factuality** on TruthfulQA (47.50%)
92
- - **10x training efficiency** via WSD (Warmup-Stable-Decay) conversion
93
 
94
- ### Architecture
95
 
96
- | Specification | Value |
97
- |--------------|-------|
98
- | **Parameters** | 71.34M |
99
- | **Layers** | 32 |
100
- | **Hidden Size** | 384 |
101
- | **FF Dimension** | 1024 |
102
- | **Attention Heads** | 8 |
103
- | **KV Heads** | 4 (GQA) |
104
- | **Context Length** | 2048 tokens |
105
- | **Position Encoding** | RoPE |
106
- | **Normalization** | RMSNorm |
107
- | **Special Layers** | Canon (depthwise causal convolutions) |
108
- | **Generation Type** | Diffusion (parallel token generation) |
109
-
110
- ## Training Data
111
-
112
- Dhara was trained in two stages:
113
-
114
- **Stage 1: AR Pretraining (1B tokens)**
115
- - 40% FinePDFs (400M tokens)
116
- - 30% DCLM Baseline (300M tokens)
117
- - 30% FineWeb-Edu (300M tokens)
118
-
119
- **Stage 2: WSD Conversion (100M tokens)**
120
- - Progressive block size warmup (1→4→32→64→1024)
121
- - MDLM diffusion objective
122
 
123
- ## Training Details
124
 
125
  | Parameter | Value |
126
  |-----------|-------|
127
- | **AR Training Tokens** | 1 billion |
128
- | **WSD Conversion Tokens** | 100 million |
129
- | **Batch Size** | 128 effective (8 × 16 gradient accumulation) |
130
- | **Learning Rate** | 5e-4 (AR) / 5e-5 (WSD) |
131
- | **Optimizer** | AdamW |
132
- | **Schedule** | Cosine decay with 2% warmup |
133
- | **Precision** | BF16 |
134
- | **Hardware** | Single NVIDIA A40 GPU |
135
- | **Total Training Time** | ~20 hours |
136
-
137
- ## Benchmark Results
138
-
139
- | Benchmark | Dhara-70M | GPT-2-70M | vs GPT-2 |
140
- |-----------|-----------|-----------|----------|
141
- | HellaSwag (0-shot) | 25.58% | 26.46% | -0.88% |
142
- | PIQA (0-shot) | 51.58% | 58.05% | -6.47% |
143
- | WinoGrande (0-shot) | 49.64% | 52.64% | -3.00% |
144
- | ARC-Challenge (0-shot) | **24.83%** | 22.27% | **+2.56%** |
145
- | MMLU (5-shot) | 23.85% | 25.77% | -1.92% |
146
- | TruthfulQA (0-shot) | **47.50%** | 45.83% | **+1.67%** |
147
- | GSM8K (5-shot) | 0.00% | 1.21% | -1.21% |
148
- | **Average** | **31.85%** | **33.18%** | -1.33% |
149
-
150
- ### Inference Performance
151
-
152
- | Metric | Dhara-70M | GPT-2-70M | Advantage |
153
- |--------|-----------|-----------|-----------|
154
- | Time to First Token | 35.5 ms | ~25 ms | 1.4x slower |
155
- | Throughput | 183.5 tok/s | ~48 tok/s | **3.8x faster** |
156
- | Peak Memory | 0.24 GB | 0.15 GB | 1.6x higher |
157
 
158
  ## Usage
159
 
 
 
 
 
 
 
 
 
160
  ```python
 
161
  from transformers import AutoTokenizer, AutoModelForCausalLM
162
 
163
  # Load model and tokenizer
 
 
 
 
 
164
  tokenizer = AutoTokenizer.from_pretrained("codelion/dhara-70m")
165
- model = AutoModelForCausalLM.from_pretrained("codelion/dhara-70m", trust_remote_code=True)
166
 
167
- # Generate text using diffusion sampling
168
- inputs = tokenizer("The future of AI is", return_tensors="pt")
 
 
 
 
 
 
169
  outputs = model.generate(
170
- **inputs,
171
- max_new_tokens=40, # Generate 40 new tokens
172
- num_diffusion_steps=10, # Diffusion denoising steps (higher = better quality)
173
- do_sample=True,
174
  temperature=0.8,
175
- top_p=0.9
 
 
 
176
  )
 
177
  print(tokenizer.decode(outputs[0], skip_special_tokens=True))
178
  ```
179
 
180
- ### Batch Generation (High Throughput)
181
-
182
- ```python
183
- # For batch generation, use larger batch sizes
184
- prompts = [
185
- "The future of AI is",
186
- "In recent years, machine learning has",
187
- "The most important discovery in physics was",
188
- "Climate change affects our planet by"
189
- ]
190
-
191
- inputs = tokenizer(prompts, return_tensors="pt", padding=True)
192
- outputs = model.generate(
193
- **inputs,
194
- max_length=100,
195
- do_sample=True,
196
- temperature=0.7,
197
- num_diffusion_steps=10 # Fewer steps = faster generation
198
- )
199
 
200
- for i, output in enumerate(outputs):
201
- print(f"Prompt {i+1}: {tokenizer.decode(output, skip_special_tokens=True)}")
202
- ```
 
 
 
 
 
 
203
 
204
- ## Key Insights
205
 
206
- 1. **Throughput vs Accuracy Trade-off**: Dhara trades 1.33% average accuracy for 3.8x higher throughput, making it ideal for batch processing tasks.
207
 
208
- 2. **Superior Factuality**: Dhara excels on TruthfulQA (+1.67% vs GPT-2), suggesting diffusion models may reduce hallucinations through bidirectional context.
209
 
210
- 3. **Reasoning Advantage**: ARC-Challenge +2.56% indicates strong performance on reasoning tasks.
211
 
212
- 4. **WSD Efficiency**: Converting an AR model to diffusion via WSD uses 10x fewer tokens than training from scratch with equivalent quality.
213
 
214
- 5. **Canon Layers Help**: The depthwise causal convolutions (Canon layers) improve factuality and reasoning with only 0.13% parameter overhead.
 
 
 
 
215
 
216
- ## When to Use Dhara
217
 
218
- **Choose Dhara when:**
219
- - Batch generation throughput matters
220
- - Factual accuracy is critical
221
- - You have an existing AR checkpoint to convert
222
 
223
- **Choose AR models when:**
224
- - Interactive latency is critical
225
- - Sequential reasoning is important (math, coding)
226
- - Memory is constrained
 
227
 
228
  ## Limitations
229
 
230
- - Lower performance on sequential reasoning tasks (GSM8K: 0.00%)
231
- - Higher memory usage due to bidirectional attention
232
- - Slightly higher time-to-first-token latency
233
- - Best suited for batch rather than interactive use cases
234
 
235
  ## Citation
236
 
 
 
237
  ```bibtex
238
- @article{sharma2025optimal,
239
- title={The Optimal Architecture for Small Language Models},
240
- author={Sharma, Asankhaya},
241
- year={2025},
242
- url={https://huggingface.co/blog/codelion/optimal-model-architecture}
 
243
  }
244
  ```
245
 
246
- ## Related Work
247
-
248
- - [The Optimal Architecture for Small Language Models](https://huggingface.co/blog/codelion/optimal-model-architecture) - Blog post describing this work
249
- - [The 1 Billion Token Challenge: Optimal Dataset Mixing](https://huggingface.co/blog/codelion/optimal-dataset-mixing) - Our previous work on optimal pretraining data
250
- - [GPT-2-70M](https://huggingface.co/codelion/gpt-2-70m) - Our previous model from optimal pretraining experiments
251
-
252
- ## Contact
253
 
254
- For questions or feedback, please open a discussion on the [Hugging Face discussions page](https://huggingface.co/codelion/dhara-70m/discussions).
 
2
  license: apache-2.0
3
  language:
4
  - en
5
+ library_name: transformers
6
  tags:
 
7
  - diffusion
8
+ - masked-language-model
9
+ - text-generation
10
+ - pytorch
11
+ - transformers
12
+ pipeline_tag: text-generation
13
  datasets:
14
+ - codelion/pre-training-dataset-samples
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  ---
16
 
17
+ # Dhara-70M: Diffusion Language Model
18
 
19
+ Dhara is a 70M parameter diffusion language model that combines masked diffusion training with Canon layers for improved local context understanding.
 
 
 
 
 
 
 
 
 
 
20
 
21
  ## Model Description
22
 
23
+ Dhara was created by converting a pre-trained autoregressive (AR) LLM to a diffusion model using **Warmup-Stable-Decay (WSD)** training. This approach preserves the language understanding capabilities of the original AR model while enabling bidirectional attention and parallel token generation.
 
 
 
24
 
25
+ ### Key Features
26
 
27
+ - **Bidirectional Attention**: Unlike causal LLMs, Dhara uses full bidirectional attention during generation
28
+ - **Canon Layers**: Incorporates causal depthwise convolutions at positions A (before attention) and C (before MLP) for local context mixing
29
+ - **WSD Conversion**: Trained with 100M tokens to convert AR checkpoint to diffusion while preserving language capabilities
30
+ - **Custom Generate Method**: Includes a specialized `generate()` method for text generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ ### Architecture
33
 
34
  | Parameter | Value |
35
  |-----------|-------|
36
+ | Parameters | 70M |
37
+ | Hidden Size | 384 |
38
+ | Layers | 32 |
39
+ | Attention Heads | 8 |
40
+ | KV Heads | 4 (GQA) |
41
+ | Intermediate Size | 1024 |
42
+ | Vocabulary | 50,304 |
43
+ | Context Length | 1024 |
44
+ | Canon Kernel | 4 |
45
+ | Canon Positions | A, C |
46
+
47
+ ## Evaluation Results
48
+
49
+ | Benchmark | Score |
50
+ |-----------|-------|
51
+ | HellaSwag | 29.42 |
52
+ | ARC-Easy | 43.35 |
53
+ | ARC-Challenge | 24.15 |
54
+ | PIQA | 61.48 |
55
+ | Winogrande | 50.75 |
56
+ | OpenBookQA | 19.75 |
57
+ | **Average** | **38.15** |
58
+
59
+ *Self-reported evaluation results across 6 standard benchmarks.*
 
 
 
 
 
 
60
 
61
  ## Usage
62
 
63
+ ### Installation
64
+
65
+ ```bash
66
+ pip install transformers torch
67
+ ```
68
+
69
+ ### Quick Start
70
+
71
  ```python
72
+ import torch
73
  from transformers import AutoTokenizer, AutoModelForCausalLM
74
 
75
  # Load model and tokenizer
76
+ model = AutoModelForCausalLM.from_pretrained(
77
+ "codelion/dhara-70m",
78
+ trust_remote_code=True,
79
+ torch_dtype=torch.bfloat16
80
+ )
81
  tokenizer = AutoTokenizer.from_pretrained("codelion/dhara-70m")
 
82
 
83
+ # Move to GPU if available
84
+ device = "cuda" if torch.cuda.is_available() else "cpu"
85
+ model = model.to(device)
86
+
87
+ # Generate text
88
+ prompt = "The future of artificial intelligence"
89
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
90
+
91
  outputs = model.generate(
92
+ inputs.input_ids,
93
+ max_new_tokens=50,
 
 
94
  temperature=0.8,
95
+ top_p=0.9,
96
+ top_k=50,
97
+ repetition_penalty=1.2,
98
+ do_sample=True
99
  )
100
+
101
  print(tokenizer.decode(outputs[0], skip_special_tokens=True))
102
  ```
103
 
104
+ ### Generation Parameters
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ | Parameter | Default | Description |
107
+ |-----------|---------|-------------|
108
+ | `max_new_tokens` | 50 | Number of tokens to generate |
109
+ | `temperature` | 1.0 | Sampling temperature (higher = more random) |
110
+ | `top_p` | 0.9 | Nucleus sampling threshold |
111
+ | `top_k` | 50 | Top-k sampling threshold |
112
+ | `repetition_penalty` | 1.2 | Penalty for token repetition |
113
+ | `do_sample` | True | Whether to sample or use greedy decoding |
114
+ | `num_diffusion_steps` | 10 | Diffusion refinement steps (for future use) |
115
 
116
+ ## Training Details
117
 
118
+ ### Dataset
119
 
120
+ Dhara was trained on a curated 1B token sample from the [Pre-training Dataset Samples](https://huggingface.co/collections/codelion/pre-training-dataset-samples) collection.
121
 
122
+ ### WSD (Warmup-Stable-Decay) Conversion
123
 
124
+ Dhara was converted from an autoregressive checkpoint using the WSD training schedule:
125
 
126
+ - **Base Model**: LLaMA-style AR model with Canon layers
127
+ - **Total Training Tokens**: 1B tokens (AR) + 100M tokens (diffusion conversion)
128
+ - **WSD Warmup Phase**: 20M tokens
129
+ - **WSD Stable Phase**: 80M tokens
130
+ - **Training Objective**: Masked Diffusion Modeling (MDM)
131
 
132
+ ### Canon Layers
133
 
134
+ Canon layers are causal depthwise convolutions that provide local context mixing with O(n) complexity. Based on "Physics of Language Models: Part 4.1" by Zeyuan Allen-Zhu:
 
 
 
135
 
136
+ - **Position A**: Applied after input LayerNorm, before attention
137
+ - **Position C**: Applied after post-attention LayerNorm, before MLP
138
+ - **Kernel Size**: 4 tokens
139
+ - **Residual Connection**: Enabled
140
+ - **Activation**: None (as recommended for transformers)
141
 
142
  ## Limitations
143
 
144
+ - This is a research model and may generate inaccurate or inappropriate content
145
+ - Performance may vary on tasks requiring long-range dependencies
146
+ - The model was trained on a limited dataset and may have knowledge gaps
 
147
 
148
  ## Citation
149
 
150
+ If you use this model, please cite:
151
+
152
  ```bibtex
153
+ @misc{dhara2024,
154
+ title={Dhara: Diffusion Language Model with Canon Layers},
155
+ author={CodeLion},
156
+ year={2024},
157
+ publisher={HuggingFace},
158
+ url={https://huggingface.co/codelion/dhara-70m}
159
  }
160
  ```
161
 
162
+ ## License
 
 
 
 
 
 
163
 
164
+ Apache 2.0
config.json CHANGED
@@ -2,12 +2,12 @@
2
  "architectures": [
3
  "DharaForMaskedDiffusion"
4
  ],
 
5
  "auto_map": {
6
  "AutoConfig": "modeling_dhara.DharaConfig",
7
  "AutoModel": "modeling_dhara.DharaForMaskedDiffusion",
8
  "AutoModelForCausalLM": "modeling_dhara.DharaForMaskedDiffusion"
9
  },
10
- "attention_dropout": 0.0,
11
  "bos_token_id": 1,
12
  "canon_activation": false,
13
  "canon_bias": false,
@@ -15,26 +15,26 @@
15
  "canon_residual": true,
16
  "canon_set": "AC",
17
  "eos_token_id": 2,
18
- "head_dim": 64,
19
  "hidden_act": "silu",
20
  "hidden_size": 384,
21
  "initializer_range": 0.02,
22
  "intermediate_size": 1024,
23
  "mask_epsilon": 0.001,
24
  "mask_token_id": 50256,
25
- "max_position_embeddings": 2048,
26
  "model_type": "dhara",
27
- "num_attention_heads": 6,
28
  "num_diffusion_steps": 1000,
29
  "num_hidden_layers": 32,
30
- "num_key_value_heads": 6,
31
  "pad_token_id": 0,
32
- "rms_norm_eps": 1e-05,
33
  "rope_theta": 10000.0,
34
- "torch_dtype": "float32",
35
  "transformers_version": "4.55.2",
36
  "use_cache": false,
37
  "use_flash_attention": false,
38
  "use_xformers": false,
39
- "vocab_size": 50257
40
  }
 
2
  "architectures": [
3
  "DharaForMaskedDiffusion"
4
  ],
5
+ "attention_dropout": 0.0,
6
  "auto_map": {
7
  "AutoConfig": "modeling_dhara.DharaConfig",
8
  "AutoModel": "modeling_dhara.DharaForMaskedDiffusion",
9
  "AutoModelForCausalLM": "modeling_dhara.DharaForMaskedDiffusion"
10
  },
 
11
  "bos_token_id": 1,
12
  "canon_activation": false,
13
  "canon_bias": false,
 
15
  "canon_residual": true,
16
  "canon_set": "AC",
17
  "eos_token_id": 2,
18
+ "head_dim": 48,
19
  "hidden_act": "silu",
20
  "hidden_size": 384,
21
  "initializer_range": 0.02,
22
  "intermediate_size": 1024,
23
  "mask_epsilon": 0.001,
24
  "mask_token_id": 50256,
25
+ "max_position_embeddings": 1024,
26
  "model_type": "dhara",
27
+ "num_attention_heads": 8,
28
  "num_diffusion_steps": 1000,
29
  "num_hidden_layers": 32,
30
+ "num_key_value_heads": 4,
31
  "pad_token_id": 0,
32
+ "rms_norm_eps": 1e-06,
33
  "rope_theta": 10000.0,
34
+ "torch_dtype": "bfloat16",
35
  "transformers_version": "4.55.2",
36
  "use_cache": false,
37
  "use_flash_attention": false,
38
  "use_xformers": false,
39
+ "vocab_size": 50304
40
  }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:138820db3e8e59ed037924f14f9739ca9667e406465fc236fa9765691386f5fc
3
- size 304219496
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e67749b3a03df19a3a36cbf3994997df649b5c2044cbf581a22e5eb8f473975f
3
+ size 142728304
modeling_dhara.py CHANGED
@@ -1,24 +1,29 @@
1
  #!/usr/bin/env python3
2
  """
3
- Dhara: Diffusion Language Model
4
 
5
- A diffusion-based language model that combines:
6
- 1. Masked diffusion training (MDM) with bidirectional attention
7
- 2. Canon layers for local context mixing via causal depthwise convolutions
8
- 3. High-throughput parallel token generation
9
 
10
- Usage:
11
- from transformers import AutoModel, AutoTokenizer
12
- model = AutoModel.from_pretrained("codelion/dhara-70m", trust_remote_code=True)
13
- tokenizer = AutoTokenizer.from_pretrained("codelion/dhara-70m")
 
 
 
 
14
  """
15
 
16
  import math
17
- from typing import Optional, Tuple, Union
 
18
 
19
  import torch
20
  import torch.nn as nn
21
  import torch.nn.functional as F
 
22
 
23
  from transformers import PreTrainedModel
24
  from transformers.generation import GenerationMixin
@@ -36,12 +41,18 @@ try:
36
  except ImportError:
37
  FLASH_ATTN_AVAILABLE = False
38
 
 
 
 
 
 
 
39
 
40
  class DharaConfig(PretrainedConfig):
41
  """
42
  Configuration for Dhara model.
43
 
44
- Dhara is a diffusion language model with Canon layers for local context mixing.
45
  """
46
 
47
  model_type = "dhara"
@@ -49,33 +60,33 @@ class DharaConfig(PretrainedConfig):
49
  def __init__(
50
  self,
51
  # Core architecture
52
- vocab_size: int = 50257,
53
  hidden_size: int = 384,
54
  num_hidden_layers: int = 32,
55
- num_attention_heads: int = 6,
56
- num_key_value_heads: int = 6,
57
  intermediate_size: int = 1024,
58
  head_dim: int = None,
59
  max_position_embeddings: int = 2048,
60
 
61
  # Model specifics
62
  hidden_act: str = "silu",
63
- rms_norm_eps: float = 1e-5,
64
  rope_theta: float = 10000.0,
65
  initializer_range: float = 0.02,
66
  tie_word_embeddings: bool = True,
67
  attention_dropout: float = 0.0,
68
 
69
  # Canon layer parameters
70
- canon_set: str = "AC",
71
- canon_kernel: int = 4,
72
- canon_residual: bool = True,
73
- canon_activation: bool = False,
74
  canon_bias: bool = False,
75
 
76
  # Diffusion specific
77
- mask_token_id: int = 50256,
78
- mask_epsilon: float = 0.001,
79
  num_diffusion_steps: int = 1000,
80
 
81
  # Special tokens
@@ -85,7 +96,7 @@ class DharaConfig(PretrainedConfig):
85
 
86
  # Performance flags
87
  use_cache: bool = False,
88
- use_flash_attention: bool = False,
89
  use_xformers: bool = False,
90
 
91
  **kwargs
@@ -98,6 +109,7 @@ class DharaConfig(PretrainedConfig):
98
  **kwargs
99
  )
100
 
 
101
  self.vocab_size = vocab_size
102
  self.hidden_size = hidden_size
103
  self.num_hidden_layers = num_hidden_layers
@@ -107,29 +119,44 @@ class DharaConfig(PretrainedConfig):
107
  self.head_dim = head_dim or (hidden_size // num_attention_heads)
108
  self.max_position_embeddings = max_position_embeddings
109
 
 
110
  self.hidden_act = hidden_act
111
  self.rms_norm_eps = rms_norm_eps
112
  self.rope_theta = rope_theta
113
  self.initializer_range = initializer_range
 
114
  self.attention_dropout = attention_dropout
115
 
 
116
  self.canon_set = canon_set
117
  self.canon_kernel = canon_kernel
118
  self.canon_residual = canon_residual
119
  self.canon_activation = canon_activation
120
  self.canon_bias = canon_bias
121
 
122
- self.mask_token_id = mask_token_id
 
123
  self.mask_epsilon = mask_epsilon
124
  self.num_diffusion_steps = num_diffusion_steps
125
 
 
 
 
 
 
 
126
  self.use_cache = use_cache
127
  self.use_flash_attention = use_flash_attention
128
  self.use_xformers = use_xformers
129
 
130
 
131
  class CanonLayer(nn.Module):
132
- """Causal 1D depthwise convolution for local context mixing."""
 
 
 
 
 
133
 
134
  def __init__(
135
  self,
@@ -145,29 +172,49 @@ class CanonLayer(nn.Module):
145
  self.use_residual = use_residual
146
  self.use_activation = use_activation
147
 
 
148
  self.conv = nn.Conv1d(
149
  in_channels=hidden_size,
150
  out_channels=hidden_size,
151
  kernel_size=kernel_size,
152
- padding=kernel_size - 1,
153
- groups=hidden_size,
154
  bias=use_bias,
155
  )
156
 
 
157
  nn.init.normal_(self.conv.weight, mean=0.0, std=0.02)
158
  if use_bias:
159
  nn.init.zeros_(self.conv.bias)
160
 
161
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
162
  batch_size, seq_len, hidden_size = hidden_states.shape
 
 
163
  x = hidden_states.transpose(1, 2)
 
 
164
  out = self.conv(x)
 
165
  out = out[:, :, :seq_len]
 
 
166
  if self.use_activation:
167
  out = F.silu(out)
 
 
168
  out = out.transpose(1, 2)
 
 
169
  if self.use_residual:
170
  out = hidden_states + out
 
171
  return out
172
 
173
 
@@ -206,6 +253,7 @@ class RotaryEmbedding(nn.Module):
206
  def _set_cos_sin_cache(self, seq_len, device, dtype):
207
  self.max_seq_len_cached = seq_len
208
  t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
 
209
  freqs = torch.outer(t, self.inv_freq)
210
  emb = torch.cat((freqs, freqs), dim=-1)
211
  self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
@@ -214,6 +262,7 @@ class RotaryEmbedding(nn.Module):
214
  def forward(self, x, seq_len=None):
215
  if seq_len > self.max_seq_len_cached:
216
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
 
217
  return (
218
  self.cos_cached[:seq_len].to(dtype=x.dtype),
219
  self.sin_cached[:seq_len].to(dtype=x.dtype),
@@ -221,14 +270,17 @@ class RotaryEmbedding(nn.Module):
221
 
222
 
223
  def rotate_half(x):
 
224
  x1 = x[..., : x.shape[-1] // 2]
225
  x2 = x[..., x.shape[-1] // 2 :]
226
  return torch.cat((-x2, x1), dim=-1)
227
 
228
 
229
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
 
230
  cos = cos[position_ids].unsqueeze(unsqueeze_dim)
231
  sin = sin[position_ids].unsqueeze(unsqueeze_dim)
 
232
  cos = cos.to(q.dtype)
233
  sin = sin.to(q.dtype)
234
  q_embed = (q * cos) + (rotate_half(q) * sin)
@@ -241,12 +293,14 @@ class DharaMLP(nn.Module):
241
 
242
  def __init__(self, config):
243
  super().__init__()
 
244
  self.hidden_size = config.hidden_size
245
  self.intermediate_size = config.intermediate_size
246
 
247
  self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
248
  self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
249
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
 
250
  self.act_fn = nn.SiLU()
251
 
252
  def forward(self, x):
@@ -254,6 +308,7 @@ class DharaMLP(nn.Module):
254
 
255
 
256
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 
257
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
258
  if n_rep == 1:
259
  return hidden_states
@@ -262,7 +317,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
262
 
263
 
264
  class DharaAttention(nn.Module):
265
- """Multi-Head Bidirectional Attention with GQA support"""
266
 
267
  def __init__(self, config: DharaConfig, layer_idx: Optional[int] = None):
268
  super().__init__()
@@ -277,7 +332,13 @@ class DharaAttention(nn.Module):
277
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
278
  self.max_position_embeddings = config.max_position_embeddings
279
  self.rope_theta = config.rope_theta
280
- self.is_causal = False # Bidirectional for diffusion
 
 
 
 
 
 
281
 
282
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
283
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
@@ -311,6 +372,12 @@ class DharaAttention(nn.Module):
311
 
312
  kv_seq_len = key_states.shape[-2]
313
  if past_key_value is not None:
 
 
 
 
 
 
314
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
315
 
316
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
@@ -323,6 +390,7 @@ class DharaAttention(nn.Module):
323
  key_states = repeat_kv(key_states, self.num_key_value_groups)
324
  value_states = repeat_kv(value_states, self.num_key_value_groups)
325
 
 
326
  if FLASH_ATTN_AVAILABLE and self.config.use_flash_attention:
327
  query_states = query_states.transpose(1, 2).contiguous()
328
  key_states = key_states.transpose(1, 2).contiguous()
@@ -334,17 +402,26 @@ class DharaAttention(nn.Module):
334
  value_states = value_states.to(torch.bfloat16)
335
 
336
  attn_output = flash_attn_func(
337
- query_states, key_states, value_states,
338
- dropout_p=0.0, causal=False,
 
 
 
339
  )
 
340
  attn_output = attn_output.view(bsz, q_len, self.hidden_size)
 
341
  else:
 
342
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 
343
  if attention_mask is not None:
344
  attn_weights = attn_weights + attention_mask
 
345
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
346
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
347
  attn_output = torch.matmul(attn_weights, value_states)
 
348
  attn_output = attn_output.transpose(1, 2).contiguous()
349
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
350
 
@@ -357,15 +434,23 @@ class DharaAttention(nn.Module):
357
 
358
 
359
  class DharaDecoderLayer(nn.Module):
360
- """Dhara decoder layer with Canon layers"""
 
 
 
 
 
 
361
 
362
  def __init__(self, config: DharaConfig, layer_idx: int):
363
  super().__init__()
364
  self.hidden_size = config.hidden_size
365
  self.config = config
366
 
 
367
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
368
 
 
369
  self.canon_a = None
370
  if "A" in config.canon_set:
371
  self.canon_a = CanonLayer(
@@ -376,9 +461,13 @@ class DharaDecoderLayer(nn.Module):
376
  use_bias=config.canon_bias,
377
  )
378
 
 
379
  self.self_attn = DharaAttention(config=config, layer_idx=layer_idx)
 
 
380
  self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
381
 
 
382
  self.canon_c = None
383
  if "C" in config.canon_set:
384
  self.canon_c = CanonLayer(
@@ -389,6 +478,7 @@ class DharaDecoderLayer(nn.Module):
389
  use_bias=config.canon_bias,
390
  )
391
 
 
392
  self.mlp = DharaMLP(config)
393
 
394
  def forward(
@@ -401,11 +491,15 @@ class DharaDecoderLayer(nn.Module):
401
  use_cache: Optional[bool] = False,
402
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
403
  residual = hidden_states
 
 
404
  hidden_states = self.input_layernorm(hidden_states)
405
 
 
406
  if self.canon_a is not None:
407
  hidden_states = self.canon_a(hidden_states)
408
 
 
409
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
410
  hidden_states=hidden_states,
411
  attention_mask=attention_mask,
@@ -416,9 +510,11 @@ class DharaDecoderLayer(nn.Module):
416
  )
417
  hidden_states = residual + hidden_states
418
 
 
419
  residual = hidden_states
420
  hidden_states = self.post_attention_layernorm(hidden_states)
421
 
 
422
  if self.canon_c is not None:
423
  hidden_states = self.canon_c(hidden_states)
424
 
@@ -426,8 +522,10 @@ class DharaDecoderLayer(nn.Module):
426
  hidden_states = residual + hidden_states
427
 
428
  outputs = (hidden_states,)
 
429
  if output_attentions:
430
  outputs += (self_attn_weights,)
 
431
  if use_cache:
432
  outputs += (present_key_value,)
433
 
@@ -456,7 +554,9 @@ class DharaPreTrainedModel(PreTrainedModel):
456
 
457
 
458
  class DharaModel(DharaPreTrainedModel):
459
- """Dhara base model with bidirectional attention and Canon layers."""
 
 
460
 
461
  def __init__(self, config: DharaConfig):
462
  super().__init__(config)
@@ -467,6 +567,7 @@ class DharaModel(DharaPreTrainedModel):
467
  self.layers = nn.ModuleList(
468
  [DharaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
469
  )
 
470
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
471
  self.gradient_checkpointing = False
472
 
@@ -495,12 +596,14 @@ class DharaModel(DharaPreTrainedModel):
495
  return_dict: Optional[bool] = None,
496
  ) -> Union[Tuple, BaseModelOutputWithPast]:
497
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
498
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
499
  use_cache = use_cache if use_cache is not None else self.config.use_cache
500
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
501
 
502
  if input_ids is not None and inputs_embeds is not None:
503
- raise ValueError("You cannot specify both input_ids and inputs_embeds")
504
  elif input_ids is not None:
505
  batch_size, seq_length = input_ids.shape[:2]
506
  elif inputs_embeds is not None:
@@ -508,8 +611,12 @@ class DharaModel(DharaPreTrainedModel):
508
  else:
509
  raise ValueError("You have to specify either input_ids or inputs_embeds")
510
 
511
- if self.gradient_checkpointing and self.training and use_cache:
512
- use_cache = False
 
 
 
 
513
 
514
  past_key_values_length = 0
515
  if use_cache:
@@ -531,8 +638,10 @@ class DharaModel(DharaPreTrainedModel):
531
  if self._use_flash_attention_2:
532
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
533
  else:
 
534
  if attention_mask is not None:
535
  if attention_mask.dim() == 2:
 
536
  attention_mask_4d = attention_mask[:, None, None, :].expand(
537
  batch_size, 1, seq_length, seq_length
538
  ).to(dtype=inputs_embeds.dtype)
@@ -541,8 +650,13 @@ class DharaModel(DharaPreTrainedModel):
541
  torch.tensor(float('-inf'), dtype=inputs_embeds.dtype, device=attention_mask_4d.device),
542
  torch.tensor(0.0, dtype=inputs_embeds.dtype, device=attention_mask_4d.device)
543
  )
 
 
 
 
544
 
545
  hidden_states = inputs_embeds
 
546
  all_hidden_states = () if output_hidden_states else None
547
  all_self_attns = () if output_attentions else None
548
  next_decoder_cache = None
@@ -554,8 +668,12 @@ class DharaModel(DharaPreTrainedModel):
554
  if self.gradient_checkpointing and self.training:
555
  layer_outputs = self._gradient_checkpointing_func(
556
  decoder_layer.__call__,
557
- hidden_states, attention_mask, position_ids,
558
- past_key_values, output_attentions, use_cache,
 
 
 
 
559
  )
560
  else:
561
  layer_outputs = decoder_layer(
@@ -571,6 +689,7 @@ class DharaModel(DharaPreTrainedModel):
571
 
572
  if use_cache:
573
  next_decoder_cache = layer_outputs[2 if output_attentions else 1]
 
574
  if output_attentions:
575
  all_self_attns += (layer_outputs[1],)
576
 
@@ -594,23 +713,36 @@ class DharaModel(DharaPreTrainedModel):
594
  )
595
 
596
  def add_noise_to_tokens(self, input_ids: torch.LongTensor, t: torch.FloatTensor, eps: float = None):
597
- """MDM-style masking: Replace tokens with [MASK] based on noise level t."""
 
 
 
 
 
 
 
 
 
 
598
  batch_size, seq_len = input_ids.shape
599
  device = input_ids.device
600
 
601
  if eps is None:
602
  eps = getattr(self.config, 'mask_epsilon', 0.001)
603
  p_mask = (1 - eps) * t + eps
 
604
  p_mask = p_mask.unsqueeze(-1).expand(batch_size, seq_len)
605
 
606
  corruption_mask = torch.rand(batch_size, seq_len, device=device) < p_mask
607
- noisy_input_ids = torch.where(corruption_mask, self.mask_token_id, input_ids)
 
 
608
 
609
  return noisy_input_ids, corruption_mask, p_mask
610
 
611
 
612
  class DharaForMaskedDiffusion(DharaPreTrainedModel, GenerationMixin):
613
- """Dhara Model with Masked Diffusion head for training and inference"""
614
  _tied_weights_keys = ["lm_head.weight"]
615
 
616
  def __init__(self, config):
@@ -636,6 +768,9 @@ class DharaForMaskedDiffusion(DharaPreTrainedModel, GenerationMixin):
636
  def set_output_embeddings(self, new_embeddings):
637
  self.lm_head = new_embeddings
638
 
 
 
 
639
  def get_decoder(self):
640
  return self.model
641
 
@@ -655,7 +790,9 @@ class DharaForMaskedDiffusion(DharaPreTrainedModel, GenerationMixin):
655
  p_mask: Optional[torch.Tensor] = None,
656
  ) -> Union[Tuple, MaskedLMOutput]:
657
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
658
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
 
659
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
660
 
661
  outputs = self.model(
@@ -693,9 +830,13 @@ class DharaForMaskedDiffusion(DharaPreTrainedModel, GenerationMixin):
693
  )
694
 
695
  def compute_diffusion_loss(self, logits, labels, corruption_mask=None, p_mask=None):
696
- """MDM loss with p_mask importance weighting."""
 
 
697
  if corruption_mask is None or p_mask is None:
698
- raise ValueError("MDM requires both corruption_mask and p_mask for loss computation.")
 
 
699
 
700
  loss = F.cross_entropy(
701
  logits.view(-1, self.config.vocab_size),
@@ -706,6 +847,7 @@ class DharaForMaskedDiffusion(DharaPreTrainedModel, GenerationMixin):
706
 
707
  masked_losses = loss[corruption_mask]
708
  masked_p_mask = p_mask[corruption_mask]
 
709
  weighted_losses = masked_losses / masked_p_mask
710
 
711
  total_positions = labels.shape[0] * labels.shape[1]
@@ -728,11 +870,15 @@ class DharaForMaskedDiffusion(DharaPreTrainedModel, GenerationMixin):
728
  max_cache_length = None
729
 
730
  if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
731
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
732
  elif past_length < input_ids.shape[1]:
733
  input_ids = input_ids[:, past_length:]
734
 
735
- if max_cache_length is not None and attention_mask is not None and cache_length + input_ids.shape[1] > max_cache_length:
 
 
 
 
736
  attention_mask = attention_mask[:, -max_cache_length:]
737
 
738
  position_ids = kwargs.get("position_ids", None)
@@ -740,21 +886,32 @@ class DharaForMaskedDiffusion(DharaPreTrainedModel, GenerationMixin):
740
  position_ids = attention_mask.long().cumsum(-1) - 1
741
  position_ids.masked_fill_(attention_mask == 0, 1)
742
  if past_key_values:
743
- position_ids = position_ids[:, -input_ids.shape[1]:]
744
 
745
  if inputs_embeds is not None and past_key_values is None:
746
  model_inputs = {"inputs_embeds": inputs_embeds}
747
  else:
748
  model_inputs = {"input_ids": input_ids}
749
 
750
- model_inputs.update({
751
- "position_ids": position_ids,
752
- "past_key_values": past_key_values,
753
- "use_cache": kwargs.get("use_cache"),
754
- "attention_mask": attention_mask,
755
- })
 
 
756
  return model_inputs
757
 
 
 
 
 
 
 
 
 
 
758
  @torch.no_grad()
759
  def generate(
760
  self,
@@ -764,27 +921,32 @@ class DharaForMaskedDiffusion(DharaPreTrainedModel, GenerationMixin):
764
  num_diffusion_steps: int = 10,
765
  temperature: float = 1.0,
766
  top_p: float = 0.9,
 
767
  do_sample: bool = True,
768
  pad_token_id: Optional[int] = None,
769
  eos_token_id: Optional[int] = None,
 
770
  **kwargs
771
  ) -> torch.LongTensor:
772
  """
773
- Generate text using masked diffusion sampling.
774
 
775
- This method performs iterative denoising: starting from fully masked tokens,
776
- it progressively unmasks positions based on model confidence.
 
777
 
778
  Args:
779
  input_ids: Input prompt token IDs [batch_size, prompt_len]
780
  max_length: Maximum total sequence length (prompt + generation)
781
  max_new_tokens: Number of new tokens to generate (alternative to max_length)
782
- num_diffusion_steps: Number of denoising iterations (more = higher quality, slower)
783
  temperature: Sampling temperature (higher = more random)
784
  top_p: Nucleus sampling threshold
 
785
  do_sample: Whether to sample or take argmax
786
  pad_token_id: Token ID for padding
787
  eos_token_id: Token ID for end of sequence
 
788
 
789
  Returns:
790
  Generated token IDs including the prompt
@@ -816,97 +978,134 @@ class DharaForMaskedDiffusion(DharaPreTrainedModel, GenerationMixin):
816
  if eos_token_id is None:
817
  eos_token_id = self.config.eos_token_id if hasattr(self.config, 'eos_token_id') else 2
818
 
819
- # Initialize: prompt + masked tokens for generation
820
- total_len = prompt_len + gen_len
821
- tokens = torch.full((batch_size, total_len), mask_token_id, dtype=torch.long, device=device)
822
- tokens[:, :prompt_len] = input_ids
823
-
824
- # Track which positions are masked (need generation)
825
- is_masked = torch.ones(batch_size, total_len, dtype=torch.bool, device=device)
826
- is_masked[:, :prompt_len] = False # Prompt is not masked
827
-
828
- # Number of tokens to unmask per step
829
- tokens_per_step = max(1, gen_len // num_diffusion_steps)
830
 
831
- # Iterative denoising
832
- for step in range(num_diffusion_steps):
833
- # Forward pass to get logits
834
- outputs = self(input_ids=tokens)
 
 
 
 
 
 
 
 
 
 
 
 
835
  logits = outputs.logits # [batch, seq_len, vocab]
836
 
837
- # Only consider masked positions
838
- masked_positions = is_masked.clone()
839
 
840
- if not masked_positions.any():
841
- break # All tokens have been generated
 
 
 
 
842
 
843
  # Apply temperature
844
- if temperature != 1.0:
845
- logits = logits / temperature
846
-
847
- # Get probabilities
848
- probs = F.softmax(logits, dim=-1)
849
-
850
- # Calculate confidence (max prob) for each position
851
- confidence, _ = probs.max(dim=-1) # [batch, seq_len]
852
-
853
- # Mask out already-generated positions from confidence calculation
854
- confidence = confidence.masked_fill(~masked_positions, -float('inf'))
855
-
856
- # Determine how many tokens to unmask this step
857
- remaining_masked = masked_positions.sum(dim=1) # [batch]
858
-
859
- # For the last step, unmask everything remaining
860
- if step == num_diffusion_steps - 1:
861
- num_to_unmask = remaining_masked
 
 
 
 
 
 
 
 
 
862
  else:
863
- num_to_unmask = torch.minimum(
864
- torch.tensor(tokens_per_step, device=device).expand(batch_size),
865
- remaining_masked
866
- )
867
-
868
- # For each batch item, unmask the highest confidence positions
869
- for b in range(batch_size):
870
- if num_to_unmask[b] == 0:
871
- continue
872
-
873
- # Get confidence scores for this batch item
874
- conf_b = confidence[b] # [seq_len]
875
-
876
- # Get top-k positions with highest confidence
877
- k = int(num_to_unmask[b].item())
878
- _, top_indices = conf_b.topk(k)
879
-
880
- # Sample or argmax for these positions
881
- for idx in top_indices:
882
- pos_logits = logits[b, idx] # [vocab]
883
-
884
- if do_sample and temperature > 0:
885
- # Top-p (nucleus) sampling
886
- sorted_logits, sorted_indices = torch.sort(pos_logits, descending=True)
887
- sorted_probs = F.softmax(sorted_logits, dim=-1)
888
- cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
889
-
890
- # Remove tokens with cumulative probability above top_p
891
- sorted_indices_to_remove = cumsum_probs > top_p
892
- sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
893
- sorted_indices_to_remove[0] = False
894
 
895
- sorted_logits[sorted_indices_to_remove] = float('-inf')
896
- probs_filtered = F.softmax(sorted_logits, dim=-1)
897
 
898
- # Sample
899
- sampled_idx = torch.multinomial(probs_filtered, 1)
900
- token_id = sorted_indices[sampled_idx]
901
- else:
902
- # Greedy (argmax)
903
- token_id = pos_logits.argmax()
904
 
905
- tokens[b, idx] = token_id
906
- is_masked[b, idx] = False
 
907
 
908
- return tokens
909
 
910
  def save_pretrained(self, save_directory, **kwargs):
 
911
  kwargs['safe_serialization'] = kwargs.get('safe_serialization', True)
912
  return super().save_pretrained(save_directory, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ Dhara: Diffusion LLM with Canon Layers
4
 
5
+ Combines:
6
+ 1. Dhara's masked diffusion training (bidirectional attention, high throughput)
7
+ 2. Canon layers (local context mixing via causal depthwise convolutions)
 
8
 
9
+ Canon layers from "Physics of Language Models: Part 4.1" by Zeyuan Allen-Zhu:
10
+ - Position A: After input LayerNorm, before attention
11
+ - Position C: After post-attention LayerNorm, before MLP
12
+ - kernel_size=4, residual=True, activation=False (default)
13
+
14
+ Expected benefits:
15
+ - ~280-290 tok/s throughput (Dhara's parallel generation)
16
+ - +0.25-0.5% accuracy improvement (Canon's local context mixing)
17
  """
18
 
19
  import math
20
+ import warnings
21
+ from typing import Optional, Tuple, Union, List
22
 
23
  import torch
24
  import torch.nn as nn
25
  import torch.nn.functional as F
26
+ from torch.nn import CrossEntropyLoss
27
 
28
  from transformers import PreTrainedModel
29
  from transformers.generation import GenerationMixin
 
41
  except ImportError:
42
  FLASH_ATTN_AVAILABLE = False
43
 
44
+ try:
45
+ import xformers.ops as xops
46
+ XFORMERS_AVAILABLE = True
47
+ except ImportError:
48
+ XFORMERS_AVAILABLE = False
49
+
50
 
51
  class DharaConfig(PretrainedConfig):
52
  """
53
  Configuration for Dhara model.
54
 
55
+ Combines Dhara diffusion config with Canon layer parameters.
56
  """
57
 
58
  model_type = "dhara"
 
60
  def __init__(
61
  self,
62
  # Core architecture
63
+ vocab_size: int = 50304,
64
  hidden_size: int = 384,
65
  num_hidden_layers: int = 32,
66
+ num_attention_heads: int = 8,
67
+ num_key_value_heads: int = 4,
68
  intermediate_size: int = 1024,
69
  head_dim: int = None,
70
  max_position_embeddings: int = 2048,
71
 
72
  # Model specifics
73
  hidden_act: str = "silu",
74
+ rms_norm_eps: float = 1e-6,
75
  rope_theta: float = 10000.0,
76
  initializer_range: float = 0.02,
77
  tie_word_embeddings: bool = True,
78
  attention_dropout: float = 0.0,
79
 
80
  # Canon layer parameters
81
+ canon_set: str = "AC", # Positions: A (before attn), C (before MLP)
82
+ canon_kernel: int = 4, # Kernel size (2-4)
83
+ canon_residual: bool = True, # Highly recommended
84
+ canon_activation: bool = False, # NOT recommended for transformers
85
  canon_bias: bool = False,
86
 
87
  # Diffusion specific
88
+ mask_token_id: int = None, # Will be set from tokenizer
89
+ mask_epsilon: float = 0.001, # Minimum mask probability
90
  num_diffusion_steps: int = 1000,
91
 
92
  # Special tokens
 
96
 
97
  # Performance flags
98
  use_cache: bool = False,
99
+ use_flash_attention: bool = True,
100
  use_xformers: bool = False,
101
 
102
  **kwargs
 
109
  **kwargs
110
  )
111
 
112
+ # Core architecture
113
  self.vocab_size = vocab_size
114
  self.hidden_size = hidden_size
115
  self.num_hidden_layers = num_hidden_layers
 
119
  self.head_dim = head_dim or (hidden_size // num_attention_heads)
120
  self.max_position_embeddings = max_position_embeddings
121
 
122
+ # Model specifics
123
  self.hidden_act = hidden_act
124
  self.rms_norm_eps = rms_norm_eps
125
  self.rope_theta = rope_theta
126
  self.initializer_range = initializer_range
127
+ self.tie_word_embeddings = tie_word_embeddings
128
  self.attention_dropout = attention_dropout
129
 
130
+ # Canon parameters
131
  self.canon_set = canon_set
132
  self.canon_kernel = canon_kernel
133
  self.canon_residual = canon_residual
134
  self.canon_activation = canon_activation
135
  self.canon_bias = canon_bias
136
 
137
+ # Diffusion specific
138
+ self.mask_token_id = mask_token_id if mask_token_id is not None else (vocab_size - 1)
139
  self.mask_epsilon = mask_epsilon
140
  self.num_diffusion_steps = num_diffusion_steps
141
 
142
+ # Special tokens
143
+ self.bos_token_id = bos_token_id
144
+ self.eos_token_id = eos_token_id
145
+ self.pad_token_id = pad_token_id
146
+
147
+ # Performance
148
  self.use_cache = use_cache
149
  self.use_flash_attention = use_flash_attention
150
  self.use_xformers = use_xformers
151
 
152
 
153
  class CanonLayer(nn.Module):
154
+ """
155
+ Canon Layer: Causal 1D depthwise convolution for local context mixing.
156
+
157
+ From "Physics of Language Models: Part 4.1" by Zeyuan Allen-Zhu.
158
+ Captures local sequential dependencies with O(n) complexity.
159
+ """
160
 
161
  def __init__(
162
  self,
 
172
  self.use_residual = use_residual
173
  self.use_activation = use_activation
174
 
175
+ # Depthwise causal convolution
176
  self.conv = nn.Conv1d(
177
  in_channels=hidden_size,
178
  out_channels=hidden_size,
179
  kernel_size=kernel_size,
180
+ padding=kernel_size - 1, # Causal (left-pad)
181
+ groups=hidden_size, # Depthwise
182
  bias=use_bias,
183
  )
184
 
185
+ # Initialize for stability
186
  nn.init.normal_(self.conv.weight, mean=0.0, std=0.02)
187
  if use_bias:
188
  nn.init.zeros_(self.conv.bias)
189
 
190
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
191
+ """
192
+ Args:
193
+ hidden_states: [batch_size, seq_len, hidden_size]
194
+ Returns:
195
+ output: [batch_size, seq_len, hidden_size]
196
+ """
197
  batch_size, seq_len, hidden_size = hidden_states.shape
198
+
199
+ # Transpose for Conv1d: [B, H, L]
200
  x = hidden_states.transpose(1, 2)
201
+
202
+ # Apply conv with causal padding
203
  out = self.conv(x)
204
+ # Remove right padding to make it causal
205
  out = out[:, :, :seq_len]
206
+
207
+ # Optional activation
208
  if self.use_activation:
209
  out = F.silu(out)
210
+
211
+ # Transpose back: [B, L, H]
212
  out = out.transpose(1, 2)
213
+
214
+ # Residual connection
215
  if self.use_residual:
216
  out = hidden_states + out
217
+
218
  return out
219
 
220
 
 
253
  def _set_cos_sin_cache(self, seq_len, device, dtype):
254
  self.max_seq_len_cached = seq_len
255
  t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
256
+
257
  freqs = torch.outer(t, self.inv_freq)
258
  emb = torch.cat((freqs, freqs), dim=-1)
259
  self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
 
262
  def forward(self, x, seq_len=None):
263
  if seq_len > self.max_seq_len_cached:
264
  self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
265
+
266
  return (
267
  self.cos_cached[:seq_len].to(dtype=x.dtype),
268
  self.sin_cached[:seq_len].to(dtype=x.dtype),
 
270
 
271
 
272
  def rotate_half(x):
273
+ """Rotates half the hidden dims of the input."""
274
  x1 = x[..., : x.shape[-1] // 2]
275
  x2 = x[..., x.shape[-1] // 2 :]
276
  return torch.cat((-x2, x1), dim=-1)
277
 
278
 
279
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
280
+ """Applies Rotary Position Embedding to query and key tensors."""
281
  cos = cos[position_ids].unsqueeze(unsqueeze_dim)
282
  sin = sin[position_ids].unsqueeze(unsqueeze_dim)
283
+ # Cast to input dtype for consistency
284
  cos = cos.to(q.dtype)
285
  sin = sin.to(q.dtype)
286
  q_embed = (q * cos) + (rotate_half(q) * sin)
 
293
 
294
  def __init__(self, config):
295
  super().__init__()
296
+ self.config = config
297
  self.hidden_size = config.hidden_size
298
  self.intermediate_size = config.intermediate_size
299
 
300
  self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
301
  self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
302
  self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
303
+
304
  self.act_fn = nn.SiLU()
305
 
306
  def forward(self, x):
 
308
 
309
 
310
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
311
+ """Repeat KV heads for GQA."""
312
  batch, num_key_value_heads, slen, head_dim = hidden_states.shape
313
  if n_rep == 1:
314
  return hidden_states
 
317
 
318
 
319
  class DharaAttention(nn.Module):
320
+ """Multi-Head Bidirectional Attention with GQA support (for diffusion)"""
321
 
322
  def __init__(self, config: DharaConfig, layer_idx: Optional[int] = None):
323
  super().__init__()
 
332
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
333
  self.max_position_embeddings = config.max_position_embeddings
334
  self.rope_theta = config.rope_theta
335
+ self.is_causal = False # CRITICAL: Dhara uses bidirectional attention
336
+
337
+ if (self.head_dim * self.num_heads) != self.hidden_size:
338
+ raise ValueError(
339
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
340
+ f" and `num_heads`: {self.num_heads})."
341
+ )
342
 
343
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
344
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
 
372
 
373
  kv_seq_len = key_states.shape[-2]
374
  if past_key_value is not None:
375
+ if self.layer_idx is None:
376
+ raise ValueError(
377
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
378
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
379
+ "with a layer index."
380
+ )
381
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
382
 
383
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 
390
  key_states = repeat_kv(key_states, self.num_key_value_groups)
391
  value_states = repeat_kv(value_states, self.num_key_value_groups)
392
 
393
+ # Flash Attention for bidirectional
394
  if FLASH_ATTN_AVAILABLE and self.config.use_flash_attention:
395
  query_states = query_states.transpose(1, 2).contiguous()
396
  key_states = key_states.transpose(1, 2).contiguous()
 
402
  value_states = value_states.to(torch.bfloat16)
403
 
404
  attn_output = flash_attn_func(
405
+ query_states,
406
+ key_states,
407
+ value_states,
408
+ dropout_p=0.0,
409
+ causal=False, # Bidirectional for diffusion
410
  )
411
+
412
  attn_output = attn_output.view(bsz, q_len, self.hidden_size)
413
+
414
  else:
415
+ # Standard attention
416
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
417
+
418
  if attention_mask is not None:
419
  attn_weights = attn_weights + attention_mask
420
+
421
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
422
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
423
  attn_output = torch.matmul(attn_weights, value_states)
424
+
425
  attn_output = attn_output.transpose(1, 2).contiguous()
426
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
427
 
 
434
 
435
 
436
  class DharaDecoderLayer(nn.Module):
437
+ """
438
+ Dhara decoder layer with Canon layers at positions A and C.
439
+
440
+ Flow:
441
+ x -> LayerNorm -> [CanonA] -> Attention -> + residual
442
+ x -> LayerNorm -> [CanonC] -> MLP -> + residual
443
+ """
444
 
445
  def __init__(self, config: DharaConfig, layer_idx: int):
446
  super().__init__()
447
  self.hidden_size = config.hidden_size
448
  self.config = config
449
 
450
+ # Pre-attention norm
451
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
452
 
453
+ # Canon-A: before attention
454
  self.canon_a = None
455
  if "A" in config.canon_set:
456
  self.canon_a = CanonLayer(
 
461
  use_bias=config.canon_bias,
462
  )
463
 
464
+ # Attention
465
  self.self_attn = DharaAttention(config=config, layer_idx=layer_idx)
466
+
467
+ # Post-attention norm
468
  self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
469
 
470
+ # Canon-C: before MLP
471
  self.canon_c = None
472
  if "C" in config.canon_set:
473
  self.canon_c = CanonLayer(
 
478
  use_bias=config.canon_bias,
479
  )
480
 
481
+ # MLP
482
  self.mlp = DharaMLP(config)
483
 
484
  def forward(
 
491
  use_cache: Optional[bool] = False,
492
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
493
  residual = hidden_states
494
+
495
+ # Pre-attention layernorm
496
  hidden_states = self.input_layernorm(hidden_states)
497
 
498
+ # Canon-A (before attention)
499
  if self.canon_a is not None:
500
  hidden_states = self.canon_a(hidden_states)
501
 
502
+ # Self Attention (bidirectional)
503
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
504
  hidden_states=hidden_states,
505
  attention_mask=attention_mask,
 
510
  )
511
  hidden_states = residual + hidden_states
512
 
513
+ # MLP block
514
  residual = hidden_states
515
  hidden_states = self.post_attention_layernorm(hidden_states)
516
 
517
+ # Canon-C (before MLP)
518
  if self.canon_c is not None:
519
  hidden_states = self.canon_c(hidden_states)
520
 
 
522
  hidden_states = residual + hidden_states
523
 
524
  outputs = (hidden_states,)
525
+
526
  if output_attentions:
527
  outputs += (self_attn_weights,)
528
+
529
  if use_cache:
530
  outputs += (present_key_value,)
531
 
 
554
 
555
 
556
  class DharaModel(DharaPreTrainedModel):
557
+ """
558
+ Dhara base model with bidirectional attention and Canon layers.
559
+ """
560
 
561
  def __init__(self, config: DharaConfig):
562
  super().__init__(config)
 
567
  self.layers = nn.ModuleList(
568
  [DharaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
569
  )
570
+
571
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
572
  self.gradient_checkpointing = False
573
 
 
596
  return_dict: Optional[bool] = None,
597
  ) -> Union[Tuple, BaseModelOutputWithPast]:
598
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
599
+ output_hidden_states = (
600
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
601
+ )
602
  use_cache = use_cache if use_cache is not None else self.config.use_cache
603
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
604
 
605
  if input_ids is not None and inputs_embeds is not None:
606
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
607
  elif input_ids is not None:
608
  batch_size, seq_length = input_ids.shape[:2]
609
  elif inputs_embeds is not None:
 
611
  else:
612
  raise ValueError("You have to specify either input_ids or inputs_embeds")
613
 
614
+ if self.gradient_checkpointing and self.training:
615
+ if use_cache:
616
+ logger.warning_once(
617
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
618
+ )
619
+ use_cache = False
620
 
621
  past_key_values_length = 0
622
  if use_cache:
 
638
  if self._use_flash_attention_2:
639
  attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
640
  else:
641
+ # Bidirectional attention mask (not causal)
642
  if attention_mask is not None:
643
  if attention_mask.dim() == 2:
644
+ batch_size, seq_length = attention_mask.shape
645
  attention_mask_4d = attention_mask[:, None, None, :].expand(
646
  batch_size, 1, seq_length, seq_length
647
  ).to(dtype=inputs_embeds.dtype)
 
650
  torch.tensor(float('-inf'), dtype=inputs_embeds.dtype, device=attention_mask_4d.device),
651
  torch.tensor(0.0, dtype=inputs_embeds.dtype, device=attention_mask_4d.device)
652
  )
653
+ else:
654
+ attention_mask = attention_mask
655
+ else:
656
+ attention_mask = None
657
 
658
  hidden_states = inputs_embeds
659
+
660
  all_hidden_states = () if output_hidden_states else None
661
  all_self_attns = () if output_attentions else None
662
  next_decoder_cache = None
 
668
  if self.gradient_checkpointing and self.training:
669
  layer_outputs = self._gradient_checkpointing_func(
670
  decoder_layer.__call__,
671
+ hidden_states,
672
+ attention_mask,
673
+ position_ids,
674
+ past_key_values,
675
+ output_attentions,
676
+ use_cache,
677
  )
678
  else:
679
  layer_outputs = decoder_layer(
 
689
 
690
  if use_cache:
691
  next_decoder_cache = layer_outputs[2 if output_attentions else 1]
692
+
693
  if output_attentions:
694
  all_self_attns += (layer_outputs[1],)
695
 
 
713
  )
714
 
715
  def add_noise_to_tokens(self, input_ids: torch.LongTensor, t: torch.FloatTensor, eps: float = None):
716
+ """
717
+ MDM-style masking: Replace tokens with [MASK] based on noise level t.
718
+
719
+ Args:
720
+ input_ids: Input token IDs [batch_size, seq_len]
721
+ t: Noise level in [0, 1] [batch_size]
722
+ eps: Minimum mask probability
723
+
724
+ Returns:
725
+ Tuple of (noisy_input_ids, corruption_mask, p_mask)
726
+ """
727
  batch_size, seq_len = input_ids.shape
728
  device = input_ids.device
729
 
730
  if eps is None:
731
  eps = getattr(self.config, 'mask_epsilon', 0.001)
732
  p_mask = (1 - eps) * t + eps
733
+
734
  p_mask = p_mask.unsqueeze(-1).expand(batch_size, seq_len)
735
 
736
  corruption_mask = torch.rand(batch_size, seq_len, device=device) < p_mask
737
+
738
+ mask_token_id = self.mask_token_id
739
+ noisy_input_ids = torch.where(corruption_mask, mask_token_id, input_ids)
740
 
741
  return noisy_input_ids, corruption_mask, p_mask
742
 
743
 
744
  class DharaForMaskedDiffusion(DharaPreTrainedModel, GenerationMixin):
745
+ """Dhara Model with Masked Diffusion head for training"""
746
  _tied_weights_keys = ["lm_head.weight"]
747
 
748
  def __init__(self, config):
 
768
  def set_output_embeddings(self, new_embeddings):
769
  self.lm_head = new_embeddings
770
 
771
+ def set_decoder(self, decoder):
772
+ self.model = decoder
773
+
774
  def get_decoder(self):
775
  return self.model
776
 
 
790
  p_mask: Optional[torch.Tensor] = None,
791
  ) -> Union[Tuple, MaskedLMOutput]:
792
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
793
+ output_hidden_states = (
794
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
795
+ )
796
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
797
 
798
  outputs = self.model(
 
830
  )
831
 
832
  def compute_diffusion_loss(self, logits, labels, corruption_mask=None, p_mask=None):
833
+ """
834
+ MDM loss with p_mask importance weighting.
835
+ """
836
  if corruption_mask is None or p_mask is None:
837
+ raise ValueError(
838
+ "MDM requires both corruption_mask and p_mask for loss computation."
839
+ )
840
 
841
  loss = F.cross_entropy(
842
  logits.view(-1, self.config.vocab_size),
 
847
 
848
  masked_losses = loss[corruption_mask]
849
  masked_p_mask = p_mask[corruption_mask]
850
+
851
  weighted_losses = masked_losses / masked_p_mask
852
 
853
  total_positions = labels.shape[0] * labels.shape[1]
 
870
  max_cache_length = None
871
 
872
  if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
873
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
874
  elif past_length < input_ids.shape[1]:
875
  input_ids = input_ids[:, past_length:]
876
 
877
+ if (
878
+ max_cache_length is not None
879
+ and attention_mask is not None
880
+ and cache_length + input_ids.shape[1] > max_cache_length
881
+ ):
882
  attention_mask = attention_mask[:, -max_cache_length:]
883
 
884
  position_ids = kwargs.get("position_ids", None)
 
886
  position_ids = attention_mask.long().cumsum(-1) - 1
887
  position_ids.masked_fill_(attention_mask == 0, 1)
888
  if past_key_values:
889
+ position_ids = position_ids[:, -input_ids.shape[1] :]
890
 
891
  if inputs_embeds is not None and past_key_values is None:
892
  model_inputs = {"inputs_embeds": inputs_embeds}
893
  else:
894
  model_inputs = {"input_ids": input_ids}
895
 
896
+ model_inputs.update(
897
+ {
898
+ "position_ids": position_ids,
899
+ "past_key_values": past_key_values,
900
+ "use_cache": kwargs.get("use_cache"),
901
+ "attention_mask": attention_mask,
902
+ }
903
+ )
904
  return model_inputs
905
 
906
+ @staticmethod
907
+ def _reorder_cache(past_key_values, beam_idx):
908
+ reordered_past = ()
909
+ for layer_past in past_key_values:
910
+ reordered_past += (
911
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
912
+ )
913
+ return reordered_past
914
+
915
  @torch.no_grad()
916
  def generate(
917
  self,
 
921
  num_diffusion_steps: int = 10,
922
  temperature: float = 1.0,
923
  top_p: float = 0.9,
924
+ top_k: int = 50,
925
  do_sample: bool = True,
926
  pad_token_id: Optional[int] = None,
927
  eos_token_id: Optional[int] = None,
928
+ repetition_penalty: float = 1.2,
929
  **kwargs
930
  ) -> torch.LongTensor:
931
  """
932
+ Generate text using autoregressive sampling with the diffusion model.
933
 
934
+ Since this model was converted from AR to diffusion via WSD training,
935
+ we generate tokens one at a time left-to-right, using the model's
936
+ next-token predictions at each position.
937
 
938
  Args:
939
  input_ids: Input prompt token IDs [batch_size, prompt_len]
940
  max_length: Maximum total sequence length (prompt + generation)
941
  max_new_tokens: Number of new tokens to generate (alternative to max_length)
942
+ num_diffusion_steps: Number of refinement iterations per token (higher = better quality)
943
  temperature: Sampling temperature (higher = more random)
944
  top_p: Nucleus sampling threshold
945
+ top_k: Top-k sampling threshold
946
  do_sample: Whether to sample or take argmax
947
  pad_token_id: Token ID for padding
948
  eos_token_id: Token ID for end of sequence
949
+ repetition_penalty: Penalty for repeating tokens (>1 = less repetition)
950
 
951
  Returns:
952
  Generated token IDs including the prompt
 
978
  if eos_token_id is None:
979
  eos_token_id = self.config.eos_token_id if hasattr(self.config, 'eos_token_id') else 2
980
 
981
+ # Start with the prompt
982
+ generated = input_ids.clone()
 
 
 
 
 
 
 
 
 
983
 
984
+ # Track generated tokens for repetition penalty
985
+ generated_set = set()
986
+ for i in range(prompt_len):
987
+ for b in range(batch_size):
988
+ generated_set.add(input_ids[b, i].item())
989
+
990
+ # Generate tokens one at a time (autoregressive style)
991
+ for pos in range(gen_len):
992
+ # Add a mask token at the next position
993
+ current_seq = torch.cat([
994
+ generated,
995
+ torch.full((batch_size, 1), mask_token_id, dtype=torch.long, device=device)
996
+ ], dim=1)
997
+
998
+ # Get model predictions
999
+ outputs = self(input_ids=current_seq)
1000
  logits = outputs.logits # [batch, seq_len, vocab]
1001
 
1002
+ # Get logits for the last (masked) position
1003
+ next_token_logits = logits[:, -1, :] # [batch, vocab]
1004
 
1005
+ # Apply repetition penalty
1006
+ if repetition_penalty != 1.0:
1007
+ for b in range(batch_size):
1008
+ for prev_token in generated_set:
1009
+ if prev_token < next_token_logits.shape[1]:
1010
+ next_token_logits[b, prev_token] /= repetition_penalty
1011
 
1012
  # Apply temperature
1013
+ if temperature != 1.0 and temperature > 0:
1014
+ next_token_logits = next_token_logits / temperature
1015
+
1016
+ if do_sample and temperature > 0:
1017
+ # Apply top-k filtering
1018
+ if top_k > 0:
1019
+ indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
1020
+ next_token_logits[indices_to_remove] = float('-inf')
1021
+
1022
+ # Apply top-p (nucleus) filtering
1023
+ if top_p < 1.0:
1024
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
1025
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
1026
+
1027
+ # Remove tokens with cumulative probability above threshold
1028
+ sorted_indices_to_remove = cumulative_probs > top_p
1029
+ # Shift the indices to the right to keep the first token above threshold
1030
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
1031
+ sorted_indices_to_remove[..., 0] = False
1032
+
1033
+ # Scatter sorted indices to original indexing
1034
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
1035
+ next_token_logits[indices_to_remove] = float('-inf')
1036
+
1037
+ # Sample from the filtered distribution
1038
+ probs = F.softmax(next_token_logits, dim=-1)
1039
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1)
1040
  else:
1041
+ # Greedy decoding
1042
+ next_tokens = next_token_logits.argmax(dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1043
 
1044
+ # Add to generated sequence
1045
+ generated = torch.cat([generated, next_tokens.unsqueeze(-1)], dim=1)
1046
 
1047
+ # Update generated set for repetition penalty
1048
+ for b in range(batch_size):
1049
+ generated_set.add(next_tokens[b].item())
 
 
 
1050
 
1051
+ # Check for EOS
1052
+ if eos_token_id is not None and (next_tokens == eos_token_id).all():
1053
+ break
1054
 
1055
+ return generated
1056
 
1057
  def save_pretrained(self, save_directory, **kwargs):
1058
+ """Override to save in SafeTensors format by default"""
1059
  kwargs['safe_serialization'] = kwargs.get('safe_serialization', True)
1060
  return super().save_pretrained(save_directory, **kwargs)
1061
+
1062
+
1063
+ def count_parameters(model):
1064
+ """Count total and Canon-specific parameters."""
1065
+ total = sum(p.numel() for p in model.parameters())
1066
+ canon = sum(p.numel() for n, p in model.named_parameters() if 'canon' in n.lower())
1067
+ return total, canon
1068
+
1069
+
1070
+ if __name__ == "__main__":
1071
+ # Quick test
1072
+ print("Testing Dhara model creation...")
1073
+
1074
+ config = DharaConfig(
1075
+ vocab_size=50304,
1076
+ hidden_size=384,
1077
+ num_hidden_layers=32,
1078
+ num_attention_heads=8,
1079
+ num_key_value_heads=4,
1080
+ intermediate_size=1024,
1081
+ canon_set="AC",
1082
+ canon_kernel=4,
1083
+ canon_residual=True,
1084
+ )
1085
+
1086
+ model = DharaForMaskedDiffusion(config)
1087
+
1088
+ total, canon = count_parameters(model)
1089
+ print(f"Model created successfully!")
1090
+ print(f"Total params: {total:,} ({total/1e6:.2f}M)")
1091
+ print(f"Canon params: {canon:,} ({100*canon/total:.3f}%)")
1092
+ print(f"Base Dhara would be: {total - canon:,}")
1093
+
1094
+ # Test forward pass
1095
+ batch_size, seq_len = 2, 64
1096
+ input_ids = torch.randint(0, 50304, (batch_size, seq_len))
1097
+
1098
+ # Test with diffusion noise
1099
+ t = torch.rand(batch_size)
1100
+ noisy_ids, corruption_mask, p_mask = model.add_noise_to_tokens(input_ids, t)
1101
+
1102
+ with torch.no_grad():
1103
+ outputs = model(
1104
+ input_ids=noisy_ids,
1105
+ labels=input_ids,
1106
+ corruption_mask=corruption_mask,
1107
+ p_mask=p_mask,
1108
+ )
1109
+
1110
+ print(f"Forward pass: loss={outputs.loss.item():.4f}")
1111
+ print("Ready for training!")
tokenizer.json CHANGED
@@ -1,21 +1,7 @@
1
  {
2
  "version": "1.0",
3
- "truncation": {
4
- "direction": "Right",
5
- "max_length": 2048,
6
- "strategy": "LongestFirst",
7
- "stride": 0
8
- },
9
- "padding": {
10
- "strategy": {
11
- "Fixed": 2048
12
- },
13
- "direction": "Right",
14
- "pad_to_multiple_of": null,
15
- "pad_id": 50256,
16
- "pad_type_id": 0,
17
- "pad_token": "<|endoftext|>"
18
- },
19
  "added_tokens": [
20
  {
21
  "id": 50256,
 
1
  {
2
  "version": "1.0",
3
+ "truncation": null,
4
+ "padding": null,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  "added_tokens": [
6
  {
7
  "id": 50256,