hamza2923 commited on
Commit
3e22401
·
verified ·
1 Parent(s): fe9ee59

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +18 -3
main.py CHANGED
@@ -13,6 +13,8 @@ import logging
13
  import os
14
  import shutil
15
  from pathlib import Path
 
 
16
 
17
  app = FastAPI()
18
 
@@ -39,6 +41,7 @@ class TranscriptResponse(BaseModel):
39
  error: str | None
40
  processing_time: float
41
 
 
42
  def init_driver():
43
  options = Options()
44
  options.add_argument("--headless=new")
@@ -47,6 +50,11 @@ def init_driver():
47
  options.add_argument("--disable-gpu")
48
  options.add_argument("--disable-extensions")
49
 
 
 
 
 
 
50
  possible_chrome_paths = [
51
  "/usr/bin/google-chrome",
52
  "/usr/bin/google-chrome-stable",
@@ -75,22 +83,26 @@ def init_driver():
75
  chrome_version = driver.capabilities["browserVersion"]
76
  chromedriver_version = driver.capabilities["chrome"]["chromedriverVersion"].split()[0]
77
  logger.info(f"Chrome version: {chrome_version}, ChromeDriver version: {chromedriver_version}")
78
- return driver
79
  except Exception as e:
80
  logger.error(f"Driver initialization failed: {str(e)}")
 
 
 
81
  raise Exception(f"Driver initialization failed: {str(e)}")
82
 
83
  @app.post("/transcript", response_model=TranscriptResponse)
84
  async def get_transcript(request: VideoRequest):
85
  start_time = time.time()
86
  driver = None
 
87
 
88
  try:
89
  video_url = request.url
90
  if not ("youtube.com" in video_url or "youtu.be" in video_url):
91
  raise HTTPException(status_code=400, detail="Invalid YouTube URL")
92
 
93
- driver = init_driver()
94
  logger.info(f"Processing URL: {video_url}")
95
  driver.get(video_url)
96
 
@@ -162,6 +174,9 @@ async def get_transcript(request: VideoRequest):
162
  finally:
163
  if driver:
164
  driver.quit()
 
 
 
165
 
166
  @app.get("/health")
167
  def health_check():
@@ -180,4 +195,4 @@ async def root():
180
 
181
  if __name__ == "__main__":
182
  import uvicorn
183
- uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))
 
13
  import os
14
  import shutil
15
  from pathlib import Path
16
+ import tempfile
17
+ from tenacity import retry, stop_after_attempt, wait_fixed
18
 
19
  app = FastAPI()
20
 
 
41
  error: str | None
42
  processing_time: float
43
 
44
+ @retry(stop=stop_after_attempt(3), wait=wait_fixed(2))
45
  def init_driver():
46
  options = Options()
47
  options.add_argument("--headless=new")
 
50
  options.add_argument("--disable-gpu")
51
  options.add_argument("--disable-extensions")
52
 
53
+ # Create a unique temporary user data directory
54
+ user_data_dir = tempfile.mkdtemp()
55
+ options.add_argument(f"--user-data-dir={user_data_dir}")
56
+ logger.info(f"Using temporary user data directory: {user_data_dir}")
57
+
58
  possible_chrome_paths = [
59
  "/usr/bin/google-chrome",
60
  "/usr/bin/google-chrome-stable",
 
83
  chrome_version = driver.capabilities["browserVersion"]
84
  chromedriver_version = driver.capabilities["chrome"]["chromedriverVersion"].split()[0]
85
  logger.info(f"Chrome version: {chrome_version}, ChromeDriver version: {chromedriver_version}")
86
+ return driver, user_data_dir
87
  except Exception as e:
88
  logger.error(f"Driver initialization failed: {str(e)}")
89
+ # Clean up the temporary directory in case of failure
90
+ if Path(user_data_dir).exists():
91
+ shutil.rmtree(user_data_dir, ignore_errors=True)
92
  raise Exception(f"Driver initialization failed: {str(e)}")
93
 
94
  @app.post("/transcript", response_model=TranscriptResponse)
95
  async def get_transcript(request: VideoRequest):
96
  start_time = time.time()
97
  driver = None
98
+ user_data_dir = None
99
 
100
  try:
101
  video_url = request.url
102
  if not ("youtube.com" in video_url or "youtu.be" in video_url):
103
  raise HTTPException(status_code=400, detail="Invalid YouTube URL")
104
 
105
+ driver, user_data_dir = init_driver()
106
  logger.info(f"Processing URL: {video_url}")
107
  driver.get(video_url)
108
 
 
174
  finally:
175
  if driver:
176
  driver.quit()
177
+ if user_data_dir and Path(user_data_dir).exists():
178
+ shutil.rmtree(user_data_dir, ignore_errors=True)
179
+ logger.info(f"Cleaned up temporary user data directory: {user_data_dir}")
180
 
181
  @app.get("/health")
182
  def health_check():
 
195
 
196
  if __name__ == "__main__":
197
  import uvicorn
198
+ uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)), workers=1)