fedebotu commited on
Commit
83dd916
·
1 Parent(s): a9d97d7

[Feat] support for merging generation config with generate kwargs

Browse files
Files changed (1) hide show
  1. generation_utils.py +11 -16
generation_utils.py CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  """
2
  RND1 Generation Utilities.
3
 
@@ -6,12 +11,12 @@ including the main GenerationMixin class that integrates with HuggingFace.
6
  """
7
 
8
  import torch
9
- import torch.nn as nn
10
  from typing import Optional, Union, Dict, Any
11
  from transformers import GenerationMixin as HFGenerationMixin
12
  from transformers.generation import GenerationConfig
13
 
14
- from .sampling import diffusion_sample, apply_top_k_filtering, apply_top_p_filtering
 
15
 
16
 
17
  class RND1GenerationMixin(HFGenerationMixin):
@@ -41,12 +46,12 @@ class RND1GenerationMixin(HFGenerationMixin):
41
 
42
  Args:
43
  inputs: Input token IDs to use as prefix (standard HF parameter)
44
- generation_config: Generation configuration object
45
  prefix_ids: Alternative to inputs for infilling tasks
46
  suffix_ids: Optional suffix for infilling tasks
47
  infill_length: Length of infill region (for infilling)
48
  return_dict_in_generate: Whether to return GenerateDecoderOnlyOutput
49
- **kwargs: Additional arguments (accepted for compatibility)
50
 
51
  Returns:
52
  Generated token IDs or GenerateDecoderOnlyOutput
@@ -56,7 +61,7 @@ class RND1GenerationMixin(HFGenerationMixin):
56
  model_kwargs = kwargs.copy()
57
  else:
58
  # Only prepare config from kwargs if no config was provided
59
- gen_config, model_kwargs = self._prepare_generation_config(None, **kwargs)
60
 
61
  device = next(self.parameters()).device
62
 
@@ -71,7 +76,7 @@ class RND1GenerationMixin(HFGenerationMixin):
71
  suffix_ids = suffix_ids.to(device)
72
 
73
  eos_token_id = gen_config.eos_token_id or getattr(self.config, "eos_token_id", 151645)
74
- pad_token_id = gen_config.pad_token_id or getattr(self.config, "pad_token_id", None)
75
  bos_token_id = gen_config.bos_token_id or getattr(self.config, "bos_token_id", None)
76
  mask_token_id = getattr(gen_config, "mask_token_id", getattr(self.config, "mask_token_id", 151669))
77
 
@@ -101,12 +106,6 @@ class RND1GenerationMixin(HFGenerationMixin):
101
  greedy = getattr(gen_config, "greedy",
102
  not bool(gen_config.do_sample) if hasattr(gen_config, "do_sample") else True)
103
 
104
- generator = model_kwargs.get("generator", None)
105
- if generator is None:
106
- seed = getattr(gen_config, 'seed', None)
107
- if seed is not None:
108
- generator = torch.Generator(device=device)
109
- generator.manual_seed(seed)
110
 
111
  with torch.inference_mode():
112
  sequences = diffusion_sample(
@@ -125,7 +124,6 @@ class RND1GenerationMixin(HFGenerationMixin):
125
  pad_token_id=pad_token_id,
126
  bos_token_id=bos_token_id,
127
  device=device,
128
- generator=generator,
129
  visualizer=model_kwargs.get("visualizer", None), # Optional visualizer from kwargs
130
  )
131
 
@@ -142,7 +140,6 @@ class RND1GenerationMixin(HFGenerationMixin):
142
  generation_config: Optional[GenerationConfig] = None,
143
  suffix_ids: Optional[torch.LongTensor] = None,
144
  infill_length: Optional[int] = None,
145
- generator: Optional[torch.Generator] = None,
146
  **kwargs,
147
  ) -> torch.LongTensor:
148
  """
@@ -157,7 +154,6 @@ class RND1GenerationMixin(HFGenerationMixin):
157
  generation_config: Generation configuration object
158
  suffix_ids: Optional suffix token IDs
159
  infill_length: Length of infill region
160
- generator: Random generator for reproducibility
161
  **kwargs: Additional arguments for backward compatibility
162
 
163
  Returns:
@@ -171,7 +167,6 @@ class RND1GenerationMixin(HFGenerationMixin):
171
  generation_config=generation_config,
172
  suffix_ids=suffix_ids,
173
  infill_length=infill_length,
174
- generator=generator,
175
  visualizer=visualizer,
176
  return_dict_in_generate=False,
177
  **kwargs,
 
1
+ # Copyright 2025 Radical Numerics Inc.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0, found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
  """
7
  RND1 Generation Utilities.
8
 
 
11
  """
12
 
13
  import torch
 
14
  from typing import Optional, Union, Dict, Any
15
  from transformers import GenerationMixin as HFGenerationMixin
16
  from transformers.generation import GenerationConfig
17
 
18
+ from .generation_config import RND1GenerationConfig
19
+ from .sampling import diffusion_sample
20
 
21
 
22
  class RND1GenerationMixin(HFGenerationMixin):
 
46
 
47
  Args:
48
  inputs: Input token IDs to use as prefix (standard HF parameter)
49
+ generation_config: Generation configuration object. Default is RND1GenerationConfig.
50
  prefix_ids: Alternative to inputs for infilling tasks
51
  suffix_ids: Optional suffix for infilling tasks
52
  infill_length: Length of infill region (for infilling)
53
  return_dict_in_generate: Whether to return GenerateDecoderOnlyOutput
54
+ **kwargs: Additional arguments (accepted for compatibility). These will be passed to the config constructor.
55
 
56
  Returns:
57
  Generated token IDs or GenerateDecoderOnlyOutput
 
61
  model_kwargs = kwargs.copy()
62
  else:
63
  # Only prepare config from kwargs if no config was provided
64
+ gen_config, model_kwargs = self._prepare_generation_config(RND1GenerationConfig(), **kwargs)
65
 
66
  device = next(self.parameters()).device
67
 
 
76
  suffix_ids = suffix_ids.to(device)
77
 
78
  eos_token_id = gen_config.eos_token_id or getattr(self.config, "eos_token_id", 151645)
79
+ pad_token_id = gen_config.pad_token_id or getattr(self.config, "pad_token_id", 151643)
80
  bos_token_id = gen_config.bos_token_id or getattr(self.config, "bos_token_id", None)
81
  mask_token_id = getattr(gen_config, "mask_token_id", getattr(self.config, "mask_token_id", 151669))
82
 
 
106
  greedy = getattr(gen_config, "greedy",
107
  not bool(gen_config.do_sample) if hasattr(gen_config, "do_sample") else True)
108
 
 
 
 
 
 
 
109
 
110
  with torch.inference_mode():
111
  sequences = diffusion_sample(
 
124
  pad_token_id=pad_token_id,
125
  bos_token_id=bos_token_id,
126
  device=device,
 
127
  visualizer=model_kwargs.get("visualizer", None), # Optional visualizer from kwargs
128
  )
129
 
 
140
  generation_config: Optional[GenerationConfig] = None,
141
  suffix_ids: Optional[torch.LongTensor] = None,
142
  infill_length: Optional[int] = None,
 
143
  **kwargs,
144
  ) -> torch.LongTensor:
145
  """
 
154
  generation_config: Generation configuration object
155
  suffix_ids: Optional suffix token IDs
156
  infill_length: Length of infill region
 
157
  **kwargs: Additional arguments for backward compatibility
158
 
159
  Returns:
 
167
  generation_config=generation_config,
168
  suffix_ids=suffix_ids,
169
  infill_length=infill_length,
 
170
  visualizer=visualizer,
171
  return_dict_in_generate=False,
172
  **kwargs,