Phonetisaurus  1.0
FST-based Grapheme-to-Phoneme conversion
rnnlm.cc
Go to the documentation of this file.
1 //
3 // Recurrent neural network based statistical language modeling toolkit
4 // Version 0.3e
5 // (c) 2010-2012 Tomas Mikolov (tmikolov@gmail.com)
6 //
8 
9 
10 #include <stdio.h>
11 #include <stdlib.h>
12 #include <string.h>
13 #include <math.h>
14 #include <fstream>
15 #include <iostream>
16 #include "rnnlmlib.h"
17 
18 using namespace std;
19 
20 int argPos(char *str, int argc, char **argv)
21 {
22  int a;
23 
24  for (a=1; a<argc; a++) if (!strcmp(str, argv[a])) return a;
25 
26  return -1;
27 }
28 
29 int main(int argc, char **argv)
30 {
31  int i;
32 
33  int debug_mode=1;
34 
35  int fileformat=TEXT;
36 
37  int train_mode=0;
38  int valid_data_set=0;
39  int test_data_set=0;
40  int rnnlm_file_set=0;
41 
42  int alpha_set=0, train_file_set=0;
43 
44  int class_size=100;
45  int old_classes=0;
46  float lambda=0.75;
47  float gradient_cutoff=15;
48  float dynamic=0;
49  float starting_alpha=0.1;
50  float regularization=0.0000001;
51  float min_improvement=1.003;
52  int hidden_size=30;
53  int compression_size=0;
54  long long direct=0;
55  int direct_order=3;
56  int bptt=0;
57  int bptt_block=10;
58  int gen=0;
59  int independent=0;
60  int use_lmprob=0;
61  int rand_seed=1;
62  int nbest=0;
63  int one_iter=0;
64  int anti_k=0;
65 
66  char train_file[MAX_STRING];
67  char valid_file[MAX_STRING];
68  char test_file[MAX_STRING];
69  char rnnlm_file[MAX_STRING];
70  char lmprob_file[MAX_STRING];
71 
72  FILE *f;
73 
74  if (argc==1) {
75  //printf("Help\n");
76 
77  printf("Recurrent neural network based language modeling toolkit v 0.3d\n\n");
78 
79  printf("Options:\n");
80 
81  //
82  printf("Parameters for training phase:\n");
83 
84  printf("\t-train <file>\n");
85  printf("\t\tUse text data from <file> to train rnnlm model\n");
86 
87  printf("\t-class <int>\n");
88  printf("\t\tWill use specified amount of classes to decompose vocabulary; default is 100\n");
89 
90  printf("\t-old-classes\n");
91  printf("\t\tThis will use old algorithm to compute classes, which results in slower models but can be a bit more precise\n");
92 
93  printf("\t-rnnlm <file>\n");
94  printf("\t\tUse <file> to store rnnlm model\n");
95 
96  printf("\t-binary\n");
97  printf("\t\tRnnlm model will be saved in binary format (default is plain text)\n");
98 
99  printf("\t-valid <file>\n");
100  printf("\t\tUse <file> as validation data\n");
101 
102  printf("\t-alpha <float>\n");
103  printf("\t\tSet starting learning rate; default is 0.1\n");
104 
105  printf("\t-beta <float>\n");
106  printf("\t\tSet L2 regularization parameter; default is 1e-7\n");
107 
108  printf("\t-hidden <int>\n");
109  printf("\t\tSet size of hidden layer; default is 30\n");
110 
111  printf("\t-compression <int>\n");
112  printf("\t\tSet size of compression layer; default is 0 (not used)\n");
113 
114  printf("\t-direct <int>\n");
115  printf("\t\tSets size of the hash for direct connections with n-gram features in millions; default is 0\n");
116 
117  printf("\t-direct-order <int>\n");
118  printf("\t\tSets the n-gram order for direct connections (max %d); default is 3\n", MAX_NGRAM_ORDER);
119 
120  printf("\t-bptt <int>\n");
121  printf("\t\tSet amount of steps to propagate error back in time; default is 0 (equal to simple RNN)\n");
122 
123  printf("\t-bptt-block <int>\n");
124  printf("\t\tSpecifies amount of time steps after which the error is backpropagated through time in block mode (default 10, update at each time step = 1)\n");
125 
126  printf("\t-one-iter\n");
127  printf("\t\tWill cause training to perform exactly one iteration over training data (useful for adapting final models on different data etc.)\n");
128 
129  printf("\t-anti-kasparek <int>\n");
130  printf("\t\tModel will be saved during training after processing specified amount of words\n");
131 
132  printf("\t-min-improvement <float>\n");
133  printf("\t\tSet minimal relative entropy improvement for training convergence; default is 1.003\n");
134 
135  printf("\t-gradient-cutoff <float>\n");
136  printf("\t\tSet maximal absolute gradient value (to improve training stability, use lower values; default is 15, to turn off use 0)\n");
137 
138  //
139 
140  printf("Parameters for testing phase:\n");
141 
142  printf("\t-rnnlm <file>\n");
143  printf("\t\tRead rnnlm model from <file>\n");
144 
145  printf("\t-test <file>\n");
146  printf("\t\tUse <file> as test data to report perplexity\n");
147 
148  printf("\t-lm-prob\n");
149  printf("\t\tUse other LM probabilities for linear interpolation with rnnlm model; see examples at the rnnlm webpage\n");
150 
151  printf("\t-lambda <float>\n");
152  printf("\t\tSet parameter for linear interpolation of rnnlm and other lm; default weight of rnnlm is 0.75\n");
153 
154  printf("\t-dynamic <float>\n");
155  printf("\t\tSet learning rate for dynamic model updates during testing phase; default is 0 (static model)\n");
156 
157  //
158 
159  printf("Additional parameters:\n");
160 
161  printf("\t-gen <int>\n");
162  printf("\t\tGenerate specified amount of words given distribution from current model\n");
163 
164  printf("\t-independent\n");
165  printf("\t\tWill erase history at end of each sentence (if used for training, this switch should be used also for testing & rescoring)\n");
166 
167  printf("\nExamples:\n");
168  printf("rnnlm -train train -rnnlm model -valid valid -hidden 50\n");
169  printf("rnnlm -rnnlm model -test test\n");
170  printf("\n");
171 
172  return 0; //***
173  }
174 
175 
176  //set debug mode
177  i=argPos((char *)"-debug", argc, argv);
178  if (i>0) {
179  if (i+1==argc) {
180  printf("ERROR: debug mode not specified!\n");
181  return 0;
182  }
183 
184  debug_mode=atoi(argv[i+1]);
185 
186  if (debug_mode>0)
187  printf("debug mode: %d\n", debug_mode);
188  }
189 
190 
191  //search for train file
192  i=argPos((char *)"-train", argc, argv);
193  if (i>0) {
194  if (i+1==argc) {
195  printf("ERROR: training data file not specified!\n");
196  return 0;
197  }
198 
199  strcpy(train_file, argv[i+1]);
200 
201  if (debug_mode>0)
202  printf("train file: %s\n", train_file);
203 
204  f=fopen(train_file, "rb");
205  if (f==NULL) {
206  printf("ERROR: training data file not found!\n");
207  return 0;
208  }
209 
210  train_mode=1;
211 
212  train_file_set=1;
213  }
214 
215 
216  //set one-iter
217  i=argPos((char *)"-one-iter", argc, argv);
218  if (i>0) {
219  one_iter=1;
220 
221  if (debug_mode>0)
222  printf("Training for one iteration\n");
223  }
224 
225 
226  //search for validation file
227  i=argPos((char *)"-valid", argc, argv);
228  if (i>0) {
229  if (i+1==argc) {
230  printf("ERROR: validation data file not specified!\n");
231  return 0;
232  }
233 
234  strcpy(valid_file, argv[i+1]);
235 
236  if (debug_mode>0)
237  printf("valid file: %s\n", valid_file);
238 
239  f=fopen(valid_file, "rb");
240  if (f==NULL) {
241  printf("ERROR: validation data file not found!\n");
242  return 0;
243  }
244 
245  valid_data_set=1;
246  }
247 
248  if (train_mode && !valid_data_set) {
249  if (one_iter==0) {
250  printf("ERROR: validation data file must be specified for training!\n");
251  return 0;
252  }
253  }
254 
255 
256  //set nbest rescoring mode
257  i=argPos((char *)"-nbest", argc, argv);
258  if (i>0) {
259  nbest=1;
260  if (debug_mode>0)
261  printf("Processing test data as list of nbests\n");
262  }
263 
264 
265  //search for test file
266  i=argPos((char *)"-test", argc, argv);
267  if (i>0) {
268  if (i+1==argc) {
269  printf("ERROR: test data file not specified!\n");
270  return 0;
271  }
272 
273  strcpy(test_file, argv[i+1]);
274 
275  if (debug_mode>0)
276  printf("test file: %s\n", test_file);
277 
278 
279  if (nbest && (!strcmp(test_file, "-"))) ; else {
280  f=fopen(test_file, "rb");
281  if (f==NULL) {
282  printf("ERROR: test data file not found!\n");
283  return 0;
284  }
285  }
286 
287  test_data_set=1;
288  }
289 
290 
291  //set class size parameter
292  i=argPos((char *)"-class", argc, argv);
293  if (i>0) {
294  if (i+1==argc) {
295  printf("ERROR: amount of classes not specified!\n");
296  return 0;
297  }
298 
299  class_size=atoi(argv[i+1]);
300 
301  if (debug_mode>0)
302  printf("class size: %d\n", class_size);
303  }
304 
305 
306  //set old class
307  i=argPos((char *)"-old-classes", argc, argv);
308  if (i>0) {
309  old_classes=1;
310 
311  if (debug_mode>0)
312  printf("Old algorithm for computing classes will be used\n");
313  }
314 
315 
316  //set lambda
317  i=argPos((char *)"-lambda", argc, argv);
318  if (i>0) {
319  if (i+1==argc) {
320  printf("ERROR: lambda not specified!\n");
321  return 0;
322  }
323 
324  lambda=atof(argv[i+1]);
325 
326  if (debug_mode>0)
327  printf("Lambda (interpolation coefficient between rnnlm and other lm): %f\n", lambda);
328  }
329 
330 
331  //set gradient cutoff
332  i=argPos((char *)"-gradient-cutoff", argc, argv);
333  if (i>0) {
334  if (i+1==argc) {
335  printf("ERROR: gradient cutoff not specified!\n");
336  return 0;
337  }
338 
339  gradient_cutoff=atof(argv[i+1]);
340 
341  if (debug_mode>0)
342  printf("Gradient cutoff: %f\n", gradient_cutoff);
343  }
344 
345 
346  //set dynamic
347  i=argPos((char *)"-dynamic", argc, argv);
348  if (i>0) {
349  if (i+1==argc) {
350  printf("ERROR: dynamic learning rate not specified!\n");
351  return 0;
352  }
353 
354  dynamic=atof(argv[i+1]);
355 
356  if (debug_mode>0)
357  printf("Dynamic learning rate: %f\n", dynamic);
358  }
359 
360 
361  //set gen
362  i=argPos((char *)"-gen", argc, argv);
363  if (i>0) {
364  if (i+1==argc) {
365  printf("ERROR: gen parameter not specified!\n");
366  return 0;
367  }
368 
369  gen=atoi(argv[i+1]);
370 
371  if (debug_mode>0)
372  printf("Generating # words: %d\n", gen);
373  }
374 
375 
376  //set independent
377  i=argPos((char *)"-independent", argc, argv);
378  if (i>0) {
379  independent=1;
380 
381  if (debug_mode>0)
382  printf("Sentences will be processed independently...\n");
383  }
384 
385 
386  //set learning rate
387  i=argPos((char *)"-alpha", argc, argv);
388  if (i>0) {
389  if (i+1==argc) {
390  printf("ERROR: alpha not specified!\n");
391  return 0;
392  }
393 
394  starting_alpha=atof(argv[i+1]);
395 
396  if (debug_mode>0)
397  printf("Starting learning rate: %f\n", starting_alpha);
398  alpha_set=1;
399  }
400 
401 
402  //set regularization
403  i=argPos((char *)"-beta", argc, argv);
404  if (i>0) {
405  if (i+1==argc) {
406  printf("ERROR: beta not specified!\n");
407  return 0;
408  }
409 
410  regularization=atof(argv[i+1]);
411 
412  if (debug_mode>0)
413  printf("Regularization: %f\n", regularization);
414  }
415 
416 
417  //set min improvement
418  i=argPos((char *)"-min-improvement", argc, argv);
419  if (i>0) {
420  if (i+1==argc) {
421  printf("ERROR: minimal improvement value not specified!\n");
422  return 0;
423  }
424 
425  min_improvement=atof(argv[i+1]);
426 
427  if (debug_mode>0)
428  printf("Min improvement: %f\n", min_improvement);
429  }
430 
431 
432  //set anti kasparek
433  i=argPos((char *)"-anti-kasparek", argc, argv);
434  if (i>0) {
435  if (i+1==argc) {
436  printf("ERROR: anti-kasparek parameter not set!\n");
437  return 0;
438  }
439 
440  anti_k=atoi(argv[i+1]);
441 
442  if ((anti_k!=0) && (anti_k<10000)) anti_k=10000;
443 
444  if (debug_mode>0)
445  printf("Model will be saved after each # words: %d\n", anti_k);
446  }
447 
448 
449  //set hidden layer size
450  i=argPos((char *)"-hidden", argc, argv);
451  if (i>0) {
452  if (i+1==argc) {
453  printf("ERROR: hidden layer size not specified!\n");
454  return 0;
455  }
456 
457  hidden_size=atoi(argv[i+1]);
458 
459  if (debug_mode>0)
460  printf("Hidden layer size: %d\n", hidden_size);
461  }
462 
463 
464  //set compression layer size
465  i=argPos((char *)"-compression", argc, argv);
466  if (i>0) {
467  if (i+1==argc) {
468  printf("ERROR: compression layer size not specified!\n");
469  return 0;
470  }
471 
472  compression_size=atoi(argv[i+1]);
473 
474  if (debug_mode>0)
475  printf("Compression layer size: %d\n", compression_size);
476  }
477 
478 
479  //set direct connections
480  i=argPos((char *)"-direct", argc, argv);
481  if (i>0) {
482  if (i+1==argc) {
483  printf("ERROR: direct connections not specified!\n");
484  return 0;
485  }
486 
487  direct=atoi(argv[i+1]);
488 
489  direct*=1000000;
490  if (direct<0) direct=0;
491 
492  if (debug_mode>0)
493  printf("Direct connections: %dM\n", (int)(direct/1000000));
494  }
495 
496 
497  //set order of direct connections
498  i=argPos((char *)"-direct-order", argc, argv);
499  if (i>0) {
500  if (i+1==argc) {
501  printf("ERROR: direct order not specified!\n");
502  return 0;
503  }
504 
505  direct_order=atoi(argv[i+1]);
506  if (direct_order>MAX_NGRAM_ORDER) direct_order=MAX_NGRAM_ORDER;
507 
508  if (debug_mode>0)
509  printf("Order of direct connections: %d\n", direct_order);
510  }
511 
512 
513  //set bptt
514  i=argPos((char *)"-bptt", argc, argv);
515  if (i>0) {
516  if (i+1==argc) {
517  printf("ERROR: bptt value not specified!\n");
518  return 0;
519  }
520 
521  bptt=atoi(argv[i+1]);
522  bptt++;
523  if (bptt<1) bptt=1;
524 
525  if (debug_mode>0)
526  printf("BPTT: %d\n", bptt-1);
527  }
528 
529 
530  //set bptt block
531  i=argPos((char *)"-bptt-block", argc, argv);
532  if (i>0) {
533  if (i+1==argc) {
534  printf("ERROR: bptt block value not specified!\n");
535  return 0;
536  }
537 
538  bptt_block=atoi(argv[i+1]);
539  if (bptt_block<1) bptt_block=1;
540 
541  if (debug_mode>0)
542  printf("BPTT block: %d\n", bptt_block);
543  }
544 
545 
546  //set random seed
547  i=argPos((char *)"-rand-seed", argc, argv);
548  if (i>0) {
549  if (i+1==argc) {
550  printf("ERROR: Random seed variable not specified!\n");
551  return 0;
552  }
553 
554  rand_seed=atoi(argv[i+1]);
555 
556  if (debug_mode>0)
557  printf("Rand seed: %d\n", rand_seed);
558  }
559 
560 
561  //use other lm
562  i=argPos((char *)"-lm-prob", argc, argv);
563  if (i>0) {
564  if (i+1==argc) {
565  printf("ERROR: other lm file not specified!\n");
566  return 0;
567  }
568 
569  strcpy(lmprob_file, argv[i+1]);
570 
571  if (debug_mode>0)
572  printf("other lm probabilities specified in: %s\n", lmprob_file);
573 
574  f=fopen(lmprob_file, "rb");
575  if (f==NULL) {
576  printf("ERROR: other lm file not found!\n");
577  return 0;
578  }
579 
580  use_lmprob=1;
581  }
582 
583 
584  //search for binary option
585  i=argPos((char *)"-binary", argc, argv);
586  if (i>0) {
587  if (debug_mode>0)
588  printf("Model will be saved in binary format\n");
589 
590  fileformat=BINARY;
591  }
592 
593 
594  //search for rnnlm file
595  i=argPos((char *)"-rnnlm", argc, argv);
596  if (i>0) {
597  if (i+1==argc) {
598  printf("ERROR: model file not specified!\n");
599  return 0;
600  }
601 
602  strcpy(rnnlm_file, argv[i+1]);
603 
604  if (debug_mode>0)
605  printf("rnnlm file: %s\n", rnnlm_file);
606 
607  f=fopen(rnnlm_file, "rb");
608 
609  rnnlm_file_set=1;
610  }
611  if (train_mode && !rnnlm_file_set) {
612  printf("ERROR: rnnlm file must be specified for training!\n");
613  return 0;
614  }
615  if (test_data_set && !rnnlm_file_set) {
616  printf("ERROR: rnnlm file must be specified for testing!\n");
617  return 0;
618  }
619  if (!test_data_set && !train_mode && gen==0) {
620  printf("ERROR: training or testing must be specified!\n");
621  return 0;
622  }
623  if ((gen>0) && !rnnlm_file_set) {
624  printf("ERROR: rnnlm file must be specified to generate words!\n");
625  return 0;
626  }
627 
628 
629  srand(1);
630 
631  if (train_mode) {
632  CRnnLM model1;
633 
634  model1.setTrainFile(train_file);
635  model1.setRnnLMFile(rnnlm_file);
636  model1.setFileType(fileformat);
637 
638  model1.setOneIter(one_iter);
639  if (one_iter==0) model1.setValidFile(valid_file);
640 
641  model1.setClassSize(class_size);
642  model1.setOldClasses(old_classes);
643  model1.setLearningRate(starting_alpha);
644  model1.setGradientCutoff(gradient_cutoff);
645  model1.setRegularization(regularization);
646  model1.setMinImprovement(min_improvement);
647  model1.setHiddenLayerSize(hidden_size);
648  model1.setCompressionLayerSize(compression_size);
649  model1.setDirectSize(direct);
650  model1.setDirectOrder(direct_order);
651  model1.setBPTT(bptt);
652  model1.setBPTTBlock(bptt_block);
653  model1.setRandSeed(rand_seed);
654  model1.setDebugMode(debug_mode);
655  model1.setAntiKasparek(anti_k);
656  model1.setIndependent(independent);
657 
658  model1.alpha_set=alpha_set;
659  model1.train_file_set=train_file_set;
660 
661  model1.trainNet();
662  }
663 
664  if (test_data_set && rnnlm_file_set) {
665  CRnnLM model1;
666 
667  model1.setLambda(lambda);
668  model1.setRegularization(regularization);
669  model1.setDynamic(dynamic);
670  model1.setTestFile(test_file);
671  model1.setRnnLMFile(rnnlm_file);
672  model1.setRandSeed(rand_seed);
673  model1.useLMProb(use_lmprob);
674  if (use_lmprob) model1.setLMProbFile(lmprob_file);
675  model1.setDebugMode(debug_mode);
676 
677  if (nbest==0) model1.testNet();
678  else model1.testNbest();
679  }
680 
681  if (gen>0) {
682  CRnnLM model1;
683 
684  model1.setRnnLMFile(rnnlm_file);
685  model1.setDebugMode(debug_mode);
686  model1.setRandSeed(rand_seed);
687  model1.setGen(gen);
688 
689  model1.testGen();
690  }
691 
692 
693  return 0;
694 }
int main(int argc, char **argv)
Definition: rnnlm.cc:29
int argPos(char *str, int argc, char **argv)
Definition: rnnlm.cc:20