Spaces:
Sleeping
Sleeping
Commit
·
7662a6a
1
Parent(s):
6951e54
Revert portg
Browse files
app.py
CHANGED
|
@@ -559,6 +559,60 @@ class RealtimeSpeakerDiarization:
|
|
| 559 |
except Exception as e:
|
| 560 |
print(f"Error feeding audio data: {e}")
|
| 561 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 562 |
|
| 563 |
# FastRTC Audio Handler for Real-time Diarization
|
| 564 |
|
|
@@ -638,6 +692,16 @@ class DiarizationHandler(AsyncStreamHandler):
|
|
| 638 |
)
|
| 639 |
except Exception as e:
|
| 640 |
print(f"Error in async audio processing: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 641 |
|
| 642 |
|
| 643 |
# Global instances
|
|
@@ -1083,8 +1147,19 @@ def create_app(diarization_sys=None):
|
|
| 1083 |
# Mount Gradio on FastAPI
|
| 1084 |
app = gr.mount_gradio_app(fastapi_app, gradio_interface, path="/")
|
| 1085 |
|
| 1086 |
-
# Setup FastRTC stream
|
| 1087 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1088 |
|
| 1089 |
return app, gradio_interface
|
| 1090 |
|
|
|
|
| 559 |
except Exception as e:
|
| 560 |
print(f"Error feeding audio data: {e}")
|
| 561 |
|
| 562 |
+
def process_audio_chunk(self, audio_data, sample_rate=16000):
|
| 563 |
+
"""Process audio chunk from FastRTC input"""
|
| 564 |
+
if not self.is_running or self.recorder is None:
|
| 565 |
+
return
|
| 566 |
+
|
| 567 |
+
try:
|
| 568 |
+
# Convert float audio to int16 for the recorder
|
| 569 |
+
if audio_data.dtype == np.float32 or audio_data.dtype == np.float64:
|
| 570 |
+
if np.max(np.abs(audio_data)) <= 1.0:
|
| 571 |
+
# Float audio is normalized to [-1, 1], convert to int16
|
| 572 |
+
audio_int16 = (audio_data * 32767).astype(np.int16)
|
| 573 |
+
else:
|
| 574 |
+
# Audio is already in higher range
|
| 575 |
+
audio_int16 = audio_data.astype(np.int16)
|
| 576 |
+
else:
|
| 577 |
+
audio_int16 = audio_data
|
| 578 |
+
|
| 579 |
+
# Ensure correct shape (1, N) for the recorder
|
| 580 |
+
if len(audio_int16.shape) == 1:
|
| 581 |
+
audio_int16 = np.expand_dims(audio_int16, 0)
|
| 582 |
+
|
| 583 |
+
# Resample if needed
|
| 584 |
+
if sample_rate != SAMPLE_RATE:
|
| 585 |
+
audio_int16 = self._resample_audio(audio_int16, sample_rate, SAMPLE_RATE)
|
| 586 |
+
|
| 587 |
+
# Convert to bytes for feeding to recorder
|
| 588 |
+
audio_bytes = audio_int16.tobytes()
|
| 589 |
+
|
| 590 |
+
# Feed to recorder
|
| 591 |
+
self.feed_audio_data(audio_bytes)
|
| 592 |
+
|
| 593 |
+
except Exception as e:
|
| 594 |
+
print(f"Error processing audio chunk: {e}")
|
| 595 |
+
|
| 596 |
+
def _resample_audio(self, audio, orig_sr, target_sr):
|
| 597 |
+
"""Resample audio to target sample rate"""
|
| 598 |
+
try:
|
| 599 |
+
import scipy.signal
|
| 600 |
+
|
| 601 |
+
# Get the resampling ratio
|
| 602 |
+
ratio = target_sr / orig_sr
|
| 603 |
+
|
| 604 |
+
# Calculate the new length
|
| 605 |
+
new_length = int(len(audio[0]) * ratio)
|
| 606 |
+
|
| 607 |
+
# Resample the audio
|
| 608 |
+
resampled = scipy.signal.resample(audio[0], new_length)
|
| 609 |
+
|
| 610 |
+
# Return in the same shape format
|
| 611 |
+
return np.expand_dims(resampled, 0)
|
| 612 |
+
except Exception as e:
|
| 613 |
+
print(f"Error resampling audio: {e}")
|
| 614 |
+
return audio
|
| 615 |
+
|
| 616 |
|
| 617 |
# FastRTC Audio Handler for Real-time Diarization
|
| 618 |
|
|
|
|
| 692 |
)
|
| 693 |
except Exception as e:
|
| 694 |
print(f"Error in async audio processing: {e}")
|
| 695 |
+
|
| 696 |
+
async def start_up(self) -> None:
|
| 697 |
+
"""Initialize any resources when the stream starts"""
|
| 698 |
+
print("FastRTC stream started")
|
| 699 |
+
self.is_processing = True
|
| 700 |
+
|
| 701 |
+
async def shutdown(self) -> None:
|
| 702 |
+
"""Clean up any resources when the stream ends"""
|
| 703 |
+
print("FastRTC stream shutting down")
|
| 704 |
+
self.is_processing = False
|
| 705 |
|
| 706 |
|
| 707 |
# Global instances
|
|
|
|
| 1147 |
# Mount Gradio on FastAPI
|
| 1148 |
app = gr.mount_gradio_app(fastapi_app, gradio_interface, path="/")
|
| 1149 |
|
| 1150 |
+
# Setup FastRTC stream
|
| 1151 |
+
if diarization_system is not None:
|
| 1152 |
+
# Initialize the system if not already done
|
| 1153 |
+
if not hasattr(diarization_system, 'encoder') or diarization_system.encoder is None:
|
| 1154 |
+
diarization_system.initialize_models()
|
| 1155 |
+
|
| 1156 |
+
# Create audio handler if needed
|
| 1157 |
+
global audio_handler
|
| 1158 |
+
if audio_handler is None:
|
| 1159 |
+
audio_handler = DiarizationHandler(diarization_system)
|
| 1160 |
+
|
| 1161 |
+
# Setup and mount the FastRTC stream
|
| 1162 |
+
setup_fastrtc_stream(app)
|
| 1163 |
|
| 1164 |
return app, gradio_interface
|
| 1165 |
|