Nexari-Research commited on
Commit
2f7603d
·
verified ·
1 Parent(s): 59903c2

Update chat_model.py

Browse files
Files changed (1) hide show
  1. chat_model.py +26 -13
chat_model.py CHANGED
@@ -1,4 +1,6 @@
1
- import os, asyncio, logging
 
 
2
  from huggingface_hub import hf_hub_download
3
  from llama_cpp import Llama
4
 
@@ -9,19 +11,30 @@ model = None
9
  REPO_ID = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF"
10
  FILENAME = "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf"
11
 
12
- def load_model():
 
 
 
 
 
 
 
 
 
 
 
13
  global model
14
- os.makedirs(BASE_DIR, exist_ok=True)
15
- path = hf_hub_download(REPO_ID, FILENAME, local_dir=BASE_DIR)
16
- model = Llama(
17
- model_path=path,
18
- n_ctx=4096,
19
- n_threads=os.cpu_count(),
20
- n_batch=256,
21
- verbose=False
22
- )
23
- logger.info("Chat model ready")
24
- return model
25
 
26
  async def load_model_async():
27
  return await asyncio.to_thread(load_model)
 
1
+ import os
2
+ import logging
3
+ import asyncio
4
  from huggingface_hub import hf_hub_download
5
  from llama_cpp import Llama
6
 
 
11
  REPO_ID = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF"
12
  FILENAME = "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf"
13
 
14
+ def download_if_needed(local_dir: str):
15
+ os.makedirs(local_dir, exist_ok=True)
16
+ local_path = os.path.join(local_dir, FILENAME)
17
+ if os.path.exists(local_path):
18
+ logger.info(f"Chat model already present: {local_path}")
19
+ return local_path
20
+ logger.info(f"Downloading chat model to {local_dir} ...")
21
+ path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, local_dir=local_dir)
22
+ logger.info(f"Downloaded chat: {path}")
23
+ return path
24
+
25
+ def load_model(local_dir: str = None):
26
  global model
27
+ if not local_dir:
28
+ local_dir = BASE_DIR
29
+ try:
30
+ model_path = download_if_needed(local_dir)
31
+ model = Llama(model_path=model_path, n_ctx=2048, verbose=False)
32
+ logger.info("Chat model loaded")
33
+ return model
34
+ except Exception as e:
35
+ logger.exception(f"Failed to load chat model: {e}")
36
+ model = None
37
+ raise
38
 
39
  async def load_model_async():
40
  return await asyncio.to_thread(load_model)