Source code for domino.gui

from typing import List, Union

import ipywidgets as widgets
import matplotlib.pyplot as plt
import meerkat as mk
import numpy as np
import pandas as pd
import seaborn as sns
from IPython.display import display

from domino.utils import unpack_args
from ._describe import describe


[docs]def explore( data: mk.DataPanel = None, embeddings: Union[str, np.ndarray] = "embedding", targets: Union[str, np.ndarray] = "target", pred_probs: Union[str, np.ndarray] = "pred_prob", slices: Union[str, np.ndarray] = "slices", text: mk.DataPanel = None, text_embeddings: Union[str, np.ndarray] = "embedding", phrase: Union[str, np.ndarray] = "output_phrase", ) -> None: """Creates a IPyWidget GUI for exploring discovered slices. The GUI includes two sections: (1) The first section displays data visualizations summarizing the model predictions and accuracy stratified by slice. (2) The second section displays a table (i.e. Meerkat DataPanel) of the data examples most representative of each slice. The DataPanel passed to ``data`` should include columns for embeddings, targets, pred_probs and slices. Any additional columns will be included in the visualization in section (2). .. caution:: This GUI works best in the original Jupyter Notebook, and may not work properly in a Jupyter Lab or VSCode environment. Args: data (mk.DataPanel, optional): A `Meerkat DataPanel` with columns for embeddings, targets, and prediction probabilities. The names of the columns can be specified with the ``embeddings``, ``targets``, and ``pred_probs`` arguments. Defaults to None. embeddings (Union[str, np.ndarray], optional): The name of a column in ``data`` holding embeddings. If ``data`` is ``None``, then an np.ndarray of shape (n_samples, dimension of embedding). Defaults to "embedding". targets (Union[str, np.ndarray], optional): The name of a column in ``data`` holding class labels. If ``data`` is ``None``, then an np.ndarray of shape (n_samples,). Defaults to "target". pred_probs (Union[str, np.ndarray], optional): The name of a column in ``data`` holding model predictions (can either be "soft" probability scores or "hard" 1-hot encoded predictions). If ``data`` is ``None``, then an np.ndarray of shape (n_samples, n_classes) or (n_samples,) in the binary case. Defaults to "pred_probs". slices (str, optional): The name of The name of a column in ``data`` holding discovered slices. If ``data`` is ``None``, then an np.ndarray of shape (num_examples, num_slices). Defaults to "slices". text (str, optional): A `Meerkat DataPanel` with columns for text phrases and their embeddings. The names of the columns can be specified with the ``text_embeddings`` and ``phrase`` arguments. Defaults to None. text_embeddings (Union[str, np.ndarray], optional): The name of a colum in ``text`` holding embeddings. If ``text`` is ``None``, then an np.ndarray of shape (n_phrases, dimension of embedding). Defaults to "embedding". phrase (Union[str, np.ndarray], optional): The name of a column in ``text`` holding text phrases. If ``text`` is ``None``, then an np.ndarray of shape (n_phrases,). Defaults to "output_phrase". Examples -------- .. code-block:: python :name: Example: from domino import explore, DominoSDM dp = ... # prepare the dataset as a Meerkat DataPanel # split dataset valid_dp = dp.lz[dp["split"] == "valid"] test_dp = dp.lz[dp["split"] == "test"] domino = DominoSDM() domino.fit(data=valid_dp) test_dp["slices"] = domino.transform( data=test_dp, embeddings="emb", targets="target", pred_probs="probs" ) explore(data=test_dp) """ embeddings, targets, pred_probs, slices = unpack_args( data, embeddings, targets, pred_probs, slices ) if data is None: dp = mk.DataPanel( { "embeddings": embeddings, "targets": targets, "pred_probs": pred_probs, "domino_slices": slices, } ) else: dp = data if isinstance(data, mk.DataPanel) else mk.DataPanel(data) plot_output = widgets.Output() # define functions for generating visualizations def plot_slice(slice_idx, slice_threshold: float): # TODO (Sabri): Support a confusion matrix for the multiclass case. with plot_output: plot_df = pd.DataFrame( { "in-slice": slices[:, slice_idx] > slice_threshold, "pred_probs": pred_probs[:, 1].numpy() if len(pred_probs.shape) == 2 else pred_probs, "target": targets, } ) g = sns.displot( data=plot_df, hue="in-slice", x="pred_probs", col="target", aspect=1.7, height=2, facet_kws={"sharey": False}, hue_order=[False, True], palette=["#bdbdbd", "#2396f3"], stat="percent", common_norm=False, bins=20, ) g.set_axis_labels("Model's output probability", "% of examples") for target in np.unique(targets): in_slice = np.sum( (slices[:, slice_idx] > slice_threshold) & (targets == target) ) g.axes[0, int(target)].set_title( f"target={target} \n (# of examples in-slice={in_slice})" ) plot_output.clear_output(wait=True) plt.show() description_output = widgets.Output() def show_descriptions(slice_idx: int, slice_threshold: float): description_output.clear_output(wait=False) if text is not None: description_dp = describe( data=dp, embeddings=embeddings, targets=targets, slices=slices, slice_idx=slice_idx, text=text, text_embeddings=text_embeddings, phrases=phrase, slice_threshold=slice_threshold, ) with description_output: display(description_dp[(-description_dp["score"]).argsort()[:5]]) dp_output = widgets.Output() def show_dp( slice_idx, page_idx: int, page_size: int, columns: List[str], slice_threshold: float, ): mk.config.DisplayOptions.max_rows = page_size dp_output.clear_output(wait=False) num_examples_in_slice = np.sum(slices[:, slice_idx] > slice_threshold) with dp_output: display( dp.lz[ (-slices[:, slice_idx]).argsort()[ page_size * page_idx : min( page_size * (page_idx + 1), num_examples_in_slice ) ] ][list(columns)] ) # Create widgets slice_idx_widget = widgets.Dropdown( value=1, options=list(range(slices.shape[-1])), description="Slice", layout=widgets.Layout(width="150px"), ) slice_threshold_widget = widgets.FloatSlider( value=0.5, min=0, max=1.0, step=0.025, description="Slice Inclusion Threshold", disabled=False, continuous_update=False, orientation="horizontal", readout=True, readout_format=".3f", style={"description_width": "initial"}, ) # TODO(Sabri): Add a widget for the # of examples in the slice at the current # threshold. It will have to be linked with the threshold widget above. column_selector = widgets.SelectMultiple( options=dp.columns, value=dp.columns, description="Columns", disabled=False ) page_size_widget = widgets.RadioButtons( options=[10, 25, 50], description="Page size" ) page_idx_widget = widgets.BoundedIntText( value=0, min=0, max=10, step=1, description="Page", disabled=False, readout=True, readout_format="d", layout=widgets.Layout(width="150px"), ) # Establish interactions between widgets and the visualization functions widgets.interactive( show_descriptions, slice_idx=slice_idx_widget, slice_threshold=slice_threshold_widget, ) widgets.interactive( show_dp, slice_idx=slice_idx_widget, columns=column_selector, page_idx=page_idx_widget, page_size=page_size_widget, slice_threshold=slice_threshold_widget, ) widgets.interactive( plot_slice, slice_idx=slice_idx_widget, slice_threshold=slice_threshold_widget, ) # Layout and display the widgets display( widgets.HBox( [ widgets.HTML(value="<p><strong> Domino Slice Explorer </strong></p>"), slice_idx_widget, ] ) ) display(slice_threshold_widget) display(plot_output) display( widgets.VBox( [ widgets.HTML( value=( "<p> <strong> Natural language descriptions of the slice: " "</strong> </p>" ) ), description_output, ] ) ) display( widgets.HBox( [ widgets.VBox( [ widgets.HTML( value=( "<style>p{word-wrap: break-word}</style> <p>" + "Select multiple columns with <em>cmd-click</em>." + " </p>" ) ), column_selector, ] ), widgets.VBox([page_idx_widget, page_size_widget]), ], ) ) display( widgets.VBox( [ widgets.HTML( value=( "<p> <strong> Examples in the slice, ranked by likelihood: " "</strong> </p>" ) ), dp_output, ] ) ) # To actually run the functions `plot_slice` and `show_dp` we need update the value # of one of the widgets. slice_idx_widget.value = 0