Spaces:
Running
Running
| print("Preparing for inference...") # noqa | |
| from rudalle.pipelines import generate_images | |
| from rudalle import get_rudalle_model, get_tokenizer, get_vae | |
| from huggingface_hub import hf_hub_url, cached_download | |
| import torch | |
| from io import BytesIO | |
| import base64 | |
| print(f"GPUs available: {torch.cuda.device_count()}") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| fp16 = torch.cuda.is_available() | |
| file_dir = "./models" | |
| file_name = "pytorch_model.bin" | |
| config_file_url = hf_hub_url( | |
| repo_id="minimaxir/ai-generated-pokemon-rudalle", filename=file_name) | |
| cached_download(config_file_url, cache_dir=file_dir, force_filename=file_name) | |
| model = get_rudalle_model('Malevich', pretrained=False, | |
| fp16=fp16, device=device) | |
| model.load_state_dict(torch.load( | |
| f"{file_dir}/{file_name}", map_location=f"{'cuda:0' if torch.cuda.is_available() else 'cpu'}")) | |
| vae = get_vae().to(device) | |
| tokenizer = get_tokenizer() | |
| print("Ready for inference") | |
| def english_to_russian(english_string): | |
| word_map = { | |
| "grass": "трава", | |
| "fire": "Пожар", | |
| "water": "вода", | |
| "lightning": "молния", | |
| "fighting": "борьба", | |
| "psychic": "психический", | |
| "colorless": "бесцветный", | |
| "darkness": "темнота", | |
| "metal": "металл", | |
| "dragon": "Дракон", | |
| "fairy": "сказочный" | |
| } | |
| return word_map[english_string.lower()] | |
| def generate_image(prompt): | |
| if prompt.lower() in ['grass', 'fire', 'water', 'lightning', 'fighting', 'psychic', 'colorless', 'darkness', 'metal', 'dragon', 'fairy']: | |
| prompt = english_to_russian(prompt) | |
| result, _ = generate_images( | |
| prompt, tokenizer, model, vae, top_k=2048, images_num=1, top_p=0.995) | |
| buffer = BytesIO() | |
| result[0].save(buffer, format="PNG") | |
| base64_bytes = base64.b64encode(buffer.getvalue()) | |
| base64_string = base64_bytes.decode("UTF-8") | |
| return "data:image/png;base64," + base64_string | |