1 module wagomu; 2 3 /* 4 * Copyright (C) 2009 The Tegaki project contributors 5 * 6 * Ported to D by webfreak 7 * 8 * This program is free software; you can redistribute it and/or modify 9 * it under the terms of the GNU General Public License as published by 10 * the Free Software Foundation; either version 2 of the License, or 11 * (at your option) any later version. 12 * 13 * This program is distributed in the hope that it will be useful, 14 * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 * GNU General Public License for more details. 17 * 18 * You should have received a copy of the GNU General Public License along 19 * with this program; if not, write to the Free Software Foundation, Inc., 20 * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 21 */ 22 23 import std.algorithm; 24 import std.file; 25 import std.math; 26 import std.range; 27 28 @safe: 29 30 enum MAGIC_NUMBER = 0x77778888; 31 enum VEC_DIM_MAX = 4; 32 33 struct CharacterInfo 34 { 35 dchar unicode; 36 uint n_vectors; 37 } 38 39 struct CharacterGroup 40 { 41 uint n_strokes; 42 uint n_chars; 43 uint offset; 44 void[4] pad; 45 } 46 47 struct CharDist 48 { 49 dchar unicode; 50 float distance = 0; 51 } 52 53 struct Character 54 { 55 this(uint numVectors, uint numStrokes) 56 { 57 this.numStrokes = numStrokes; 58 59 if (numVectors) 60 points = new float[VEC_DIM_MAX][numVectors]; 61 } 62 63 float[VEC_DIM_MAX][] points; 64 uint numStrokes; 65 } 66 67 struct Recognizer 68 { 69 int windowSize = 3; 70 ubyte[] data; 71 uint numCharacters, numGroups, dimension, downsampleThreshold; 72 CharacterInfo[] characters; 73 CharacterGroup[] groups; 74 CharDist[] distm; 75 float[] dtw1, dtw2; 76 77 void load(string path) @safe 78 { 79 data = (() @trusted => cast(ubyte[]) read(path))(); 80 81 if (data.length < 20) 82 throw new Exception("Not a valid file"); 83 84 uint[] header = (() @trusted => (cast(uint*) data.ptr)[0 .. 5])(); 85 if (header[0] != MAGIC_NUMBER) 86 throw new Exception("Not a valid file"); 87 88 numCharacters = header[1]; 89 numGroups = header[2]; 90 dimension = header[3]; 91 downsampleThreshold = header[4]; 92 93 if (numCharacters == 0 || numGroups == 0) 94 throw new Exception("No characters in this model"); 95 96 if ( 97 data.length < 5 * uint.sizeof + numCharacters * CharacterInfo.sizeof 98 + numGroups * CharacterGroup.sizeof) 99 throw new Exception("Not a valid file"); 100 101 (() @trusted{ 102 characters = (cast(CharacterInfo*)(data.ptr + 5 * uint.sizeof))[0 .. numCharacters]; 103 groups = (cast(CharacterGroup*)( 104 data.ptr + 5 * uint.sizeof + numCharacters * CharacterInfo.sizeof))[0 .. numGroups]; 105 })(); 106 107 distm = new CharDist[numCharacters]; 108 109 const maxnvec = maxNumVectors; 110 111 dtw1 = new float[maxnvec * VEC_DIM_MAX]; 112 dtw2 = new float[maxnvec * VEC_DIM_MAX]; 113 dtw1[] = 0; 114 dtw2[] = 0; 115 } 116 117 uint maxNumVectors() const @safe 118 { 119 uint maxNumVectors; 120 foreach (ref ch; characters) 121 if (ch.n_vectors > maxNumVectors) 122 maxNumVectors = ch.n_vectors; 123 return maxNumVectors; 124 } 125 126 /* The euclidean distance is replaced by the sum of absolute 127 differences for performance reasons... */ 128 float localDistance(float[VEC_DIM_MAX] v1, float[VEC_DIM_MAX] v2) const 129 { 130 float sum = 0; 131 for (uint i = 0; i < dimension; i++) 132 sum += abs(v2[i] - v1[i]); 133 return sum; 134 } 135 136 /** 137 m [X][ ][ ][ ][ ][r] 138 [X][ ][ ][ ][ ][ ] 139 [X][ ][ ][ ][ ][ ] 140 [X][ ][ ][ ][ ][ ] 141 [0][X][X][X][X][X] 142 n 143 Each cell in the n*m matrix is defined as follows: 144 145 dtw(i,j) = local_distance(i,j) + MIN3(dtw(i-1,j-1), dtw(i-1,j), dtw(i,j-1)) 146 Cells marked with an X are set to infinity. 147 The bottom-left cell is set to 0. 148 The top-right cell is the result. 149 At any given time, we only need two columns of the matrix, thus we use 150 two arrays dtw1 and dtw2 as our data structure. 151 [ ] [ ] 152 [ j ] [ j ] 153 [j-1] [j-1] 154 [ ] [ ] 155 [ X ] [ X ] 156 dtw1 dtw2 157 A cell can thus be calculated as follows: 158 dtw2(j) = local_distance(i,j) + MIN3(dtw2(j-1), dtw1(j), dtw1(j-1)) 159 */ 160 float dtw(in float[VEC_DIM_MAX][] s, in float[VEC_DIM_MAX][] t) 161 { 162 float cost = 0; 163 164 dtw1[] = float.max; 165 dtw1[0] = 0; 166 dtw2[0] = float.max; 167 168 for (size_t i = 1; i < s.length; i++) 169 { 170 for (size_t j = 1; j < t.length; j++) 171 { 172 cost = localDistance(s[i], t[j]); 173 dtw2[j] = cost + min(dtw2[j - 1], dtw1[j], dtw1[j - 1]); 174 } 175 176 auto tmp = dtw1; 177 dtw1 = dtw2; 178 dtw2 = tmp; 179 dtw2[0] = float.max; 180 } 181 182 return dtw1[t.length - 1]; 183 } 184 185 CharDist[] recognize(in ref Character ch, uint maxResults) 186 { 187 auto numVectors = ch.points.length; 188 auto numStrokes = ch.numStrokes; 189 auto input = ch.points; 190 191 uint numChars, charID; 192 193 foreach (ref group; groups) 194 { 195 if (group.n_strokes > (numStrokes + windowSize)) 196 break; 197 if (numStrokes > windowSize && group.n_strokes < (numStrokes + windowSize)) 198 { 199 charID += group.n_chars; 200 continue; 201 } 202 203 (() @trusted{ 204 auto cursor = cast(float*)(data.ptr + group.offset); 205 206 for (int i = 0; i < group.n_chars; i++) 207 { 208 distm[numChars].unicode = characters[charID].unicode; 209 distm[numChars].distance = dtw(input, 210 (cast(float[VEC_DIM_MAX]*) cursor)[0 .. characters[charID].n_vectors]); 211 cursor += characters[charID].n_vectors * VEC_DIM_MAX; 212 charID++; 213 numChars++; 214 } 215 })(); 216 } 217 218 auto size = min(numChars, maxResults); 219 220 CharDist[] results = new CharDist[size]; 221 int i; 222 223 foreach (res; distm[0 .. numChars].sort!((a, b) => charDistCmp(a, b) < 0 224 ? true : false).take(maxResults)) 225 results[i++] = res; 226 227 return results; 228 } 229 } 230 231 int charDistCmp(in CharDist a, in CharDist b) 232 { 233 if (a.distance < b.distance) 234 return -1; 235 if (a.distance > b.distance) 236 return 1; 237 return 0; 238 } 239 240 /// 241 unittest 242 { 243 Recognizer r = Recognizer(2); 244 r.load("/usr/share/tegaki/models/wagomu/joyo-kanji.model"); 245 // this model is on a 1000x1000 canvas 246 247 Character ch = Character(100, 1); 248 for (int x = 0; x < 10; x++) 249 ch.points[x] = [x * 80 + 100, 500, 0, 0]; 250 auto res = r.recognize(ch, 5); 251 252 assert(res[0].unicode == '一'); 253 }