I've been working on implementing a multi-client socket communication system with a thread pool in C++. The system comprises three main components: logger.h, socket.h, and thread.h, which handle logging, socket operations, and thread pooling respectively. Additionally, I have a test suite tests_multisock.cpp that verifies the functionality of the system.
The code has been designed to facilitate communication between multiple clients and a server using sockets. I'd greatly appreciate your expertise in reviewing my code for any potential issues, optimizations, or areas of improvement. I've outlined a brief summary of the components and included a snippet of the test code for context. If you have any suggestions or feedback, I'm eager to hear them.
Components:
logger.h: A logging class for capturing events and errors.
socket.h: A socket class to manage socket communication, supporting functions like creating, binding, listening, accepting, and sending/receiving data.
thread.h: A thread pool class to handle concurrent execution of tasks.
Test Suite:
tests_multisock.cpp: A suite of tests that validates the functionality of the multi-client socket communication and thread pool system.
include.h
#pragma once
#include <iostream>
#include <vector>
#include <queue>
#include <thread>
#include <functional>
#include <mutex>
#include <condition_variable>
#include <iostream>
#include <ctime>
#include <string>
#include <fstream>
#include <thread>
#include <WinSock2.h>
#include <Ws2tcpip.h>
logger.h
#pragma once
#include "include.h"
enum class LogLevel
{
DEBUG,
INFO,
WARNING,
ERR
};
class Logger
{
public:
Logger(LogLevel minLogLevel = LogLevel::INFO, const std::string& fileName = "default_log.txt")
: minLogLevel(minLogLevel), fileName(fileName)
{
SetLogFile(fileName);
}
~Logger()
{
if (logFile.is_open())
{
logFile.close();
}
}
void SetLogFile(const std::string& filename)
{
logFile.open(filename, std::ios::app);
if (!logFile.is_open())
{
std::cerr << "Failed to open log file: " << filename << std::endl;
}
}
void Log(LogLevel level, const char* file, int line, const std::string& message)
{
if (level >= minLogLevel)
{
std::string logEntry = GetTimeStamp() + " [" + LogLevelToString(level) + "] " + message +
" [" + file + ":" + std::to_string(line) + "]" + "\n";
std::cout << logEntry;
if (logFile.is_open())
{
logFile << logEntry;
logFile.flush();
}
}
}
std::string GetLogFile() const
{
return fileName;
}
private:
LogLevel minLogLevel;
std::ofstream logFile;
std::string fileName;
std::string LogLevelToString(LogLevel level) const
{
switch (level)
{
case LogLevel::DEBUG: return "DEBUG";
case LogLevel::INFO: return "INFO";
case LogLevel::WARNING: return "WARNING";
case LogLevel::ERR: return "ERROR";
default: return "UNKNOWN";
}
}
std::string GetTimeStamp() const
{
std::time_t now = std::time(nullptr);
char timestamp[20];
struct tm timeinfo;
#ifdef _WIN32
localtime_s(&timeinfo, &now);
#else
localtime_r(&now, &timeinfo);
#endif
std::strftime(timestamp, sizeof(timestamp), "%Y-%m-%d %H:%M:%S", &timeinfo);
return timestamp;
}
};
thread.h
#include "logger.h"
class ThreadPool {
public:
ThreadPool(int numThreads, Logger& logger) : logger(logger), stop(false) {
for (int i = 0; i < numThreads; ++i) {
threads.emplace_back([this]() { ThreadFunction(); });
}
}
~ThreadPool() {
{
std::unique_lock<std::mutex> lock(mutex);
stop = true;
}
condition.notify_all();
for (std::thread& thread : threads) {
thread.join();
}
}
void Enqueue(std::function<void()> task) {
{
std::unique_lock<std::mutex> lock(mutex);
tasks.push(task);
}
condition.notify_one();
}
private:
void ThreadFunction() {
while (true) {
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(mutex);
condition.wait(lock, [this]() { return stop || !tasks.empty(); });
if (stop && tasks.empty()) {
return;
}
task = tasks.front();
tasks.pop();
}
try {
logger.Log(LogLevel::DEBUG, __FILE__, __LINE__, "Task started.");
task();
logger.Log(LogLevel::DEBUG, __FILE__, __LINE__, "Task completed.");
}
catch (const std::exception& ex) {
logger.Log(LogLevel::DEBUG, __FILE__, __LINE__, "Task error: " + std::string(ex.what()));
}
}
}
private:
Logger& logger; // Reference to your Logger instance
std::vector<std::thread> threads;
std::queue<std::function<void()>> tasks;
std::mutex mutex;
std::condition_variable condition;
bool stop;
};
socket.h
#include "logger.h"
class Socket
{
public:
Socket(Logger& logger) : logger(logger) {}
bool Create()
{
WSADATA wsaData;
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0)
{
logger.Log(LogLevel::WARNING, __FILE__, __LINE__, "WSAStartup failed");
return false;
}
m_socket = socket(AF_INET, SOCK_STREAM, 0);
if (m_socket == INVALID_SOCKET)
{
logger.Log(LogLevel::WARNING, __FILE__, __LINE__, "Failed to create socket: " + std::to_string(WSAGetLastError()));
return false;
}
return true;
}
bool Bind(int port)
{
sockaddr_in hint;
hint.sin_family = AF_INET;
hint.sin_port = htons(port);
hint.sin_addr.s_addr = INADDR_ANY;
return bind(m_socket, (sockaddr*)&hint, sizeof(hint)) != SOCKET_ERROR;
}
bool Listen()
{
return listen(m_socket, SOMAXCONN) != SOCKET_ERROR;
}
bool Accept(Socket& clientSocket)
{
SOCKET client = accept(m_socket, nullptr, nullptr);
if (client != INVALID_SOCKET)
{
clientSocket.m_socket = client;
return true;
}
return false;
}
bool Connect(const char* ipAddress, int port)
{
sockaddr_in hint;
hint.sin_family = AF_INET;
hint.sin_port = htons(port);
if (inet_pton(AF_INET, ipAddress, &hint.sin_addr) <= 0)
{
// Handle error, unable to convert IP address
return false;
}
return connect(m_socket, (sockaddr*)&hint, sizeof(hint)) != SOCKET_ERROR;
}
int Send(const char* data, int dataSize)
{
return send(m_socket, data, dataSize, 0);
}
int Receive(char* buffer, int bufferSize)
{
return recv(m_socket, buffer, bufferSize, 0);
}
void Close()
{
if (m_socket != INVALID_SOCKET)
{
closesocket(m_socket);
m_socket = INVALID_SOCKET;
logger.Log(LogLevel::DEBUG, __FILE__, __LINE__, "Socket closed.");
}
}
bool SendHttpRequest(const std::string& url, int port, const std::string& httpRequest, std::string& httpResponse)
{
if (!Create())
{
return false;
}
if (!Connect(url.c_str(), port))
{
return false;
}
if (Send(httpRequest.c_str(), httpRequest.size()) != static_cast<int>(httpRequest.size()))
{
return false;
}
const int bufferSize = 1024;
char recvBuffer[bufferSize];
httpResponse.clear();
int bytesRead = 0;
do
{
bytesRead = Receive(recvBuffer, bufferSize);
if (bytesRead > 0)
{
httpResponse.append(recvBuffer, bytesRead);
}
} while (bytesRead > 0);
return true;
}
private:
SOCKET m_socket;
Logger& logger; // Reference to the Logger instance
};
tests_multisock.cpp
#include "logger.h"
#include "socket.h"
#include "thread.h"
#include <gtest/gtest.h>
class MultiClientCommunicationTest : public testing::Test {
protected:
Logger logger;
void SetUp() override {}
void TearDown() override {}
};
TEST_F(MultiClientCommunicationTest, MultipleClients) {
Logger serverLogger(LogLevel::DEBUG, "server_log.txt");
Logger clientLogger(LogLevel::DEBUG, "client_log.txt");
// Start a server thread
std::thread serverThread([&serverLogger]() {
Socket serverSocket(serverLogger);
ASSERT_TRUE(serverSocket.Create());
ASSERT_TRUE(serverSocket.Bind(12345));
ASSERT_TRUE(serverSocket.Listen());
serverLogger.Log(LogLevel::DEBUG, __FILE__, __LINE__, "Server listening");
ThreadPool threadPool(50, serverLogger);
for (int i = 0; i < 50; ++i) {
threadPool.Enqueue([&serverSocket, i, &serverLogger]() {
Socket clientSocket(serverLogger);
ASSERT_TRUE(serverSocket.Accept(clientSocket));
serverLogger.Log(LogLevel::DEBUG, __FILE__, __LINE__, "Server accepted client connection");
const char* message = "Hello from server!";
ASSERT_EQ(clientSocket.Send(message, static_cast<int>(strlen(message) + 1)), static_cast<int>(strlen(message) + 1));
serverLogger.Log(LogLevel::DEBUG, __FILE__, __LINE__, "Server sent message to client");
});
}
});
// Start client threads
std::vector<std::thread> clientThreads;
for (int i = 0; i < 50; ++i) {
clientThreads.emplace_back([&clientLogger]() {
Socket clientSocket(clientLogger);
ASSERT_TRUE(clientSocket.Create());
ASSERT_TRUE(clientSocket.Connect("127.0.0.1", 12345));
clientLogger.Log(LogLevel::DEBUG, __FILE__, __LINE__, "Client socket connected to server");
char recvBuffer[1024] = { 0 };
ASSERT_EQ(clientSocket.Receive(recvBuffer, sizeof(recvBuffer)), strlen("Hello from server!") + 1);
clientLogger.Log(LogLevel::DEBUG, __FILE__, __LINE__, "Client received message from server");
ASSERT_STREQ(recvBuffer, "Hello from server!");
});
}
// Wait for server and client threads to finish
serverThread.join();
for (auto& thread : clientThreads) {
thread.join();
}
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}