lll2343 commited on
Commit
d7ac46a
·
verified ·
1 Parent(s): 29fddbf

Update modeling_sdlm.py

Browse files
Files changed (1) hide show
  1. modeling_sdlm.py +100 -75
modeling_sdlm.py CHANGED
@@ -82,7 +82,8 @@ from .attn_mask_utils import (
82
  update_causal_mask_for_one_gen_window_2d,
83
  create_block_diff_mask_by_pe_1d,
84
  create_block_diff_mask_by_pe_4d,
85
- find_pred_pos_from_input_ids
 
86
  )
87
 
88
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
@@ -1148,6 +1149,8 @@ class Qwen2Model(Qwen2PreTrainedModel):
1148
  self.causal_attn = getattr(config, 'causal_attn', False)
1149
  self.text_mask_token_id = getattr(config, 'text_mask_token_id', 151666)
1150
 
 
 
1151
  # print(f'{self.block_size=} {self.causal_attn=} {self.training=} {self.text_mask_token_id=}\n')
1152
 
1153
 
@@ -1227,80 +1230,17 @@ class Qwen2Model(Qwen2PreTrainedModel):
1227
  device = input_ids.device if input_ids is not None else inputs_embeds.device
1228
  x0_len = find_prefix_seq_length_by_pe(position_ids).to(device=device)
1229
 
1230
- if self._attn_implementation == "sdpa" and not output_attentions:
1231
- # output_attentions=True can not be supported when using SDPA, and we fall back on
1232
- # the manual implementation that requires a 4D causal mask in all cases.
1233
- # attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1234
- # attention_mask,
1235
- # (batch_size, seq_length),
1236
- # inputs_embeds,
1237
- # past_key_values_length,
1238
- # )
1239
-
1240
- attention_mask, _ = create_block_diff_mask_by_pe_4d(
1241
- block_size=self.block_size,
1242
- x0_len_list=x0_len,
1243
- position_ids=position_ids,
1244
- causal_attn=self.causal_attn
1245
- )
1246
-
1247
- elif self._attn_implementation == "flash_attention_2":
1248
- # # 2d mask is passed through the layers
1249
- # attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1250
-
1251
- # TODO Update to Flex Attention.
1252
- block_diff_mask_func = partial(
1253
- create_block_diff_mask_by_pe_1d,
1254
- block_size=self.block_size,
1255
- x0_len_list=x0_len,
1256
- position_ids_list=position_ids,
1257
- causal_attn=self.causal_attn
1258
- )
1259
-
1260
- attention_mask = create_block_mask(
1261
- block_diff_mask_func,
1262
- B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length, device=device
1263
- )
1264
-
1265
- else:
1266
- if not self.training:
1267
- # for sampling, set attn = eager
1268
- attention_mask = _prepare_4d_causal_attention_mask(
1269
- attention_mask,
1270
- (batch_size, seq_length),
1271
- inputs_embeds,
1272
- past_key_values_length,
1273
- sliding_window=self.config.sliding_window,
1274
- )
1275
 
1276
- if use_cache:
1277
- update_mask_func = partial(
1278
- update_causal_mask_for_one_gen_window_2d,
1279
- block_size=self.block_size,
1280
- use_cache=use_cache,
1281
- causal_attn=self.causal_attn
1282
- )
1283
- else:
1284
- update_mask_func = partial(
1285
- update_causal_mask_with_pad_non_visible_2d,
1286
- block_size=self.block_size,
1287
- text_mask_token_id=self.text_mask_token_id,
1288
- causal_attn=self.causal_attn
1289
- )
1290
-
1291
- if attention_mask is not None and len(attention_mask.shape) == 4:
1292
- new_attention_mask = []
1293
- for b in range(attention_mask.shape[0]):
1294
- new_attention_mask.append(
1295
- update_mask_func(
1296
- input_ids[b],
1297
- attention_mask[b][0]
1298
- ).unsqueeze(0)
1299
- )
1300
- attention_mask = torch.stack(new_attention_mask, dim=0)
1301
-
1302
- else:
1303
- # for training
1304
  attention_mask, _ = create_block_diff_mask_by_pe_4d(
1305
  block_size=self.block_size,
1306
  x0_len_list=x0_len,
@@ -1308,6 +1248,90 @@ class Qwen2Model(Qwen2PreTrainedModel):
1308
  causal_attn=self.causal_attn
1309
  )
1310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1311
  hidden_states = inputs_embeds
1312
 
1313
  # decoder layers
@@ -1465,6 +1489,7 @@ class SDLMQwen2ForCausalLM(Qwen2PreTrainedModel):
1465
  logits = logits.float()
1466
 
1467
  loss = None
 
1468
  if labels is not None:
1469
 
1470
  # Shift so that tokens < n predict n
@@ -1506,7 +1531,7 @@ class SDLMQwen2ForCausalLM(Qwen2PreTrainedModel):
1506
  output = (logits,) + outputs[1:]
1507
  return (loss,) + output if loss is not None else output
1508
 
1509
- if self.training:
1510
  return CausalLMOutputWithPast(
1511
  loss=loss,
1512
  logits=logits,
 
82
  update_causal_mask_for_one_gen_window_2d,
83
  create_block_diff_mask_by_pe_1d,
84
  create_block_diff_mask_by_pe_4d,
85
+ find_pred_pos_from_input_ids,
86
+ update_causal_mask_with_pad_non_visible_2d_for_ssd_cache
87
  )
88
 
89
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
 
1149
  self.causal_attn = getattr(config, 'causal_attn', False)
1150
  self.text_mask_token_id = getattr(config, 'text_mask_token_id', 151666)
1151
 
1152
+ self.decoding_with_ssd_cache = False
1153
+
1154
  # print(f'{self.block_size=} {self.causal_attn=} {self.training=} {self.text_mask_token_id=}\n')
1155
 
1156
 
 
1230
  device = input_ids.device if input_ids is not None else inputs_embeds.device
1231
  x0_len = find_prefix_seq_length_by_pe(position_ids).to(device=device)
1232
 
1233
+ if self.training:
1234
+ if (self._attn_implementation == "sdpa" and not output_attentions) or self._attn_implementation == "eager":
1235
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
1236
+ # the manual implementation that requires a 4D causal mask in all cases.
1237
+ # attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1238
+ # attention_mask,
1239
+ # (batch_size, seq_length),
1240
+ # inputs_embeds,
1241
+ # past_key_values_length,
1242
+ # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1244
  attention_mask, _ = create_block_diff_mask_by_pe_4d(
1245
  block_size=self.block_size,
1246
  x0_len_list=x0_len,
 
1248
  causal_attn=self.causal_attn
1249
  )
1250
 
1251
+ elif self._attn_implementation == "flash_attention_2":
1252
+ # # 2d mask is passed through the layers
1253
+ # attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1254
+
1255
+ # TODO Update to Flex Attention.
1256
+ block_diff_mask_func = partial(
1257
+ create_block_diff_mask_by_pe_1d,
1258
+ block_size=self.block_size,
1259
+ x0_len_list=x0_len,
1260
+ position_ids_list=position_ids,
1261
+ causal_attn=self.causal_attn
1262
+ )
1263
+
1264
+ attention_mask = create_block_mask(
1265
+ block_diff_mask_func,
1266
+ B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length, device=device
1267
+ )
1268
+
1269
+ else:
1270
+ raise NotImplementedError
1271
+ else:
1272
+ assert self._attn_implementation in ['sdpa', 'eager']
1273
+ # for sampling, set attn = eager or sdpa
1274
+ attention_mask = _prepare_4d_causal_attention_mask(
1275
+ attention_mask,
1276
+ (batch_size, seq_length),
1277
+ inputs_embeds,
1278
+ past_key_values_length,
1279
+ sliding_window=self.config.sliding_window,
1280
+ )
1281
+
1282
+ if use_cache and self.decoding_with_ssd_cache:
1283
+ # Sampling with Self-Speculative Decoding
1284
+ update_mask_func = partial(
1285
+ update_causal_mask_with_pad_non_visible_2d_for_ssd_cache,
1286
+ block_size=self.block_size,
1287
+ use_cache=use_cache,
1288
+ causal_attn=self.causal_attn
1289
+ )
1290
+ elif use_cache:
1291
+ # Sampling with Confidence Decoding,
1292
+ # Only the generation window setting to full attention.
1293
+ # The last token of the previous window serves two roles, and is therefore duplicated, (1) causal attention for KV recomputation, during which it is masked from the subsequent decoding window; (2) and it also acts as the first token of the next decoding window
1294
+ update_mask_func = partial(
1295
+ update_causal_mask_for_one_gen_window_2d,
1296
+ block_size=self.block_size,
1297
+ use_cache=use_cache,
1298
+ causal_attn=self.causal_attn
1299
+ )
1300
+ else:
1301
+ update_mask_func = partial(
1302
+ update_causal_mask_with_pad_non_visible_2d,
1303
+ block_size=self.block_size,
1304
+ text_mask_token_id=self.text_mask_token_id,
1305
+ causal_attn=self.causal_attn
1306
+ )
1307
+
1308
+ if attention_mask is not None and len(attention_mask.shape) == 4:
1309
+ new_attention_mask = []
1310
+ for b in range(attention_mask.shape[0]):
1311
+ new_attention_mask.append(
1312
+ update_mask_func(
1313
+ input_ids[b],
1314
+ attention_mask[b][0]
1315
+ ).unsqueeze(0)
1316
+ )
1317
+ attention_mask = torch.stack(new_attention_mask, dim=0)
1318
+
1319
+ # if True:
1320
+ # # For debug logging...
1321
+ # print(f'inference attention {self._attn_implementation} {self.training=} {use_cache=}')
1322
+ # print(f'{attention_mask.shape=}')
1323
+ # import numpy as np
1324
+
1325
+ # for b in range(batch_size):
1326
+ # causal_mask_2d = attention_mask[b][0].tolist()
1327
+ # binary_mask = np.where(np.array(causal_mask_2d)==0.0, 0, 1)
1328
+ # print(f'{b=}, mask is\n')
1329
+
1330
+ # for ix in range(binary_mask.shape[0]):
1331
+ # print(' '.join(str(_) for _ in binary_mask[ix]))
1332
+ # print('\n\n')
1333
+
1334
+
1335
  hidden_states = inputs_embeds
1336
 
1337
  # decoder layers
 
1489
  logits = logits.float()
1490
 
1491
  loss = None
1492
+ pos_loss_list = None
1493
  if labels is not None:
1494
 
1495
  # Shift so that tokens < n predict n
 
1531
  output = (logits,) + outputs[1:]
1532
  return (loss,) + output if loss is not None else output
1533
 
1534
+ if self.training and pos_loss_list is not None:
1535
  return CausalLMOutputWithPast(
1536
  loss=loss,
1537
  logits=logits,