Update generate.py
Browse files- generate.py +12 -10
generate.py
CHANGED
|
@@ -45,15 +45,12 @@ def custom_generate(
|
|
| 45 |
):
|
| 46 |
if input_ids is None or input_ids.nelement() == 0:
|
| 47 |
# If input_ids is None or an empty tensor, create a default input tensor
|
| 48 |
-
input_ids = torch.LongTensor([[self.tokenizer.bos_token_id]])
|
| 49 |
-
attention_mask = torch.ones_like(input_ids)
|
| 50 |
|
| 51 |
device = input_ids.device
|
| 52 |
with torch.no_grad():
|
| 53 |
batch_size = input_ids.shape[0]
|
| 54 |
-
if max_new_tokens is None:
|
| 55 |
-
raise ValueError("max_new_tokens must be provided.")
|
| 56 |
-
|
| 57 |
finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
| 58 |
generated_token_ids = torch.full((batch_size, max_new_tokens), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
|
| 59 |
|
|
@@ -156,10 +153,10 @@ def generate(
|
|
| 156 |
torch_dtype=torch.bfloat16,
|
| 157 |
**model_kwargs,
|
| 158 |
):
|
| 159 |
-
# Set default value for max_new_tokens if not provided
|
| 160 |
-
if max_new_tokens is None:
|
| 161 |
-
max_new_tokens = 128 # Set a reasonable default value
|
| 162 |
|
|
|
|
|
|
|
|
|
|
| 163 |
# Set model attributes
|
| 164 |
self.max_thoughts = n_ahead + n_ahead_talk + 1
|
| 165 |
self.merged_talk_heads = merged_talk_heads
|
|
@@ -186,11 +183,16 @@ def generate(
|
|
| 186 |
if isinstance(input_ids, str):
|
| 187 |
input_ids = self.tokenizer.encode(input_ids, return_tensors='pt')
|
| 188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
generated_token_ids = custom_generate(
|
| 190 |
self,
|
| 191 |
-
input_ids=input_ids,
|
| 192 |
attention_mask=attention_mask,
|
| 193 |
-
max_new_tokens=max_new_tokens,
|
| 194 |
min_length=min_length,
|
| 195 |
do_sample=do_sample,
|
| 196 |
early_stopping=early_stopping,
|
|
|
|
| 45 |
):
|
| 46 |
if input_ids is None or input_ids.nelement() == 0:
|
| 47 |
# If input_ids is None or an empty tensor, create a default input tensor
|
| 48 |
+
input_ids = torch.LongTensor([[self.tokenizer.bos_token_id]]).to(self.device)
|
| 49 |
+
attention_mask = torch.ones_like(input_ids).to(self.device)
|
| 50 |
|
| 51 |
device = input_ids.device
|
| 52 |
with torch.no_grad():
|
| 53 |
batch_size = input_ids.shape[0]
|
|
|
|
|
|
|
|
|
|
| 54 |
finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
| 55 |
generated_token_ids = torch.full((batch_size, max_new_tokens), self.tokenizer.pad_token_id, dtype=torch.long, device=device)
|
| 56 |
|
|
|
|
| 153 |
torch_dtype=torch.bfloat16,
|
| 154 |
**model_kwargs,
|
| 155 |
):
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
+
if max_new_tokens is None:
|
| 158 |
+
max_new_tokens = 128
|
| 159 |
+
|
| 160 |
# Set model attributes
|
| 161 |
self.max_thoughts = n_ahead + n_ahead_talk + 1
|
| 162 |
self.merged_talk_heads = merged_talk_heads
|
|
|
|
| 183 |
if isinstance(input_ids, str):
|
| 184 |
input_ids = self.tokenizer.encode(input_ids, return_tensors='pt')
|
| 185 |
|
| 186 |
+
# Move input_ids and attention_mask to the same device as the model
|
| 187 |
+
input_ids = input_ids.to(self.device)
|
| 188 |
+
if attention_mask is not None:
|
| 189 |
+
attention_mask = attention_mask.to(self.device)
|
| 190 |
+
|
| 191 |
generated_token_ids = custom_generate(
|
| 192 |
self,
|
| 193 |
+
input_ids=input_ids,
|
| 194 |
attention_mask=attention_mask,
|
| 195 |
+
max_new_tokens=max_new_tokens,
|
| 196 |
min_length=min_length,
|
| 197 |
do_sample=do_sample,
|
| 198 |
early_stopping=early_stopping,
|