!pip install -q datasets faiss-gpu transformers sentencepiece
Similarity Search
Embedding multimodal data for similarity search using 🤗 transformers, 🤗 datasets and FAISS
Authored by: Merve Noyan
Embeddings are semantically meaningful compressions of information. They can be used to do similarity search, zero-shot classification or simply train a new model. Use cases for similarity search include searching for similar products in e-commerce, content search in social media and more. This notebook walks you through using 🤗transformers, 🤗datasets and FAISS to create and index embeddings from a feature extraction model to later use them for similarity search. Let’s install necessary libraries.
For this tutorial, we will use CLIP model to extract the features. CLIP is a revolutionary model that introduced joint training of a text encoder and an image encoder to connect two modalities.
import torch
from PIL import Image
from transformers import AutoImageProcessor, AutoModel, AutoTokenizer
import faiss
import numpy as np
= torch.device('cuda' if torch.cuda.is_available() else "cpu")
device
= AutoModel.from_pretrained("openai/clip-vit-base-patch16").to(device)
model = AutoImageProcessor.from_pretrained("openai/clip-vit-base-patch16")
processor = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch16") tokenizer
Load the dataset. To keep this notebook light, we will use a small captioning dataset, jmhessel/newyorker_caption_contest.
from datasets import load_dataset
= load_dataset("jmhessel/newyorker_caption_contest", "explanation") ds
See an example.
"train"][0]["image"] ds[
"train"][0]["image_description"] ds[
'Two women are looking out a window. There is snow outside, and there is a snowman with human arms.'
We don’t have to write any function to embed examples or create an index. 🤗 datasets library’s FAISS integration abstracts these processes. We can simply use map
method of the dataset to create a new column with the embeddings for each example like below. Let’s create one for text features on the prompt column.
= ds["train"]
dataset = dataset.map(lambda example:
ds_with_embeddings 'embeddings': model.get_text_features(
{**tokenizer([example["image_description"]],
=True, return_tensors="pt")
truncation"cuda"))[0].detach().cpu().numpy()}) .to(
='embeddings') ds_with_embeddings.add_faiss_index(column
We can do the same and get the image embeddings.
= ds_with_embeddings.map(lambda example:
ds_with_embeddings 'image_embeddings': model.get_image_features(
{**processor([example["image"]], return_tensors="pt")
"cuda"))[0].detach().cpu().numpy()}) .to(
='image_embeddings') ds_with_embeddings.add_faiss_index(column
Querying the data with text prompts
We can now query the dataset with text or image to get similar items from it.
= "a snowy day"
prmt = model.get_text_features(**tokenizer([prmt], return_tensors="pt", truncation=True).to("cuda"))[0].detach().cpu().numpy()
prmt_embedding = ds_with_embeddings.get_nearest_examples('embeddings', prmt_embedding, k=1) scores, retrieved_examples
def downscale_images(image):
= 200
width = (width / float(image.size[0]))
ratio = int((float(image.size[1]) * float(ratio)))
height = image.resize((width, height), Image.Resampling.LANCZOS)
img return img
= [downscale_images(image) for image in retrieved_examples["image"]]
images # see the closest text and image
print(retrieved_examples["image_description"])
0])
display(images[
['A man is in the snow. A boy with a huge snow shovel is there too. They are outside a house.']
Querying the data with image prompts
Image similarity inference is similar, where you just call get_image_features
.
import requests
# image of a beaver
= "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/beaver.png"
url = Image.open(requests.get(url, stream=True).raw)
image display(downscale_images(image))
Search for the similar image.
= model.get_image_features(**processor([image], return_tensors="pt", truncation=True).to("cuda"))[0].detach().cpu().numpy()
img_embedding = ds_with_embeddings.get_nearest_examples('image_embeddings', img_embedding, k=1) scores, retrieved_examples
Display the most similar image to the beaver image.
= [downscale_images(image) for image in retrieved_examples["image"]]
images # see the closest text and image
print(retrieved_examples["image_description"])
0]) display(images[
['Salmon swim upstream but they see a grizzly bear and are in shock. The bear has a smug look on his face when he sees the salmon.']
Saving, pushing and loading the embeddings
We can save the dataset with embeddings with save_faiss_index
.
'embeddings', 'embeddings/embeddings.faiss') ds_with_embeddings.save_faiss_index(
'image_embeddings', 'embeddings/image_embeddings.faiss') ds_with_embeddings.save_faiss_index(
It’s a good practice to store the embeddings in a dataset repository, so we will create one and push our embeddings there to pull later. We will login to Hugging Face Hub, create a dataset repository there and push our indexes there and load using snapshot_download
.
from huggingface_hub import HfApi, notebook_login, snapshot_download
notebook_login()
from huggingface_hub import HfApi
= HfApi()
api "merve/faiss_embeddings", repo_type="dataset")
api.create_repo(
api.upload_folder(="./embeddings",
folder_path="merve/faiss_embeddings",
repo_id="dataset",
repo_type )
="merve/faiss_embeddings", repo_type="dataset",
snapshot_download(repo_id="downloaded_embeddings") local_dir
We can load the embeddings to the dataset with no embeddings using load_faiss_index
.
= ds["train"]
ds 'embeddings', './downloaded_embeddings/embeddings.faiss')
ds.load_faiss_index(# infer again
= "people under the rain" prmt
= model.get_text_features(
prmt_embedding **tokenizer([prmt], return_tensors="pt", truncation=True)
"cuda"))[0].detach().cpu().numpy()
.to(
= ds.get_nearest_examples('embeddings', prmt_embedding, k=1) scores, retrieved_examples
"image"][0]) display(retrieved_examples[