Abstract accept handler

Abstract accept handler for socket&TLS, and add helper function
'connAcceptHandler' to get accept handler by specified type.

Also move acceptTcpHandler into socket.c, and move
acceptTLSHandler into tls.c.

Signed-off-by: zhenwei pi <pizhenwei@bytedance.com>
This commit is contained in:
zhenwei pi 2022-07-27 11:47:50 +08:00
parent 41fff55d52
commit 0ae02ce95b
7 changed files with 62 additions and 50 deletions

View File

@ -2430,7 +2430,7 @@ static int updateHZ(const char **err) {
}
static int updatePort(const char **err) {
if (changeListenPort(server.port, &server.ipfd, acceptTcpHandler) == C_ERR) {
if (changeListenPort(server.port, &server.ipfd, connAcceptHandler(CONN_TYPE_SOCKET)) == C_ERR) {
*err = "Unable to listen on this port. Check server logs.";
return 0;
}
@ -2591,7 +2591,7 @@ static int applyTLSPort(const char **err) {
return 0;
}
if (changeListenPort(server.tls_port, &server.tlsfd, acceptTLSHandler) == C_ERR) {
if (changeListenPort(server.tls_port, &server.tlsfd, connAcceptHandler(CONN_TYPE_TLS)) == C_ERR) {
*err = "Unable to listen on this port. Check server logs.";
return 0;
}

View File

@ -36,8 +36,11 @@
#include <string.h>
#include <sys/uio.h>
#include "ae.h"
#define CONN_INFO_LEN 32
#define CONN_ADDR_STR_LEN 128 /* Similar to INET6_ADDRSTRLEN, hoping to handle other protocols. */
#define MAX_ACCEPTS_PER_CALL 1000
struct aeEventLoop;
typedef struct connection connection;
@ -71,6 +74,7 @@ typedef struct ConnectionType {
/* ae & accept & listen & error & address handler */
void (*ae_handler)(struct aeEventLoop *el, int fd, void *clientData, int mask);
aeFileProc *accept_handler;
int (*addr)(connection *conn, char *ip, size_t ip_len, int *port, int remote);
/* create/close connection */
@ -381,6 +385,14 @@ int connTypeHasPendingData(void);
/* walk all the connection types and process pending data for each connection type */
int connTypeProcessPendingData(void);
/* Get accept_handler of a connection type */
static inline aeFileProc *connAcceptHandler(int type) {
ConnectionType *ct = connectionByType(type);
if (ct)
return ct->accept_handler;
return NULL;
}
int RedisRegisterConnectionTypeSocket();
int RedisRegisterConnectionTypeTLS();

View File

@ -1255,8 +1255,7 @@ void clientAcceptHandler(connection *conn) {
c);
}
#define MAX_ACCEPTS_PER_CALL 1000
static void acceptCommonHandler(connection *conn, int flags, char *ip) {
void acceptCommonHandler(connection *conn, int flags, char *ip) {
client *c;
char conninfo[100];
UNUSED(ip);
@ -1328,46 +1327,6 @@ static void acceptCommonHandler(connection *conn, int flags, char *ip) {
}
}
void acceptTcpHandler(aeEventLoop *el, int fd, void *privdata, int mask) {
int cport, cfd, max = MAX_ACCEPTS_PER_CALL;
char cip[NET_IP_STR_LEN];
UNUSED(el);
UNUSED(mask);
UNUSED(privdata);
while(max--) {
cfd = anetTcpAccept(server.neterr, fd, cip, sizeof(cip), &cport);
if (cfd == ANET_ERR) {
if (errno != EWOULDBLOCK)
serverLog(LL_WARNING,
"Accepting client connection: %s", server.neterr);
return;
}
serverLog(LL_VERBOSE,"Accepted %s:%d", cip, cport);
acceptCommonHandler(connCreateAccepted(CONN_TYPE_SOCKET, cfd, NULL),0,cip);
}
}
void acceptTLSHandler(aeEventLoop *el, int fd, void *privdata, int mask) {
int cport, cfd, max = MAX_ACCEPTS_PER_CALL;
char cip[NET_IP_STR_LEN];
UNUSED(el);
UNUSED(mask);
UNUSED(privdata);
while(max--) {
cfd = anetTcpAccept(server.neterr, fd, cip, sizeof(cip), &cport);
if (cfd == ANET_ERR) {
if (errno != EWOULDBLOCK)
serverLog(LL_WARNING,
"Accepting client connection: %s", server.neterr);
return;
}
serverLog(LL_VERBOSE,"Accepted %s:%d", cip, cport);
acceptCommonHandler(connCreateAccepted(CONN_TYPE_TLS, cfd, &server.tls_auth_clients),0,cip);
}
}
void acceptUnixHandler(aeEventLoop *el, int fd, void *privdata, int mask) {
int cfd, max = MAX_ACCEPTS_PER_CALL;
UNUSED(el);

View File

@ -2586,10 +2586,10 @@ void initServer(void) {
/* Create an event handler for accepting new connections in TCP and Unix
* domain sockets. */
if (createSocketAcceptHandler(&server.ipfd, acceptTcpHandler) != C_OK) {
if (createSocketAcceptHandler(&server.ipfd, connAcceptHandler(CONN_TYPE_SOCKET)) != C_OK) {
serverPanic("Unrecoverable error creating TCP socket accept handler.");
}
if (createSocketAcceptHandler(&server.tlsfd, acceptTLSHandler) != C_OK) {
if (createSocketAcceptHandler(&server.tlsfd, connAcceptHandler(CONN_TYPE_TLS)) != C_OK) {
serverPanic("Unrecoverable error creating TLS socket accept handler.");
}
if (createSocketAcceptHandler(&server.sofd, acceptUnixHandler) != C_OK) {
@ -6282,10 +6282,10 @@ int changeBindAddr(void) {
}
/* Create TCP and TLS event handlers */
if (createSocketAcceptHandler(&server.ipfd, acceptTcpHandler) != C_OK) {
if (createSocketAcceptHandler(&server.ipfd, connAcceptHandler(CONN_TYPE_SOCKET)) != C_OK) {
serverPanic("Unrecoverable error creating TCP socket accept handler.");
}
if (createSocketAcceptHandler(&server.tlsfd, acceptTLSHandler) != C_OK) {
if (createSocketAcceptHandler(&server.tlsfd, connAcceptHandler(CONN_TYPE_TLS)) != C_OK) {
serverPanic("Unrecoverable error creating TLS socket accept handler.");
}

View File

@ -2460,8 +2460,7 @@ void setDeferredSetLen(client *c, void *node, long length);
void setDeferredAttributeLen(client *c, void *node, long length);
void setDeferredPushLen(client *c, void *node, long length);
int processInputBuffer(client *c);
void acceptTcpHandler(aeEventLoop *el, int fd, void *privdata, int mask);
void acceptTLSHandler(aeEventLoop *el, int fd, void *privdata, int mask);
void acceptCommonHandler(connection *conn, int flags, char *ip);
void acceptUnixHandler(aeEventLoop *el, int fd, void *privdata, int mask);
void readQueryFromClient(connection *conn);
int prepareClientToWrite(client *c);

View File

@ -301,6 +301,26 @@ static void connSocketEventHandler(struct aeEventLoop *el, int fd, void *clientD
}
}
static void connSocketAcceptHandler(aeEventLoop *el, int fd, void *privdata, int mask) {
int cport, cfd, max = MAX_ACCEPTS_PER_CALL;
char cip[NET_IP_STR_LEN];
UNUSED(el);
UNUSED(mask);
UNUSED(privdata);
while(max--) {
cfd = anetTcpAccept(server.neterr, fd, cip, sizeof(cip), &cport);
if (cfd == ANET_ERR) {
if (errno != EWOULDBLOCK)
serverLog(LL_WARNING,
"Accepting client connection: %s", server.neterr);
return;
}
serverLog(LL_VERBOSE,"Accepted %s:%d", cip, cport);
acceptCommonHandler(connCreateAcceptedSocket(cfd, NULL),0,cip);
}
}
static int connSocketAddr(connection *conn, char *ip, size_t ip_len, int *port, int remote) {
if (anetFdToString(conn->fd, ip, ip_len, port, remote) == 0)
return C_OK;
@ -360,6 +380,7 @@ static ConnectionType CT_Socket = {
/* ae & accept & listen & error & address handler */
.ae_handler = connSocketEventHandler,
.accept_handler = connSocketAcceptHandler,
.addr = connSocketAddr,
/* create/close connection */

View File

@ -719,6 +719,26 @@ static void tlsEventHandler(struct aeEventLoop *el, int fd, void *clientData, in
tlsHandleEvent(conn, mask);
}
static void tlsAcceptHandler(aeEventLoop *el, int fd, void *privdata, int mask) {
int cport, cfd, max = MAX_ACCEPTS_PER_CALL;
char cip[NET_IP_STR_LEN];
UNUSED(el);
UNUSED(mask);
UNUSED(privdata);
while(max--) {
cfd = anetTcpAccept(server.neterr, fd, cip, sizeof(cip), &cport);
if (cfd == ANET_ERR) {
if (errno != EWOULDBLOCK)
serverLog(LL_WARNING,
"Accepting client connection: %s", server.neterr);
return;
}
serverLog(LL_VERBOSE,"Accepted %s:%d", cip, cport);
acceptCommonHandler(connCreateAcceptedTLS(cfd, &server.tls_auth_clients),0,cip);
}
}
static int connTLSAddr(connection *conn, char *ip, size_t ip_len, int *port, int remote) {
return anetFdToString(conn->fd, ip, ip_len, port, remote);
}
@ -1082,6 +1102,7 @@ static ConnectionType CT_TLS = {
/* ae & accept & listen & error & address handler */
.ae_handler = tlsEventHandler,
.accept_handler = tlsAcceptHandler,
.addr = connTLSAddr,
/* create/close connection */