taruschirag commited on
Commit
9f5b9de
·
verified ·
1 Parent(s): 0b2a0d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -13
app.py CHANGED
@@ -1,26 +1,143 @@
1
  import os
2
 
3
- # Forcefully disable Server-Side Rendering
 
4
  os.environ["GRADIO_ENABLE_SSR"] = "0"
5
 
6
  import gradio as gr
 
 
 
7
 
8
- print("Gradio imported. Starting minimal test app.")
 
 
 
9
 
10
- def minimal_test_function(text_input):
11
- """A simple function that cannot fail."""
12
- print("Test function executed successfully.")
13
- return f"The app is stable. You entered: '{text_input}'"
14
 
15
- # A minimal interface with no complex dependencies
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  demo = gr.Interface(
17
- fn=minimal_test_function,
18
- inputs=gr.Textbox(label="Input Text"),
19
- outputs=gr.Textbox(label="Output"),
20
- title="Environment Stability Test",
21
- description="If you see this page and can use it without it crashing, the environment is now fixed."
 
 
 
 
 
22
  )
23
 
24
  if __name__ == "__main__":
25
- # Standard launch command
26
  demo.launch()
 
1
  import os
2
 
3
+ # --- CRITICAL: SET ENVIRONMENT VARIABLES BEFORE IMPORTING GRADIO ---
4
+ # This ensures a stable Gradio environment.
5
  os.environ["GRADIO_ENABLE_SSR"] = "0"
6
 
7
  import gradio as gr
8
+ import torch
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+ from huggingface_hub import login
11
 
12
+ # --- Hugging Face Login ---
13
+ HF_READONLY_API_KEY = os.getenv("HF_READONLY_API_KEY")
14
+ if HF_READONLY_API_KEY:
15
+ login(token=HF_READONLY_API_KEY)
16
 
17
+ # --- Constants ---
18
+ SYSTEM_PROMPT = """You are a guardian model evaluating…</explanation>"""
19
+ COT_OPENING = "<think>"
 
20
 
21
+ # --- Helper Functions ---
22
+ def format_rules(rules):
23
+ formatted_rules = "<rules>\n"
24
+ for i, rule in enumerate(rules):
25
+ formatted_rules += f"{i + 1}. {rule}\n"
26
+ formatted_rules += "</rules>\n"
27
+ return formatted_rules
28
+
29
+ def format_transcript(transcript):
30
+ formatted_transcript = f"<transcript>\n{transcript}\n</transcript>\n"
31
+ return formatted_transcript
32
+
33
+ def safe_truncate_to_bytes(text, max_bytes=4096):
34
+ """
35
+ Safely truncates text to fit within a byte limit, handling UTF-8 correctly.
36
+ """
37
+ if len(text.encode('utf-8')) <= max_bytes:
38
+ return text
39
+
40
+ # Binary search for the right truncation point
41
+ left, right = 0, len(text)
42
+ result = ""
43
+ while left <= right:
44
+ mid = (left + right) // 2
45
+ candidate = text[:mid]
46
+ if len(candidate.encode('utf-8')) <= max_bytes:
47
+ result = candidate
48
+ left = mid + 1
49
+ else:
50
+ right = mid - 1
51
+
52
+ # Add a truncation notice if the text was shortened
53
+ if len(result) < len(text):
54
+ notice = "\n\n[Response truncated to prevent server errors]"
55
+ notice_bytes = len(notice.encode('utf-8'))
56
+ # Make space for the notice itself
57
+ if len(result.encode('utf-8')) + notice_bytes > max_bytes:
58
+ result = result[:len(result) - len(notice)]
59
+ result += notice
60
+
61
+ return result
62
+
63
+ # --- Your Original ModelWrapper Class ---
64
+ # Bringing this back as it's a good way to organize your model logic.
65
+ class ModelWrapper:
66
+ def __init__(self, model_name="Qwen/Qwen3-0.6B"):
67
+ print(f"Loading model: {model_name}")
68
+ self.model_name = model_name
69
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
70
+ self.tokenizer.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
71
+ self.model = AutoModelForCausalLM.from_pretrained(
72
+ model_name, device_map="auto", torch_dtype=torch.bfloat16).eval()
73
+ print("Model loaded successfully.")
74
+
75
+ def get_response(self, prompt, max_new_tokens=256, temperature=0.7, top_p=0.9, **kwargs):
76
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
77
+ with torch.no_grad():
78
+ output_ids = self.model.generate(
79
+ **inputs,
80
+ max_new_tokens=max_new_tokens,
81
+ temperature=temperature,
82
+ top_p=top_p,
83
+ pad_token_id=self.tokenizer.pad_token_id,
84
+ do_sample=True,
85
+ eos_token_id=self.tokenizer.eos_token_id
86
+ )
87
+ # Decode only the newly generated part of the output
88
+ return self.tokenizer.decode(output_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
89
+
90
+ # --- Instantiate Your Model ---
91
+ model_wrapper = ModelWrapper()
92
+
93
+ # --- Main Gradio Inference Function ---
94
+ def compliance_check(rules_text, transcript_text, thinking):
95
+ try:
96
+ # Input validation
97
+ if not rules_text.strip():
98
+ return "Error: Please provide at least one rule."
99
+ if not transcript_text.strip():
100
+ return "Error: Please provide a transcript to analyze."
101
+
102
+ rules = [r.strip() for r in rules_text.split("\n") if r.strip()]
103
+ inp = format_rules(rules) + format_transcript(transcript_text)
104
+
105
+ # Prepare the prompt using a simplified chat template structure
106
+ message = [
107
+ {'role': 'system', 'content': SYSTEM_PROMPT},
108
+ {'role': 'user', 'content': inp}
109
+ ]
110
+ prompt = model_wrapper.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
111
+
112
+ if thinking:
113
+ prompt += f"\n{COT_OPENING}"
114
+
115
+ # Get the model's response
116
+ out = model_wrapper.get_response(prompt)
117
+
118
+ if not out.strip():
119
+ out = "No response generated from the model."
120
+
121
+ except Exception as e:
122
+ print(f"An error occurred: {str(e)}")
123
+ out = f"An unexpected error occurred during processing. Please check the logs."
124
+
125
+ # Apply safe truncation to ALL possible outputs (both success and error)
126
+ return safe_truncate_to_bytes(out.strip())
127
+
128
+ # — Build the Final Gradio Interface —
129
  demo = gr.Interface(
130
+ fn=compliance_check,
131
+ inputs=[
132
+ gr.Textbox(lines=5, label="Rules (one per line)", placeholder="Enter compliance rules..."),
133
+ gr.Textbox(lines=10, label="Transcript", placeholder="Paste the transcript to analyze..."),
134
+ gr.Checkbox(label="Enable ⟨think⟩ mode", value=True)
135
+ ],
136
+ outputs=gr.Textbox(label="Compliance Output", lines=10, show_copy_button=True),
137
+ title="DynaGuard Compliance Checker",
138
+ description="Paste your rules & transcript, then hit Submit.",
139
+ flagging_options=None # Modern way to disable flagging
140
  )
141
 
142
  if __name__ == "__main__":
 
143
  demo.launch()