Upload KotobaWhisperPipeline
Browse files- config.json +1 -1
- generation_config.json +1 -1
- kotoba_whisper.py +4 -23
config.json
CHANGED
|
@@ -54,7 +54,7 @@
|
|
| 54 |
"pad_token_id": 50256,
|
| 55 |
"scale_embedding": false,
|
| 56 |
"torch_dtype": "float32",
|
| 57 |
-
"transformers_version": "4.
|
| 58 |
"use_cache": true,
|
| 59 |
"use_weighted_layer_sum": false,
|
| 60 |
"vocab_size": 51866
|
|
|
|
| 54 |
"pad_token_id": 50256,
|
| 55 |
"scale_embedding": false,
|
| 56 |
"torch_dtype": "float32",
|
| 57 |
+
"transformers_version": "4.41.0.dev0",
|
| 58 |
"use_cache": true,
|
| 59 |
"use_weighted_layer_sum": false,
|
| 60 |
"vocab_size": 51866
|
generation_config.json
CHANGED
|
@@ -261,5 +261,5 @@
|
|
| 261 |
"transcribe": 50360,
|
| 262 |
"translate": 50359
|
| 263 |
},
|
| 264 |
-
"transformers_version": "4.
|
| 265 |
}
|
|
|
|
| 261 |
"transcribe": 50360,
|
| 262 |
"translate": 50359
|
| 263 |
},
|
| 264 |
+
"transformers_version": "4.41.0.dev0"
|
| 265 |
}
|
kotoba_whisper.py
CHANGED
|
@@ -249,6 +249,8 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
|
|
| 249 |
encoder = self.model.get_encoder()
|
| 250 |
# Consume values so we can let extra information flow freely through
|
| 251 |
# the pipeline (important for `partial` in microphone)
|
|
|
|
|
|
|
| 252 |
if "input_features" in model_inputs:
|
| 253 |
inputs = model_inputs.pop("input_features")
|
| 254 |
elif "input_values" in model_inputs:
|
|
@@ -260,18 +262,7 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
|
|
| 260 |
)
|
| 261 |
|
| 262 |
# custom processing for Whisper timestamps and word-level timestamps
|
| 263 |
-
|
| 264 |
-
generate_kwargs["return_timestamps"] = return_timestamps
|
| 265 |
-
if return_timestamps == "word":
|
| 266 |
-
generate_kwargs["return_token_timestamps"] = True
|
| 267 |
-
generate_kwargs["return_segments"] = True
|
| 268 |
-
|
| 269 |
-
if stride is not None:
|
| 270 |
-
if isinstance(stride, tuple):
|
| 271 |
-
generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length
|
| 272 |
-
else:
|
| 273 |
-
generate_kwargs["num_frames"] = [s[0] // self.feature_extractor.hop_length for s in stride]
|
| 274 |
-
|
| 275 |
if inputs.shape[-1] > self.feature_extractor.nb_max_frames:
|
| 276 |
generate_kwargs["input_features"] = inputs
|
| 277 |
else:
|
|
@@ -279,17 +270,7 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
|
|
| 279 |
|
| 280 |
tokens = self.model.generate(attention_mask=attention_mask, **generate_kwargs)
|
| 281 |
# whisper longform generation stores timestamps in "segments"
|
| 282 |
-
|
| 283 |
-
if "segments" not in tokens:
|
| 284 |
-
out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]}
|
| 285 |
-
else:
|
| 286 |
-
token_timestamps = [
|
| 287 |
-
torch.cat([segment["token_timestamps"] for segment in segment_list])
|
| 288 |
-
for segment_list in tokens["segments"]
|
| 289 |
-
]
|
| 290 |
-
out = {"tokens": tokens["sequences"], "token_timestamps": token_timestamps}
|
| 291 |
-
else:
|
| 292 |
-
out = {"tokens": tokens}
|
| 293 |
if self.type == "seq2seq_whisper":
|
| 294 |
if stride is not None:
|
| 295 |
out["stride"] = stride
|
|
|
|
| 249 |
encoder = self.model.get_encoder()
|
| 250 |
# Consume values so we can let extra information flow freely through
|
| 251 |
# the pipeline (important for `partial` in microphone)
|
| 252 |
+
if type(return_timestamps) is not bool:
|
| 253 |
+
raise ValueError("return_timestamps should be bool")
|
| 254 |
if "input_features" in model_inputs:
|
| 255 |
inputs = model_inputs.pop("input_features")
|
| 256 |
elif "input_values" in model_inputs:
|
|
|
|
| 262 |
)
|
| 263 |
|
| 264 |
# custom processing for Whisper timestamps and word-level timestamps
|
| 265 |
+
generate_kwargs["return_timestamps"] = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
if inputs.shape[-1] > self.feature_extractor.nb_max_frames:
|
| 267 |
generate_kwargs["input_features"] = inputs
|
| 268 |
else:
|
|
|
|
| 270 |
|
| 271 |
tokens = self.model.generate(attention_mask=attention_mask, **generate_kwargs)
|
| 272 |
# whisper longform generation stores timestamps in "segments"
|
| 273 |
+
out = {"tokens": tokens}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
if self.type == "seq2seq_whisper":
|
| 275 |
if stride is not None:
|
| 276 |
out["stride"] = stride
|