libsocket 1.5
|
00001 /* 00002 ** socket.cc 00003 ** Login : Julien Lemoine <speedblue@happycoders.org> 00004 ** Started on Sat Mar 1 23:01:09 2003 Julien Lemoine 00005 ** $Id: socket.cc,v 1.16 2004/11/24 21:25:36 speedblue Exp $ 00006 ** 00007 ** Copyright (C) 2003,2004 Julien Lemoine 00008 ** This program is free software; you can redistribute it and/or modify 00009 ** it under the terms of the GNU Lesser General Public License as published by 00010 ** the Free Software Foundation; either version 2 of the License, or 00011 ** (at your option) any later version. 00012 ** 00013 ** This program is distributed in the hope that it will be useful, 00014 ** but WITHOUT ANY WARRANTY; without even the implied warranty of 00015 ** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00016 ** GNU Lesser General Public License for more details. 00017 ** 00018 ** You should have received a copy of the GNU Lesser General Public License 00019 ** along with this program; if not, write to the Free Software 00020 ** Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA 00021 */ 00022 00023 #include <iostream> 00024 #include <fstream> 00025 #include <sys/types.h> 00026 #include "socket.hh" 00027 00028 namespace Network 00029 { 00030 00031 Socket::Socket(SOCKET_KIND kind, SOCKET_VERSION version) : 00032 _kind(kind), _version(version), _state_timeout(0), 00033 _socket(0), _recv_flags(kind), _proto_kind(text), _empty_lines(false), 00034 _buffer(""), _tls(false) 00035 { 00036 _delim.push_back("\0"); 00037 #ifdef LIBSOCKET_WIN 00038 WSADATA wsadata; 00039 if (WSAStartup(MAKEWORD(1, 1), &wsadata) != 0) 00040 throw WSAStartupError("WSAStartup failed", HERE); 00041 #endif 00042 #ifndef IPV6_ENABLED 00043 if (version == V6) 00044 throw Ipv6SupportError("lib was not compiled with ipv6 support", HERE); 00045 #endif 00046 } 00047 00048 Socket::Socket(SOCKET_KIND kind, PROTO_KIND pkind, SOCKET_VERSION version) : 00049 _kind(kind), _version(version), _state_timeout(0), 00050 _socket(0), _recv_flags(kind), _proto_kind(pkind), _empty_lines(false), 00051 _buffer(""), _tls(false) 00052 { 00053 _delim.push_back("\0"); 00054 #ifdef LIBSOCKET_WIN 00055 WSADATA wsadata; 00056 if (WSAStartup(MAKEWORD(1, 1), &wsadata) != 0) 00057 throw WSAStartupError("WSAStartup failed", HERE); 00058 #endif 00059 #ifndef IPV6_ENABLED 00060 if (version == V6) 00061 throw Ipv6SupportError("lib was not compiled with ipv6 support", HERE); 00062 #endif 00063 } 00064 00065 Socket::~Socket() 00066 { 00067 } 00068 00069 void Socket::enable_tls() 00070 { 00071 #ifdef TLS 00072 int ret; 00073 00074 if (_kind != TCP) 00075 throw TLSError("You need to have a TCP connection", HERE); 00076 if (!connected()) 00077 throw NoConnection("You need to have a connection", HERE); 00078 00079 gnutls_transport_set_ptr(_session, (gnutls_transport_ptr)_socket); 00080 ret = gnutls_handshake(_session); 00081 if (ret < 0) 00082 { 00083 close(_socket); 00084 gnutls_deinit(_session); 00085 throw TLSError(gnutls_strerror(ret), HERE); 00086 } 00087 #else 00088 throw TLSSupportError("lib was not compiled with TLS support", HERE); 00089 #endif 00090 } 00091 00092 void Socket::init_tls(GnuTLSKind kind, 00093 unsigned size, const std::string &certfile, 00094 const std::string &keyfile, 00095 const std::string &trustfile, 00096 const std::string &crlfile) 00097 { 00098 #ifdef TLS 00099 static bool init = false; 00100 static gnutls_dh_params dh_params; 00101 const int protocol_tls[] = { GNUTLS_TLS1, 0 }; 00102 const int protocol_ssl[] = { GNUTLS_SSL3, 0 }; 00103 const int cert_type_priority[] = { GNUTLS_CRT_X509, 00104 GNUTLS_CRT_OPENPGP, 0 }; 00105 00106 if (!init) 00107 { 00108 gnutls_global_init(); 00109 init = true; 00110 } 00111 _tls = true; 00112 _tls_main = true; 00113 gnutls_certificate_allocate_credentials(&_x509_cred); 00114 if (keyfile.size() > 0 && certfile.size() > 0) 00115 { 00116 std::ifstream key(keyfile.c_str()), cert(certfile.c_str()); 00117 if (!key.is_open() || !cert.is_open()) 00118 throw InvalidFile("key or cert invalid", HERE); 00119 key.close(); 00120 cert.close(); 00121 // Only for server... 00122 _nbbits = size; 00123 if (trustfile.size() > 0) 00124 gnutls_certificate_set_x509_trust_file(_x509_cred, trustfile.c_str(), 00125 GNUTLS_X509_FMT_PEM); 00126 if (crlfile.size() > 0) 00127 gnutls_certificate_set_x509_crl_file(_x509_cred, crlfile.c_str(), 00128 GNUTLS_X509_FMT_PEM); 00129 gnutls_certificate_set_x509_key_file(_x509_cred, certfile.c_str(), 00130 keyfile.c_str(), 00131 GNUTLS_X509_FMT_PEM); 00132 gnutls_dh_params_init(&dh_params); 00133 gnutls_dh_params_generate2(dh_params, _nbbits); 00134 gnutls_certificate_set_dh_params(_x509_cred, dh_params); 00135 00136 if (gnutls_init(&_session, GNUTLS_SERVER)) 00137 throw TLSError("gnutls_init failed", HERE); 00138 } 00139 else 00140 { 00141 if (gnutls_init(&_session, GNUTLS_CLIENT)) 00142 throw TLSError("gnutls_init failed", HERE); 00143 } 00144 00145 gnutls_set_default_priority(_session); 00146 if (kind == TLS) 00147 gnutls_protocol_set_priority(_session, protocol_tls); 00148 else 00149 gnutls_protocol_set_priority(_session, protocol_ssl); 00150 00151 if (keyfile.size() > 0 && certfile.size() > 0) 00152 { 00153 gnutls_credentials_set(_session, GNUTLS_CRD_CERTIFICATE, _x509_cred); 00154 gnutls_certificate_server_set_request(_session, GNUTLS_CERT_REQUEST); 00155 gnutls_dh_set_prime_bits(_session, _nbbits); 00156 } 00157 else 00158 { 00159 gnutls_certificate_type_set_priority(_session, cert_type_priority); 00160 gnutls_credentials_set(_session, GNUTLS_CRD_CERTIFICATE, _x509_cred); 00161 } 00162 #else 00163 throw TLSSupportError("lib was not compiled with TLS support", HERE); 00164 #endif 00165 } 00166 00167 void Socket::_close(int socket) const 00168 { 00169 #ifndef LIBSOCKET_WIN 00170 if (socket < 0 || close(socket) < 0) 00171 throw CloseError("Close Error", HERE); 00172 socket = 0; 00173 #else 00174 if (socket < 0 || closesocket(socket) < 0) 00175 throw CloseError("Close Error", HERE); 00176 socket = 0; 00177 #endif 00178 #ifdef TLS 00179 if (_tls) 00180 { 00181 std::cout << "Deletion..." << std::endl; 00182 gnutls_deinit(_session); 00183 if (_tls_main) 00184 { 00185 gnutls_certificate_free_credentials(_x509_cred); 00186 gnutls_global_deinit(); 00187 } 00188 } 00189 #endif 00190 } 00191 00192 void Socket::_listen(int socket) const 00193 { 00194 if (socket < 0 || listen(socket, 5) < 0) 00195 throw ListenError("Listen Error", HERE); 00196 } 00197 00198 void Socket::_write_str(int socket, const std::string& str) const 00199 { 00200 int res = 1; 00201 unsigned int count = 0; 00202 const char *buf; 00203 00204 buf = str.c_str(); 00205 if (socket < 0) 00206 throw NoConnection("No Socket", HERE); 00207 while (res && count < str.size()) 00208 { 00209 #ifdef IPV6_ENABLED 00210 if (V4 == _version) 00211 #endif 00212 #ifdef TLS 00213 if (_tls) 00214 res = gnutls_record_send(_session, buf + count, str.size() - count); 00215 else 00216 #endif 00217 res = sendto(socket, buf + count, str.size() - count, SENDTO_FLAGS, 00218 (const struct sockaddr*)&_addr, sizeof(_addr)); 00219 #ifdef IPV6_ENABLED 00220 else 00221 res = sendto(socket, buf + count, str.size() - count, SENDTO_FLAGS, 00222 (const struct sockaddr*)&_addr6, sizeof(_addr6)); 00223 #endif 00224 if (res <= 0) 00225 throw ConnectionClosed("Connection Closed", HERE); 00226 count += res; 00227 } 00228 } 00229 00230 void Socket::_write_str_bin(int socket, const std::string& str) const 00231 { 00232 int res = 1; 00233 unsigned int count = 0; 00234 #ifdef LIBSOCKET_WIN 00235 char* buf = new char[str.size() + 2]; 00236 #else 00237 char buf[str.size() + 2]; 00238 #endif 00239 buf[0] = str.size() / 256; 00240 buf[1] = str.size() % 256; 00241 memcpy(buf + 2, str.c_str(), str.size()); 00242 if (socket < 0) 00243 throw NoConnection("No Socket", HERE); 00244 while (res && count < str.size() + 2) 00245 { 00246 #ifdef IPV6_ENABLED 00247 if (V4 == _version) 00248 #endif 00249 #ifdef TLS 00250 if (_tls) 00251 res = gnutls_record_send(_session, buf + count, str.size() + 2 - count); 00252 else 00253 #endif 00254 res = sendto(socket, buf + count, str.size() + 2 - count, 00255 SENDTO_FLAGS, 00256 (const struct sockaddr*)&_addr, sizeof(_addr)); 00257 #ifdef IPV6_ENABLED 00258 else 00259 res = sendto(socket, buf + count, str.size() + 2 - count, 00260 \ SENDTO_FLAGS, 00261 (const struct sockaddr*)&_addr6, sizeof(_addr6)); 00262 #endif 00263 if (res <= 0) 00264 throw ConnectionClosed("Connection Closed", HERE); 00265 count += res; 00266 } 00267 #ifdef LIBSOCKET_WIN 00268 delete[] buf; 00269 #endif 00270 } 00271 00272 void Socket::_set_timeout(bool enable, int socket, int timeout) 00273 { 00274 fd_set fdset; 00275 struct timeval timetowait; 00276 int res; 00277 00278 if (enable) 00279 timetowait.tv_sec = timeout; 00280 else 00281 timetowait.tv_sec = 65535; 00282 timetowait.tv_usec = 0; 00283 FD_ZERO(&fdset); 00284 FD_SET(socket, &fdset); 00285 if (enable) 00286 res = select(socket + 1, &fdset, NULL, NULL, &timetowait); 00287 else 00288 res = select(socket + 1, &fdset, NULL, NULL, NULL); 00289 if (res < 0) 00290 throw SelectError("Select error", HERE); 00291 if (res == 0) 00292 throw Timeout("Timeout on socket", HERE); 00293 } 00294 00295 void Socket::write(const std::string& str) 00296 { 00297 if (_proto_kind == binary) 00298 _write_str_bin(_socket, str); 00299 else 00300 _write_str(_socket, str); 00301 } 00302 00303 bool Socket::connected() const 00304 { 00305 return _socket != 0; 00306 } 00307 00308 void Socket::allow_empty_lines() 00309 { 00310 _empty_lines = true; 00311 } 00312 00313 int Socket::get_socket() 00314 { 00315 return _socket; 00316 } 00317 00318 void Socket::add_delim(const std::string& delim) 00319 { 00320 _delim.push_back(delim); 00321 } 00322 00323 void Socket::del_delim(const std::string& delim) 00324 { 00325 std::list<std::string>::iterator it, it2; 00326 00327 for (it = _delim.begin(); it != _delim.end(); ) 00328 { 00329 if (*it == delim) 00330 { 00331 it2 = it++; 00332 _delim.erase(it2); 00333 } 00334 else 00335 it++; 00336 } 00337 } 00338 00339 std::pair<int, int> Socket::_find_delim(const std::string& str, int start) const 00340 { 00341 int i = -1; 00342 int pos = -1, size = 0; 00343 std::list<std::string>::const_iterator it; 00344 00345 // Looking for the first delimiter. 00346 if (_delim.size() > 0) 00347 { 00348 it = _delim.begin(); 00349 while (it != _delim.end()) 00350 { 00351 if (*it == "") 00352 i = str.find('\0', start); 00353 else 00354 i = str.find(*it, start); 00355 if ((i >= 0) && ((unsigned int)i < str.size()) && 00356 (pos < 0 || i < pos)) 00357 { 00358 pos = i; 00359 size = it->size() ? it->size() : 1; 00360 } 00361 it++; 00362 } 00363 } 00364 return std::pair<int, int>(pos, size); 00365 } 00366 00367 Socket& operator<<(Socket& s, const std::string& str) 00368 { 00369 s.write(str); 00370 return s; 00371 } 00372 00373 Socket& operator>>(Socket& s, std::string& str) 00374 { 00375 str = s.read(); 00376 return s; 00377 } 00378 }