Spaces:
Sleeping
Sleeping
| import os | |
| # --- CRITICAL: SET ENVIRONMENT VARIABLES BEFORE IMPORTING GRADIO --- | |
| # This ensures a stable Gradio environment. | |
| os.environ["GRADIO_ENABLE_SSR"] = "0" | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from huggingface_hub import login | |
| # --- Hugging Face Login --- | |
| HF_READONLY_API_KEY = os.getenv("HF_READONLY_API_KEY") | |
| if HF_READONLY_API_KEY: | |
| login(token=HF_READONLY_API_KEY) | |
| # --- Constants --- | |
| SYSTEM_PROMPT = """You are a guardian model evaluating…</explanation>""" | |
| COT_OPENING = "<think>" | |
| # --- Helper Functions --- | |
| def format_rules(rules): | |
| formatted_rules = "<rules>\n" | |
| for i, rule in enumerate(rules): | |
| formatted_rules += f"{i + 1}. {rule}\n" | |
| formatted_rules += "</rules>\n" | |
| return formatted_rules | |
| def format_transcript(transcript): | |
| formatted_transcript = f"<transcript>\n{transcript}\n</transcript>\n" | |
| return formatted_transcript | |
| def safe_truncate_to_bytes(text, max_bytes=4096): | |
| """ | |
| Safely truncates text to fit within a byte limit, handling UTF-8 correctly. | |
| """ | |
| if len(text.encode('utf-8')) <= max_bytes: | |
| return text | |
| # Binary search for the right truncation point | |
| left, right = 0, len(text) | |
| result = "" | |
| while left <= right: | |
| mid = (left + right) // 2 | |
| candidate = text[:mid] | |
| if len(candidate.encode('utf-8')) <= max_bytes: | |
| result = candidate | |
| left = mid + 1 | |
| else: | |
| right = mid - 1 | |
| # Add a truncation notice if the text was shortened | |
| if len(result) < len(text): | |
| notice = "\n\n[Response truncated to prevent server errors]" | |
| notice_bytes = len(notice.encode('utf-8')) | |
| # Make space for the notice itself | |
| if len(result.encode('utf-8')) + notice_bytes > max_bytes: | |
| result = result[:len(result) - len(notice)] | |
| result += notice | |
| return result | |
| # --- Your Original ModelWrapper Class --- | |
| # Bringing this back as it's a good way to organize your model logic. | |
| class ModelWrapper: | |
| def __init__(self, model_name="Qwen/Qwen3-0.6B"): | |
| print(f"Loading model: {model_name}") | |
| self.model_name = model_name | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.tokenizer.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_name, device_map="auto", torch_dtype=torch.bfloat16).eval() | |
| print("Model loaded successfully.") | |
| def get_response(self, prompt, max_new_tokens=256, temperature=0.7, top_p=0.9, **kwargs): | |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) | |
| with torch.no_grad(): | |
| output_ids = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| do_sample=True, | |
| eos_token_id=self.tokenizer.eos_token_id | |
| ) | |
| # Decode only the newly generated part of the output | |
| return self.tokenizer.decode(output_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
| # --- Instantiate Your Model --- | |
| model_wrapper = ModelWrapper() | |
| # --- Main Gradio Inference Function --- | |
| def compliance_check(rules_text, transcript_text, thinking): | |
| try: | |
| # Input validation | |
| if not rules_text.strip(): | |
| return "Error: Please provide at least one rule." | |
| if not transcript_text.strip(): | |
| return "Error: Please provide a transcript to analyze." | |
| rules = [r.strip() for r in rules_text.split("\n") if r.strip()] | |
| inp = format_rules(rules) + format_transcript(transcript_text) | |
| # Prepare the prompt using a simplified chat template structure | |
| message = [ | |
| {'role': 'system', 'content': SYSTEM_PROMPT}, | |
| {'role': 'user', 'content': inp} | |
| ] | |
| prompt = model_wrapper.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) | |
| if thinking: | |
| prompt += f"\n{COT_OPENING}" | |
| # Get the model's response | |
| out = model_wrapper.get_response(prompt) | |
| if not out.strip(): | |
| out = "No response generated from the model." | |
| except Exception as e: | |
| print(f"An error occurred: {str(e)}") | |
| out = f"An unexpected error occurred during processing. Please check the logs." | |
| # Apply safe truncation to ALL possible outputs (both success and error) | |
| return safe_truncate_to_bytes(out.strip()) | |
| # — Build the Final Gradio Interface — | |
| demo = gr.Interface( | |
| fn=compliance_check, | |
| inputs=[ | |
| gr.Textbox(lines=5, label="Rules (one per line)", placeholder="Enter compliance rules..."), | |
| gr.Textbox(lines=10, label="Transcript", placeholder="Paste the transcript to analyze..."), | |
| gr.Checkbox(label="Enable ⟨think⟩ mode", value=True) | |
| ], | |
| outputs=gr.Textbox(label="Compliance Output", lines=10, show_copy_button=True), | |
| title="DynaGuard Compliance Checker", | |
| description="Paste your rules & transcript, then hit Submit.", | |
| flagging_options=None # Modern way to disable flagging | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |