20 int argPos(
char *str,
int argc,
char **argv)
24 for (a=1; a<argc; a++)
if (!strcmp(str, argv[a]))
return a;
29 int main(
int argc,
char **argv)
42 int alpha_set=0, train_file_set=0;
47 float gradient_cutoff=15;
49 float starting_alpha=0.1;
50 float regularization=0.0000001;
51 float min_improvement=1.003;
53 int compression_size=0;
66 char train_file[MAX_STRING];
67 char valid_file[MAX_STRING];
68 char test_file[MAX_STRING];
69 char rnnlm_file[MAX_STRING];
70 char lmprob_file[MAX_STRING];
77 printf(
"Recurrent neural network based language modeling toolkit v 0.3d\n\n");
82 printf(
"Parameters for training phase:\n");
84 printf(
"\t-train <file>\n");
85 printf(
"\t\tUse text data from <file> to train rnnlm model\n");
87 printf(
"\t-class <int>\n");
88 printf(
"\t\tWill use specified amount of classes to decompose vocabulary; default is 100\n");
90 printf(
"\t-old-classes\n");
91 printf(
"\t\tThis will use old algorithm to compute classes, which results in slower models but can be a bit more precise\n");
93 printf(
"\t-rnnlm <file>\n");
94 printf(
"\t\tUse <file> to store rnnlm model\n");
96 printf(
"\t-binary\n");
97 printf(
"\t\tRnnlm model will be saved in binary format (default is plain text)\n");
99 printf(
"\t-valid <file>\n");
100 printf(
"\t\tUse <file> as validation data\n");
102 printf(
"\t-alpha <float>\n");
103 printf(
"\t\tSet starting learning rate; default is 0.1\n");
105 printf(
"\t-beta <float>\n");
106 printf(
"\t\tSet L2 regularization parameter; default is 1e-7\n");
108 printf(
"\t-hidden <int>\n");
109 printf(
"\t\tSet size of hidden layer; default is 30\n");
111 printf(
"\t-compression <int>\n");
112 printf(
"\t\tSet size of compression layer; default is 0 (not used)\n");
114 printf(
"\t-direct <int>\n");
115 printf(
"\t\tSets size of the hash for direct connections with n-gram features in millions; default is 0\n");
117 printf(
"\t-direct-order <int>\n");
118 printf(
"\t\tSets the n-gram order for direct connections (max %d); default is 3\n", MAX_NGRAM_ORDER);
120 printf(
"\t-bptt <int>\n");
121 printf(
"\t\tSet amount of steps to propagate error back in time; default is 0 (equal to simple RNN)\n");
123 printf(
"\t-bptt-block <int>\n");
124 printf(
"\t\tSpecifies amount of time steps after which the error is backpropagated through time in block mode (default 10, update at each time step = 1)\n");
126 printf(
"\t-one-iter\n");
127 printf(
"\t\tWill cause training to perform exactly one iteration over training data (useful for adapting final models on different data etc.)\n");
129 printf(
"\t-anti-kasparek <int>\n");
130 printf(
"\t\tModel will be saved during training after processing specified amount of words\n");
132 printf(
"\t-min-improvement <float>\n");
133 printf(
"\t\tSet minimal relative entropy improvement for training convergence; default is 1.003\n");
135 printf(
"\t-gradient-cutoff <float>\n");
136 printf(
"\t\tSet maximal absolute gradient value (to improve training stability, use lower values; default is 15, to turn off use 0)\n");
140 printf(
"Parameters for testing phase:\n");
142 printf(
"\t-rnnlm <file>\n");
143 printf(
"\t\tRead rnnlm model from <file>\n");
145 printf(
"\t-test <file>\n");
146 printf(
"\t\tUse <file> as test data to report perplexity\n");
148 printf(
"\t-lm-prob\n");
149 printf(
"\t\tUse other LM probabilities for linear interpolation with rnnlm model; see examples at the rnnlm webpage\n");
151 printf(
"\t-lambda <float>\n");
152 printf(
"\t\tSet parameter for linear interpolation of rnnlm and other lm; default weight of rnnlm is 0.75\n");
154 printf(
"\t-dynamic <float>\n");
155 printf(
"\t\tSet learning rate for dynamic model updates during testing phase; default is 0 (static model)\n");
159 printf(
"Additional parameters:\n");
161 printf(
"\t-gen <int>\n");
162 printf(
"\t\tGenerate specified amount of words given distribution from current model\n");
164 printf(
"\t-independent\n");
165 printf(
"\t\tWill erase history at end of each sentence (if used for training, this switch should be used also for testing & rescoring)\n");
167 printf(
"\nExamples:\n");
168 printf(
"rnnlm -train train -rnnlm model -valid valid -hidden 50\n");
169 printf(
"rnnlm -rnnlm model -test test\n");
177 i=
argPos((
char *)
"-debug", argc, argv);
180 printf(
"ERROR: debug mode not specified!\n");
184 debug_mode=atoi(argv[i+1]);
187 printf(
"debug mode: %d\n", debug_mode);
192 i=
argPos((
char *)
"-train", argc, argv);
195 printf(
"ERROR: training data file not specified!\n");
199 strcpy(train_file, argv[i+1]);
202 printf(
"train file: %s\n", train_file);
204 f=fopen(train_file,
"rb");
206 printf(
"ERROR: training data file not found!\n");
217 i=
argPos((
char *)
"-one-iter", argc, argv);
222 printf(
"Training for one iteration\n");
227 i=
argPos((
char *)
"-valid", argc, argv);
230 printf(
"ERROR: validation data file not specified!\n");
234 strcpy(valid_file, argv[i+1]);
237 printf(
"valid file: %s\n", valid_file);
239 f=fopen(valid_file,
"rb");
241 printf(
"ERROR: validation data file not found!\n");
248 if (train_mode && !valid_data_set) {
250 printf(
"ERROR: validation data file must be specified for training!\n");
257 i=
argPos((
char *)
"-nbest", argc, argv);
261 printf(
"Processing test data as list of nbests\n");
266 i=
argPos((
char *)
"-test", argc, argv);
269 printf(
"ERROR: test data file not specified!\n");
273 strcpy(test_file, argv[i+1]);
276 printf(
"test file: %s\n", test_file);
279 if (nbest && (!strcmp(test_file,
"-"))) ;
else {
280 f=fopen(test_file,
"rb");
282 printf(
"ERROR: test data file not found!\n");
292 i=
argPos((
char *)
"-class", argc, argv);
295 printf(
"ERROR: amount of classes not specified!\n");
299 class_size=atoi(argv[i+1]);
302 printf(
"class size: %d\n", class_size);
307 i=
argPos((
char *)
"-old-classes", argc, argv);
312 printf(
"Old algorithm for computing classes will be used\n");
317 i=
argPos((
char *)
"-lambda", argc, argv);
320 printf(
"ERROR: lambda not specified!\n");
324 lambda=atof(argv[i+1]);
327 printf(
"Lambda (interpolation coefficient between rnnlm and other lm): %f\n", lambda);
332 i=
argPos((
char *)
"-gradient-cutoff", argc, argv);
335 printf(
"ERROR: gradient cutoff not specified!\n");
339 gradient_cutoff=atof(argv[i+1]);
342 printf(
"Gradient cutoff: %f\n", gradient_cutoff);
347 i=
argPos((
char *)
"-dynamic", argc, argv);
350 printf(
"ERROR: dynamic learning rate not specified!\n");
354 dynamic=atof(argv[i+1]);
357 printf(
"Dynamic learning rate: %f\n", dynamic);
362 i=
argPos((
char *)
"-gen", argc, argv);
365 printf(
"ERROR: gen parameter not specified!\n");
372 printf(
"Generating # words: %d\n", gen);
377 i=
argPos((
char *)
"-independent", argc, argv);
382 printf(
"Sentences will be processed independently...\n");
387 i=
argPos((
char *)
"-alpha", argc, argv);
390 printf(
"ERROR: alpha not specified!\n");
394 starting_alpha=atof(argv[i+1]);
397 printf(
"Starting learning rate: %f\n", starting_alpha);
403 i=
argPos((
char *)
"-beta", argc, argv);
406 printf(
"ERROR: beta not specified!\n");
410 regularization=atof(argv[i+1]);
413 printf(
"Regularization: %f\n", regularization);
418 i=
argPos((
char *)
"-min-improvement", argc, argv);
421 printf(
"ERROR: minimal improvement value not specified!\n");
425 min_improvement=atof(argv[i+1]);
428 printf(
"Min improvement: %f\n", min_improvement);
433 i=
argPos((
char *)
"-anti-kasparek", argc, argv);
436 printf(
"ERROR: anti-kasparek parameter not set!\n");
440 anti_k=atoi(argv[i+1]);
442 if ((anti_k!=0) && (anti_k<10000)) anti_k=10000;
445 printf(
"Model will be saved after each # words: %d\n", anti_k);
450 i=
argPos((
char *)
"-hidden", argc, argv);
453 printf(
"ERROR: hidden layer size not specified!\n");
457 hidden_size=atoi(argv[i+1]);
460 printf(
"Hidden layer size: %d\n", hidden_size);
465 i=
argPos((
char *)
"-compression", argc, argv);
468 printf(
"ERROR: compression layer size not specified!\n");
472 compression_size=atoi(argv[i+1]);
475 printf(
"Compression layer size: %d\n", compression_size);
480 i=
argPos((
char *)
"-direct", argc, argv);
483 printf(
"ERROR: direct connections not specified!\n");
487 direct=atoi(argv[i+1]);
490 if (direct<0) direct=0;
493 printf(
"Direct connections: %dM\n", (
int)(direct/1000000));
498 i=
argPos((
char *)
"-direct-order", argc, argv);
501 printf(
"ERROR: direct order not specified!\n");
505 direct_order=atoi(argv[i+1]);
506 if (direct_order>MAX_NGRAM_ORDER) direct_order=MAX_NGRAM_ORDER;
509 printf(
"Order of direct connections: %d\n", direct_order);
514 i=
argPos((
char *)
"-bptt", argc, argv);
517 printf(
"ERROR: bptt value not specified!\n");
521 bptt=atoi(argv[i+1]);
526 printf(
"BPTT: %d\n", bptt-1);
531 i=
argPos((
char *)
"-bptt-block", argc, argv);
534 printf(
"ERROR: bptt block value not specified!\n");
538 bptt_block=atoi(argv[i+1]);
539 if (bptt_block<1) bptt_block=1;
542 printf(
"BPTT block: %d\n", bptt_block);
547 i=
argPos((
char *)
"-rand-seed", argc, argv);
550 printf(
"ERROR: Random seed variable not specified!\n");
554 rand_seed=atoi(argv[i+1]);
557 printf(
"Rand seed: %d\n", rand_seed);
562 i=
argPos((
char *)
"-lm-prob", argc, argv);
565 printf(
"ERROR: other lm file not specified!\n");
569 strcpy(lmprob_file, argv[i+1]);
572 printf(
"other lm probabilities specified in: %s\n", lmprob_file);
574 f=fopen(lmprob_file,
"rb");
576 printf(
"ERROR: other lm file not found!\n");
585 i=
argPos((
char *)
"-binary", argc, argv);
588 printf(
"Model will be saved in binary format\n");
595 i=
argPos((
char *)
"-rnnlm", argc, argv);
598 printf(
"ERROR: model file not specified!\n");
602 strcpy(rnnlm_file, argv[i+1]);
605 printf(
"rnnlm file: %s\n", rnnlm_file);
607 f=fopen(rnnlm_file,
"rb");
611 if (train_mode && !rnnlm_file_set) {
612 printf(
"ERROR: rnnlm file must be specified for training!\n");
615 if (test_data_set && !rnnlm_file_set) {
616 printf(
"ERROR: rnnlm file must be specified for testing!\n");
619 if (!test_data_set && !train_mode && gen==0) {
620 printf(
"ERROR: training or testing must be specified!\n");
623 if ((gen>0) && !rnnlm_file_set) {
624 printf(
"ERROR: rnnlm file must be specified to generate words!\n");
634 model1.setTrainFile(train_file);
635 model1.setRnnLMFile(rnnlm_file);
636 model1.setFileType(fileformat);
638 model1.setOneIter(one_iter);
639 if (one_iter==0) model1.setValidFile(valid_file);
641 model1.setClassSize(class_size);
642 model1.setOldClasses(old_classes);
643 model1.setLearningRate(starting_alpha);
644 model1.setGradientCutoff(gradient_cutoff);
645 model1.setRegularization(regularization);
646 model1.setMinImprovement(min_improvement);
647 model1.setHiddenLayerSize(hidden_size);
648 model1.setCompressionLayerSize(compression_size);
649 model1.setDirectSize(direct);
650 model1.setDirectOrder(direct_order);
651 model1.setBPTT(bptt);
652 model1.setBPTTBlock(bptt_block);
653 model1.setRandSeed(rand_seed);
654 model1.setDebugMode(debug_mode);
655 model1.setAntiKasparek(anti_k);
656 model1.setIndependent(independent);
658 model1.alpha_set=alpha_set;
659 model1.train_file_set=train_file_set;
664 if (test_data_set && rnnlm_file_set) {
667 model1.setLambda(lambda);
668 model1.setRegularization(regularization);
669 model1.setDynamic(dynamic);
670 model1.setTestFile(test_file);
671 model1.setRnnLMFile(rnnlm_file);
672 model1.setRandSeed(rand_seed);
673 model1.useLMProb(use_lmprob);
674 if (use_lmprob) model1.setLMProbFile(lmprob_file);
675 model1.setDebugMode(debug_mode);
677 if (nbest==0) model1.testNet();
678 else model1.testNbest();
684 model1.setRnnLMFile(rnnlm_file);
685 model1.setDebugMode(debug_mode);
686 model1.setRandSeed(rand_seed);
int main(int argc, char **argv)
int argPos(char *str, int argc, char **argv)