I'm working on a Pluto.jl notebook blog that is an implementation of the Crystal Graph Convolutional Neural Network paper by Xie and Grossman [1]. My approach is to use the FeaturedGraph and GeometricFlux.jl package. The problem is I'm at times having such a hard time wrapping my head around how a graph gets feed into a NN. My guess is I'm struggling for two reasons:
- I'm not following the math closely enough.
- I don't have a good grasp of how a regular convolutional neural network is implemented.
As a result, I'm going to probably have two or three blog posts write-up my notes on this so that I'm honest with myself (i.e., I need to be through since I'm posting). This post will just focus on 1., so its basically a dictation of the original paper by Xie and Grossman.
What is a Crystal Graph
A crystal in 3D is best represented by a basis and a lattice which when combined form a crystal (i.e. space group). We can illustrate the operation:
Adapted ref. [2]. (https://tikz.janosh.dev/basis-plus-lattice) |
However, this representation derives mostly from group theory mathematics applied to matter (e.g., atoms) and space, but we could use a different approach to describe the same physical system. That would be the theory of graphs1, i.e., graph theory. With graphs we define a graph as:
$$ \begin{equation} \label{eq:graph} G = (V,E) \end{equation} $$
with $V$ being the set of vertices or nodes, and $E$ the edges. The edges are a subset of the vertices, $E \subset {{x,y}}$ where ${x,y} \in V^2$ and $x\neq y$. Graphs can also be directed or undirected, meaning $E= (x,y)$ is an ordered pair. To represent the full set of edges we can use an adjacency matrix, which is just a one-hot-encoding of where a pair of nodes forms an edge. Furthermore we can weight the adjacency matrix, meaning the elements are larger then unity. This might indicate some chemical bond strength for example.
Thats the basic structure of a graph and there is much more one can discuss but I'm content with that understanding at the moment. We can no take a crystal and convert it to a graph. We define the nodes as the atoms and the edges as the neighbors or bonds around the atoms. The question becomes how does one define whether an atom is a neighbor of another atom? The answer is we have to rely on spherical neighborhood truncated by a cutoff, $N(i) = \lbrace j : ||\vec{r}_i - \vec{r}_j|| \leq r_{\text{cutoff}}\rbrace$, with $i$ and $j$ being the atom indices. Using $N(i)$ we can then iterate over unique pairs to define the edges of our graph, $G$. This means the adjacency matrix size/sparsity will depend on the value for $r_{\text{cutoff}}$.
There is one other thing about graphs in the context of machine learning, which is we can associate features with nodes and edges. For example imagine that each node has a vector of properties that correspond to some quantity prescribed or to be learned. Similarly we can do that for the edges. This is very useful in chemistry and materials science applications because we can assign elemental properties features to nodes and bond characteristics to edges. Then we can either update/learn the true features across data set or just use prescribed features. The former is the path usually taken.
There are different ways to visualize the graph representation of a crystal, the one by Xie and Grossman is pretty straightforward. Since the chemical environments will be similar due to symmetry,all the matters is the most primitive graph is provided. For example, the illustrative exercise2 in the supplemental material of ref. [1], the NaCl structure Xie and Grossman use 8 nodes, while for KCl they use 2.
Crystal graphs for NaCl (a) and KCl (b). (Figure from S1. https://doi.org/10.1103/PhysRevLett.120.145301) |
Convolution of a crystal graph
A convolution is a dimensional reduction based on some kernel/filtering function. Convolutional layers in a neural network work by learning what characteristics "pop out" when passing data input into the layer. This is particularly successful in images where we have a 2D array of pixels and then each pixel may be single or multi-channel. With a CNN, each layer can learn what what are the characteristics of the image based on how the pixels and channels "transform".
It turns out we can adopt the same concept to graphs, because if your really think about it, a graph is just an image but non-cartesian. We have the adjacency matrix, where each element can be though of as an analog to a pixel, and then as with a multi-channel image, we have a high-dimensional feature vector. We construct our feature vector by somehow aggregating the node and edge features (more on this later).
So what does a CGNN learn, well it learns the features of the nodes and edges in each convolutional layer, that are distinct characteristics of an atom, its bonds, and environment.
This brings us to the main equation used in ref. [1]:
$$ \begin{align} \mathbf{v}_i\left(t+1\right) = g\Bigg[ &\left(\sum_{i,j} \mathbf{v}_j\left(t\right) \oplus \mathbf{u}_{\left(i,j\right)_k}\right)\mathbf{W}_c\left(t\right) \nonumber \\ &+ \mathbf{v}_i(t)\;\mathbf{W}_s\left(t\right) + \mathbf{b}\left(t\right)\Bigg]\label{eq:Xie_Grossman_eq4} \\ \end{align} $$
Here $g$ is the convolution function, which takes as its inputs the convolutions of the summed concatenated neighbor features, its self-features, and a bias. In other words, $\mathbf{W}_c(t)$ is the convolution weight and $\mathbf{W}_s$ the self weight matrices. This operation, with different matrices, is then applied to every layer $t$.
The convolution weight matrix, $\mathbf{W}_c(t)$, is the same for all neighbors of node $i$ and thus won't allow for selecting out different strengths based on environment. However, I'm going to keep with eq. \eqref{eq:Xie_Grossman_eq4} and work through the example they provide in the supplemental.
So what does the CGCNN look like from a visualization perspective. Below, if you look at (b) you can see what gets done in terms of the operations. There are additional fully-connected hidden layers to reconstruct the graph features and then pooling operation (dim. reduction) followed by another full-connected set of hidden layer to make the final target property prediction.
Fig. 1 from Xie and Grossman. (See https://doi.org/10.1103/PhysRevLett.120.145301) |
Note
One thing that seems cool to me about this architecture is that if you say learn the graph features (i.e. after L$_1$) from say an electronic target property. Then I'm thinking you could take this learned CGCNN part and plug it into another L$_2$ for a different property prediction to learn. This is transfer learning, but I'm wondering how good it is in this case.
Working through an example
This is just going to be a reworking of the illustrative example provided by Xie and Grossman in the supplemental material of their paper [1]. Taking the figure above of NaCl and KCl we express the initial feature vectors for each node (i.e., atom) as:
$$ \begin{align*} \mathbf{v}_{\text{Cl}} &= \left( 1\; 0\; 0\right) \\ \mathbf{v}_{\text{Na}} &= \left( 0\; 1\; 0\right) \\ \mathbf{v}_{\text{K}} &= \left( 0\; 0\;1\right) \\ \end{align*} $$
This is just a one-hot-encoding because we will focus on classification, i.e., assign a label to whether a crystal is NaCl or KCl.
Note
Notice that we have excluded the edge features, but these can be added and then you have to include the operation to concatenate the features.
The next step is to work through eq. \eqref{eq:Xie_Grossman_eq4} by applying both the $\mathbf{W}_c$ and $\mathbf{W}_s$ matrices to each feature vector. The convolutional filter for both will just be a $3\times1$ matrix with elements/weights $w_{c_i}$ and $w_{s_i}$, with $i$ being the row index.
Assuming a bias of $\mathbf{b} = \mathbf{0}$, and activation function $g: x \mapsto x$. This simplifies the problem to very straightforward matrix algebra, giving the following feature vectors:
$$ \begin{align*} \mathbf{v}_{\text{Na}}(t=1) &= 6w_{c_1} + w_{s_2} \\ \mathbf{v}_{\text{Cl}}(t=1) &= 6w_{c_2} + w_{s_1} \\ \end{align*} $$
and for the KCl graph,
$$ \begin{align*} \mathbf{v}_{\text{K}}(t=1) &= 8w_{c_1} + w_{s_3} \\ \mathbf{v}_{\text{Cl}}(t=1) &= 8w_{c_3} + w_{s_1} \\ \end{align*}. $$
Here $t=1$ indicates this is just a single convolutional layer. It's important to note that the factors 6 and 8 for the $w_{c}$, in NaCl and KCl, come from the summation in eq. \eqref{eq:Xie_Grossman_eq4} which captures the connectivity of the nodes, i.e., neighbors. We now have a feature vector for each node that has a length of 1.
The next step is to perform a normalized pooling operation for each crystal graph, such that we get a single feature vector representation. This is done by simply taking an arithmetic average over all the node features at the layer output, i.e., $\mathbf{v}_i(t=1)$,
$$ \begin{align*} \mathbf{v}_{\text{NaCl}} &= \frac{\mathbf{v}_{\text{Na}}+\mathbf{v}_{\text{Cl}}}{8} \\ &= \frac{4\cdot6 w_{c_1} + 4 w_{s_1} + 4\cdot6 w_{c_2} + 4 w_{s_2}}{8} \\ &= 3 w_{c_1} + 0.5 w_{s_1} + 3 w_{c_2} + 0.5 w_{s_2} \\ \end{align*} $$
Notice that I worked this out so we can see that all the nodes corresponding to a species are used. If its not clear just review the S1 figure above. Now for KCl,
$$ \begin{align*} \mathbf{v}_{\text{KCl}} &= \frac{\mathbf{v}_{\text{K}}+\mathbf{v}_{\text{Cl}}}{2} \\ &= \frac{1\cdot8 w_{c_1} + w_{s_2} + 1\cdot8 w_{c_2} + w_{s_1}}{2} \\ &= 4 w_{c_1} + 0.5 w_{s_1} + 4 w_{c_3} + 0.5 w_{s_3} \\ \end{align*} $$
Since we have a single valued feature vector for each crystal structure we don't need to have an additional layer to map inputs to target outputs, we can just solve the set of equations:
$$ \begin{align} \hat{y}_{\text{NaCl}} &= 3 w_{c_1} + 0.5 w_{s_1} + 3 w_{c_2} + 0.5 w_{s_2} \label{eq:nacl}\\ \hat{y}_{\text{KCl}} &= 4 w_{c_1} + 0.5 w_{s_1} + 4 w_{c_3} + 0.5 w_{s_3} \label{eq:kcl} \\ \end{align} $$
If our target is just to distinguish between NaCl and KCl, we can set the $\hat{y}_{\text{NaCl}}=1$ and $\hat{y}_{\text{KCl}}=-1$. We can solve this simple linear algebra, $Ax=b$, problem by hand or using code. I'll use julia:
A = [3.0 0.5 3.0 0.5 0.0 0.0;
4.0 0.5 0.0 0.0 4.0 0.5]
b = [1; -1]
x = A \ b
which gives the weights:
$$\begin{align*} w_{c_1} &= 0.0249 \\ w_{s_1} &= 0.0155 \\ w_{c_2} &= 0.2975 \\ w_{s_2} &= 0.0496 \\ w_{c_3} &= -0.2726 \\ w_{s_3} &= -0.0341 \end{align*}$$
Question
If we now feed in new feature vectors corresponding to a different structure we will get the label prediction? I assume if we change the target using a softmax logit we would get the fraction it is a NaCl and KCl structure.
In actual implementation of CGCNN we train the weights using backpropagation and the edge features, bias, and activation functions are included. My follow-up post will try to understand a little more of the pass through mechanism and coding of how this is done in practice.
Footnotes
-
I'm pretty sure graph theory is ultimately related to group theory. ↩
-
I'm reusing the figures here from the arXiv preprint version, but I also permission and reuse license from APS. I'm a bit confused on how this works, but my main understanding is as long as a link to the DOI after obtaining permission I'm OK. ↩
References
[1] T. Xie, J.C. Grossman, Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties, Phys. Rev. Lett. 120 (2018) 145301. https://doi.org/10.1103/PhysRevLett.120.145301.
[2] J. Riebesell, S. Bringuier, Collection of standalone TikZ images, (2022). https://doi.org/10.5281/zenodo.7486911.
No comments:
Post a Comment
Please refrain from using ad hominem attacks, profanity, slander, or any similar sentiment in your comments. Let's keep the discussion respectful and constructive.