Spaces:
Runtime error
Runtime error
| import torch | |
| from huggingface_guess import model_list | |
| from backend.diffusion_engine.base import ForgeDiffusionEngine, ForgeObjects | |
| from backend.patcher.clip import CLIP | |
| from backend.patcher.vae import VAE | |
| from backend.patcher.unet import UnetPatcher | |
| from backend.text_processing.t5_engine import T5TextProcessingEngine | |
| from backend.args import dynamic_args | |
| from backend.modules.k_prediction import PredictionFlux | |
| from backend import memory_management | |
| class Chroma(ForgeDiffusionEngine): | |
| def __init__(self, estimated_config, huggingface_components): | |
| super().__init__(estimated_config, huggingface_components) | |
| self.is_inpaint = False | |
| clip = CLIP( | |
| model_dict={ | |
| 't5xxl': huggingface_components['text_encoder'] | |
| }, | |
| tokenizer_dict={ | |
| 't5xxl': huggingface_components['tokenizer'] | |
| } | |
| ) | |
| vae = VAE(model=huggingface_components['vae']) | |
| k_predictor = PredictionFlux( | |
| mu=1.0 | |
| ) | |
| unet = UnetPatcher.from_model( | |
| model=huggingface_components['transformer'], | |
| diffusers_scheduler=None, | |
| k_predictor=k_predictor, | |
| config=estimated_config | |
| ) | |
| self.text_processing_engine_t5 = T5TextProcessingEngine( | |
| text_encoder=clip.cond_stage_model.t5xxl, | |
| tokenizer=clip.tokenizer.t5xxl, | |
| emphasis_name=dynamic_args['emphasis_name'], | |
| min_length=1 | |
| ) | |
| self.forge_objects = ForgeObjects(unet=unet, clip=clip, vae=vae, clipvision=None) | |
| self.forge_objects_original = self.forge_objects.shallow_copy() | |
| self.forge_objects_after_applying_lora = self.forge_objects.shallow_copy() | |
| def set_clip_skip(self, clip_skip): | |
| pass | |
| def get_learned_conditioning(self, prompt: list[str]): | |
| memory_management.load_model_gpu(self.forge_objects.clip.patcher) | |
| return self.text_processing_engine_t5(prompt) | |
| def get_prompt_lengths_on_ui(self, prompt): | |
| token_count = len(self.text_processing_engine_t5.tokenize([prompt])[0]) | |
| return token_count, max(255, token_count) | |
| def encode_first_stage(self, x): | |
| sample = self.forge_objects.vae.encode(x.movedim(1, -1) * 0.5 + 0.5) | |
| sample = self.forge_objects.vae.first_stage_model.process_in(sample) | |
| return sample.to(x) | |
| def decode_first_stage(self, x): | |
| sample = self.forge_objects.vae.first_stage_model.process_out(x) | |
| sample = self.forge_objects.vae.decode(sample).movedim(-1, 1) * 2.0 - 1.0 | |
| return sample.to(x) | |