John6666's picture
Upload 3 files
7691c26 verified
import spaces
import gradio as gr
from huggingface_hub import HfApi
from transformers.image_transforms import pad
import numpy as np
import torch
from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModel
model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
def _expand_for_data_format(values, input_data_format):
"""
Convert values to be in the format expected by np.pad based on the data format.
"""
if isinstance(values, (int, float)):
values = ((values, values), (values, values))
elif isinstance(values, tuple) and len(values) == 1:
values = ((values[0], values[0]), (values[0], values[0]))
elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], int):
values = (values, values)
elif isinstance(values, tuple) and len(values) == 2 and isinstance(values[0], tuple):
values = values
else:
raise ValueError(f"Unsupported format: {values}")
# add 0 for channel dimension
#values = ((0, 0), *values) if input_data_format == ChannelDimension.FIRST else (*values, (0, 0))
values = ((0, 0), *values) if input_data_format == "channels_first" else (*values, (0, 0))
# Add additional padding if there's a batch dimension
#values = (0, *values) if image.ndim == 4 else values
return values
#@spaces.GPU
def infer(height: int, width: int, channels: int, input_data_format: str, mode: str,
is_numpy: bool=True, is_mul: bool=True, is_int: bool=True, is_abs: bool=True):
try:
pad_kwargs = {}
pad_kwargs["mode"] = mode
if input_data_format != "None":
pad_kwargs["input_data_format"] = input_data_format
pad_kwargs["data_format"] = "channels_last"
# Example image as a NumPy array
image = np.random.rand(height, width, channels) # Height x Width x Channels
image_pil = np.array(Image.fromarray(image, 'RGB')) # Open with PIL and save
if is_mul: image = image * 255
if is_int: image = image.astype(np.uint8)
if is_abs: image = np.abs(image)
print(image)
print(image.dtype)
print(image_pil)
print(image_pil.dtype)
# Define padding: ((before_height, after_height), (before_width, after_width))
padding = ((0, 0), (112, 112)) # Pads width to make it 448
# Apply padding
if is_numpy:
padded_image = np.pad(image, _expand_for_data_format(padding, input_data_format), mode="constant",
constant_values=_expand_for_data_format(0.0, input_data_format))
padded_image_pil = np.pad(image_pil, _expand_for_data_format(padding, input_data_format), mode="constant",
constant_values=_expand_for_data_format(0.0, input_data_format))
else:
padded_image = pad(image, padding=padding)
padded_image_pil = pad(image_pil, padding=padding, **pad_kwargs)
print("Original Image Shape:", image.shape)
print("Padded Image Shape:", padded_image.shape)
print("Padded Image Shape (PIL):", padded_image_pil.shape)
image_torch = torch.tensor(image).permute(2, 0, 1).unsqueeze(0)
padded_image_torch = torch.tensor(padded_image).permute(2, 0, 1).unsqueeze(0)
padded_image_pil_torch = torch.tensor(padded_image_pil).permute(2, 0, 1).unsqueeze(0)
print("Original Image Shape (Torch):", image_torch.shape)
print("Padded Image Shape (Torch):", padded_image_torch.shape)
print("Padded Image Shape (PIL) (Torch):", padded_image_pil_torch.shape)
# Step 5: Pass the padded image through the model
#outputs_padded = model(pixel_values=padded_image_torch, interpolate_pos_encoding=True)
#outputs_original = model(pixel_values=image_torch)
# Step 6: Extract the results for comparison
#original = outputs_original.pooler_output
#padded = outputs_padded.pooler_output
#print(torch.mean(original - padded))
# Save images
original_im = Image.fromarray(image, 'RGB')
padded_im = Image.fromarray(padded_image, 'RGB')
padded_im_pil = Image.fromarray(padded_image_pil, 'RGB')
#original_im.save("_pad_original.png")
#padded_im.save("_pad_padded.png")
#padded_im_pil.save("_pad_padded_pil.png")
return original_im, padded_im, padded_im_pil
except Exception as e:
raise gr.Error(e)
with gr.Blocks() as demo:
with gr.Row(equal_height=True):
width = gr.Number(label="Width", value=224, minimum=1, maximum=4096, step=1)
height = gr.Number(label="Height", value=224, minimum=1, maximum=4096, step=1)
channels = gr.Number(label="Channels", value=3, minimum=2, maximum=5, step=1)
input_df = gr.Radio(label="Input data format", choices=["None", "channels_first", "channels_last"], value="None")
mode = gr.Radio(label="Mode", choices=["constant", "reflect", "replicate", "symmetric"], value="constant")
is_mul = gr.Checkbox(label="Multiply by 255", value=False)
is_int = gr.Checkbox(label="Cast to uint8", value=False)
is_abs = gr.Checkbox(label="Absolute value", value=False)
is_numpy = gr.Checkbox(label="Pad by numpy", value=False)
run_button = gr.Button("Run", variant="primary")
with gr.Row(equal_height=True):
output_image1 = gr.Image(label="Original")
output_image2 = gr.Image(label="Padded")
output_image3 = gr.Image(label="Padded (with PIL)")
run_button.click(infer, [height, width, channels, input_df, mode, is_numpy, is_mul, is_int, is_abs],
[output_image1, output_image2, output_image3])
demo.launch()