Crystal Graph Convolutional Neural Network (CGCNN)¶
Introduction¶
The Crystal Graph Convolutional Neural Network (CGCNN) is a deep learning framework designed for predicting material properties based on their crystal structures. It was introduced in the paper "Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties".
Graph Representation¶
The main idea in CGCNN is to represent the crystal structure by a crystal graph that encodes both atomic information and bonding interactions between atoms. A crystal graph \(\mathcal{G}\) is an udirected multigraph which is defined by nodes representing atoms and edges representing connections between atoms in a crystal.
Each node \(i\) is represented by a feature vector \(v_i\), encoding the property of the atom corresponding to node \(i\). Similarly, each edge \((i,j)_k\) is represented by a feature vector \(u_{(i,j)_k}\) corresponding to the \(k\)th bond connecting atom \(i\) and atom \(j\).
The crystal graph is unlike normal graphs since it allows multiple edges between the same pair of end nodes, a characteristic for crystal graphs due to their periodicity, in contrast to molecular graphs.
graph LR
A[Na1] --- B[Cl1]
A --- B
A --- C[Cl2]
A --- C
A --- E[Cl3]
A --- E
B --- D[Na2]
B --- D
B --- F[Na3]
B --- F
C --- D
C --- D
C --- G[Na4]
C --- G
D --- H[Cl4]
D --- H
E --- F
E --- F
E --- G
E --- G
F --- H
F --- H
G --- H
G --- H
Model Architecture¶
Graph Neural Network¶
The convolutional neural networks built on top of the crystal graph consist of two major components: convolutional layers and pooling layers. The convolutional layers iteratively update the atom feature vector \(v_i\) by "convolution" with surrounding atoms and bonds with a nonlinear graph convolution function,
After \(R\) convolutions, the network automatically learns the feature vector \(v_i^{(R)}\) for each atom by iteratively including its surrounding environment. The pooling layer is then used for producing an overall feature vector \(v_c\) for the crystal, which can be represented by a pooling function,
that satisfies permutational invariance with respect to atom indexing and size invariance with respect to unit cell choice. In this work, a normalized summation is used as the pooling function for simplicity, but other functions can also be used. In addition to the convolutional and pooling layers, two fully connected hidden layers with the depths of \(L_1\) and \(L_2\) are added to capture the complex mapping between crystal structure and property. Finally, an output layer is used to connect the \(L_2\) hidden layer to predict the target property \(\hat{y}\).
graph LR
A[Embed] --> B[Conv] --> C[Pool] --> D[L1] --> |softplus| E[L2] --> F[Out]
E --> |softmax| G[Out]
Convolutional Layer¶
The convolutional operation in CGCNN can be expressed as:
where:
- \(v_i^{(t)}\) is the feature vector of atom \(i\) at layer \(t\)
- \(\sigma\) is the sigmoid function (gate)
- \(g\) is the softplus activation function for introducing nonlinear coupling between layers
- \(\odot\) is the element-wise multiplication
- \(\oplus\) is the concatenation operation
- \(z_{(i,j)_k}^{(t)}\) is the concatenation of the feature vectors of atom \(i\), atom \(j\), and the \(k\)th bond between atom \(i\) and atom \(j\) at layer \(t\)
- \(W_f^{(t)}\) and \(b_f^{(t)}\) are the learnable weights and biases for the sigmoid function
- \(W_s^{(t)}\) and \(b_s^{(t)}\) are the learnable weights and biases for the softplus function