forked from aws-samples/process-optimization-workshop
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathCSTRModel.py
More file actions
31 lines (27 loc) · 783 Bytes
/
CSTRModel.py
File metadata and controls
31 lines (27 loc) · 783 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0
import torch
import torch.nn as nn
class CSTRModel(torch.nn.Module):
"""
MLP model for the CSTR forward model
"""
def __init__(self):
super(CSTRModel, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(2, 16),
nn.ReLU(),
nn.Linear(16, 64),
nn.ReLU(),
nn.Linear(64, 256),
nn.ReLU(),
nn.Linear(256, 64),
nn.ReLU(),
nn.Linear(64, 16),
nn.ReLU(),
nn.Linear(16, 3),
)
def forward(self, x):
y_pred = self.linear_relu_stack(x)
return y_pred