| from huggingface_hub import InferenceClient |
| from dotenv import load_dotenv |
| import configparser |
| import os |
|
|
|
|
| class LLMManager: |
| def __init__(self, settings): |
|
|
| |
| try: |
| load_dotenv() |
| except: |
| print("No .env file") |
|
|
| |
| HF_TOKEN = os.environ.get('HF_TOKEN') |
| self.client = InferenceClient(token=HF_TOKEN) |
|
|
| |
| self.config=configparser.ConfigParser() |
| self.config.read("config.ini") |
|
|
| |
| self.set=settings |
|
|
| |
| self.defaultLLM=self.set.defaultLLM |
|
|
| |
| self.listLLM=self.get_llm() |
| self.listLLMMap=self.get_llm_map() |
|
|
| |
| self.currentLLM=self.listLLM[self.defaultLLM] |
| |
| |
| def selectLLM(self, llm): |
| print("Selected {llm} LLM") |
| llmIndex=self.listLLMMap.index(llm) |
| self.currentLLM=self.listLLM[llmIndex] |
|
|
| |
| def get_llm(self): |
| llm_section = 'LLM' |
| if llm_section in self.config: |
| return [self.config.get(llm_section, llm) for llm in self.config[llm_section]] |
| else: |
| return [] |
| |
| |
| def get_llm_prompts(self): |
| prompt_section = 'Prompt_map' |
| if prompt_section in self.config: |
| return [self.config.get(prompt_section, llm) for llm in self.config[prompt_section]] |
| else: |
| return [] |
| |
| |
| def get_llm_map(self): |
| llm_map_section = 'LLM_Map' |
| if llm_map_section in self.config: |
| return [self.config.get(llm_map_section, llm) for llm in self.config[llm_map_section]] |
| else: |
| return [] |
| |
| |
| def get_text(self, question): |
|
|
| print("temp={temp}".format(temp=self.set.temperature)) |
| print("Repetition={rep}".format(rep=self.set.repetition_penalty)) |
| generate_kwargs = dict( |
| temperature=self.set.temperature, |
| max_new_tokens=self.set.max_new_token, |
| top_p=self.set.top_p, |
| repetition_penalty=self.set.repetition_penalty, |
| do_sample=True, |
| seed=42, |
| ) |
|
|
| stream = self.client.text_generation(model=self.currentLLM, prompt=question, **generate_kwargs,stream=False, details=False, return_full_text=False) |
| |
| return stream |
| |
| |
| |
| |
| |
| |
| def get_query_terms(self, question): |
| generate_kwargs = dict( |
| temperature=self.set.RAG_temperature, |
| max_new_tokens=self.set.RAG_max_new_token, |
| top_p=self.set.RAG_top_p, |
| repetition_penalty=self.set.RAG_repetition_penalty, |
| do_sample=True, |
| ) |
| stream = self.client.text_generation(model=self.currentLLM, prompt=question, **generate_kwargs,stream=False, details=False, return_full_text=False) |
| return stream |
| |
|
|
| |
| def get_prompt(self,user_input,rag_contex,chat_history, system_prompt=None): |
| """Returns the formatted prompt for a specific LLM""" |
|
|
| prompts=self.get_llm_prompts() |
| prompt="" |
|
|
| if system_prompt==None: |
| system_prompt=self.set.system_prompt |
| else: |
| print("System prompt set to : \n {sys_prompt}".format(sys_prompt=system_prompt)) |
|
|
| try: |
| prompt= prompts[self.listLLM.index(self.currentLLM)].format(sys_prompt=system_prompt) |
| except Exception: |
| print("Warning prompt map for {llm} has not been defined".format(llm=self.currentLLM)) |
| prompt="{sys_prompt}".format(sys_prompt=system_prompt) |
|
|
| print("Prompt={pro}".format(pro=prompt)) |
| return prompt.format(context=rag_contex,history=chat_history,question=user_input) |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |