1 module wagomu;
2 
3 /*
4 * Copyright (C) 2009 The Tegaki project contributors
5 *
6 * Ported to D by webfreak
7 *
8 * This program is free software; you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation; either version 2 of the License, or
11 * (at your option) any later version.
12 *
13 * This program is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License along
19 * with this program; if not, write to the Free Software Foundation, Inc.,
20 * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 */
22 
23 import std.algorithm;
24 import std.file;
25 import std.math;
26 import std.range;
27 
28 @safe:
29 
30 enum MAGIC_NUMBER = 0x77778888;
31 enum VEC_DIM_MAX = 4;
32 
33 struct CharacterInfo
34 {
35 	dchar unicode;
36 	uint n_vectors;
37 }
38 
39 struct CharacterGroup
40 {
41 	uint n_strokes;
42 	uint n_chars;
43 	uint offset;
44 	void[4] pad;
45 }
46 
47 struct CharDist
48 {
49 	dchar unicode;
50 	float distance = 0;
51 }
52 
53 struct Character
54 {
55 	this(uint numVectors, uint numStrokes)
56 	{
57 		this.numStrokes = numStrokes;
58 
59 		if (numVectors)
60 			points = new float[VEC_DIM_MAX][numVectors];
61 	}
62 
63 	float[VEC_DIM_MAX][] points;
64 	uint numStrokes;
65 }
66 
67 struct Recognizer
68 {
69 	int windowSize = 3;
70 	ubyte[] data;
71 	uint numCharacters, numGroups, dimension, downsampleThreshold;
72 	CharacterInfo[] characters;
73 	CharacterGroup[] groups;
74 	CharDist[] distm;
75 	float[] dtw1, dtw2;
76 
77 	void load(string path) @safe
78 	{
79 		data = (() @trusted => cast(ubyte[]) read(path))();
80 
81 		if (data.length < 20)
82 			throw new Exception("Not a valid file");
83 
84 		uint[] header = (() @trusted => (cast(uint*) data.ptr)[0 .. 5])();
85 		if (header[0] != MAGIC_NUMBER)
86 			throw new Exception("Not a valid file");
87 
88 		numCharacters = header[1];
89 		numGroups = header[2];
90 		dimension = header[3];
91 		downsampleThreshold = header[4];
92 
93 		if (numCharacters == 0 || numGroups == 0)
94 			throw new Exception("No characters in this model");
95 
96 		if (
97 			data.length < 5 * uint.sizeof + numCharacters * CharacterInfo.sizeof
98 				+ numGroups * CharacterGroup.sizeof)
99 			throw new Exception("Not a valid file");
100 
101 		(() @trusted{
102 			characters = (cast(CharacterInfo*)(data.ptr + 5 * uint.sizeof))[0 .. numCharacters];
103 			groups = (cast(CharacterGroup*)(
104 				data.ptr + 5 * uint.sizeof + numCharacters * CharacterInfo.sizeof))[0 .. numGroups];
105 		})();
106 
107 		distm = new CharDist[numCharacters];
108 
109 		const maxnvec = maxNumVectors;
110 
111 		dtw1 = new float[maxnvec * VEC_DIM_MAX];
112 		dtw2 = new float[maxnvec * VEC_DIM_MAX];
113 		dtw1[] = 0;
114 		dtw2[] = 0;
115 	}
116 
117 	uint maxNumVectors() const @safe
118 	{
119 		uint maxNumVectors;
120 		foreach (ref ch; characters)
121 			if (ch.n_vectors > maxNumVectors)
122 				maxNumVectors = ch.n_vectors;
123 		return maxNumVectors;
124 	}
125 
126 	/* The euclidean distance is replaced by the sum of absolute
127    differences for performance reasons... */
128 	float localDistance(float[VEC_DIM_MAX] v1, float[VEC_DIM_MAX] v2) const
129 	{
130 		float sum = 0;
131 		for (uint i = 0; i < dimension; i++)
132 			sum += abs(v2[i] - v1[i]);
133 		return sum;
134 	}
135 
136 	/**
137 	m [X][ ][ ][ ][ ][r]
138 		[X][ ][ ][ ][ ][ ]
139 		[X][ ][ ][ ][ ][ ]
140 		[X][ ][ ][ ][ ][ ]
141 		[0][X][X][X][X][X]
142 										n
143 	Each cell in the n*m matrix is defined as follows:
144 			
145 			dtw(i,j) = local_distance(i,j) + MIN3(dtw(i-1,j-1), dtw(i-1,j), dtw(i,j-1))
146 	Cells marked with an X are set to infinity.
147 	The bottom-left cell is set to 0.
148 	The top-right cell is the result.
149 	At any given time, we only need two columns of the matrix, thus we use
150 	two arrays dtw1 and dtw2 as our data structure.
151 	[   ]   [   ]
152 	[ j ]   [ j ]
153 	[j-1]   [j-1]
154 	[   ]   [   ]
155 	[ X ]   [ X ]
156 	dtw1    dtw2
157 	A cell can thus be calculated as follows:
158 			dtw2(j) = local_distance(i,j) + MIN3(dtw2(j-1), dtw1(j), dtw1(j-1))
159 	*/
160 	float dtw(in float[VEC_DIM_MAX][] s, in float[VEC_DIM_MAX][] t)
161 	{
162 		float cost = 0;
163 
164 		dtw1[] = float.max;
165 		dtw1[0] = 0;
166 		dtw2[0] = float.max;
167 
168 		for (size_t i = 1; i < s.length; i++)
169 		{
170 			for (size_t j = 1; j < t.length; j++)
171 			{
172 				cost = localDistance(s[i], t[j]);
173 				dtw2[j] = cost + min(dtw2[j - 1], dtw1[j], dtw1[j - 1]);
174 			}
175 
176 			auto tmp = dtw1;
177 			dtw1 = dtw2;
178 			dtw2 = tmp;
179 			dtw2[0] = float.max;
180 		}
181 
182 		return dtw1[t.length - 1];
183 	}
184 
185 	CharDist[] recognize(in ref Character ch, uint maxResults)
186 	{
187 		auto numVectors = ch.points.length;
188 		auto numStrokes = ch.numStrokes;
189 		auto input = ch.points;
190 
191 		uint numChars, charID;
192 
193 		foreach (ref group; groups)
194 		{
195 			if (group.n_strokes > (numStrokes + windowSize))
196 				break;
197 			if (numStrokes > windowSize && group.n_strokes < (numStrokes + windowSize))
198 			{
199 				charID += group.n_chars;
200 				continue;
201 			}
202 
203 			(() @trusted{
204 				auto cursor = cast(float*)(data.ptr + group.offset);
205 
206 				for (int i = 0; i < group.n_chars; i++)
207 				{
208 					distm[numChars].unicode = characters[charID].unicode;
209 					distm[numChars].distance = dtw(input,
210 						(cast(float[VEC_DIM_MAX]*) cursor)[0 .. characters[charID].n_vectors]);
211 					cursor += characters[charID].n_vectors * VEC_DIM_MAX;
212 					charID++;
213 					numChars++;
214 				}
215 			})();
216 		}
217 
218 		auto size = min(numChars, maxResults);
219 
220 		CharDist[] results = new CharDist[size];
221 		int i;
222 
223 		foreach (res; distm[0 .. numChars].sort!((a, b) => charDistCmp(a, b) < 0
224 				? true : false).take(maxResults))
225 			results[i++] = res;
226 
227 		return results;
228 	}
229 }
230 
231 int charDistCmp(in CharDist a, in CharDist b)
232 {
233 	if (a.distance < b.distance)
234 		return -1;
235 	if (a.distance > b.distance)
236 		return 1;
237 	return 0;
238 }
239 
240 ///
241 unittest
242 {
243 	Recognizer r = Recognizer(2);
244 	r.load("/usr/share/tegaki/models/wagomu/joyo-kanji.model");
245 	// this model is on a 1000x1000 canvas
246 
247 	Character ch = Character(100, 1);
248 	for (int x = 0; x < 10; x++)
249 		ch.points[x] = [x * 80 + 100, 500, 0, 0];
250 	auto res = r.recognize(ch, 5);
251 
252 	assert(res[0].unicode == '一');
253 }