Search Blogs

Thursday, March 13, 2025

Petar Veličković's review talk on GNNs

If you're new to graph neural networks, Petar Veličković's review talk [1] that was given at KHIPU was a really nice introductory lecture. I watched it and decided to take some notes to go along with my other GNN blogs. Here is the video if you want to watch it.

šŸ–‹️ Style

The format of this post will look more like notes. I think I'll do this more often since I take most my notes in markdown anyways and I use this blog really for tracking my thoughts and learning. The notes follow the sections of the talk pretty closely so I've just titled them that way.

Why Graphs?

  • Key building blocks of a graph are node and edge: $u \in \mathcal{V}$ and $(u,v) \in \mathcal{E}$. Together they form a graph: $\mathcal{G} = (\mathcal{V}, \mathcal{E})$. (see fig. Figure 1)
  • Nodes can be: documents, roads, atoms, etc.
  • Edges can be: document links, road intersections, chemical bonds, etc.
  • What can we do with graphs?
    • Classify nodes: what type of document, what type of atom, etc.
    • Node regression: predict traffic on a particular road, etc.
    • Graph-level classification/regression: label specific molecules (i.e., graphs) or predict their properties1.
    • There has been considerable success in molecule/drug discovery using GNNs.
    • Google's DeepMind GNoME effort
    • Link prediction: predict which edges should exist in a graph, e.g., in social networks, edges connect users who purchased the same product.
    • Arguably most data has a suitable graph representation. Think of an image where instead of convolution over neighbor pixels, this is learned from the graph structure.
    • Graph structures are generalizable because you can map projected spaces much easier to graph structures without constrain.
    • Agentic graphs: nodes are memory and edges are connections/flow of information between memories.

Figure 1. Structure of a graph.

What's in a graph neural network?

  • Graphs are easy to depict data in abstract sparse form.
  • This is not machine learning friendly/ready: need tensorsial representation.
  • Construct node feature matrix, $\mathbf{X} \in \mathbb{R}^{\mathcal{V} \times k}$, and a connectivity matrix (adjacency matrix), $\mathbf{A} \in {0,1}^{\mathcal{V} \times \mathcal{V}}$2. (see fig. Figure 2)

Figure 2. Graph as tensors

  • The choice of the ordering in $\mathcal{X}$ is arbitrary and thus we have imposed an ordering on the graph which we do not want the neural network to depend on or learn.
  • If we shuffle $\mathcal{X}$ and thus orderings of rows in $\mathbf{A}$ then the output of the NN should be the same.
    • This is permutation invariance: $f(\mathbf{X}, \mathbf{A}) = f(\mathbf{P}\mathbf{X}, \mathbf{P}\mathbf{A}\mathbf{P}^\top) = f(\mathbf{X}', \mathbf{A}') = \mathbf{y}$
    • Very important property for GNNs in classifying/regression of graph properties (e.g., predicting energy of a molecule represented as a graph).
  • For node/edge level task predictions, we don't have necessarily expect permutation invariance, but rather equivariance.
  • Equivariance relates well defined transformations of inputs to outputs, i.e., its predictable. (see fig. Figure 3)
    • Whats important is this is the permutation is commutative, i.e., can apply to function input or output.
  • On a graph, the equivariance can be handled through applying a function locally to act on each node-features.
  • $f(\mathbf{x}_i, \mathbf{X}_{\mathcal{N}_i})$ is an invariant NN and thus ensures the output is invariant as long as $f$ doesn't depend on the ordering of node neighbors $\mathcal{N}_i$. (see fig. Figure 4, important property of GNNs)
  • GNN: We take a single node, its features, and its neighbors. apply a function to it, and get the latent space features of the node. (see fig. Figure 5)
    • Node classification/regression: equivariant3
    • Graph classification/regression: invariant
    • Link prediction: equivariant

Figure 3. Equivariance in graphs

Figure 4. Local invariant functions applied to nodes

Figure 5. General blueprint for learning on graphs

  • Geometric Deep Learning is about identifying what symmetry functions you want your ML model to be invariant or equivariant to, i.e., what transformations do we constrain to (see fig. Figure 6)

Figure 6. Geometric Deep Learning

How to implement a graph neural network?

  • How do we build local functions? How do we process a node to get latent space features?

Graph Convolutional Networks

  • Simple, let each node update by average of its neighbors, $\mathbf{h}_i = \sigma\left(\frac{1}{d_i} \sum_{j\in\mathcal{N}_i} \mathbf{W}\mathbf{x}_j\right)$.
  • $d_i$ is number of neighbors of node $i$ and $\sigma$ activation function.
  • Kind of convolution over the local graph structure.
  • Symmetric normalization (i.e., geometric mean) seems to perform better, $\frac{1}{\sqrt{d_i d_j}}$.
  • A lot of this came from spectral methods on a graph.
  • Issue with simple convolution GNN which can be used for a images; you end of with radial kernel filters because the average will look the same for all nodes(pixels) having the same neighbors.
  • You cannot just use the graph structure because for a regular graph (i.e. image grid where $d_i$ is same for every node) it will look the same.

Graph Attention Networks

  • To circumvent limitations of CGNN, use graph attention networks, where the features of the nodes tell use which neighbors are important (see fig. Figure 7)
    • $\mathbf{h}_i = \sigma\left(\sum_{j\in\mathcal{N}_i} \alpha\left(\mathbf{x}_i, \mathbf{x}_j\right) \mathbf{W}\mathbf{x}_i\right)$
    • $\alpha$ is the attention mechanism, $\mathbb{R}^{k \times k} \rightarrow \mathbb{R}$
    • Post 2018 this is usually a learned function using a MLP.

Figure 7. Graph Attention

Message Passing Networks

  • Graph attention is a essentially weighted sum but possibly not general enough. Whats next? Message Passing NN.
    • The recieve node not actively computes the data being sent to it via messages (see fig. Figure 8)
    • $\mathbf{h}_i = \phi\left(\mathbf{x}_i, \bigoplus_{v\in\mathcal{N}_i} \psi\left(\mathbf{x}_i, \mathbf{x}_j\right)\right)$
    • $\psi$ is the message function (kernel-like function?), $\mathbb{R}^{k \times k} \rightarrow \mathbb{R}^m$. Tells use the new vector representation being set along the edge.
    • $\bigoplus_{j\in\mathcal{N}_i}: \mathbb{R}^{m} \rightarrow \mathbb{R}^m$ is a permutation invariant aggregation function (e.g., sum, max) tells a nodes "mailbox" how to "summarize" the messages.
    • $\phi$ is the update function, $\mathbb{R}^{k} \times \mathbb{R}^{m} \rightarrow \mathbb{R}^l$ (again a kernel-like function?)
    • $\phi$ and $\psi$ are typically shallow MLPs, e.g., $\phi(\mathbf{x}_i, \mathbf{x}_j) = \text{ReLU}(\mathbf{W}_1 \mathbf{x}_i +\mathbf{b}_1 + \mathbf{W}_2 \mathbf{x}_j + \mathbf{b}_2)$
  • Multiple communities independently came up with this idea of message passing.

Figure 8. Message Passing graph

Diving into computational graphs

  • Graph NN modify the node features but leaves the graph structure intact (i.e., adjacency matrix static)
  • One has to pay attention to how a graph is represented (i.e., connected) as it may be suboptimal.
  • In processes (think chemical reactions) the graph structure is dynamic and changes4.
  • Graph connected by "barbell" is a example of a suboptimal representation (see fig. Figure 9). You need that edge to become the postman to ensure messages are passed among all high-order neighbors. Best use subgraph to represent the graph.

Figure 9. Two graphs bottlenecked

Rewire a graph for messaging

  • If we are to be naive, the two ways to rewire:
    • Assume no good edges, so just self-messaging/update: $\mathbf{h}_i = \psi(\mathbf{x}_i)$ with neighbors $\mathcal{N}_i = {i}$ (called Deep Sets?)
    • Assume every node is connected to every other node: $\mathcal{N}_i = \mathcal{V}$: $\mathbf{h}_i = \phi(\mathbf{x}_i, \bigoplus_{j\in\mathcal{V}} \alpha(\mathbf{x}_i, \mathbf{x}_j)\psi(\mathbf{x}_i, \mathbf{x}_j))$, hence graph attention message passing5.
    • Attention is inferring the "soft" adjacency matrix, i.e., dense weighted adjacency matrix. This is ideal for GPU matrix/array operations.
  • Use a nonparametric rewiring of edges, algorithmically, i.e., nothing learned
  • Parametric rewiring, latent graph inference, learn a new graph structure. Hard to do in practice.
  • Non-parameteric rewiring options:
    • Use diffusion model to rewire an already good set of edges. Take adjacency matrix and raise to a power $k$ to get a measure of diffusive distance between nodes.
    • Surgical adjustments, target bottleneck regions of edges in graph by assessing curvature of the graph.
    • Precompute a template graph, just start with no edges and iterate to build really good (i.e., information dense) graph edges. Takes $\sim\log{N}$ steps. Sparse graphs like this called expander graphs6.

LLM Stack

  • Since LLMs are causal (directed) graphs, information propagation can be biased towards earlier tokens. (does this affect BERT?)

Figure 10. Casual attention is a graph

Getting in on the action!

  • Rich ecosystem of libraries/frameworks.
Framework Sub-libraries
PyTorch PyTorch Geometric, Deep Graph Library
Tensorflow Spektral, TensorFlow GNN
JAX Jraph

Footnotes


  1. Petar points out that in the 1990's there was already work in cheminformatics that was using graph representations of molecules to predict properties and that much of this work precedes the modern activities in the ML/AI community.

    Figure 11. Chemistry lead GNNs
     

  2. The adjacency matrix here just encodes wheter two nodes are connected or not. It assumes no features on the edges. If there were features on the edges, we would write $\mathbf{A} \in \mathbb{R}^{\mathcal{V} \times \mathcal{V} \times d}$ where $d$ is the feature dimension. 

  3. At the node level the classification/regression is invariant because the function output does not depend on the ordering of the neighbors, however, the entire graph output vector $\mathbf{y}$ is equivariant because the output will depend on the ordering of the nodes. The node outputs being invariant is important to ensure the graph output is just a commutation with the permutation matrix $\mathbf{P}$. 

  4. If we reference graph neural network potentials the graph is in an essence dynamically determined because these use radial cutoff values/functions to determine what nodes are connected at any given time. 

  5. The equation for updating node features with graph attention with message passing looks similar to the equation for the transformers model. In fact, transformers are fully connected graphs. So the question I ask my self is, can the tokens be though of as nodes and the next token prediction corresponding to the next node to connect/label to in the graph? I think the answer is yes. And Petar says GPT is directed edges (i.e. fully connected casual graphs). Also attention is a graph level operation so equivariant behavior is still preserved despite text being a sequence (i.e., its ordered by position). 

  6. Expander graphs might be a good use case for incorporating long-range interactions in crystal graphs. The original graphs based on cutoffs can be used in addition to an expander graph which finds optimal connections between nodes to propagate efficiently long-range interactions. Just an idea, of course graph attention would achieve the same thing but would quickly become intractable for large systems. 

References

[1] P. Veličković, Geometric Deep Learning, in: KHIPU 2025: Latin American Meeting in Artificial Intelligence, KHIPU, Santiago, Chile, 2025. URL



Reuse and Attribution