| from typing import Dict, List, Any | |
| from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig | |
| from PIL import Image | |
| import requests | |
| import torch | |
| import gc | |
| import base64 | |
| import io | |
| class EndpointHandler: | |
| def __init__(self, path=""): | |
| self.processor = AutoProcessor.from_pretrained( | |
| path, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, | |
| device_map='auto' | |
| ) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| path, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, | |
| device_map='auto', | |
| low_cpu_mem_usage=True | |
| ) | |
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| inputs = data.get("inputs", {}) | |
| image_url = inputs.get("image_url") | |
| image_data = inputs.get("image") | |
| text_prompt = inputs.get("text_prompt", "Describe this image.") | |
| if image_url: | |
| try: | |
| image = Image.open(requests.get(image_url, stream=True).raw) | |
| except Exception as e: | |
| return [{"error": f"Failed to load image from URL: {str(e)}"}] | |
| elif image_data: | |
| try: | |
| image = Image.open(io.BytesIO(base64.b64decode(image_data))) | |
| except Exception as e: | |
| return [{"error": f"Failed to decode image data: {str(e)}"}] | |
| else: | |
| return [{"error": "No image_url or image data provided in inputs"}] | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| try: | |
| with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16): | |
| inputs = self.processor.process( | |
| images=[image], | |
| text=text_prompt | |
| ) | |
| inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()} | |
| output = self.model.generate_from_batch( | |
| inputs, | |
| GenerationConfig(max_new_tokens=200, stop_strings="<|endoftext|>"), | |
| tokenizer=self.processor.tokenizer | |
| ) | |
| generated_tokens = output[0, inputs['input_ids'].size(1):] | |
| generated_text = self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return [{"generated_text": generated_text}] | |
| except Exception as e: | |
| return [{"error": f"Error during generation: {str(e)}"}] |