Skip to content

plexusone/shap-go

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 

History

83 Commits
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

SHAP-Go

Go CI Go Lint Go SAST Go Report Card Coverage Docs Visualization License

A Go library for computing SHAP (SHapley Additive exPlanations) values for ML model explainability.

Overview

SHAP-Go provides a Go-native implementation of SHAP value computation for explaining machine learning model predictions. It supports:

  • ๐Ÿ“ฆ ONNX models via ONNX Runtime bindings
  • โš™๏ธ Custom models via a simple function interface
  • ๐Ÿ”€ Permutation SHAP with antithetic sampling for variance reduction
  • ๐ŸŽฒ Sampling SHAP using Monte Carlo estimation
  • ๐Ÿ“‹ JSON-serializable explanations for audit/compliance

Explainer Types

Status Explainer Model Type Notes
โœ… TreeSHAP Trees Exact & fast (O(TLDยฒ)) for XGBoost, LightGBM, CatBoost; 40-100x faster than permutation
โœ… KernelSHAP Any Black-box, weighted linear regression, model-agnostic baseline
โœ… LinearSHAP Linear Exact closed-form solution for linear/logistic regression
โœ… DeepSHAP Neural Nets Combines DeepLIFT with Shapley values, efficient for deep networks
โœ… GradientSHAP Any Expected gradients using numerical differentiation, works with any differentiable model
โœ… SamplingSHAP Any Monte Carlo approximation, fast, good for quick estimates
โœ… PermutationSHAP Any Black-box, antithetic sampling for variance reduction, guarantees local accuracy
โœ… ExactSHAP Any Brute-force exact computation, O(2^n) - only for small feature sets (โ‰ค15)
โœ… PartitionSHAP Structured Hierarchical Owen values for feature groupings, respects domain structure
โœ… AdditiveSHAP GAMs Exact SHAP for Generalized Additive Models, O(nร—b) complexity

Legend

  • โœ… Implemented
  • โฌœ Not yet implemented

Choosing an Explainer

Use Case Recommended Explainer
Tree-based models (XGBoost, LightGBM) TreeSHAP โœ…
Linear/logistic regression LinearSHAP โœ…
Any model, need guaranteed accuracy PermutationSHAP โœ…
Any model, weighted regression baseline KernelSHAP โœ…
Any model, quick estimates SamplingSHAP โœ…
Small feature sets (โ‰ค15 features) ExactSHAP โœ…
Deep learning models (ONNX) DeepSHAP โœ…
Differentiable models, gradient-based GradientSHAP โœ…
Grouped/structured features PartitionSHAP โœ…
Generalized Additive Models (GAMs) AdditiveSHAP โœ…

Installation

go get github.com/plexusone/shap-go

Quick Start

package main

import (
    "context"
    "fmt"

    "github.com/plexusone/shap-go/explainer"
    "github.com/plexusone/shap-go/explainer/permutation"
    "github.com/plexusone/shap-go/model"
)

func main() {
    // Define a simple model
    predict := func(ctx context.Context, input []float64) (float64, error) {
        return input[0] + 2*input[1], nil
    }
    m := model.NewFuncModel(predict, 2)

    // Background data for SHAP computation
    background := [][]float64{
        {0.0, 0.0},
    }

    // Create explainer
    exp, _ := permutation.New(m, background,
        explainer.WithNumSamples(100),
        explainer.WithFeatureNames([]string{"x1", "x2"}),
    )

    // Explain a prediction
    ctx := context.Background()
    explanation, _ := exp.Explain(ctx, []float64{1.0, 2.0})

    fmt.Printf("Prediction: %.2f\n", explanation.Prediction)
    fmt.Printf("Base Value: %.2f\n", explanation.BaseValue)
    for name, shap := range explanation.Values {
        fmt.Printf("SHAP(%s): %.4f\n", name, shap)
    }

    // Verify local accuracy
    result := explanation.Verify(1e-10)
    fmt.Printf("Local accuracy valid: %v\n", result.Valid)
}

TreeSHAP for XGBoost/LightGBM

TreeSHAP computes exact SHAP values in O(TLDยฒ) time, where T=trees, L=leaves, D=depth. This is 40-100x faster than permutation-based methods for typical tree ensembles.

XGBoost Example

package main

import (
    "context"
    "fmt"
    "log"

    "github.com/plexusone/shap-go/explainer/tree"
)

func main() {
    // Load XGBoost model (saved with model.save_model("model.json"))
    ensemble, err := tree.LoadXGBoostModel("model.json")
    if err != nil {
        log.Fatal(err)
    }

    // Create TreeSHAP explainer
    explainer, err := tree.New(ensemble)
    if err != nil {
        log.Fatal(err)
    }

    // Explain a prediction
    ctx := context.Background()
    instance := []float64{0.5, 0.3, 0.8}
    explanation, err := explainer.Explain(ctx, instance)
    if err != nil {
        log.Fatal(err)
    }

    fmt.Printf("Prediction: %.4f\n", explanation.Prediction)
    fmt.Printf("Base Value: %.4f\n", explanation.BaseValue)
    for _, feat := range explanation.TopFeatures(10) {
        fmt.Printf("  %s: %.4f\n", feat.Name, feat.Value)
    }
}

LightGBM Example

// Load LightGBM JSON model (saved with booster.dump_model())
ensemble, err := tree.LoadLightGBMModel("model.json")
if err != nil {
    log.Fatal(err)
}

// Or load text format (saved with booster.save_model())
ensemble, err := tree.LoadLightGBMTextModel("model.txt")

explainer, err := tree.New(ensemble)
// ... same as XGBoost

Python: Export Models for Go

XGBoost:

import xgboost as xgb

model = xgb.Booster()
model.load_model("model.bin")
model.save_model("model.json")  # JSON format for Go

LightGBM:

import lightgbm as lgb
import json

model = lgb.Booster(model_file="model.txt")
with open("model.json", "w") as f:
    json.dump(model.dump_model(), f)

Batch Processing

// Explain multiple instances in parallel
instances := [][]float64{
    {0.1, 0.2, 0.3},
    {0.4, 0.5, 0.6},
    {0.7, 0.8, 0.9},
}
explanations, err := explainer.ExplainBatch(ctx, instances)

Interaction Values

TreeSHAP can compute pairwise feature interactions, revealing how features work together:

// Compute SHAP interaction values
result, err := explainer.ExplainInteractions(ctx, instance)
if err != nil {
    log.Fatal(err)
}

// Interaction matrix properties:
// - Diagonal Interactions[i][i]: main effect of feature i
// - Off-diagonal Interactions[i][j]: interaction between features i and j
// - Symmetric: Interactions[i][j] == Interactions[j][i]
// - Rows sum to SHAP values: sum(Interactions[i][:]) == SHAP[i]

// Get the interaction between two features
interaction := result.GetInteraction(0, 1)

// Get main effect (diagonal)
mainEffect := result.GetMainEffect(0)

// Get derived SHAP value (row sum)
shapValue := result.GetSHAPValue(0)

// Get top k strongest interactions
topK := result.TopInteractions(5)
for _, inter := range topK {
    fmt.Printf("%s <-> %s: %.4f\n", inter.Name1, inter.Name2, inter.Value)
}

Packages

explanation

Core types for SHAP explanations:

  • ๐Ÿ“Š Explanation - Contains prediction, base value, SHAP values, and metadata
  • โœ”๏ธ Verify() - Checks local accuracy (sum of SHAP values = prediction - base)
  • ๐Ÿ” TopFeatures() - Returns features sorted by absolute SHAP value
  • ๐Ÿ“„ JSON serialization with ToJSON() and FromJSON()

model

Model interface for SHAP computation:

  • ๐Ÿ”Œ Model interface with Predict(), PredictBatch(), and NumFeatures()
  • ๐Ÿ› ๏ธ FuncModel - Wraps a prediction function as a Model

model/onnx

ONNX Runtime wrapper:

  • ๐Ÿ”— Session - Wraps an ONNX Runtime session
  • ๐Ÿ“ฆ Supports batch predictions
  • ๐Ÿ“š Requires ONNX Runtime shared library

explainer/tree

TreeSHAP for tree-based models:

  • ๐ŸŽฏ Exact SHAP values (not approximations)
  • โšก O(TLDยฒ) complexity - 40-100x faster than permutation
  • ๐ŸŒฒ XGBoost JSON model support
  • ๐Ÿ’ก LightGBM JSON and text format support
  • ๐Ÿ”„ Parallel batch processing
  • ๐Ÿ”— Interaction values for pairwise feature interactions

explainer/linear

LinearSHAP for linear models:

  • ๐ŸŽฏ Exact closed-form solution: SHAP[i] = coef[i] * (x[i] - E[X[i]])
  • โšก O(d) complexity where d is number of features
  • ๐Ÿ“ˆ Support for linear regression and logistic regression

explainer/kernel

KernelSHAP for model-agnostic explanations:

  • ๐Ÿ”ฎ Model-agnostic black-box method
  • โš–๏ธ Weighted linear regression on binary coalition masks
  • ๐Ÿงฎ SHAP kernel weights: (d-1) / (C(d,k) * k * (d-k))
  • โœ… Validated against Python SHAP library

explainer/exact

ExactSHAP for brute-force exact Shapley values:

  • ๐ŸŽฏ Mathematically exact values by enumerating all 2^n coalitions
  • โฑ๏ธ O(n * 2^n) complexity - only practical for โ‰ค15 features
  • ๐Ÿ” Useful for validating other SHAP implementations
  • ๐Ÿ“ Reference implementation for small feature sets

explainer/deepshap

DeepSHAP for neural network explanations:

  • ๐Ÿง  Combines DeepLIFT with Shapley values for efficient neural network attribution
  • ๐Ÿ”— Works with ONNX models via model/onnx ActivationSession
  • โšก Efficient backward propagation using DeepLIFT rescale rule
  • ๐Ÿ“Š Supports Dense, ReLU, Sigmoid, Tanh, Softmax layers

explainer/permutation

Permutation SHAP with antithetic sampling:

  • โœ… Guarantees local accuracy
  • ๐Ÿ”„ Supports parallel computation
  • ๐Ÿ“‰ Lower variance than pure Monte Carlo

explainer/sampling

Monte Carlo sampling SHAP:

  • ๐Ÿ› ๏ธ Simple implementation
  • โšก Good for quick estimates

masker

Feature masking strategies:

  • ๐ŸŽญ IndependentMasker - Marginal/independent masking using background samples

background

Background dataset management:

  • ๐Ÿ“‚ Dataset loading and statistics
  • ๐ŸŽฒ Random sampling and k-means summarization

Algorithms

Permutation SHAP with Antithetic Sampling

The permutation explainer uses antithetic sampling for variance reduction:

  1. For each permutation sample:

    • โ–ถ๏ธ Forward pass: Start with background, add features one by one
    • โ—€๏ธ Reverse pass: Start with instance, remove features one by one
    • โš–๏ธ Average contributions from both passes
  2. Average over all permutation samples

This guarantees that SHAP values sum exactly to (prediction - base value).

Sampling SHAP

The sampling explainer uses simple Monte Carlo:

  1. ๐Ÿ”€ Generate random permutations
  2. ๐Ÿ“Š For each permutation, compute marginal contributions
  3. โš–๏ธ Average over all samples

Configuration Options

exp, err := permutation.New(model, background,
    explainer.WithNumSamples(100),     // Number of permutation samples
    explainer.WithSeed(42),            // Random seed for reproducibility
    explainer.WithNumWorkers(4),       // Parallel workers
    explainer.WithFeatureNames(names), // Feature names
    explainer.WithModelID("my-model"), // Model identifier
)

ONNX Runtime Usage

import "github.com/plexusone/shap-go/model/onnx"

// Initialize ONNX Runtime
onnx.InitializeRuntime("/path/to/libonnxruntime.so")
defer onnx.DestroyRuntime()

// Create session
session, err := onnx.NewSession(onnx.Config{
    ModelPath:   "model.onnx",
    InputName:   "float_input",
    OutputName:  "probabilities",
    NumFeatures: 10,
})
defer session.Close()

// Use with explainer
exp, err := permutation.New(session, background)

Local Accuracy Verification

Every SHAP explanation should satisfy local accuracy:

sum(SHAP values) = prediction - base_value

You can verify this with:

result := explanation.Verify(tolerance)
if !result.Valid {
    fmt.Printf("Local accuracy failed: difference = %f\n", result.Difference)
}

Benchmarks

Performance benchmarks on Apple M1 Max (arm64):

TreeSHAP Scaling

Configuration Time/op Allocs/op
10 trees, depth 4, 10 features 20ฮผs 372
100 trees, depth 4, 10 features 194ฮผs 3,612
1000 trees, depth 4, 10 features 1.9ms 36,012
Tree Depth Time/op Notes
Depth 3 39ฮผs Shallow trees
Depth 6 598ฮผs Typical production depth
Depth 10 13.2ms Very deep trees

TreeSHAP vs PermutationSHAP

Method Time/op Type
TreeSHAP 8.8ฮผs Exact
PermutationSHAP (10 samples) 16ฮผs Approximate
PermutationSHAP (50 samples) 77ฮผs Approximate
PermutationSHAP (100 samples) 153ฮผs Approximate

TreeSHAP is ~17x faster than PermutationSHAP with 100 samples while providing exact values.

Realistic Model Sizes

Model Size Trees Depth Features Time/op
Small 50 4 10 106ฮผs
Medium 200 6 30 2.7ms
Large 500 8 50 31.7ms

Batch Processing

Workers 100 instances Speedup
1 (sequential) 10.2ms 1.0x
4 (parallel) 8.0ms 1.3x
8 (parallel) 8.1ms 1.3x

Run benchmarks with:

go test -bench=. -benchmem ./explainer/tree/...

Examples

The examples/ directory contains working examples:

Example Description
examples/linear PermutationSHAP with a simple linear model
examples/linearshap LinearSHAP for linear/logistic regression
examples/treeshap TreeSHAP with manually constructed tree ensembles
examples/kernelshap KernelSHAP weighted linear regression explainer
examples/sampling SamplingSHAP Monte Carlo approximation
examples/onnx_basic ONNX model with KernelSHAP explanations
examples/deepshap DeepSHAP for neural network explanations
examples/batch Batch processing with parallel workers
examples/visualization Generating chart data for visualizations
examples/markdown_report Generate Markdown reports with SHAP explanations

Run an example:

go run ./examples/linear
go run ./examples/linearshap
go run ./examples/treeshap
go run ./examples/kernelshap
go run ./examples/sampling
go run ./examples/batch
go run ./examples/visualization
go run ./examples/markdown_report

# ONNX examples (requires ONNX Runtime and model files)
cd examples/onnx_basic && python generate_model.py && go run main.go
cd examples/deepshap && python generate_model.py && go run main.go

License

MIT License

About

A Go library for computing SHAP (SHapley Additive exPlanations) values for ML model explainability.

Resources

Contributing

Stars

Watchers

Forks

Contributors