1 module more.net.socketserver;
2 
3 import std.traits : hasMember;
4 import more.net.sock;
5 
6 version(Windows)
7 {
8     import core.sys.windows.windows : GetTickCount, timeval, GetLastError;
9     alias TimerTime = uint;
10     auto getCurrentTimeMillis()
11     {
12         return GetTickCount();
13     }
14 }
15 
16 enum EventFlags : ubyte
17 {
18     none  = 0x00,
19     read  = 0x01,
20     write = 0x02,
21     error = 0x04,
22     all   = read | write | error,
23 }
24 
25 enum NO_TIMER = 0;
26 
27 enum SelectSet { read = 0, write = 1, error = 2}
28 struct SelectSetProperties
29 {
30     EventFlags eventFlag;
31     string name;
32 }
33 __gshared immutable selectSetProps = [
34     immutable SelectSetProperties(EventFlags.read, "read"),
35     immutable SelectSetProperties(EventFlags.write, "write"),
36     immutable SelectSetProperties(EventFlags.error, "error"),
37 ];
38 
39 struct SocketServerTemplate(Policy)
40 {
41     struct EventSocket
42     {
43         SocketHandle handle;
44         EventFlags flags;
45 
46         void function(EventSocket*) handler;
47 
48         // The number of milliseconds until the timer event occurs.
49         // 0 means no timer.
50         uint timer;
51         // The time when the timer expires
52         TimerTime timerExpireTime;
53 
54         static if( hasMember!(Policy, "EventSocketMixinTemplate") )
55         {
56             mixin Policy.EventSocketMixinTemplate;
57         }
58 
59         void setToRemove()
60         {
61             flags = EventFlags.none;
62             timer = NO_TIMER;
63         }
64 
65         void updateTimerExpireTime()
66         {
67             if(timer != NO_TIMER)
68             {
69                 timerExpireTime = getCurrentTimeMillis() + timer;
70             }
71         }
72     }
73 
74     enum SET_COUNT =
75         cast(ubyte)Policy.ReadEvents +
76         cast(ubyte)Policy.WriteEvents +
77         cast(ubyte)Policy.ErrorEvents;
78     static if(Policy.ReadEvents)
79     {
80         static if(Policy.WriteEvents)
81         {
82             static if(Policy.ErrorEvents)
83                 __gshared immutable setIndexToSetPropIndex = [SelectSet.read, SelectSet.write, SelectSet.error];
84             else
85                 __gshared immutable setIndexToSetPropIndex = [SelectSet.read, SelectSet.write];
86         }
87         else
88         {
89             static if(Policy.ErrorEvents)
90                 __gshared immutable setIndexToSetPropIndex = [SelectSet.read, SelectSet.error];
91             else
92                 __gshared immutable setIndexToSetPropIndex = [SelectSet.read];
93         }
94     }
95     else
96     {
97         static if(Policy.WriteEvents)
98         {
99             static if(Policy.ErrorEvents)
100                 __gshared immutable setIndexToSetPropIndex = [SelectSet.write, SelectSet.error];
101             else
102                 __gshared immutable setIndexToSetPropIndex = [SelectSet.write];
103         }
104         else
105         {
106             static if(Policy.ErrorEvents)
107                 __gshared immutable setIndexToSetPropIndex = [SelectSet.error];
108             else
109                 __gshared immutable setIndexToSetPropIndex = null;
110         }
111     }
112 
113     union select_fd_sets
114     {
115         struct
116         {
117             static if(Policy.ReadEvents)
118             {
119                 fd_set_storage!(Policy.MaxSocketCount) readSet;
120                 @property fd_set* readSetPointer() { return readSet.ptr; }
121             }
122             else
123             {
124                 @property fd_set* readSetPointer() { return null; }
125             }
126 
127             static if(Policy.WriteEvents)
128             {
129                 fd_set_storage!(Policy.MaxSocketCount) writeSet;
130                 @property fd_set* writeSetPointer() { return writeSet.ptr; }
131             }
132             else
133             {
134                 @property fd_set* writeSetPointer() { return null; }
135             }
136 
137             static if(Policy.ErrorEvents)
138             {
139                 fd_set_storage!(Policy.MaxSocketCount) errorSet;
140                 @property fd_set* errorSetPointer() { return errorSet.ptr; }
141             }
142             else
143             {
144                 @property fd_set* errorSetPointer() { return null; }
145             }
146         }
147         fd_set_storage!(Policy.MaxSocketCount)[SET_COUNT] sets;
148     }
149 
150     EventSocket[Policy.MaxSocketCount] eventSockets;
151     size_t reservedSocketCount;
152 
153     // Only call inside callbacks, or before calling run
154     // TODO: add an option in H to say, canAddFromOtherThread
155     //       if this is true then I can implement a locking mechanism
156     // The socket gets added on the next loop iteration
157     void add(const EventSocket socket)
158     {
159         eventSockets[reservedSocketCount++] = cast(EventSocket)socket;
160     }
161 
162     static if(Policy.ImplementStop)
163     {
164         bool stopOnNextIteration;
165         // The socket gets added on the next loop iteration
166         void stop()
167         {
168             stopOnNextIteration = true;
169         }
170     }
171 
172     void run()
173     {
174         uint activeSocketCount;
175         select_fd_sets socketSets;
176 
177         for(;;)
178         {
179             // Add new sockets
180             if(reservedSocketCount > activeSocketCount)
181             {
182                 import std.stdio;
183                 writefln("adding %s sockets (%s total)", reservedSocketCount - activeSocketCount, reservedSocketCount);
184                 static if(Policy.TimerEvents)
185                 {
186                     do
187                     {
188                         eventSockets[activeSocketCount].updateTimerExpireTime();
189                         activeSocketCount++;
190                     } while(activeSocketCount < reservedSocketCount);
191                 }
192                 else
193                 {
194                     activeSocketCount = reservedSocketCount;
195                 }
196             }
197 
198             // Remove sockets
199             {
200                 uint removeCount = 0;
201                 for(uint i = 0; i < activeSocketCount; i++)
202                 {
203                     if( (eventSockets[i].flags & EventFlags.all) == 0 &&
204                         eventSockets[i].timer == NO_TIMER)
205                     {
206                         removeCount++;
207                         continue;
208                     }
209                     if(removeCount)
210                     {
211                         eventSockets[i-removeCount] = eventSockets[i];
212                     }
213                 }
214                 if(removeCount)
215                 {
216                     activeSocketCount -= removeCount;
217                     reservedSocketCount -= removeCount;
218                     import std.stdio;
219                     writefln("removed %s sockets (%s sockets left)", removeCount, activeSocketCount);
220                 }
221             }
222 
223             if(activeSocketCount == 0)
224             {
225                 break;
226             }
227 
228             static if(Policy.ImplementStop)
229             {
230                 if(stopOnNextIteration)
231                 {
232                     import std.stdio; writeln("STOPPING!");
233                     for(uint i = 0; i < activeSocketCount; i++)
234                     {
235                         shutdown(eventSockets[i].handle, Shutdown.both);
236                         closesocket(eventSockets[i].handle);
237                     }
238                     return;
239                 }
240                 else
241                 {
242                     import std.stdio; writeln("NOT STOPPING!");
243                 }
244             }
245             // TODO: maybe implement a stop feature, this
246             //       could be an option for H, something like, Policy.implementStop
247 
248             // Setup the select call
249             static if(Policy.ReadEvents)
250                 socketSets.readSet.fd_count = 0;
251             static if(Policy.WriteEvents)
252                 socketSets.writeSet.fd_count = 0;
253             static if(Policy.ErrorEvents)
254                 socketSets.errorSet.fd_count = 0;
255 
256             static if(Policy.TimerEvents)
257             {
258                 uint soonestTimerEvent = uint.max;
259                 uint now;
260             }
261 
262             for(uint i = 0; i < activeSocketCount; i++)
263             {
264                 static if(Policy.ReadEvents)
265                 {
266                     if(eventSockets[i].flags & EventFlags.read)
267                     {
268                         socketSets.readSet.addNoCheck(eventSockets[i].handle);
269                     }
270                 }
271                 static if(Policy.WriteEvents)
272                 {
273                     if(eventSockets[i].flags & EventFlags.write)
274                     {
275                         socketSets.writeSet.addNoCheck(eventSockets[i].handle);
276                     }
277                 }
278                 static if(Policy.ErrorEvents)
279                 {
280                     if(eventSockets[i].flags & EventFlags.error)
281                     {
282                         socketSets.errorSet.addNoCheck(eventSockets[i].handle);
283                     }
284                 }
285                 static if(Policy.TimerEvents)
286                 {
287                     if(eventSockets[i].timer != NO_TIMER)
288                     {
289                         if(soonestTimerEvent == uint.max)
290                         {
291                             now = getCurrentTimeMillis();
292                         }
293 
294                         auto diff = eventSockets[i].timerExpireTime - now;
295                         if(diff >= 0x7FFFFFFF)
296                         {
297                             diff = 0;
298                         }
299                         if(diff < soonestTimerEvent)
300                         {
301                             soonestTimerEvent = diff;
302                         }
303                     }
304                 }
305             }
306 
307 
308             {
309                 import std.stdio;
310                 // go in reverse since we probably want to call error callbacks first
311                 write("select");
312                 foreach(setIndex; 0..SET_COUNT)
313                 {
314                     writef(" %s(", selectSetProps[setIndexToSetPropIndex[setIndex]].name);
315                     bool atFirst = true;
316                     foreach(handle; socketSets.sets[setIndex].fd_array[0..socketSets.sets[setIndex].fd_count])
317                     {
318                         if(atFirst) { atFirst = false; } else { write(", "); }
319                         writef("%s", handle);
320                     }
321                     write(")");
322                 }
323                 writeln();
324             }
325 
326             static if(Policy.TimerEvents)
327             {
328                 timeval timeout;
329                 if(soonestTimerEvent != uint.max)
330                 {
331                     if(soonestTimerEvent == 0)
332                     {
333                         timeout.tv_sec = 0;
334                         timeout.tv_usec = 0;
335                     }
336                     else
337                     {
338                         timeout.tv_sec  = soonestTimerEvent / 1000;          // seconds
339                         timeout.tv_usec = (soonestTimerEvent % 1000) * 1000; // microseconds
340                     }
341                 }
342                 import std.stdio; writefln("calling select...");
343                 int selectResult = select(0,
344                     socketSets.readSetPointer,
345                     socketSets.writeSetPointer,
346                     socketSets.errorSetPointer,
347                     (soonestTimerEvent == uint.max) ? null : &timeout);
348             }
349             else
350             {
351                 import std.stdio; writefln("calling select...");
352                 int selectResult = select(0,
353                     socketSets.readSetPointer,
354                     socketSets.writeSetPointer,
355                     socketSets.errorSetPointer,
356                     null);
357             }
358             import std.stdio; writefln("select returned %d", selectResult);
359             if(selectResult < 0)
360             {
361                 import std.conv : to;
362                 assert(0, "select failed, error: " ~ GetLastError().to!string);
363             }
364 
365             // Handle read/write/error events
366             if(selectResult > 0)
367             {
368                 // go in reverse since we probably want to call error callbacks first
369                 foreach_reverse(setIndex; 0..SET_COUNT)
370                 {
371                     // The hint keeps track of where the last popped socket was found,
372                     // it used to determine where to start searching for the next socket.
373                     // If select keeps the sockets in order, then the hint will find the sockets
374                     // in the most efficient way possible.
375                     uint hint = 0;
376                     foreach(handle; socketSets.sets[setIndex].fd_array[0..socketSets.sets[setIndex].fd_count])
377                     {
378                         uint eventSocketIndex = findSocket(activeSocketCount, eventSockets, handle, &hint);
379                         if(eventSocketIndex == activeSocketCount)
380                         {
381                             import std.conv : to;
382                             assert(0, "socket handle " ~ handle.to!string ~ " was in the select set, but not in the eventSockets array");
383                             //continue;
384                         }
385 
386                         // Check if the flag is still set, if not, it could have been removed
387                         // by another event callback that was previously called.
388                         if(0 == (eventSockets[eventSocketIndex].flags & selectSetProps[setIndexToSetPropIndex[setIndex]].eventFlag))
389                         {
390                             continue;
391                         }
392 
393                         eventSockets[eventSocketIndex].handler(&eventSockets[eventSocketIndex]);
394                         eventSockets[eventSocketIndex].updateTimerExpireTime();
395                     }
396                 }
397             }
398 
399             // Handle timer events
400             static if(Policy.TimerEvents)
401             {
402                 if(soonestTimerEvent != uint.max)
403                 {
404                     assert(0, "not implemented");
405                 }
406             }
407         }
408     }
409 
410     // Returns: index on success, count on error
411     // hint: contains the guess of where the next socket will be
412     uint findSocket(uint count, EventSocket[] eventSockets, SocketHandle handle, uint* hintReference)
413     {
414         enum checkHandleCode = q{
415             if(handle == eventSockets[i].handle)
416             {
417                 *hintReference = i + 1; // increment for next time
418                 return i;
419             }
420         };
421 
422         foreach(i; *hintReference..count)
423         {
424             mixin(checkHandleCode);
425         }
426         // If we get to this point in the function
427         // then the select sockets were out of order so the
428         // hint mechanism didn't quite work.  Maybe we could log
429         // when this happens to determine whether or not there is
430         // a better mechanism to find the sockets.
431         foreach(i; 0..*hintReference)
432         {
433             mixin(checkHandleCode);
434         }
435 
436         return count; // ERROR
437     }
438 }
439 
440 
441 unittest
442 {
443     static struct Hooks1
444     {
445         enum MaxSocketCount = 64;
446         enum ReadEvents = true;
447         enum WriteEvents = true;
448         enum ErrorEvents = true;
449         enum TimerEvents = true;
450         enum ImplementStop = true;
451     }
452     static struct Hooks2
453     {
454         enum MaxSocketCount = 64;
455         enum ReadEvents = true;
456         enum WriteEvents = true;
457         enum ErrorEvents = true;
458         enum TimerEvents = false;
459         enum ImplementStop = true;
460     }
461     static struct Hooks3
462     {
463         enum MaxSocketCount = 64;
464         enum ReadEvents = true;
465         enum WriteEvents = false;
466         enum ErrorEvents = true;
467         enum TimerEvents = false;
468         enum ImplementStop = true;
469     }
470     static struct Hooks4
471     {
472         enum MaxSocketCount = 64;
473         enum ReadEvents = true;
474         enum WriteEvents = false;
475         enum ErrorEvents = true;
476         enum TimerEvents = false;
477         enum ImplementStop = false;
478         mixin template EventSocketMixinTemplate()
479         {
480             int aCoolNewField;
481         }
482     }
483 
484     {
485         auto server = new SocketServer!Hooks1();
486     }
487     {
488         auto server = new SocketServer!Hooks2();
489     }
490     {
491         auto server = new SocketServer!Hooks3();
492     }
493     {
494         auto server = new SocketServer!Hooks4();
495     }
496 }
497