diff --git a/gensen.py b/gensen.py index 2842429..d63debd 100644 --- a/gensen.py +++ b/gensen.py @@ -18,11 +18,12 @@ class Encoder(nn.Module): def __init__( self, vocab_size, embedding_dim, - hidden_dim, num_layers, rnn_type='GRU' + hidden_dim, num_layers, rnn_type='GRU', cuda=False ): """Initialize params.""" super(Encoder, self).__init__() self.rnn_type = rnn_type + self.cuda = cuda rnn = getattr(nn, rnn_type) self.src_embedding = nn.Embedding( num_embeddings=vocab_size, @@ -56,12 +57,13 @@ def set_pretrained_embeddings(self, embedding_matrix): self.src_vocab_size = embedding_matrix.shape[0] self.src_emb_dim = embedding_matrix.shape[1] - try: - self.src_embedding.weight.data.set_(torch.from_numpy(embedding_matrix)) - except: + if self.cuda: self.src_embedding.weight.data.set_(torch.from_numpy(embedding_matrix).cuda()) - - self.src_embedding.cuda() + else: + self.src_embedding.weight.data.set_(torch.from_numpy(embedding_matrix)) + + if self.cuda: + self.src_embedding.cuda() def forward(self, input, lengths, return_all=False, pool='last'): """Propogate input through the encoder.""" @@ -154,11 +156,17 @@ def _load_params(self): self.id2word = model_vocab['id2word'] self.task_word2id = self.word2id self.id2word = self.id2word - - encoder_model = torch.load(os.path.join( - self.model_folder, - '%s.model' % (self.filename_prefix) - )) + + if self.cuda: + encoder_model = torch.load(os.path.join( + self.model_folder, + '%s.model' % (self.filename_prefix) + )) + else: + encoder_model = torch.load(os.path.join( + self.model_folder, + '%s.model' % (self.filename_prefix) + ),map_location='cpu') # Initialize encoders self.encoder = Encoder( @@ -166,7 +174,8 @@ def _load_params(self): embedding_dim=encoder_model['src_embedding.weight'].size(1), hidden_dim=encoder_model['encoder.weight_hh_l0'].size(1), num_layers=1 if len(encoder_model) < 10 else 2, - rnn_type=self.rnn_type + rnn_type=self.rnn_type, + cuda=self.cuda ) # Load pretrained sentence encoder weights