Skip to content

Commit 2d9cf58

Browse files
committed
Updated to latest tensorial 0.6.0
1 parent 2c81cf7 commit 2d9cf58

7 files changed

Lines changed: 54 additions & 20 deletions

File tree

configs/listeners/default.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ defaults:
22
- model_checkpoint
33
- early_stopping
44
# - model_summary
5-
- rich_progress_bar
5+
- metrics_printer
6+
# - rich_progress_bar
67
- _self_
78

89
model_checkpoint:
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
metrics_printer:
2+
_target_: tensorial.reaxkit.MetricsPrinter

configs/model/mace.yaml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@ loss_fn:
1515
weights: [ 1000., 1. ]
1616
loss_fns:
1717
- _target_: tensorial.gcnn.Loss
18-
field: nodes.born_charges_predicted
19-
target_field: nodes.born_charges
18+
loss_fn: squared_error
19+
targets: nodes.born_charges
20+
predictions: nodes.born_charges_predicted
2021

2122
- _target_: tensorial.gcnn.Loss
22-
field: nodes.raman_tensors_predicted
23-
target_field: nodes.raman_tensors
23+
loss_fn: squared_error
24+
predictions: nodes.raman_tensors_predicted
25+
targets: nodes.raman_tensors
2426

2527

2628
model:

configs/model/nequip.yaml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@ loss_fn:
1515
weights: [ 1000., 1. ]
1616
loss_fns:
1717
- _target_: tensorial.gcnn.Loss
18-
field: nodes.born_charges_predicted
19-
target_field: nodes.born_charges
18+
loss_fn: squared_error
19+
predictions: nodes.born_charges_predicted
20+
targets: nodes.born_charges
2021

2122
- _target_: tensorial.gcnn.Loss
22-
field: nodes.raman_tensors_predicted
23-
target_field: nodes.raman_tensors
23+
loss_fn: squared_error
24+
predictions: nodes.raman_tensors_predicted
25+
targets: nodes.raman_tensors
2426

2527
model:
2628
_target_: tensorial.nn.Sequential

configs/model/nequip_nmr.yaml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@ loss_fn:
1515
weights: [ 1000., 1. ]
1616
loss_fns:
1717
- _target_: tensorial.gcnn.Loss
18-
field: nodes.born_charges_predicted
19-
target_field: nodes.born_charges
18+
loss_fn: squared_error
19+
targets: nodes.born_charges
20+
predictions: nodes.born_charges_predicted
2021

2122
- _target_: tensorial.gcnn.Loss
22-
field: nodes.raman_tensors_predicted
23-
target_field: nodes.raman_tensors
23+
loss_fn: squared_error
24+
targets: nodes.raman_tensors
25+
predictions: nodes.raman_tensors_predicted
2426

2527
model:
2628
_target_: tensorial.nn.Sequential

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ dependencies = [
3333
'e3nn-jax',
3434
"equinox",
3535
"reax>=0.2,<0.7",
36-
"tensorial>=0.5.1",
36+
"tensorial~=0.6",
3737
"pymatgen",
3838
]
3939

src/e3response/losses.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import jax
44
import jraph
5+
import optax
56
from tensorial import gcnn
67
from tensorial.gcnn import atomic
78
from tensorial.gcnn.keys import predicted
@@ -23,37 +24,61 @@ def response_loss(
2324
if energy:
2425
weights.append(1.0 if isinstance(energy, bool) else energy)
2526
loss_terms.append(
26-
gcnn.Loss(f"globals.{predicted(atomic.TOTAL_ENERGY)}", f"globals.{atomic.TOTAL_ENERGY}")
27+
gcnn.Loss(
28+
optax.squared_error,
29+
f"globals.{atomic.TOTAL_ENERGY}",
30+
f"globals.{predicted(atomic.TOTAL_ENERGY)}",
31+
)
2732
)
2833

2934
if forces:
3035
weights.append(1.0 if isinstance(forces, bool) else forces)
31-
loss_terms.append(gcnn.Loss(f"nodes.{predicted(atomic.FORCES)}", f"nodes.{atomic.FORCES}"))
36+
loss_terms.append(
37+
gcnn.Loss(
38+
optax.squared_error,
39+
f"nodes.{atomic.FORCES}",
40+
f"nodes.{predicted(atomic.FORCES)}",
41+
)
42+
)
3243

3344
if born_charges:
3445
weights.append(1.0 if isinstance(born_charges, bool) else born_charges)
3546
loss_terms.append(
36-
gcnn.Loss(f"nodes.{predicted(keys.BORN_CHARGES)}", f"nodes.{keys.BORN_CHARGES}")
47+
gcnn.Loss(
48+
optax.squared_error,
49+
f"nodes.{keys.BORN_CHARGES}",
50+
f"nodes.{predicted(keys.BORN_CHARGES)}",
51+
)
3752
)
3853

3954
if polarization_tensors:
4055
weights.append(1.0 if isinstance(polarization_tensors, bool) else polarization_tensors)
4156
loss_terms.append(
42-
gcnn.Loss(f"globals.{predicted(keys.POLARIZATION)}", f"globals.{keys.POLARIZATION}")
57+
gcnn.Loss(
58+
optax.squared_error,
59+
f"globals.{keys.POLARIZATION}",
60+
f"globals.{predicted(keys.POLARIZATION)}",
61+
)
4362
)
4463

4564
if dielectric_tensor:
4665
weights.append(1.0 if isinstance(dielectric_tensor, bool) else dielectric_tensor)
4766
loss_terms.append(
4867
gcnn.Loss(
49-
f"globals.{predicted(keys.DIELECTRIC_TENSOR)}", f"globals.{keys.DIELECTRIC_TENSOR}"
68+
optax.squared_error,
69+
f"globals.{keys.DIELECTRIC_TENSOR}",
70+
f"globals.{predicted(keys.DIELECTRIC_TENSOR)}",
5071
)
5172
)
5273

5374
if raman_tensors:
5475
weights.append(1.0 if isinstance(raman_tensors, bool) else raman_tensors)
5576
loss_terms.append(
56-
gcnn.Loss(f"nodes.{predicted(keys.RAMAN_TENSORS)}", f"nodes.{keys.RAMAN_TENSORS}")
77+
gcnn.Loss(
78+
optax.squared_error,
79+
f"nodes.{keys.RAMAN_TENSORS}",
80+
f"nodes.{predicted(keys.RAMAN_TENSORS)}",
81+
)
5782
)
5883

5984
if not loss_terms:

0 commit comments

Comments
 (0)