Skip to content

[feat] add tcp store for rendezvous #410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
feat: tcp store success
  • Loading branch information
konnase committed Feb 13, 2025
commit ab9b40eff8e10c0857b9600679573bb0d367ff9a
300 changes: 200 additions & 100 deletions gloo/rendezvous/tcp_store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
#include "gloo/common/logging.h"

#define BUFFER_SIZE 1024
#define ACTION_SIZE 3
#define SIZE_OF_SIZE 16
#define RESPONSE_SIZE 2

const std::string POST_ACTION_SET = "set";
const std::string POST_ACTION_GET = "get";
const std::string NOT_FOUND = "NOT_FOUND";
Expand All @@ -43,13 +47,12 @@ namespace gloo
TCPStore::TCPStore(const std::string &hostname, int port, int world_size, bool is_master, int timeout)
: hostname_(hostname),
host_ip_(host_to_ip(hostname)),
port_(port),
port_(static_cast<uint16_t>(port)),
world_size_(world_size),
is_master_(is_master),
timeout_(timeout),
data_({})
{
uint16_t PORT = static_cast<uint16_t>(port);
std::cout << "hostname: " << hostname_ << ", " << host_ip_ << ", port: " << port << ", world_size: " << world_size
<< ", is_master: " << is_master << std::endl;
if (is_master)
Expand All @@ -66,7 +69,7 @@ namespace gloo
struct sockaddr_in address;
address.sin_family = AF_INET;
address.sin_addr.s_addr = INADDR_ANY; // 监听所有的网络接口
address.sin_port = htons(PORT);
address.sin_port = htons(port_);

// 绑定 socket 到地址
if (bind(server_fd, (struct sockaddr *)&address, sizeof(address)) < 0)
Expand All @@ -84,110 +87,115 @@ namespace gloo

std::thread(&TCPStore::accept_func, this).detach();
}
else
}

void TCPStore::accept_func()
{

// 服务器进入循环,持续接受客户端连接
while (true)
{
// 创建 socket
server_fd = socket(AF_INET, SOCK_STREAM, 0);
if (server_fd == -1)
// 接受客户端连接
int new_socket;
struct sockaddr_in client_address;
socklen_t addr_len = sizeof(client_address);
new_socket = accept(server_fd, (struct sockaddr *)&client_address, &addr_len);
if (new_socket < 0)
{
auto err = std::string("Socket creation failed: ") + strerror(errno);
auto err = std::string("Accept client connection failed: ") + strerror(errno);
GLOO_THROW(err);
}

// 设置服务器地址信息
server_address.sin_family = AF_INET;
server_address.sin_port = htons(PORT);
std::cout << "Connection established with client." << std::endl;

// 将 IP 地址从文本转换为二进制形式
if (inet_pton(AF_INET, host_ip_.c_str(), &server_address.sin_addr) <= 0)
// 读取客户端消息
char act_buffer[ACTION_SIZE + 1] = {0};
int valread = read(new_socket, act_buffer, ACTION_SIZE);
std::string action = std::string(act_buffer);
if (action == POST_ACTION_SET)
{
auto err = std::string("Invalid address: ") + strerror(errno);
GLOO_THROW(err);
}
std::cout << "Set request received." << std::endl;

// 连接服务器
if (connect(server_fd, (struct sockaddr *)&server_address, sizeof(server_address)) < 0)
{
auto err = std::string("Connection to server failed: ") + strerror(errno);
GLOO_THROW(err);
}
}
}
// read key size
char key_size_buffer[SIZE_OF_SIZE + 1] = {0};
int valread = read(new_socket, key_size_buffer, SIZE_OF_SIZE);
int key_size = atoi(key_size_buffer);
std::cout << "key size: " << key_size << std::endl;

void TCPStore::accept_func()
{
// 接受客户端连接
int new_socket;
struct sockaddr_in client_address;
socklen_t addr_len = sizeof(client_address);
new_socket = accept(server_fd, (struct sockaddr *)&client_address, &addr_len);
if (new_socket < 0)
{
auto err = std::string("Accept client connection failed: ") + strerror(errno);
GLOO_THROW(err);
}
// read key
char key_buffer[key_size + 1] = {0};
valread = read(new_socket, key_buffer, key_size);
std::string key = std::string(key_buffer);
std::cout << "key: " << key << std::endl;

std::cout << "Connection established with client." << std::endl;
// read data size
char data_size_buffer[SIZE_OF_SIZE + 1] = {0};
valread = read(new_socket, data_size_buffer, SIZE_OF_SIZE);
int data_size = atoi(data_size_buffer);
std::cout << "data size: " << data_size << std::endl;

// 服务器进入循环,持续接受客户端连接
while (true)
{
// 读取客户端消息
char buffer[BUFFER_SIZE] = {0};
int valread = read(new_socket, buffer, BUFFER_SIZE);
if (valread > 0)
// read data
char data_buffer[data_size + 1] = {0};
valread = read(new_socket, data_buffer, data_size);
std::string value = std::string(data_buffer);
std::vector<char> value_vec(data_buffer, data_buffer + data_size);
std::cout << "data_buffer: <" << data_buffer << ">" << std::endl;
std::cout << "value read: " << valread << "value: <" << value << ">" << std::endl;

mtx.lock();
data_[key] = value_vec;
mtx.unlock();

// 向客户端发送响应
const char *response = "OK";
send(new_socket, response, strlen(response), 0);
// std::cout << "Response sent to client." << std::endl;
}
else if (action == POST_ACTION_GET)
{
std::string buffer_str = std::string(buffer);
std::vector<std::string> buffer_split = str_split(buffer_str, ':');
if (buffer_split.size() < 2)
std::cout << "Get request received." << std::endl;
// read key size
char key_size_buffer[SIZE_OF_SIZE + 1] = {0};
int valread = read(new_socket, key_size_buffer, SIZE_OF_SIZE);
int key_size = atoi(key_size_buffer);

// read key
char key_buffer[key_size + 1] = {0};
valread = read(new_socket, key_buffer, key_size);
std::string key = std::string(key_buffer);
std::cout << "get key: " << key << std::endl;

bool found = false;
std::vector<char> value = {};

mtx.lock();
if (data_.find(key) != data_.end())
{
GLOO_THROW("Invalid message format, must be formated as [action]:[key]:[value] or [action]:[key]!");
found = true;
value = data_[key];
}
mtx.unlock();

std::string action = buffer_split[0];
if (action == POST_ACTION_SET)
if (found)
{
std::string key = buffer_split[1];
std::string value = buffer_split[2];
std::vector<char> value_vec(value.begin(), value.end());
mtx.lock();
data_[key] = value_vec;
mtx.unlock();

// 向客户端发送响应
const char *response = "OK";
send(new_socket, response, strlen(response), 0);
// std::cout << "Response sent to client." << std::endl;
}
else if (action == POST_ACTION_GET)
{
std::string key = buffer_split[1];
bool found = false;
std::vector<char> value = {};

mtx.lock();
if (data_.find(key) != data_.end())
{
found = true;
value = data_[key];
}
mtx.unlock();

std::string value_str(value.begin(), value.end());
value_str = found ? value_str : NOT_FOUND;
const char *response = value_str.c_str();
send(new_socket, response, strlen(response), 0);
send(new_socket, value.data(), value.size(), 0);
}
else
{
// 向客户端发送响应
const char *response = "OK";
const char *response = NOT_FOUND.c_str();
send(new_socket, response, strlen(response), 0);
// std::cout << "Response sent to client." << std::endl;
}
}
else
{
// 向客户端发送响应
const char *response = "OK";
send(new_socket, response, strlen(response), 0);
std::cout << "Response sent to client." << std::endl;
}

close(new_socket);
}
close(new_socket);
}

void TCPStore::set(const std::string &key, const std::vector<char> &data)
Expand All @@ -200,16 +208,68 @@ namespace gloo
}
else
{
// 向服务器发送消息
std::string key_with_data = POST_ACTION_SET + ":" + key + ":" + std::string(data.begin(), data.end());
const char *message = key_with_data.c_str();
send(server_fd, message, strlen(message), 0);
// std::cout << "Message sent to server." << std::endl;
// 创建 socket
int new_server_fd = socket(AF_INET, SOCK_STREAM, 0);
if (new_server_fd == -1)
{
auto err = std::string("Socket creation failed: ") + strerror(errno);
GLOO_THROW(err);
}

// 设置服务器地址信息
server_address.sin_family = AF_INET;
server_address.sin_port = htons(port_);

// 将 IP 地址从文本转换为二进制形式
if (inet_pton(AF_INET, host_ip_.c_str(), &server_address.sin_addr) <= 0)
{
auto err = std::string("Invalid address: ") + strerror(errno);
GLOO_THROW(err);
}

// 连接服务器
if (connect(new_server_fd, (struct sockaddr *)&server_address, sizeof(server_address)) < 0)
{
auto err = std::string("Connection to server failed: ") + strerror(errno);
GLOO_THROW(err);
}

// send action
std::string act_data = POST_ACTION_SET;
const char *message = act_data.c_str();
send(new_server_fd, message, strlen(message), 0);

// send key size
size_t len = key.length();
std::string len_str = std::to_string(len);
len_str = std::string(SIZE_OF_SIZE - len_str.length(), '0') + len_str;
message = len_str.c_str();
send(new_server_fd, message, strlen(message), 0);
std::cout << "key size: " << len_str << std::endl;

// send key
message = key.c_str();
send(new_server_fd, message, strlen(message), 0);
std::cout << "key: " << key << std::endl;

// send data size
len = data.size();
len_str = std::to_string(len);
len_str = std::string(SIZE_OF_SIZE - len_str.length(), '0') + len_str;
message = len_str.c_str();
send(new_server_fd, message, strlen(message), 0);
std::cout << "data size: " << len_str << std::endl;

// send data
void *data_ptr = static_cast<void *>(const_cast<char *>(data.data()));
send(new_server_fd, data_ptr, len, 0);

// 读取服务器响应
char buffer[BUFFER_SIZE] = {0};
int valread = read(server_fd, buffer, BUFFER_SIZE);
// std::cout << "Server response: " << buffer << std::endl;
char buffer[RESPONSE_SIZE] = {0};
int valread = read(new_server_fd, buffer, RESPONSE_SIZE);
std::cout << key << " set request, server response: " << buffer << std::endl;

close(new_server_fd);
}
}

Expand All @@ -234,26 +294,65 @@ namespace gloo
}
else
{
// 向服务器发送消息
std::string key_with_data = POST_ACTION_GET + ":" + key;
const char *message = key_with_data.c_str();
send(server_fd, message, strlen(message), 0);
// 创建 socket
int new_server_fd = socket(AF_INET, SOCK_STREAM, 0);
if (new_server_fd == -1)
{
auto err = std::string("Socket creation failed: ") + strerror(errno);
GLOO_THROW(err);
}

// 设置服务器地址信息
server_address.sin_family = AF_INET;
server_address.sin_port = htons(port_);

// 将 IP 地址从文本转换为二进制形式
if (inet_pton(AF_INET, host_ip_.c_str(), &server_address.sin_addr) <= 0)
{
auto err = std::string("Invalid address: ") + strerror(errno);
GLOO_THROW(err);
}

// 连接服务器
if (connect(new_server_fd, (struct sockaddr *)&server_address, sizeof(server_address)) < 0)
{
auto err = std::string("Connection to server failed: ") + strerror(errno);
GLOO_THROW(err);
}

// send action
std::string act_data = POST_ACTION_GET;
const char *message = act_data.c_str();
send(new_server_fd, message, strlen(message), 0);
// std::cout << "Message sent to server." << std::endl;

// send key size
size_t len = key.length();
std::string len_str = std::to_string(len);
len_str = std::string(SIZE_OF_SIZE - len_str.length(), '0') + len_str;
message = len_str.c_str();
send(new_server_fd, message, strlen(message), 0);

// send key
message = key.c_str();
send(new_server_fd, message, strlen(message), 0);

// 读取服务器响应
char buffer[BUFFER_SIZE] = {0};
int valread = read(server_fd, buffer, BUFFER_SIZE);
int valread = read(new_server_fd, buffer, BUFFER_SIZE);
if (valread > 0)
{
std::string buffer_str = std::string(buffer);
// std::cout << "Server response: " << buffer_str << std::endl;
std::cout << key << " get request, server response: " << buffer_str << std::endl;

return std::vector<char>(buffer_str.begin(), buffer_str.end());
return std::vector<char>(buffer, buffer + valread);
}
else
{
GLOO_THROW("Server response failed!");
}

close(new_server_fd);
}
}

Expand All @@ -268,7 +367,7 @@ namespace gloo
{
auto data = get(key);
std::string buffer_str(data.begin(), data.end());
// std::cout << "key: " << key << ", data: <" << buffer_str << ">" << std::endl;
std::cout << "key: " << key << ", data: <" << buffer_str << ">" << std::endl;
if (buffer_str == NOT_FOUND)
{
return false;
Expand All @@ -288,6 +387,7 @@ namespace gloo
}
/* sleep override */
std::this_thread::sleep_for(std::chrono::milliseconds(10));
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
}
}

Expand Down
Loading