Burf commited on
Commit
e600401
·
1 Parent(s): d797ff7

change custom pipeline init issue

Browse files
Files changed (1) hide show
  1. wrapper.py +2 -5
wrapper.py CHANGED
@@ -183,14 +183,11 @@ def peca(pipeline, save_path = "./weight", n_layer = 10):
183
  return encoder, feature_encoder.get_text_feature, size, num_inference_steps, skip
184
 
185
  class DrUM(DiffusionPipeline):
186
- def __init__(self, pipeline = None, repo_id = "Burf/DrUM", weight = None, pretrained_model_name_or_path = None, torch_dtype = torch.bfloat16, device = "cuda"):
187
  """
188
  DrUM for various T2I diffusion models
189
  """
190
- if pipeline is None and pretrained_model_name_or_path is None:
191
- raise ValueError("pipeline or pretrained_model_name_or_path must be provided")
192
-
193
- self.pipeline = pipeline if pipeline is not None else self.load_pipeline(pretrained_model_name_or_path, torch_dtype = torch_dtype, device = device)
194
  self.repo_id = repo_id
195
 
196
  self.adapter, self.feature_encoder, self.size, self.num_inference_steps, self.skip = self.load_peca(self.pipeline, repo_id, weight)
 
183
  return encoder, feature_encoder.get_text_feature, size, num_inference_steps, skip
184
 
185
  class DrUM(DiffusionPipeline):
186
+ def __init__(self, pipeline, repo_id = "Burf/DrUM", weight = None, torch_dtype = torch.bfloat16, device = "cuda"):
187
  """
188
  DrUM for various T2I diffusion models
189
  """
190
+ self.pipeline = pipeline if not isinstance(pipeline, str) else self.load_pipeline(pipeline, torch_dtype = torch_dtype, device = device)
 
 
 
191
  self.repo_id = repo_id
192
 
193
  self.adapter, self.feature_encoder, self.size, self.num_inference_steps, self.skip = self.load_peca(self.pipeline, repo_id, weight)