#include #include #include #include #include #include #include #include #include #include #include "tcpnet.h" #define TCPNET_BUFFER_SIZE 16384 void error(const char *msg) { // perror(msg); throw std::runtime_error(msg); } // TODO do it as RAII TcpNet::TcpNet() = default; int TcpNet::server(int portno, const std::function(std::string)> &process_request) const { int sockfd, newsockfd; socklen_t clilen; struct sockaddr_in serv_addr, cli_addr; sockfd = socket(AF_INET, SOCK_STREAM, 0); if (sockfd < 0) error("ERROR opening socket"); memset((char *) &serv_addr, 0, sizeof(serv_addr)); serv_addr.sin_family = AF_INET; serv_addr.sin_addr.s_addr = INADDR_ANY; serv_addr.sin_port = htons(portno); // this allows immediate bind after exit of ml int reuse = 1; if (setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (const char *) &reuse, sizeof(reuse)) < 0) error("setsockopt(SO_REUSEADDR) failed"); #ifdef SO_REUSEPORT if (setsockopt(sockfd, SOL_SOCKET, SO_REUSEPORT, (const char *) &reuse, sizeof(reuse)) < 0) error("setsockopt(SO_REUSEPORT) failed"); #endif if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) error("ERROR on binding"); listen(sockfd, 5); clilen = sizeof(cli_addr); int requests_processed = 0; bool shutdown = false; while (!shutdown) { newsockfd = accept(sockfd, (struct sockaddr *) &cli_addr, &clilen); if (newsockfd < 0) error("ERROR on accept"); while (true) { std::string request = read_from_socket(newsockfd); if (request.empty()) break; std::pair response = process_request(request); shutdown = response.first; std::string response_str = response.second; write_to_socket(newsockfd, response_str); requests_processed++; } close(newsockfd); } close(sockfd); return requests_processed; } std::vector TcpNet::client(const std::string &address, int portno, const std::vector &requests) const { int sockfd; struct sockaddr_in serv_addr; struct hostent *server; std::vector responses; sockfd = socket(AF_INET, SOCK_STREAM, 0); if (sockfd < 0) error("ERROR opening socket"); server = gethostbyname(address.c_str()); if (server == NULL) { fprintf(stderr, "ERROR, no such host\n"); exit(0); } memset((char *) &serv_addr, 0, sizeof(serv_addr)); serv_addr.sin_family = AF_INET; bcopy((char *) server->h_addr, (char *) &serv_addr.sin_addr.s_addr, server->h_length); serv_addr.sin_port = htons(portno); if (connect(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) error("ERROR connecting"); responses.reserve(requests.size()); for (const auto &req : requests) { write_to_socket(sockfd, req); std::string response = read_from_socket(sockfd); responses.push_back(response); } close(sockfd); return responses; } std::string TcpNet::client(const std::string &address, int portno, const std::string &request) const { std::vector c{request}; auto response = client(address, portno, c); return response[0]; } std::string TcpNet::read_from_socket(int sockfd) { char buffer[TCPNET_BUFFER_SIZE]; std::string request; // read length header unsigned long long readdatalen = 0; if (USE_LENGTH_HEADER) { long n = read(sockfd, &readdatalen, sizeof(readdatalen)); if (n != 0 && n != sizeof(readdatalen)) error("ERROR reading length header failed"); } // read data long n; long readed = 0; do { memset(buffer, 0, TCPNET_BUFFER_SIZE); n = read(sockfd, buffer, TCPNET_BUFFER_SIZE - 1); if (n == 0) break; // nothing to read from client anymore if (n < 0) { if (errno == ECONNRESET) break; // connection reset by peer std::string err_desc{"ERROR reading from socket "}; err_desc.append(strerror(errno)); error(err_desc.c_str()); } std::string part{buffer}; request.append(part); readed += n; } while ((USE_LENGTH_HEADER && readed < readdatalen) || n == TCPNET_BUFFER_SIZE - 1); // TODO what if data exactly of this size return request; } void TcpNet::write_to_socket(int sockfd, const std::string &str) { const char *buffer = str.c_str(); long n; // write length header unsigned long long writedatalen = str.length(); if (USE_LENGTH_HEADER) { n = write(sockfd, &writedatalen, (int) sizeof(writedatalen)); if (n < 0) error("ERROR writing size number to socket"); } // write data int pos = 0; do { n = write(sockfd, buffer + pos, (int) (str.length() - pos)); if (n < 0) error("ERROR writing to socket"); } while (pos + n < str.length()); }