from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Union
import meerkat as mk
import numpy as np
import torch.nn as nn
from sklearn.base import BaseEstimator
@dataclass
class Config:
pass
[docs]class Slicer(ABC, BaseEstimator):
def __init__(self, n_slices: int):
super().__init__()
self.config = Config()
self.config.n_slices = n_slices
[docs] @abstractmethod
def fit(
self,
model: nn.Module = None,
data_dp: mk.DataPanel = None,
) -> Slicer:
"""
Fit the slicer to data.
Args:
data (mk.DataPanel, optional): A `Meerkat DataPanel` with columns for
embeddings, targets, and prediction probabilities. The names of the
columns can be specified with the ``embeddings``, ``targets``, and
``pred_probs`` arguments. Defaults to None.
embeddings (Union[str, np.ndarray], optional): The name of a column in
``data`` holding embeddings. If ``data`` is ``None``, then an np.ndarray
of shape (n_samples, dimension of embedding). Defaults to
"embedding".
targets (Union[str, np.ndarray], optional): The name of a column in
``data`` holding class labels. If ``data`` is ``None``, then an
np.ndarray of shape (n_samples,). Defaults to "target".
pred_probs (Union[str, np.ndarray], optional): The name of
a column in ``data`` holding model predictions (can either be "soft"
probability scores or "hard" 1-hot encoded predictions). If
``data`` is ``None``, then an np.ndarray of shape (n_samples, n_classes)
or (n_samples,) in the binary case. Defaults to "pred_probs".
losses (Union[str, np.ndarray], optional): The name of a column in ``data``
holding the loss of the model predictions. If ``data`` is ``None``,
then an np.ndarray of shape (n_samples,). Defaults to "loss".
Returns:
Slicer: Returns a fit instance of the slicer.
"""
raise NotImplementedError()
[docs] @abstractmethod
def predict(
self,
data: mk.DataPanel,
embeddings: Union[str, np.ndarray] = "embedding",
targets: Union[str, np.ndarray] = "target",
pred_probs: Union[str, np.ndarray] = "pred_probs",
) -> np.ndarray:
"""
Get slice membership for data using the fit slicer.
.. caution::
Must call ``Slicer.fit`` prior to calling ``Slicer.predict``.
Args:
data (mk.DataPanel, optional): A `Meerkat DataPanel` with columns for
embeddings, targets, and prediction probabilities. The names of the
columns can be specified with the ``embeddings``, ``targets``, and
``pred_probs`` arguments. Defaults to None.
embeddings (Union[str, np.ndarray], optional): The name of a colum in
``data`` holding embeddings. If ``data`` is ``None``, then an np.ndarray
of shape (n_samples, dimension of embedding). Defaults to
"embedding".
targets (Union[str, np.ndarray], optional): The name of a column in
``data`` holding class labels. If ``data`` is ``None``, then an
np.ndarray of shape (n_samples,). Defaults to "target".
pred_probs (Union[str, np.ndarray], optional): The name of
a column in ``data`` holding model predictions (can either be "soft"
probability scores or "hard" 1-hot encoded predictions). If
``data`` is ``None``, then an np.ndarray of shape (n_samples, n_classes)
or (n_samples,) in the binary case. Defaults to "pred_probs".
losses (Union[str, np.ndarray], optional): The name of a column in ``data``
holding the loss of the model predictions. If ``data`` is ``None``,
then an np.ndarray of shape (n_samples,). Defaults to "loss".
Returns:
np.ndarray: A binary ``np.ndarray`` of shape (n_samples, n_slices) where
values are either 1 or 0.
"""
raise NotImplementedError()
[docs] @abstractmethod
def predict_proba(
self,
data: mk.DataPanel,
embeddings: Union[str, np.ndarray] = "embedding",
targets: Union[str, np.ndarray] = "target",
pred_probs: Union[str, np.ndarray] = "pred_probs",
) -> np.ndarray:
"""
Get probablisitic (**i.e.** soft) slice membership for data using the fit
slicer.
.. caution::
Must call ``Slicer.fit`` prior to calling ``Slicer.predict``.
Args:
data (mk.DataPanel, optional): A `Meerkat DataPanel` with columns for
embeddings, targets, and prediction probabilities. The names of the
columns can be specified with the ``embeddings``, ``targets``, and
``pred_probs`` arguments. Defaults to None.
embeddings (Union[str, np.ndarray], optional): The name of a colum in
``data`` holding embeddings. If ``data`` is ``None``, then an np.ndarray
of shape (n_samples, dimension of embedding). Defaults to
"embedding".
targets (Union[str, np.ndarray], optional): The name of a column in
``data`` holding class labels. If ``data`` is ``None``, then an
np.ndarray of shape (n_samples,). Defaults to "target".
pred_probs (Union[str, np.ndarray], optional): The name of
a column in ``data`` holding model predictions (can either be "soft"
probability scores or "hard" 1-hot encoded predictions). If
``data`` is ``None``, then an np.ndarray of shape (n_samples, n_classes)
or (n_samples,) in the binary case. Defaults to "pred_probs".
losses (Union[str, np.ndarray], optional): The name of a column in ``data``
holding the loss of the model predictions. If ``data`` is ``None``,
then an np.ndarray of shape (n_samples,). Defaults to "loss".
Returns:
np.ndarray: A binary ``np.ndarray`` of shape (n_samples, n_slices) where
values are either 1 or 0.
"""
raise NotImplementedError()
[docs] def get_params(self) -> Dict[str, Any]:
"""
Get the parameters of this slicer. Returns a dictionary mapping from the names
of the parameters (as they are defined in the ``__init__``) to their values.
Returns:
Dict[str, Any]: A dictionary of parameters.
"""
return self.config.__dict__
[docs] def set_params(self, **params):
raise ValueError(
f"Slicer of type {self.__class__.__name__} does not support `set_params`."
)
def to(self, device: Union[str, int]):
if device != "cpu":
raise ValueError(f"Slicer of type {type(self)} does not support GPU.")
# by default this is a no-op, but subclasses can override