Update ultravox_processing.py
Browse files- ultravox_processing.py +28 -0
ultravox_processing.py
CHANGED
|
@@ -4,6 +4,8 @@ import numpy as np
|
|
| 4 |
import torch
|
| 5 |
import transformers
|
| 6 |
|
|
|
|
|
|
|
| 7 |
|
| 8 |
class UltravoxProcessor(transformers.ProcessorMixin):
|
| 9 |
"""
|
|
@@ -59,6 +61,29 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
| 59 |
|
| 60 |
super().__init__(audio_processor=audio_processor, tokenizer=tokenizer)
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
def __call__(
|
| 63 |
self,
|
| 64 |
text: Optional[str] = None,
|
|
@@ -178,3 +203,6 @@ class UltravoxProcessor(transformers.ProcessorMixin):
|
|
| 178 |
tokenizer_input_names = self.tokenizer.model_input_names
|
| 179 |
audio_processor_input_names = self.audio_processor.model_input_names
|
| 180 |
return list(set(tokenizer_input_names + audio_processor_input_names))
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import torch
|
| 5 |
import transformers
|
| 6 |
|
| 7 |
+
from .ultravox_config import UltravoxConfig
|
| 8 |
+
|
| 9 |
|
| 10 |
class UltravoxProcessor(transformers.ProcessorMixin):
|
| 11 |
"""
|
|
|
|
| 61 |
|
| 62 |
super().__init__(audio_processor=audio_processor, tokenizer=tokenizer)
|
| 63 |
|
| 64 |
+
@classmethod
|
| 65 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
| 66 |
+
config: UltravoxConfig = transformers.AutoConfig.from_pretrained(
|
| 67 |
+
pretrained_model_name_or_path, **kwargs
|
| 68 |
+
)
|
| 69 |
+
audio_processor = transformers.AutoProcessor.from_pretrained(
|
| 70 |
+
config.audio_model_id
|
| 71 |
+
or config.audio_config._name_or_path
|
| 72 |
+
or "facebook/wav2vec2-base-960h"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 76 |
+
pretrained_model_name_or_path, **kwargs
|
| 77 |
+
)
|
| 78 |
+
tokenizer.padding_side = "left"
|
| 79 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 80 |
+
|
| 81 |
+
return cls(
|
| 82 |
+
audio_processor=audio_processor,
|
| 83 |
+
tokenizer=tokenizer,
|
| 84 |
+
stack_factor=config.stack_factor,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
def __call__(
|
| 88 |
self,
|
| 89 |
text: Optional[str] = None,
|
|
|
|
| 203 |
tokenizer_input_names = self.tokenizer.model_input_names
|
| 204 |
audio_processor_input_names = self.audio_processor.model_input_names
|
| 205 |
return list(set(tokenizer_input_names + audio_processor_input_names))
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
transformers.AutoProcessor.register(UltravoxConfig, UltravoxProcessor)
|