Skip to content
Open

Fast #10

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
bin
7 changes: 3 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
bilbowa: src/bilbowa.v0.1.c
mkdir -p bin
gcc src/bilbowa.v0.1.c -g -o bin/bilbowa -lm -pthread -march=native -funroll-loops -w
gcc -std=c99 -Wall -O2 -march=native -funroll-loops src/bilbowa.v0.1.c -o bin/bilbowa -lm -pthread
bidist: src/bidist.c
gcc src/bidist.c -g -o bin/bidist -lm -march=native -funroll-loops -w
gcc -std=c99 -Wall -O2 -march=native -funroll-loops src/bidist.c -o bin/bidist -lm
all: bilbowa bidist
clean:
rm bin/bilbowa bin/bidist

rm -fv bin/*
6 changes: 3 additions & 3 deletions src/bidist.c
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ int main(int argc, char **argv) {
for (a = 0; a < size_lang1; a++) fread(&M_lang1[a + b * size_lang1], sizeof(float), 1, f);
len = 0;
for (a = 0; a < size_lang1; a++) len += M_lang1[a + b * size_lang1] * M_lang1[a + b * size_lang1];
len = sqrt(len);
len = sqrtf(len);
for (a = 0; a < size_lang1; a++) M_lang1[a + b * size_lang1] /= len;
}
fclose(f);
Expand All @@ -79,7 +79,7 @@ int main(int argc, char **argv) {
for (a = 0; a < size_lang2; a++) fread(&M_lang2[a + b * size_lang2], sizeof(float), 1, f);
len = 0;
for (a = 0; a < size_lang2; a++) len += M_lang2[a + b * size_lang2] * M_lang2[a + b * size_lang2];
len = sqrt(len);
len = sqrtf(len);
for (a = 0; a < size_lang2; a++) M_lang2[a + b * size_lang2] /= len;
}
fclose(f);
Expand Down Expand Up @@ -134,7 +134,7 @@ int main(int argc, char **argv) {
}
len = 0;
for (a = 0; a < size_lang2; a++) len += vec[a] * vec[a];
len = sqrt(len);
len = sqrtf(len);
for (a = 0; a < size_lang2; a++) vec[a] /= len;
for (a = 0; a < N; a++) bestd[a] = 0;
for (a = 0; a < N; a++) bestw[a][0] = 0;
Expand Down
23 changes: 14 additions & 9 deletions src/bilbowa.v0.1.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#define _POSIX_C_SOURCE 200809L // for posix_memalign()

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <pthread.h>
#include <unistd.h>
#include <stdint.h>

#define MAX_STRING 100
#define EXP_TABLE_SIZE 1000
Expand Down Expand Up @@ -110,7 +113,8 @@ void ReadWord(char *word, FILE *fin) {
/* Returns hash value of a word */
int GetWordHash(char *word) {
unsigned long long a, hash = 0;
for (a = 0; a < strlen(word); a++) hash = hash * 257 + word[a];
unsigned int len = strlen(word);
for (a = 0; a < len; a++) hash = hash * 257 + word[a];
hash = hash % vocab_hash_size;
return hash;
}
Expand Down Expand Up @@ -352,7 +356,7 @@ void InitNet(int lang_id) {

char SubSample(int lang_id, long long word_id) {
long long count = vocabs[lang_id][word_id].cn;
real thresh = (sqrt(count / (sample * train_words[lang_id])) + 1) *
real thresh = (sqrtf(count / (sample * train_words[lang_id])) + 1) *
(sample * train_words[lang_id]) / count;
next_random = next_random * (unsigned long long)25214903917 + 11;
if ((next_random & 0xFFFF) / (real)65536 > thresh) return 1;
Expand Down Expand Up @@ -391,7 +395,7 @@ void UpdateEmbeddings(real *embeddings, real *grads, int offset,
if (adagrad) {
// Use Adagrad for automatic learning rate selection
grads[offset + a] += (deltas[a] * deltas[a]);
step = (alpha / fmax(epsilon, sqrt(grads[offset + a]))) * deltas[a];
step = (alpha / fmaxf(epsilon, sqrtf(grads[offset + a]))) * deltas[a];
} else {
// Regular SGD
step = alpha * deltas[a];
Expand Down Expand Up @@ -453,7 +457,7 @@ void BuildCDF(real *cdf, long long *sen, int lang_id, int len) {
}
else if (PAR_SAMPLE && sample > 0) { // subsample
count = vocabs[lang_id][word].cn;
threshold = (sqrt(count / (sample * train_words[lang_id])) + 1) * (sample *
threshold = (sqrtf(count / (sample * train_words[lang_id])) + 1) * (sample *
train_words[lang_id]) / count;
if (threshold < 0) threshold = 0.0;
}
Expand Down Expand Up @@ -502,7 +506,7 @@ void *BilbowaThread(void *id) {
// TODO: Change this for more than two languages
int lang_id1 = 0, lang_id2 = 1;
// Each thread will be responsible for reading a portion of both lang_id1 and lang_id2 files. portion size is: file_size/num_threads
int thread_id = (int)id % num_threads; // total_sampled;
int thread_id = ((int)(uintptr_t) id) % num_threads; // total_sampled;
long long par_sen1[MAX_SEN_LEN], par_sen2[MAX_SEN_LEN],
// sampled_sen1[10], sampled_sen2[10],
updates_l1 = 1, updates_l2 = 1,
Expand Down Expand Up @@ -603,7 +607,8 @@ void *MonoModelThread(void *id) {
long long word_count = 0, last_word_count = 0, all_train_words = 0;
long long mono_sen[MAX_SEN_LEN + 1];
long long l1, l2, c, target, label;
int lang_id = (int)id / num_threads, thread_id = (int)id % num_threads;
int id_int = (int)(uintptr_t) id;
int lang_id = id_int / num_threads, thread_id = id_int % num_threads;
char *train_file = mono_train_files[lang_id];
long long vocab_size = vocab_sizes[lang_id];
real f, g;
Expand Down Expand Up @@ -674,9 +679,9 @@ void *MonoModelThread(void *id) {
}
if (EARLY_STOP) {
if (word_count_actual > EARLY_STOP) {
fprintf(stderr, "EARLY STOP point reached (thread %d)\n", (int)id);
fprintf(stderr, "EARLY STOP point reached (thread %d)\n", (int)(uintptr_t) id);
break;
}
}
}
word = mono_sen[sentence_position];
if (word == -1) continue;
Expand Down Expand Up @@ -800,7 +805,7 @@ void TrainModel() {
expTable = malloc((EXP_TABLE_SIZE + 1) * sizeof(real));
for (i = 0; i < EXP_TABLE_SIZE; i++) {
// Precompute the exp() table
expTable[i] = exp((i / (real)EXP_TABLE_SIZE * 2 - 1) * MAX_EXP);
expTable[i] = expf((i / (real)EXP_TABLE_SIZE * 2 - 1) * MAX_EXP);
// Precompute sigmoid f(x) = x / (x + 1)
expTable[i] = expTable[i] / (expTable[i] + 1);
}
Expand Down