diff --git a/linear.cpp b/linear.cpp index 6843833..39ba61b 100644 --- a/linear.cpp +++ b/linear.cpp @@ -2620,6 +2620,8 @@ int save_model(const char *model_file_name, const struct model *model_) fprintf(fp, "bias %.16g\n", model_->bias); + fprintf(fp, "normalization %d\n", model_->normal); + fprintf(fp, "w\n"); for(i=0; iparam; @@ -2686,6 +2689,11 @@ struct model *load_model(const char *model_file_name) fscanf(fp,"%d",&nr_class); model_->nr_class=nr_class; } + else if(strcmp(cmd,"normalization")==0) + { + fscanf(fp,"%d",&normalfac); + model_->normal=normalfac; + } else if(strcmp(cmd,"nr_feature")==0) { fscanf(fp,"%d",&nr_feature); diff --git a/linear.h b/linear.h index 22a3567..d4334c3 100644 --- a/linear.h +++ b/linear.h @@ -39,6 +39,7 @@ struct model struct parameter param; int nr_class; /* number of classes */ int nr_feature; + int normal; double *w; int *label; /* label of each class */ double bias; diff --git a/predict.c b/predict.c index c5b3f1d..878f2ea 100644 --- a/predict.c +++ b/predict.c @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include "linear.h" @@ -133,6 +134,17 @@ void do_predict(FILE *input, FILE *output) } x[i].index = -1; + if(model_->normal){ + double length = 0; + for(int kk = 0; x[kk].index != -1; kk++) + length += x[kk].value * x[kk].value; + + length = sqrt(length); + + for(int kk = 0; x[kk].index != -1; kk++) + x[kk].value /= length; + } + if(flag_predict_probability) { int j; diff --git a/train.c b/train.c index d388175..e7f0ecc 100644 --- a/train.c +++ b/train.c @@ -93,6 +93,24 @@ int flag_cross_validation; int nr_fold; double bias; +int normal = 0; + +void normalize(problem *pb){ + feature_node **p = pb->x; + + for(int i = 0; i < pb->l; i++){ + double length = 0; + + for(int j = 0; p[i][j].index != -1; j++) + length += (p[i][j].value) * (p[i][j].value); + + length = sqrt(length); + + for(int j = 0; p[i][j].index != -1; j++) + p[i][j].value /= length; + } +} + int main(int argc, char **argv) { char input_file_name[1024]; @@ -101,6 +119,7 @@ int main(int argc, char **argv) parse_command_line(argc, argv, input_file_name, model_file_name); read_problem(input_file_name); + if(normal) normalize(&prob); error_msg = check_parameter(&prob,¶m); if(error_msg) @@ -116,6 +135,7 @@ int main(int argc, char **argv) else { model_=train(&prob, ¶m); + model_->normal = normal; if(save_model(model_file_name, model_)) { fprintf(stderr,"can't save model to file %s\n",model_file_name); @@ -197,6 +217,11 @@ void parse_command_line(int argc, char **argv, char *input_file_name, char *mode exit_with_help(); switch(argv[i-1][1]) { + case 'n': + normal = 1; + i--; + break; + case 's': param.solver_type = atoi(argv[i]); break; @@ -252,7 +277,7 @@ void parse_command_line(int argc, char **argv, char *input_file_name, char *mode // determine filenames if(i>=argc) exit_with_help(); - + strcpy(input_file_name, argv[i]); if(i