source: box/trunk/lib/server/SocketStreamTLS.cpp @ 3096

Revision 3096, 11.6 KB checked in by chris, 4 weeks ago (diff)

Move LogError? out of server/SSLLib so we can use it in Crypto.

  • Property svn:eol-style set to native
Line 
1// --------------------------------------------------------------------------
2//
3// File
4//              Name:    SocketStreamTLS.cpp
5//              Purpose: Socket stream encrpyted and authenticated by TLS
6//              Created: 2003/08/06
7//
8// --------------------------------------------------------------------------
9
10#include "Box.h"
11
12#define TLS_CLASS_IMPLEMENTATION_CPP
13#include <openssl/ssl.h>
14#include <openssl/bio.h>
15#include <errno.h>
16#include <fcntl.h>
17
18#ifndef WIN32
19#include <poll.h>
20#endif
21
22#include "BoxTime.h"
23#include "CryptoUtils.h"
24#include "ServerException.h"
25#include "SocketStreamTLS.h"
26#include "SSLLib.h"
27#include "TLSContext.h"
28
29#include "MemLeakFindOn.h"
30
31// Allow 5 minutes to handshake (in milliseconds)
32#define TLS_HANDSHAKE_TIMEOUT (5*60*1000)
33
34// --------------------------------------------------------------------------
35//
36// Function
37//              Name:    SocketStreamTLS::SocketStreamTLS()
38//              Purpose: Constructor
39//              Created: 2003/08/06
40//
41// --------------------------------------------------------------------------
42SocketStreamTLS::SocketStreamTLS()
43        : mpSSL(0), mpBIO(0)
44{
45        ResetCounters();
46}
47
48// --------------------------------------------------------------------------
49//
50// Function
51//              Name:    SocketStreamTLS::SocketStreamTLS(int)
52//              Purpose: Constructor, taking previously connected socket
53//              Created: 2003/08/06
54//
55// --------------------------------------------------------------------------
56SocketStreamTLS::SocketStreamTLS(int socket)
57        : SocketStream(socket),
58          mpSSL(0), mpBIO(0)
59{
60}
61
62// --------------------------------------------------------------------------
63//
64// Function
65//              Name:    SocketStreamTLS::~SocketStreamTLS()
66//              Purpose: Destructor
67//              Created: 2003/08/06
68//
69// --------------------------------------------------------------------------
70SocketStreamTLS::~SocketStreamTLS()
71{
72        if(mpSSL)
73        {
74                // Attempt to close to avoid problems
75                Close();
76               
77                // And if that didn't work...
78                if(mpSSL)
79                {
80                        ::SSL_free(mpSSL);
81                        mpSSL = 0;
82                        mpBIO = 0;      // implicity freed by the SSL_free call
83                }
84        }
85       
86        // If we only got to creating that BIO.
87        if(mpBIO)
88        {
89                ::BIO_free(mpBIO);
90                mpBIO = 0;
91        }
92}
93
94
95// --------------------------------------------------------------------------
96//
97// Function
98//              Name:    SocketStreamTLS::Open(const TLSContext &, int, const char *, int)
99//              Purpose: Open connection, and perform TLS handshake
100//              Created: 2003/08/06
101//
102// --------------------------------------------------------------------------
103void SocketStreamTLS::Open(const TLSContext &rContext, Socket::Type Type,
104        const std::string& rName, int Port)
105{
106        SocketStream::Open(Type, rName, Port);
107        Handshake(rContext);
108        ResetCounters();
109}
110
111
112// --------------------------------------------------------------------------
113//
114// Function
115//              Name:    SocketStreamTLS::Handshake(const TLSContext &, bool)
116//              Purpose: Perform TLS handshake
117//              Created: 2003/08/06
118//
119// --------------------------------------------------------------------------
120void SocketStreamTLS::Handshake(const TLSContext &rContext, bool IsServer)
121{
122        if(mpBIO || mpSSL) {THROW_EXCEPTION(ServerException, TLSAlreadyHandshaked)}
123
124        // Create a BIO for this socket
125        mpBIO = ::BIO_new(::BIO_s_socket());
126        if(mpBIO == 0)
127        {
128                CryptoUtils::LogError("creating socket bio");
129                THROW_EXCEPTION(ServerException, TLSAllocationFailed)
130        }
131
132        tOSSocketHandle socket = GetSocketHandle();
133        BIO_set_fd(mpBIO, socket, BIO_NOCLOSE);
134       
135        // Then the SSL object
136        mpSSL = ::SSL_new(rContext.GetRawContext());
137        if(mpSSL == 0)
138        {
139                CryptoUtils::LogError("creating SSL object");
140                THROW_EXCEPTION(ServerException, TLSAllocationFailed)
141        }
142
143        // Make the socket non-blocking so timeouts on Read work
144
145#ifdef WIN32
146        u_long nonblocking = 1;
147        ioctlsocket(socket, FIONBIO, &nonblocking);
148#else // !WIN32
149        // This is more portable than using ioctl with FIONBIO
150        int statusFlags = 0;
151        if(::fcntl(socket, F_GETFL, &statusFlags) < 0
152           || ::fcntl(socket, F_SETFL, statusFlags | O_NONBLOCK) == -1)
153        {
154                THROW_EXCEPTION(ServerException, SocketSetNonBlockingFailed)
155        }
156#endif
157       
158        // FIXME: This is less portable than the above. However, it MAY be needed
159        // for cygwin, which has/had bugs with fcntl
160        //
161        // int nonblocking = true;
162        // if(::ioctl(socket, FIONBIO, &nonblocking) == -1)
163        // {
164        //      THROW_EXCEPTION(ServerException, SocketSetNonBlockingFailed)
165        // }
166
167        // Set the two to know about each other
168        ::SSL_set_bio(mpSSL, mpBIO, mpBIO);
169
170        bool waitingForHandshake = true;
171        while(waitingForHandshake)
172        {
173                // Attempt to do the handshake
174                int r = 0;
175                if(IsServer)
176                {
177                        r = ::SSL_accept(mpSSL);
178                }
179                else
180                {
181                        r = ::SSL_connect(mpSSL);
182                }
183
184                // check return code
185                int se;
186                switch((se = ::SSL_get_error(mpSSL, r)))
187                {
188                case SSL_ERROR_NONE:
189                        // No error, handshake succeeded
190                        waitingForHandshake = false;
191                        break;
192
193                case SSL_ERROR_WANT_READ:
194                case SSL_ERROR_WANT_WRITE:
195                        // wait for the requried data
196                        if(WaitWhenRetryRequired(se, TLS_HANDSHAKE_TIMEOUT) == false)
197                        {
198                                // timed out
199                                THROW_EXCEPTION(ConnectionException, Conn_TLSHandshakeTimedOut)
200                        }
201                        break;
202                       
203                default: // (and SSL_ERROR_ZERO_RETURN)
204                        // Error occured
205                        if(IsServer)
206                        {
207                                CryptoUtils::LogError("accepting connection");
208                                THROW_EXCEPTION(ConnectionException, Conn_TLSHandshakeFailed)
209                        }
210                        else
211                        {
212                                CryptoUtils::LogError("connecting");
213                                THROW_EXCEPTION(ConnectionException, Conn_TLSHandshakeFailed)
214                        }
215                }
216        }
217       
218        // And that's it
219}
220
221// --------------------------------------------------------------------------
222//
223// Function
224//              Name:    WaitWhenRetryRequired(int, int)
225//              Purpose: Waits until the condition required by the TLS layer is met.
226//                               Returns true if the condition is met, false if timed out.
227//              Created: 2003/08/15
228//
229// --------------------------------------------------------------------------
230bool SocketStreamTLS::WaitWhenRetryRequired(int SSLErrorCode, int Timeout)
231{
232        struct pollfd p;
233        p.fd = GetSocketHandle();
234        switch(SSLErrorCode)
235        {
236        case SSL_ERROR_WANT_READ:
237                p.events = POLLIN;
238                break;
239               
240        case SSL_ERROR_WANT_WRITE:
241                p.events = POLLOUT;
242                break;
243
244        default:
245                // Not good!
246                THROW_EXCEPTION(ServerException, Internal)
247                break;
248        }
249        p.revents = 0;
250
251        int64_t start, end;
252        start = BoxTimeToMilliSeconds(GetCurrentBoxTime());
253        end   = start + Timeout;
254        int result;
255
256        do
257        {
258                int64_t now = BoxTimeToMilliSeconds(GetCurrentBoxTime());
259                int poll_timeout = (int)(end - now);
260                if (poll_timeout < 0) poll_timeout = 0;
261                if (Timeout == IOStream::TimeOutInfinite)
262                {
263                        poll_timeout = INFTIM;
264                }
265                result = ::poll(&p, 1, poll_timeout);
266        }
267        while(result == -1 && errno == EINTR);
268
269        switch(result)
270        {
271        case -1:
272                // error - Bad!
273                THROW_EXCEPTION(ServerException, SocketPollError)
274                break;
275
276        case 0:
277                // Condition not met, timed out
278                return false;
279                break;
280
281        default:
282                // good to go!
283                return true;
284                break;
285        }
286
287        return true;
288}
289
290// --------------------------------------------------------------------------
291//
292// Function
293//              Name:    SocketStreamTLS::Read(void *, int, int Timeout)
294//              Purpose: See base class
295//              Created: 2003/08/06
296//
297// --------------------------------------------------------------------------
298int SocketStreamTLS::Read(void *pBuffer, int NBytes, int Timeout)
299{
300        if(!mpSSL) {THROW_EXCEPTION(ServerException, TLSNoSSLObject)}
301
302        // Make sure zero byte reads work as expected
303        if(NBytes == 0)
304        {
305                return 0;
306        }
307       
308        while(true)
309        {
310                int r = ::SSL_read(mpSSL, pBuffer, NBytes);
311
312                int se;
313                switch((se = ::SSL_get_error(mpSSL, r)))
314                {
315                case SSL_ERROR_NONE:
316                        // No error, return number of bytes read
317                        mBytesRead += r;
318                        return r;
319                        break;
320
321                case SSL_ERROR_ZERO_RETURN:
322                        // Connection closed
323                        MarkAsReadClosed();
324                        return 0;
325                        break;
326
327                case SSL_ERROR_WANT_READ:
328                case SSL_ERROR_WANT_WRITE:
329                        // wait for the required data
330                        // Will only get once around this loop, so don't need to calculate timeout values
331                        if(WaitWhenRetryRequired(se, Timeout) == false)
332                        {
333                                // timed out
334                                return 0;
335                        }
336                        break;
337                       
338                default:
339                        CryptoUtils::LogError("reading");
340                        THROW_EXCEPTION(ConnectionException, Conn_TLSReadFailed)
341                        break;
342                }
343        }
344}
345
346// --------------------------------------------------------------------------
347//
348// Function
349//              Name:    SocketStreamTLS::Write(const void *, int)
350//              Purpose: See base class
351//              Created: 2003/08/06
352//
353// --------------------------------------------------------------------------
354void SocketStreamTLS::Write(const void *pBuffer, int NBytes)
355{
356        if(!mpSSL) {THROW_EXCEPTION(ServerException, TLSNoSSLObject)}
357       
358        // Make sure zero byte writes work as expected
359        if(NBytes == 0)
360        {
361                return;
362        }
363       
364        // from man SSL_write
365        //
366        // SSL_write() will only return with success, when the
367        // complete contents of buf of length num has been written.
368        //
369        // So no worries about partial writes and moving the buffer around
370       
371        while(true)
372        {
373                // try the write
374                int r = ::SSL_write(mpSSL, pBuffer, NBytes);
375               
376                int se;
377                switch((se = ::SSL_get_error(mpSSL, r)))
378                {
379                case SSL_ERROR_NONE:
380                        // No error, data sent, return success
381                        mBytesWritten += r;
382                        return;
383                        break;
384
385                case SSL_ERROR_ZERO_RETURN:
386                        // Connection closed
387                        MarkAsWriteClosed();
388                        THROW_EXCEPTION(ConnectionException, Conn_TLSClosedWhenWriting)
389                        break;
390
391                case SSL_ERROR_WANT_READ:
392                case SSL_ERROR_WANT_WRITE:
393                        // wait for the requried data
394                        {
395                        #ifndef BOX_RELEASE_BUILD
396                                bool conditionmet = 
397                        #endif
398                                WaitWhenRetryRequired(se, IOStream::TimeOutInfinite);
399                                ASSERT(conditionmet);
400                        }
401                        break;
402               
403                default:
404                        CryptoUtils::LogError("writing");
405                        THROW_EXCEPTION(ConnectionException, Conn_TLSWriteFailed)
406                        break;
407                }
408        }
409}
410
411// --------------------------------------------------------------------------
412//
413// Function
414//              Name:    SocketStreamTLS::Close()
415//              Purpose: See base class
416//              Created: 2003/08/06
417//
418// --------------------------------------------------------------------------
419void SocketStreamTLS::Close()
420{
421        if(!mpSSL) {THROW_EXCEPTION(ServerException, TLSNoSSLObject)}
422
423        // Base class to close
424        SocketStream::Close();
425
426        // Free resources
427        ::SSL_free(mpSSL);
428        mpSSL = 0;
429        mpBIO = 0;      // implicitly freed by SSL_free
430}
431
432// --------------------------------------------------------------------------
433//
434// Function
435//              Name:    SocketStreamTLS::Shutdown()
436//              Purpose: See base class
437//              Created: 2003/08/06
438//
439// --------------------------------------------------------------------------
440void SocketStreamTLS::Shutdown(bool Read, bool Write)
441{
442        if(!mpSSL) {THROW_EXCEPTION(ServerException, TLSNoSSLObject)}
443
444        if(::SSL_shutdown(mpSSL) < 0)
445        {
446                CryptoUtils::LogError("shutting down");
447                THROW_EXCEPTION(ConnectionException, Conn_TLSShutdownFailed)
448        }
449
450        // Don't ask the base class to shutdown -- BIO does this, apparently.
451}
452
453// --------------------------------------------------------------------------
454//
455// Function
456//              Name:    SocketStreamTLS::GetPeerCommonName()
457//              Purpose: Returns the common name of the other end of the connection
458//              Created: 2003/08/06
459//
460// --------------------------------------------------------------------------
461std::string SocketStreamTLS::GetPeerCommonName()
462{
463        if(!mpSSL) {THROW_EXCEPTION(ServerException, TLSNoSSLObject)}
464
465        // Get certificate
466        X509 *cert = ::SSL_get_peer_certificate(mpSSL);
467        if(cert == 0)
468        {
469                ::X509_free(cert);
470                THROW_EXCEPTION(ConnectionException, Conn_TLSNoPeerCertificate)
471        }
472
473        // Subject details     
474        X509_NAME *subject = ::X509_get_subject_name(cert); 
475        if(subject == 0)
476        {
477                ::X509_free(cert);
478                THROW_EXCEPTION(ConnectionException, Conn_TLSPeerCertificateInvalid)
479        }
480       
481        // Common name
482        char commonName[256];
483        if(::X509_NAME_get_text_by_NID(subject, NID_commonName, commonName, sizeof(commonName)) <= 0)
484        {
485                ::X509_free(cert);
486                THROW_EXCEPTION(ConnectionException, Conn_TLSPeerCertificateInvalid)
487        }
488        // Terminate just in case
489        commonName[sizeof(commonName)-1] = '\0';
490       
491        // Done.
492        return std::string(commonName);
493}
Note: See TracBrowser for help on using the repository browser.