Usage¶
Basic Usage¶
Here's a basic tutorial on going through the prediction script using the functions provided by the package.
1. Importing the package¶
There are three main modules available in the package:
cgcnn2.data
: For loading and preprocessing the data.cgcnn2.model
: Building blocks for the CGCNN model.cgcnn2.util
: Some utility functions.
from cgcnn2.data import CIFData, collate_pool
from cgcnn2.model import CrystalGraphConvNet
from cgcnn2.util import cgcnn_test
2. Data Preparation¶
To input material structures into CGCNN, you need to define a custom dataset. Before doing so, make sure you have the following files:
CIF
files recording the structures of the materials you wish to study.- Target properties for each material (not needed for prediction jobs).
Organize these files in a directory (root_dir
) with the following structure:
id_prop.csv
(optional for prediction): A CSV with two columns, the first column is a unique material ID, and the second column is the corresponding target property value.atom_init.json
: AJSON
file that provides the initialization vector for each element. You can use the example at/cgcnn2/asset/atom_init.json
from the original CGCNN repository; it should work for most applications.CIF
files: One.cif
file per material, namedID.cif
, whereID
matches the entries inid_prop.csv
.
Once your root_dir
(for example, /examples/data/sample_regression
) contains these files, you can load the dataset using the CIFData
class:
dataset = CIFData("/examples/data/sample_regression")
This will prepare your crystal structures (and, if provided, their target properties) for use with CGCNN. Then, we can build a torch.utils.data.DataLoader
object that can be used to load the dataset in a batch.
from torch.utils.data import DataLoader
loader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
collate_fn=collate_pool,
)
3. Model Initialization¶
We need some information from the dataset to initialize the model, which can be done by:
atom_graph, _, _ = dataset[0]
orig_atom_fea_len = atom_graph[0].shape[-1]
nbr_fea_len = atom_graph[1].shape[-1]
where dataset[0]
is a tuple of (atom_graph, target, cif_id)
, where atom_graph
is a tuple of (atom_fea, nbr_fea, nbr_fea_idx)
, target
is the target property value, and cif_id
is the unique ID of the material. The atom_graph
tuple contains the atom features, neighbor features, and neighbor indices, and the dimensions of these features given by orig_atom_fea_len
and nbr_fea_len
are needed to initialize the model.
Besides, we need some information about the pre-trained model architecture, which can be done by:
import torch
import argparse
checkpoint = torch.load(args.model_path, map_location=args.device)
model_args = argparse.Namespace(**checkpoint["args"])
atom_fea_len = model_args.atom_fea_len
n_conv = model_args.n_conv
h_fea_len = model_args.h_fea_len
n_h = model_args.n_h
where atom_fea_len
, n_conv
, h_fea_len
, and n_h
are the dimensions of the atom features, the number of convolutional layers, the dimension of the hidden features, and the number of hidden layers, respectively. Now, we can initialize the model by:
model = CrystalGraphConvNet(
orig_atom_fea_len=orig_atom_fea_len,
nbr_fea_len=nbr_fea_len,
atom_fea_len=atom_fea_len,
n_conv=n_conv,
h_fea_len=h_fea_len,
n_h=n_h,
)
4. Model Loading and Prediction¶
The checkpoint
from the pre-trained model includes the model's state dictionary and training arguments. The state dictionary contains all the learned parameters of the model, while the training arguments store the hyperparameters used during training. The model can be loaded onto either CPU or GPU device by specifying device
as cpu
or cuda
. Using GPU is recommended for faster inference if available.
model.load_state_dict(checkpoint["state_dict"])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
Now, we can run the prediction with the cgcnn_test
utility function:
cgcnn_test(
model=model,
loader=loader,
device=device,
plot_file=os.path.join(output_folder, "parity_plot.svg"),
)