DynaGuard / app.py
taruschirag's picture
Update app.py
9f5b9de verified
raw
history blame
5.39 kB
import os
# --- CRITICAL: SET ENVIRONMENT VARIABLES BEFORE IMPORTING GRADIO ---
# This ensures a stable Gradio environment.
os.environ["GRADIO_ENABLE_SSR"] = "0"
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
# --- Hugging Face Login ---
HF_READONLY_API_KEY = os.getenv("HF_READONLY_API_KEY")
if HF_READONLY_API_KEY:
login(token=HF_READONLY_API_KEY)
# --- Constants ---
SYSTEM_PROMPT = """You are a guardian model evaluating…</explanation>"""
COT_OPENING = "<think>"
# --- Helper Functions ---
def format_rules(rules):
formatted_rules = "<rules>\n"
for i, rule in enumerate(rules):
formatted_rules += f"{i + 1}. {rule}\n"
formatted_rules += "</rules>\n"
return formatted_rules
def format_transcript(transcript):
formatted_transcript = f"<transcript>\n{transcript}\n</transcript>\n"
return formatted_transcript
def safe_truncate_to_bytes(text, max_bytes=4096):
"""
Safely truncates text to fit within a byte limit, handling UTF-8 correctly.
"""
if len(text.encode('utf-8')) <= max_bytes:
return text
# Binary search for the right truncation point
left, right = 0, len(text)
result = ""
while left <= right:
mid = (left + right) // 2
candidate = text[:mid]
if len(candidate.encode('utf-8')) <= max_bytes:
result = candidate
left = mid + 1
else:
right = mid - 1
# Add a truncation notice if the text was shortened
if len(result) < len(text):
notice = "\n\n[Response truncated to prevent server errors]"
notice_bytes = len(notice.encode('utf-8'))
# Make space for the notice itself
if len(result.encode('utf-8')) + notice_bytes > max_bytes:
result = result[:len(result) - len(notice)]
result += notice
return result
# --- Your Original ModelWrapper Class ---
# Bringing this back as it's a good way to organize your model logic.
class ModelWrapper:
def __init__(self, model_name="Qwen/Qwen3-0.6B"):
print(f"Loading model: {model_name}")
self.model_name = model_name
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.tokenizer.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
self.model = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", torch_dtype=torch.bfloat16).eval()
print("Model loaded successfully.")
def get_response(self, prompt, max_new_tokens=256, temperature=0.7, top_p=0.9, **kwargs):
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
output_ids = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
pad_token_id=self.tokenizer.pad_token_id,
do_sample=True,
eos_token_id=self.tokenizer.eos_token_id
)
# Decode only the newly generated part of the output
return self.tokenizer.decode(output_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
# --- Instantiate Your Model ---
model_wrapper = ModelWrapper()
# --- Main Gradio Inference Function ---
def compliance_check(rules_text, transcript_text, thinking):
try:
# Input validation
if not rules_text.strip():
return "Error: Please provide at least one rule."
if not transcript_text.strip():
return "Error: Please provide a transcript to analyze."
rules = [r.strip() for r in rules_text.split("\n") if r.strip()]
inp = format_rules(rules) + format_transcript(transcript_text)
# Prepare the prompt using a simplified chat template structure
message = [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': inp}
]
prompt = model_wrapper.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
if thinking:
prompt += f"\n{COT_OPENING}"
# Get the model's response
out = model_wrapper.get_response(prompt)
if not out.strip():
out = "No response generated from the model."
except Exception as e:
print(f"An error occurred: {str(e)}")
out = f"An unexpected error occurred during processing. Please check the logs."
# Apply safe truncation to ALL possible outputs (both success and error)
return safe_truncate_to_bytes(out.strip())
# — Build the Final Gradio Interface —
demo = gr.Interface(
fn=compliance_check,
inputs=[
gr.Textbox(lines=5, label="Rules (one per line)", placeholder="Enter compliance rules..."),
gr.Textbox(lines=10, label="Transcript", placeholder="Paste the transcript to analyze..."),
gr.Checkbox(label="Enable ⟨think⟩ mode", value=True)
],
outputs=gr.Textbox(label="Compliance Output", lines=10, show_copy_button=True),
title="DynaGuard Compliance Checker",
description="Paste your rules & transcript, then hit Submit.",
flagging_options=None # Modern way to disable flagging
)
if __name__ == "__main__":
demo.launch()