Search Blogs

Thursday, May 30, 2024

What am I getting wrong here

The implementation of Xie and Grossman's crystal graph convolutional neural network[1] in Julia using Flux.jl and GraphNeuralNetworks.jl [2] is moving along. As I updated in my previous post, I got the data wrangling and processing done correctly. Well, at least I think. It is hard to compare apples-to-apples with the Python implementation and my Julia code for the graph structure. In the original PyTorch version by Xie and Grossman, the graph structure was just conceptual as there were no class objects that specifically represented a graph. Actually, it would be useful if someone familiar with PyTorch Geometric, CGCNN, and Flux.jl/GraphNeuralNetworks.jl [2] could check things out for me. πŸ™

But assuming the graph structure of the data is correct, the thing that is even more difficult is the construction of the neural network layers. In the original implementation, the node features and edge features are concatenated into a single feature vector on the nodes, then a convolutional weight matrix (i.e., kernel) is applied to "reshape" the feature vector and update the node features. Let me back up; the concatenated edge features are taken by a summed aggregation over nodes/edges. This means we take the bonds based on some cutoff to neighboring atoms. In the parlance of GNN, I think this is just a message passing scheme. But I'm not entirely sure.

So we have an update that takes the edge features, concatenates them to the node features, then applies a convolutional weight matrix. Then we also include a self-update to the original node feature using a weight matrix. This type of operation updates the node features, but not the edge features. This may make sense because in the original CGCNN implementation, the edge features are represented as radial basis functions. To summarize, this is the equation:

$$ \begin{align} z^{(t)}_{(i,j)_k} &= v_{i}^{(t)} \oplus v_{j}^{(t)} \oplus u_{(i,j)_k} \label{eq:feature} \\ v_i^{(t+1)} &= v_i^{(t)} + \sum_{j,k} \sigma\left(z^{(t)}_{(i,j)_k} \mathbf{W}_f^{(t)} + \mathbf{b}_f^{(t)} \right) \odot g\left(z^{(t)}_{(i,j)_k} \mathbf{W}_s^{(t)} + \mathbf{b}_s^{(t)} \right) \label{eq:update} \end{align} $$

Here we have $t$ being the layer1 number. The key aspect is that the convolutional weight matrix, $\mathbf{W}^{(t)}_f$, is not globally learned, but rather learned for $i$-th atom environment over $k$ bonds2. I'm pretty sure my current implementation does not do this and does what is described in eq. 4 in the original paper:

$$ \begin{align} \mathbf{v}_i\left(t+1\right) = g\Bigg[ &\left(\sum_{i,j} \mathbf{v}_j\left(t\right) \oplus \mathbf{u}_{\left(i,j\right)_k}\right)\mathbf{W}_c\left(t\right) \nonumber \\ &+ \mathbf{v}_i(t)\;\mathbf{W}_s\left(t\right) + \mathbf{b}\left(t\right)\Bigg]\label{eq:Xie_Grossman_eq4} \\ \end{align} $$

The other confusion I have is that the pooling operation appears to be done across each layer update of the node features. This makes sense but I don't think I'm doing it correctly.

Here is my current CGCNN model Flux.jl implementation:

# Define the CGCNN struct struct CGCNN embedding::Dense convs::Vector{CGConv} conv_to_fc::Dense conv_to_fc_softplus::Function fc_out::Dense end
# Constructor for CGCNN ...
# Pooling function function pooling(atom_fea) return sum(atom_fea, dims=2) ./ size(atom_fea, 2) end
# Forward pass function (model::CGCNN)(g::GNNGraph) atom_fea = g.ndata[:x] edge_fea = g.edata[:e]
atom_fea = model.embedding(atom_fea) for conv in model.convs atom_fea = conv(g, atom_fea, edge_fea) end crys_fea = pooling(atom_fea) crys_fea = model.conv_to_fc_softplus(model.conv_to_fc(crys_fea))
out = model.fc_out(crys_fea)
return out end

The concern is in the pooling operation as you see it's applied only to the final layer update to the node/atom features, which is not what is shown in eq. 2 in Xie and Grossman's paper. The thing though is Zygote.jl which does the automatic differentiation for back-propagation, fails about mutating arrays if I try and store the atom_fea outputs at each Graph CNN layer.

Well, I will keep working through this until I get something that seems to train at the same mean-absolute-error as the original paper. Also, the softmax should be applied to each feature layer, $v_i^{(t)}$ prior to pooling, which I'm not doing (see eq. S1 in [1]).

It would have been nice if a detailed neural network diagram was shown on top of Fig. 1 in Xie and Grossman's paper. I'm probably missing some small detail regarding implementation. The one thing I know is correct is the conv function call because this is just CGConv from GraphNeuralNetworks.jl which is an exact implementation of eq. 5 in [1].

Footnotes


  1. I think it is easier to think of $t$ as the iteration step because what is happening is the node features are updating at each layer, not the edges though. 

  2. If you look at Fig. S1 in Xie and Grossman's paper, you see that they actually represent the graph using symmetric bonds and thus two nodes in their representation have multiple edges/bonds. This is why the $k$ index is used. In my implementation, I just build an adjacency matrix and use this to define the graph. 


References

[1] T. Xie, J.C. Grossman, Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties, Phys. Rev. Lett. 120 (2018) 145301. https://doi.org/10.1103/PhysRevLett.120.145301.

[2] C. Lucibello, other contributors, GraphNeuralNetworks.jl: a geometric deep learning library for the Julia programming language, (2021). https://github.com/CarloLucibello/GraphNeuralNetworks.jl.



Reuse and Attribution

Thursday, May 16, 2024

Update on Crystal Graph CNN

I've finally made progress! Thanks1 to a shift from using GeometricFlux.jl to GraphNeuralNetworks.jl. Both are good, but I find GraphNeuralNetworks.jl to be easier to use. I wrote about my interest in doing something with GNNs originally in this blog post and then a few other general posts.

If you recall from the original post was on Crystal Graph CNN for property predictions. The original paper was by Xie and Grossman [1] and is in my opinion a very easy paper to follow and understand, hence why I choose it. I also think it set off a bit of the interest of GNNs in materials science. The nice thing is that paper used the materials project data, which I'm familiar with.

Original CGCNN Dataset

One thing I've noticed is that the materials project ids in the .csv files in the original CGCNN repo have ids that are no longer valid to query the materials project API. I'm not sure if this is an issue with my REST API julia function or that the materials project drops ids for structures when they get more accurate calculations for those structures and chemical systems. My guess is the REST API endpoints that I am using are wrong and I need to make modifications.

So how did I make progress? Well I finally got all the innards working to prepare the graphs. The final Julia code to do that looks like:

""" get_item(cifdata::CIFData, idx; dtype=Float32)
Prepare a graph for a Graph Neural Network (GNN) from CIF data.
# Arguments - cifdata::CIFData: An instance of CIFData containing crystallographic information. - idx: An index identifying the specific crystal structure within cifdata. - dtype: The data type for the node and edge features (default is Float32).
# Returns - gnn_graph::GNNGraph: A graph object compatible with GNNs, containing node and edge features.
# Details This function performs the following steps: 1. Loads the crystal structure from a .cif file. 2. Extracts atomic numbers and computes node features. 3. Initializes a graph with nodes corresponding to atoms in the crystal. 4. Builds a neighbor list and constructs the adjacency matrix. 5. Computes edge features using an expanded Gaussian Radial Basis Function (RBF). 6. Creates a GNNGraph object with the node and edge features. """ function get_item(cifdata::CIFData, idx; dtype=Float32) cif_id, target = cifdata.id_prop_data[idx] crystal = load_system(joinpath(cifdata.root_dir, join([cif_id, ".cif"])))
# Atom/Node Features at_nums = @. atomic_number(crystal) atom_features = [dtype.(get_atom_fea(cifdata.ari, at)) for at in at_nums] node_features = hcat(atom_features...) # Convert to matrix form
num_atoms = length(crystal) g = SimpleGraph(num_atoms)
# Build Neighbor list -> Adjacency Matrix nlist = PairList(crystal, cifdata.radius*u"Γ…") adj_mat = construct_adjacency_mat(nlist) edge_features = [] processed_edges = Set{Tuple{Int, Int}}() # Construct edge features in expanded Gaussian RBF # 1. We have to track whether a edge/pair has been # assigned a feature vector. # 2. The implementation here differs from CGCNN.py # in that we don't process all edge features for # a node/atom in one shot. for i in 1:num_atoms nbrs, nbrs_dist_vecs = neigs(nlist, i) for (j, dist_vec) in zip(nbrs, nbrs_dist_vecs) edge = i < j ? (i, j) : (j, i) # Ensure unique representation if edge processed_edges add_edge!(g, edge[1], edge[2]) # Add edge to the graph dist = norm(dist_vec) dist_basis = expand(dist, cifdata.gdf) push!(edge_features, dtype.(dist_basis[:])) push!(processed_edges, edge) end end end edge_feature_matrix = hcat(edge_features...)
gnn_graph = GNNGraph(g; ndata=node_features, edata=edge_feature_matrix)
return gnn_graph end

There are a lot of supporting function that are implemented but are not shown here. Hopefully from the doc string and function naming you can see what they do. Once I figured out how to do this it was pretty simple to construct a neural network with different convolutional, pooling, and dense layers. I haven't fully implemented the correct architecture as described in [1], so I won't show the code here yet. The model I do have isn't doing to well given the MAE loss curves below (~1K data points), but this is not the model in the paper. Also don't have the compute to deal with the total dataset size (~70K).

I can at least start from here! Obviously it looks pretty bad.

If your looking to utilize the function above, just shoot me an email and I'll share the current Pluto notebook. The notebook itself will be a post on my computational blog once I've successfully reproduced, within some reasonable error, a metric from ref. [1]. I don't know what the timeline looks like for this because it will depend on my ability to train the model on my computatational resources. At the moment I don't have a GPU to train on.

If your looking to utilize the function above, just shoot me an email and I'll share the current Pluto notebook. The notebook itself will be a post on my computational blog once I've successfully reproduced, within some reasonable error, a metric from ref. [1]. I don't know what the timeline looks like for this because it will depend on my ability to train the model on my computatational resources. At the moment I don't have a GPU to train on.

Training Update

I think I've managed to get the correct model representation with Flux.jl, I'm not getting training and validation curves that look more reasonable.

Training curves looking better, still need to improve things.

Footnotes


  1. Also thanks to the github user @aurorarossi who provided some basic understanding on how GNNGraph treats undirected graphs as having two opposing directed edges for each node pair. 


References

[1] T. Xie, J.C. Grossman, Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties, Phys. Rev. Lett. 120 (2018) 145301. https://doi.org/10.1103/PhysRevLett.120.145301.



Reuse and Attribution