Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2620,6 +2620,8 @@ int save_model(const char *model_file_name, const struct model *model_)

fprintf(fp, "bias %.16g\n", model_->bias);

fprintf(fp, "normalization %d\n", model_->normal);

fprintf(fp, "w\n");
for(i=0; i<w_size; i++)
{
Expand All @@ -2645,6 +2647,7 @@ struct model *load_model(const char *model_file_name)
int nr_feature;
int n;
int nr_class;
int normalfac;
double bias;
model *model_ = Malloc(model,1);
parameter& param = model_->param;
Expand Down Expand Up @@ -2686,6 +2689,11 @@ struct model *load_model(const char *model_file_name)
fscanf(fp,"%d",&nr_class);
model_->nr_class=nr_class;
}
else if(strcmp(cmd,"normalization")==0)
{
fscanf(fp,"%d",&normalfac);
model_->normal=normalfac;
}
else if(strcmp(cmd,"nr_feature")==0)
{
fscanf(fp,"%d",&nr_feature);
Expand Down
1 change: 1 addition & 0 deletions linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ struct model
struct parameter param;
int nr_class; /* number of classes */
int nr_feature;
int normal;
double *w;
int *label; /* label of each class */
double bias;
Expand Down
12 changes: 12 additions & 0 deletions predict.c
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <stdio.h>
#include <ctype.h>
#include <stdlib.h>
#include <math.h>
#include <string.h>
#include <errno.h>
#include "linear.h"
Expand Down Expand Up @@ -133,6 +134,17 @@ void do_predict(FILE *input, FILE *output)
}
x[i].index = -1;

if(model_->normal){
double length = 0;
for(int kk = 0; x[kk].index != -1; kk++)
length += x[kk].value * x[kk].value;

length = sqrt(length);

for(int kk = 0; x[kk].index != -1; kk++)
x[kk].value /= length;
}

if(flag_predict_probability)
{
int j;
Expand Down
27 changes: 26 additions & 1 deletion train.c
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,24 @@ int flag_cross_validation;
int nr_fold;
double bias;

int normal = 0;

void normalize(problem *pb){
feature_node **p = pb->x;

for(int i = 0; i < pb->l; i++){
double length = 0;

for(int j = 0; p[i][j].index != -1; j++)
length += (p[i][j].value) * (p[i][j].value);

length = sqrt(length);

for(int j = 0; p[i][j].index != -1; j++)
p[i][j].value /= length;
}
}

int main(int argc, char **argv)
{
char input_file_name[1024];
Expand All @@ -101,6 +119,7 @@ int main(int argc, char **argv)

parse_command_line(argc, argv, input_file_name, model_file_name);
read_problem(input_file_name);
if(normal) normalize(&prob);
error_msg = check_parameter(&prob,&param);

if(error_msg)
Expand All @@ -116,6 +135,7 @@ int main(int argc, char **argv)
else
{
model_=train(&prob, &param);
model_->normal = normal;
if(save_model(model_file_name, model_))
{
fprintf(stderr,"can't save model to file %s\n",model_file_name);
Expand Down Expand Up @@ -197,6 +217,11 @@ void parse_command_line(int argc, char **argv, char *input_file_name, char *mode
exit_with_help();
switch(argv[i-1][1])
{
case 'n':
normal = 1;
i--;
break;

case 's':
param.solver_type = atoi(argv[i]);
break;
Expand Down Expand Up @@ -252,7 +277,7 @@ void parse_command_line(int argc, char **argv, char *input_file_name, char *mode
// determine filenames
if(i>=argc)
exit_with_help();

strcpy(input_file_name, argv[i]);

if(i<argc-1)
Expand Down