Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ pytest # tests runner
pytest-xdist # allows running parallel testing with pytest -n <num_workers>
pytest-cov # allows to get coverage report
ruff # format and lint code
jenn >= 1.0.2, <2.0
jenn >= 2.0.0, <3.0
egobox >= 0.25.0, <1.0
7 changes: 3 additions & 4 deletions smt/surrogate_models/genn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Dict, List, Tuple, Union

import numpy as np
from jenn.model import NeuralNet
import jenn

from smt.surrogate_models.surrogate_model import SurrogateModel
from smt.utils import persistence
Expand Down Expand Up @@ -70,7 +70,6 @@ class GENN(SurrogateModel):
def load_data(self, xt, yt, dyt_dxt=None):
"""Load all training data into surrogate model in one step.

:param model: SurrogateModel object for which to load training data
:param xt: smt data points at which response is evaluated
:param yt: response at xt
:param dyt_dxt: gradient at xt
Expand Down Expand Up @@ -167,7 +166,7 @@ def _initialize(self):
)
self.options.declare(
"is_normalize",
default=False,
default=True,
types=bool,
desc="normalize training by mean and variance",
)
Expand All @@ -183,7 +182,7 @@ def _final_initialize(self):
output = [1] # will be overwritten during training (dummy value)
hidden = self.options["hidden_layer_sizes"]
layer_sizes = inputs + hidden + output
self.model = NeuralNet(layer_sizes)
self.model = jenn.NeuralNet(layer_sizes)

def _train(self):
X, Y, J = _smt_to_genn(self.training_points)
Expand Down
59 changes: 59 additions & 0 deletions smt/surrogate_models/tests/test_genn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import numpy as np
import unittest
import jenn
import smt


class TestGENN(unittest.TestCase):
def test_rosenbrock(self):
"""Check GENN predictions on Rosenbrock function."""

# Generate synthetic training data outside of SMT (using JENN)
x_train, y_train, dydx_train = jenn.utilities.sample(
f=jenn.synthetic_data.rosenbrock.compute,
f_prime=jenn.synthetic_data.rosenbrock.compute_partials,
m_random=0,
m_levels=3,
lb=[-np.pi, -np.pi],
ub=[np.pi, np.pi],
)

# Generate synthetic test data outside of SMT (using JENN)
x_test, y_test, dydx_test = jenn.utilities.sample(
f=jenn.synthetic_data.rosenbrock.compute,
f_prime=jenn.synthetic_data.rosenbrock.compute_partials,
m_random=0,
m_levels=100,
lb=[-np.pi, -np.pi],
ub=[np.pi, np.pi],
)

# SMT and JENN data structures are transposed from each other
x_train = x_train.T
y_train = y_train.T
dydx_train = dydx_train.squeeze().T

x_test = x_test.T
y_test = y_test.T
dydx_test = dydx_test.squeeze().T

# Training model using SMT API as usual
genn = smt.surrogate_models.GENN()
genn.options["hidden_layer_sizes"] = [12, 12]
genn.options["alpha"] = 0.01
genn.options["lambd"] = 0.01
genn.options["gamma"] = 1
genn.options["num_iterations"] = 5000
genn.options["is_backtracking"] = True
genn.options["is_normalize"] = True
genn.options["seed"] = 123
genn.load_data(x_train, y_train, dydx_train)
genn.train()

# Predict test data
y_pred = genn.predict_values(x_test)

# Make sure the prediction is good
rsquare = jenn.metrics.rsquare(y_pred.ravel(), y_test.ravel())
tol = 0.99
self.assertGreater(rsquare, tol, msg=f"R^2 = {rsquare:.3f} is less than {tol}")
4 changes: 2 additions & 2 deletions smt/tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ def setUp(self):
sms["MGP"] = MGP(theta0=[1e-2] * ndim)
sms["GEKPLS"] = GEKPLS(theta0=[1e-2] * 2, n_comp=2, delta_x=1e-1)
sms["GENN"] = GENN(
num_iterations=1000,
num_iterations=5000,
hidden_layer_sizes=[
24,
],
alpha=1e-1,
alpha=1e-2,
lambd=1e-2,
is_backtracking=True,
is_normalize=True,
Expand Down