1 ///
2 module imap.socket;
3 import imap.defines;
4 import imap.session;
5 
6 import core.stdc.stdio;
7 import core.stdc.string;
8 import core.stdc.errno;
9 import std.socket;
10 import core.time : Duration;
11 
12 import deimos.openssl.ssl;
13 import deimos.openssl.err;
14 import deimos.openssl.sha;
15 
16 alias ssize_t = ptrdiff_t;
17 
18 extern(C) @nogc nothrow
19 {
20     SSL_METHOD * TLS_method();
21     SSL_METHOD * TLS_server_method();
22     SSL_METHOD * TLS_client_method();
23 }
24 
25 ///
26 SSL_CTX* getContext(string caFile, string caPath, string certificateFile, string keyFile, bool asServer = false) {
27     import std.exception : enforce;
28     import std.string : toStringz;
29     import std.format : format;
30     SSL_CTX* ret = SSL_CTX_new(asServer ? TLS_server_method() : TLS_client_method());
31     enforce(ret !is null, "unable to create new SSL context");
32     enforce(SSL_CTX_set_default_verify_paths(ret) == 1, "unable to set context default verify paths");
33     if (caFile.length > 0 || caPath.length > 0)
34         enforce(SSL_CTX_load_verify_locations(ret, caFile.toStringz, caPath.toStringz), "unable to load context verify locations");
35     SSL_CTX_set_verify(ret, 0, null);
36     if (certificateFile.length > 0) {
37         enforce(SSL_CTX_use_certificate_file(ret, certificateFile.toStringz, SSL_FILETYPE_PEM) > 0,
38                 format!"unable to set SSL certificate file as PEM to %s"(certificateFile));
39     }
40     if (keyFile.length > 0) {
41         enforce(SSL_CTX_use_certificate_file(ret, keyFile.toStringz, SSL_FILETYPE_PEM) > 0,
42                 format!"unable to set SSL key file as PEM to %s"(keyFile));
43     }
44     // enforce(SSL_CTX_check_private_key(ret) > 0, "check private key failed");
45     return ret;
46 }
47 
48 
49 /// Connect to mail server.
50 Session openConnection(Session session) {
51     import core.time : seconds;
52     import std.format : format;
53     import std.exception : enforce;
54     import std.range : front;
55     auto addressInfos = getAddressInfo(session.server, session.port);
56     enforce(addressInfos.length >= 0, format!"unable to get address info for %s:%s"(session.server, session.port));
57     session.addressInfo = addressInfos.front;
58     // just use first address
59 
60     session.socket = new TcpSocket();
61     session.socket.blocking(true);
62     session.socket.connect(session.addressInfo.address);
63     session.socket.blocking(false);
64     session.socket.setOption(SocketOptionLevel.SOCKET, SocketOption.SNDTIMEO, 1.seconds);
65     enforce(session.socket.isAlive(), format!"connecting to %s:%s failed"(session.server, session.port));
66     if (session.useSSL && !session.options.startTLS)
67         return openSecureConnection(session);
68     return session;
69 }
70 
71 ///
72 enum ProtocolSSL {
73     none,
74     ssl3,
75     tls1,
76     tls1_1,
77     tls1_2,
78 }
79 
80 /// Initialize SSL/TLS connection.
81 Session openSecureConnection(Session session) {
82     import std.exception : enforce;
83     import imap.ssl;
84 
85     enforce(session.socket.isAlive, "trying to secure a disconnected socket");
86 
87     // TODO: The CAs are *usually* found under /etc/ssl on *nix systems, but on Windows they
88     // probably need to come from the certificate manager or somewhere?  I don't know yet.  For now
89     // we'll make these assumptions, and on Windows not try to verify the server certificate at all.
90     version (linux) {
91         import std.file: exists;
92         // Taken from https://serverfault.com/questions/62496/ssl-certificate-location-on-unix-linux/722646#722646
93         string cacertPath;
94         foreach (path; [
95                  "/etc/ssl/certs/ca-certificates.crt",                // Debian/Ubuntu/Gentoo etc.
96                  "/etc/pki/tls/certs/ca-bundle.crt",                  // Fedora/RHEL 6
97                  "/etc/ssl/ca-bundle.pem",                            // OpenSUSE
98                  "/etc/pki/tls/cacert.pem",                           // OpenELEC
99                  "/etc/pki/ca-trust/extracted/pem/tls-ca-bundle.pem", // CentOS/RHEL 7
100                  "/etc/ssl/cert.pem",                                 // Alpine Linux
101                  ]) {
102             if (path.exists) {
103                 cacertPath = path;
104                 break;
105             }
106         }
107         enforce(cacertPath !is null, "Failed to find CA certificate path.");
108         session.sslContext = getContext(cacertPath, "/etc/ssl", null, null, false);
109     }
110     version (Windows) session.sslContext = getContext("", "", null, null, false);
111 
112     enforce(session.sslContext !is null, "unable to create new SSL context");
113     session.sslConnection = SSL_new(session.sslContext);
114     enforce(session.sslConnection !is null, "unable to create new SSL connection");
115     enforce(session.socket.isAlive, "trying to secure a disconnected socket");
116     SSL_set_fd(session.sslConnection, cast(int) session.socket.handle);
117     scope (failure)
118         session.sslConnection = null;
119     int r = SSL_connect(session.sslConnection);
120     if (isSSLError(r)) {
121         auto message = sslConnectionError(session, r);
122         throw new Exception(message);
123     }
124     if (!session.noCerts)
125         getCert(session);
126     return session;
127 }
128 
129 ///
130 bool isSSLError(int socketStatus) {
131     switch (socketStatus) {
132         case SSL_ERROR_NONE,
133             SSL_ERROR_WANT_CONNECT,
134             SSL_ERROR_WANT_ACCEPT,
135             SSL_ERROR_WANT_X509_LOOKUP,
136             SSL_ERROR_WANT_READ,
137             SSL_ERROR_WANT_WRITE:
138             return false;
139 
140         case SSL_ERROR_SYSCALL,
141             SSL_ERROR_ZERO_RETURN,
142             SSL_ERROR_SSL:
143             return true;
144 
145         default:
146             return false;
147     }
148     assert(0);
149 }
150 
151 ///
152 string sslConnectionError(Session session, int socketStatus) {
153     import std.format : format;
154     import std.string : fromStringz;
155     auto result = SSL_get_error(session.sslConnection, socketStatus);
156     switch (result) {
157         case SSL_ERROR_SYSCALL:
158             return format!"initiating SSL connection to %s; %s" (session.server, sslConnectionSysCallError(result));
159 
160         case SSL_ERROR_ZERO_RETURN:
161             return format!"initiating SSL connection to %s; connection has been closed cleanly"(session.server);
162 
163         case SSL_ERROR_SSL:
164             return format!"initiating SSL connection to %s; %s\n"(session.server, ERR_error_string(ERR_get_error(), null).fromStringz);
165 
166         case SSL_ERROR_NONE,
167             SSL_ERROR_WANT_CONNECT,
168             SSL_ERROR_WANT_ACCEPT,
169             SSL_ERROR_WANT_X509_LOOKUP,
170             SSL_ERROR_WANT_READ,
171             SSL_ERROR_WANT_WRITE:
172             break;
173         default:
174             return "";
175     }
176     assert(0);
177 }
178 
179 private string sslConnectionSysCallError(int socketStatus) {
180     import std.string : fromStringz;
181     import std.format : format;
182     auto e = ERR_get_error();
183     if (e == 0 && socketStatus == 0) {
184         return format!"EOF in violation of the protocol";
185     } else if (e == 0 && socketStatus == -1) {
186         return strerror(errno).fromStringz.idup;
187     }
188     return ERR_error_string(e, null).fromStringz.idup;
189 }
190 
191 /// Disconnect from mail server.
192 void closeConnection(Session session) {
193     version (SSL) closeSecureConnection(session);
194     if (session.socket !is null && session.socket.isAlive) {
195         session.socket.close();
196     }
197 }
198 
199 /// Shutdown SSL/TLS connection.
200 int closeSecureConnection(Session session) {
201     if (session.sslConnection) {
202         SSL_shutdown(session.sslConnection);
203         SSL_free(session.sslConnection);
204         session.sslConnection = null;
205     }
206 
207     return 0;
208 }
209 
210 ///
211 enum Status {
212     success,
213     failure,
214 }
215 
216 ///
217 struct Result(T) {
218     Status status;
219     T value;
220 }
221 
222 
223 ///
224 auto result(T)(Status status, T value) {
225     return Result!T(status, value);
226 }
227 
228 /// Read data from socket.
229 Result!string socketRead(Session session, Duration timeout, bool timeoutFail = true) {
230     import std.experimental.logger : tracef;
231     import std.exception : enforce;
232     import std.format : format;
233     import std.string : fromStringz;
234     import std.conv : to;
235     auto buf = new char[16384];
236 
237     int s;
238     ssize_t r;
239 
240     r = 0;
241     s = 1;
242 
243     scope (failure)
244         closeConnection(session);
245 
246     auto socketSet = new SocketSet(1);
247     socketSet.add(session.socket);
248     auto selectResult = Socket.select(socketSet, null, null, timeout);
249     if (session.sslConnection)
250         return socketSecureRead(session);
251 /+
252     {
253         if (SSL_pending(session.sslConnection) > 0 || selectResult > 0)
254         {
255             r = SSL_read(session.sslConnection, cast(void*) buf.ptr, buf.length.to!int);
256             enforce(r > 0, "error reading socket");
257         }
258     } +/
259     if (!session.sslConnection) {
260         if (selectResult > 0) {
261             r = session.socket.receive(cast(void[]) buf);
262 
263             enforce(r != -1, format!"reading data; %s"(strerror(errno).fromStringz));
264             enforce(r != 0, "read returned no data");
265         }
266     }
267 
268     enforce(s != -1, format!"waiting to read from socket; %s"(strerror(errno).fromStringz));
269     enforce(s != 0 || !timeoutFail, "timeout period expired while waiting to read data");
270     version (Trace) tracef("socketRead: %s / %s", session.socket, buf);
271     return result(Status.success, cast(string) buf[0 .. r]);
272 }
273 
274 ///
275 bool isSSLReadError(Session session, int status) {
276     switch (SSL_get_error(session.sslConnection, status)) {
277         case SSL_ERROR_ZERO_RETURN,
278             SSL_ERROR_SYSCALL,
279             SSL_ERROR_SSL:
280             return true;
281 
282         case SSL_ERROR_NONE:
283         case SSL_ERROR_WANT_READ:
284         case SSL_ERROR_WANT_WRITE:
285         case SSL_ERROR_WANT_CONNECT:
286         case SSL_ERROR_WANT_ACCEPT:
287         case SSL_ERROR_WANT_X509_LOOKUP:
288             return false;
289 
290         default:
291             return false;
292     }
293     assert(0);
294 }
295 
296 ///
297 bool isTryAgain(Session session, int status) {
298     if (status > 0)
299         return false;
300     if (session.isSSLReadError(status))
301         return false;
302 
303     switch (status) {
304         case SSL_ERROR_NONE:
305         case SSL_ERROR_WANT_READ:
306         case SSL_ERROR_WANT_WRITE:
307         case SSL_ERROR_WANT_CONNECT:
308         case SSL_ERROR_WANT_ACCEPT:
309         case SSL_ERROR_WANT_X509_LOOKUP:
310             return true;
311 
312         default:
313             return true;
314     }
315 }
316 
317 /// Read data from a TLS/SSL connection.
318 Result!string socketSecureRead(Session session) {
319     import std.experimental.logger : tracef;
320     import std.exception : enforce;
321     import std.conv : to;
322     import std.format : format;
323     version (Trace) import std.stdio : writefln, stderr;
324     enforce(session.sslConnection !is null);
325     int res;
326     auto buf = new char[16384 * 1024];
327     scope (failure)
328         SSL_set_shutdown(session.sslConnection, SSL_SENT_SHUTDOWN | SSL_RECEIVED_SHUTDOWN);
329     do
330     {
331         enforce(session.sslConnection !is null, "trying to read from unconnected SSL socket");
332         res = SSL_read(session.sslConnection, cast(void*) buf.ptr, buf.length.to!int);
333         enforce(!session.isSSLReadError(res), session.sslReadErrorMessage(res));
334     } while (session.isTryAgain(res));
335     enforce(res > 0, format!"SSL_read returned %s and expecting a positive number of bytes"(res));
336     version (Trace) tracef("socketSecureRead: %s / %s", session.socket, buf[0 .. res]);
337     version (Trace) stderr.writefln("socketSecureRead: %s / %s", session.socket, buf[0 .. res]);
338     return result(Status.success, buf[0 .. res].idup);
339 }
340 
341 ///
342 string sslReadErrorMessage(Session session, int status) {
343     import std.format : format;
344     import std.string : fromStringz;
345     import std.exception : enforce;
346     enforce(session.isSSLReadError(status), "ssl error that is not an error!");
347     switch (SSL_get_error(session.sslConnection, status)) {
348         case SSL_ERROR_ZERO_RETURN:
349             return "reading data through SSL; the connection has been closed cleanly";
350 
351         case SSL_ERROR_SSL:
352             return format!"reading data through SSL; %s\n"(ERR_error_string(ERR_get_error(), null).fromStringz);
353 
354         case SSL_ERROR_SYSCALL:
355             auto e = ERR_get_error();
356             if (e == 0 && status == 0) {
357                 return "reading data through SSL; EOF in violation of the protocol";
358             } else if (e == 0 && status == -1) {
359                 return format!"reading data through SSL; %s"(strerror(errno).fromStringz);
360             } else {
361                 return format!"reading data through SSL; %s"(ERR_error_string(e, null).fromStringz);
362             }
363         default:
364             return "";
365     }
366     assert(0);
367 }
368 
369 /// Write data to socket.
370 ssize_t socketWrite(Session session, string buf) {
371     import std.experimental.logger : tracef;
372     import std.exception : enforce;
373     import std.format : format;
374     import std.string : fromStringz;
375     import std.conv : to;
376     int s;
377     ssize_t r, t;
378 
379     r = t = 0;
380     s = 1;
381 
382     version (Trace) tracef("socketWrite: %s / %s", session.socket, buf);
383     if (session.sslConnection) {
384         version (Trace) tracef("socketSecureWrite: %s / %s", session.socket, buf);
385         return session.socketSecureWrite(buf);
386     }
387 
388     scope (failure)
389         closeConnection(session);
390 
391     auto socketSet = new SocketSet(1);
392     socketSet.add(session.socket);
393 
394     version (Trace) tracef("entering loop with buf.length=%s", buf.length);
395     while (buf.length > 0) {
396         version (Trace) tracef("buf is of length %s", buf.length);
397         version (Trace) tracef("sending buf: %s", buf);
398         r = session.socket.send(cast(void[]) buf);
399         version (Trace) tracef("r=: %s", r);
400         enforce(r != -1, format!"writing data; %s"(strerror(errno).fromStringz));
401         enforce(r != 0, "unknown error");
402 
403         if (r > 0) {
404             enforce(r <= buf.length, "send to socket returned more bytes than we sent!");
405             buf = buf[r .. $];
406             t += r;
407             version (Trace) tracef("buf now =: %s", buf);
408         }
409     }
410 
411     enforce(s != -1, format!"waiting to write to socket; %s"(strerror(errno).fromStringz));
412     enforce(s != 0, "timeout period expired while waiting to write data");
413 
414     return t;
415 }
416 
417 /// Write data to a TLS/SSL connection.
418 auto socketSecureWrite(Session session, string buf) {
419     import std.experimental.logger : tracef;
420     import std.string : fromStringz;
421     import std.format : format;
422     import std.conv : to;
423     import std.exception : enforce;
424     int r;
425     size_t e;
426 
427     version (Trace) tracef("socketSecureWrite: %s / %s", session.socket, buf);
428     enforce(session.sslConnection, "no SSL connection has been established");
429     if (buf.length == 0)
430         return 0;
431     while (true) {
432         if ((r = SSL_write(session.sslConnection, buf.ptr, buf.length.to!int)) > 0)
433             break;
434 
435         scope (failure)
436             SSL_set_shutdown(session.sslConnection, SSL_SENT_SHUTDOWN | SSL_RECEIVED_SHUTDOWN);
437 
438         switch (SSL_get_error(session.sslConnection, r)) {
439             case SSL_ERROR_ZERO_RETURN:
440                 throw new Exception("writing data through SSL; the connection has been closed cleanly");
441             case SSL_ERROR_NONE:
442             case SSL_ERROR_WANT_READ:
443             case SSL_ERROR_WANT_WRITE:
444             case SSL_ERROR_WANT_CONNECT:
445             case SSL_ERROR_WANT_ACCEPT:
446             case SSL_ERROR_WANT_X509_LOOKUP:
447                 break;
448             case SSL_ERROR_SYSCALL:
449                 e = ERR_get_error();
450                 if (e == 0 && r == 0)
451                     throw new Exception("writing data through SSL; EOF in violation of the protocol");
452                 enforce(!(e == 0 && r == -1), format!"writing data through SSL; %s\n"(strerror(errno).fromStringz.idup));
453                 enforce(true, format!"writing data through SSL; %s"(ERR_error_string(cast(uint) e, null).fromStringz.idup));
454                 break;
455             case SSL_ERROR_SSL:
456                 enforce(true, format!"writing data through SSL; %s"(ERR_error_string(ERR_get_error(), null).fromStringz.idup));
457                 break;
458             default:
459                 break;
460         }
461     }
462     return r;
463 }
464 
465 
466 enum StateTLS {
467     connecting,
468     accepting,
469     connected,
470 }
471