Skip to content

Implementation of MLP #3

@tangh18

Description

@tangh18

Appreciate your splendid work.
However, the MLP and forward func in DRPO seems are not correctly implemented.
In model/drpo.py, the MLP is defined as:

        self.MLP = torch.nn.Sequential(
            torch.nn.Linear(self.hidden_size * self.stock_num,
                            self.hidden_size * 7, bias=self.bias),
            torch.nn.SiLU(),
            torch.nn.Dropout(self.dropout),
            torch.nn.Linear(self.hidden_size * 7, 128, bias=self.bias),
            torch.nn.SiLU(),
            torch.nn.Dropout(self.dropout),
            torch.nn.Linear(128, 64, bias=self.bias),
            torch.nn.SiLU(),
            torch.nn.Dropout(self.dropout),
            torch.nn.Linear(64, 1, bias=self.bias),
        )

The forward func in the same file is:

        # inputs = self.faltten(inputs)                                       # bs, 140, 1470
        for i in range(ts):  # each timestep
            input = inputs[:, i:i+1, :]                                         # bs, 30, 49
            input = input.reshape(-1, fn)
            
            input_temp = torch.concat([input, obs_omega.reshape(-1,1)], dim=1)
            input_temp = input_temp.reshape(bs* sn, 1, -1)
            output, (hx, cx) = self.LSTM(input_temp, (hx, cx))
            output = output.reshape(bs, -1)
            output = self.MLP(output)    # bs, 1
           
            # output = output - torch.mean(output,dim = -1,keepdim=True)
            # output = torch.sigmoid(output)
            out_time_stock.append(output)  # (batchsize,stocknum)

    
            ''' Step 2: Calculate Next State '''
            with torch.no_grad():
                this_state = state[:, i, :]                                     # (bs, 30)   
                output = output - torch.mean(output,dim = -1,keepdim=True)      # not (bs, 30), actual (bs, 1)
                # output is all zero here
                next_state = this_state + output                                # change to next state
                # check next hold > 0

The MLP output shape is (bs, 1), the demean operation on dimension 1 makes output all zero. Hence the state will never change. What should be the correct implementation? Thanks

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions