-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmodel.js
More file actions
115 lines (87 loc) · 2.79 KB
/
model.js
File metadata and controls
115 lines (87 loc) · 2.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
let tokenize = require('./tokenize.js');
const fs = require('fs')
if (!fs.existsSync('./model.json')) {
console.warn('[!] Fatal error, could not locate a model.json file');
process.exit(0);
}
let model = require('./model.json');
let defaultStartingPhrase = 'Hello there';
function getRandomToken(key) {
if (!model[key]) return null;
const object = model[key];
const tokens = Object.keys(object);
let total = 0;
for (const token of tokens) {
total += object[token];
}
let rand = Math.random() * total;
for (const token of tokens) {
rand -= object[token];
if (rand <= 0) {
return token;
}
}
return tokens[tokens.length - 1];
}
function checkKey(key) {
return model[key] ? true : false;
}
/**
* Returns an array of tokens generated by the prediction model
* @param {number} context
* @param {string} startingPhrase
* @returns
*/
function generateTokens(context, startingPhrase) {
startingPhrase = startingPhrase || defaultStartingPhrase;
let sentence = tokenize(startingPhrase);
let maxTokens = 200;
let tokens = 0;
const TIMEOUT_MS = 5000;
const startTime = Date.now();
while (!sentence.join('').includes('<!end>') && tokens < maxTokens) {
if (Date.now() - startTime > TIMEOUT_MS) return null;
let dynamicContext = Math.min(context, sentence.length);
let match = false;
let seed;
while (!match && dynamicContext > 0) {
if (Date.now() - startTime > TIMEOUT_MS) return null;
seed = sentence.slice(-dynamicContext).join('');
const variants = [
seed,
seed.toLowerCase(),
seed.trim(),
seed.toLowerCase().trim(),
" " + seed,
" " + seed.toLowerCase()
];
for (const variant of variants) {
if (checkKey(variant)) {
seed = variant;
match = true;
break;
}
}
if (!match) {
dynamicContext--;
}
}
if (!match || !(model[seed])) break;
sentence.push(getRandomToken(seed));
tokens++;
}
return sentence;
}
/**
* Formats and stringifys an array of tokens
* @param {array} array
* @returns {string}
*/
function stringifyOutput(array) {
if (!array || !Array.isArray(array)) return null;
let string = array.join('').replaceAll('<!end>', '');
string = string.replace(/\\/g, "\\\\");
string = string.replace(/([*_`|])/g, "\\$1");
return string;
}
module.exports = { generateTokens, stringifyOutput };