Explainable AI (xAI) Tutorials - NeurIPS 2025¶
Explainable AI - Deep Dive 1¶
Tutorial Overview
- Tutorial page: link
- AI explainability evolution across three eras:
- Before 2014: Linear models and trees for explanation
- 2014-2020: Interpretable models (feature) → data attribution interpretable (DNNs)
- After 2022: Component attribution (LLM era)
- Technical Deep Dive covers three attribution types:
- Feature attribution
- Data attribution
- Component attribution
Attribution Problem
- General framework: Training data → model → output
- Three perspectives for explaining AI system outputs
- Unified mathematical notation:
- Feature attribution scores: φᵢ
- Data attribution scores: ψⱼ
- Component attribution scores: γₖ
Feature Attribution
- Core question: How do features impact the output?
- Applications:
- Justify predictions and provide counterfactual explanations
- Identify spurious correlations (e.g., husky classified as wolf based on snow background)
- Example: Loan application model explaining denial based on salary, credit score vs inappropriate reliance on gender
Data Attribution
- Core question: Why this output for those training data points?
- Studies how training data influences model output
- Applications:
- Characterize training data properties
- Determine data values and identify harmful training examples
- Example: Fish classification traced back to semantically similar training image with coral background
Component Attribution
- Core question: Why this output for these model components?
- Components can be: neurons, attention heads, layers, subnetworks
- Example: Language model answering “When Mary and John went to the store, John gave…” → “Mary”
- Sparse attention head activation map shows only small subset needed for indirect object identification
Perturbation-Based Feature Attribution
- Direct perturbation:
- Perturb features and observe output changes
- Problem: Feature interactions require considering all possible subsets
- Game theoretic perturbation:
- SHAP method using Shapley values
- Considers all 2^d marginal contributions
- Computational complexity: O(2^d) - not scalable for large feature sets
- Perturbation mask learning:
- Continuous and learnable masks instead of binary
- Generates saliency maps for computer vision
- Learned masking model applicable across multiple inputs
Gradient-Based Feature Attribution
- Key distinction: Feature gradients for attribution vs parameter gradients for training
- Measures output sensitivity with respect to input features
- SmoothGrad method:
- Adds noise to create multiple input versions
- Aggregates gradients across noisy versions
- More robust than vanilla gradients
- Produces increasingly intuitive saliency maps as methods evolved
Linear Approximation for Feature Attribution
- LIME (Local Interpretable Model-agnostic Explanations):
- Uses linear model for local approximation around decision boundary
- Trains on binary indicators (feature present/absent) rather than actual input values
- Linear coefficients become attribution scores
- Successfully identified husky-wolf misclassification based on background snow
Data Attribution Methods
Perturbation-Based Data Attribution
- Leave-one-out (direct perturbation):
- Remove one training data point, retrain model, observe changes
- Computationally expensive: requires n retraining cycles
- Game-theoretic perturbation:
- Data Shapley algorithm
- Requires 2^n marginal contributions with retraining
- Only approximate versions practical
Gradient-based Data Attribution
- No perturbation required - avoids retraining
- Gradient similarity method:
- Compute gradients for test and training points
- Dot product measures similarity between gradient representations
- Problem: No causal interpretation, only similarity measure
- Influence functions:
- Approximates leave-one-out computationally efficiently
- Introduces Hessian matrix for second-order information
- Mathematical derivation recovers leave-one-out with modified training objective
Linear Approximation for Data Attribution
- Datamodel approach:
- Skip training step - directly predict model output from training data
- Linear model G approximates relationship: training data → test output
- Counterfactual data collection:
- Each data point: subset of training data + prediction from model trained on that subset
- Binary indicator vector Z shows which training points included
- Linear coefficients become attribution scores
Perturbation-Based Component Attribution
- Causal mediation analysis:
- Replace components with dummy values, observe output changes
- Neural Shapley applies game theory to capture component interactions
- Mask learning and subnetwork probing:
- Learnable continuous masks for component selection
- Optimize mask to recover original output while identifying important components
- Causal tracing (three-run patching):
- Clean input → model → output
- Perturbed input → model → baseline output
- Perturbed input + restored component K → model → recovered output
- Attribution score: difference between run 3 and run 2
- Target perturbation:
- Control model behavior by optimizing component values
- Example: Change “capital of France” answer from Paris to London by modifying identified component
Gradient-Based Component Attribution
- Approximates three-run patching paradigm using Taylor approximation
- Efficiency advantage: Batches multiple inputs in single forward/backward pass
- Avoids expensive instance-wise patching step
- Gradient of perturbed output with respect to component × component value difference
Linear Approximation for Component Attribution
- Direct prediction: model components → test output
- Linear function G locally approximates component influence
- Counterfactual data collection from different component subsets
- Coefficients become component attribution scores
Unified Framework
- Three attribution problems solved by three method categories:
- Perturbations (direct, game-theoretic, mask learning)
- Gradients (similarity, influence functions, Taylor approximation)
- Linear approximations (LIME, datamodel, component linear models)
- Additional methods exist beyond these three categories
- Mechanistic interpretability includes: sparse autoencoders, logit lens, linear probing
XAI - Deep Dive 2¶
Inherent Interpretability Framework
- Goal: Design interpretable yet performant language models at scale
- Alternative to post-hoc explanation methods (gradients, probes, influence functions)
- Core approach: Add interpretability constraints during training pipeline
- Data constraints: Reprocess datasets for human understanding
- Architecture constraints: Modify transformer layers for traceability
- Representation constraints: Force interpretable concept encoding
- Training constraints: Add interpretability losses to standard task loss
Concept-Constrained Interpretable Models
- Architecture modification: Replace transformer layer with interpretable transformation
- Maps representations to basis of known concepts (blue neurons) vs unknown (black neurons)
- Example: Paper review system with clarity, novelty, significance concepts
- Loss function structure: L_total = L_task + λ L_interp
- Task loss: Standard next token prediction
- Interpretability loss: Forces specific neurons to represent target concepts
- Scaling results: Achieves comparable performance to GPT models at billions of parameters
- Tested on 33K supervised concepts, 160K unsupervised concepts
- Trained on billions of tokens with minimal performance degradation
Post-Hoc Method Limitations
- Feature attribution mismatch
- Gradient-based methods often contradict occlusion analysis
- Example: Amino acid sequence task where gradients highlight distractors, not task-relevant features
- Concept probing challenges
- Can identify features in activations but not causal relevance to output
- Spurious correlations between causal and irrelevant features mislead probes
- Training data attribution complexity
- Influence functions require Hessian computation (billion × billion matrices)
- Computational intractability and convexity assumptions limit practical application
Training Solutions for Interpretability
- Input masking during training
- Randomly mask inputs to force robustness
- Aligns gradient behavior with human-expected occlusion patterns
- Makes models smooth and differentiable for better gradient interpretability
- Adversarial training
- Train on adversarial examples for off-manifold robustness
- Similar alignment benefits between gradients and perturbation analysis
- Architecture modifications
- Backpack Language Models: Rewrite transformer as generalized additive model
- Split token embeddings into sense vectors (fruit Apple vs company Apple)
- Enables surgical intervention and concept toggling
Scaling and Performance Results
- Concept-constrained models scale to billions of parameters with <2% performance drop
- Training data attribution achievable in single forward pass (no Hessian computation)
- Prototype-based clustering losses enable direct tracing from outputs to training data
- Maintains competitive performance on standard LM benchmarks while providing interpretability