1 module more.sha;
2 
3 import std.format : formattedWrite;
4 import std.bigint : BigInt;
5 
6 union Sha1
7 {
8     struct
9     {
10         uint _0;
11         uint _1;
12         uint _2;
13         uint _3;
14         uint _4;
15     }
16     uint[5] array;
17     bool opEquals(const(Sha1) rhs) const
18     {
19         foreach(i; 0..5)
20         {
21             if(this.array[i] != rhs.array[i])
22             {
23                 return false;
24             }
25         }
26         return true;
27     }
28     void toString(scope void delegate(const(char)[]) sink) const
29     {
30         formattedWrite(sink, "%08x%08x%08x%08x%08x", _0, _1, _2, _3, _4);
31     }
32 }
33 Sha1 sha1Hash(T)(const(T)[] data) if(T.sizeof == 1)
34 {
35     auto builder = Sha1Builder();
36     builder.put(data);
37     return builder.finish();
38 }
39 
40 uint circularLeftShift(uint value, uint shift)
41 {
42     return (value << shift) | (value >> (32 - shift));
43 }
44 
45 struct Sha1Builder
46 {
47     enum HashByteLength = 20;
48     enum BlockByteLength = 64;
49 
50     enum InitialHash = Sha1(
51         0x67452301,
52         0xEFCDAB89,
53         0x98BADCFE,
54         0x10325476,
55         0xC3D2E1F0);
56 
57     enum uint K_0 = 0x5A827999;
58     enum uint K_1 = 0x6ED9EBA1;
59     enum uint K_2 = 0x8F1BBCDC;
60     enum uint K_3 = 0xCA62C1D6;
61 
62     ubyte[BlockByteLength] unhashedBlock;
63     ulong totalBlocksHashed;
64     Sha1 currentHash = InitialHash;
65     ubyte blockIndex;
66     void put(T)(const(T)[] data) if(T.sizeof == 1)
67     {
68         if(blockIndex + data.length < BlockByteLength)
69         {
70             unhashedBlock[blockIndex..blockIndex + data.length] = cast(ubyte[])data[];
71             blockIndex += data.length;
72             return;
73         }
74 
75         {
76             auto copyLength = BlockByteLength - blockIndex;
77             unhashedBlock[blockIndex..$] = cast(ubyte[])data[0..copyLength];
78             data = data[copyLength..$];
79         }
80         hashBlock(unhashedBlock);
81 
82         for(;data.length >= BlockByteLength;)
83         {
84             hashBlock(cast(ubyte[])data[0..BlockByteLength]);
85             data = data[BlockByteLength..$];
86         }
87 
88         if(data.length > 0)
89         {
90             unhashedBlock[0..data.length] = cast(ubyte[])data;
91             blockIndex = cast(ubyte)data.length;
92         }
93     }
94     Sha1 finish()
95     {
96         auto totalBitsHashed = (totalBlocksHashed * 512) + (blockIndex * 8);
97 
98         // pad
99         unhashedBlock[blockIndex++] = 0x80;
100         if(blockIndex > 56)
101         {
102             unhashedBlock[blockIndex..$] = 0;
103             hashBlock(unhashedBlock);
104             blockIndex = 0;
105         }
106         unhashedBlock[blockIndex..56] = 0;
107         {
108             ubyte index = 0;
109             auto shift = 56;
110             for(;;)
111             {
112                 unhashedBlock[56 + index] = cast(ubyte)(totalBitsHashed >> shift);
113                 index++;
114                 if(index >= 8)
115                 {
116                     break;
117                 }
118                 shift -= 8;
119             }
120         }
121         hashBlock(unhashedBlock);
122         return currentHash;
123     }
124     private void hashBlock(ubyte[] block)
125     {
126         uint[80] W;
127         foreach(i; 0..16)
128         {
129             auto blockIndex = i * 4;
130             W[i] = (
131                 (block[blockIndex + 0] << 24) |
132                 (block[blockIndex + 1] << 16) |
133                 (block[blockIndex + 2] <<  8) |
134                 (block[blockIndex + 3]      ) );
135         }
136         foreach(i; 16..80)
137         {
138             W[i] = circularLeftShift(W[i - 3] ^ W[i - 8] ^ W[i - 14] ^ W[i - 16], 1);
139         }
140 
141         Sha1 tempHash = currentHash;
142         foreach(i; 0..20)
143         {
144             auto temp = circularLeftShift(tempHash._0, 5) + ((tempHash._1 & tempHash._2) | ((~tempHash._1) & tempHash._3)) + tempHash._4 + W[i] + K_0;
145             tempHash._4 = tempHash._3;
146             tempHash._3 = tempHash._2;
147             tempHash._2 = circularLeftShift(tempHash._1, 30);
148             tempHash._1 = tempHash._0;
149             tempHash._0 = temp;
150         }
151         foreach(i; 20..40)
152         {
153             auto temp = circularLeftShift(tempHash._0, 5) + (tempHash._1 ^ tempHash._2 ^ tempHash._3) + tempHash._4 + W[i] + K_1;
154             tempHash._4 = tempHash._3;
155             tempHash._3 = tempHash._2;
156             tempHash._2 = circularLeftShift(tempHash._1, 30);
157             tempHash._1 = tempHash._0;
158             tempHash._0 = temp;
159         }
160         foreach(i; 40..60)
161         {
162             auto temp = circularLeftShift(tempHash._0, 5) + ((tempHash._1 & tempHash._2) | (tempHash._1 & tempHash._3) | (tempHash._2 & tempHash._3)) + tempHash._4 + W[i] + K_2;
163             tempHash._4 = tempHash._3;
164             tempHash._3 = tempHash._2;
165             tempHash._2 = circularLeftShift(tempHash._1, 30);
166             tempHash._1 = tempHash._0;
167             tempHash._0 = temp;
168         }
169         foreach(i; 60..80)
170         {
171             auto temp = circularLeftShift(tempHash._0, 5) + (tempHash._1 ^ tempHash._2 ^ tempHash._3) + tempHash._4 + W[i] + K_3;
172             tempHash._4 = tempHash._3;
173             tempHash._3 = tempHash._2;
174             tempHash._2 = circularLeftShift(tempHash._1, 30);
175             tempHash._1 = tempHash._0;
176             tempHash._0 = temp;
177         }
178         foreach(i; 0..5)
179         {
180             currentHash.array[i] += tempHash.array[i];
181         }
182         totalBlocksHashed++;
183         // TODO: assert if the maximum number of blocks is reached
184     }
185 
186 }
187 
188 unittest
189 {
190     import more.test;
191     mixin(scopedTest!"sha");
192     assert(Sha1(0xda39a3ee, 0x5e6b4b0d, 0x3255bfef, 0x95601890, 0xafd80709) == sha1Hash(null));
193     assert(Sha1(0xA9993E36, 0x4706816A, 0xBA3E2571, 0x7850C26C, 0x9Cd0d89D) == sha1Hash("abc"));
194     assert(Sha1(0x84983e44, 0x1c3bd26e, 0xbaae4aa1, 0xf95129e5, 0xe54670f1) ==
195         sha1Hash("abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq"));
196 }