Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions dynet/lstm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,13 +293,13 @@ void CoupledLSTMBuilder::disable_dropout() {
enum { _X2I, _H2I, _BI, _X2F, _H2F, _BF, _X2O, _H2O, _BO, _X2G, _H2G, _BG };
enum { LN_GH, LN_BH, LN_GX, LN_BX, LN_GC, LN_BC};

VanillaLSTMBuilder::VanillaLSTMBuilder() : has_initial_state(false), layers(0), input_dim(0), hid(0), dropout_rate_h(0), ln_lstm(false), forget_bias(1.f), dropout_masks_valid(false) { }
VanillaLSTMBuilder::VanillaLSTMBuilder() : has_initial_state(false), layers(0), input_dim(0), hid(0), dropout_rate_h(0), ln_lstm(false), forget_bias(1.f), dropout_masks_valid(false), dropconnect_masks_valid(false), dropconnect_rate(0) { }

VanillaLSTMBuilder::VanillaLSTMBuilder(unsigned layers,
unsigned input_dim,
unsigned hidden_dim,
ParameterCollection& model,
bool ln_lstm, float forget_bias) : layers(layers), input_dim(input_dim), hid(hidden_dim), ln_lstm(ln_lstm), forget_bias(forget_bias), dropout_masks_valid(false) {
bool ln_lstm, float forget_bias) : layers(layers), input_dim(input_dim), hid(hidden_dim), ln_lstm(ln_lstm), forget_bias(forget_bias), dropout_masks_valid(false), dropconnect_masks_valid(false) {
unsigned layer_input_dim = input_dim;
local_model = model.add_subcollection("vanilla-lstm-builder");
for (unsigned i = 0; i < layers; ++i) {
Expand Down Expand Up @@ -327,6 +327,7 @@ VanillaLSTMBuilder::VanillaLSTMBuilder(unsigned layers,
} // layers
dropout_rate = 0.f;
dropout_rate_h = 0.f;
dropconnect_rate = 0.f;
}

void VanillaLSTMBuilder::new_graph_impl(ComputationGraph& cg, bool update) {
Expand Down Expand Up @@ -370,6 +371,7 @@ void VanillaLSTMBuilder::start_new_sequence_impl(const vector<Expression>& hinit
}

dropout_masks_valid = false;
dropconnect_masks_valid = false;
}

void VanillaLSTMBuilder::set_dropout_masks(unsigned batch_size) {
Expand All @@ -392,6 +394,19 @@ void VanillaLSTMBuilder::set_dropout_masks(unsigned batch_size) {
dropout_masks_valid = true;
}

void VanillaLSTMBuilder::set_dropconnect_masks() {
dropconnect_masks.clear();
for (unsigned i = 0; i < layers; ++i) {
if (dropconnect_rate > 0.f) {
float retention_rate = 1.f - dropconnect_rate;
auto& p = params[i];
const auto& hidden2hidden_dim = p[_H2I].dim();
dropconnect_masks.push_back(random_bernoulli(*_cg, hidden2hidden_dim, retention_rate));
}
}
dropconnect_masks_valid = true;
}

ParameterCollection & VanillaLSTMBuilder::get_parameter_collection() {
return local_model;
}
Expand Down Expand Up @@ -440,6 +455,8 @@ Expression VanillaLSTMBuilder::add_input_impl(int prev, const Expression& x) {
vector<Expression>& ct = c.back();
Expression in = x;
if ((dropout_rate > 0.f || dropout_rate_h > 0.f) && !dropout_masks_valid) set_dropout_masks(x.dim().bd);
if (dropconnect_rate > 0.f && !dropconnect_masks_valid) set_dropconnect_masks();

for (unsigned i = 0; i < layers; ++i) {
const vector<Expression>& vars = param_vars[i];

Expand All @@ -462,6 +479,14 @@ Expression VanillaLSTMBuilder::add_input_impl(int prev, const Expression& x) {
}
if (has_prev_state && dropout_rate_h > 0.f)
i_h_tm1 = cmult(i_h_tm1, masks[i][1]);

Expression h2h_weights;

h2h_weights = vars[_H2I];
if (dropconnect_rate > 0.f) {
h2h_weights = cmult(h2h_weights, dropconnect_masks[i]);
}

// input
Expression tmp;
Expression i_ait;
Expand All @@ -476,7 +501,7 @@ Expression VanillaLSTMBuilder::add_input_impl(int prev, const Expression& x) {
tmp = vars[_BI] + layer_norm(vars[_X2I] * in, ln_vars[LN_GX], ln_vars[LN_BX]);
}else{
if (has_prev_state)
tmp = affine_transform({vars[_BI], vars[_X2I], in, vars[_H2I], i_h_tm1});
tmp = affine_transform({vars[_BI], vars[_X2I], in, h2h_weights, i_h_tm1});
else
tmp = affine_transform({vars[_BI], vars[_X2I], in});
}
Expand Down Expand Up @@ -536,6 +561,16 @@ void VanillaLSTMBuilder::disable_dropout() {
dropout_rate_h = 0.f;
}

void VanillaLSTMBuilder::set_dropconnect(float d) {
DYNET_ARG_CHECK(d >= 0.f && d <= 1.f,
"dropweight rate must be a probability (>=0 and <=1)");
dropconnect_rate = d;
}

void VanillaLSTMBuilder::disable_dropconnect() {
dropconnect_rate = 0.f;
}


CompactVanillaLSTMBuilder::CompactVanillaLSTMBuilder() : has_initial_state(false), layers(0), input_dim(0), hid(0), dropout_rate_h(0), weightnoise_std(0), dropout_masks_valid(false) { }

Expand Down
23 changes: 23 additions & 0 deletions dynet/lstm.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,23 @@ struct VanillaLSTMBuilder : public RNNBuilder {
* \param batch_size Batch size
*/
void set_dropout_masks(unsigned batch_size = 1);

/**
* \brief Set DropConnect rate
* \details Apply DropConnect [Wan et al., 2013] to hidden to hidden weight matrix. Weights will be masked by random bernoulli matrix.
* Each layer of network gets unique mask.
* \param d Weight drop rate
*/
void set_dropconnect(float d);
/**
* \brief Set DropConnect rate to 0
*/
void disable_dropconnect();
/**
* \brief Set mask for recurrent weight matrices at the beginning of new sequence
*/
void set_dropconnect_masks();

/**
* \brief Get parameters in VanillaLSTMBuilder
* \return list of points to ParameterStorage objects
Expand Down Expand Up @@ -294,6 +311,9 @@ struct VanillaLSTMBuilder : public RNNBuilder {
// first index is layer, then ...
std::vector<std::vector<Expression>> masks;

// one mask on recurrent weights per layer
std::vector<Expression> dropconnect_masks;

// first index is time, second is layer
std::vector<std::vector<Expression>> h, c;

Expand All @@ -309,6 +329,9 @@ struct VanillaLSTMBuilder : public RNNBuilder {
float forget_bias;
bool dropout_masks_valid;

float dropconnect_rate;
bool dropconnect_masks_valid;

private:
ComputationGraph* _cg; // Pointer to current cg

Expand Down
4 changes: 4 additions & 0 deletions python/_dynet.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,13 @@ cdef extern from "dynet/lstm.h" namespace "dynet":
CVanillaLSTMBuilder(unsigned layers, unsigned input_dim, unsigned hidden_dim, CModel &model, bool ln_lstm, float forget_bias)
void set_dropout(float d, float d_r)
void set_dropout_masks(unsigned batch_size)
void set_dropconnect(float d)
void disable_dropconnect()
void set_dropconnect_masks()

vector[vector[CParameters]] params
vector[vector[CExpression]] param_vars
vector[CExpression] dropconnect_masks

cdef cppclass CCoupledLSTMBuilder "dynet::CoupledLSTMBuilder" (CRNNBuilder):
CCoupledLSTMBuilder()
Expand Down
10 changes: 9 additions & 1 deletion python/_dynet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -5288,7 +5288,6 @@ cdef class VanillaLSTMBuilder(_RNNBuilder): # {{{
exprs.append(layer_exprs)
return exprs


cpdef void set_dropouts(self, float d, float d_r):
"""Set the dropout rates

Expand Down Expand Up @@ -5329,6 +5328,15 @@ cdef class VanillaLSTMBuilder(_RNNBuilder): # {{{
"""
self.thisvanillaptr.set_dropout_masks(batch_size)

cpdef void set_dropconnect(self, float d):
self.thisvanillaptr.set_dropconnect(d)

cpdef void disable_dropconnect(self):
self.thisvanillaptr.disable_dropconnect()

cpdef set_dropconnect_masks(self):
self.thisvanillaptr.set_dropconnect_masks()

def whoami(self): return "VanillaLSTMBuilder"
# VanillaLSTMBuilder }}}

Expand Down