E(3) Equivariant Graph Neural Network Checklist
Here is my checklist for training E(3) Equivariant GNNs. I wanted to put this together to condense months of lessons I've learned.
Table of Contents
- General GNN Checklist
- Equivariant GNN Checklist
- Material Science GNN Checklist
- Diffusion Model Checklist
- Bonus Checklist
General GNN Checklist
- 1.1When performing message passing, try encoding the distance between the sender & receiver nodes into the edge message.
- 1.2Get creative with the loss! It doesn't have to be the output of node states.
- You can use the edge messages as well.
- Often, this is important if the edge distance in between the nodes is important to the final prediction
e.g. In diffusion models (like Mattergen), the edge distance loss teaches the model how close to place adjacent nodes
- 1.3Know when to predict scaled versions of your targets
- e.g. Predicting x per number of atoms may be better than just predicting x.
- 1.4
When creating the neighbor graph, don't naively create it by calculating the distance between all pairs of atoms O(N2) operations.
Instead, use a k-d tree to find the nearest neighbors for each atom. This reduces the time complexity to O(N log N)
Here is an example of using a k-d tree to construct the neighbor graph (Orbital Material's Orb model)
Equivariant GNN Checklist
- 2.1Think equivariantly. Use equivariant features.
- E.g. In diffusion models for material science, use the lattice vectors as 3D features, not 9 scalar values to learn on
- Another example: if you pass the symmetric matrix of the lattice vectors as input features, is that equivariant?
- Three 3D vectors could represent a symmetric matrix, but there's only 6 degrees of freedom. So how should you pass in the symmetric matrix? If at all?
- 2.2Only construct losses that are equivariant.
- 2.3Test your models for equivariance
- You can perform this test with an untrained model.
- Your goal is to test that F(R(x)) = R(F(x))
- Where F is your model, and R is a random rotation
- Perform this check five times with different random rotations, and you can be confident that your model is rotationally equivariant.
Here is an example equivariance test on the Fairchem Repo
- Here is another test you should do. This test is more end-to-end:
- Overfit your model on a single example.
- Then, rotate the example and pass it back into your model.
- If your model is equivariant, the predicted value should be correct despite the rotation.
- Do this test five times with different rotations to ensure it's equivariant.
- If your model is not equivariant, you must debug it layer by layer. Comment out each layer and test them individually.
- It may not even be your model! Your features/loss may not be equivariant.
This is why using equivariant frameworks (like e3nn) is not enough. Hence, why tests are important.
- In general, try to have equivariance tests for every type of output of your model (testing for invariance of the scalar output (e.g. total energy) isn't enough).
- Catching equivariance issues is very important because it affects the performance of your model.
- Writing tests helps you catch them as your model architecture evolves
- 2.4Make a nice API to feed in input features.
- Most often, we only want to pass in 1D or 3D features. So invest in a simple interface where the model accepts concatenated tensors of 1D or 3D feature inputs.
- This reduces bugs because you aren't fiddling with irrep definitions (e.g. 0x1e + 5x1o)
Material Science GNN Checklist
- 3.1The number of max neighbors you have biases your model.
- Cause in BCC crystals, the center cell can have 14 neighbors. So don't be too tempted to decrease the number of neighbors too much
- If your max_neighbors cutoff is 8, you're biasing the model to do WORSE in FCC crystals (which can have 12 nearest neighbors)
- Note: Account for self-loops (edges that point to the same node). This could hog up one neighbor slot (and increase computation)
- 3.2Use float64 (doubles). Molecular dynamics require such precision
This is advice from Albert Musaelian! (An author of Allegro)
- 3.3Make sure your data is at 0 Kelvin if you expect it
- Because materials can settle to different relaxed configurations if they are not at 0 Kelvin.
If you have non-zero Kelvin datasets, you need to first denoise them (see Generalizing Denoising to Non-Equilibrium Structures Improves Equivariant Force Fields)
- 3.4Visualize your dataset
- Please, please, please do this! I've seen so many teams lose so many days to this.
- Visualization isn't considered fun work. But you still have to do it eventually!
- Especially because it helps you understand the samples your model struggles with. Also, building the visualization infra isn't that much work.
- The best place to visualize the data is right before it enters the model.
- Visualize the forces in your training data! (make sure they aren’t too large!)
- Make sure that the atoms are centered in the middle of the cells. If they’re not centered, they can end up in the corners of the giant lattice (if you turn off periodic systems)
- 3.5Check for exact duplicates in the data
- You'd be surprised by the amount of EXACT duplicates across the training AND validation datasets!
- In general, material science datasets are pretty messy
- 3.6Avoid training on near-duplicates
- For models dealing with molecules, training on 5 different conformer configurations for the same molecule won't yield you as much signal as training on different molecules
Shoutout to Corin Wagen for this!
Diffusion Model Checklist
- 4.1For each training sample, after you noise the graph, you need to recompute the graph edges since nodes may enter/leave the cutoff radius (which determines neighbors)
- 4.2Periodic boundaries affect the loss. the error of pred_frac_x isn't abs(target_frac_x - pred_frac_x)
- Cause if your model predicted 0.1, but the target frac_coord is 0.9, your model's prediction is NOT off by 0.8. it's 0.2
- 4.3Don't forget to input these global features
- The current lattice as a feature (yes, even though it's implicit, this is very important)
- The num_atoms
- This seems important since it'll scale the length of the lattice (I need more testing though)
- 4.4Print out individual losses for each target
- When training atomic diffusion models, we tend to have multiple losses: lattice loss, atomic position loss, atomic type loss
- Seeing each loss individually helps you debug which loss is the worse. This gives you more ideas on how to improve your model
- This is a general ML technique. A few years ago, I could have done better on a Kaggle competition if I focused more on specific target losses.
- It also helps you determine how to much to scale each loss
- (e.g. multiply frac_x_loss by 10, so it's generally in the same "scale" as the atomic_loss)
Bonus Checklist
These are tips I've seen that aren't often followed:
- 5.1
Try to follow the tips on Kaparthy's A Recipe for Training Neural Networks
- 5.2Come up with a list of hypotheses first, then test the ones that are most likely to improve your model first
- Trying the first ideas you have may not be efficient
This tip is from Andrew Ng!
I hope you find this checklist useful!
- Curtis