Embrace-Vision / features_extraction.py
Reverb's picture
Upload 123 files
b239c75
Raw
History Blame Contribute Delete
1.35 kB
import torch
from torchvision import models, transforms
from PIL import Image
import pickle
import os
from tqdm import tqdm # Import tqdm for the progress bar
# Load a pretrained ResNet model
model = models.resnet50(pretrained=True)
model = model.eval()
# Define preprocessing transforms
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Function to extract features from an image
def extract_features(image_path):
image = Image.open(image_path).convert('RGB')
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0)
with torch.no_grad():
output = model(input_batch)
return output.squeeze().numpy()
# Directory containing your images
images_directory = "photos/"
# Process each image and save features
image_features = {}
for filename in tqdm(os.listdir(images_directory), desc="Processing Images"):
if filename.endswith(".jpg") or filename.endswith(".png"):
image_path = os.path.join(images_directory, filename)
features = extract_features(image_path)
image_features[filename] = features
# Save the features to a pickle file
output_file = "unsplash-25k-embeddings.pkl"
with open(output_file, 'wb') as f:
pickle.dump(image_features, f)