Spaces:
Runtime error
Runtime error
| import json | |
| import traceback | |
| import torch | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, LogitsProcessorList | |
| from trl import PPOTrainer, PPOConfig | |
| import gradio as gr | |
| # ----------------------------------------------------------------------------- | |
| # 1. Helpers | |
| # ----------------------------------------------------------------------------- | |
| def make_json_serializable(obj): | |
| """ | |
| Recursively convert any torch.Tensor in obj to Python lists. | |
| """ | |
| if isinstance(obj, torch.Tensor): | |
| return obj.cpu().tolist() | |
| elif isinstance(obj, dict): | |
| return {k: make_json_serializable(v) for k, v in obj.items()} | |
| elif isinstance(obj, list): | |
| return [make_json_serializable(v) for v in obj] | |
| return obj | |
| def safe_json_dumps(data): | |
| """ | |
| Dump JSON with our converter to avoid Tensor serialization errors. | |
| """ | |
| return json.dumps( | |
| make_json_serializable(data), | |
| indent=2, | |
| ensure_ascii=False | |
| ) | |
| # ----------------------------------------------------------------------------- | |
| # 2. Load Models and Initialize PPO Agent | |
| # ----------------------------------------------------------------------------- | |
| MODEL_NAME = "google/flan-t5-base" | |
| # Core seq2seq model & tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) | |
| # PPO configuration | |
| ppo_config = PPOConfig( | |
| model_name=MODEL_NAME, | |
| learning_rate=1e-5, | |
| batch_size=1, | |
| log_with=None # switch to "wandb" or "tensorboard" if you like | |
| ) | |
| # Wrap FLAN-T5 in a PPO agent | |
| ppo_trainer = PPOTrainer( | |
| config=ppo_config, | |
| model=model, | |
| tokenizer=tokenizer | |
| ) | |
| # ----------------------------------------------------------------------------- | |
| # 3. Session State | |
| # ----------------------------------------------------------------------------- | |
| current_session = { | |
| "dialog": [] # each entry: {"user": str, "bot": str, "reward": float or None} | |
| } | |
| # ----------------------------------------------------------------------------- | |
| # 4. Core Callback Functions | |
| # ----------------------------------------------------------------------------- | |
| def reset_session(): | |
| """ | |
| Clear the conversation and return an empty chat history. | |
| """ | |
| global current_session | |
| current_session = {"dialog": []} | |
| return [] | |
| def chat_with_agent(user_input: str): | |
| """ | |
| Generate the model's reply, append to session, and return full chat history. | |
| """ | |
| global current_session | |
| try: | |
| # Tokenize user prompt and generate | |
| inputs = tokenizer(user_input, return_tensors="pt").input_ids | |
| outputs = model.generate( | |
| inputs, | |
| max_new_tokens=128, | |
| do_sample=True, | |
| top_p=0.9, | |
| temperature=0.8 | |
| ) | |
| bot_reply = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Store in session | |
| current_session["dialog"].append({ | |
| "user": user_input, | |
| "bot": bot_reply, | |
| "reward": None | |
| }) | |
| # Prepare for Gradio Chatbot: list of (user, bot) | |
| history = [ | |
| (turn["user"], turn["bot"]) | |
| for turn in current_session["dialog"] | |
| ] | |
| return history | |
| except Exception as e: | |
| print("π₯ Error in chat_with_agent:", e) | |
| traceback.print_exc() | |
| # On failure, leave session untouched | |
| return [("Error:", "Failed to generate reply. Check logs.")] | |
| def rate_and_train(rating: float): | |
| """ | |
| Take the last bot replyβs rating, run a PPO step, and return serialized session. | |
| """ | |
| global current_session | |
| try: | |
| if not current_session["dialog"]: | |
| return "No dialog to rate. Chat first." | |
| # Attach reward | |
| last = current_session["dialog"][-1] | |
| last["reward"] = float(rating) | |
| # Prepare for PPO step | |
| user_text = last["user"] | |
| bot_text = last["bot"] | |
| # Token IDs for PPO | |
| query_ids = tokenizer(user_text, return_tensors="pt").input_ids.squeeze(0) | |
| response_ids = tokenizer(bot_text, return_tensors="pt").input_ids.squeeze(0) | |
| # Run PPO optimization with this single example | |
| stats = ppo_trainer.step( | |
| [query_ids], | |
| [response_ids], | |
| [last["reward"]] | |
| ) | |
| print("π PPO step stats:", stats) | |
| # Return the entire session as JSON | |
| return safe_json_dumps(current_session) | |
| except Exception as e: | |
| print("π₯ Error in rate_and_train:", e) | |
| traceback.print_exc() | |
| return "Failed to apply training step. See logs." | |
| # ----------------------------------------------------------------------------- | |
| # 5. Gradio UI | |
| # ----------------------------------------------------------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## FLAN-T5 Chatbot with On-the-Fly Reinforcement Learning") | |
| chat_box = gr.Chatbot(label="Chat History") | |
| user_input = gr.Textbox(placeholder="Type your message hereβ¦", label="You") | |
| send_btn = gr.Button("Send") | |
| reset_btn = gr.Button("Reset Conversation") | |
| with gr.Row(): | |
| rating = gr.Slider(0, 5, step=1, value=0, label="Rate Last Reply") | |
| rate_btn = gr.Button("Apply Rating & Train") | |
| export_json = gr.Textbox(label="Session JSON", lines=10) | |
| # Reset chat | |
| reset_btn.click( | |
| fn=reset_session, | |
| inputs=None, | |
| outputs=chat_box | |
| ) | |
| # Send user message | |
| send_btn.click( | |
| fn=chat_with_agent, | |
| inputs=user_input, | |
| outputs=chat_box | |
| ) | |
| # Rate & train | |
| rate_btn.click( | |
| fn=rate_and_train, | |
| inputs=rating, | |
| outputs=export_json | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |