adithyagv commited on
Commit
7a04dae
·
verified ·
1 Parent(s): 49970a3

Upload Gradio_UI.py

Browse files
Files changed (1) hide show
  1. Gradio_UI.py +331 -0
Gradio_UI.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import mimetypes
17
+ import os
18
+ import re
19
+ import shutil
20
+ from typing import Optional
21
+
22
+ from smolagents.agent_types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
23
+ from smolagents.agents import ActionStep, MultiStepAgent
24
+ from smolagents.memory import MemoryStep
25
+ from smolagents.utils import _is_package_available
26
+
27
+
28
+ import logging
29
+
30
+ # Configure logging (do this once, ideally at the top of your file)
31
+ logging.basicConfig(level=logging.DEBUG, # Set the logging level to capture detailed information
32
+ format='%(asctime)s - %(levelname)s - %(message)s')
33
+
34
+
35
+ def interact_with_agent(self, prompt, stored_messages):
36
+ logging.debug(f"Type of prompt in interact_with_agent: {type(prompt)}")
37
+ logging.debug(f"Value of prompt in interact_with_agent: {prompt}")
38
+ import gradio as gr
39
+ print(prompt)
40
+ stored_messages.append(gr.ChatMessage(role="user", content=prompt))
41
+ yield stored_messages
42
+ for msg in stream_to_gradio(self.agent, task=prompt, reset_agent_memory=False):
43
+ stored_messages.append(msg)
44
+ yield stored_messages
45
+ yield stored_messages
46
+
47
+
48
+ def log_user_message(self, text_input, file_uploads_log):
49
+ message = text_input + (
50
+ f"\nYou have been provided with these files, which might be helpful or not: {file_uploads_log}"
51
+ if len(file_uploads_log) > 0
52
+ else ""
53
+ )
54
+ logging.debug(f"Type of message in log_user_message: {type(message)}")
55
+ logging.debug(f"Value of message in log_user_message: {message}")
56
+ return message, None
57
+
58
+ def pull_messages_from_step(
59
+ step_log: MemoryStep,
60
+ ):
61
+ """Extract ChatMessage objects from agent steps with proper nesting"""
62
+ import gradio as gr
63
+
64
+ if isinstance(step_log, ActionStep):
65
+ # Output the step number
66
+ step_number = f"Step {step_log.step_number}" if step_log.step_number is not None else ""
67
+ yield gr.ChatMessage(role="assistant", content=f"**{step_number}**")
68
+
69
+ # First yield the thought/reasoning from the LLM
70
+ if hasattr(step_log, "model_output") and step_log.model_output is not None:
71
+ # Clean up the LLM output
72
+ model_output = step_log.model_output.strip()
73
+ # Remove any trailing <end_code> and extra backticks, handling multiple possible formats
74
+ model_output = re.sub(r"```\s*<end_code>", "```", model_output) # handles ```<end_code>
75
+ model_output = re.sub(r"<end_code>\s*```", "```", model_output) # handles <end_code>```
76
+ model_output = re.sub(r"```\s*\n\s*<end_code>", "```", model_output) # handles ```\n<end_code>
77
+ model_output = model_output.strip()
78
+ yield gr.ChatMessage(role="assistant", content=model_output)
79
+
80
+ # For tool calls, create a parent message
81
+ if hasattr(step_log, "tool_calls") and step_log.tool_calls is not None:
82
+ first_tool_call = step_log.tool_calls[0]
83
+ used_code = first_tool_call.name == "python_interpreter"
84
+ parent_id = f"call_{len(step_log.tool_calls)}"
85
+
86
+ # Tool call becomes the parent message with timing info
87
+ # First we will handle arguments based on type
88
+ args = first_tool_call.arguments
89
+ if isinstance(args, dict):
90
+ content = str(args.get("answer", str(args)))
91
+ else:
92
+ content = str(args).strip()
93
+
94
+ if used_code:
95
+ # Clean up the content by removing any end code tags
96
+ content = re.sub(r"```.*?\n", "", content) # Remove existing code blocks
97
+ content = re.sub(r"\s*<end_code>\s*", "", content) # Remove end_code tags
98
+ content = content.strip()
99
+ if not content.startswith("```python"):
100
+ content = f"```python\n{content}\n```"
101
+
102
+ parent_message_tool = gr.ChatMessage(
103
+ role="assistant",
104
+ content=content,
105
+ metadata={
106
+ "title": f"🛠️ Used tool {first_tool_call.name}",
107
+ "id": parent_id,
108
+ "status": "pending",
109
+ },
110
+ )
111
+ yield parent_message_tool
112
+
113
+ # Nesting execution logs under the tool call if they exist
114
+ if hasattr(step_log, "observations") and (
115
+ step_log.observations is not None and step_log.observations.strip()
116
+ ): # Only yield execution logs if there's actual content
117
+ log_content = step_log.observations.strip()
118
+ if log_content:
119
+ log_content = re.sub(r"^Execution logs:\s*", "", log_content)
120
+ yield gr.ChatMessage(
121
+ role="assistant",
122
+ content=f"{log_content}",
123
+ metadata={"title": "📝 Execution Logs", "parent_id": parent_id, "status": "done"},
124
+ )
125
+
126
+ # Nesting any errors under the tool call
127
+ if hasattr(step_log, "error") and step_log.error is not None:
128
+ yield gr.ChatMessage(
129
+ role="assistant",
130
+ content=str(step_log.error),
131
+ metadata={"title": "💥 Error", "parent_id": parent_id, "status": "done"},
132
+ )
133
+
134
+ # Update parent message metadata to done status without yielding a new message
135
+ parent_message_tool.metadata["status"] = "done"
136
+
137
+ # Handle standalone errors but not from tool calls
138
+ elif hasattr(step_log, "error") and step_log.error is not None:
139
+ yield gr.ChatMessage(role="assistant", content=str(step_log.error), metadata={"title": "💥 Error"})
140
+
141
+ # Calculate duration and token information
142
+ step_footnote = f"{step_number}"
143
+ if hasattr(step_log, "input_token_count") and hasattr(step_log, "output_token_count"):
144
+ token_str = (
145
+ f" | Input-tokens:{step_log.input_token_count:,} | Output-tokens:{step_log.output_token_count:,}"
146
+ )
147
+ step_footnote += token_str
148
+ if hasattr(step_log, "duration"):
149
+ step_duration = f" | Duration: {round(float(step_log.duration), 2)}" if step_log.duration else None
150
+ step_footnote += step_duration
151
+ step_footnote = f"""<span style="color: #bbbbc2; font-size: 12px;">{step_footnote}</span> """
152
+ yield gr.ChatMessage(role="assistant", content=f"{step_footnote}")
153
+ yield gr.ChatMessage(role="assistant", content="-----")
154
+
155
+
156
+ def stream_to_gradio(
157
+ agent,
158
+ task: str,
159
+ reset_agent_memory: bool = False,
160
+ additional_args: Optional[dict] = None,
161
+ ):
162
+ """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
163
+ if not _is_package_available("gradio"):
164
+ raise ModuleNotFoundError(
165
+ "Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[gradio]'`"
166
+ )
167
+ import gradio as gr
168
+
169
+ total_input_tokens = 0
170
+ total_output_tokens = 0
171
+
172
+ for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args):
173
+ # Track tokens if model provides them
174
+ if hasattr(agent.model, "last_input_token_count"):
175
+ total_input_tokens += agent.model.last_input_token_count
176
+ total_output_tokens += agent.model.last_output_token_count
177
+ if isinstance(step_log, ActionStep):
178
+ step_log.input_token_count = agent.model.last_input_token_count
179
+ step_log.output_token_count = agent.model.last_output_token_count
180
+
181
+ for message in pull_messages_from_step(
182
+ step_log,
183
+ ):
184
+ yield message
185
+
186
+ final_answer = step_log # Last log is the run's final_answer
187
+ final_answer = handle_agent_output_types(final_answer)
188
+
189
+ if isinstance(final_answer, AgentText):
190
+ yield gr.ChatMessage(
191
+ role="assistant",
192
+ content=f"**Final answer:**\n{final_answer.to_string()}\n",
193
+ )
194
+ elif isinstance(final_answer, AgentImage):
195
+ yield gr.ChatMessage(
196
+ role="assistant",
197
+ content={"path": final_answer.to_string(), "mime_type": "image/png"},
198
+ )
199
+ elif isinstance(final_answer, AgentAudio):
200
+ yield gr.ChatMessage(
201
+ role="assistant",
202
+ content={"path": final_answer.to_string(), "mime_type": "audio/wav"},
203
+ )
204
+ else:
205
+ yield gr.ChatMessage(role="assistant", content=f"**Final answer:** {str(final_answer)}")
206
+
207
+
208
+ class GradioUI:
209
+ """A one-line interface to launch your agent in Gradio"""
210
+
211
+ def __init__(self, agent: MultiStepAgent, file_upload_folder: str | None = None):
212
+ if not _is_package_available("gradio"):
213
+ raise ModuleNotFoundError(
214
+ "Please install 'gradio' extra to use the GradioUI: `pip install 'smolagents[gradio]'`"
215
+ )
216
+ self.agent = agent
217
+ self.file_upload_folder = file_upload_folder
218
+ if self.file_upload_folder is not None:
219
+ if not os.path.exists(file_upload_folder):
220
+ os.mkdir(file_upload_folder)
221
+
222
+ def interact_with_agent(self, prompt, stored_messages):
223
+ import gradio as gr
224
+ print(prompt)
225
+ stored_messages = [] # Initialize as an empty list HERE
226
+ stored_messages.append(gr.ChatMessage(role="user", content=prompt))
227
+ yield stored_messages
228
+ for msg in stream_to_gradio(self.agent, task=prompt, reset_agent_memory=False):
229
+ stored_messages.append(msg)
230
+ yield stored_messages
231
+ yield stored_messages
232
+
233
+ def upload_file(
234
+ self,
235
+ file,
236
+ file_uploads_log,
237
+ allowed_file_types=[
238
+ "application/pdf",
239
+ "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
240
+ "text/plain",
241
+ ],
242
+ ):
243
+ """
244
+ Handle file uploads, default allowed types are .pdf, .docx, and .txt
245
+ """
246
+ import gradio as gr
247
+
248
+ if file is None:
249
+ return gr.Textbox("No file uploaded", visible=True), file_uploads_log
250
+
251
+ try:
252
+ mime_type, _ = mimetypes.guess_type(file.name)
253
+ except Exception as e:
254
+ return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log
255
+
256
+ if mime_type not in allowed_file_types:
257
+ return gr.Textbox("File type disallowed", visible=True), file_uploads_log
258
+
259
+ # Sanitize file name
260
+ original_name = os.path.basename(file.name)
261
+ sanitized_name = re.sub(
262
+ r"[^\w\-.]", "_", original_name
263
+ ) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores
264
+
265
+ type_to_ext = {}
266
+ for ext, t in mimetypes.types_map.items():
267
+ if t not in type_to_ext:
268
+ type_to_ext[t] = ext
269
+
270
+ # Ensure the extension correlates to the mime type
271
+ sanitized_name = sanitized_name.split(".")[:-1]
272
+ sanitized_name.append("" + type_to_ext[mime_type])
273
+ sanitized_name = "".join(sanitized_name)
274
+
275
+ # Save the uploaded file to the specified folder
276
+ file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name))
277
+ shutil.copy(file.name, file_path)
278
+
279
+ return gr.Textbox(f"File uploaded: {file_path}", visible=True), file_uploads_log + [file_path]
280
+
281
+ def log_user_message(self, text_input, file_uploads_log):
282
+ message = text_input + (
283
+ f"\nYou have been provided with these files, which might be helpful or not: {file_uploads_log}"
284
+ if len(file_uploads_log) > 0
285
+ else ""
286
+ )
287
+ return message, None
288
+ # Return just the message string, second value can be None (or anything other than the list you were appending to) since it's being passed to stored_messages
289
+
290
+ def launch(self, **kwargs):
291
+ import gradio as gr
292
+
293
+ with gr.Blocks(fill_height=True) as demo:
294
+ stored_messages = gr.State([])
295
+ file_uploads_log = gr.State([])
296
+ chatbot = gr.Chatbot(
297
+ label="Agent",
298
+ type="messages",
299
+ avatar_images=(
300
+ None,
301
+ "https://huggingface.co/datasets/agents-course/course-images/resolve/main/en/communication/Alfred.png",
302
+ ),
303
+ resizeable=True,
304
+ scale=1,
305
+ )
306
+ # If an upload folder is provided, enable the upload feature
307
+ if self.file_upload_folder is not None:
308
+ upload_file = gr.File(label="Upload a file")
309
+ upload_status = gr.Textbox(label="Upload Status", interactive=False, visible=False)
310
+ upload_file.change(
311
+ self.upload_file,
312
+ [upload_file, file_uploads_log],
313
+ [upload_status, file_uploads_log],
314
+ )
315
+ text_input = gr.Textbox(lines=1, label="Enter Timezone") # add location input as we ask
316
+ text_input.submit(
317
+ self.log_user_message,
318
+ [text_input, file_uploads_log],
319
+ [stored_messages,text_input], # Combine the outputs into a list
320
+ ).then(
321
+ lambda x,y,z : self.interact_with_agent(x[1],y), # Access the second element of the tuple x, which is the text input
322
+ [stored_messages, stored_messages, chatbot], #Inputs to Lambda
323
+ stored_messages, #Outputs
324
+ )
325
+ demo.launch(debug=True, share=True, **kwargs)
326
+
327
+
328
+
329
+
330
+
331
+ __all__ = ["stream_to_gradio", "GradioUI"]