# -*- coding: utf-8 -*- """ YOUR FOIA CHAT ASSISTANCE - Text-only chatbot (STT and TTS removed) Drop this file into your Hugging Face Space (replace existing app.py) or run locally. Notes: - Dark UI via custom CSS (works even if Gradio theme API differs) - Performance-focused: greedy generation, lower max_new_tokens, use_cache, no_grad, streaming - Keeps bitsandbytes / 4-bit logic intact when available """ import os import threading import gradio as gr import importlib import importlib.util import torch from huggingface_hub import login from transformers import ( AutoTokenizer, AutoModelForCausalLM, pipeline, TextIteratorStreamer, ) from peft import PeftModel, PeftConfig # -------------------- Configuration -------------------- ADAPTER_REPO_ID = "EYEDOL/FOIA" # adapter-only repo BASE_MODEL_ID = "unsloth/Llama-3.2-3B-Instruct-bnb-4bit" # full base model referenced by adapter HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("hugface") if HF_TOKEN: try: login(token=HF_TOKEN) print("Successfully logged into Hugging Face Hub!") except Exception as e: print("Warning: huggingface_hub.login() failed:", e) else: print("Warning: HF_TOKEN not found in env. Private repos may fail to load.") def is_package_installed(name: str) -> bool: """Return True if installed (distribution metadata present).""" try: import importlib.metadata as md try: md.distribution(name) return True except Exception: return False except Exception: try: importlib.import_module(name) return True except Exception: return False class WeeboAssistant: def __init__(self): # system prompt instructs the assistant to answer concisely in English self.SYSTEM_PROMPT = ( "You are an intelligent assistant. Answer questions briefly and accurately. " "Respond only in English. No long answers.\n" ) # generation defaults tuned for speed (adjust if you need different behavior) self.MAX_NEW_TOKENS = 256 # lowered from 512 for speed self.DO_SAMPLE = False # greedy = faster; set True if you want sampling self.NUM_BEAMS = 1 # keep 1 for greedy (increase >1 for beam search) self._init_models() def _init_models(self): print("Initializing models...") self.device = "cuda" if torch.cuda.is_available() else "cpu" self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32 print(f"Using device: {self.device}, torch_dtype: {self.torch_dtype}") BNB_AVAILABLE = is_package_installed("bitsandbytes") print("bitsandbytes available:", BNB_AVAILABLE) # load tokenizer (prefer base tokenizer) try: self.llm_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True) print("Loaded tokenizer from BASE_MODEL_ID") except Exception as e: print("Warning: could not load base tokenizer, falling back to adapter tokenizer. Error:", e) self.llm_tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO_ID, use_fast=True) print("Loaded tokenizer from ADAPTER_REPO_ID") # ensure tokenizer has pad_token_id to avoid generation stalls if getattr(self.llm_tokenizer, "pad_token_id", None) is None: if getattr(self.llm_tokenizer, "eos_token_id", None) is not None: self.llm_tokenizer.pad_token_id = self.llm_tokenizer.eos_token_id else: # fallback to 0 to prevent crashes (not ideal but safe) self.llm_tokenizer.pad_token_id = 0 # decide device_map (never pass None) if torch.cuda.is_available(): device_map = "auto" else: device_map = {"": "cpu"} print("device_map being used for model load:", device_map) base_model_kwargs = dict( torch_dtype=self.torch_dtype, low_cpu_mem_usage=True, device_map=device_map, trust_remote_code=True, ) if BNB_AVAILABLE and torch.cuda.is_available(): base_model_kwargs["load_in_4bit"] = True print("Will attempt to load base model in 4-bit (bitsandbytes + CUDA detected).") else: print("bitsandbytes not usable or no CUDA: loading model normally (no 4-bit).") try: self.llm_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL_ID, **base_model_kwargs, ) # ensure use_cache set for faster autoregressive generation try: self.llm_model.config.use_cache = True except Exception: pass print("Base model loaded from", BASE_MODEL_ID) except Exception as e: raise RuntimeError( "Failed to load base model. Ensure the base model ID is correct and HF_TOKEN has access if private. Error: " + str(e) ) # load and apply PEFT adapter try: try: peft_config = PeftConfig.from_pretrained(ADAPTER_REPO_ID) print("Loaded PEFT config from", ADAPTER_REPO_ID) except Exception: peft_config = None print("Warning: could not load PeftConfig; continuing to attempt adapter load.") peft_kwargs = dict( device_map=device_map, torch_dtype=self.torch_dtype, low_cpu_mem_usage=True, ) self.llm_model = PeftModel.from_pretrained( self.llm_model, ADAPTER_REPO_ID, **peft_kwargs, ) # ensure adapter-wrapped model also has use_cache try: self.llm_model.config.use_cache = True except Exception: pass print("PEFT adapter applied from", ADAPTER_REPO_ID) except Exception as e: raise RuntimeError( "Failed to load/apply PEFT adapter from adapter repo. Make sure adapter files are present and HF_TOKEN has access if private. Error: " + str(e) ) # optional non-streaming pipeline (useful for quick tests) try: device_index = 0 if torch.cuda.is_available() else -1 self.llm_pipeline = pipeline( "text-generation", model=self.llm_model, tokenizer=self.llm_tokenizer, device=device_index, model_kwargs={"torch_dtype": self.torch_dtype}, ) print("Created text-generation pipeline (non-streaming).") except Exception as e: print("Warning: could not create text-generation pipeline. Streaming generate will still work. Error:", e) self.llm_pipeline = None print("LLM base + adapter loaded successfully.") def get_llm_response(self, chat_history): # Build prompt (system + conversation) prompt_lines = [self.SYSTEM_PROMPT] for user_msg, assistant_msg in chat_history: if user_msg: prompt_lines.append("User: " + user_msg) if assistant_msg: prompt_lines.append("Assistant: " + assistant_msg) prompt_lines.append("Assistant: ") prompt = "\n".join(prompt_lines) # Tokenize inputs inputs = self.llm_tokenizer(prompt, return_tensors="pt", padding=False) try: model_device = next(self.llm_model.parameters()).device except StopIteration: model_device = torch.device("cpu") inputs = {k: v.to(model_device) for k, v in inputs.items()} # Use TextIteratorStreamer for streaming outputs to Gradio streamer = TextIteratorStreamer(self.llm_tokenizer, skip_prompt=True, skip_special_tokens=True) # Prefill generation kwargs optimized for speed input_len = inputs["input_ids"].shape[1] max_new = self.MAX_NEW_TOKENS max_length = input_len + max_new generation_kwargs = dict( input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask", None), max_length=max_length, # input_len + max_new max_new_tokens=max_new, # explicit do_sample=self.DO_SAMPLE, # greedy if False -> faster num_beams=self.NUM_BEAMS, # keep 1 for speed streamer=streamer, eos_token_id=getattr(self.llm_tokenizer, "eos_token_id", None), pad_token_id=getattr(self.llm_tokenizer, "pad_token_id", None), use_cache=True, early_stopping=True, ) # Run generate under no_grad to save memory and time def _generate_thread(): with torch.no_grad(): try: self.llm_model.generate(**generation_kwargs) except Exception as e: print("Generation error:", e) gen_thread = threading.Thread(target=_generate_thread, daemon=True) gen_thread.start() return streamer # create assistant instance (loads model once at startup) assistant = WeeboAssistant() # -------------------- Gradio pipeline functions -------------------- def t2t_pipeline(text_input, chat_history): chat_history = chat_history or [] chat_history.append((text_input, "")) # placeholder for assistant reply yield chat_history response_stream = assistant.get_llm_response(chat_history) llm_response_text = "" for text_chunk in response_stream: llm_response_text += text_chunk chat_history[-1] = (text_input, llm_response_text) yield chat_history def clear_textbox(): return gr.Textbox.update(value="") # -------------------- MODIFIED: Modern Dark UI CSS -------------------- MODERN_CSS = """ @import url('https://fonts.googleapis.com/css2?family=Poppins:wght@400;500;600;700&display=swap'); :root { --body-bg: linear-gradient(135deg, #10141a 0%, #06090f 100%); --chat-bg: #0b0f19; --border-color: rgba(255, 255, 255, 0.08); --text-color: #E6EEF8; --input-bg: #131926; --user-msg-bg: #1B2336; --bot-msg-bg: #0F1522; --primary-color: #0084ff; --primary-hover: #006fdb; --font-family: 'Poppins', sans-serif; } body, .gradio-container { background: var(--body-bg) !important; color: var(--text-color) !important; font-family: var(--font-family) !important; } .gradio-container * { font-family: var(--font-family) !important; } h1, h2, h3, .markdown { color: var(--text-color) !important; } .gr-block, .gr-box, .gr-row, .gr-column { background: transparent !important; border: none !important; box-shadow: none !important; } .gr-chatbot { background: var(--chat-bg) !important; border: 1px solid var(--border-color) !important; border-radius: 12px !important; box-shadow: 0 4px 20px rgba(0, 0, 0, 0.2) !important; } .gr-chatbot .message { border-radius: 8px !important; padding: 12px !important; box-shadow: 0 2px 4px rgba(0,0,0,0.1) !important; border: none !important; } .gr-chatbot .message.user { background: var(--user-msg-bg) !important; color: var(--text-color) !important; } .gr-chatbot .message.bot { background: var(--bot-msg-bg) !important; color: var(--text-color) !important; } .gr-chatbot .message p { margin: 0; } .gr-textbox, .gr-textbox textarea { background: var(--input-bg) !important; color: var(--text-color) !important; border: 1px solid var(--border-color) !important; border-radius: 8px !important; transition: all 0.2s ease-in-out; } .gr-textbox:focus, .gr-textbox textarea:focus { border-color: var(--primary-color) !important; box-shadow: 0 0 0 2px rgba(0, 132, 255, 0.3) !important; } .gr-button { background: var(--primary-color) !important; color: white !important; border: none !important; border-radius: 8px !important; box-shadow: 0 4px 12px rgba(0, 132, 255, 0.2) !important; transition: all 0.2s ease-in-out !important; font-weight: 500 !important; display: flex; justify-content: center; align-items: center; gap: 8px; /* Space between icon and text */ } .gr-button:hover { background: var(--primary-hover) !important; transform: translateY(-2px); box-shadow: 0 6px 16px rgba(0, 132, 255, 0.3) !important; } /* Hide default Gradio button text when we add our own */ .send-btn span { font-size: 1rem; } /* Add a send icon to the button */ .send-btn::before { content: ''; display: inline-block; width: 20px; height: 20px; background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 24 24' fill='white'%3E%3Cpath d='M2.01 21L23 12 2.01 3 2 10l15 2-15 2z'/%3E%3C/svg%3E"); background-size: contain; background-repeat: no-repeat; background-position: center; } footer, .footer { display: none !important; } """ # -------------------- MODIFIED: Gradio UI with Logo -------------------- with gr.Blocks(css=MODERN_CSS, title="DimChi FOIA Assistant") as demo: # NEW: Centered header with logo with gr.Row(): gr.Markdown( """
Your intelligent chat partner for FOIA inquiries.