Skip to content

API Reference

Data Module

cgcnn2.data

AtomCustomJSONInitializer

Bases: AtomInitializer

Initialize atom feature vectors using a JSON file, which is a python dictionary mapping from element number to a list representing the feature vector of the element.

Parameters:

Name Type Description Default
elem_embedding_file str

The path to the .json file

required
Source code in cgcnn2/data.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
class AtomCustomJSONInitializer(AtomInitializer):
    """
    Initialize atom feature vectors using a JSON file, which is a python
    dictionary mapping from element number to a list representing the
    feature vector of the element.

    Args:
        elem_embedding_file (str): The path to the .json file
    """

    def __init__(self, elem_embedding_file):
        with open(elem_embedding_file) as f:
            elem_embedding = json.load(f)
        elem_embedding = {int(key): value for key, value in elem_embedding.items()}
        atom_types = set(elem_embedding.keys())
        super(AtomCustomJSONInitializer, self).__init__(atom_types)
        for key, value in elem_embedding.items():
            self._embedding[key] = np.array(value, dtype=float)

AtomInitializer

Bases: object

Base class for initializing the vector representation for atoms.

Use one AtomInitializer per dataset.

Source code in cgcnn2/data.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
class AtomInitializer(object):
    """
    Base class for initializing the vector representation for atoms.

    Use one `AtomInitializer` per dataset.
    """

    def __init__(self, atom_types):
        self.atom_types = set(atom_types)
        self._embedding = {}

    def get_atom_fea(self, atom_type):
        assert atom_type in self.atom_types
        return self._embedding[atom_type]

    def load_state_dict(self, state_dict):
        self._embedding = state_dict
        self.atom_types = set(self._embedding.keys())
        self._decodedict = {
            idx: atom_type for atom_type, idx in self._embedding.items()
        }

    def state_dict(self):
        return self._embedding

    def decode(self, idx):
        if not hasattr(self, "_decodedict"):
            self._decodedict = {
                idx: atom_type for atom_type, idx in self._embedding.items()
            }
        return self._decodedict[idx]

CIFData

Bases: Dataset

The CIFData dataset is a wrapper for a dataset where the crystal structures are stored in the form of CIF files. The dataset should have the following directory structure:

root_dir |-- id_prop.csv |-- atom_init.json |-- id0.cif |-- id1.cif |-- ...

id_prop.csv: a CSV file with two columns. The first column recodes a unique ID for each crystal, and the second column recodes the value of target property.

atom_init.json: a JSON file that stores the initialization vector for each element.

ID.cif: a CIF file that recodes the crystal structure, where ID is the unique ID for the crystal.

Parameters:

Name Type Description Default
root_dir str

The path to the root directory of the dataset

required
max_num_nbr int

The maximum number of neighbors while constructing the crystal graph

12
radius float

The cutoff radius for searching neighbors

8
dmin float

The minimum distance for constructing GaussianDistance

0
step float

The step size for constructing GaussianDistance

0.2
random_seed int

Random seed for shuffling the dataset

123

Returns:

Name Type Description
atom_fea

torch.Tensor shape (n_i, atom_fea_len)

nbr_fea

torch.Tensor shape (n_i, M, nbr_fea_len)

nbr_fea_idx

torch.LongTensor shape (n_i, M)

target

torch.Tensor shape (1, )

cif_id

str or int

Source code in cgcnn2/data.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
class CIFData(Dataset):
    """
    The CIFData dataset is a wrapper for a dataset where the crystal structures
    are stored in the form of CIF files. The dataset should have the following
    directory structure:

    root_dir
    |-- id_prop.csv
    |-- atom_init.json
    |-- id0.cif
    |-- id1.cif
    |-- ...

    id_prop.csv: a CSV file with two columns. The first column recodes a
    unique ID for each crystal, and the second column recodes the value of
    target property.

    atom_init.json: a JSON file that stores the initialization vector for each
    element.

    ID.cif: a CIF file that recodes the crystal structure, where ID is the
    unique ID for the crystal.

    Args:
        root_dir (str): The path to the root directory of the dataset
        max_num_nbr (int): The maximum number of neighbors while constructing the crystal graph
        radius (float): The cutoff radius for searching neighbors
        dmin (float): The minimum distance for constructing GaussianDistance
        step (float): The step size for constructing GaussianDistance
        random_seed (int): Random seed for shuffling the dataset

    Returns:
        atom_fea: torch.Tensor shape (n_i, atom_fea_len)
        nbr_fea: torch.Tensor shape (n_i, M, nbr_fea_len)
        nbr_fea_idx: torch.LongTensor shape (n_i, M)
        target: torch.Tensor shape (1, )
        cif_id: str or int
    """

    def __init__(
        self, root_dir, max_num_nbr=12, radius=8, dmin=0, step=0.2, random_seed=123
    ):
        self.root_dir = root_dir
        self.max_num_nbr, self.radius = max_num_nbr, radius
        assert os.path.exists(root_dir), "root_dir does not exist!"
        id_prop_file = os.path.join(self.root_dir, "id_prop.csv")
        assert os.path.exists(id_prop_file), "id_prop.csv does not exist!"
        with open(id_prop_file) as f:
            reader = csv.reader(f)
            self.id_prop_data = [row for row in reader]
        random.seed(random_seed)
        atom_init_file = os.path.join(self.root_dir, "atom_init.json")
        assert os.path.exists(atom_init_file), "atom_init.json does not exist!"
        self.ari = AtomCustomJSONInitializer(atom_init_file)
        self.gdf = GaussianDistance(dmin=dmin, dmax=self.radius, step=step)

    def __len__(self):
        return len(self.id_prop_data)

    @functools.lru_cache(maxsize=1024)  # Cache loaded structures
    def __getitem__(self, idx):
        cif_id, target = self.id_prop_data[idx]
        crystal = Structure.from_file(os.path.join(self.root_dir, cif_id + ".cif"))
        atom_fea = np.vstack(
            [
                self.ari.get_atom_fea(crystal[i].specie.number)
                for i in range(len(crystal))
            ]
        )
        atom_fea = torch.Tensor(atom_fea)
        all_nbrs = crystal.get_all_neighbors(self.radius, include_index=True)
        all_nbrs = [sorted(nbrs, key=lambda x: x[1]) for nbrs in all_nbrs]
        nbr_fea_idx, nbr_fea = [], []
        for nbr in all_nbrs:
            if len(nbr) < self.max_num_nbr:
                warnings.warn(
                    "{} not find enough neighbors to build graph. "
                    "If it happens frequently, consider increase "
                    "radius.".format(cif_id),
                    stacklevel=2,
                )
                nbr_fea_idx.append(
                    list(map(lambda x: x[2], nbr)) + [0] * (self.max_num_nbr - len(nbr))
                )
                nbr_fea.append(
                    list(map(lambda x: x[1], nbr))
                    + [self.radius + 1.0] * (self.max_num_nbr - len(nbr))
                )
            else:
                nbr_fea_idx.append(list(map(lambda x: x[2], nbr[: self.max_num_nbr])))
                nbr_fea.append(list(map(lambda x: x[1], nbr[: self.max_num_nbr])))
        nbr_fea_idx, nbr_fea = np.array(nbr_fea_idx), np.array(nbr_fea)
        nbr_fea = self.gdf.expand(nbr_fea)
        atom_fea = torch.Tensor(atom_fea)
        nbr_fea = torch.Tensor(nbr_fea)
        nbr_fea_idx = torch.LongTensor(nbr_fea_idx)
        target = torch.Tensor([float(target)])
        return (atom_fea, nbr_fea, nbr_fea_idx), target, cif_id

CIFData_NoTarget

Bases: Dataset

The CIFData_NoTarget dataset is a wrapper for a dataset where the crystal structures are stored in the form of CIF files. The dataset should have the following directory structure:

root_dir |-- atom_init.json |-- id0.cif |-- id1.cif |-- ...

atom_init.json: a JSON file that stores the initialization vector for each element.

ID.cif: a CIF file that recodes the crystal structure, where ID is the unique ID for the crystal.

Parameters:

Name Type Description Default
root_dir str

The path to the root directory of the dataset

required
max_num_nbr int

The maximum number of neighbors while constructing the crystal graph

12
radius float

The cutoff radius for searching neighbors

8
dmin float

The minimum distance for constructing GaussianDistance

0
step float

The step size for constructing GaussianDistance

0.2
random_seed int

Random seed for shuffling the dataset

123

Returns:

Name Type Description
atom_fea

torch.Tensor shape (n_i, atom_fea_len)

nbr_fea

torch.Tensor shape (n_i, M, nbr_fea_len)

nbr_fea_idx

torch.LongTensor shape (n_i, M)

target

torch.Tensor shape (1, )

cif_id

str or int

Source code in cgcnn2/data.py
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
class CIFData_NoTarget(Dataset):
    """
    The CIFData_NoTarget dataset is a wrapper for a dataset where the crystal
    structures are stored in the form of CIF files. The dataset should have the
    following directory structure:

    root_dir
    |-- atom_init.json
    |-- id0.cif
    |-- id1.cif
    |-- ...

    atom_init.json: a JSON file that stores the initialization vector for each
    element.

    ID.cif: a CIF file that recodes the crystal structure, where ID is the
    unique ID for the crystal.

    Args:
        root_dir (str): The path to the root directory of the dataset
        max_num_nbr (int): The maximum number of neighbors while constructing the crystal graph
        radius (float): The cutoff radius for searching neighbors
        dmin (float): The minimum distance for constructing GaussianDistance
        step (float): The step size for constructing GaussianDistance
        random_seed (int): Random seed for shuffling the dataset

    Returns:
        atom_fea: torch.Tensor shape (n_i, atom_fea_len)
        nbr_fea: torch.Tensor shape (n_i, M, nbr_fea_len)
        nbr_fea_idx: torch.LongTensor shape (n_i, M)
        target: torch.Tensor shape (1, )
        cif_id: str or int
    """

    def __init__(
        self, root_dir, max_num_nbr=12, radius=8, dmin=0, step=0.2, random_seed=123
    ):
        self.root_dir = root_dir
        self.max_num_nbr, self.radius = max_num_nbr, radius
        assert os.path.exists(root_dir), "root_dir does not exist!"
        id_prop_data = []
        for file in os.listdir(root_dir):
            if file.endswith(".cif"):
                id_prop_data.append(file[:-4])
        id_prop_data = [(cif_id, 0) for cif_id in id_prop_data]
        id_prop_data.sort(key=lambda x: x[0])
        self.id_prop_data = id_prop_data
        random.seed(random_seed)
        atom_init_file = os.path.join(self.root_dir, "atom_init.json")
        assert os.path.exists(atom_init_file), "atom_init.json does not exist!"
        self.ari = AtomCustomJSONInitializer(atom_init_file)
        self.gdf = GaussianDistance(dmin=dmin, dmax=self.radius, step=step)

    def __len__(self):
        return len(self.id_prop_data)

    @functools.lru_cache(maxsize=1024)  # Cache loaded structures
    def __getitem__(self, idx):
        cif_id, target = self.id_prop_data[idx]
        crystal = Structure.from_file(os.path.join(self.root_dir, cif_id + ".cif"))
        atom_fea = np.vstack(
            [
                self.ari.get_atom_fea(crystal[i].specie.number)
                for i in range(len(crystal))
            ]
        )
        atom_fea = torch.Tensor(atom_fea)
        all_nbrs = crystal.get_all_neighbors(self.radius, include_index=True)
        all_nbrs = [sorted(nbrs, key=lambda x: x[1]) for nbrs in all_nbrs]
        nbr_fea_idx, nbr_fea = [], []
        for nbr in all_nbrs:
            if len(nbr) < self.max_num_nbr:
                warnings.warn(
                    "{} not find enough neighbors to build graph. "
                    "If it happens frequently, consider increase "
                    "radius.".format(cif_id),
                    stacklevel=2,
                )
                nbr_fea_idx.append(
                    list(map(lambda x: x[2], nbr)) + [0] * (self.max_num_nbr - len(nbr))
                )
                nbr_fea.append(
                    list(map(lambda x: x[1], nbr))
                    + [self.radius + 1.0] * (self.max_num_nbr - len(nbr))
                )
            else:
                nbr_fea_idx.append(list(map(lambda x: x[2], nbr[: self.max_num_nbr])))
                nbr_fea.append(list(map(lambda x: x[1], nbr[: self.max_num_nbr])))
        nbr_fea_idx, nbr_fea = np.array(nbr_fea_idx), np.array(nbr_fea)
        nbr_fea = self.gdf.expand(nbr_fea)
        atom_fea = torch.Tensor(atom_fea)
        nbr_fea = torch.Tensor(nbr_fea)
        nbr_fea_idx = torch.LongTensor(nbr_fea_idx)
        target = torch.Tensor([float(target)])
        return (atom_fea, nbr_fea, nbr_fea_idx), target, cif_id

GaussianDistance

Bases: object

Expands the distance by Gaussian basis.

Unit: angstrom

Source code in cgcnn2/data.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
class GaussianDistance(object):
    """
    Expands the distance by Gaussian basis.

    Unit: angstrom
    """

    def __init__(self, dmin, dmax, step, var=None):
        """
        Args:
            dmin (float): Minimum interatomic distance
            dmax (float): Maximum interatomic distance
            step (float): Step size for the Gaussian filter
        """
        assert dmin < dmax
        assert dmax - dmin > step
        self.filter = np.arange(dmin, dmax + step, step)
        if var is None:
            var = step
        self.var = var

    def expand(self, distances):
        """
        Apply Gaussian distance filter to a numpy distance array

        Args:
            distances (np.ndarray): A distance matrix of any shape

        Returns:
            expanded_distance: shape (n+1)-d array
              Expanded distance matrix with the last dimension of length
              len(self.filter)
        """
        return np.exp(-((distances[..., np.newaxis] - self.filter) ** 2) / self.var**2)

__init__(dmin, dmax, step, var=None)

Parameters:

Name Type Description Default
dmin float

Minimum interatomic distance

required
dmax float

Maximum interatomic distance

required
step float

Step size for the Gaussian filter

required
Source code in cgcnn2/data.py
87
88
89
90
91
92
93
94
95
96
97
98
99
def __init__(self, dmin, dmax, step, var=None):
    """
    Args:
        dmin (float): Minimum interatomic distance
        dmax (float): Maximum interatomic distance
        step (float): Step size for the Gaussian filter
    """
    assert dmin < dmax
    assert dmax - dmin > step
    self.filter = np.arange(dmin, dmax + step, step)
    if var is None:
        var = step
    self.var = var

expand(distances)

Apply Gaussian distance filter to a numpy distance array

Parameters:

Name Type Description Default
distances ndarray

A distance matrix of any shape

required

Returns:

Name Type Description
expanded_distance

shape (n+1)-d array Expanded distance matrix with the last dimension of length len(self.filter)

Source code in cgcnn2/data.py
101
102
103
104
105
106
107
108
109
110
111
112
113
def expand(self, distances):
    """
    Apply Gaussian distance filter to a numpy distance array

    Args:
        distances (np.ndarray): A distance matrix of any shape

    Returns:
        expanded_distance: shape (n+1)-d array
          Expanded distance matrix with the last dimension of length
          len(self.filter)
    """
    return np.exp(-((distances[..., np.newaxis] - self.filter) ** 2) / self.var**2)

collate_pool(dataset_list)

Collate a list of data and return a batch for predicting crystal properties.

Parameters:

Name Type Description Default
dataset_list list of tuples

List of tuples for each data point. Each tuple contains: - atom_fea (torch.Tensor): shape (n_i, atom_fea_len) Atom features for each atom in the crystal - nbr_fea (torch.Tensor): shape (n_i, M, nbr_fea_len) Bond features for each atom's M neighbors - nbr_fea_idx (torch.LongTensor): shape (n_i, M) Indices of M neighbors of each atom - target (torch.Tensor): shape (1, ) Target value for prediction - cif_id (str or int) Unique ID for the crystal

required

Returns:

Name Type Description

N = sum(n_i); N0 = sum(i)

batch_atom_fea

torch.Tensor shape (N, orig_atom_fea_len)

Atom features from atom type

batch_nbr_fea

torch.Tensor shape (N, M, nbr_fea_len)

Bond features of each atom's M neighbors

batch_nbr_fea_idx

torch.LongTensor shape (N, M)

Indices of M neighbors of each atom

crystal_atom_idx

list of torch.LongTensor of length N0

Mapping from the crystal idx to atom idx

target

torch.Tensor shape (N, 1)

Target value for prediction

batch_cif_ids

list

Source code in cgcnn2/data.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def collate_pool(dataset_list):
    """
    Collate a list of data and return a batch for predicting crystal
    properties.

    Args:
        dataset_list (list of tuples): List of tuples for each data point.
          Each tuple contains:
          - atom_fea (torch.Tensor): shape (n_i, atom_fea_len)
            Atom features for each atom in the crystal
          - nbr_fea (torch.Tensor): shape (n_i, M, nbr_fea_len)
            Bond features for each atom's M neighbors
          - nbr_fea_idx (torch.LongTensor): shape (n_i, M)
            Indices of M neighbors of each atom
          - target (torch.Tensor): shape (1, )
            Target value for prediction
          - cif_id (str or int)
            Unique ID for the crystal


    Returns:
        N = sum(n_i); N0 = sum(i)

        batch_atom_fea: torch.Tensor shape (N, orig_atom_fea_len)
        Atom features from atom type
        batch_nbr_fea: torch.Tensor shape (N, M, nbr_fea_len)
        Bond features of each atom's M neighbors
        batch_nbr_fea_idx: torch.LongTensor shape (N, M)
        Indices of M neighbors of each atom
        crystal_atom_idx: list of torch.LongTensor of length N0
        Mapping from the crystal idx to atom idx
        target: torch.Tensor shape (N, 1)
        Target value for prediction
        batch_cif_ids: list
    """
    batch_atom_fea, batch_nbr_fea, batch_nbr_fea_idx = [], [], []
    crystal_atom_idx, batch_target = [], []
    batch_cif_ids = []
    base_idx = 0
    for i, ((atom_fea, nbr_fea, nbr_fea_idx), target, cif_id) in enumerate(
        dataset_list
    ):
        n_i = atom_fea.shape[0]  # number of atoms for this crystal
        batch_atom_fea.append(atom_fea)
        batch_nbr_fea.append(nbr_fea)
        batch_nbr_fea_idx.append(nbr_fea_idx + base_idx)
        new_idx = torch.LongTensor(np.arange(n_i) + base_idx)
        crystal_atom_idx.append(new_idx)
        batch_target.append(target)
        batch_cif_ids.append(cif_id)
        base_idx += n_i
    return (
        (
            torch.cat(batch_atom_fea, dim=0),
            torch.cat(batch_nbr_fea, dim=0),
            torch.cat(batch_nbr_fea_idx, dim=0),
            crystal_atom_idx,
        ),
        torch.stack(batch_target, dim=0),
        batch_cif_ids,
    )

train_force_split(total_set, train_ratio_force_set, train_ratio)

Set up a training dataset with a forced training set.

Parameters:

Name Type Description Default
total_set str

The path to the total set

required
train_ratio_force_set str

The path to the forced training set

required
train_ratio float

The ratio of the training set

required

Returns:

Name Type Description
train_dataset

CIFData The training dataset

valid_test_dataset

CIFData The validation set

Source code in cgcnn2/data.py
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
def train_force_split(total_set, train_ratio_force_set, train_ratio):
    """
    Set up a training dataset with a forced training set.

    Args:
        total_set (str): The path to the total set
        train_ratio_force_set (str): The path to the forced training set
        train_ratio (float): The ratio of the training set

    Returns:
        train_dataset: CIFData
            The training dataset
        valid_test_dataset: CIFData
            The validation set
    """
    # create a new temporary directory for the training set
    temp_train_dir = tempfile.mkdtemp()
    temp_valid_test_dir = tempfile.mkdtemp()

    shutil.copy(f"{total_set}/atom_init.json", temp_train_dir)
    shutil.copy(f"{total_set}/atom_init.json", temp_valid_test_dir)

    # concatenate the two csv files in the temp_train_dir
    train_force_csv = pd.read_csv(f"{train_ratio_force_set}/id_prop.csv", header=None)
    split_csv = pd.read_csv(f"{total_set}/id_prop.csv", header=None)
    total_csv = pd.concat([train_force_csv, split_csv])

    train_force_cif_files = [
        f for f in os.listdir(train_ratio_force_set) if f.endswith(".cif")
    ]
    total_cif_files = [f for f in os.listdir(total_set) if f.endswith(".cif")]

    for file in train_force_cif_files:
        shutil.copy(
            os.path.join(train_ratio_force_set, file),
            os.path.join(temp_train_dir, file),
        )

    train_force_size = len(train_force_cif_files)
    total_size = len(total_cif_files)
    train_size = int(total_size * train_ratio)
    train_split_size = int(max(train_size - train_force_size, 0))

    if train_split_size > 0:
        train_split_cif_files = random.sample(total_cif_files, train_split_size)
        valid_test_cif_files = [
            f for f in total_cif_files if f not in train_split_cif_files
        ]
        valid_test_cif_ids = [f[:-4] for f in valid_test_cif_files]

        for file in train_split_cif_files:
            shutil.copy(
                os.path.join(total_set, file),
                os.path.join(temp_train_dir, file),
            )

        for file in valid_test_cif_files:
            shutil.copy(
                os.path.join(total_set, file),
                os.path.join(temp_valid_test_dir, file),
            )

        train_csv = total_csv[~total_csv[total_csv.columns[0]].isin(valid_test_cif_ids)]
        train_csv.to_csv(f"{temp_train_dir}/id_prop.csv", index=False, header=False)

        valid_test_csv = total_csv[
            total_csv[total_csv.columns[0]].isin(valid_test_cif_ids)
        ]
        valid_test_csv.to_csv(
            f"{temp_valid_test_dir}/id_prop.csv", index=False, header=False
        )

        train_dataset = CIFData(temp_train_dir)
        valid_test_dataset = CIFData(temp_valid_test_dir)

        return train_dataset, valid_test_dataset

    else:
        raise ValueError(
            f"Forced training set is larger than expected training set. Expected: {train_size}, Forced: {train_force_size}"
        )

Model Framework

cgcnn2.model

ConvLayer

Bases: Module

Convolutional layer for graph data.

Performs a convolutional operation on graphs, updating atom features based on their neighbors.

Source code in cgcnn2/model.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class ConvLayer(nn.Module):
    """
    Convolutional layer for graph data.

    Performs a convolutional operation on graphs, updating atom features based on their neighbors.
    """

    def __init__(self, atom_fea_len: int, nbr_fea_len: int) -> None:
        """
        Initialize the ConvLayer.

        Args:
            atom_fea_len (int): Number of atom hidden features.
            nbr_fea_len (int): Number of bond (neighbor) features.
        """
        super(ConvLayer, self).__init__()
        self.atom_fea_len = atom_fea_len
        self.nbr_fea_len = nbr_fea_len
        self.fc_full = nn.Linear(
            2 * self.atom_fea_len + self.nbr_fea_len, 2 * self.atom_fea_len
        )
        self.sigmoid = nn.Sigmoid()
        self.softplus1 = nn.Softplus()
        self.bn1 = nn.BatchNorm1d(2 * self.atom_fea_len)
        self.bn2 = nn.BatchNorm1d(self.atom_fea_len)
        self.softplus2 = nn.Softplus()

    def forward(self, atom_in_fea, nbr_fea, nbr_fea_idx):
        """
        Forward pass

        N: Total number of atoms in the batch
        M: Max number of neighbors

        Args:
            atom_in_fea (torch.Tensor): Variable(torch.Tensor) shape (N, atom_fea_len)
              Atom hidden features before convolution
            nbr_fea (torch.Tensor): Variable(torch.Tensor) shape (N, M, nbr_fea_len)
              Bond features of each atom's M neighbors
            nbr_fea_idx (torch.LongTensor): shape (N, M)
              Indices of M neighbors of each atom

        Returns:
            atom_out_fea (nn.Variable): shape (N, atom_fea_len)
              Atom hidden features after convolution

        """
        N, M = nbr_fea_idx.shape
        # convolution
        atom_nbr_fea = atom_in_fea[nbr_fea_idx, :]
        total_nbr_fea = torch.cat(
            [
                atom_in_fea.unsqueeze(1).expand(N, M, self.atom_fea_len),
                atom_nbr_fea,
                nbr_fea,
            ],
            dim=2,
        )
        total_gated_fea = self.fc_full(total_nbr_fea)
        total_gated_fea = self.bn1(
            total_gated_fea.view(-1, self.atom_fea_len * 2)
        ).view(N, M, self.atom_fea_len * 2)
        nbr_filter, nbr_core = total_gated_fea.chunk(2, dim=2)
        nbr_filter = self.sigmoid(nbr_filter)
        nbr_core = self.softplus1(nbr_core)
        nbr_sumed = torch.sum(nbr_filter * nbr_core, dim=1)
        nbr_sumed = self.bn2(nbr_sumed)
        out = self.softplus2(atom_in_fea + nbr_sumed)
        return out

__init__(atom_fea_len, nbr_fea_len)

Initialize the ConvLayer.

Parameters:

Name Type Description Default
atom_fea_len int

Number of atom hidden features.

required
nbr_fea_len int

Number of bond (neighbor) features.

required
Source code in cgcnn2/model.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def __init__(self, atom_fea_len: int, nbr_fea_len: int) -> None:
    """
    Initialize the ConvLayer.

    Args:
        atom_fea_len (int): Number of atom hidden features.
        nbr_fea_len (int): Number of bond (neighbor) features.
    """
    super(ConvLayer, self).__init__()
    self.atom_fea_len = atom_fea_len
    self.nbr_fea_len = nbr_fea_len
    self.fc_full = nn.Linear(
        2 * self.atom_fea_len + self.nbr_fea_len, 2 * self.atom_fea_len
    )
    self.sigmoid = nn.Sigmoid()
    self.softplus1 = nn.Softplus()
    self.bn1 = nn.BatchNorm1d(2 * self.atom_fea_len)
    self.bn2 = nn.BatchNorm1d(self.atom_fea_len)
    self.softplus2 = nn.Softplus()

forward(atom_in_fea, nbr_fea, nbr_fea_idx)

Forward pass

N: Total number of atoms in the batch M: Max number of neighbors

Parameters:

Name Type Description Default
atom_in_fea Tensor

Variable(torch.Tensor) shape (N, atom_fea_len) Atom hidden features before convolution

required
nbr_fea Tensor

Variable(torch.Tensor) shape (N, M, nbr_fea_len) Bond features of each atom's M neighbors

required
nbr_fea_idx LongTensor

shape (N, M) Indices of M neighbors of each atom

required

Returns:

Name Type Description
atom_out_fea Variable

shape (N, atom_fea_len) Atom hidden features after convolution

Source code in cgcnn2/model.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def forward(self, atom_in_fea, nbr_fea, nbr_fea_idx):
    """
    Forward pass

    N: Total number of atoms in the batch
    M: Max number of neighbors

    Args:
        atom_in_fea (torch.Tensor): Variable(torch.Tensor) shape (N, atom_fea_len)
          Atom hidden features before convolution
        nbr_fea (torch.Tensor): Variable(torch.Tensor) shape (N, M, nbr_fea_len)
          Bond features of each atom's M neighbors
        nbr_fea_idx (torch.LongTensor): shape (N, M)
          Indices of M neighbors of each atom

    Returns:
        atom_out_fea (nn.Variable): shape (N, atom_fea_len)
          Atom hidden features after convolution

    """
    N, M = nbr_fea_idx.shape
    # convolution
    atom_nbr_fea = atom_in_fea[nbr_fea_idx, :]
    total_nbr_fea = torch.cat(
        [
            atom_in_fea.unsqueeze(1).expand(N, M, self.atom_fea_len),
            atom_nbr_fea,
            nbr_fea,
        ],
        dim=2,
    )
    total_gated_fea = self.fc_full(total_nbr_fea)
    total_gated_fea = self.bn1(
        total_gated_fea.view(-1, self.atom_fea_len * 2)
    ).view(N, M, self.atom_fea_len * 2)
    nbr_filter, nbr_core = total_gated_fea.chunk(2, dim=2)
    nbr_filter = self.sigmoid(nbr_filter)
    nbr_core = self.softplus1(nbr_core)
    nbr_sumed = torch.sum(nbr_filter * nbr_core, dim=1)
    nbr_sumed = self.bn2(nbr_sumed)
    out = self.softplus2(atom_in_fea + nbr_sumed)
    return out

CrystalGraphConvNet

Bases: Module

Create a crystal graph convolutional neural network for predicting total material properties.

Source code in cgcnn2/model.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
class CrystalGraphConvNet(nn.Module):
    """
    Create a crystal graph convolutional neural network for predicting total
    material properties.
    """

    def __init__(
        self,
        orig_atom_fea_len: int,
        nbr_fea_len: int,
        atom_fea_len: int = 64,
        n_conv: int = 3,
        h_fea_len: int = 128,
        n_h: int = 1,
        classification: bool = False,
    ) -> None:
        """
        Initialize CrystalGraphConvNet.

        Args:
            orig_atom_fea_len (int): Number of atom features in the input.
            nbr_fea_len (int): Number of bond features.
            atom_fea_len (int): Number of hidden atom features in the convolutional layers
            n_conv (int): Number of convolutional layers
            h_fea_len (int): Number of hidden features after pooling
            n_h (int): Number of hidden layers after pooling
            classification (bool): Whether to use classification or regression
        """
        super(CrystalGraphConvNet, self).__init__()
        self.classification = classification
        self.embedding = nn.Linear(orig_atom_fea_len, atom_fea_len)
        self.convs = nn.ModuleList(
            [
                ConvLayer(atom_fea_len=atom_fea_len, nbr_fea_len=nbr_fea_len)
                for _ in range(n_conv)
            ]
        )
        self.conv_to_fc = nn.Linear(atom_fea_len, h_fea_len)
        self.conv_to_fc_softplus = nn.Softplus()
        if n_h > 1:
            self.fcs = nn.ModuleList(
                [nn.Linear(h_fea_len, h_fea_len) for _ in range(n_h - 1)]
            )
            self.softpluses = nn.ModuleList([nn.Softplus() for _ in range(n_h - 1)])

        if self.classification:
            self.fc_out = nn.Linear(h_fea_len, 2)
        else:
            self.fc_out = nn.Linear(h_fea_len, 1)

        if self.classification:
            self.logsoftmax = nn.LogSoftmax(dim=1)
            self.dropout = nn.Dropout()

    def forward(
        self,
        atom_fea: torch.Tensor,
        nbr_fea: torch.Tensor,
        nbr_fea_idx: torch.LongTensor,
        crystal_atom_idx: list[torch.LongTensor],
    ):
        """
        Forward pass

        N: Total number of atoms in the batch
        M: Max number of neighbors
        N0: Total number of crystals in the batch

        Args:
            atom_fea (torch.Tensor): Variable(torch.Tensor) shape (N, orig_atom_fea_len)
              Atom features from atom type
            nbr_fea (torch.Tensor): Variable(torch.Tensor) shape (N, M, nbr_fea_len)
              Bond features of each atom's M neighbors
            nbr_fea_idx (torch.LongTensor): shape (N, M)
              Indices of M neighbors of each atom
            crystal_atom_idx (list of torch.LongTensor): Mapping from the crystal idx to atom idx

        Returns:
            prediction (nn.Variable): shape (N, )
              Atom hidden features after convolution

        """
        atom_fea = self.embedding(atom_fea)
        for conv_func in self.convs:
            atom_fea = conv_func(atom_fea, nbr_fea, nbr_fea_idx)
        crys_fea = self.pooling(atom_fea, crystal_atom_idx)
        crys_fea = self.conv_to_fc(self.conv_to_fc_softplus(crys_fea))
        crys_fea = self.conv_to_fc_softplus(crys_fea)
        if self.classification:
            crys_fea = self.dropout(crys_fea)
        if hasattr(self, "fcs") and hasattr(self, "softpluses"):
            for fc, softplus in zip(self.fcs, self.softpluses):
                crys_fea = softplus(fc(crys_fea))
        out = self.fc_out(crys_fea)
        if self.classification:
            out = self.logsoftmax(out)
        return out, crys_fea

    def pooling(
        self, atom_fea: torch.Tensor, crystal_atom_idx: list[torch.LongTensor]
    ) -> torch.Tensor:
        """
        Pooling the atom features to crystal features

        N: Total number of atoms in the batch
        N0: Total number of crystals in the batch

        Args:
            atom_fea (torch.Tensor): Variable(torch.Tensor) shape (N, atom_fea_len)
              Atom feature vectors of the batch
            crystal_atom_idx (list of torch.LongTensor): Mapping from the crystal idx to atom idx
        """
        assert (
            sum([len(idx_map) for idx_map in crystal_atom_idx])
            == atom_fea.data.shape[0]
        )
        summed_fea = [
            torch.mean(atom_fea[idx_map], dim=0, keepdim=True)
            for idx_map in crystal_atom_idx
        ]
        return torch.cat(summed_fea, dim=0)

__init__(orig_atom_fea_len, nbr_fea_len, atom_fea_len=64, n_conv=3, h_fea_len=128, n_h=1, classification=False)

Initialize CrystalGraphConvNet.

Parameters:

Name Type Description Default
orig_atom_fea_len int

Number of atom features in the input.

required
nbr_fea_len int

Number of bond features.

required
atom_fea_len int

Number of hidden atom features in the convolutional layers

64
n_conv int

Number of convolutional layers

3
h_fea_len int

Number of hidden features after pooling

128
n_h int

Number of hidden layers after pooling

1
classification bool

Whether to use classification or regression

False
Source code in cgcnn2/model.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def __init__(
    self,
    orig_atom_fea_len: int,
    nbr_fea_len: int,
    atom_fea_len: int = 64,
    n_conv: int = 3,
    h_fea_len: int = 128,
    n_h: int = 1,
    classification: bool = False,
) -> None:
    """
    Initialize CrystalGraphConvNet.

    Args:
        orig_atom_fea_len (int): Number of atom features in the input.
        nbr_fea_len (int): Number of bond features.
        atom_fea_len (int): Number of hidden atom features in the convolutional layers
        n_conv (int): Number of convolutional layers
        h_fea_len (int): Number of hidden features after pooling
        n_h (int): Number of hidden layers after pooling
        classification (bool): Whether to use classification or regression
    """
    super(CrystalGraphConvNet, self).__init__()
    self.classification = classification
    self.embedding = nn.Linear(orig_atom_fea_len, atom_fea_len)
    self.convs = nn.ModuleList(
        [
            ConvLayer(atom_fea_len=atom_fea_len, nbr_fea_len=nbr_fea_len)
            for _ in range(n_conv)
        ]
    )
    self.conv_to_fc = nn.Linear(atom_fea_len, h_fea_len)
    self.conv_to_fc_softplus = nn.Softplus()
    if n_h > 1:
        self.fcs = nn.ModuleList(
            [nn.Linear(h_fea_len, h_fea_len) for _ in range(n_h - 1)]
        )
        self.softpluses = nn.ModuleList([nn.Softplus() for _ in range(n_h - 1)])

    if self.classification:
        self.fc_out = nn.Linear(h_fea_len, 2)
    else:
        self.fc_out = nn.Linear(h_fea_len, 1)

    if self.classification:
        self.logsoftmax = nn.LogSoftmax(dim=1)
        self.dropout = nn.Dropout()

forward(atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)

Forward pass

N: Total number of atoms in the batch M: Max number of neighbors N0: Total number of crystals in the batch

Parameters:

Name Type Description Default
atom_fea Tensor

Variable(torch.Tensor) shape (N, orig_atom_fea_len) Atom features from atom type

required
nbr_fea Tensor

Variable(torch.Tensor) shape (N, M, nbr_fea_len) Bond features of each atom's M neighbors

required
nbr_fea_idx LongTensor

shape (N, M) Indices of M neighbors of each atom

required
crystal_atom_idx list of torch.LongTensor

Mapping from the crystal idx to atom idx

required

Returns:

Name Type Description
prediction Variable

shape (N, ) Atom hidden features after convolution

Source code in cgcnn2/model.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def forward(
    self,
    atom_fea: torch.Tensor,
    nbr_fea: torch.Tensor,
    nbr_fea_idx: torch.LongTensor,
    crystal_atom_idx: list[torch.LongTensor],
):
    """
    Forward pass

    N: Total number of atoms in the batch
    M: Max number of neighbors
    N0: Total number of crystals in the batch

    Args:
        atom_fea (torch.Tensor): Variable(torch.Tensor) shape (N, orig_atom_fea_len)
          Atom features from atom type
        nbr_fea (torch.Tensor): Variable(torch.Tensor) shape (N, M, nbr_fea_len)
          Bond features of each atom's M neighbors
        nbr_fea_idx (torch.LongTensor): shape (N, M)
          Indices of M neighbors of each atom
        crystal_atom_idx (list of torch.LongTensor): Mapping from the crystal idx to atom idx

    Returns:
        prediction (nn.Variable): shape (N, )
          Atom hidden features after convolution

    """
    atom_fea = self.embedding(atom_fea)
    for conv_func in self.convs:
        atom_fea = conv_func(atom_fea, nbr_fea, nbr_fea_idx)
    crys_fea = self.pooling(atom_fea, crystal_atom_idx)
    crys_fea = self.conv_to_fc(self.conv_to_fc_softplus(crys_fea))
    crys_fea = self.conv_to_fc_softplus(crys_fea)
    if self.classification:
        crys_fea = self.dropout(crys_fea)
    if hasattr(self, "fcs") and hasattr(self, "softpluses"):
        for fc, softplus in zip(self.fcs, self.softpluses):
            crys_fea = softplus(fc(crys_fea))
    out = self.fc_out(crys_fea)
    if self.classification:
        out = self.logsoftmax(out)
    return out, crys_fea

pooling(atom_fea, crystal_atom_idx)

Pooling the atom features to crystal features

N: Total number of atoms in the batch N0: Total number of crystals in the batch

Parameters:

Name Type Description Default
atom_fea Tensor

Variable(torch.Tensor) shape (N, atom_fea_len) Atom feature vectors of the batch

required
crystal_atom_idx list of torch.LongTensor

Mapping from the crystal idx to atom idx

required
Source code in cgcnn2/model.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
def pooling(
    self, atom_fea: torch.Tensor, crystal_atom_idx: list[torch.LongTensor]
) -> torch.Tensor:
    """
    Pooling the atom features to crystal features

    N: Total number of atoms in the batch
    N0: Total number of crystals in the batch

    Args:
        atom_fea (torch.Tensor): Variable(torch.Tensor) shape (N, atom_fea_len)
          Atom feature vectors of the batch
        crystal_atom_idx (list of torch.LongTensor): Mapping from the crystal idx to atom idx
    """
    assert (
        sum([len(idx_map) for idx_map in crystal_atom_idx])
        == atom_fea.data.shape[0]
    )
    summed_fea = [
        torch.mean(atom_fea[idx_map], dim=0, keepdim=True)
        for idx_map in crystal_atom_idx
    ]
    return torch.cat(summed_fea, dim=0)

Utility Function

cgcnn2.util

Normalizer

Normalizes a PyTorch tensor and allows restoring it later.

This class keeps track of the mean and standard deviation of a tensor and provides methods to normalize and denormalize tensors using these statistics.

Source code in cgcnn2/util.py
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
class Normalizer:
    """
    Normalizes a PyTorch tensor and allows restoring it later.

    This class keeps track of the mean and standard deviation of a tensor and provides methods
    to normalize and denormalize tensors using these statistics.
    """

    def __init__(self, tensor: torch.Tensor) -> None:
        """
        Initialize the Normalizer with a sample tensor to calculate mean and standard deviation.

        Args:
            tensor (torch.Tensor): Sample tensor to compute mean and standard deviation.
        """
        self.mean: torch.Tensor = torch.mean(tensor)
        self.std: torch.Tensor = torch.std(tensor)

    def norm(self, tensor: torch.Tensor) -> torch.Tensor:
        """
        Normalize a tensor using the stored mean and standard deviation.

        Args:
            tensor (torch.Tensor): Tensor to normalize.

        Returns:
            torch.Tensor: Normalized tensor.
        """
        return (tensor - self.mean) / self.std

    def denorm(self, normed_tensor: torch.Tensor) -> torch.Tensor:
        """
        Denormalize a tensor using the stored mean and standard deviation.

        Args:
            normed_tensor (torch.Tensor): Normalized tensor to denormalize.

        Returns:
            torch.Tensor: Denormalized tensor.
        """
        return normed_tensor * self.std + self.mean

    def state_dict(self) -> dict[str, torch.Tensor]:
        """
        Returns the state dictionary containing the mean and standard deviation.

        Returns:
            dict[str, torch.Tensor]: State dictionary.
        """
        return {"mean": self.mean, "std": self.std}

    def load_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None:
        """
        Loads the mean and standard deviation from a state dictionary.

        Args:
            state_dict (dict[str, torch.Tensor]): State dictionary containing 'mean' and 'std'.
        """
        self.mean = state_dict["mean"]
        self.std = state_dict["std"]

__init__(tensor)

Initialize the Normalizer with a sample tensor to calculate mean and standard deviation.

Parameters:

Name Type Description Default
tensor Tensor

Sample tensor to compute mean and standard deviation.

required
Source code in cgcnn2/util.py
414
415
416
417
418
419
420
421
422
def __init__(self, tensor: torch.Tensor) -> None:
    """
    Initialize the Normalizer with a sample tensor to calculate mean and standard deviation.

    Args:
        tensor (torch.Tensor): Sample tensor to compute mean and standard deviation.
    """
    self.mean: torch.Tensor = torch.mean(tensor)
    self.std: torch.Tensor = torch.std(tensor)

denorm(normed_tensor)

Denormalize a tensor using the stored mean and standard deviation.

Parameters:

Name Type Description Default
normed_tensor Tensor

Normalized tensor to denormalize.

required

Returns:

Type Description
Tensor

torch.Tensor: Denormalized tensor.

Source code in cgcnn2/util.py
436
437
438
439
440
441
442
443
444
445
446
def denorm(self, normed_tensor: torch.Tensor) -> torch.Tensor:
    """
    Denormalize a tensor using the stored mean and standard deviation.

    Args:
        normed_tensor (torch.Tensor): Normalized tensor to denormalize.

    Returns:
        torch.Tensor: Denormalized tensor.
    """
    return normed_tensor * self.std + self.mean

load_state_dict(state_dict)

Loads the mean and standard deviation from a state dictionary.

Parameters:

Name Type Description Default
state_dict dict[str, Tensor]

State dictionary containing 'mean' and 'std'.

required
Source code in cgcnn2/util.py
457
458
459
460
461
462
463
464
465
def load_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None:
    """
    Loads the mean and standard deviation from a state dictionary.

    Args:
        state_dict (dict[str, torch.Tensor]): State dictionary containing 'mean' and 'std'.
    """
    self.mean = state_dict["mean"]
    self.std = state_dict["std"]

norm(tensor)

Normalize a tensor using the stored mean and standard deviation.

Parameters:

Name Type Description Default
tensor Tensor

Tensor to normalize.

required

Returns:

Type Description
Tensor

torch.Tensor: Normalized tensor.

Source code in cgcnn2/util.py
424
425
426
427
428
429
430
431
432
433
434
def norm(self, tensor: torch.Tensor) -> torch.Tensor:
    """
    Normalize a tensor using the stored mean and standard deviation.

    Args:
        tensor (torch.Tensor): Tensor to normalize.

    Returns:
        torch.Tensor: Normalized tensor.
    """
    return (tensor - self.mean) / self.std

state_dict()

Returns the state dictionary containing the mean and standard deviation.

Returns:

Type Description
dict[str, Tensor]

dict[str, torch.Tensor]: State dictionary.

Source code in cgcnn2/util.py
448
449
450
451
452
453
454
455
def state_dict(self) -> dict[str, torch.Tensor]:
    """
    Returns the state dictionary containing the mean and standard deviation.

    Returns:
        dict[str, torch.Tensor]: State dictionary.
    """
    return {"mean": self.mean, "std": self.std}

cgcnn_descriptor(model, loader, device, verbose)

This function takes a pre-trained CGCNN model and a dataset, runs inference to generate predictions and features from the last layer, and returns the predictions and features. It is not necessary to have target values for the predicted set.

Parameters:

Name Type Description Default
model Module

The trained CGCNN model.

required
loader DataLoader

DataLoader for the dataset.

required
device str

The device ('cuda' or 'cpu') where the model will be run.

required
verbose int

The verbosity level of the output.

required

Returns:

Name Type Description
tuple tuple[list[float], list[Tensor]]

A tuple containing: - list: Model predictions - list: Crystal features from the last layer

Notes

This function is intended for use in programmatic downstream analysis, where the user wants to continue downstream analysis using predictions or features (descriptors) generated by the model. For the command-line interface, consider using the cgcnn_pr script instead.

Source code in cgcnn2/util.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
def cgcnn_descriptor(
    model: torch.nn.Module,
    loader: torch.utils.data.DataLoader,
    device: str,
    verbose: int,
) -> tuple[list[float], list[torch.Tensor]]:
    """
    This function takes a pre-trained CGCNN model and a dataset, runs inference
    to generate predictions and features from the last layer, and returns the
    predictions and features. It is not necessary to have target values for the
    predicted set.

    Args:
        model (torch.nn.Module): The trained CGCNN model.
        loader (torch.utils.data.DataLoader): DataLoader for the dataset.
        device (str): The device ('cuda' or 'cpu') where the model will be run.
        verbose (int): The verbosity level of the output.

    Returns:
        tuple: A tuple containing:
            - list: Model predictions
            - list: Crystal features from the last layer

    Notes:
        This function is intended for use in programmatic downstream analysis,
        where the user wants to continue downstream analysis using predictions or
        features (descriptors) generated by the model. For the command-line interface,
        consider using the cgcnn_pr script instead.
    """

    model.eval()
    targets_list = []
    outputs_list = []
    crys_feas_list = []
    index = 0

    with torch.no_grad():
        for input, target, cif_id in loader:
            atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx = input
            atom_fea = atom_fea.to(device)
            nbr_fea = nbr_fea.to(device)
            nbr_fea_idx = nbr_fea_idx.to(device)
            crystal_atom_idx = [idx_map.to(device) for idx_map in crystal_atom_idx]
            target = target.to(device)

            output, crys_fea = model(atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)

            targets_list.extend(target.cpu().numpy().ravel().tolist())
            outputs_list.extend(output.cpu().numpy().ravel().tolist())
            crys_feas_list.append(crys_fea.cpu().numpy())

            index += 1

            # Extract the actual values from cif_id and output tensor
            cif_id_value = cif_id[0] if cif_id and isinstance(cif_id, list) else cif_id
            prediction_value = output.item() if output.numel() == 1 else output.tolist()

            if verbose >= 4:
                print(
                    "index:",
                    index,
                    "| cif id:",
                    cif_id_value,
                    "| prediction:",
                    prediction_value,
                )

    return outputs_list, crys_feas_list

cgcnn_pred(model_path, full_set, verbose=4, cuda=False, num_workers=0)

This function takes the path to a pre-trained CGCNN model and a dataset, runs inference to generate predictions, and returns the predictions. It is not necessary to have target values for the predicted set.

Parameters:

Name Type Description Default
model_path str

Path to the file containing the pre-trained model parameters.

required
full_set str

Path to the directory containing all CIF files for the dataset.

required
verbose int

Verbosity level of the output. Defaults to 4.

4
cuda bool

Whether to use CUDA. Defaults to False.

False
num_workers int

Number of subprocesses for data loading. Defaults to 0.

0

Returns:

Name Type Description
tuple tuple[list[float], list[Tensor]]

A tuple containing: - list: Model predictions - list: Features from the last layer

Notes

This function is intended for use in programmatic downstream analysis, where the user wants to continue downstream analysis using predictions or features (descriptors) generated by the model. For the command-line interface, consider using the cgcnn_pr script instead.

Source code in cgcnn2/util.py
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
def cgcnn_pred(
    model_path: str,
    full_set: str,
    verbose: int = 4,
    cuda: bool = False,
    num_workers: int = 0,
) -> tuple[list[float], list[torch.Tensor]]:
    """
    This function takes the path to a pre-trained CGCNN model and a dataset,
    runs inference to generate predictions, and returns the predictions. It is
    not necessary to have target values for the predicted set.

    Args:
        model_path (str): Path to the file containing the pre-trained model parameters.
        full_set (str): Path to the directory containing all CIF files for the dataset.
        verbose (int, optional): Verbosity level of the output. Defaults to 4.
        cuda (bool, optional): Whether to use CUDA. Defaults to False.
        num_workers (int, optional): Number of subprocesses for data loading. Defaults to 0.

    Returns:
        tuple: A tuple containing:
            - list: Model predictions
            - list: Features from the last layer

    Notes:
        This function is intended for use in programmatic downstream analysis,
        where the user wants to continue downstream analysis using predictions or
        features (descriptors) generated by the model. For the command-line interface,
        consider using the cgcnn_pr script instead.
    """
    if not os.path.isfile(model_path):
        raise FileNotFoundError(f"=> No model params found at '{model_path}'")

    total_dataset = CIFData_NoTarget(full_set)

    checkpoint = torch.load(
        model_path,
        map_location=lambda storage, loc: storage if not cuda else None,
        weights_only=False,
    )
    structures, _, _ = total_dataset[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    model_args = argparse.Namespace(**checkpoint["args"])
    model = CrystalGraphConvNet(
        orig_atom_fea_len,
        nbr_fea_len,
        atom_fea_len=model_args.atom_fea_len,
        n_conv=model_args.n_conv,
        h_fea_len=model_args.h_fea_len,
        n_h=model_args.n_h,
    )
    if cuda:
        model.cuda()

    normalizer = Normalizer(torch.zeros(3))
    normalizer.load_state_dict(checkpoint["normalizer"])
    model.load_state_dict(checkpoint["state_dict"])

    if verbose >= 3:
        print(
            f"=> Loaded model from '{model_path}' (epoch {checkpoint['epoch']}, validation error {checkpoint['best_mae_error']})"
        )

    device = "cuda" if cuda else "cpu"
    model.to(device).eval()

    full_loader = DataLoader(
        total_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_pool,
        pin_memory=cuda,
    )

    pred, last_layer = cgcnn_descriptor(model, full_loader, device, verbose)

    return pred, last_layer

cgcnn_test(model, loader, device, plot_file='parity_plot.svg', results_file='results.csv', axis_limits=None, **kwargs)

This function takes a pre-trained CGCNN model and a test dataset, runs inference to generate predictions, creates a parity plot comparing predicted versus actual values, and writes the results to a CSV file.

Parameters:

Name Type Description Default
model Module

The pre-trained CGCNN model.

required
loader DataLoader

DataLoader for the dataset.

required
device str

The device ('cuda' or 'cpu') where the model will be run.

required
plot_file str

File path for saving the parity plot. Defaults to 'parity_plot.svg'.

'parity_plot.svg'
results_file str

File path for saving results as CSV. Defaults to 'results.csv'.

'results.csv'
axis_limits list

Limits for x and y axes of the parity plot. Defaults to None.

None
**kwargs Any

Additional keyword arguments: xlabel (str): x-axis label for the parity plot. Defaults to "Actual". ylabel (str): y-axis label for the parity plot. Defaults to "Predicted".

{}

Returns:

Type Description
None

None

Notes

This function is intended for use in a command-line interface, providing direct output of results. For programmatic downstream analysis, consider using the API functions instead, i.e. cgcnn_pred and cgcnn_descriptor.

Source code in cgcnn2/util.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
def cgcnn_test(
    model: torch.nn.Module,
    loader: torch.utils.data.DataLoader,
    device: str,
    plot_file: str = "parity_plot.svg",
    results_file: str = "results.csv",
    axis_limits: list[float] | None = None,
    **kwargs: Any,
) -> None:
    """
    This function takes a pre-trained CGCNN model and a test dataset, runs
    inference to generate predictions, creates a parity plot comparing predicted
    versus actual values, and writes the results to a CSV file.

    Args:
        model (torch.nn.Module): The pre-trained CGCNN model.
        loader (torch.utils.data.DataLoader): DataLoader for the dataset.
        device (str): The device ('cuda' or 'cpu') where the model will be run.
        plot_file (str, optional): File path for saving the parity plot. Defaults to 'parity_plot.svg'.
        results_file (str, optional): File path for saving results as CSV. Defaults to 'results.csv'.
        axis_limits (list, optional): Limits for x and y axes of the parity plot. Defaults to None.
        **kwargs: Additional keyword arguments:
            xlabel (str): x-axis label for the parity plot. Defaults to "Actual".
            ylabel (str): y-axis label for the parity plot. Defaults to "Predicted".

    Returns:
        None

    Notes:
        This function is intended for use in a command-line interface, providing
        direct output of results. For programmatic downstream analysis, consider
        using the API functions instead, i.e. cgcnn_pred and cgcnn_descriptor.
    """

    # Extract optional plot labels from kwargs, with defaults
    xlabel = kwargs.get("xlabel", "Actual")
    ylabel = kwargs.get("ylabel", "Predicted")

    model.eval()
    targets_list = []
    outputs_list = []
    cif_ids = []

    with torch.no_grad():
        for input_batch, target, cif_id in loader:
            atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx = input_batch
            atom_fea = atom_fea.to(device)
            nbr_fea = nbr_fea.to(device)
            nbr_fea_idx = nbr_fea_idx.to(device)
            crystal_atom_idx = [idx_map.to(device) for idx_map in crystal_atom_idx]
            target = target.to(device)
            output, _ = model(atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)

            targets_list.extend(target.cpu().numpy().ravel().tolist())
            outputs_list.extend(output.cpu().numpy().ravel().tolist())
            cif_ids.extend(cif_id)

    mse = mean_squared_error(targets_list, outputs_list)
    r2 = r2_score(targets_list, outputs_list)
    print(f"MSE: {mse:.4f}, R2 Score: {r2:.4f}")

    # Save results to CSV
    sorted_rows = sorted(zip(cif_ids, targets_list, outputs_list), key=lambda x: x[0])
    with open(results_file, "w", newline="") as file:
        writer = csv.writer(file)
        writer.writerow(["cif_id", "Actual", "Predicted"])
        writer.writerows(sorted_rows)
    print(f"Prediction results have been saved to {results_file}")

    # Create parity plot
    fig, ax = plt.subplots(figsize=(8, 6))
    df = pd.DataFrame({"Actual": targets_list, "Predicted": outputs_list})

    ax = density_hexbin(
        x="Actual",
        y="Predicted",
        df=df,
        ax=ax,
        xlabel=xlabel,
        ylabel=ylabel,
        best_fit_line=False,
        gridsize=40,
    )
    ax.set_aspect("auto")
    ax.set_box_aspect(1)
    plt.tight_layout()
    plt.savefig(plot_file, format="svg")
    print(f"Parity plot has been saved to {plot_file}")
    plt.close()

    # If axis limits are provided, save the csv file with the specified limits
    if axis_limits:
        results_file = (
            results_file.split(".")[0]
            + "_axis_limits_"
            + str(axis_limits[0])
            + "_"
            + str(axis_limits[1])
            + ".csv"
        )
        plot_file = (
            plot_file.split(".")[0]
            + "_axis_limits_"
            + str(axis_limits[0])
            + "_"
            + str(axis_limits[1])
            + ".svg"
        )

        df = df[
            (df["Actual"] >= axis_limits[0])
            & (df["Actual"] <= axis_limits[1])
            & (df["Predicted"] >= axis_limits[0])
            & (df["Predicted"] <= axis_limits[1])
        ]

        df.to_csv(
            results_file,
            index=False,
        )

        # Create parity plot
        fig, ax = plt.subplots(figsize=(8, 6))

        ax = density_hexbin(
            x="Actual",
            y="Predicted",
            df=df,
            ax=ax,
            xlabel=xlabel,
            ylabel=ylabel,
            best_fit_line=False,
            gridsize=40,
        )
        ax.set_aspect("auto")
        ax.set_box_aspect(1)
        plt.tight_layout()
        plt.savefig(plot_file, format="svg")
        print(f"Parity plot has been saved to {plot_file}")
        plt.close()

get_lr(optimizer)

Extracts learning rates from a PyTorch optimizer.

Parameters:

Name Type Description Default
optimizer Optimizer

The PyTorch optimizer to extract learning rates from.

required

Returns:

Type Description
list[float]

list[float]: A list of learning rates, one for each parameter group in the optimizer.

Source code in cgcnn2/util.py
59
60
61
62
63
64
65
66
67
68
69
70
def get_lr(optimizer: torch.optim.Optimizer) -> list[float]:
    """
    Extracts learning rates from a PyTorch optimizer.

    Args:
        optimizer (torch.optim.Optimizer): The PyTorch optimizer to extract learning rates from.

    Returns:
        list[float]: A list of learning rates, one for each parameter group in the optimizer.
    """

    return [param_group["lr"] for param_group in optimizer.param_groups]

id_prop_gen(cif_dir)

Generates a CSV file containing IDs and properties of CIF files.

Parameters:

Name Type Description Default
cif_dir str

Directory containing the CIF files.

required
Source code in cgcnn2/util.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def id_prop_gen(cif_dir: str) -> None:
    """Generates a CSV file containing IDs and properties of CIF files.

    Args:
        cif_dir (str): Directory containing the CIF files.
    """

    cif_list = glob.glob(f"{cif_dir}/*.cif")

    id_prop_cif = pd.DataFrame(
        {
            "id": [os.path.basename(cif).split(".")[0] for cif in cif_list],
            "prop": [0 for _ in range(len(cif_list))],
        }
    )

    id_prop_cif.to_csv(
        f"{cif_dir}/id_prop.csv",
        index=False,
        header=False,
    )

output_id_gen()

Generates a unique output identifier based on current date and time.

Returns:

Name Type Description
str str

A string in the format 'output_mmdd_HHMM' representing the current date and time.

Source code in cgcnn2/util.py
21
22
23
24
25
26
27
28
29
30
31
32
33
def output_id_gen() -> str:
    """
    Generates a unique output identifier based on current date and time.

    Returns:
        str: A string in the format 'output_mmdd_HHMM' representing the current date and time.
    """

    now = datetime.now()
    timestamp = now.strftime("%m%d_%H%M")
    folder_name = f"output_{timestamp}"

    return folder_name

unique_structures_clean(dataset_dir, delete_duplicates=False)

Checks for duplicate (structurally equivalent) structures in a directory of CIF files using pymatgen's StructureMatcher and returns the count of unique structures.

Parameters:

Name Type Description Default
dataset_dir str

The path to the dataset containing CIF files.

required
delete_duplicates bool

Whether to delete the duplicate structures, default is False.

False

Returns:

Name Type Description
grouped

list A list of lists, where each sublist contains structurally equivalent

structures.

Source code in cgcnn2/util.py
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
def unique_structures_clean(dataset_dir, delete_duplicates=False):
    """
    Checks for duplicate (structurally equivalent) structures in a directory
    of CIF files using pymatgen's StructureMatcher and returns the count
    of unique structures.

    Args:
        dataset_dir (str): The path to the dataset containing CIF files.
        delete_duplicates (bool): Whether to delete the duplicate structures, default is False.

    Returns:
        grouped: list
            A list of lists, where each sublist contains structurally equivalent
        structures.
    """
    cif_files   = [f for f in os.listdir(dataset_dir) if f.endswith(".cif")]
    structures  = []
    filenames   = []

    for fname in cif_files:
        full_path   = os.path.join(dataset_dir, fname)
        structures.append(Structure.from_file(full_path))
        filenames.append(fname)

    id_to_fname = {id(s): fn for s, fn in zip(structures, filenames)}

    matcher  = StructureMatcher()
    grouped  = matcher.group_structures(structures)

    grouped_fnames = [[id_to_fname[id(s)] for s in group] for group in grouped]

    if delete_duplicates:
        for file_group in grouped_fnames:
            # keep the first file, delete the rest
            for dup_fname in file_group[1:]:
                os.remove(os.path.join(dataset_dir, dup_fname))

    return grouped_fnames