Russel-Morant's picture
Update app.py
f97a8f6 verified
import json
import traceback
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, LogitsProcessorList
from trl import PPOTrainer, PPOConfig
import gradio as gr
# -----------------------------------------------------------------------------
# 1. Helpers
# -----------------------------------------------------------------------------
def make_json_serializable(obj):
"""
Recursively convert any torch.Tensor in obj to Python lists.
"""
if isinstance(obj, torch.Tensor):
return obj.cpu().tolist()
elif isinstance(obj, dict):
return {k: make_json_serializable(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [make_json_serializable(v) for v in obj]
return obj
def safe_json_dumps(data):
"""
Dump JSON with our converter to avoid Tensor serialization errors.
"""
return json.dumps(
make_json_serializable(data),
indent=2,
ensure_ascii=False
)
# -----------------------------------------------------------------------------
# 2. Load Models and Initialize PPO Agent
# -----------------------------------------------------------------------------
MODEL_NAME = "google/flan-t5-base"
# Core seq2seq model & tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
# PPO configuration
ppo_config = PPOConfig(
model_name=MODEL_NAME,
learning_rate=1e-5,
batch_size=1,
log_with=None # switch to "wandb" or "tensorboard" if you like
)
# Wrap FLAN-T5 in a PPO agent
ppo_trainer = PPOTrainer(
config=ppo_config,
model=model,
tokenizer=tokenizer
)
# -----------------------------------------------------------------------------
# 3. Session State
# -----------------------------------------------------------------------------
current_session = {
"dialog": [] # each entry: {"user": str, "bot": str, "reward": float or None}
}
# -----------------------------------------------------------------------------
# 4. Core Callback Functions
# -----------------------------------------------------------------------------
def reset_session():
"""
Clear the conversation and return an empty chat history.
"""
global current_session
current_session = {"dialog": []}
return []
def chat_with_agent(user_input: str):
"""
Generate the model's reply, append to session, and return full chat history.
"""
global current_session
try:
# Tokenize user prompt and generate
inputs = tokenizer(user_input, return_tensors="pt").input_ids
outputs = model.generate(
inputs,
max_new_tokens=128,
do_sample=True,
top_p=0.9,
temperature=0.8
)
bot_reply = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Store in session
current_session["dialog"].append({
"user": user_input,
"bot": bot_reply,
"reward": None
})
# Prepare for Gradio Chatbot: list of (user, bot)
history = [
(turn["user"], turn["bot"])
for turn in current_session["dialog"]
]
return history
except Exception as e:
print("πŸ”₯ Error in chat_with_agent:", e)
traceback.print_exc()
# On failure, leave session untouched
return [("Error:", "Failed to generate reply. Check logs.")]
def rate_and_train(rating: float):
"""
Take the last bot reply’s rating, run a PPO step, and return serialized session.
"""
global current_session
try:
if not current_session["dialog"]:
return "No dialog to rate. Chat first."
# Attach reward
last = current_session["dialog"][-1]
last["reward"] = float(rating)
# Prepare for PPO step
user_text = last["user"]
bot_text = last["bot"]
# Token IDs for PPO
query_ids = tokenizer(user_text, return_tensors="pt").input_ids.squeeze(0)
response_ids = tokenizer(bot_text, return_tensors="pt").input_ids.squeeze(0)
# Run PPO optimization with this single example
stats = ppo_trainer.step(
[query_ids],
[response_ids],
[last["reward"]]
)
print("πŸš€ PPO step stats:", stats)
# Return the entire session as JSON
return safe_json_dumps(current_session)
except Exception as e:
print("πŸ”₯ Error in rate_and_train:", e)
traceback.print_exc()
return "Failed to apply training step. See logs."
# -----------------------------------------------------------------------------
# 5. Gradio UI
# -----------------------------------------------------------------------------
with gr.Blocks() as demo:
gr.Markdown("## FLAN-T5 Chatbot with On-the-Fly Reinforcement Learning")
chat_box = gr.Chatbot(label="Chat History")
user_input = gr.Textbox(placeholder="Type your message here…", label="You")
send_btn = gr.Button("Send")
reset_btn = gr.Button("Reset Conversation")
with gr.Row():
rating = gr.Slider(0, 5, step=1, value=0, label="Rate Last Reply")
rate_btn = gr.Button("Apply Rating & Train")
export_json = gr.Textbox(label="Session JSON", lines=10)
# Reset chat
reset_btn.click(
fn=reset_session,
inputs=None,
outputs=chat_box
)
# Send user message
send_btn.click(
fn=chat_with_agent,
inputs=user_input,
outputs=chat_box
)
# Rate & train
rate_btn.click(
fn=rate_and_train,
inputs=rating,
outputs=export_json
)
if __name__ == "__main__":
demo.launch()