22import os
33from functools import partial
44import argparse
5+ import logging
56
67from ..utils .compactmodel_io import get_model_classifier_name
78from ..utils .classification_exceptions import AlgorithmNotExistException , WordEmbeddingModelNotExistException
89from ..utils import load_word2vec_model , load_fasttext_model , load_poincare_model
910from ..smartload import smartload_compact_model
1011from ..classifiers import TopicVectorCosineDistanceClassifier
1112
13+ logging .basicConfig (level = logging .INFO )
14+ logger = logging .getLogger (__name__ )
1215
13- allowed_classifiers = ['ldatopic' , 'lsitopic' , 'rptopic' , 'kerasautoencoder' , 'topic_sklearn' ,
14- 'nnlibvec' , 'sumvec' , 'maxent' ]
16+ allowed_classifiers = [
17+ 'ldatopic' , 'lsitopic' , 'rptopic' , 'kerasautoencoder' ,
18+ 'topic_sklearn' , 'nnlibvec' , 'sumvec' , 'maxent'
19+ ]
1520needembedded_classifiers = ['nnlibvec' , 'sumvec' ]
1621topicmodels = ['ldatopic' , 'lsitopic' , 'rptopic' , 'kerasautoencoder' ]
1722
1823load_word2vec_nonbinary_model = partial (load_word2vec_model , binary = False )
1924load_poincare_binary_model = partial (load_poincare_model , binary = True )
2025
21- typedict = {'word2vec' : load_word2vec_model ,
22- 'word2vec_nonbinary' : load_word2vec_nonbinary_model ,
23- 'fasttext' : load_fasttext_model ,
24- 'poincare' : load_poincare_model ,
25- 'poincare_binary' : load_poincare_binary_model }
26+ typedict = {
27+ 'word2vec' : load_word2vec_model ,
28+ 'word2vec_nonbinary' : load_word2vec_nonbinary_model ,
29+ 'fasttext' : load_fasttext_model ,
30+ 'poincare' : load_poincare_model ,
31+ 'poincare_binary' : load_poincare_binary_model
32+ }
2633
2734
2835def get_argparser ():
29- argparser = argparse .ArgumentParser (description = 'Perform prediction on short text with a given trained model.' )
30- argparser .add_argument ('model_filepath' , help = 'Path of the trained (compact) model.' )
31- argparser .add_argument ('--wv' , default = '' , help = 'Path of the pre-trained Word2Vec model. (None if not needed)' )
32- argparser .add_argument ('--vecsize' , default = 300 , type = int , help = 'Vector dimensions. (Default: 300)' )
33- argparser .add_argument ('--topn' , type = int , default = 10 , help = 'Number of top-scored results displayed. (Default: 10)' )
34- argparser .add_argument ('--inputtext' , default = None , help = 'single input text for classification. Run console if set to None. (Default: None)' )
35- argparser .add_argument ('--type' , default = 'word2vec' ,
36- help = 'Type of word-embedding model (default: "word2vec"; other options: "fasttext", "poincare", "word2vec_nonbinary", "poincare_binary")' )
37- return argparser
36+ parser = argparse .ArgumentParser (
37+ description = 'Perform prediction on short text with a given trained model.'
38+ )
39+ parser .add_argument ('model_filepath' , help = 'Path of the trained (compact) model.' )
40+ parser .add_argument ('--wv' , default = '' , help = 'Path of the pre-trained Word2Vec model.' )
41+ parser .add_argument ('--vecsize' , default = 300 , type = int , help = 'Vector dimensions. (Default: 300)' )
42+ parser .add_argument ('--topn' , type = int , default = 10 , help = 'Number of top results to show.' )
43+ parser .add_argument ('--inputtext' , default = None , help = 'Single input text for classification. If omitted, will enter console mode.' )
44+ parser .add_argument ('--type' , default = 'word2vec' , choices = typedict .keys (),
45+ help = 'Type of word-embedding model (default: word2vec)' )
46+ return parser
3847
3948# main block
4049def main ():
@@ -43,51 +52,52 @@ def main():
4352
4453 # check if the model file is given
4554 if not os .path .exists (args .model_filepath ):
46- raise IOError ('Model file ' + args .model_filepath + ' not found!' )
47-
55+ raise IOError (f 'Model file " { args .model_filepath } " not found!' )
56+
4857 # get the name of the classifier
49- print ('Retrieving classifier name...' )
58+ logger . info ('Retrieving classifier name...' )
5059 classifier_name = get_model_classifier_name (args .model_filepath )
51- if not (classifier_name in allowed_classifiers ):
60+
61+ if classifier_name not in allowed_classifiers :
5262 raise AlgorithmNotExistException (classifier_name )
5363
5464 # load the Word2Vec model if necessary
5565 wvmodel = None
5666 if classifier_name in needembedded_classifiers :
57- # check if thw word embedding model is available
67+ # check if the word embedding model is available
5868 if not os .path .exists (args .wv ):
5969 raise WordEmbeddingModelNotExistException (args .wv )
6070 # if there, load it
61- print ( 'Loading word-embedding model: ' + args .wv )
71+ logger . info ( f 'Loading word-embedding model from { args .wv } ...' )
6272 wvmodel = typedict [args .type ](args .wv )
6373
6474 # load the classifier
65- print ('Initializing the classifier...' )
66- classifier = None
75+ logger .info ('Initializing the classifier...' )
6776 if classifier_name in topicmodels :
6877 topicmodel = smartload_compact_model (args .model_filepath , wvmodel , vecsize = args .vecsize )
6978 classifier = TopicVectorCosineDistanceClassifier (topicmodel )
7079 else :
7180 classifier = smartload_compact_model (args .model_filepath , wvmodel , vecsize = args .vecsize )
7281
73-
74- if args .inputtext != None :
75- if len (args .inputtext ) > 0 :
76- scoredict = classifier . score ( args . inputtext )
77- for label , score in sorted ( scoredict . items (), key = lambda s : s [ 1 ], reverse = True )[: args . topn ]:
78- print ( label , ' : ' , score )
79- else :
80- print ('No input text available! ' )
82+ # predict single input or run in console mode
83+ if args .inputtext is not None :
84+ if len (args .inputtext . strip ()) == 0 :
85+ print ( 'No input text provided.' )
86+ return
87+ scoredict = classifier . score ( args . inputtext )
88+ for label , score in sorted ( scoredict . items (), key = lambda x : x [ 1 ], reverse = True )[: args . topn ] :
89+ print (f' { label } : { score :.4f } ' )
8190 else :
82- # Console
83- run = True
84- while run :
85- shorttext = input ('text> ' )
86- if len (shorttext ) > 0 :
87- scoredict = classifier .score (shorttext )
88- for label , score in sorted (scoredict .items (), key = lambda s : s [1 ], reverse = True )[:args .topn ]:
89- print (label + ' : ' + '%.4f' % (score ))
90- else :
91- run = False
92-
91+ # Console
92+ print ('Enter text to classify (empty input to quit):' )
93+ while True :
94+ shorttext = input ('text> ' ).strip ()
95+ if not shorttext :
96+ break
97+ scoredict = classifier .score (shorttext )
98+ for label , score in sorted (scoredict .items (), key = lambda x : x [1 ], reverse = True )[:args .topn ]:
99+ print (f'{ label } : { score :.4f} ' )
93100 print ('Done.' )
101+
102+ if __name__ == "__main__" :
103+ main ()
0 commit comments