How to convert .safetensors or .ckpt Files and Using in FlaxStableDiffusionImg2ImgPipeline?

3k views Asked by At

I am trying to convert a .safetensors model to a diffusers model using the Python script found at https://github.com/huggingface/diffusers/blob/main/scripts/convert_original_stable_diffusion_to_diffusers.py. The command I tried is python3 convert_original_stable_diffusion_to_diffusers.py --checkpoint_path /home/aero/stable-diffusion-webui/models/Stable-diffusion/chilloutmix_NiPrunedFp32Fix.safetensors --scheduler_type euler-ancestral --dump_path /home/aero/diffusers/models/chilloutmix_NiPrunedFp32Fix --from_safetensors. After the conversion, I intend to use the diffusers model within the FlaxStableDiffusionImg2ImgPipeline.

However, I encountered an error when running the script I provided below (full code):

First error: OSError: diffusion_pytorch_model.bin file found in directory /home/aero/diffusers/models/chilloutmix_NiPrunedFp32Fix/vae. Please load the model using from_pt=True.
I modified the code by adding from_pt=True.
Second error: TypeError: getattr(): attribute name must be string

My question is how I can fix these issues and properly convert the .safetensors model to a diffusers model, so I can use it with FlaxStableDiffusionImg2ImgPipeline without encountering any errors?

Full Code:

import jax
import numpy as np
import jax.numpy as jnp
from flax.jax_utils import replicate
from flax.training.common_utils import shard
import requests
from io import BytesIO
from PIL import Image
from diffusers import FlaxStableDiffusionImg2ImgPipeline
import time
from datetime import datetime
import random

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

start_time = time.time()

url = "https://i.pinimg.com/564x/e6/36/a6/e636a664f860a1ec9f7b5f3c4e2f634b.jpg"
response = requests.get(url)
init_img = Image.open(BytesIO(response.content)).convert("RGB")
init_img = init_img.resize((784, 784))

prompts = "hyperreal, artstation, (masterpiece:1.0), (best quality:1.4), (ultra highres:1.2), (photorealistic:1.4), (8k, RAW photo:1.2), (soft focus:1.4),  (sharp focus:1.4)"

num_samples = jax.device_count()

pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained(
    "/home/aero/diffusers/models/chilloutmix_NiPrunedFp32Fix",
    dtype=jnp.bfloat16,
    safety_checker=None,
    # from_pt=True,
)

for x in range(4):
    rng = create_key(random.randint(0, 7183698734589870))
    rng = jax.random.split(rng, num_samples)
    prompt_ids, processed_image = pipeline.prepare_inputs(
        prompt=[prompts] * num_samples, image=[init_img] * num_samples
    )
    p_params = replicate(params)
    prompt_ids = shard(prompt_ids)
    processed_image = shard(processed_image)

    output = pipeline(
        prompt_ids=prompt_ids,
        image=processed_image,
        params=p_params,
        prng_seed=rng,
        strength=0.6,
        num_inference_steps=50,
        jit=True,
        height=784,
        width=784,
    ).images

    output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))

    # Get timestamp
    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")

    # Loop over images and save to output directory with unique name
    for i, image in enumerate(output_images):
        filename = f"./{timestamp}_{x}_{i}.jpg"
        image.save(filename)

duration = time.time() - start_time
print(f"Elapsed time: {duration:.4f} seconds")

Error Stack:

╭──────────────────────────── Traceback (most recent call last) ─────────────────────────────╮
│ /home/aero/diffusers/./test.py:28 in <module>                                              │
│                                                                                            │
│   25                                                                                       │
│   26 num_samples = jax.device_count()                                                      │
│   27                                                                                       │
│ ❱ 28 pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained(                │
│   29 │   "/home/aero/diffusers/models/chilloutmix_NiPrunedFp32Fix",                        │
│   30 │   dtype=jnp.bfloat16,                                                               │
│   31 │   safety_checker=None,                                                              │
│                                                                                            │
│ /home/aero/.local/lib/python3.8/site-packages/diffusers/pipelines/pipeline_flax_utils.py:4 │
│ 46 in from_pretrained                                                                      │
│                                                                                            │
│   443 │   │   │   │   │   if class_candidate is not None and issubclass(class_obj, class_c │
│   444 │   │   │   │   │   │   load_method_name = importable_classes[class_name][1]         │
│   445 │   │   │   │                                                                        │
│ ❱ 446 │   │   │   │   load_method = getattr(class_obj, load_method_name)                   │
│   447 │   │   │   │                                                                        │
│   448 │   │   │   │   # check if the module is in a subdirectory                           │
│   449 │   │   │   │   if os.path.isdir(os.path.join(cached_folder, name)):                 │
╰────────────────────────────────────────────────────────────────────────────────────────────╯
0

There are 0 answers