1 module more.net.socketserver; 2 3 import std.traits : hasMember; 4 import more.net.sock; 5 6 version(Windows) 7 { 8 import core.sys.windows.windows : GetTickCount, timeval, GetLastError; 9 alias TimerTime = uint; 10 auto getCurrentTimeMillis() 11 { 12 return GetTickCount(); 13 } 14 } 15 16 enum EventFlags : ubyte 17 { 18 none = 0x00, 19 read = 0x01, 20 write = 0x02, 21 error = 0x04, 22 all = read | write | error, 23 } 24 25 enum NO_TIMER = 0; 26 27 enum SelectSet { read = 0, write = 1, error = 2} 28 struct SelectSetProperties 29 { 30 EventFlags eventFlag; 31 string name; 32 } 33 __gshared immutable selectSetProps = [ 34 immutable SelectSetProperties(EventFlags.read, "read"), 35 immutable SelectSetProperties(EventFlags.write, "write"), 36 immutable SelectSetProperties(EventFlags.error, "error"), 37 ]; 38 39 struct SocketServerTemplate(Policy) 40 { 41 struct EventSocket 42 { 43 SocketHandle handle; 44 EventFlags flags; 45 46 void function(EventSocket*) handler; 47 48 // The number of milliseconds until the timer event occurs. 49 // 0 means no timer. 50 uint timer; 51 // The time when the timer expires 52 TimerTime timerExpireTime; 53 54 static if( hasMember!(Policy, "EventSocketMixinTemplate") ) 55 { 56 mixin Policy.EventSocketMixinTemplate; 57 } 58 59 void setToRemove() 60 { 61 flags = EventFlags.none; 62 timer = NO_TIMER; 63 } 64 65 void updateTimerExpireTime() 66 { 67 if(timer != NO_TIMER) 68 { 69 timerExpireTime = getCurrentTimeMillis() + timer; 70 } 71 } 72 } 73 74 enum SET_COUNT = 75 cast(ubyte)Policy.ReadEvents + 76 cast(ubyte)Policy.WriteEvents + 77 cast(ubyte)Policy.ErrorEvents; 78 static if(Policy.ReadEvents) 79 { 80 static if(Policy.WriteEvents) 81 { 82 static if(Policy.ErrorEvents) 83 __gshared immutable setIndexToSetPropIndex = [SelectSet.read, SelectSet.write, SelectSet.error]; 84 else 85 __gshared immutable setIndexToSetPropIndex = [SelectSet.read, SelectSet.write]; 86 } 87 else 88 { 89 static if(Policy.ErrorEvents) 90 __gshared immutable setIndexToSetPropIndex = [SelectSet.read, SelectSet.error]; 91 else 92 __gshared immutable setIndexToSetPropIndex = [SelectSet.read]; 93 } 94 } 95 else 96 { 97 static if(Policy.WriteEvents) 98 { 99 static if(Policy.ErrorEvents) 100 __gshared immutable setIndexToSetPropIndex = [SelectSet.write, SelectSet.error]; 101 else 102 __gshared immutable setIndexToSetPropIndex = [SelectSet.write]; 103 } 104 else 105 { 106 static if(Policy.ErrorEvents) 107 __gshared immutable setIndexToSetPropIndex = [SelectSet.error]; 108 else 109 __gshared immutable setIndexToSetPropIndex = null; 110 } 111 } 112 113 union select_fd_sets 114 { 115 struct 116 { 117 static if(Policy.ReadEvents) 118 { 119 fd_set_storage!(Policy.MaxSocketCount) readSet; 120 @property fd_set* readSetPointer() { return readSet.ptr; } 121 } 122 else 123 { 124 @property fd_set* readSetPointer() { return null; } 125 } 126 127 static if(Policy.WriteEvents) 128 { 129 fd_set_storage!(Policy.MaxSocketCount) writeSet; 130 @property fd_set* writeSetPointer() { return writeSet.ptr; } 131 } 132 else 133 { 134 @property fd_set* writeSetPointer() { return null; } 135 } 136 137 static if(Policy.ErrorEvents) 138 { 139 fd_set_storage!(Policy.MaxSocketCount) errorSet; 140 @property fd_set* errorSetPointer() { return errorSet.ptr; } 141 } 142 else 143 { 144 @property fd_set* errorSetPointer() { return null; } 145 } 146 } 147 fd_set_storage!(Policy.MaxSocketCount)[SET_COUNT] sets; 148 } 149 150 EventSocket[Policy.MaxSocketCount] eventSockets; 151 size_t reservedSocketCount; 152 153 // Only call inside callbacks, or before calling run 154 // TODO: add an option in H to say, canAddFromOtherThread 155 // if this is true then I can implement a locking mechanism 156 // The socket gets added on the next loop iteration 157 void add(const EventSocket socket) 158 { 159 eventSockets[reservedSocketCount++] = cast(EventSocket)socket; 160 } 161 162 static if(Policy.ImplementStop) 163 { 164 bool stopOnNextIteration; 165 // The socket gets added on the next loop iteration 166 void stop() 167 { 168 stopOnNextIteration = true; 169 } 170 } 171 172 void run() 173 { 174 uint activeSocketCount; 175 select_fd_sets socketSets; 176 177 for(;;) 178 { 179 // Add new sockets 180 if(reservedSocketCount > activeSocketCount) 181 { 182 import std.stdio; 183 writefln("adding %s sockets (%s total)", reservedSocketCount - activeSocketCount, reservedSocketCount); 184 static if(Policy.TimerEvents) 185 { 186 do 187 { 188 eventSockets[activeSocketCount].updateTimerExpireTime(); 189 activeSocketCount++; 190 } while(activeSocketCount < reservedSocketCount); 191 } 192 else 193 { 194 activeSocketCount = reservedSocketCount; 195 } 196 } 197 198 // Remove sockets 199 { 200 uint removeCount = 0; 201 for(uint i = 0; i < activeSocketCount; i++) 202 { 203 if( (eventSockets[i].flags & EventFlags.all) == 0 && 204 eventSockets[i].timer == NO_TIMER) 205 { 206 removeCount++; 207 continue; 208 } 209 if(removeCount) 210 { 211 eventSockets[i-removeCount] = eventSockets[i]; 212 } 213 } 214 if(removeCount) 215 { 216 activeSocketCount -= removeCount; 217 reservedSocketCount -= removeCount; 218 import std.stdio; 219 writefln("removed %s sockets (%s sockets left)", removeCount, activeSocketCount); 220 } 221 } 222 223 if(activeSocketCount == 0) 224 { 225 break; 226 } 227 228 static if(Policy.ImplementStop) 229 { 230 if(stopOnNextIteration) 231 { 232 import std.stdio; writeln("STOPPING!"); 233 for(uint i = 0; i < activeSocketCount; i++) 234 { 235 shutdown(eventSockets[i].handle, Shutdown.both); 236 closesocket(eventSockets[i].handle); 237 } 238 return; 239 } 240 else 241 { 242 import std.stdio; writeln("NOT STOPPING!"); 243 } 244 } 245 // TODO: maybe implement a stop feature, this 246 // could be an option for H, something like, Policy.implementStop 247 248 // Setup the select call 249 static if(Policy.ReadEvents) 250 socketSets.readSet.fd_count = 0; 251 static if(Policy.WriteEvents) 252 socketSets.writeSet.fd_count = 0; 253 static if(Policy.ErrorEvents) 254 socketSets.errorSet.fd_count = 0; 255 256 static if(Policy.TimerEvents) 257 { 258 uint soonestTimerEvent = uint.max; 259 uint now; 260 } 261 262 for(uint i = 0; i < activeSocketCount; i++) 263 { 264 static if(Policy.ReadEvents) 265 { 266 if(eventSockets[i].flags & EventFlags.read) 267 { 268 socketSets.readSet.addNoCheck(eventSockets[i].handle); 269 } 270 } 271 static if(Policy.WriteEvents) 272 { 273 if(eventSockets[i].flags & EventFlags.write) 274 { 275 socketSets.writeSet.addNoCheck(eventSockets[i].handle); 276 } 277 } 278 static if(Policy.ErrorEvents) 279 { 280 if(eventSockets[i].flags & EventFlags.error) 281 { 282 socketSets.errorSet.addNoCheck(eventSockets[i].handle); 283 } 284 } 285 static if(Policy.TimerEvents) 286 { 287 if(eventSockets[i].timer != NO_TIMER) 288 { 289 if(soonestTimerEvent == uint.max) 290 { 291 now = getCurrentTimeMillis(); 292 } 293 294 auto diff = eventSockets[i].timerExpireTime - now; 295 if(diff >= 0x7FFFFFFF) 296 { 297 diff = 0; 298 } 299 if(diff < soonestTimerEvent) 300 { 301 soonestTimerEvent = diff; 302 } 303 } 304 } 305 } 306 307 308 { 309 import std.stdio; 310 // go in reverse since we probably want to call error callbacks first 311 write("select"); 312 foreach(setIndex; 0..SET_COUNT) 313 { 314 writef(" %s(", selectSetProps[setIndexToSetPropIndex[setIndex]].name); 315 bool atFirst = true; 316 foreach(handle; socketSets.sets[setIndex].fd_array[0..socketSets.sets[setIndex].fd_count]) 317 { 318 if(atFirst) { atFirst = false; } else { write(", "); } 319 writef("%s", handle); 320 } 321 write(")"); 322 } 323 writeln(); 324 } 325 326 static if(Policy.TimerEvents) 327 { 328 timeval timeout; 329 if(soonestTimerEvent != uint.max) 330 { 331 if(soonestTimerEvent == 0) 332 { 333 timeout.tv_sec = 0; 334 timeout.tv_usec = 0; 335 } 336 else 337 { 338 timeout.tv_sec = soonestTimerEvent / 1000; // seconds 339 timeout.tv_usec = (soonestTimerEvent % 1000) * 1000; // microseconds 340 } 341 } 342 import std.stdio; writefln("calling select..."); 343 int selectResult = select(0, 344 socketSets.readSetPointer, 345 socketSets.writeSetPointer, 346 socketSets.errorSetPointer, 347 (soonestTimerEvent == uint.max) ? null : &timeout); 348 } 349 else 350 { 351 import std.stdio; writefln("calling select..."); 352 int selectResult = select(0, 353 socketSets.readSetPointer, 354 socketSets.writeSetPointer, 355 socketSets.errorSetPointer, 356 null); 357 } 358 import std.stdio; writefln("select returned %d", selectResult); 359 if(selectResult < 0) 360 { 361 import std.conv : to; 362 assert(0, "select failed, error: " ~ GetLastError().to!string); 363 } 364 365 // Handle read/write/error events 366 if(selectResult > 0) 367 { 368 // go in reverse since we probably want to call error callbacks first 369 foreach_reverse(setIndex; 0..SET_COUNT) 370 { 371 // The hint keeps track of where the last popped socket was found, 372 // it used to determine where to start searching for the next socket. 373 // If select keeps the sockets in order, then the hint will find the sockets 374 // in the most efficient way possible. 375 uint hint = 0; 376 foreach(handle; socketSets.sets[setIndex].fd_array[0..socketSets.sets[setIndex].fd_count]) 377 { 378 uint eventSocketIndex = findSocket(activeSocketCount, eventSockets, handle, &hint); 379 if(eventSocketIndex == activeSocketCount) 380 { 381 import std.conv : to; 382 assert(0, "socket handle " ~ handle.to!string ~ " was in the select set, but not in the eventSockets array"); 383 //continue; 384 } 385 386 // Check if the flag is still set, if not, it could have been removed 387 // by another event callback that was previously called. 388 if(0 == (eventSockets[eventSocketIndex].flags & selectSetProps[setIndexToSetPropIndex[setIndex]].eventFlag)) 389 { 390 continue; 391 } 392 393 eventSockets[eventSocketIndex].handler(&eventSockets[eventSocketIndex]); 394 eventSockets[eventSocketIndex].updateTimerExpireTime(); 395 } 396 } 397 } 398 399 // Handle timer events 400 static if(Policy.TimerEvents) 401 { 402 if(soonestTimerEvent != uint.max) 403 { 404 assert(0, "not implemented"); 405 } 406 } 407 } 408 } 409 410 // Returns: index on success, count on error 411 // hint: contains the guess of where the next socket will be 412 uint findSocket(uint count, EventSocket[] eventSockets, SocketHandle handle, uint* hintReference) 413 { 414 enum checkHandleCode = q{ 415 if(handle == eventSockets[i].handle) 416 { 417 *hintReference = i + 1; // increment for next time 418 return i; 419 } 420 }; 421 422 foreach(i; *hintReference..count) 423 { 424 mixin(checkHandleCode); 425 } 426 // If we get to this point in the function 427 // then the select sockets were out of order so the 428 // hint mechanism didn't quite work. Maybe we could log 429 // when this happens to determine whether or not there is 430 // a better mechanism to find the sockets. 431 foreach(i; 0..*hintReference) 432 { 433 mixin(checkHandleCode); 434 } 435 436 return count; // ERROR 437 } 438 } 439 440 441 unittest 442 { 443 static struct Hooks1 444 { 445 enum MaxSocketCount = 64; 446 enum ReadEvents = true; 447 enum WriteEvents = true; 448 enum ErrorEvents = true; 449 enum TimerEvents = true; 450 enum ImplementStop = true; 451 } 452 static struct Hooks2 453 { 454 enum MaxSocketCount = 64; 455 enum ReadEvents = true; 456 enum WriteEvents = true; 457 enum ErrorEvents = true; 458 enum TimerEvents = false; 459 enum ImplementStop = true; 460 } 461 static struct Hooks3 462 { 463 enum MaxSocketCount = 64; 464 enum ReadEvents = true; 465 enum WriteEvents = false; 466 enum ErrorEvents = true; 467 enum TimerEvents = false; 468 enum ImplementStop = true; 469 } 470 static struct Hooks4 471 { 472 enum MaxSocketCount = 64; 473 enum ReadEvents = true; 474 enum WriteEvents = false; 475 enum ErrorEvents = true; 476 enum TimerEvents = false; 477 enum ImplementStop = false; 478 mixin template EventSocketMixinTemplate() 479 { 480 int aCoolNewField; 481 } 482 } 483 484 { 485 auto server = new SocketServer!Hooks1(); 486 } 487 { 488 auto server = new SocketServer!Hooks2(); 489 } 490 { 491 auto server = new SocketServer!Hooks3(); 492 } 493 { 494 auto server = new SocketServer!Hooks4(); 495 } 496 } 497