I am trying to convert a .safetensors model to a diffusers model using the Python script found at https://github.com/huggingface/diffusers/blob/main/scripts/convert_original_stable_diffusion_to_diffusers.py. The command I tried is python3 convert_original_stable_diffusion_to_diffusers.py --checkpoint_path /home/aero/stable-diffusion-webui/models/Stable-diffusion/chilloutmix_NiPrunedFp32Fix.safetensors --scheduler_type euler-ancestral --dump_path /home/aero/diffusers/models/chilloutmix_NiPrunedFp32Fix --from_safetensors
. After the conversion, I intend to use the diffusers model within the FlaxStableDiffusionImg2ImgPipeline.
However, I encountered an error when running the script I provided below (full code):
First error: OSError: diffusion_pytorch_model.bin file found in directory /home/aero/diffusers/models/chilloutmix_NiPrunedFp32Fix/vae. Please load the model using from_pt=True.
I modified the code by adding from_pt=True.
Second error: TypeError: getattr(): attribute name must be string
My question is how I can fix these issues and properly convert the .safetensors model to a diffusers model, so I can use it with FlaxStableDiffusionImg2ImgPipeline without encountering any errors?
Full Code:
import jax
import numpy as np
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
import requests
from io import BytesIO
from PIL import Image
from diffusers import FlaxStableDiffusionImg2ImgPipeline
import time
from datetime import datetime
import random
def create_key(seed=0):
return jax.random.PRNGKey(seed)
start_time = time.time()
url = "https://i.pinimg.com/564x/e6/36/a6/e636a664f860a1ec9f7b5f3c4e2f634b.jpg"
response = requests.get(url)
init_img = Image.open(BytesIO(response.content)).convert("RGB")
init_img = init_img.resize((784, 784))
prompts = "hyperreal, artstation, (masterpiece:1.0), (best quality:1.4), (ultra highres:1.2), (photorealistic:1.4), (8k, RAW photo:1.2), (soft focus:1.4), (sharp focus:1.4)"
num_samples = jax.device_count()
pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained(
"/home/aero/diffusers/models/chilloutmix_NiPrunedFp32Fix",
dtype=jnp.bfloat16,
safety_checker=None,
# from_pt=True,
)
for x in range(4):
rng = create_key(random.randint(0, 7183698734589870))
rng = jax.random.split(rng, num_samples)
prompt_ids, processed_image = pipeline.prepare_inputs(
prompt=[prompts] * num_samples, image=[init_img] * num_samples
)
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
processed_image = shard(processed_image)
output = pipeline(
prompt_ids=prompt_ids,
image=processed_image,
params=p_params,
prng_seed=rng,
strength=0.6,
num_inference_steps=50,
jit=True,
height=784,
width=784,
).images
output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
# Get timestamp
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
# Loop over images and save to output directory with unique name
for i, image in enumerate(output_images):
filename = f"./{timestamp}_{x}_{i}.jpg"
image.save(filename)
duration = time.time() - start_time
print(f"Elapsed time: {duration:.4f} seconds")
Error Stack:
╭──────────────────────────── Traceback (most recent call last) ─────────────────────────────╮
│ /home/aero/diffusers/./test.py:28 in <module> │
│ │
│ 25 │
│ 26 num_samples = jax.device_count() │
│ 27 │
│ ❱ 28 pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained( │
│ 29 │ "/home/aero/diffusers/models/chilloutmix_NiPrunedFp32Fix", │
│ 30 │ dtype=jnp.bfloat16, │
│ 31 │ safety_checker=None, │
│ │
│ /home/aero/.local/lib/python3.8/site-packages/diffusers/pipelines/pipeline_flax_utils.py:4 │
│ 46 in from_pretrained │
│ │
│ 443 │ │ │ │ │ if class_candidate is not None and issubclass(class_obj, class_c │
│ 444 │ │ │ │ │ │ load_method_name = importable_classes[class_name][1] │
│ 445 │ │ │ │ │
│ ❱ 446 │ │ │ │ load_method = getattr(class_obj, load_method_name) │
│ 447 │ │ │ │ │
│ 448 │ │ │ │ # check if the module is in a subdirectory │
│ 449 │ │ │ │ if os.path.isdir(os.path.join(cached_folder, name)): │
╰────────────────────────────────────────────────────────────────────────────────────────────╯