import gradio as gr
import torch
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
)
from sentence_transformers import SentenceTransformer, util
from typing import List, Tuple, Dict
import re
import difflib
# Initialize similarity model
similarity_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
# Model configurations
PARAPHRASE_MODELS = {
"T5-Base": "Vamsi/T5_Paraphrase_Paws",
# "PEGASUS-Paraphrase": "tuner007/pegasus_paraphrase",
"Parrot-Paraphraser": "prithivida/parrot_paraphraser_on_T5",
"BART-Paraphrase": "eugenesiow/bart-paraphrase",
"ChatGPT-Style-T5": "humarin/chatgpt_paraphraser_on_T5_base",
}
EXPANSION_MODELS = {
"Flan-T5-Base": "google/flan-t5-base",
"Flan-T5-Large": "google/flan-t5-large",
}
# Cache for loaded models
model_cache = {}
def load_model(model_name: str, model_path: str):
"""Load model and tokenizer with caching"""
if model_name in model_cache:
return model_cache[model_name]
print(f"Loading {model_name}...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
# Move to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model_cache[model_name] = (model, tokenizer, device)
return model, tokenizer, device
def chunk_text(text: str, max_sentences: int = 4) -> List[str]:
"""Split text into chunks based on number of sentences"""
sentences = re.split(r'(?<=[.!?]) +', text.strip())
chunks = [' '.join(sentences[i:i+max_sentences]) for i in range(0, len(sentences), max_sentences)]
return [chunk for chunk in chunks if chunk.strip()]
def estimate_tokens(text: str) -> int:
"""Estimate number of tokens in text (approximate: 1 token ≈ 0.75 words)"""
word_count = len(text.split())
return int(word_count / 0.75)
def calculate_max_length(input_text: str, mode: str, base_max_length: int) -> int:
"""Calculate appropriate max_length based on input tokens"""
input_tokens = estimate_tokens(input_text)
if mode == "Paraphrase":
# For paraphrasing: output should be 1.2-1.5x input tokens
calculated_max = int(input_tokens * 1.5) + 50
else:
# For expansion: output should be 2-3x input tokens
calculated_max = int(input_tokens * 3) + 100
# Use the larger of calculated or user-specified max_length
final_max_length = max(calculated_max, base_max_length)
# Cap at reasonable maximum to avoid memory issues
return min(final_max_length, 1024)
def calculate_similarity(text1: str, text2: str) -> float:
"""Calculate cosine similarity between two texts"""
if not text1.strip() or not text2.strip():
return 0.0
embeddings = similarity_model.encode([text1, text2], convert_to_tensor=True)
similarity = util.cos_sim(embeddings[0], embeddings[1]).item()
similarity = round(similarity*100,2)
return similarity
def highlight_differences(original: str, generated: str) -> Tuple[str, str, Dict]:
"""
Create highlighted HTML versions of both texts showing differences
Returns: (highlighted_original, highlighted_generated, statistics)
"""
# Split into words for comparison
original_words = original.split()
generated_words = generated.split()
# Use difflib to find differences
diff = difflib.SequenceMatcher(None, original_words, generated_words)
highlighted_original = []
highlighted_generated = []
changes_count = 0
additions_count = 0
deletions_count = 0
unchanged_count = 0
word_substitutions = []
for tag, i1, i2, j1, j2 in diff.get_opcodes():
original_chunk = ' '.join(original_words[i1:i2])
generated_chunk = ' '.join(generated_words[j1:j2])
if tag == 'equal':
# Unchanged text
highlighted_original.append(original_chunk)
highlighted_generated.append(generated_chunk)
unchanged_count += (i2 - i1)
elif tag == 'replace':
# Changed text
highlighted_original.append(f'{original_chunk}')
highlighted_generated.append(f'{generated_chunk}')
changes_count += max(i2 - i1, j2 - j1)
# Track word substitutions (limit to single word changes for clarity)
if i2 - i1 == 1 and j2 - j1 == 1:
word_substitutions.append((original_chunk, generated_chunk))
elif tag == 'delete':
# Text removed in generated
highlighted_original.append(f'{original_chunk}')
deletions_count += (i2 - i1)
elif tag == 'insert':
# Text added in generated
highlighted_generated.append(f'{generated_chunk}')
additions_count += (j2 - j1)
# Join with spaces
final_original = ' '.join(highlighted_original)
final_generated = ' '.join(highlighted_generated)
# Calculate statistics
total_original_words = len(original_words)
total_generated_words = len(generated_words)
percentage_changed = (changes_count + deletions_count + additions_count) / max(total_original_words, 1) * 100
percentage_unchanged = (unchanged_count / max(total_original_words, 1)) * 100
statistics = {
'total_original': total_original_words,
'total_generated': total_generated_words,
'unchanged': unchanged_count,
'changed': changes_count,
'added': additions_count,
'deleted': deletions_count,
'percentage_changed': percentage_changed,
'percentage_unchanged': percentage_unchanged,
'substitutions': word_substitutions[:10] # Limit to first 10
}
return final_original, final_generated, statistics
def format_statistics(stats: Dict) -> str:
"""Format statistics into a readable HTML string with dark theme"""
html = f"""
📊 Change Analysis
{stats['total_original']}
Original Words
{stats['total_generated']}
Generated Words
{stats['unchanged']}
Unchanged
{stats['changed']}
Changed
Modification Rate: {stats['percentage_changed']:.1f}% modified, {stats['percentage_unchanged']:.1f}% preserved
✚ Added: {stats['added']} words |
✖ Removed: {stats['deleted']} words
"""
if stats['substitutions']:
html += """
🔄 Sample Word Substitutions:
"""
for orig, new in stats['substitutions']:
html += f'
{orig} → {new}
'
html += """
"""
html += """
Legend:
Removed/Changed
Added/New
"""
return html
def paraphrase_text(
text: str,
model_name: str,
temperature: float,
top_p: float,
max_length: int,
num_beams: int,
max_sentences: int,
target_words: int = None,
mode: str = "Paraphrase"
) -> Tuple[str, float]:
"""Paraphrase or expand text based on mode"""
if not text.strip():
return "Please enter some text to process.", 0.0
# Select appropriate model based on mode
if mode == "Paraphrase":
models_dict = PARAPHRASE_MODELS
if model_name not in models_dict:
model_name = list(models_dict.keys())[0]
model_path = models_dict[model_name]
prefix = "paraphrase: " if "T5" in model_name else ""
else: # Expand mode
models_dict = EXPANSION_MODELS
if model_name not in models_dict:
model_name = list(models_dict.keys())[0]
model_path = models_dict[model_name]
target_words = target_words or 300
prefix = f"Expand the following text to approximately {target_words} words, adding more details and context: "
# Load model
model, tokenizer, device = load_model(model_name, model_path)
# Chunk text based on sentences
chunks = chunk_text(text, max_sentences=max_sentences)
processed_chunks = []
print(f"\n{'='*60}")
print(f"Processing {len(chunks)} chunk(s) with {max_sentences} sentences per chunk")
print(f"{'='*60}")
for i, chunk in enumerate(chunks):
# Calculate dynamic max_length for this chunk
chunk_max_length = calculate_max_length(chunk, mode, max_length)
input_tokens = estimate_tokens(chunk)
# Prepare input
input_text = prefix + chunk + " " if mode == "Paraphrase" else prefix + chunk
inputs = tokenizer.encode(
input_text,
return_tensors="pt",
max_length=512,
truncation=True
)
inputs = inputs.to(device)
# Calculate min_length to ensure output isn't too short
if mode == "Paraphrase":
min_length_calc = int(input_tokens * 0.8)
else:
min_length_calc = int(input_tokens * 1.5)
# Generate
with torch.no_grad():
outputs = model.generate(
inputs,
max_length=chunk_max_length,
min_length=min(min_length_calc, chunk_max_length - 10),
num_beams=num_beams,
temperature=temperature if temperature > 0 else 1.0,
top_p=top_p,
top_k=120 if mode == "Paraphrase" else 50,
do_sample=temperature > 0,
early_stopping=True,
no_repeat_ngram_size=3 if mode == "Expand" else 2,
length_penalty=1.0 if mode == "Paraphrase" else 1.5,
repetition_penalty=1.2,
)
# Decode output
processed_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
processed_chunks.append(processed_text.strip())
output_tokens = estimate_tokens(processed_text)
print(f"Chunk {i+1}/{len(chunks)}:")
print(f" Input: {len(chunk.split())} words (~{input_tokens} tokens)")
print(f" Output: {len(processed_text.split())} words (~{output_tokens} tokens)")
print(f" Max length used: {chunk_max_length}")
print(f"-" * 60)
# Combine chunks with double newline
final_text = "\n\n".join(processed_chunks)
# Calculate similarity
similarity_score = calculate_similarity(text, final_text)
print(f"{'='*60}")
print(f"Total: {len(text.split())} → {len(final_text.split())} words")
print(f"Similarity: {similarity_score:.4f}")
print(f"{'='*60}\n")
return final_text, similarity_score
def update_model_choices(mode: str):
"""Update model dropdown based on selected mode"""
if mode == "Paraphrase":
choices = list(PARAPHRASE_MODELS.keys())
else:
choices = list(EXPANSION_MODELS.keys())
return gr.Dropdown(choices=choices, value=choices[0])
def update_parameters_visibility(mode: str):
"""Show/hide target words parameter based on mode"""
if mode == "Expand":
return gr.Number(visible=True)
else:
return gr.Number(visible=False)
def process_text(
input_text: str,
mode: str,
model_name: str,
temperature: float,
top_p: float,
max_length: int,
num_beams: int,
max_sentences: int,
target_words: int
):
"""Main processing function"""
try:
output_text, similarity = paraphrase_text(
input_text,
model_name,
temperature,
top_p,
max_length,
num_beams,
max_sentences,
target_words,
mode
)
word_count_original = len(input_text.split())
word_count_output = len(output_text.split())
# Generate highlighted comparison
highlighted_original, highlighted_generated, statistics = highlight_differences(
input_text,
output_text
)
# Format statistics
stats_html = format_statistics(statistics)
# Basic stats line
basic_stats = f"**Original:** {word_count_original} words | **Generated:** {word_count_output} words"
return output_text, basic_stats, similarity, highlighted_original, highlighted_generated, stats_html
except Exception as e:
import traceback
error_msg = f"Error: {str(e)}\n\n{traceback.format_exc()}"
print(error_msg)
return error_msg, "Error occurred", 0.0, "", "", ""
# Create Gradio interface
with gr.Blocks(title="Text Paraphraser & Expander", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 📝 Text Paraphraser & Expander
Transform your text with AI-powered paraphrasing and expansion capabilities.
"""
)
with gr.Row():
with gr.Column(scale=1):
mode = gr.Radio(
choices=["Paraphrase", "Expand"],
value="Paraphrase",
label="Mode",
info="Choose to paraphrase or expand your text"
)
model_dropdown = gr.Dropdown(
choices=list(PARAPHRASE_MODELS.keys()),
value=list(PARAPHRASE_MODELS.keys())[0],
label="Model Selection",
info="Choose the model for processing"
)
with gr.Row():
gr.Markdown("## ⚙️ Configuration")
with gr.Accordion("Advanced Parameters", open=False):
temperature = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature",
info="Higher = more creative, Lower = more focused"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05,
label="Top-p (Nucleus Sampling)",
info="Probability threshold for token selection"
)
max_length = gr.Slider(
minimum=128,
maximum=1024,
value=512,
step=32,
label="Max Length (tokens)",
info="Maximum length of generated text per chunk"
)
num_beams = gr.Slider(
minimum=1,
maximum=10,
value=4,
step=1,
label="Number of Beams",
info="Higher = better quality but slower"
)
max_sentences = gr.Slider(
minimum=1,
maximum=10,
value=4,
step=1,
label="Sentences per Chunk",
info="Number of sentences to process together"
)
target_words = gr.Number(
value=300,
label="Target Word Count (Expand mode)",
info="Approximate number of words for expansion",
visible=False
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 📥 Input Text")
input_text = gr.Textbox(
lines=10,
placeholder="Enter your text here...",
label="Original Text",
show_copy_button=True
)
with gr.Column(scale=1):
gr.Markdown("### 📤 Generated Text")
output_text = gr.Textbox(
lines=10,
label="Processed Text",
show_copy_button=True
)
with gr.Row():
process_btn = gr.Button("🚀 Generate", variant="primary", size="lg")
clear_btn = gr.Button("🗑️ Clear",size="lg")
stats_display = gr.Markdown()
similarity_display = gr.Number(
label="Content Similarity (%)",
precision=2,
interactive=False
)
# Highlighted comparison section
gr.Markdown("---")
gr.Markdown("## 🧩 Visual Comparison - Original vs Paraphrased Text")
gr.HTML("""
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 📄 Original Text (with changes highlighted)")
highlighted_original = gr.HTML(
label="Original with Changes",
show_label=False,
elem_id="highlighted_original"
)
with gr.Column(scale=1):
gr.Markdown("### ✨ Generated Text (with changes highlighted)")
highlighted_generated = gr.HTML(
label="Generated with Changes",
show_label=False,
elem_id="highlighted_original"
)
change_stats = gr.HTML(label="Change Statistics",elem_id="change_stats")
# Event handlers
mode.change(
fn=update_model_choices,
inputs=[mode],
outputs=[model_dropdown]
)
mode.change(
fn=update_parameters_visibility,
inputs=[mode],
outputs=[target_words]
)
process_btn.click(
fn=process_text,
inputs=[
input_text,
mode,
model_dropdown,
temperature,
top_p,
max_length,
num_beams,
max_sentences,
target_words
],
outputs=[
output_text,
stats_display,
similarity_display,
highlighted_original,
highlighted_generated,
change_stats
]
)
clear_btn.click(
fn=lambda: ("", "", 0.0, "", "", ""),
inputs=[],
outputs=[
input_text,
output_text,
similarity_display,
highlighted_original,
highlighted_generated,
change_stats
]
)
gr.Markdown(
"""
---
### 💡 Tips:
- **Paraphrase Mode**: Rewrites text while preserving meaning
- **Expand Mode**: Adds details and elaboration to make text longer
- **Sentences per Chunk**: Controls how many sentences are processed together (4 recommended)
- Adjust temperature for creativity (0.7-1.0 for paraphrase, 1.0-1.5 for expansion)
- Higher beam count = better quality but slower processing
- Max length is automatically calculated based on input, but can be overridden
- Output chunks are separated by double newlines for readability
"""
)
if __name__ == "__main__":
demo.launch(share=True)