1 ///
2 module imap.socket;
3 import imap.defines;
4 import imap.session;
5 
6 import core.stdc.stdio;
7 import core.stdc.string;
8 import core.stdc.errno;
9 import std.socket;
10 import core.time : Duration;
11 
12 import deimos.openssl.ssl;
13 import deimos.openssl.err;
14 import deimos.openssl.sha;
15 
16 alias ssize_t = ptrdiff_t;
17 
18 extern(C) @nogc nothrow
19 {
20 	SSL_METHOD *TLS_method();
21 	SSL_METHOD *TLS_server_method();
22 	SSL_METHOD *TLS_client_method();
23 }
24 
25 ///
26 SSL_CTX* getContext(string caFile, string caPath, string certificateFile, string keyFile, bool asServer = false)
27 {
28 	import std.exception : enforce;
29 	import std.string : toStringz;
30 	import std.format : format;
31 	SSL_CTX* ret = SSL_CTX_new(asServer ? TLS_server_method() : TLS_client_method());
32 	enforce(ret !is null, "unable to create new SSL context");
33 	enforce(SSL_CTX_set_default_verify_paths(ret) == 1, "unable to set context default verify paths");
34 	if (caFile.length > 0 || caPath.length > 0)
35 		enforce(SSL_CTX_load_verify_locations(ret,caFile.toStringz,caPath.toStringz), "unable to load context verify locations");
36 	SSL_CTX_set_verify(ret,0,null);
37 	if(certificateFile.length > 0)
38 	{
39 		enforce(SSL_CTX_use_certificate_file(ret,certificateFile.toStringz, SSL_FILETYPE_PEM) > 0,
40 			format!"unable to set SSL certificate file as PEM to %s"(certificateFile));
41 	}
42 	if (keyFile.length > 0)
43 	{
44 		enforce(SSL_CTX_use_certificate_file(ret,keyFile.toStringz, SSL_FILETYPE_PEM) > 0,
45 			format!"unable to set SSL key file as PEM to %s"(keyFile));
46 	}
47 	//enforce(SSL_CTX_check_private_key(ret) > 0, "check private key failed");
48 	return ret;
49 }
50 
51 
52 ///	Connect to mail server.
53 Session openConnection(ref Session session)
54 {
55     import core.time : seconds;
56 	import std.format : format;
57 	import std.exception : enforce;
58 	import std.range : front;
59 	auto addressInfos = getAddressInfo(session.server, session.port);
60 	enforce(addressInfos.length >=0, format!"unable to get address info for %s:%s"(session.server,session.port));
61 	session.addressInfo = addressInfos.front;
62 	// just use first address
63 
64 	session.socket = new TcpSocket();
65 	session.socket.blocking(true);
66 	session.socket.connect(session.addressInfo.address);
67 	session.socket.blocking(false);
68     session.socket.setOption(SocketOptionLevel.SOCKET,SocketOption.SNDTIMEO,1.seconds);
69 	enforce(session.socket.isAlive(), format!"connecting to %s:%s failed"(session.server, session.port));
70 	if (session.useSSL && !session.options.startTLS)
71 		return openSecureConnection(session);
72 	return session;
73 }
74 
75 ///
76 enum ProtocolSSL
77 {
78 	none,
79 	ssl3,
80 	tls1,
81 	tls1_1,
82 	tls1_2,
83 }
84 
85 ///	Initialize SSL/TLS connection.
86 ref Session openSecureConnection(ref Session session)
87 {
88 	import std.exception : enforce;
89 	import imap.ssl;
90 	int r;
91 	enforce(session.socket.isAlive, "trying to secure a disconnected socket");
92 	session.sslContext = getContext("/etc/ssl/cert.pem","/etc/ssl",null,null,false); // "cacert.pem","/etc/pki/CA",null,null,false);
93 	enforce(session.sslContext !is null, "unable to create new SSL context");
94 	session.sslConnection = SSL_new(session.sslContext);
95 	enforce(session.sslConnection !is null, "unable to create new SSL connection");
96 	enforce(session.socket.isAlive, "trying to secure a disconnected socket");
97 	SSL_set_fd(session.sslConnection, session.socket.handle);
98 	scope(failure)
99 		session.sslConnection = null;
100 	r = SSL_connect(session.sslConnection);
101 	if (isSSLError(r))
102 	{
103 		auto message = sslConnectionError(session,r);
104 		throw new Exception(message);
105 	}
106 	if(!session.noCerts)
107 	   getCert(session);
108 	return session;
109 }
110 
111 ///
112 bool isSSLError(int socketStatus)
113 {
114 	switch(socketStatus)
115 	{
116 			case 	SSL_ERROR_NONE,
117 					SSL_ERROR_WANT_CONNECT,
118 					SSL_ERROR_WANT_ACCEPT,
119 					SSL_ERROR_WANT_X509_LOOKUP,
120 					SSL_ERROR_WANT_READ,
121 					SSL_ERROR_WANT_WRITE:
122 						return false;
123 
124 			case 	SSL_ERROR_SYSCALL,
125 					SSL_ERROR_ZERO_RETURN,
126 					SSL_ERROR_SSL:
127 						return true;
128 
129 		default:
130 					return false;
131 	}
132 	assert(0);
133 }
134 
135 ///
136 string sslConnectionError(ref Session session, int socketStatus)
137 {
138 	import std.format : format;
139 	import std.string : fromStringz;
140 	auto result = SSL_get_error(session.sslConnection, socketStatus);
141 	switch(result)
142 	{
143 		case SSL_ERROR_SYSCALL:
144 			return format!"initiating SSL connection to %s; %s" ( session.server, sslConnectionSysCallError(result));
145 
146 		case SSL_ERROR_ZERO_RETURN:
147 			return format!"initiating SSL connection to %s; connection has been closed cleanly"(session.server);
148 
149 		case SSL_ERROR_SSL:
150 			return format!"initiating SSL connection to %s; %s\n"(session.server, ERR_error_string(ERR_get_error(), null).fromStringz);
151 
152 			case 	SSL_ERROR_NONE,
153 					SSL_ERROR_WANT_CONNECT,
154 					SSL_ERROR_WANT_ACCEPT,
155 					SSL_ERROR_WANT_X509_LOOKUP,
156 					SSL_ERROR_WANT_READ,
157 					SSL_ERROR_WANT_WRITE:
158 						break;
159 		default:
160 			return "";
161 	}
162 	assert(0);
163 }
164 
165 private string sslConnectionSysCallError(int socketStatus)
166 {
167 	import std.string : fromStringz;
168 	import std.format : format;
169 	auto e = ERR_get_error();
170 	if (e == 0 && socketStatus ==0)
171 		return format!"EOF in violation of the protocol";
172 	else if (e == 0 && socketStatus == -1)
173 		return strerror(errno).fromStringz.idup;
174 	return ERR_error_string(e, null).fromStringz.idup;
175 }
176 
177 /// Disconnect from mail server.
178 void closeConnection(ref Session session)
179 {
180 	version(SSL) closeSecureConnection(session);
181 	if (session.socket.isAlive)
182 	{
183 		session.socket.close();
184 	}
185 }
186 
187 ///	Shutdown SSL/TLS connection.
188 int closeSecureConnection(ref Session session)
189 {
190 	if (session.sslConnection)
191 	{
192 		SSL_shutdown(session.sslConnection);
193 		SSL_free(session.sslConnection);
194 		session.sslConnection = null;
195 	}
196 
197 	return 0;
198 }
199 
200 ///
201 enum Status
202 {
203 	success,
204 	failure,
205 }
206 
207 ///
208 struct Result(T)
209 {
210 	Status status;
211 	T value;
212 }
213 
214 
215 ///
216 auto result(T)(Status status, T value)
217 {
218 	return Result!T(status,value);
219 }
220 
221 /// Read data from socket.
222 Result!string socketRead(ref Session session, Duration timeout, bool timeoutFail = true)
223 {
224 	import std.experimental.logger : tracef;
225 	import std.exception : enforce;
226 	import std.format : format;
227 	import std.string : fromStringz;
228 	import std.conv : to;
229 	auto buf = new char[16384];
230 	
231 	int s;
232 	ssize_t r;
233 
234 	r = 0;
235 	s = 1;
236 
237 	scope(failure)
238 		closeConnection(session);
239 
240 	auto socketSet = new SocketSet(1);
241 	socketSet.add(session.socket);
242 	auto selectResult = Socket.select(socketSet,null,null,timeout);
243 	if (session.sslConnection)
244         return socketSecureRead(session);
245 /+
246 	{
247 		if (SSL_pending(session.sslConnection) > 0 || selectResult > 0)
248 		{
249 			r = SSL_read(session.sslConnection, cast(void*) buf.ptr, buf.length.to!int);
250 			enforce(r > 0, "error reading socket");
251 		}
252 	} +/
253 	if (!session.sslConnection)
254 	{
255 		if (selectResult > 0)
256 		{
257 			r = session.socket.receive(cast(void[])buf);
258 
259 			enforce(r != -1, format!"reading data; %s"(strerror(errno).fromStringz));
260 			enforce(r != 0, "read returned no data");
261 		}
262 	}
263 
264 	enforce(s != -1, format!"waiting to read from socket; %s"(strerror(errno).fromStringz));
265 	enforce (s != 0 || !timeoutFail, "timeout period expired while waiting to read data");
266 	tracef("socketRead: %s / %s",session.socket,buf);
267 	return result(Status.success,cast(string)buf[0..r]);
268 }
269 
270 ///
271 bool isSSLReadError(ref Session session, int status)
272 {
273 	switch(SSL_get_error(session.sslConnection,status))
274 	{
275 		case SSL_ERROR_ZERO_RETURN,
276 			 SSL_ERROR_SYSCALL,
277 			 SSL_ERROR_SSL:
278 				 return true;
279 
280 		case SSL_ERROR_NONE:
281 		case SSL_ERROR_WANT_READ:
282 		case SSL_ERROR_WANT_WRITE:
283 		case SSL_ERROR_WANT_CONNECT:
284 		case SSL_ERROR_WANT_ACCEPT:
285 		case SSL_ERROR_WANT_X509_LOOKUP:
286 				return false;
287 		
288 		default:
289 				return false;
290 	}
291 	assert(0);
292 }
293 
294 ///
295 bool isTryAgain(ref Session session, int status)
296 {
297 	if (status > 0)
298 		return false;
299 	if (session.isSSLReadError(status))
300 		return false;
301 
302 	switch(status)
303 	{
304 		case SSL_ERROR_NONE:
305 		case SSL_ERROR_WANT_READ:
306 		case SSL_ERROR_WANT_WRITE:
307 		case SSL_ERROR_WANT_CONNECT:
308 		case SSL_ERROR_WANT_ACCEPT:
309 		case SSL_ERROR_WANT_X509_LOOKUP:
310 			return true;
311 
312 		default:
313 			return true;
314 	}
315 }
316 
317 /// Read data from a TLS/SSL connection.
318 Result!string socketSecureRead(ref Session session)
319 {
320 	import std.experimental.logger : tracef;
321 	import std.exception : enforce;
322 	import std.conv : to;
323 	import std.format : format;
324 	version(Trace) import std.stdio: writefln,stderr;
325     enforce(session.sslConnection !is null);
326 	int res;
327 	auto buf = new char[16384*1024];
328 	scope(failure)
329 		SSL_set_shutdown(session.sslConnection, SSL_SENT_SHUTDOWN | SSL_RECEIVED_SHUTDOWN);
330 	do
331 	{
332         enforce(session.sslConnection !is null, "trying to read from unconnected SSL socket");
333 		res = SSL_read(session.sslConnection, cast(void*)buf.ptr, buf.length.to!int);
334 		enforce(!session.isSSLReadError(res), session.sslReadErrorMessage(res));
335 	} while(session.isTryAgain(res));
336 	enforce(res >0, format!"SSL_read returned %s and expecting a positive number of bytes"(res));
337 	version(Trace) tracef("socketSecureRead: %s / %s",session.socket,buf[0..res]);
338 	version(Trace) stderr.writefln("socketSecureRead: %s / %s",session.socket,buf[0..res]);
339 	return result(Status.success,buf[0..res].idup);
340 }
341 
342 ///
343 string sslReadErrorMessage(ref Session session, int status)
344 {
345 	import std.format : format;
346 	import std.string : fromStringz;
347 	import std.exception : enforce;
348 	enforce(session.isSSLReadError(status),"ssl error that is not an error!");
349 	switch (SSL_get_error(session.sslConnection, status))
350 	{
351 		case SSL_ERROR_ZERO_RETURN:
352 			return "reading data through SSL; the connection has been closed cleanly";
353 
354 		case SSL_ERROR_SSL:
355 			return format!"reading data through SSL; %s\n"(ERR_error_string(ERR_get_error(), null).fromStringz);
356 
357 		case SSL_ERROR_SYSCALL:
358 			auto e = ERR_get_error();
359 			if (e == 0 && status == 0)
360 				return "reading data through SSL; EOF in violation of the protocol";
361 			else if (e == 0 && status == -1)
362 				return format!"reading data through SSL; %s"(strerror(errno).fromStringz);
363 			else
364 				return format!"reading data through SSL; %s"(ERR_error_string(e, null).fromStringz);
365 		default:
366 			return "";
367 	}
368 	assert(0);
369 }
370 
371 /// Write data to socket.
372 ssize_t socketWrite(ref Session session, string buf)
373 {
374 	import std.experimental.logger : tracef;
375 	import std.exception : enforce;
376 	import std.format : format;
377 	import std.string : fromStringz;
378 	import std.conv : to;
379 	int s;
380 	ssize_t r, t;
381 
382 	r = t = 0;
383 	s = 1;
384 	
385 	tracef("socketWrite: %s / %s",session.socket,buf);
386     if (session.sslConnection) {
387         tracef("socketSecureWrite: %s / %s",session.socket,buf);
388         return session.socketSecureWrite(buf);
389     }
390 
391 	scope(failure)
392 		closeConnection(session);
393 
394 	auto socketSet = new SocketSet(1);
395 	socketSet.add(session.socket);
396 
397 	tracef("entering loop with buf.length=%s",buf.length);
398 	while(buf.length > 0)
399 	{
400         tracef("buf is of length %s",buf.length);
401         tracef("sending buf: %s",buf);
402         r = session.socket.send(cast(void[])buf);
403         tracef("r=: %s",r);
404         enforce(r != -1, format!"writing data; %s"(strerror(errno).fromStringz));
405         enforce(r !=0, "unknown error");
406 
407         if (r > 0) {
408             enforce(r <= buf.length, "send to socket returned more bytes than we sent!");
409             buf = buf[r .. $];
410             t += r;
411             tracef("buf now =: %s",buf);
412         }
413     }
414 
415     enforce(s != -1, format!"waiting to write to socket; %s"(strerror(errno).fromStringz));
416     enforce(s != 0, "timeout period expired while waiting to write data");
417 
418     return t;
419 }
420 
421 /// Write data to a TLS/SSL connection.
422 auto socketSecureWrite(ref Session session, string buf)
423 {
424 	import std.experimental.logger : tracef;
425 	import std.string : fromStringz;
426 	import std.format : format;
427 	import std.conv : to;
428 	import std.exception : enforce;
429 	int r;
430 	size_t e;
431 
432 	tracef("socketSecureWrite: %s / %s",session.socket,buf);
433 	enforce(session.sslConnection, "no SSL connection has been established");
434 	if (buf.length ==0)
435 		return 0;
436 	while(true)
437 	{
438 		if ((r = SSL_write(session.sslConnection, buf.ptr, buf.length.to!int)) > 0)
439 			break;
440 	
441 		scope(failure)
442 			SSL_set_shutdown(session.sslConnection, SSL_SENT_SHUTDOWN | SSL_RECEIVED_SHUTDOWN);
443 
444 		switch (SSL_get_error(session.sslConnection, r))
445 		{
446 			case SSL_ERROR_ZERO_RETURN:
447 				throw new Exception("writing data through SSL; the connection has been closed cleanly");
448 			case SSL_ERROR_NONE:
449 			case SSL_ERROR_WANT_READ:
450 			case SSL_ERROR_WANT_WRITE:
451 			case SSL_ERROR_WANT_CONNECT:
452 			case SSL_ERROR_WANT_ACCEPT:
453 			case SSL_ERROR_WANT_X509_LOOKUP:
454 				break;
455 			case SSL_ERROR_SYSCALL:
456 				e = ERR_get_error();
457 				if (e == 0 && r == 0)
458 					throw new Exception("writing data through SSL; EOF in violation of the protocol");
459 				enforce( !(e==0 && r == -1),format!"writing data through SSL; %s\n"(strerror(errno).fromStringz.idup));
460 				enforce(true,format!"writing data through SSL; %s"(ERR_error_string(e, null).fromStringz.idup));
461 				break;
462 			case SSL_ERROR_SSL:
463 				enforce(true,format!"writing data through SSL; %s"(ERR_error_string(ERR_get_error(), null).fromStringz.idup));
464 				break;
465 			default:
466 				break;
467 		}
468 	}
469 	return r;
470 }
471 
472 
473 enum StateTLS
474 {
475     connecting,
476     accepting,
477     connected,
478 }
479 
480