Spaces:
Runtime error
Runtime error
| # based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/inferences.py # noqa: E501 | |
| # thank you @NimaBoscarino | |
| import os | |
| import re | |
| from pathlib import Path | |
| from uuid import uuid4 | |
| from minydra import resolved_args | |
| import numpy as np | |
| import torch | |
| from diffusers import StableDiffusionInpaintPipeline | |
| from PIL import Image | |
| from skimage.color import rgba2rgb | |
| from skimage.transform import resize | |
| from climategan.trainer import Trainer | |
| CUDA = torch.cuda.is_available() | |
| def concat_events(output_dict, events, i=None, axis=1): | |
| """ | |
| Concatenates the `i`th data in `output_dict` according to the keys listed | |
| in `events` on dimension `axis`. | |
| Args: | |
| output_dict (dict[Union[list[np.array], np.array]]): A dictionary mapping | |
| events to their corresponding data : | |
| {k: [HxWxC]} (for i != None) or {k: BxHxWxC}. | |
| events (list[str]): output_dict's keys to concatenate. | |
| axis (int, optional): Concatenation axis. Defaults to 1. | |
| """ | |
| cs = [e for e in events if e in output_dict] | |
| if i is not None: | |
| return uint8(np.concatenate([output_dict[c][i] for c in cs], axis=axis)) | |
| return uint8(np.concatenate([output_dict[c] for c in cs], axis=axis)) | |
| def clear(folder): | |
| """ | |
| Deletes all the images without the inference separator "---" in their name. | |
| Args: | |
| folder (Union[str, Path]): The folder to clear. | |
| """ | |
| for i in list(Path(folder).iterdir()): | |
| if i.is_file() and "---" in i.stem: | |
| i.unlink() | |
| def uint8(array, rescale=False): | |
| """ | |
| convert an array to np.uint8 (does not rescale or anything else than changing dtype) | |
| Args: | |
| array (np.array): array to modify | |
| Returns: | |
| np.array(np.uint8): converted array | |
| """ | |
| if rescale: | |
| if array.min() < 0: | |
| if array.min() >= -1 and array.max() <= 1: | |
| array = (array + 1) / 2 | |
| else: | |
| raise ValueError( | |
| f"Data range mismatch for image: ({array.min()}, {array.max()})" | |
| ) | |
| if array.max() <= 1: | |
| array = array * 255 | |
| return array.astype(np.uint8) | |
| def resize_and_crop(img, to=640): | |
| """ | |
| Resizes an image so that it keeps the aspect ratio and the smallest dimensions | |
| is `to`, then crops this resized image in its center so that the output is `to x to` | |
| without aspect ratio distortion | |
| Args: | |
| img (np.array): np.uint8 255 image | |
| Returns: | |
| np.array: [0, 1] np.float32 image | |
| """ | |
| # resize keeping aspect ratio: smallest dim is 640 | |
| h, w = img.shape[:2] | |
| if h < w: | |
| size = (to, int(to * w / h)) | |
| else: | |
| size = (int(to * h / w), to) | |
| r_img = resize(img, size, preserve_range=True, anti_aliasing=True) | |
| r_img = uint8(r_img) | |
| # crop in the center | |
| H, W = r_img.shape[:2] | |
| top = (H - to) // 2 | |
| left = (W - to) // 2 | |
| rc_img = r_img[top : top + to, left : left + to, :] | |
| return rc_img / 255.0 | |
| def to_m1_p1(img): | |
| """ | |
| rescales a [0, 1] image to [-1, +1] | |
| Args: | |
| img (np.array): float32 numpy array of an image in [0, 1] | |
| i (int): Index of the image being rescaled | |
| Raises: | |
| ValueError: If the image is not in [0, 1] | |
| Returns: | |
| np.array(np.float32): array in [-1, +1] | |
| """ | |
| if img.min() >= 0 and img.max() <= 1: | |
| return (img.astype(np.float32) - 0.5) * 2 | |
| raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})") | |
| # No need to do any timing in this, since it's just for the HF Space | |
| class ClimateGAN: | |
| def __init__(self, model_path, dev_mode=False) -> None: | |
| """ | |
| A wrapper for the ClimateGAN model that you can use to generate | |
| events from images or folders containing images. | |
| Args: | |
| model_path (Union[str, Path]): Where to load the Masker from | |
| """ | |
| torch.set_grad_enabled(False) | |
| self.target_size = 640 | |
| self._stable_diffusion_is_setup = False | |
| self.dev_mode = dev_mode | |
| if self.dev_mode: | |
| return | |
| self.trainer = Trainer.resume_from_path( | |
| model_path, | |
| setup=True, | |
| inference=True, | |
| new_exp=None, | |
| ) | |
| if CUDA: | |
| self.trainer.G.half() | |
| def _setup_stable_diffusion(self): | |
| """ | |
| Sets up the stable diffusion pipeline for in-painting. | |
| Make sure you have accepted the license on the model's card | |
| https://huggingface.co/CompVis/stable-diffusion-v1-4 | |
| """ | |
| if self.dev_mode: | |
| return | |
| try: | |
| self.sdip_pipeline = StableDiffusionInpaintPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-inpainting", | |
| revision="fp16" if CUDA else "main", | |
| torch_dtype=torch.float16 if CUDA else torch.float32, | |
| safety_checker=None, | |
| use_auth_token=os.environ.get("HF_AUTH_TOKEN"), | |
| ).to(self.trainer.device) | |
| self._stable_diffusion_is_setup = True | |
| except Exception as e: | |
| print( | |
| "\nCould not load stable diffusion model. " | |
| + "Please make sure you have accepted the license on the model's" | |
| + " card https://huggingface.co/CompVis/stable-diffusion-v1-4\n" | |
| ) | |
| raise e | |
| def _preprocess_image(self, img): | |
| """ | |
| Turns a HxWxC uint8 numpy array into a 640x640x3 float32 numpy array | |
| in [-1, 1]. | |
| Args: | |
| img (np.array): Image to resize crop and rescale | |
| Returns: | |
| np.array: Resized, cropped and rescaled image | |
| """ | |
| # rgba to rgb | |
| data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255) | |
| # to args.target_size | |
| data = resize_and_crop(data, self.target_size) | |
| # resize() produces [0, 1] images, rescale to [-1, 1] | |
| data = to_m1_p1(data) | |
| return data | |
| # Does all three inferences at the moment. | |
| def infer_single( | |
| self, | |
| orig_image, | |
| painter="both", | |
| prompt="An HD picture of a street with dirty water after a heavy flood", | |
| concats=[ | |
| "input", | |
| "masked_input", | |
| "climategan_flood", | |
| "stable_flood", | |
| "stable_copy_flood", | |
| ], | |
| as_pil_image=False, | |
| ): | |
| """ | |
| Infers the image with the ClimateGAN model. | |
| Importantly (and unlike self.infer_preprocessed_batch), the image is | |
| pre-processed by self._preprocess_image before going through the networks. | |
| Output dict contains the following keys: | |
| - "input": The input image | |
| - "mask": The mask used to generate the flood (from ClimateGAN's Masker) | |
| - "masked_input": The input image with the mask applied | |
| - "climategan_flood": The flooded image generated by ClimateGAN's Painter | |
| on the masked input (only if "painter" is "climategan" or "both"). | |
| - "stable_flood": The flooded image in-painted by the stable diffusion model | |
| from the mask and the input image (only if "painter" is "stable_diffusion" | |
| or "both"). | |
| - "stable_copy_flood": The flooded image in-painted by the stable diffusion | |
| model with its original context pasted back in: | |
| y = m * flooded + (1-m) * input | |
| (only if "painter" is "stable_diffusion" or "both"). | |
| Args: | |
| orig_image (Union[str, np.array]): image to infer on. Can be a path to | |
| an image which will be read. | |
| painter (str, optional): Which painter to use: "climategan", | |
| "stable_diffusion" or "both". Defaults to "both". | |
| prompt (str, optional): The prompt used to guide the diffusion. Defaults | |
| to "An HD picture of a street with dirty water after a heavy flood". | |
| concats (list, optional): List of keys in `output` to concatenate together | |
| in a new `{original_stem}_concat` image written. Defaults to: | |
| ["input", "masked_input", "climategan_flood", "stable_flood", | |
| "stable_copy_flood"]. | |
| Returns: | |
| dict: a dictionary containing the output images {k: HxWxC}. C is omitted | |
| for masks (HxW). | |
| """ | |
| if self.dev_mode: | |
| return { | |
| "input": orig_image, | |
| "mask": np.random.randint(0, 255, (640, 640)), | |
| "masked_input": np.random.randint(0, 255, (640, 640, 3)), | |
| "climategan_flood": np.random.randint(0, 255, (640, 640, 3)), | |
| "stable_flood": np.random.randint(0, 255, (640, 640, 3)), | |
| "stable_copy_flood": np.random.randint(0, 255, (640, 640, 3)), | |
| "concat": np.random.randint(0, 255, (640, 640 * 5, 3)), | |
| "smog": np.random.randint(0, 255, (640, 640, 3)), | |
| "wildfire": np.random.randint(0, 255, (640, 640, 3)), | |
| "depth": np.random.randint(0, 255, (640, 640, 1)), | |
| "segmentation": np.random.randint(0, 255, (640, 640, 3)), | |
| } | |
| return | |
| image_array = ( | |
| np.array(Image.open(orig_image)) | |
| if isinstance(orig_image, str) | |
| else orig_image | |
| ) | |
| pil_image = None | |
| if as_pil_image: | |
| pil_image = Image.fromarray(image_array) | |
| print("Preprocessing image") | |
| image = self._preprocess_image(image_array) | |
| output_dict = self.infer_preprocessed_batch( | |
| images=image[None, ...], | |
| painter=painter, | |
| prompt=prompt, | |
| concats=concats, | |
| pil_image=pil_image, | |
| ) | |
| print("Inference done") | |
| return {k: v[0] for k, v in output_dict.items()} | |
| def infer_preprocessed_batch( | |
| self, | |
| images, | |
| painter="both", | |
| prompt="An HD picture of a street with dirty water after a heavy flood", | |
| concats=[ | |
| "input", | |
| "masked_input", | |
| "climategan_flood", | |
| "stable_flood", | |
| "stable_copy_flood", | |
| ], | |
| pil_image=None, | |
| ): | |
| """ | |
| Infers ClimateGAN predictions on a batch of preprocessed images. | |
| It assumes that each image in the batch has been preprocessed with | |
| self._preprocess_image(). | |
| Output dict contains the following keys: | |
| - "input": The input image | |
| - "mask": The mask used to generate the flood (from ClimateGAN's Masker) | |
| - "masked_input": The input image with the mask applied | |
| - "climategan_flood": The flooded image generated by ClimateGAN's Painter | |
| on the masked input (only if "painter" is "climategan" or "both"). | |
| - "stable_flood": The flooded image in-painted by the stable diffusion model | |
| from the mask and the input image (only if "painter" is "stable_diffusion" | |
| or "both"). | |
| - "stable_copy_flood": The flooded image in-painted by the stable diffusion | |
| model with its original context pasted back in: | |
| y = m * flooded + (1-m) * input | |
| (only if "painter" is "stable_diffusion" or "both"). | |
| Args: | |
| images (np.array): A batch of input images BxHxWx3 | |
| painter (str, optional): Which painter to use: "climategan", | |
| "stable_diffusion" or "both". Defaults to "both". | |
| prompt (str, optional): The prompt used to guide the diffusion. Defaults | |
| to "An HD picture of a street with dirty water after a heavy flood". | |
| concats (list, optional): List of keys in `output` to concatenate together | |
| in a new `{original_stem}_concat` image written. Defaults to: | |
| ["input", "masked_input", "climategan_flood", "stable_flood", | |
| "stable_copy_flood"]. | |
| pil_image (PIL.Image, optional): The original PIL image. If provided, | |
| will be used for a single inference (batch_size=1) | |
| Returns: | |
| dict: a dictionary containing the output images | |
| """ | |
| assert painter in [ | |
| "both", | |
| "stable_diffusion", | |
| "climategan", | |
| ], f"Unknown painter: {painter}" | |
| ignore_event = set() | |
| if painter == "stable_diffusion": | |
| ignore_event.add("flood") | |
| if pil_image is not None: | |
| print("Warning: `pil_image` has been provided, it will override `images`") | |
| images = self._preprocess_image(np.array(pil_image))[None, ...] | |
| pil_image = Image.fromarray(((images[0] + 1) / 2 * 255).astype(np.uint8)) | |
| # Retrieve numpy events as a dict {event: array[BxHxWxC]} | |
| print("Inferring ClimateGAN events") | |
| outputs = self.trainer.infer_all( | |
| images, | |
| numpy=True, | |
| bin_value=0.5, | |
| half=CUDA, | |
| ignore_event=ignore_event, | |
| return_intermediates=True, | |
| ) | |
| outputs["input"] = uint8(images, True) | |
| # from Bx1xHxW to BxHxWx1 | |
| outputs["masked_input"] = outputs["input"] * ( | |
| outputs["mask"].squeeze(1)[..., None] == 0 | |
| ) | |
| if painter in {"both", "climategan"}: | |
| outputs["climategan_flood"] = outputs.pop("flood") | |
| else: | |
| del outputs["flood"] | |
| if painter != "climategan": | |
| if not self._stable_diffusion_is_setup: | |
| print("Setting up stable diffusion in-painting pipeline") | |
| self._setup_stable_diffusion() | |
| mask = outputs["mask"].squeeze(1) | |
| input_images = ( | |
| torch.tensor(images).permute(0, 3, 1, 2).to(self.trainer.device) | |
| if pil_image is None | |
| else pil_image | |
| ) | |
| input_mask = ( | |
| torch.tensor(mask[:, None, ...] > 0).to(self.trainer.device) | |
| if pil_image is None | |
| else Image.fromarray(mask[0]) | |
| ) | |
| print("Inferring stable diffusion in-painting for 50 steps") | |
| floods = self.sdip_pipeline( | |
| prompt=[prompt] * images.shape[0], | |
| image=input_images, | |
| mask_image=input_mask, | |
| height=640, | |
| width=640, | |
| num_inference_steps=50, | |
| ) | |
| print("Stable diffusion in-painting done") | |
| bin_mask = mask[..., None] > 0 | |
| flood = np.stack([np.array(i) for i in floods.images]) | |
| copy_flood = flood * bin_mask + uint8(images, True) * (1 - bin_mask) | |
| outputs["stable_flood"] = flood | |
| outputs["stable_copy_flood"] = copy_flood | |
| if concats: | |
| print("Concatenating flood images") | |
| outputs["concat"] = concat_events(outputs, concats, axis=2) | |
| return {k: v.squeeze(1) if v.shape[1] == 1 else v for k, v in outputs.items()} | |
| def infer_folder( | |
| self, | |
| folder_path, | |
| painter="both", | |
| prompt="An HD picture of a street with dirty water after a heavy flood", | |
| batch_size=4, | |
| concats=[ | |
| "input", | |
| "masked_input", | |
| "climategan_flood", | |
| "stable_flood", | |
| "stable_copy_flood", | |
| ], | |
| write=True, | |
| overwrite=False, | |
| ): | |
| """ | |
| Infers the images in a folder with the ClimateGAN model, batching images for | |
| inference according to the batch_size. | |
| Images must end in .jpg, .jpeg or .png (not case-sensitive). | |
| Images must not contain the separator ("---") in their name. | |
| Images will be written to disk in the same folder as the input images, with | |
| a name that depends on its data, potentially the prompt and a random | |
| identifier in case multiple inferences are run in the folder. | |
| Output dict contains the following keys: | |
| - "input": The input image | |
| - "mask": The mask used to generate the flood (from ClimateGAN's Masker) | |
| - "masked_input": The input image with the mask applied | |
| - "climategan_flood": The flooded image generated by ClimateGAN's Painter | |
| on the masked input (only if "painter" is "climategan" or "both"). | |
| - "stable_flood": The flooded image in-painted by the stable diffusion model | |
| from the mask and the input image (only if "painter" is "stable_diffusion" | |
| or "both"). | |
| - "stable_copy_flood": The flooded image in-painted by the stable diffusion | |
| model with its original context pasted back in: | |
| y = m * flooded + (1-m) * input | |
| (only if "painter" is "stable_diffusion" or "both"). | |
| Args: | |
| folder_path (Union[str, Path]): Where to read images from. | |
| painter (str, optional): Which painter to use: "climategan", | |
| "stable_diffusion" or "both". Defaults to "both". | |
| prompt (str, optional): The prompt used to guide the diffusion. Defaults | |
| to "An HD picture of a street with dirty water after a heavy flood". | |
| batch_size (int, optional): Size of inference batches. Defaults to 4. | |
| concats (list, optional): List of keys in `output` to concatenate together | |
| in a new `{original_stem}_concat` image written. Defaults to: | |
| ["input", "masked_input", "climategan_flood", "stable_flood", | |
| "stable_copy_flood"]. | |
| write (bool, optional): Whether or not to write the outputs to the input | |
| folder.Defaults to True. | |
| overwrite (Union[bool, str], optional): Whether to overwrite the images or | |
| not. If a string is provided, it will be included in the name. | |
| Defaults to False. | |
| Returns: | |
| dict: a dictionary containing the output images | |
| """ | |
| folder_path = Path(folder_path).expanduser().resolve() | |
| assert folder_path.exists(), f"Folder {str(folder_path)} does not exist" | |
| assert folder_path.is_dir(), f"{str(folder_path)} is not a directory" | |
| im_paths = [ | |
| p | |
| for p in folder_path.iterdir() | |
| if p.suffix.lower() in [".jpg", ".png", ".jpeg"] and "---" not in p.name | |
| ] | |
| assert im_paths, f"No images found in {str(folder_path)}" | |
| ims = [self._preprocess_image(np.array(Image.open(p))) for p in im_paths] | |
| batches = [ | |
| np.stack(ims[i : i + batch_size]) for i in range(0, len(ims), batch_size) | |
| ] | |
| inferences = [ | |
| self.infer_preprocessed_batch(b, painter, prompt, concats) for b in batches | |
| ] | |
| outputs = { | |
| k: [i for e in inferences for i in e[k]] for k in inferences[0].keys() | |
| } | |
| if write: | |
| self.write(outputs, im_paths, painter, overwrite, prompt) | |
| return outputs | |
| def write( | |
| self, | |
| outputs, | |
| im_paths, | |
| painter="both", | |
| overwrite=False, | |
| prompt="", | |
| ): | |
| """ | |
| Writes the outputs of the inference to disk, in the input folder. | |
| Images will be named like: | |
| f"{original_stem}---{overwrite_prefix}_{painter_type}_{output_type}.{suffix}" | |
| `painter_type` is either "climategan" or f"stable_diffusion_{prompt}" | |
| Args: | |
| outputs (_type_): The inference procedure's output dict. | |
| im_paths (list[Path]): The list of input images paths. | |
| painter (str, optional): Which painter was used. Defaults to "both". | |
| overwrite (bool, optional): Whether to overwrite the images or not. | |
| If a string is provided, it will be included in the name. | |
| If False, a random identifier will be added to the name. | |
| Defaults to False. | |
| prompt (str, optional): The prompt used to guide the diffusion. Defaults | |
| to "". | |
| """ | |
| prompt = re.sub("[^0-9a-zA-Z]+", "", prompt).lower() | |
| overwrite_prefix = "" | |
| if not overwrite: | |
| overwrite_prefix = str(uuid4())[:8] | |
| print("Writing events with prefix", overwrite_prefix) | |
| else: | |
| if isinstance(overwrite, str): | |
| overwrite_prefix = overwrite | |
| print("Writing events with prefix", overwrite_prefix) | |
| # for each image, for each event/data type | |
| for i, im_path in enumerate(im_paths): | |
| for event, ims in outputs.items(): | |
| painter_prefix = "" | |
| if painter == "climategan" and event == "flood": | |
| painter_prefix = "climategan" | |
| elif ( | |
| painter in {"stable_diffusion", "both"} and event == "stable_flood" | |
| ): | |
| painter_prefix = f"_stable_{prompt}" | |
| elif painter == "both" and event == "climategan_flood": | |
| painter_prefix = "" | |
| im = ims[i] | |
| im = Image.fromarray(uint8(im)) | |
| imstem = f"{im_path.stem}---{overwrite_prefix}{painter_prefix}_{event}" | |
| im.save(im_path.parent / (imstem + im_path.suffix)) | |
| if __name__ == "__main__": | |
| print("Run `$ python climategan_wrapper.py help` for usage instructions\n") | |
| # parse arguments | |
| args = resolved_args( | |
| defaults={ | |
| "input_folder": None, | |
| "output_folder": None, | |
| "painter": "both", | |
| "help": False, | |
| } | |
| ) | |
| # print help | |
| if args.help: | |
| print( | |
| "Usage: python inference.py input_folder=/path/to/folder\n" | |
| + "By default inferences will be stored in the input folder.\n" | |
| + "Add `output_folder=/path/to/folder` for a different output folder.\n" | |
| + "By default, both ClimateGAN and Stable Diffusion will be used." | |
| + "Change this by adding `painter=climategan` or" | |
| + " `painter=stable_diffusion`.\n" | |
| + "Make sure you have agreed to the terms of use for the models." | |
| + "In particular, visit SD's model card to agree to the terms of use:" | |
| + " https://huggingface.co/runwayml/stable-diffusion-inpainting" | |
| ) | |
| # print args | |
| args.pretty_print() | |
| # load models | |
| cg = ClimateGAN("models/climategan") | |
| # check painter type | |
| assert args.painter in { | |
| "climategan", | |
| "stable_diffusion", | |
| "both", | |
| }, ( | |
| f"Unknown painter {args.painter}. " | |
| + "Allowed values are 'climategan', 'stable_diffusion' and 'both'." | |
| ) | |
| # load SD pipeline if need be | |
| if args.painter != "climate_gan": | |
| cg._setup_stable_diffusion() | |
| # resolve input folder path | |
| in_path = Path(args.input_folder).expanduser().resolve() | |
| assert in_path.exists(), f"Folder {str(in_path)} does not exist" | |
| # output is input if not specified | |
| if args.output_folder is None: | |
| out_path = in_path | |
| # find images in input folder | |
| im_paths = [ | |
| p | |
| for p in in_path.iterdir() | |
| if p.suffix.lower() in [".jpg", ".png", ".jpeg"] and "---" not in p.name | |
| ] | |
| assert im_paths, f"No images found in {str(im_paths)}" | |
| print(f"\nFound {len(im_paths)} images in {str(in_path)}\n") | |
| # infer and write | |
| for i, im_path in enumerate(im_paths): | |
| print(">>> Processing", f"{i}/{len(im_paths)}", im_path.name) | |
| outs = cg.infer_single( | |
| np.array(Image.open(im_path)), | |
| args.painter, | |
| as_pil_image=True, | |
| concats=[ | |
| "input", | |
| "masked_input", | |
| "climategan_flood", | |
| "stable_copy_flood", | |
| ], | |
| ) | |
| for k, v in outs.items(): | |
| name = f"{im_path.stem}---{k}{im_path.suffix}" | |
| im = Image.fromarray(uint8(v)) | |
| im.save(out_path / name) | |
| print(">>> Done", f"{i}/{len(im_paths)}", im_path.name, end="\n\n") | |