From 93eb84aca23526959b76401fd6509f151a589e9a Mon Sep 17 00:00:00 2001
From: Determinant <ted.sybil@gmail.com>
Date: Sun, 13 Mar 2016 16:18:36 +0800
Subject: add TNet tutorial; support converting global transf from TNet format

---
 kaldi_io/tools/kaldi_to_nerv.cpp | 17 +++++++++++++++--
 1 file changed, 15 insertions(+), 2 deletions(-)

(limited to 'kaldi_io')

diff --git a/kaldi_io/tools/kaldi_to_nerv.cpp b/kaldi_io/tools/kaldi_to_nerv.cpp
index f16de44..aadac53 100644
--- a/kaldi_io/tools/kaldi_to_nerv.cpp
+++ b/kaldi_io/tools/kaldi_to_nerv.cpp
@@ -4,6 +4,7 @@
 #include <cstring>
 #include <cassert>
 #include <cstdlib>
+#include <map>
 
 char token[1024];
 char output[1024];
@@ -23,6 +24,18 @@ void free_matrix(double **mat, int nrow, int ncol) {
     delete [] mat;
 }
 
+int cnt0;
+std::map<std::string, int> param_cnt;
+int get_param_cnt(const std::string &key) {
+    std::map<std::string, int>::iterator it = param_cnt.find(key);
+    if (it == param_cnt.end())
+    {
+        param_cnt[key] = cnt0 + 1;
+        return cnt0;
+    }
+    return it->second++;
+}
+
 int main(int argc, char **argv) {
     FILE *fin;
     std::ofstream fout;
@@ -30,13 +43,14 @@ int main(int argc, char **argv) {
     fin = fopen(argv[1], "r");
     fout.open(argv[2]);
     assert(fin != NULL);
-    int cnt = argc > 3 ? atoi(argv[3]) : 0;
+    cnt0 = argc > 3 ? atoi(argv[3]) : 0;
     bool shift;
     while (fscanf(fin, "%s", token) != EOF)
     {
         int nrow, ncol;
         int i, j;
         double **mat;
+        int cnt = get_param_cnt(token);
         if (strcmp(token, "<AffineTransform>") == 0)
         {
             double lrate, blrate, mnorm;
@@ -91,7 +105,6 @@ int main(int argc, char **argv) {
                 sprintf(output, "[%13lu]\n", length);
                 fout << output;
                 fout.seekp(0, std::ios_base::end);
-                cnt++;
             }
             free_matrix(mat, nrow, ncol);
         }
-- 
cgit v1.2.3-70-g09d2