| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| import json |
| import base64 |
| import io |
| from PIL import Image |
| import svgwrite |
| from typing import Dict, Any, List, Optional, Union |
| import diffusers |
| from diffusers import StableDiffusionPipeline, DDIMScheduler |
| from transformers import CLIPTextModel, CLIPTokenizer |
| import torchvision.transforms as transforms |
| from torchvision.transforms.functional import to_pil_image |
| import random |
| import math |
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| self.model_id = "runwayml/stable-diffusion-v1-5" |
| |
| try: |
| |
| self.pipe = StableDiffusionPipeline.from_pretrained( |
| self.model_id, |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| safety_checker=None, |
| requires_safety_checker=False |
| ).to(self.device) |
| |
| |
| self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) |
| |
| |
| self.clip_model = self.pipe.text_encoder |
| self.clip_tokenizer = self.pipe.tokenizer |
| |
| print("DiffSketcher handler initialized successfully!") |
| except Exception as e: |
| print(f"Warning: Could not load diffusion model: {e}") |
| self.pipe = None |
| self.clip_model = None |
| self.clip_tokenizer = None |
|
|
| def __call__(self, inputs: Union[str, Dict[str, Any]]) -> Image.Image: |
| """ |
| Generate SVG sketch from text prompt using DiffSketcher approach |
| """ |
| try: |
| |
| if isinstance(inputs, str): |
| prompt = inputs |
| parameters = {} |
| else: |
| prompt = inputs.get("inputs", inputs.get("prompt", "a simple sketch")) |
| parameters = inputs.get("parameters", {}) |
| |
| |
| num_paths = parameters.get("num_paths", 64) |
| num_iter = parameters.get("num_iter", 500) |
| width = parameters.get("width", 224) |
| height = parameters.get("height", 224) |
| guidance_scale = parameters.get("guidance_scale", 7.5) |
| seed = parameters.get("seed", None) |
| |
| if seed is not None: |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| random.seed(seed) |
| |
| print(f"Generating sketch for: '{prompt}' with {num_paths} paths") |
| |
| |
| svg_content, metadata = self.generate_diffsketcher_svg( |
| prompt, width, height, num_paths, num_iter, guidance_scale |
| ) |
| |
| |
| pil_image = self.svg_to_pil_image(svg_content, width, height) |
| |
| |
| pil_image.info['svg_content'] = svg_content |
| pil_image.info['prompt'] = prompt |
| pil_image.info['parameters'] = json.dumps(parameters) |
| pil_image.info['num_paths'] = str(num_paths) |
| pil_image.info['method'] = 'diffsketcher' |
| |
| return pil_image |
| |
| except Exception as e: |
| print(f"Error in DiffSketcher handler: {e}") |
| |
| fallback_svg = self.create_fallback_svg(prompt if 'prompt' in locals() else "error", 224, 224) |
| fallback_image = self.svg_to_pil_image(fallback_svg, 224, 224) |
| fallback_image.info['error'] = str(e) |
| return fallback_image |
|
|
| def generate_diffsketcher_svg(self, prompt: str, width: int, height: int, |
| num_paths: int, num_iter: int, guidance_scale: float): |
| """ |
| Generate SVG using DiffSketcher-inspired approach with diffusion guidance |
| """ |
| |
| text_embeddings = self.get_text_embeddings(prompt) |
| |
| |
| paths = self.initialize_paths(num_paths, width, height) |
| |
| |
| optimized_paths = self.optimize_paths_with_diffusion( |
| paths, text_embeddings, prompt, width, height, num_iter, guidance_scale |
| ) |
| |
| |
| svg_content = self.paths_to_svg(optimized_paths, width, height) |
| |
| metadata = { |
| "method": "diffsketcher", |
| "prompt": prompt, |
| "num_paths": num_paths, |
| "num_iter": num_iter, |
| "guidance_scale": guidance_scale, |
| "width": width, |
| "height": height |
| } |
| |
| return svg_content, metadata |
|
|
| def get_text_embeddings(self, prompt: str): |
| """Get CLIP text embeddings for the prompt""" |
| if self.clip_model is None or self.clip_tokenizer is None: |
| |
| return torch.zeros((2, 77, 768)) |
| |
| try: |
| with torch.no_grad(): |
| text_inputs = self.clip_tokenizer( |
| prompt, |
| padding="max_length", |
| max_length=self.clip_tokenizer.model_max_length, |
| truncation=True, |
| return_tensors="pt" |
| ).to(self.device) |
| |
| text_embeddings = self.clip_model(text_inputs.input_ids)[0] |
| |
| |
| uncond_inputs = self.clip_tokenizer( |
| "", |
| padding="max_length", |
| max_length=self.clip_tokenizer.model_max_length, |
| return_tensors="pt" |
| ).to(self.device) |
| |
| uncond_embeddings = self.clip_model(uncond_inputs.input_ids)[0] |
| |
| |
| text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) |
| |
| return text_embeddings |
| except Exception as e: |
| print(f"Error getting text embeddings: {e}") |
| return torch.zeros((2, 77, 768)) |
|
|
| def initialize_paths(self, num_paths: int, width: int, height: int): |
| """Initialize random Bezier paths""" |
| paths = [] |
| |
| for i in range(num_paths): |
| |
| start_x = random.uniform(0.1 * width, 0.9 * width) |
| start_y = random.uniform(0.1 * height, 0.9 * height) |
| |
| |
| cp1_x = start_x + random.uniform(-width*0.2, width*0.2) |
| cp1_y = start_y + random.uniform(-height*0.2, height*0.2) |
| cp2_x = start_x + random.uniform(-width*0.2, width*0.2) |
| cp2_y = start_y + random.uniform(-height*0.2, height*0.2) |
| |
| |
| end_x = start_x + random.uniform(-width*0.3, width*0.3) |
| end_y = start_y + random.uniform(-height*0.3, height*0.3) |
| |
| |
| cp1_x = max(0, min(width, cp1_x)) |
| cp1_y = max(0, min(height, cp1_y)) |
| cp2_x = max(0, min(width, cp2_x)) |
| cp2_y = max(0, min(height, cp2_y)) |
| end_x = max(0, min(width, end_x)) |
| end_y = max(0, min(height, end_y)) |
| |
| |
| color_intensity = random.uniform(0.1, 0.7) |
| color = ( |
| int(color_intensity * 255), |
| int(color_intensity * 255), |
| int(color_intensity * 255) |
| ) |
| |
| |
| stroke_width = random.uniform(0.5, 3.0) |
| |
| path = { |
| 'start': (start_x, start_y), |
| 'cp1': (cp1_x, cp1_y), |
| 'cp2': (cp2_x, cp2_y), |
| 'end': (end_x, end_y), |
| 'color': color, |
| 'stroke_width': stroke_width, |
| 'opacity': random.uniform(0.3, 0.8) |
| } |
| paths.append(path) |
| |
| return paths |
|
|
| def optimize_paths_with_diffusion(self, paths: List[Dict], text_embeddings: torch.Tensor, |
| prompt: str, width: int, height: int, |
| num_iter: int, guidance_scale: float): |
| """ |
| Optimize paths using diffusion model guidance (simplified approach) |
| """ |
| |
| semantic_features = self.extract_semantic_features(prompt) |
| |
| |
| for iteration in range(min(num_iter // 10, 50)): |
| |
| paths = self.apply_semantic_guidance(paths, semantic_features, width, height) |
| |
| |
| if iteration % 5 == 0: |
| paths = self.apply_aesthetic_refinement(paths, width, height) |
| |
| return paths |
|
|
| def extract_semantic_features(self, prompt: str): |
| """Extract semantic features from prompt to guide path generation""" |
| |
| features = { |
| 'complexity': 'medium', |
| 'style': 'sketch', |
| 'density': 'medium', |
| 'organic': False, |
| 'geometric': False, |
| 'detailed': False |
| } |
| |
| prompt_lower = prompt.lower() |
| |
| |
| complex_words = ['detailed', 'intricate', 'complex', 'elaborate'] |
| simple_words = ['simple', 'minimal', 'basic', 'clean'] |
| |
| if any(word in prompt_lower for word in complex_words): |
| features['complexity'] = 'high' |
| features['detailed'] = True |
| elif any(word in prompt_lower for word in simple_words): |
| features['complexity'] = 'low' |
| |
| |
| if any(word in prompt_lower for word in ['sketch', 'drawing', 'pencil', 'charcoal']): |
| features['style'] = 'sketch' |
| elif any(word in prompt_lower for word in ['painting', 'artistic', 'painted']): |
| features['style'] = 'artistic' |
| |
| |
| organic_words = ['tree', 'flower', 'animal', 'person', 'face', 'natural', 'organic'] |
| geometric_words = ['building', 'house', 'geometric', 'square', 'circle', 'triangle'] |
| |
| if any(word in prompt_lower for word in organic_words): |
| features['organic'] = True |
| if any(word in prompt_lower for word in geometric_words): |
| features['geometric'] = True |
| |
| return features |
|
|
| def apply_semantic_guidance(self, paths: List[Dict], features: Dict, width: int, height: int): |
| """Apply semantic guidance to modify paths""" |
| modified_paths = [] |
| |
| for path in paths: |
| new_path = path.copy() |
| |
| |
| if features['complexity'] == 'high': |
| |
| variation = 0.15 |
| new_path['cp1'] = ( |
| new_path['cp1'][0] + random.uniform(-width*variation, width*variation), |
| new_path['cp1'][1] + random.uniform(-height*variation, height*variation) |
| ) |
| new_path['cp2'] = ( |
| new_path['cp2'][0] + random.uniform(-width*variation, width*variation), |
| new_path['cp2'][1] + random.uniform(-height*variation, height*variation) |
| ) |
| elif features['complexity'] == 'low': |
| |
| start_x, start_y = new_path['start'] |
| end_x, end_y = new_path['end'] |
| new_path['cp1'] = ( |
| start_x + (end_x - start_x) * 0.33, |
| start_y + (end_y - start_y) * 0.33 |
| ) |
| new_path['cp2'] = ( |
| start_x + (end_x - start_x) * 0.66, |
| start_y + (end_y - start_y) * 0.66 |
| ) |
| |
| |
| if features['organic']: |
| |
| new_path['stroke_width'] *= random.uniform(0.8, 1.2) |
| new_path['opacity'] *= random.uniform(0.9, 1.1) |
| elif features['geometric']: |
| |
| |
| grid_size = 20 |
| for key in ['start', 'cp1', 'cp2', 'end']: |
| x, y = new_path[key] |
| new_path[key] = ( |
| round(x / grid_size) * grid_size, |
| round(y / grid_size) * grid_size |
| ) |
| |
| |
| for key in ['start', 'cp1', 'cp2', 'end']: |
| x, y = new_path[key] |
| new_path[key] = ( |
| max(0, min(width, x)), |
| max(0, min(height, y)) |
| ) |
| |
| modified_paths.append(new_path) |
| |
| return modified_paths |
|
|
| def apply_aesthetic_refinement(self, paths: List[Dict], width: int, height: int): |
| """Apply aesthetic refinements to improve visual quality""" |
| |
| center_x, center_y = width / 2, height / 2 |
| |
| def distance_from_center(path): |
| start_x, start_y = path['start'] |
| return math.sqrt((start_x - center_x)**2 + (start_y - center_y)**2) |
| |
| |
| paths.sort(key=distance_from_center, reverse=True) |
| |
| |
| for i, path in enumerate(paths): |
| |
| layer_factor = 1.0 - (i / len(paths)) * 0.3 |
| path['opacity'] = min(0.9, path['opacity'] * layer_factor) |
| |
| return paths |
|
|
| def paths_to_svg(self, paths: List[Dict], width: int, height: int): |
| """Convert optimized paths to SVG format""" |
| dwg = svgwrite.Drawing(size=(width, height)) |
| dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) |
| |
| for path in paths: |
| start_x, start_y = path['start'] |
| cp1_x, cp1_y = path['cp1'] |
| cp2_x, cp2_y = path['cp2'] |
| end_x, end_y = path['end'] |
| |
| |
| path_data = f"M {start_x},{start_y} C {cp1_x},{cp1_y} {cp2_x},{cp2_y} {end_x},{end_y}" |
| |
| color = path['color'] |
| stroke_color = f"rgb({color[0]},{color[1]},{color[2]})" |
| |
| dwg.add(dwg.path( |
| d=path_data, |
| stroke=stroke_color, |
| stroke_width=path['stroke_width'], |
| stroke_opacity=path['opacity'], |
| fill='none', |
| stroke_linecap='round', |
| stroke_linejoin='round' |
| )) |
| |
| return dwg.tostring() |
|
|
| def svg_to_pil_image(self, svg_content: str, width: int, height: int): |
| """Convert SVG content to PIL Image""" |
| try: |
| import cairosvg |
| |
| |
| png_bytes = cairosvg.svg2png( |
| bytestring=svg_content.encode('utf-8'), |
| output_width=width, |
| output_height=height |
| ) |
| |
| |
| image = Image.open(io.BytesIO(png_bytes)).convert('RGB') |
| return image |
| |
| except ImportError: |
| print("cairosvg not available, creating simple image representation") |
| |
| image = Image.new('RGB', (width, height), 'white') |
| return image |
| except Exception as e: |
| print(f"Error converting SVG to image: {e}") |
| |
| image = Image.new('RGB', (width, height), 'white') |
| return image |
|
|
| def create_fallback_svg(self, prompt: str, width: int, height: int): |
| """Create simple fallback SVG""" |
| dwg = svgwrite.Drawing(size=(width, height)) |
| dwg.add(dwg.rect(insert=(0, 0), size=(width, height), fill='white')) |
| |
| |
| dwg.add(dwg.text( |
| f"DiffSketcher\n{prompt[:30]}...", |
| insert=(width/2, height/2), |
| text_anchor="middle", |
| font_size="12px", |
| fill="black" |
| )) |
| |
| return dwg.tostring() |