Update modeling_sdlm.py
Browse files- modeling_sdlm.py +100 -75
modeling_sdlm.py
CHANGED
|
@@ -82,7 +82,8 @@ from .attn_mask_utils import (
|
|
| 82 |
update_causal_mask_for_one_gen_window_2d,
|
| 83 |
create_block_diff_mask_by_pe_1d,
|
| 84 |
create_block_diff_mask_by_pe_4d,
|
| 85 |
-
find_pred_pos_from_input_ids
|
|
|
|
| 86 |
)
|
| 87 |
|
| 88 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
|
@@ -1148,6 +1149,8 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|
| 1148 |
self.causal_attn = getattr(config, 'causal_attn', False)
|
| 1149 |
self.text_mask_token_id = getattr(config, 'text_mask_token_id', 151666)
|
| 1150 |
|
|
|
|
|
|
|
| 1151 |
# print(f'{self.block_size=} {self.causal_attn=} {self.training=} {self.text_mask_token_id=}\n')
|
| 1152 |
|
| 1153 |
|
|
@@ -1227,80 +1230,17 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|
| 1227 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 1228 |
x0_len = find_prefix_seq_length_by_pe(position_ids).to(device=device)
|
| 1229 |
|
| 1230 |
-
if self.
|
| 1231 |
-
|
| 1232 |
-
|
| 1233 |
-
|
| 1234 |
-
|
| 1235 |
-
|
| 1236 |
-
|
| 1237 |
-
|
| 1238 |
-
|
| 1239 |
-
|
| 1240 |
-
attention_mask, _ = create_block_diff_mask_by_pe_4d(
|
| 1241 |
-
block_size=self.block_size,
|
| 1242 |
-
x0_len_list=x0_len,
|
| 1243 |
-
position_ids=position_ids,
|
| 1244 |
-
causal_attn=self.causal_attn
|
| 1245 |
-
)
|
| 1246 |
-
|
| 1247 |
-
elif self._attn_implementation == "flash_attention_2":
|
| 1248 |
-
# # 2d mask is passed through the layers
|
| 1249 |
-
# attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
| 1250 |
-
|
| 1251 |
-
# TODO Update to Flex Attention.
|
| 1252 |
-
block_diff_mask_func = partial(
|
| 1253 |
-
create_block_diff_mask_by_pe_1d,
|
| 1254 |
-
block_size=self.block_size,
|
| 1255 |
-
x0_len_list=x0_len,
|
| 1256 |
-
position_ids_list=position_ids,
|
| 1257 |
-
causal_attn=self.causal_attn
|
| 1258 |
-
)
|
| 1259 |
-
|
| 1260 |
-
attention_mask = create_block_mask(
|
| 1261 |
-
block_diff_mask_func,
|
| 1262 |
-
B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length, device=device
|
| 1263 |
-
)
|
| 1264 |
-
|
| 1265 |
-
else:
|
| 1266 |
-
if not self.training:
|
| 1267 |
-
# for sampling, set attn = eager
|
| 1268 |
-
attention_mask = _prepare_4d_causal_attention_mask(
|
| 1269 |
-
attention_mask,
|
| 1270 |
-
(batch_size, seq_length),
|
| 1271 |
-
inputs_embeds,
|
| 1272 |
-
past_key_values_length,
|
| 1273 |
-
sliding_window=self.config.sliding_window,
|
| 1274 |
-
)
|
| 1275 |
|
| 1276 |
-
if use_cache:
|
| 1277 |
-
update_mask_func = partial(
|
| 1278 |
-
update_causal_mask_for_one_gen_window_2d,
|
| 1279 |
-
block_size=self.block_size,
|
| 1280 |
-
use_cache=use_cache,
|
| 1281 |
-
causal_attn=self.causal_attn
|
| 1282 |
-
)
|
| 1283 |
-
else:
|
| 1284 |
-
update_mask_func = partial(
|
| 1285 |
-
update_causal_mask_with_pad_non_visible_2d,
|
| 1286 |
-
block_size=self.block_size,
|
| 1287 |
-
text_mask_token_id=self.text_mask_token_id,
|
| 1288 |
-
causal_attn=self.causal_attn
|
| 1289 |
-
)
|
| 1290 |
-
|
| 1291 |
-
if attention_mask is not None and len(attention_mask.shape) == 4:
|
| 1292 |
-
new_attention_mask = []
|
| 1293 |
-
for b in range(attention_mask.shape[0]):
|
| 1294 |
-
new_attention_mask.append(
|
| 1295 |
-
update_mask_func(
|
| 1296 |
-
input_ids[b],
|
| 1297 |
-
attention_mask[b][0]
|
| 1298 |
-
).unsqueeze(0)
|
| 1299 |
-
)
|
| 1300 |
-
attention_mask = torch.stack(new_attention_mask, dim=0)
|
| 1301 |
-
|
| 1302 |
-
else:
|
| 1303 |
-
# for training
|
| 1304 |
attention_mask, _ = create_block_diff_mask_by_pe_4d(
|
| 1305 |
block_size=self.block_size,
|
| 1306 |
x0_len_list=x0_len,
|
|
@@ -1308,6 +1248,90 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|
| 1308 |
causal_attn=self.causal_attn
|
| 1309 |
)
|
| 1310 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1311 |
hidden_states = inputs_embeds
|
| 1312 |
|
| 1313 |
# decoder layers
|
|
@@ -1465,6 +1489,7 @@ class SDLMQwen2ForCausalLM(Qwen2PreTrainedModel):
|
|
| 1465 |
logits = logits.float()
|
| 1466 |
|
| 1467 |
loss = None
|
|
|
|
| 1468 |
if labels is not None:
|
| 1469 |
|
| 1470 |
# Shift so that tokens < n predict n
|
|
@@ -1506,7 +1531,7 @@ class SDLMQwen2ForCausalLM(Qwen2PreTrainedModel):
|
|
| 1506 |
output = (logits,) + outputs[1:]
|
| 1507 |
return (loss,) + output if loss is not None else output
|
| 1508 |
|
| 1509 |
-
if self.training:
|
| 1510 |
return CausalLMOutputWithPast(
|
| 1511 |
loss=loss,
|
| 1512 |
logits=logits,
|
|
|
|
| 82 |
update_causal_mask_for_one_gen_window_2d,
|
| 83 |
create_block_diff_mask_by_pe_1d,
|
| 84 |
create_block_diff_mask_by_pe_4d,
|
| 85 |
+
find_pred_pos_from_input_ids,
|
| 86 |
+
update_causal_mask_with_pad_non_visible_2d_for_ssd_cache
|
| 87 |
)
|
| 88 |
|
| 89 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
|
|
|
| 1149 |
self.causal_attn = getattr(config, 'causal_attn', False)
|
| 1150 |
self.text_mask_token_id = getattr(config, 'text_mask_token_id', 151666)
|
| 1151 |
|
| 1152 |
+
self.decoding_with_ssd_cache = False
|
| 1153 |
+
|
| 1154 |
# print(f'{self.block_size=} {self.causal_attn=} {self.training=} {self.text_mask_token_id=}\n')
|
| 1155 |
|
| 1156 |
|
|
|
|
| 1230 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 1231 |
x0_len = find_prefix_seq_length_by_pe(position_ids).to(device=device)
|
| 1232 |
|
| 1233 |
+
if self.training:
|
| 1234 |
+
if (self._attn_implementation == "sdpa" and not output_attentions) or self._attn_implementation == "eager":
|
| 1235 |
+
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
| 1236 |
+
# the manual implementation that requires a 4D causal mask in all cases.
|
| 1237 |
+
# attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
| 1238 |
+
# attention_mask,
|
| 1239 |
+
# (batch_size, seq_length),
|
| 1240 |
+
# inputs_embeds,
|
| 1241 |
+
# past_key_values_length,
|
| 1242 |
+
# )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1244 |
attention_mask, _ = create_block_diff_mask_by_pe_4d(
|
| 1245 |
block_size=self.block_size,
|
| 1246 |
x0_len_list=x0_len,
|
|
|
|
| 1248 |
causal_attn=self.causal_attn
|
| 1249 |
)
|
| 1250 |
|
| 1251 |
+
elif self._attn_implementation == "flash_attention_2":
|
| 1252 |
+
# # 2d mask is passed through the layers
|
| 1253 |
+
# attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
| 1254 |
+
|
| 1255 |
+
# TODO Update to Flex Attention.
|
| 1256 |
+
block_diff_mask_func = partial(
|
| 1257 |
+
create_block_diff_mask_by_pe_1d,
|
| 1258 |
+
block_size=self.block_size,
|
| 1259 |
+
x0_len_list=x0_len,
|
| 1260 |
+
position_ids_list=position_ids,
|
| 1261 |
+
causal_attn=self.causal_attn
|
| 1262 |
+
)
|
| 1263 |
+
|
| 1264 |
+
attention_mask = create_block_mask(
|
| 1265 |
+
block_diff_mask_func,
|
| 1266 |
+
B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length, device=device
|
| 1267 |
+
)
|
| 1268 |
+
|
| 1269 |
+
else:
|
| 1270 |
+
raise NotImplementedError
|
| 1271 |
+
else:
|
| 1272 |
+
assert self._attn_implementation in ['sdpa', 'eager']
|
| 1273 |
+
# for sampling, set attn = eager or sdpa
|
| 1274 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
| 1275 |
+
attention_mask,
|
| 1276 |
+
(batch_size, seq_length),
|
| 1277 |
+
inputs_embeds,
|
| 1278 |
+
past_key_values_length,
|
| 1279 |
+
sliding_window=self.config.sliding_window,
|
| 1280 |
+
)
|
| 1281 |
+
|
| 1282 |
+
if use_cache and self.decoding_with_ssd_cache:
|
| 1283 |
+
# Sampling with Self-Speculative Decoding
|
| 1284 |
+
update_mask_func = partial(
|
| 1285 |
+
update_causal_mask_with_pad_non_visible_2d_for_ssd_cache,
|
| 1286 |
+
block_size=self.block_size,
|
| 1287 |
+
use_cache=use_cache,
|
| 1288 |
+
causal_attn=self.causal_attn
|
| 1289 |
+
)
|
| 1290 |
+
elif use_cache:
|
| 1291 |
+
# Sampling with Confidence Decoding,
|
| 1292 |
+
# Only the generation window setting to full attention.
|
| 1293 |
+
# The last token of the previous window serves two roles, and is therefore duplicated, (1) causal attention for KV recomputation, during which it is masked from the subsequent decoding window; (2) and it also acts as the first token of the next decoding window
|
| 1294 |
+
update_mask_func = partial(
|
| 1295 |
+
update_causal_mask_for_one_gen_window_2d,
|
| 1296 |
+
block_size=self.block_size,
|
| 1297 |
+
use_cache=use_cache,
|
| 1298 |
+
causal_attn=self.causal_attn
|
| 1299 |
+
)
|
| 1300 |
+
else:
|
| 1301 |
+
update_mask_func = partial(
|
| 1302 |
+
update_causal_mask_with_pad_non_visible_2d,
|
| 1303 |
+
block_size=self.block_size,
|
| 1304 |
+
text_mask_token_id=self.text_mask_token_id,
|
| 1305 |
+
causal_attn=self.causal_attn
|
| 1306 |
+
)
|
| 1307 |
+
|
| 1308 |
+
if attention_mask is not None and len(attention_mask.shape) == 4:
|
| 1309 |
+
new_attention_mask = []
|
| 1310 |
+
for b in range(attention_mask.shape[0]):
|
| 1311 |
+
new_attention_mask.append(
|
| 1312 |
+
update_mask_func(
|
| 1313 |
+
input_ids[b],
|
| 1314 |
+
attention_mask[b][0]
|
| 1315 |
+
).unsqueeze(0)
|
| 1316 |
+
)
|
| 1317 |
+
attention_mask = torch.stack(new_attention_mask, dim=0)
|
| 1318 |
+
|
| 1319 |
+
# if True:
|
| 1320 |
+
# # For debug logging...
|
| 1321 |
+
# print(f'inference attention {self._attn_implementation} {self.training=} {use_cache=}')
|
| 1322 |
+
# print(f'{attention_mask.shape=}')
|
| 1323 |
+
# import numpy as np
|
| 1324 |
+
|
| 1325 |
+
# for b in range(batch_size):
|
| 1326 |
+
# causal_mask_2d = attention_mask[b][0].tolist()
|
| 1327 |
+
# binary_mask = np.where(np.array(causal_mask_2d)==0.0, 0, 1)
|
| 1328 |
+
# print(f'{b=}, mask is\n')
|
| 1329 |
+
|
| 1330 |
+
# for ix in range(binary_mask.shape[0]):
|
| 1331 |
+
# print(' '.join(str(_) for _ in binary_mask[ix]))
|
| 1332 |
+
# print('\n\n')
|
| 1333 |
+
|
| 1334 |
+
|
| 1335 |
hidden_states = inputs_embeds
|
| 1336 |
|
| 1337 |
# decoder layers
|
|
|
|
| 1489 |
logits = logits.float()
|
| 1490 |
|
| 1491 |
loss = None
|
| 1492 |
+
pos_loss_list = None
|
| 1493 |
if labels is not None:
|
| 1494 |
|
| 1495 |
# Shift so that tokens < n predict n
|
|
|
|
| 1531 |
output = (logits,) + outputs[1:]
|
| 1532 |
return (loss,) + output if loss is not None else output
|
| 1533 |
|
| 1534 |
+
if self.training and pos_loss_list is not None:
|
| 1535 |
return CausalLMOutputWithPast(
|
| 1536 |
loss=loss,
|
| 1537 |
logits=logits,
|