Source code for domino._embed

import os
from typing import Callable, Union

import meerkat as mk
import torch

from domino._embed.encoder import Encoder

from ..registry import Registry
from .bit import bit
from .clip import clip
from .robust import robust
from .transformers import transformers

__all__ = ["clip", "bit"]

encoders = Registry(name="encoders")

encoders.register(clip, aliases=[])
encoders.register(bit, aliases=[])
encoders.register(robust, aliases=[])
encoders.register(transformers, aliases=[])


def infer_modality(col: mk.AbstractColumn):

    if isinstance(col, mk.ImageColumn):
        return "image"
    elif isinstance(col, mk.PandasSeriesColumn):
        return "text"
    else:
        raise ValueError(f"Cannot infer modality from colummn of type {type(col)}.")


[docs]def embed( data: mk.DataPanel, input_col: str, encoder: Union[str, Encoder] = "clip", modality: str = None, out_col: str = None, device: Union[int, str] = "cpu", mmap_dir: str = None, num_workers: int = 4, batch_size: int = 128, **kwargs, ) -> mk.DataPanel: """Embed a column of data with an encoder from the encoder registry. Examples -------- Suppose you have an Image dataset (e.g. Imagenette, CIFAR-10) loaded into a `Meerkat DataPanel <https://github.com/robustness-gym/meerkat>`_. You can embed the images in the dataset with CLIP using a code snippet like: .. code-block:: python import meerkat as mk from domino import embed dp = mk.datasets.get("imagenette") dp = embed( data=dp, input_col="img", encoder="clip" ) Args: data (mk.DataPanel): A DataPanel containing the data to embed. input_col (str): The name of the column to embed. encoder (Union[str, Encoder], optional): Name of the encoder to use. List supported encoders with ``domino.encoders``. Defaults to "clip". Alternatively, pass an :class:`~domino._embed.encoder.Encoder` object containing a custom encoder. modality (str, optional): The modality of the data to be embedded. Defaults to None, in which case the modality is inferred from the type of the input column. out_col (str, optional): The name of the column where the embeddings are stored. Defaults to None, in which case it is ``"{encoder}({input_col})"``. device (Union[int, str], optional): The device on which. Defaults to "cpu". mmap_dir (str, optional): The path to directory where a memory-mapped file containing the embeddings will be written. Defaults to None, in which case the embeddings are not memmapped. num_workers (int, optional): Number of worker processes used to load the data from disk. Defaults to 4. batch_size (int, optional): Size of the batches to used . Defaults to 128. **kwargs: Additional keyword arguments are passed to the encoder. To see supported arguments for each encoder, see the encoder documentation (e.g. :func:`~domino._embed.clip`). Returns: mk.DataPanel: A view of ``data`` with a new column containing the embeddings. This column will be named according to the ``out_col`` parameter. """ if modality is None: modality = infer_modality(col=data[input_col]) if out_col is None: out_col = f"{encoder}({input_col})" encoder = encoders.get(encoder, device=device, **kwargs) if modality not in encoder: raise ValueError(f'Encoder "{encoder}" does not support modality "{modality}".') encoder = encoder[modality] return _embed( data=data, input_col=input_col, out_col=out_col, encode=encoder.encode, preprocess=encoder.preprocess, collate=encoder.collate, device=device, mmap_dir=mmap_dir, num_workers=num_workers, batch_size=batch_size, )
def _embed( data: mk.DataPanel, input_col: str, out_col: str, encode: Callable, preprocess: Callable, collate: Callable, device: int = None, mmap_dir: str = None, num_workers: int = 4, batch_size: int = 128, ): if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" if preprocess is not None: embed_input = data[input_col].to_lambda(preprocess) else: embed_input = data[input_col] if collate is not None: embed_input.collate_fn = collate def _prepare_input(x): if isinstance(x, mk.AbstractColumn): x = x.data if torch.is_tensor(x): x = x.to(device) return x with torch.no_grad(): data[out_col] = embed_input.map( lambda x: encode(_prepare_input(x)).cpu().detach().numpy(), pbar=True, is_batched_fn=True, batch_size=batch_size, num_workers=num_workers, mmap=mmap_dir is not None, mmap_path=None if mmap_dir is None else os.path.join(mmap_dir, "emb_mmap.npy"), flush_size=128, ) return data