Spaces:
Sleeping
Sleeping
| from typing import Any, Literal, Callable | |
| import openai | |
| from pydantic import BaseModel | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import Runnable, RunnableLambda | |
| from langchain_core.prompts import ( | |
| load_prompt, | |
| ChatPromptTemplate, | |
| AIMessagePromptTemplate, | |
| HumanMessagePromptTemplate, | |
| SystemMessagePromptTemplate, | |
| ) | |
| from src.common.paths import PROMPTS_PATH | |
| from src.common.schema import DatasetSchema | |
| from src.generate.llms import LLM_NAME_TO_CLASS, LLMName | |
| class GenerationAnswer(BaseModel): | |
| answer: Any | |
| context: dict[str, Any] = {} | |
| def build_singleturn_chain( | |
| answer_class: type[BaseModel], | |
| llm_class: LLMName = "ollama", | |
| llm_args: dict[str, Any] = { | |
| "model": "gemma3:4b", | |
| "top_k": 1, | |
| "top_p": 1, | |
| "temperature": 0.0, | |
| }, | |
| structured_output_method: Literal[ | |
| "function_calling", "json_mode", "json_schema" | |
| ] = "json_schema", | |
| ) -> Runnable: | |
| llm = LLM_NAME_TO_CLASS[llm_class]( | |
| **llm_args, | |
| ) | |
| llm = llm.with_structured_output( | |
| answer_class, | |
| method=structured_output_method, | |
| ) | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| HumanMessagePromptTemplate( | |
| prompt=load_prompt(PROMPTS_PATH / "singleturn.yaml") | |
| ) | |
| ] | |
| ) | |
| chain = RunnablePassthrough.assign(answer=prompt | llm) | RunnableLambda( | |
| lambda x: GenerationAnswer( | |
| answer=x["answer"], | |
| context={}, | |
| ) | |
| ) | |
| chain = chain.with_retry(retry_if_exception_type=(openai.PermissionDeniedError,)) | |
| return chain | |
| def build_thinking_chain( | |
| answer_class: type[BaseModel], | |
| llm_class: LLMName = "ollama", | |
| think_llm_args: dict[str, Any] = { | |
| "model": "gemma3:4b", | |
| "top_k": 1, | |
| "top_p": 1, | |
| "temperature": 0.0, | |
| }, | |
| answer_llm_args: dict[str, Any] = { | |
| "model": "gemma3:4b", | |
| "top_k": 1, | |
| "top_p": 1, | |
| "temperature": 0.0, | |
| }, | |
| structured_output_method: Literal[ | |
| "function_calling", "json_mode", "json_schema" | |
| ] = "json_schema", | |
| ) -> Runnable: | |
| think_llm = LLM_NAME_TO_CLASS[llm_class]( | |
| **think_llm_args, | |
| ) | |
| think_prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| SystemMessagePromptTemplate( | |
| prompt=load_prompt(PROMPTS_PATH / "simple_think_system.yaml") | |
| ), | |
| HumanMessagePromptTemplate.from_template(f"{{{DatasetSchema.task_text}}}"), | |
| ] | |
| ) | |
| think_chain = think_prompt | think_llm | StrOutputParser() | |
| answer_prompt = ChatPromptTemplate.from_messages( | |
| think_prompt.messages | |
| + [ | |
| AIMessagePromptTemplate.from_template("{think_answer}"), | |
| HumanMessagePromptTemplate( | |
| prompt=load_prompt(PROMPTS_PATH / "simple_think_end.yaml") | |
| ), | |
| ] | |
| ) | |
| answer_llm = LLM_NAME_TO_CLASS[llm_class]( | |
| **answer_llm_args, | |
| ) | |
| answer_llm = answer_llm.with_structured_output( | |
| answer_class, | |
| method=structured_output_method, | |
| ) | |
| chain = ( | |
| RunnablePassthrough.assign( | |
| think_answer=think_chain, | |
| ) | |
| | RunnablePassthrough.assign(answer=answer_prompt | answer_llm) | |
| | RunnableLambda( | |
| lambda x: GenerationAnswer( | |
| answer=x["answer"], | |
| context={ | |
| "think_answer": x["think_answer"], | |
| }, | |
| ) | |
| ) | |
| ) | |
| chain = chain.with_retry(retry_if_exception_type=(openai.PermissionDeniedError,)) | |
| return chain | |
| GeneratorName = Literal["singleturn", "thinking"] | |
| GENERATORS_NAME_TO_FACTORY: dict[str, Callable[[type[BaseModel]], Runnable]] = { | |
| "singleturn": build_singleturn_chain, | |
| "thinking": build_thinking_chain, | |
| } | |