22
33import jax
44import jraph
5+ import optax
56from tensorial import gcnn
67from tensorial .gcnn import atomic
78from tensorial .gcnn .keys import predicted
@@ -23,37 +24,61 @@ def response_loss(
2324 if energy :
2425 weights .append (1.0 if isinstance (energy , bool ) else energy )
2526 loss_terms .append (
26- gcnn .Loss (f"globals.{ predicted (atomic .TOTAL_ENERGY )} " , f"globals.{ atomic .TOTAL_ENERGY } " )
27+ gcnn .Loss (
28+ optax .squared_error ,
29+ f"globals.{ atomic .TOTAL_ENERGY } " ,
30+ f"globals.{ predicted (atomic .TOTAL_ENERGY )} " ,
31+ )
2732 )
2833
2934 if forces :
3035 weights .append (1.0 if isinstance (forces , bool ) else forces )
31- loss_terms .append (gcnn .Loss (f"nodes.{ predicted (atomic .FORCES )} " , f"nodes.{ atomic .FORCES } " ))
36+ loss_terms .append (
37+ gcnn .Loss (
38+ optax .squared_error ,
39+ f"nodes.{ atomic .FORCES } " ,
40+ f"nodes.{ predicted (atomic .FORCES )} " ,
41+ )
42+ )
3243
3344 if born_charges :
3445 weights .append (1.0 if isinstance (born_charges , bool ) else born_charges )
3546 loss_terms .append (
36- gcnn .Loss (f"nodes.{ predicted (keys .BORN_CHARGES )} " , f"nodes.{ keys .BORN_CHARGES } " )
47+ gcnn .Loss (
48+ optax .squared_error ,
49+ f"nodes.{ keys .BORN_CHARGES } " ,
50+ f"nodes.{ predicted (keys .BORN_CHARGES )} " ,
51+ )
3752 )
3853
3954 if polarization_tensors :
4055 weights .append (1.0 if isinstance (polarization_tensors , bool ) else polarization_tensors )
4156 loss_terms .append (
42- gcnn .Loss (f"globals.{ predicted (keys .POLARIZATION )} " , f"globals.{ keys .POLARIZATION } " )
57+ gcnn .Loss (
58+ optax .squared_error ,
59+ f"globals.{ keys .POLARIZATION } " ,
60+ f"globals.{ predicted (keys .POLARIZATION )} " ,
61+ )
4362 )
4463
4564 if dielectric_tensor :
4665 weights .append (1.0 if isinstance (dielectric_tensor , bool ) else dielectric_tensor )
4766 loss_terms .append (
4867 gcnn .Loss (
49- f"globals.{ predicted (keys .DIELECTRIC_TENSOR )} " , f"globals.{ keys .DIELECTRIC_TENSOR } "
68+ optax .squared_error ,
69+ f"globals.{ keys .DIELECTRIC_TENSOR } " ,
70+ f"globals.{ predicted (keys .DIELECTRIC_TENSOR )} " ,
5071 )
5172 )
5273
5374 if raman_tensors :
5475 weights .append (1.0 if isinstance (raman_tensors , bool ) else raman_tensors )
5576 loss_terms .append (
56- gcnn .Loss (f"nodes.{ predicted (keys .RAMAN_TENSORS )} " , f"nodes.{ keys .RAMAN_TENSORS } " )
77+ gcnn .Loss (
78+ optax .squared_error ,
79+ f"nodes.{ keys .RAMAN_TENSORS } " ,
80+ f"nodes.{ predicted (keys .RAMAN_TENSORS )} " ,
81+ )
5782 )
5883
5984 if not loss_terms :
0 commit comments