"""Client Interface Module"""
import copy
import warnings
from typing import Dict, List, Optional, Union
import datasets
import numpy as np
import pandas as pd
import torch
from tqdm.autonotebook import tqdm
from .datasets import BaseDataset
from .datasets.datamanagers import HateXplainDataset, MovieReviews, SSTDataset
from .evaluators.class_measures import AOPC_Comprehensiveness_Evaluation_by_class
from .evaluators.evaluation import EvaluationMetricOutput, ExplanationEvaluation
from .evaluators.faithfulness_measures import (
AOPC_Comprehensiveness_Evaluation,
AOPC_Sufficiency_Evaluation,
TauLOO_Evaluation,
)
from .evaluators.plausibility_measures import (
AUPRC_PlausibilityEvaluation,
Tokenf1_PlausibilityEvaluation,
TokenIOU_PlausibilityEvaluation,
)
from .explainers.explanation import Explanation, ExplanationWithRationale
from .explainers.gradient import GradientExplainer, IntegratedGradientExplainer
from .explainers.lime import LIMEExplainer
from .explainers.shap import SHAPExplainer
from .modeling import create_helper
from .visualization import show_evaluation_table, show_table
NONE_RATIONALE = []
def lp_normalize(explanations, ord=1):
"""Run Lp-noramlization of explanation attribution scores
Args:
explanations (List[Explanation]): list of explanations to normalize
ord (int, optional): order of the norm. Defaults to 1.
Returns:
List[Explanation]: list of normalized explanations
"""
new_exps = list()
for exp in explanations:
new_exp = copy.copy(exp)
if isinstance(new_exp.scores, np.ndarray) and new_exp.scores.size > 0:
norm_axis = (
-1 if new_exp.scores.ndim == 1 else (0, 1)
) # handle axis correctly
norm = np.linalg.norm(new_exp.scores, axis=norm_axis, ord=ord)
if norm != 0: # avoid division by zero
new_exp.scores /= norm
new_exps.append(new_exp)
return new_exps
[docs]
class Benchmark:
"""Generic interface to compute multiple explanations."""
[docs]
def __init__(
self,
model,
tokenizer,
task_name: str = "text-classification",
explainers: List = None,
evaluators: List = None,
class_based_evaluators: List = None,
):
self.model = model
self.tokenizer = tokenizer
self.task_name = task_name
self.helper = create_helper(self.model, self.tokenizer, self.task_name)
self.explainers = explainers
self.evaluators = evaluators
self.class_based_evaluators = class_based_evaluators
if not explainers:
self.explainers = [
SHAPExplainer(self.model, self.tokenizer, self.helper),
LIMEExplainer(self.model, self.tokenizer, self.helper),
GradientExplainer(
self.model, self.tokenizer, self.helper, multiply_by_inputs=False
),
GradientExplainer(
self.model, self.tokenizer, self.helper, multiply_by_inputs=True
),
IntegratedGradientExplainer(
self.model, self.tokenizer, self.helper, multiply_by_inputs=False
),
IntegratedGradientExplainer(
self.model, self.tokenizer, self.helper, multiply_by_inputs=True
),
]
else:
for explainer in explainers:
if explainer.helper is not None:
warnings.warn(f"Overriding helper for explainer {explainer}")
explainer.helper = self.helper
if not evaluators:
self._used_evaluators = [
AOPC_Comprehensiveness_Evaluation,
AOPC_Sufficiency_Evaluation,
TauLOO_Evaluation,
AUPRC_PlausibilityEvaluation,
Tokenf1_PlausibilityEvaluation,
TokenIOU_PlausibilityEvaluation,
]
self.evaluators = [
ev(self.model, self.tokenizer, self.task_name)
for ev in self._used_evaluators
]
if not class_based_evaluators:
self._used_class_evaluators = [AOPC_Comprehensiveness_Evaluation_by_class]
self.class_based_evaluators = [
class_ev(self.model, self.tokenizer, self.task_name)
for class_ev in self._used_class_evaluators
]
############################
# Utilities
############################
def _forward(self, text):
item = self.tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = self.model(**item)
return outputs
[docs]
def score(self, text: str, return_dict: bool = True, **kwargs):
"""Compute prediction scores for a single query
:param text str: query to compute the logits from
:param return_dict bool: return a dict in the format Class Label -> score. Otherwise, return softmaxed logits as torch.Tensor. Default True
"""
return self.helper._score(text, return_dict, **kwargs)
@property
def targets(self):
return self.helper.targets
############################
# Interpretability API
############################
[docs]
def explain(
self,
text,
target=1,
show_progress: bool = True,
normalize_scores: bool = True,
order: int = 1,
target_token: Optional[str] = None,
target_option: Optional[str] = None,
) -> List[Explanation]:
"""
Compute explanations using all the explainers stored in the class.
Parameters
----------
text : str
Text string to explain.
target : int
Class label to produce the explanations for.
show_progress : bool, default False
Enable progress bar.
normalize_scores : bool, default True
Apply lp-normalization across tokens to make attribution weights comparable across different explainers.
order : int, default 1
If *normalize_scores=True*, this is the normalization order, as passed to *numpy.linalg.norm*.
Returns
-------
List[Explanation]
List of all explanations produced.
Notes
-----
Please reference to :ref:`User Guide <explaining>` for more information.
Examples
--------
>>> bench = Benchmark(model, tokenizer)
>>> explanations = bench.explain("I love your style!", target=2)
Please note that by default we apply L1 normalization across tokens, to make feature attribution weights comparable among explainers. To turn it off, you should use:
>>> bench = Benchmark(model, tokenizer)
>>> explanations = bench.explain("I love your style!", target=2, normalize_scores=False)
"""
# sanity check and transformation to integer targets (if required)
# here we are assuming the same target format (e.g., positional integer will work
# for every explanation method. We might need to chage this in the future, when
# we will add new explanation methods.
target = self.helper._check_target(target)
target_token = self.helper._check_target_token(text, target_token)
text = self.helper._check_sample(text)
text = self.helper._prepare_sample(text, target_option=target_option)
# we might optimize running the loop in parallel
explanations = list()
for explainer in tqdm(
self.explainers,
total=len(self.explainers),
desc="Explainer",
leave=False,
disable=not show_progress,
):
exp = explainer(text, target, target_token)
explanations.append(exp)
if normalize_scores:
explanations = lp_normalize(explanations, order)
return explanations
############################
# Evaluation API
############################
[docs]
def evaluate_explanation(
self,
explanation: Union[Explanation, ExplanationWithRationale],
human_rationale=None,
class_explanation: List[Union[Explanation, ExplanationWithRationale]] = None,
show_progress: bool = True,
**evaluation_args,
) -> ExplanationEvaluation:
"""Evaluate an explanation using all the evaluators stored in the class.
Args:
explanation (Union[Explanation, ExplanationWithRationale]): explanation to evaluate.
target (int): class label for which the explanation is evaluated
human_rationale (list): list with values 0 or 1. A value of 1 means that the corresponding token is part of the human (or ground truth) rationale, 0 otherwise. Tokens are indexed by position. The size of the list is the number of tokens.
class_explanation (list): list of explanations. The explanation in position i is computed using as target class the class label i. The size is #target classes. If available, class-based scores are computed.
show_progress (bool): enable progress bar
Returns:
ExplanationEvaluation: the evaluation of the explanation
"""
evaluations = list()
if show_progress:
total_evaluators = (
len(self.evaluators) + len(self.class_based_evaluators)
if class_explanation is not None
else len(self.evaluators)
)
pbar = tqdm(total=total_evaluators, desc="Evaluator", leave=False)
add_first_last = evaluation_args.get("add_first_last", True)
explanation = (
self._add_rationale(explanation, human_rationale, add_first_last)
if human_rationale is not None
else explanation
)
for evaluator in self.evaluators:
evaluation = evaluator.compute_evaluation(explanation, **evaluation_args)
if (
evaluation is not None
): # return None for plausibility measure if rationale is not available
evaluations.append(evaluation)
if show_progress:
pbar.update(1)
if class_explanation is not None:
for class_based_evaluator in self.class_based_evaluators:
class_based_evaluation = class_based_evaluator.compute_evaluation(
class_explanation, **evaluation_args
)
evaluations.append(class_based_evaluation)
if show_progress:
pbar.update(1)
if show_progress:
pbar.close()
explanation_eval = ExplanationEvaluation(explanation, evaluations)
return explanation_eval
[docs]
def evaluate_explanations(
self,
explanations: List[Union[Explanation, ExplanationWithRationale]],
human_rationale=None,
class_explanations=None,
show_progress=True,
**evaluation_args,
) -> List[ExplanationEvaluation]:
"""Evaluate explanations using all the evaluators stored in the class.
Args:
explanation ( List[Union[Explanation, ExplanationWithRationale]]): list of explanations to evaluate.
target (int): class label for which the explanations are evaluated
human rationale (list): one-hot-encoding indicating if the token is in the human rationale (1) or not (0). If available, all explanations are evaluated for the human rationale (if provided)
class_explanation (list): list of list of explanations. The k-th element represents the list of explanations computed varying the target class: the explanation in position k, i is computed using as target class the class label i. The size is # explanation, #target classes. If available, class-based scores are computed.
show_progress (bool): enable progress bar
Returns:
List[ExplanationEvaluation]: the evaluation for each explanation
"""
explanation_evaluations = list()
class_explanations_by_explainer = self._get_class_explanations_by_explainer(
class_explanations
)
if show_progress:
pbar = tqdm(total=len(explanations), desc="Explanation eval", leave=False)
for i, explanation in enumerate(explanations):
class_explanation = None
if class_explanations_by_explainer is not None:
class_explanation = class_explanations_by_explainer[i]
explanation_evaluations.append(
self.evaluate_explanation(
explanation,
human_rationale,
class_explanation,
show_progress=False,
**evaluation_args,
)
)
if show_progress:
pbar.update(1)
if show_progress:
pbar.close()
return explanation_evaluations
def _add_rationale(
self,
explanation: Explanation,
rationale: List,
add_first_last=True,
) -> ExplanationWithRationale:
"""Add the ground truth rationale to the explanation.
Args:
explanation (Explanation): explanation
rationale (list): one-hot-encoding indicating if the token is in the human rationale (1) or not (0)
add_first_last (bool): consider the first and last tokens. Set it to True if the scores of the explanation also include the importance of the first and last tokens (typically cls and eos tokens)
Returns:
ExplanationWithRationale: explanation with the ground truth rationale
"""
if rationale == NONE_RATIONALE:
return explanation
else:
if add_first_last:
# Include the first and last token (0 as default)
rationale = [0] + rationale + [0]
if len(explanation.tokens) != len(rationale):
raise ValueError()
return ExplanationWithRationale(
text=explanation.text,
tokens=explanation.tokens,
scores=explanation.scores,
explainer=explanation.explainer,
target_pos_idx=explanation.target_pos_idx,
helper_type=explanation.helper_type,
target_token_pos_idx=explanation.target_token_pos_idx,
target=explanation.target,
target_token=explanation.target_token,
rationale=rationale,
)
def _get_class_explanations_by_explainer(self, class_explanations):
"""
We convert from #target, #explainer to #explainer, #target
"""
class_explanations_by_explainer = None
if class_explanations is not None:
n_explainers = len(class_explanations[0])
n_targets = len(class_explanations)
class_explanations_by_explainer = [
[class_explanations[i][explainer_type] for i in range(n_targets)]
for explainer_type in range(n_explainers)
]
return class_explanations_by_explainer
##############################
# Dataset API
##############################
[docs]
def load_dataset(self, dataset_name: str, **kwargs):
if dataset_name == "hatexplain":
data = HateXplainDataset(self.tokenizer)
elif dataset_name == "movie_rationales":
data = MovieReviews(self.tokenizer)
elif dataset_name == "sst":
data = SSTDataset(self.tokenizer)
else:
try:
data = datasets.load_dataset(dataset_name)
except:
raise ValueError(f"Dataset {dataset_name} is not supported")
return data
[docs]
def evaluate_samples(
self,
dataset: BaseDataset,
sample: Union[int, List[int]],
target: int = None,
show_progress_bar: bool = True,
n_workers: int = 1,
**evaluation_args,
) -> Dict:
"""Explain a dataset sample, evaluate explanations, and compute average scores.
Args:
dataset (BaseDataset): XAI dataset to explain and evaluate
sample (Union[int, List[int]]): index or list of indexes
target (int): class label for which the explanations are computed and evaluated. If None, explanations are computed and evaluated for the predicted class
show_progress (bool): enable progress bar
n_workers (int) : number of workers
Returns:
Dict : the average evaluation scores and their standard deviation for each explainer. The form is the following: {explainer: {"evaluation_measure": (avg_score, std)}
"""
raise DeprecationWarning(
"This method is deprecated. You can achieve a similar result by computing each individual explanation and evaluation and averaging them."
)
# Use list to index datasets
if isinstance(sample, int):
sample = [sample]
sample = list(map(int, sample))
instances = [dataset[s] for s in sample]
# For the IOU and Token F1 plausibility scores we specify the K for deriving the top-k rationale
# As in DeYoung et al. 2020, we set it as the average size of the human rationales of the dataset
evaluation_args["top_k_rationale"] = dataset.avg_rationale_size
# is_thermostatdata = isinstance(dataset, ThermostatDataset) --> problem with reload
is_thermostatdata = dataset.NAME == "Thermostat"
# Set the explanation target class
if is_thermostatdata:
# The explanations in thermostat are pre-computed for the predicted class
targets = [i["predicted_label"] for i in instances]
else:
# Default, w.r.t. predicted class
if target is None:
# Compute explanations for the predicted class
predicted_classes = [
self.score(i["text"], return_dict=False).argmax(-1).tolist()
for i in instances
]
targets = predicted_classes
else:
targets = [target] * len(sample)
if is_thermostatdata:
name_explainers = dataset.explainers
else:
name_explainers = [e.NAME for e in self.explainers]
if show_progress_bar:
pbar = tqdm(total=len(targets), desc="explain", leave=False)
# Create an empty dict of dict to collect the results
evaluation_scores_by_explainer = {}
for explainer in name_explainers:
evaluation_scores_by_explainer[explainer] = {}
for evaluator in self.evaluators:
evaluation_scores_by_explainer[explainer][evaluator.SHORT_NAME] = []
if n_workers > 1:
raise NotImplementedError()
# Parallel(n_jobs=2)(delayed(sqrt)(i ** 2) for i in range(10))
else:
for instance, target in zip(instances, targets):
# Generate explanations - list of explanations (one for each explainers)
explanations = self.explain(
instance["text"], target, show_progress=False
)
# If available, we add the human rationale
# It will be used in the evaluation of plausibility
if "rationale" in instance and len(instance["rationale"]) > target:
# Add the human rationale for the corresponding class
explanations = [
self._add_rationale(explanation, instance["rationale"][target])
for explanation in explanations
]
for explanation in explanations:
# We evaluate the explanation and we obtain an ExplanationEvaluation
evaluation = self.evaluate_explanation(
explanation, target, show_progress=False, **evaluation_args
)
# We accumulate the results for each explainer
for evaluation_score in evaluation.evaluation_outputs:
evaluation_scores_by_explainer[explanation.explainer][
evaluation_score.name
].append(evaluation_score.score)
if show_progress_bar:
pbar.update(1)
# We compute mean and std, separately for each explainer and evaluator
for explainer in evaluation_scores_by_explainer:
for score_name in list(evaluation_scores_by_explainer[explainer]):
list_scores = evaluation_scores_by_explainer[explainer][score_name]
if list_scores:
# Compute mean and standard deviation
evaluation_scores_by_explainer[explainer][score_name] = (
np.mean(list_scores),
np.std(list_scores),
)
else:
evaluation_scores_by_explainer[explainer].pop(score_name, None)
if show_progress_bar:
pbar.close()
return evaluation_scores_by_explainer
############################
# Visualization API
############################
[docs]
def show_table(
self,
explanations: List[Explanation],
remove_first_last: bool = True,
style: None = "heatmap",
) -> pd.DataFrame:
return show_table(explanations, remove_first_last, style)
[docs]
def show_evaluation_table(
self,
explanation_evaluations: List[ExplanationEvaluation],
style: Optional[str] = "heatmap",
):
return show_evaluation_table(explanation_evaluations, style)
[docs]
def show_samples_evaluation_table(
self,
evaluation_scores_by_explainer,
apply_style: bool = True,
) -> pd.DataFrame:
"""Format average evaluation scores into a colored table.
Args:
evaluation_scores_by_explainer (Dict): the average evaluation scores and their standard deviation for each explainer (output of the evaluate_samples function)
apply_style (bool): color the table of average evaluation scores
Returns:
pd.DataFrame: a colored (styled) pandas dataframe of average evaluation scores of explanations of a sample
"""
raise DeprecationWarning(
"This method has been deprecated. See `show_evaluation_table` for an alternative."
)
# We only vizualize the average
table = pd.DataFrame(
{
explainer: {
evaluator: mean_std[0] for evaluator, mean_std in inner.items()
}
for explainer, inner in evaluation_scores_by_explainer.items()
}
).T
# Avoid visualizing a columns with all nan (default value if plausibility could not computed)
table = table.dropna(axis=1, how="all")
if apply_style:
table_style = self._style_evaluation(table)
return table_style
else:
return table