diff --git a/docs/Project.toml b/docs/Project.toml index 241acbf2..b8526875 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,6 +4,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" GraphPPL = "b3f8163a-e979-4e85-b43e-1f63d8c8b42c" GraphPlot = "a2cc645c-3eea-5389-862e-a155d0052231" +GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0" [compat] Documenter = "1.0" diff --git a/docs/make.jl b/docs/make.jl index 20cab497..5899281c 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -6,7 +6,7 @@ makedocs( modules = [GraphPPL], clean = true, sitename = "GraphPPL.jl", - pages = ["Home" => "index.md", "Getting Started" => "getting_started.md", "Syntax Guide" => "syntax_guide.md", "Nested Models" => "nested_models.md", "Plugins" => ["Overview" => "plugins/overview.md", "Variational Inference & Constraints" => "plugins/constraint_specification.md", "Attaching metadata to nodes" => "plugins/meta_specification.md", "Tracking creation of nodes" => "plugins/created_by.md", "Setting tag of nodes" => "plugins/node_tag.md", "Setting ID of nodes" => "plugins/node_id.md"], "Migration Guide (from v3 to v4)" => "migration_3_to_4.md", "Developers Guide" => "developers_guide.md", "Custom backend" => "custom_backend.md"], + pages = ["Home" => "index.md", "Getting Started" => "getting_started.md", "Syntax Guide" => "syntax_guide.md", "Nested Models" => "nested_models.md", "Visualization" => "visualization.md", "Plugins" => ["Overview" => "plugins/overview.md", "Variational Inference & Constraints" => "plugins/constraint_specification.md", "Attaching metadata to nodes" => "plugins/meta_specification.md", "Tracking creation of nodes" => "plugins/created_by.md", "Setting tag of nodes" => "plugins/node_tag.md", "Setting ID of nodes" => "plugins/node_id.md"], "Migration Guide (from v3 to v4)" => "migration_3_to_4.md", "Developers Guide" => "developers_guide.md", "Custom backend" => "custom_backend.md"], format = Documenter.HTML(prettyurls = get(ENV, "CI", nothing) == "true"), warnonly = false ) diff --git a/docs/src/getting_started.md b/docs/src/getting_started.md index 52567925..7981acaf 100644 --- a/docs/src/getting_started.md +++ b/docs/src/getting_started.md @@ -135,7 +135,7 @@ data_for_x = [ 1.0, 0.0, 0.0, 1.0 ] model = GraphPPL.create_model(coin_toss()) do model, context return (; - # This expression creates data handle for `x` in the model using the `xdata` as the underlying collection + # This expression creates data handle for `x` in the model using the `data_for_x` as the underlying collection x = GraphPPL.datalabel(model, context, GraphPPL.NodeCreationOptions(kind = GraphPPL.VariableKindData), :x, data_for_x) ) end diff --git a/docs/src/index.md b/docs/src/index.md index 5e7dab18..d990b9f1 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -11,12 +11,11 @@ For inference, you may need a `GraphPPL.jl` compatible package, for example [`Rx Pages = [ "getting_started.md", "nested_models.md", - "constraint_specification.md", - "plugins.md", - "migration.md", + "plugins/constraint_specification.md", + "plugins/overview.md", + "migration_3_to_4.md", "developers_guide.md", - "custom_backend.md", - "reference.md" + "custom_backend.md" ] Depth = 2 ``` diff --git a/docs/src/nested_models.md b/docs/src/nested_models.md index 430deb17..5794eb9e 100644 --- a/docs/src/nested_models.md +++ b/docs/src/nested_models.md @@ -19,12 +19,12 @@ Here, we see that the `κ, ω, z, x` and `y` variables define the boundary of th If we want to chain these `gcv` submodels together into a Hierarchical Gaussian Filter, we still use the `~` operator. Here, in the arguments to `gcv`, we specify all-but-one interface. `GraphPPL` will interpolate which interface is missing and assign it to the left-hand-side: ``` @example nested-models -@model function hgf(κ, ω, θ, prior_x, depth) +@model function hgf(κ, ω, z, prior_x, depth) for i = 1:depth - if i == 0 - means[i] ~ gcv(κ = κ, ω = ω, θ = θ, x = prior_x) + if i == 1 + means[i] ~ gcv(κ = κ, ω = ω, z = z, x = prior_x) else - means[i] ~ gcv(κ = κ, ω = ω, θ = θ, x = means[i - 1]) + means[i] ~ gcv(κ = κ, ω = ω, z = z, x = means[i - 1]) end end end diff --git a/docs/src/visualization.md b/docs/src/visualization.md new file mode 100644 index 00000000..149fe0db --- /dev/null +++ b/docs/src/visualization.md @@ -0,0 +1,131 @@ +# [Visualization](@id visualization) + +`GraphPPL.jl` ships with two optional visualization extensions that let you inspect the factor graph of a model. Both are loaded automatically through Julia's package extension mechanism — no explicit `using GraphPPL.Ext...` call is needed. Simply load the relevant packages alongside `GraphPPL`. + +## GraphViz extension + +The `GraphPPLGraphVizExt` extension is activated when `GraphViz.jl` is loaded alongside `GraphPPL`. It renders the model as a [DOT](https://graphviz.org/doc/info/lang.html)-format graph using GraphViz's layout engines, producing high-quality SVG output that displays inline in notebooks and IDEs. + +### Basic usage + +```@example visualization-graphviz +using GraphPPL, GraphViz, Distributions +import GraphPPL: @model + +@model function coin_toss(x) + θ ~ Beta(1, 1) + x .~ Bernoulli(θ) +end + +model = GraphPPL.create_model(coin_toss()) do model, context + return (; + x = GraphPPL.datalabel(model, context, GraphPPL.NodeCreationOptions(kind = GraphPPL.VariableKindData), :x, [1.0, 0.0, 1.0]) + ) +end + +GraphViz.load(model; strategy = :simple) +``` + +The return value is a `GraphVizGraphWrapper`. It renders as SVG in any environment that supports it. The underlying objects are accessible via: +- `viz.graph` — the raw `GraphViz.Graph` object +- `viz.dot_string` — the generated DOT source string + +### Saving to a file + +To write the visualization to disk as an SVG file, pass a path to `save_to`: + +```julia +GraphViz.load(model; strategy = :simple, save_to = "model.svg") +``` + +### Traversal strategies + +The `strategy` keyword controls the order in which nodes and edges are written into the DOT source, which influences how the layout engine positions them. + +- **`:simple`** — iterates directly over all vertices and edges. Fast and sufficient for most models. +- **`:bfs`** — traverses the graph breadth-first starting from the first created node. Tends to produce more structured layouts for models with a natural sequential or hierarchical order. + +### Visual encoding + +The extension distinguishes node types visually: + +| Node type | Shape | Fill | Text | +|:------------- |:-------- |:----------------- |:----- | +| Factor node | square | blue (`#4A90D9`) | white | +| Variable node | circle | white | black | + +Variable labels are rendered depending on their kind: +- **Constants** — shown as their quoted value (e.g. `"1.0"`) +- **Indexed variables** — rendered with an HTML subscript (e.g. `x₁`) +- **Plain variables** — shown as their quoted name (e.g. `"x"`) + +Factor node labels use `GraphPPL.prettyname` on the node's properties. + +### Configuration options + +| Keyword | Type | Default | Description | +|:------------- |:--------------------- |:---------- |:--------------------------------------------------------- | +| `strategy` | `Symbol` | (required) | Traversal order: `:simple` or `:bfs` | +| `layout` | `String` | `"dot"` | GraphViz layout engine (`"dot"`, `"neato"`, `"fdp"`, …) | +| `font_size` | `Int` | `12` | Font size for node labels | +| `edge_length` | `Float64` | `1.0` | Visual length of edges (interpreted by the layout engine) | +| `overlap` | `Bool` | `false` | Whether nodes are allowed to overlap | +| `width` | `Float64` | `10.0` | Canvas width in inches | +| `height` | `Float64` | `10.0` | Canvas height in inches | +| `save_to` | `String` or `Nothing` | `nothing` | If set, writes the SVG to this file path | + +!!! tip + For dense or large models, try `layout = "fdp"` or `layout = "dot"` combined with `overlap = false` to reduce visual clutter. + +## GraphPlot extension + +The `GraphPPLPlottingExt` extension activates when both `GraphPlot` and `Cairo` are loaded. It is a lighter-weight alternative that renders the graph through GraphPlot and saves the result as a PNG. + +### Basic usage + +```@example visualization +using GraphPPL, GraphPlot, Cairo +import GraphPPL: @model +using Distributions + +@model function coin_toss(x) + θ ~ Beta(1, 1) + x .~ Bernoulli(θ) +end + +model = GraphPPL.create_model(coin_toss()) do model, context + return (; + x = GraphPPL.datalabel(model, context, GraphPPL.NodeCreationOptions(kind = GraphPPL.VariableKindData), :x, [1.0, 0.0, 1.0]) + ) +end + +GraphPlot.gplot(model) +``` + +The plot is saved to `tmp.png` in the current directory and the plot object is returned. + +### Local subgraph visualization + +For large models it is often more useful to visualize only the neighborhood around a specific node. Pass a `NodeLabel` (or a vector of `NodeLabel`s) and a `depth` to expand the local neighborhood by that many hops: + +```julia +# show all nodes within 2 hops of `my_node` +GraphPlot.gplot(model, my_node; depth = 2) +``` + +This extracts the induced subgraph over the expanded node set and plots only that portion of the factor graph. + +| Keyword | Default | Description | +|:----------- |:----------- |:---------------------------------------------- | +| `depth` | `1` | Number of hops to expand from the seed node(s) | +| `file_name` | `"tmp.png"` | Output PNG file path | + +!!! note + The GraphPlot extension does not distinguish factor nodes from variable nodes visually — all nodes are rendered as circles with their label as the name. Use the GraphViz extension for richer visual encoding. + +## Choosing an extension + +The two extensions serve different purposes: + +- Use the **GraphViz extension** when you want publication-quality SVG output, need control over the layout engine, or want nodes color- and shape-coded by type. +- Use the **GraphPlot extension** when you want a quick PNG render or need to zoom into a local neighborhood of the graph using the `depth` parameter. diff --git a/ext/GraphPPLGraphVizExt.jl b/ext/GraphPPLGraphVizExt.jl index c9141fe3..e1ddfe40 100644 --- a/ext/GraphPPLGraphVizExt.jl +++ b/ext/GraphPPLGraphVizExt.jl @@ -218,8 +218,9 @@ Returns a quoted display label for a factor node. - `String`: The factor node's pretty name enclosed in double quotes """ function get_displayed_label(properties::GraphPPL.FactorNodeProperties) - # Ensure that the result of prettyname is enclosed in quotes label = GraphPPL.prettyname(properties) + # Strip module prefix (e.g. "Distributions.Normal" -> "Normal") + label = last(split(label, ".")) return "\"" * label * "\"" end @@ -255,7 +256,7 @@ end Writes DOT notation for nodes in a graph using simple iteration. Iterates through vertices and writes DOT format for: -- Factor nodes: Light gray squares +- Factor nodes: Blue filled squares - Variable nodes: Circles # Arguments @@ -294,7 +295,7 @@ end Writes DOT syntax for nodes in a graph visualization using breadth-first search traversal. Traverses the graph starting from the first created node and writes DOT notation for each node: -- Factor nodes are drawn as light gray squares +- Factor nodes are drawn as blue filled squares - Variable nodes are drawn as circles # Arguments @@ -330,9 +331,9 @@ function add_nodes!(io_buffer::IOBuffer, model_graph::GraphPPL.Model, global_nam if isa(properties, GraphPPL.FactorNodeProperties) displayed_label = replace(displayed_label, "\"" => "", "#" => "") - write(io_buffer, " \"$(san_label)\" [shape=square, style=filled, fillcolor=lightgray, label=\"$(displayed_label)\"];\n") + write(io_buffer, " \"$(san_label)\" [shape=square, style=filled, fillcolor=\"#4A90D9\", fontcolor=white, penwidth=1.5, label=\"$(displayed_label)\"];\n") elseif isa(properties, GraphPPL.VariableNodeProperties) - write(io_buffer, " \"$(san_label)\" [shape=circle, label=$(displayed_label)];\n") + write(io_buffer, " \"$(san_label)\" [shape=circle, style=filled, fillcolor=white, penwidth=1.5, label=$(displayed_label)];\n") else error("Unknown node type for label $(san_label)") end @@ -501,7 +502,7 @@ Converts a GraphPPL.Model to a DOT string for visualization with GraphViz.jl. - `strategy::Symbol`: Graph traversal strategy (`:simple` or `:bfs`) - `font_size::Int=12`: Font size for node labels - `edge_length::Float64=1.0`: Visual length of edges -- `layout::String="neato"`: GraphViz layout engine ("dot", "neato", "fdp", etc) +- `layout::String="dot"`: GraphViz layout engine ("dot", "neato", "fdp", etc) - `overlap::Bool=false`: Whether to allow node overlap - `width::Float64=10.0`: Display width in inches - `height::Float64=10.0`: Display height in inches @@ -521,7 +522,7 @@ function GraphViz.load( strategy::Symbol, font_size::Int = 12, edge_length::Float64 = 1.0, - layout::String = "neato", + layout::String = "dot", overlap::Bool = false, width::Float64 = 10.0, height::Float64 = 10.0, @@ -537,9 +538,12 @@ function GraphViz.load( write(io_buffer, "graph G {\n") write(io_buffer, " layout=$(layout);\n") - write(io_buffer, " overlap =$(string(overlap));\n") # control if allowing node overlaps + write(io_buffer, " rankdir=LR;\n") + write(io_buffer, " splines=ortho;\n") + write(io_buffer, " overlap=$(string(overlap));\n") write(io_buffer, " size=\"$(width),$(height)!\";\n") - write(io_buffer, " node [shape=circle, fontsize=$(font_size)];\n") + write(io_buffer, " node [fontsize=$(font_size), fontname=Helvetica];\n") + write(io_buffer, " edge [color=\"#888888\", penwidth=1.2];\n") # Nodes add_nodes!(io_buffer, model_graph, global_namespace_dict, traversal_strategy) diff --git a/src/GraphPPL.jl b/src/GraphPPL.jl index c7b0b0c3..354ed989 100644 --- a/src/GraphPPL.jl +++ b/src/GraphPPL.jl @@ -23,7 +23,7 @@ include("backends/default.jl") end Note that the `@model` macro is not exported by default and the recommended way of using it is -in the combination with some inference backend. The `GraphPPL` package provides the `DefaultGraphPPLBackend` structure +in the combination with some inference backend. The `GraphPPL` package provides the [`DefaultBackend`](@ref) structure for plotting and test purposes, but some backends may specify different behaviour for different structures. For example, the interface names of a node `Normal` or its behaviour may (and should) depend on the specified backend.