1 module more.uri;
2 
3 import more.format : toHexLower, toHexUpper, formatEscapeByPolicy;
4 import more.parse : hexValue;
5 
6 version(unittest)
7 {
8     import std.stdio : stdout;
9     import more.test;
10 }
11 
12 // http://www.ietf.org/rfc/rfc2396.txt
13 bool isValidSchemeFirstChar(char c)
14 {
15     return
16         (c >= 'a' && c <= 'z') ||
17         (c >= 'A' && c <= 'Z');
18 }
19 bool isValidSchemeChar(char c)
20 {
21     return
22         (c >= 'a' && c <= 'z') ||
23         (c >= 'A' && c <= 'Z') ||
24         (c >= '0' && c <= '9') ||
25         (c == '+') ||
26         (c == '-') ||
27         (c == '.');
28 }
29 /** Returns: 0 if no scheme is found, otherwise, returns the length of the scheme
30              including the colon character */
31 uint parseScheme(const(char)[] uri)
32 {
33     if(uri.length > 0 && isValidSchemeFirstChar(uri[0]))
34     {
35         uint index = 1;
36         foreach(char_; uri[1..$])
37         {
38             if(char_ == ':')
39             {
40                 return cast(uint)index + 1;
41             }
42             if(!isValidSchemeChar(char_))
43             {
44                 break;
45             }
46             index++;
47         }
48     }
49     return 0;
50 }
51 unittest
52 {
53     mixin(scopedTest!"uri - parseScheme");
54 
55     assert(0 == parseScheme(null));
56     assert(0 == parseScheme(""));
57     assert(0 == parseScheme("a"));
58     assert(0 == parseScheme("/"));
59     assert(0 == parseScheme(":"));
60     assert(2 == parseScheme("a:"));
61     assert(4 == parseScheme("abc:"));
62 
63     assert(0 == parseScheme("0bc:"));
64     assert(0 == parseScheme("-bc:"));
65     assert(0 == parseScheme("+bc:"));
66     assert(0 == parseScheme(".bc:"));
67     assert(0 == parseScheme(".bc:"));
68 
69     assert(5 == parseScheme("a0bc:"));
70     assert(5 == parseScheme("a-bc:"));
71     assert(5 == parseScheme("a+bc:"));
72     assert(5 == parseScheme("a.bc:"));
73     assert(5 == parseScheme("a.bc:"));
74 }
75 
76 bool isValidUriChar(char c) pure
77 {
78     if(c >= 'a')
79     {
80         if(c <= 'z' || c == '~')
81         {
82             return true;
83         }
84     }
85     else if(c >= 'A')
86     {
87         if(c <= 'Z' || c == '_')
88         {
89             return true;
90         }
91     }
92     else if(c >= '-')
93     {
94         if(c <= '9' && c != '/')
95         {
96             return true;
97         }
98     }
99     return false;
100 }
101 auto formatUriEncoded(const(char)[] str)
102 {
103     static struct Hooks
104     {
105         enum escapeBufferLength = 3;
106         static void initEscapeBuffer(char* escapeBuffer) pure
107         {
108         }
109         static auto escapeCheck(char* escapeBuffer, char charToCheck) pure
110         {
111             if(charToCheck == ' ')
112             {
113                 escapeBuffer[0] = '+';
114                 return 1;
115             }
116             if(isValidUriChar(charToCheck))
117             {
118                 return 0; // no need to escape
119             }
120             escapeBuffer[0] = '%';
121             escapeBuffer[1] = toHexUpper((cast(ubyte)charToCheck) >> 4);
122             escapeBuffer[2] = toHexUpper((cast(ubyte)charToCheck) & 0x0F);
123             return 3; // write a 3 character '%XX' escape sequence
124         }
125     }
126     return formatEscapeByPolicy!Hooks(str);
127 }
128 unittest
129 {
130     mixin(scopedTest!"uri - formatUriEncoded");
131 
132     import std.format : format;
133     assert(`` == format("%s", formatUriEncoded(``)));
134     assert(`a` == format("%s", formatUriEncoded(`a`)));
135     assert(`abcd` == format("%s", formatUriEncoded(`abcd`)));
136     assert(`abcd+efgh` == format("%s", formatUriEncoded(`abcd efgh`)));
137     assert(`%00%0A%21%2F` == format("%s", formatUriEncoded("\0\n!/")));
138     for(int i = char.min; i <= char.max; i++)
139     {
140         char[1] str;
141         str[0] = cast(char)i;
142 
143         if(str[0] == ' ') {
144             assert("+" == format("%s", formatUriEncoded(str)));
145         } else if(isValidUriChar(str[0])) {
146             char[1] expected;
147             expected[0] = cast(char)i;
148             assert(expected == format("%s", formatUriEncoded(str)));
149         } else {
150             char[3] expected;
151             expected[0] = '%';
152             expected[1] = toHexUpper((cast(ubyte)i) >> 4);
153             expected[2] = toHexUpper((cast(ubyte)i) & 0x0F);
154             assert(expected == format("%s", formatUriEncoded(str)));
155         }
156     }
157 }
158 
159 // TODO: use a function defined in a more common modules
160 private bool contains(T)(const(T)[] haystack, const(T) needle)
161 {
162     foreach (element; haystack)
163     {
164         if (element == needle)
165             return true;
166     }
167     return false;
168 }
169 
170 // bad points to the '%' of the bad URI encoding
171 void copyBadUriEncoding(char* dst, const(char)* bad, size_t max)
172 {
173     for (size_t i = 0; ; i++)
174     {
175         if (i >= max)
176         {
177             dst[i] = '\0';
178             break;
179         }
180         dst[i] = bad[i];
181         if (!dst[i])
182             break;
183     }
184 }
185 
186 /**
187 On success, returns a pointer to the terminating null character in
188 the dst buffer.
189 
190 If the src buffer contains an invalid '%XX' sequence, this function will
191 stop decoding at that point copy the invalid sequence (along with a terminating
192 null) to the dst buffer and return a pointer to the start of the invalid sequence.
193 
194 Note that dst and src can point to the same string.  The decoding is performed left-to-right
195 so it still works.
196 */
197 char* uriDecode(const(char)* src, char* dst, const(char)[] terminatingChars = "\0")
198 {
199     for(;;dst++)
200     {
201         char c = src[0];
202         src++;
203         if(c == '+') {
204             dst[0] = ' ';
205         } else if (c == '%') {
206             c = src[0];
207             src++;
208             const hexNibble1 = hexValue(c);
209             if(hexNibble1 == ubyte.max)
210             {
211                 copyBadUriEncoding(dst, src - 2, 2);
212                 return dst;
213             }
214             c = src[0];
215             src++;
216             const hexNibble2 = hexValue(c);
217             if(hexNibble2 == ubyte.max)
218             {
219                 copyBadUriEncoding(dst, src - 3, 3);
220                 return dst;
221             }
222             dst[0] = cast(char)(hexNibble1 << 4 | hexNibble2);
223         } else if (terminatingChars.contains(c)) {
224             dst[0] = '\0';
225             return dst; // success
226         } else {
227             dst[0] = c;
228         }
229     }
230 }
231 
232 char* uriDecode(const(char)[] src, char* dst)
233 {
234     return uriDecode(src.ptr, src.ptr + src.length, dst);
235 }
236 char* uriDecode(const(char)* src, const(char)* srcLimit, char* dst)
237 {
238     import core.stdc..string : memcpy;
239 
240     for(;;dst++)
241     {
242         if (src >= srcLimit)
243         {
244             *dst = '\0';
245             return dst;
246         }
247         char c = src[0];
248         src++;
249         if(c == '+') {
250             dst[0] = ' ';
251         } else if (c == '%') {
252             if (src + 1 >= srcLimit)
253             {
254                  copyBadUriEncoding(dst, src - 1, 2);
255                  return dst; // fail
256             }
257             c = src[0];
258             src++;
259             const hexNibble1 = hexValue(c);
260             if(hexNibble1 == ubyte.max)
261             {
262                  copyBadUriEncoding(dst, src - 2, 2);
263                  return dst; // fail
264             }
265             c = src[0];
266             src++;
267             const hexNibble2 = hexValue(c);
268             if(hexNibble2 == ubyte.max)
269             {
270                  copyBadUriEncoding(dst, src - 3, 3);
271                  return dst; // fail
272             }
273             dst[0] = cast(char)(hexNibble1 << 4 | hexNibble2);
274         } else {
275             dst[0] = c;
276         }
277     }
278 }
279 
280 /**
281 Returns: null on error, the decoded string on success
282 */
283 char[] tryUriDecodeInPlace(char* value, const(char)[] terminatingChars)
284 {
285     auto result = uriDecode(value, value, terminatingChars);
286     if (result[0] == '\0')
287         return value[0 .. result - value];
288     return null; // error
289 }
290 
291 /**
292 Returns: a new null-terminated string or null on error
293 */
294 char[] tryUriDecode(const(char)[] encoded)
295 {
296     auto decoded = new char[encoded.length + 1];
297     auto result = uriDecode(encoded, decoded.ptr);
298     if (result[0] != '\0')
299         return null; // fail
300     return decoded[0 .. result - decoded.ptr];
301 }
302 
303 unittest
304 {
305     import core.stdc.stdlib : alloca;
306     import core.stdc..string : strlen;
307     mixin(scopedTest!("uri encode/decode"));
308     static void test(const(char)[] before, const(char)[] expectedAfter)
309     {
310         {
311             auto actualAfter = tryUriDecode(before);
312             assert(actualAfter == expectedAfter);
313         }
314         static char[] makeCopy(const(char)[] str, const(char)[] postfix)
315         {
316             auto copy = new char[str.length + postfix.length];
317             copy[0 .. str.length] = str[];
318             copy[str.length .. $] = postfix[];
319             return copy;
320         }
321 
322         {
323             auto actualAfter = cast(char*)alloca(before.length + 1);
324             auto result = uriDecode(before.ptr, actualAfter);
325             assert(result[0] == '\0');
326             assert(actualAfter[0 .. result - actualAfter] == expectedAfter);
327         }
328         {
329             auto copy = makeCopy(before, "&");
330             auto actualAfter = cast(char*)alloca(before.length + 1);
331             auto result = uriDecode(copy.ptr, actualAfter, "&");
332             assert(result[0] == '\0');
333             assert(actualAfter[0 .. result - actualAfter] == expectedAfter);
334         }
335         {
336             auto copy = makeCopy(before, "\0");
337             auto afterCopy = tryUriDecodeInPlace(copy.ptr, "\0");
338             assert(afterCopy == expectedAfter);
339         }
340     }
341     test("", "");
342     test("a", "a");
343     test("abcd", "abcd");
344     test("abcd+efgh", "abcd efgh");
345     test("a%00b%01c%02", "a\x00b\x01c\x02");
346     for(ushort valueAsUShort = ubyte.min; valueAsUShort <= ubyte.max; valueAsUShort++) {
347         auto value = cast(ubyte)valueAsUShort;
348         char[2] expected;
349         expected[0] = cast(char)value;
350         expected[1] = '\0';
351 
352         char[4] str;
353         str[0] = '%';
354         str[1] = toHexLower(cast(ubyte)(value >> 4));
355         str[2] = toHexLower(cast(ubyte)(value & 0x0F));
356         str[3] = '\0';
357 
358         test(str[0 .. 3], expected[0 .. 1]);
359 
360         str[1] = toHexUpper(cast(ubyte)(value >> 4));
361         str[2] = toHexUpper(cast(ubyte)(value & 0x0F));
362 
363         test(str[0 .. 3], expected[0 .. 1]);
364     }
365 
366     static void testError(const(char)[] badEncoding, const(char)[] badPart)
367     {
368         assert(!tryUriDecode(badEncoding));
369         {
370             auto actualAfter = cast(char*)alloca(badEncoding.length + 1);
371             auto result =  uriDecode(badEncoding.ptr, actualAfter);
372             assert(result[0 .. strlen(result)] == badPart);
373         }
374         {
375             auto actualAfter = cast(char*)alloca(badEncoding.length + 1);
376             auto result =  uriDecode(badEncoding, actualAfter);
377             import std.stdio;writefln("badEncoding '%s' result '%s'", badEncoding, result[0 .. strlen(result)]);
378             assert(result[0 .. strlen(result)] == badPart);
379         }
380     }
381     testError("%", "%");
382     testError("foo%", "%");
383     testError("foo%a", "%a");
384     testError("foo%aZ", "%aZ");
385 }