A Go library for computing SHAP (SHapley Additive exPlanations) values for ML model explainability.
SHAP-Go provides a Go-native implementation of SHAP value computation for explaining machine learning model predictions. It supports:
- ๐ฆ ONNX models via ONNX Runtime bindings
- โ๏ธ Custom models via a simple function interface
- ๐ Permutation SHAP with antithetic sampling for variance reduction
- ๐ฒ Sampling SHAP using Monte Carlo estimation
- ๐ JSON-serializable explanations for audit/compliance
| Status | Explainer | Model Type | Notes |
|---|---|---|---|
| โ | TreeSHAP | Trees | Exact & fast (O(TLDยฒ)) for XGBoost, LightGBM, CatBoost; 40-100x faster than permutation |
| โ | KernelSHAP | Any | Black-box, weighted linear regression, model-agnostic baseline |
| โ | LinearSHAP | Linear | Exact closed-form solution for linear/logistic regression |
| โ | DeepSHAP | Neural Nets | Combines DeepLIFT with Shapley values, efficient for deep networks |
| โ | GradientSHAP | Any | Expected gradients using numerical differentiation, works with any differentiable model |
| โ | SamplingSHAP | Any | Monte Carlo approximation, fast, good for quick estimates |
| โ | PermutationSHAP | Any | Black-box, antithetic sampling for variance reduction, guarantees local accuracy |
| โ | ExactSHAP | Any | Brute-force exact computation, O(2^n) - only for small feature sets (โค15) |
| โ | PartitionSHAP | Structured | Hierarchical Owen values for feature groupings, respects domain structure |
| โ | AdditiveSHAP | GAMs | Exact SHAP for Generalized Additive Models, O(nรb) complexity |
- โ Implemented
- โฌ Not yet implemented
| Use Case | Recommended Explainer |
|---|---|
| Tree-based models (XGBoost, LightGBM) | TreeSHAP โ |
| Linear/logistic regression | LinearSHAP โ |
| Any model, need guaranteed accuracy | PermutationSHAP โ |
| Any model, weighted regression baseline | KernelSHAP โ |
| Any model, quick estimates | SamplingSHAP โ |
| Small feature sets (โค15 features) | ExactSHAP โ |
| Deep learning models (ONNX) | DeepSHAP โ |
| Differentiable models, gradient-based | GradientSHAP โ |
| Grouped/structured features | PartitionSHAP โ |
| Generalized Additive Models (GAMs) | AdditiveSHAP โ |
go get github.com/plexusone/shap-gopackage main
import (
"context"
"fmt"
"github.com/plexusone/shap-go/explainer"
"github.com/plexusone/shap-go/explainer/permutation"
"github.com/plexusone/shap-go/model"
)
func main() {
// Define a simple model
predict := func(ctx context.Context, input []float64) (float64, error) {
return input[0] + 2*input[1], nil
}
m := model.NewFuncModel(predict, 2)
// Background data for SHAP computation
background := [][]float64{
{0.0, 0.0},
}
// Create explainer
exp, _ := permutation.New(m, background,
explainer.WithNumSamples(100),
explainer.WithFeatureNames([]string{"x1", "x2"}),
)
// Explain a prediction
ctx := context.Background()
explanation, _ := exp.Explain(ctx, []float64{1.0, 2.0})
fmt.Printf("Prediction: %.2f\n", explanation.Prediction)
fmt.Printf("Base Value: %.2f\n", explanation.BaseValue)
for name, shap := range explanation.Values {
fmt.Printf("SHAP(%s): %.4f\n", name, shap)
}
// Verify local accuracy
result := explanation.Verify(1e-10)
fmt.Printf("Local accuracy valid: %v\n", result.Valid)
}TreeSHAP computes exact SHAP values in O(TLDยฒ) time, where T=trees, L=leaves, D=depth. This is 40-100x faster than permutation-based methods for typical tree ensembles.
package main
import (
"context"
"fmt"
"log"
"github.com/plexusone/shap-go/explainer/tree"
)
func main() {
// Load XGBoost model (saved with model.save_model("model.json"))
ensemble, err := tree.LoadXGBoostModel("model.json")
if err != nil {
log.Fatal(err)
}
// Create TreeSHAP explainer
explainer, err := tree.New(ensemble)
if err != nil {
log.Fatal(err)
}
// Explain a prediction
ctx := context.Background()
instance := []float64{0.5, 0.3, 0.8}
explanation, err := explainer.Explain(ctx, instance)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Prediction: %.4f\n", explanation.Prediction)
fmt.Printf("Base Value: %.4f\n", explanation.BaseValue)
for _, feat := range explanation.TopFeatures(10) {
fmt.Printf(" %s: %.4f\n", feat.Name, feat.Value)
}
}// Load LightGBM JSON model (saved with booster.dump_model())
ensemble, err := tree.LoadLightGBMModel("model.json")
if err != nil {
log.Fatal(err)
}
// Or load text format (saved with booster.save_model())
ensemble, err := tree.LoadLightGBMTextModel("model.txt")
explainer, err := tree.New(ensemble)
// ... same as XGBoostXGBoost:
import xgboost as xgb
model = xgb.Booster()
model.load_model("model.bin")
model.save_model("model.json") # JSON format for GoLightGBM:
import lightgbm as lgb
import json
model = lgb.Booster(model_file="model.txt")
with open("model.json", "w") as f:
json.dump(model.dump_model(), f)// Explain multiple instances in parallel
instances := [][]float64{
{0.1, 0.2, 0.3},
{0.4, 0.5, 0.6},
{0.7, 0.8, 0.9},
}
explanations, err := explainer.ExplainBatch(ctx, instances)TreeSHAP can compute pairwise feature interactions, revealing how features work together:
// Compute SHAP interaction values
result, err := explainer.ExplainInteractions(ctx, instance)
if err != nil {
log.Fatal(err)
}
// Interaction matrix properties:
// - Diagonal Interactions[i][i]: main effect of feature i
// - Off-diagonal Interactions[i][j]: interaction between features i and j
// - Symmetric: Interactions[i][j] == Interactions[j][i]
// - Rows sum to SHAP values: sum(Interactions[i][:]) == SHAP[i]
// Get the interaction between two features
interaction := result.GetInteraction(0, 1)
// Get main effect (diagonal)
mainEffect := result.GetMainEffect(0)
// Get derived SHAP value (row sum)
shapValue := result.GetSHAPValue(0)
// Get top k strongest interactions
topK := result.TopInteractions(5)
for _, inter := range topK {
fmt.Printf("%s <-> %s: %.4f\n", inter.Name1, inter.Name2, inter.Value)
}Core types for SHAP explanations:
- ๐
Explanation- Contains prediction, base value, SHAP values, and metadata - โ๏ธ
Verify()- Checks local accuracy (sum of SHAP values = prediction - base) - ๐
TopFeatures()- Returns features sorted by absolute SHAP value - ๐ JSON serialization with
ToJSON()andFromJSON()
Model interface for SHAP computation:
- ๐
Modelinterface withPredict(),PredictBatch(), andNumFeatures() - ๐ ๏ธ
FuncModel- Wraps a prediction function as a Model
ONNX Runtime wrapper:
- ๐
Session- Wraps an ONNX Runtime session - ๐ฆ Supports batch predictions
- ๐ Requires ONNX Runtime shared library
TreeSHAP for tree-based models:
- ๐ฏ Exact SHAP values (not approximations)
- โก O(TLDยฒ) complexity - 40-100x faster than permutation
- ๐ฒ XGBoost JSON model support
- ๐ก LightGBM JSON and text format support
- ๐ Parallel batch processing
- ๐ Interaction values for pairwise feature interactions
LinearSHAP for linear models:
- ๐ฏ Exact closed-form solution:
SHAP[i] = coef[i] * (x[i] - E[X[i]]) - โก O(d) complexity where d is number of features
- ๐ Support for linear regression and logistic regression
KernelSHAP for model-agnostic explanations:
- ๐ฎ Model-agnostic black-box method
- โ๏ธ Weighted linear regression on binary coalition masks
- ๐งฎ SHAP kernel weights:
(d-1) / (C(d,k) * k * (d-k)) - โ Validated against Python SHAP library
ExactSHAP for brute-force exact Shapley values:
- ๐ฏ Mathematically exact values by enumerating all 2^n coalitions
- โฑ๏ธ O(n * 2^n) complexity - only practical for โค15 features
- ๐ Useful for validating other SHAP implementations
- ๐ Reference implementation for small feature sets
DeepSHAP for neural network explanations:
- ๐ง Combines DeepLIFT with Shapley values for efficient neural network attribution
- ๐ Works with ONNX models via
model/onnxActivationSession - โก Efficient backward propagation using DeepLIFT rescale rule
- ๐ Supports Dense, ReLU, Sigmoid, Tanh, Softmax layers
Permutation SHAP with antithetic sampling:
- โ Guarantees local accuracy
- ๐ Supports parallel computation
- ๐ Lower variance than pure Monte Carlo
Monte Carlo sampling SHAP:
- ๐ ๏ธ Simple implementation
- โก Good for quick estimates
Feature masking strategies:
- ๐ญ
IndependentMasker- Marginal/independent masking using background samples
Background dataset management:
- ๐ Dataset loading and statistics
- ๐ฒ Random sampling and k-means summarization
The permutation explainer uses antithetic sampling for variance reduction:
-
For each permutation sample:
โถ๏ธ Forward pass: Start with background, add features one by oneโ๏ธ Reverse pass: Start with instance, remove features one by one- โ๏ธ Average contributions from both passes
-
Average over all permutation samples
This guarantees that SHAP values sum exactly to (prediction - base value).
The sampling explainer uses simple Monte Carlo:
- ๐ Generate random permutations
- ๐ For each permutation, compute marginal contributions
- โ๏ธ Average over all samples
exp, err := permutation.New(model, background,
explainer.WithNumSamples(100), // Number of permutation samples
explainer.WithSeed(42), // Random seed for reproducibility
explainer.WithNumWorkers(4), // Parallel workers
explainer.WithFeatureNames(names), // Feature names
explainer.WithModelID("my-model"), // Model identifier
)import "github.com/plexusone/shap-go/model/onnx"
// Initialize ONNX Runtime
onnx.InitializeRuntime("/path/to/libonnxruntime.so")
defer onnx.DestroyRuntime()
// Create session
session, err := onnx.NewSession(onnx.Config{
ModelPath: "model.onnx",
InputName: "float_input",
OutputName: "probabilities",
NumFeatures: 10,
})
defer session.Close()
// Use with explainer
exp, err := permutation.New(session, background)Every SHAP explanation should satisfy local accuracy:
sum(SHAP values) = prediction - base_value
You can verify this with:
result := explanation.Verify(tolerance)
if !result.Valid {
fmt.Printf("Local accuracy failed: difference = %f\n", result.Difference)
}Performance benchmarks on Apple M1 Max (arm64):
| Configuration | Time/op | Allocs/op |
|---|---|---|
| 10 trees, depth 4, 10 features | 20ฮผs | 372 |
| 100 trees, depth 4, 10 features | 194ฮผs | 3,612 |
| 1000 trees, depth 4, 10 features | 1.9ms | 36,012 |
| Tree Depth | Time/op | Notes |
|---|---|---|
| Depth 3 | 39ฮผs | Shallow trees |
| Depth 6 | 598ฮผs | Typical production depth |
| Depth 10 | 13.2ms | Very deep trees |
| Method | Time/op | Type |
|---|---|---|
| TreeSHAP | 8.8ฮผs | Exact |
| PermutationSHAP (10 samples) | 16ฮผs | Approximate |
| PermutationSHAP (50 samples) | 77ฮผs | Approximate |
| PermutationSHAP (100 samples) | 153ฮผs | Approximate |
TreeSHAP is ~17x faster than PermutationSHAP with 100 samples while providing exact values.
| Model Size | Trees | Depth | Features | Time/op |
|---|---|---|---|---|
| Small | 50 | 4 | 10 | 106ฮผs |
| Medium | 200 | 6 | 30 | 2.7ms |
| Large | 500 | 8 | 50 | 31.7ms |
| Workers | 100 instances | Speedup |
|---|---|---|
| 1 (sequential) | 10.2ms | 1.0x |
| 4 (parallel) | 8.0ms | 1.3x |
| 8 (parallel) | 8.1ms | 1.3x |
Run benchmarks with:
go test -bench=. -benchmem ./explainer/tree/...The examples/ directory contains working examples:
| Example | Description |
|---|---|
examples/linear |
PermutationSHAP with a simple linear model |
examples/linearshap |
LinearSHAP for linear/logistic regression |
examples/treeshap |
TreeSHAP with manually constructed tree ensembles |
examples/kernelshap |
KernelSHAP weighted linear regression explainer |
examples/sampling |
SamplingSHAP Monte Carlo approximation |
examples/onnx_basic |
ONNX model with KernelSHAP explanations |
examples/deepshap |
DeepSHAP for neural network explanations |
examples/batch |
Batch processing with parallel workers |
examples/visualization |
Generating chart data for visualizations |
examples/markdown_report |
Generate Markdown reports with SHAP explanations |
Run an example:
go run ./examples/linear
go run ./examples/linearshap
go run ./examples/treeshap
go run ./examples/kernelshap
go run ./examples/sampling
go run ./examples/batch
go run ./examples/visualization
go run ./examples/markdown_report
# ONNX examples (requires ONNX Runtime and model files)
cd examples/onnx_basic && python generate_model.py && go run main.go
cd examples/deepshap && python generate_model.py && go run main.goMIT License