taruschirag commited on
Commit
aefdf18
·
verified ·
1 Parent(s): 9f5b9de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -82
app.py CHANGED
@@ -1,24 +1,21 @@
1
- import os
2
-
3
- # --- CRITICAL: SET ENVIRONMENT VARIABLES BEFORE IMPORTING GRADIO ---
4
- # This ensures a stable Gradio environment.
5
  os.environ["GRADIO_ENABLE_SSR"] = "0"
6
-
7
  import gradio as gr
8
  import torch
9
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
10
  from huggingface_hub import login
11
 
12
- # --- Hugging Face Login ---
13
  HF_READONLY_API_KEY = os.getenv("HF_READONLY_API_KEY")
14
- if HF_READONLY_API_KEY:
15
- login(token=HF_READONLY_API_KEY)
16
 
17
- # --- Constants ---
 
 
 
 
18
  SYSTEM_PROMPT = """You are a guardian model evaluating…</explanation>"""
19
- COT_OPENING = "<think>"
20
 
21
- # --- Helper Functions ---
22
  def format_rules(rules):
23
  formatted_rules = "<rules>\n"
24
  for i, rule in enumerate(rules):
@@ -30,113 +27,159 @@ def format_transcript(transcript):
30
  formatted_transcript = f"<transcript>\n{transcript}\n</transcript>\n"
31
  return formatted_transcript
32
 
33
- def safe_truncate_to_bytes(text, max_bytes=4096):
34
- """
35
- Safely truncates text to fit within a byte limit, handling UTF-8 correctly.
36
- """
37
- if len(text.encode('utf-8')) <= max_bytes:
38
- return text
39
-
40
- # Binary search for the right truncation point
41
- left, right = 0, len(text)
42
- result = ""
43
- while left <= right:
44
- mid = (left + right) // 2
45
- candidate = text[:mid]
46
- if len(candidate.encode('utf-8')) <= max_bytes:
47
- result = candidate
48
- left = mid + 1
49
- else:
50
- right = mid - 1
51
-
52
- # Add a truncation notice if the text was shortened
53
- if len(result) < len(text):
54
- notice = "\n\n[Response truncated to prevent server errors]"
55
- notice_bytes = len(notice.encode('utf-8'))
56
- # Make space for the notice itself
57
- if len(result.encode('utf-8')) + notice_bytes > max_bytes:
58
- result = result[:len(result) - len(notice)]
59
- result += notice
60
-
61
- return result
62
 
63
- # --- Your Original ModelWrapper Class ---
64
- # Bringing this back as it's a good way to organize your model logic.
65
  class ModelWrapper:
66
  def __init__(self, model_name="Qwen/Qwen3-0.6B"):
67
- print(f"Loading model: {model_name}")
68
  self.model_name = model_name
69
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
70
  self.tokenizer.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
71
  self.model = AutoModelForCausalLM.from_pretrained(
72
  model_name, device_map="auto", torch_dtype=torch.bfloat16).eval()
73
- print("Model loaded successfully.")
74
 
75
- def get_response(self, prompt, max_new_tokens=256, temperature=0.7, top_p=0.9, **kwargs):
76
- inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  with torch.no_grad():
78
- output_ids = self.model.generate(
79
  **inputs,
80
  max_new_tokens=max_new_tokens,
 
81
  temperature=temperature,
 
82
  top_p=top_p,
 
83
  pad_token_id=self.tokenizer.pad_token_id,
84
  do_sample=True,
85
  eos_token_id=self.tokenizer.eos_token_id
86
  )
87
- # Decode only the newly generated part of the output
88
- return self.tokenizer.decode(output_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- # --- Instantiate Your Model ---
91
- model_wrapper = ModelWrapper()
 
92
 
93
- # --- Main Gradio Inference Function ---
94
  def compliance_check(rules_text, transcript_text, thinking):
95
  try:
96
- # Input validation
97
- if not rules_text.strip():
98
- return "Error: Please provide at least one rule."
99
- if not transcript_text.strip():
100
- return "Error: Please provide a transcript to analyze."
101
-
102
- rules = [r.strip() for r in rules_text.split("\n") if r.strip()]
103
  inp = format_rules(rules) + format_transcript(transcript_text)
104
 
105
- # Prepare the prompt using a simplified chat template structure
106
- message = [
107
- {'role': 'system', 'content': SYSTEM_PROMPT},
108
- {'role': 'user', 'content': inp}
109
- ]
110
- prompt = model_wrapper.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
111
 
112
- if thinking:
113
- prompt += f"\n{COT_OPENING}"
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
- # Get the model's response
116
- out = model_wrapper.get_response(prompt)
117
 
118
- if not out.strip():
119
- out = "No response generated from the model."
120
-
121
  except Exception as e:
122
- print(f"An error occurred: {str(e)}")
123
- out = f"An unexpected error occurred during processing. Please check the logs."
 
124
 
125
- # Apply safe truncation to ALL possible outputs (both success and error)
126
- return safe_truncate_to_bytes(out.strip())
127
 
128
- # — Build the Final Gradio Interface
129
  demo = gr.Interface(
130
  fn=compliance_check,
131
  inputs=[
132
- gr.Textbox(lines=5, label="Rules (one per line)", placeholder="Enter compliance rules..."),
133
- gr.Textbox(lines=10, label="Transcript", placeholder="Paste the transcript to analyze..."),
134
  gr.Checkbox(label="Enable ⟨think⟩ mode", value=True)
135
  ],
136
- outputs=gr.Textbox(label="Compliance Output", lines=10, show_copy_button=True),
137
  title="DynaGuard Compliance Checker",
138
  description="Paste your rules & transcript, then hit Submit.",
139
- flagging_options=None # Modern way to disable flagging
 
140
  )
141
 
142
  if __name__ == "__main__":
 
 
 
 
 
1
  os.environ["GRADIO_ENABLE_SSR"] = "0"
2
+ import os
3
  import gradio as gr
4
  import torch
5
  from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ from datasets import load_dataset
7
  from huggingface_hub import login
8
 
 
9
  HF_READONLY_API_KEY = os.getenv("HF_READONLY_API_KEY")
10
+ login(token=HF_READONLY_API_KEY)
 
11
 
12
+ COT_OPENING = "<think>"
13
+ EXPLANATION_OPENING = "<explanation>"
14
+ LABEL_OPENING = "<answer>"
15
+ LABEL_CLOSING = "</answer>"
16
+ INPUT_FIELD = "question"
17
  SYSTEM_PROMPT = """You are a guardian model evaluating…</explanation>"""
 
18
 
 
19
  def format_rules(rules):
20
  formatted_rules = "<rules>\n"
21
  for i, rule in enumerate(rules):
 
27
  formatted_transcript = f"<transcript>\n{transcript}\n</transcript>\n"
28
  return formatted_transcript
29
 
30
+ def get_example(
31
+ dataset_path="tomg-group-umd/compliance_benchmark",
32
+ subset="compliance",
33
+ split="test_handcrafted",
34
+ example_idx=0,
35
+ ):
36
+ dataset = load_dataset(dataset_path, subset, split=split)
37
+ example = dataset[example_idx]
38
+ return example[INPUT_FIELD]
39
+
40
+ def get_message(model, input, system_prompt=SYSTEM_PROMPT, enable_thinking=True):
41
+ message = model.apply_chat_template(system_prompt, input, enable_thinking=enable_thinking)
42
+ return message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
 
 
44
  class ModelWrapper:
45
  def __init__(self, model_name="Qwen/Qwen3-0.6B"):
 
46
  self.model_name = model_name
47
+ if "nemoguard" in model_name:
48
+ self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
49
+ else:
50
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
51
  self.tokenizer.pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
52
  self.model = AutoModelForCausalLM.from_pretrained(
53
  model_name, device_map="auto", torch_dtype=torch.bfloat16).eval()
 
54
 
55
+ def get_message_template(self, system_content=None, user_content=None, assistant_content=None):
56
+ """Compile sys, user, assistant inputs into the proper dictionaries"""
57
+ message = []
58
+ if system_content is not None:
59
+ message.append({'role': 'system', 'content': system_content})
60
+ if user_content is not None:
61
+ message.append({'role': 'user', 'content': user_content})
62
+ if assistant_content is not None:
63
+ message.append({'role': 'assistant', 'content': assistant_content})
64
+ if not message:
65
+ raise ValueError("No content provided for any role.")
66
+ return message
67
+
68
+ def apply_chat_template(self, system_content, user_content, assistant_content=None, enable_thinking=True):
69
+ """Call the tokenizer's chat template with exactly the right arguments for whether we want it to generate thinking before the answer (which differs depending on whether it is Qwen3 or not)."""
70
+ if assistant_content is not None:
71
+ # If assistant content is passed we simply use it.
72
+ # This works for both Qwen3 and non-Qwen3 models. With Qwen3 any time assistant_content is provided, it automatically adds the <think></think> pair before the content, which is what we want.
73
+ message = self.get_message_template(system_content, user_content, assistant_content)
74
+ prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
75
+ else:
76
+ if enable_thinking:
77
+ if "qwen3" in self.model_name.lower():
78
+ # Let the Qwen chat template handle the thinking token
79
+ message = self.get_message_template(system_content, user_content)
80
+ prompt = self.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True, enable_thinking=True)
81
+ # The way the Qwen3 chat template works is it adds a <think></think> pair when enable_thinking=False, but for enable_thinking=True, it adds nothing and lets the model decide. Here we force the <think> tag to be there.
82
+ prompt = prompt + f"\n{COT_OPENING}"
83
+ else:
84
+ message = self.get_message_template(system_content, user_content, assistant_content=COT_OPENING)
85
+ prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True)
86
+ else:
87
+ # This works for both Qwen3 and non-Qwen3 models.
88
+ # When Qwen3 gets assistant_content, it automatically adds the <think></think> pair before the content like we want. And other models ignore the enable_thinking argument.
89
+ message = self.get_message_template(system_content, user_content, assistant_content=LABEL_OPENING)
90
+ prompt = self.tokenizer.apply_chat_template(message, tokenize=False, continue_final_message=True, enable_thinking=False)
91
+ return prompt
92
+
93
+ def get_response(self, input, temperature=0.7, top_k=20, top_p=0.8, max_new_tokens=256, enable_thinking=True, system_prompt=SYSTEM_PROMPT):
94
+ """Generate and decode the response with the recommended temperature settings for thinking and non-thinking."""
95
+ print("Generating response...")
96
+
97
+ if "qwen3" in self.model_name.lower() and enable_thinking:
98
+ # Use values from https://huggingface.co/Qwen/Qwen3-8B#switching-between-thinking-and-non-thinking-mode
99
+ temperature = 0.6
100
+ top_p = 0.95
101
+ top_k = 20
102
+
103
+ message = self.apply_chat_template(system_prompt, input, enable_thinking=enable_thinking)
104
+ inputs = self.tokenizer(message, return_tensors="pt").to(self.model.device)
105
+
106
  with torch.no_grad():
107
+ output_content = self.model.generate(
108
  **inputs,
109
  max_new_tokens=max_new_tokens,
110
+ num_return_sequences=1,
111
  temperature=temperature,
112
+ top_k=top_k,
113
  top_p=top_p,
114
+ min_p=0,
115
  pad_token_id=self.tokenizer.pad_token_id,
116
  do_sample=True,
117
  eos_token_id=self.tokenizer.eos_token_id
118
  )
119
+
120
+ output_text = self.tokenizer.decode(output_content[0], skip_special_tokens=True)
121
+
122
+ try:
123
+ sys_prompt_text = output_text.split("Brief explanation\n</explanation>")[0]
124
+ remainder = output_text.split("Brief explanation\n</explanation>")[-1]
125
+ rules_transcript_text = remainder.split("</transcript>")[0]
126
+ thinking_answer_text = remainder.split("</transcript>")[-1]
127
+ return thinking_answer_text
128
+ except:
129
+ # If parsing fails, return the portion after the input
130
+ input_length = len(message)
131
+ return output_text[input_length:] if len(output_text) > input_length else "No response generated."
132
 
133
+ # instantiate your model
134
+ MODEL_NAME = "Qwen/Qwen3-0.6B"
135
+ model = ModelWrapper(MODEL_NAME)
136
 
137
+ # Gradio inference function
138
  def compliance_check(rules_text, transcript_text, thinking):
139
  try:
140
+ rules = [r for r in rules_text.split("\n") if r.strip()]
 
 
 
 
 
 
141
  inp = format_rules(rules) + format_transcript(transcript_text)
142
 
143
+ # Limit max tokens to prevent oversized responses
144
+ out = model.get_response(inp, enable_thinking=thinking, max_new_tokens=256)
 
 
 
 
145
 
146
+ # Clean up any malformed output and ensure it's a string
147
+ out = str(out).strip()
148
+ if not out:
149
+ out = "No response generated. Please try with different input."
150
+
151
+ # Ensure the response isn't too long for an HTTP response by checking byte length
152
+ max_bytes = 2500 # A more generous limit, in bytes
153
+ out_bytes = out.encode('utf-8')
154
+
155
+ if len(out_bytes) > max_bytes:
156
+ # Truncate the byte string, then decode back to a string, ignoring errors
157
+ # This prevents cutting a multi-byte character in half
158
+ truncated_bytes = out_bytes[:max_bytes]
159
+ out = truncated_bytes.decode('utf-8', errors='ignore')
160
+ out += "\n\n[Response truncated to prevent server errors]"
161
 
162
+ return out
 
163
 
 
 
 
164
  except Exception as e:
165
+ error_msg = f"Error: {str(e)[:200]}" # Limit error message length
166
+ print(f"Full error: {e}")
167
+ return error_msg
168
 
 
 
169
 
170
+ # — build Gradio interface
171
  demo = gr.Interface(
172
  fn=compliance_check,
173
  inputs=[
174
+ gr.Textbox(lines=5, label="Rules (one per line)", max_lines=10),
175
+ gr.Textbox(lines=10, label="Transcript", max_lines=15),
176
  gr.Checkbox(label="Enable ⟨think⟩ mode", value=True)
177
  ],
178
+ outputs=gr.Textbox(label="Compliance Output", lines=10, max_lines=15),
179
  title="DynaGuard Compliance Checker",
180
  description="Paste your rules & transcript, then hit Submit.",
181
+ allow_flagging="never",
182
+ show_progress=True
183
  )
184
 
185
  if __name__ == "__main__":