Abstract accept handler

Abstract accept handler for socket&TLS, and add helper function
'connAcceptHandler' to get accept handler by specified type.

Also move acceptTcpHandler into socket.c, and move
acceptTLSHandler into tls.c.

Signed-off-by: zhenwei pi <pizhenwei@bytedance.com>
This commit is contained in:
zhenwei pi 2022-07-27 11:47:50 +08:00
parent 41fff55d52
commit 0ae02ce95b
7 changed files with 62 additions and 50 deletions

View File

@ -2430,7 +2430,7 @@ static int updateHZ(const char **err) {
} }
static int updatePort(const char **err) { static int updatePort(const char **err) {
if (changeListenPort(server.port, &server.ipfd, acceptTcpHandler) == C_ERR) { if (changeListenPort(server.port, &server.ipfd, connAcceptHandler(CONN_TYPE_SOCKET)) == C_ERR) {
*err = "Unable to listen on this port. Check server logs."; *err = "Unable to listen on this port. Check server logs.";
return 0; return 0;
} }
@ -2591,7 +2591,7 @@ static int applyTLSPort(const char **err) {
return 0; return 0;
} }
if (changeListenPort(server.tls_port, &server.tlsfd, acceptTLSHandler) == C_ERR) { if (changeListenPort(server.tls_port, &server.tlsfd, connAcceptHandler(CONN_TYPE_TLS)) == C_ERR) {
*err = "Unable to listen on this port. Check server logs."; *err = "Unable to listen on this port. Check server logs.";
return 0; return 0;
} }

View File

@ -36,8 +36,11 @@
#include <string.h> #include <string.h>
#include <sys/uio.h> #include <sys/uio.h>
#include "ae.h"
#define CONN_INFO_LEN 32 #define CONN_INFO_LEN 32
#define CONN_ADDR_STR_LEN 128 /* Similar to INET6_ADDRSTRLEN, hoping to handle other protocols. */ #define CONN_ADDR_STR_LEN 128 /* Similar to INET6_ADDRSTRLEN, hoping to handle other protocols. */
#define MAX_ACCEPTS_PER_CALL 1000
struct aeEventLoop; struct aeEventLoop;
typedef struct connection connection; typedef struct connection connection;
@ -71,6 +74,7 @@ typedef struct ConnectionType {
/* ae & accept & listen & error & address handler */ /* ae & accept & listen & error & address handler */
void (*ae_handler)(struct aeEventLoop *el, int fd, void *clientData, int mask); void (*ae_handler)(struct aeEventLoop *el, int fd, void *clientData, int mask);
aeFileProc *accept_handler;
int (*addr)(connection *conn, char *ip, size_t ip_len, int *port, int remote); int (*addr)(connection *conn, char *ip, size_t ip_len, int *port, int remote);
/* create/close connection */ /* create/close connection */
@ -381,6 +385,14 @@ int connTypeHasPendingData(void);
/* walk all the connection types and process pending data for each connection type */ /* walk all the connection types and process pending data for each connection type */
int connTypeProcessPendingData(void); int connTypeProcessPendingData(void);
/* Get accept_handler of a connection type */
static inline aeFileProc *connAcceptHandler(int type) {
ConnectionType *ct = connectionByType(type);
if (ct)
return ct->accept_handler;
return NULL;
}
int RedisRegisterConnectionTypeSocket(); int RedisRegisterConnectionTypeSocket();
int RedisRegisterConnectionTypeTLS(); int RedisRegisterConnectionTypeTLS();

View File

@ -1255,8 +1255,7 @@ void clientAcceptHandler(connection *conn) {
c); c);
} }
#define MAX_ACCEPTS_PER_CALL 1000 void acceptCommonHandler(connection *conn, int flags, char *ip) {
static void acceptCommonHandler(connection *conn, int flags, char *ip) {
client *c; client *c;
char conninfo[100]; char conninfo[100];
UNUSED(ip); UNUSED(ip);
@ -1328,46 +1327,6 @@ static void acceptCommonHandler(connection *conn, int flags, char *ip) {
} }
} }
void acceptTcpHandler(aeEventLoop *el, int fd, void *privdata, int mask) {
int cport, cfd, max = MAX_ACCEPTS_PER_CALL;
char cip[NET_IP_STR_LEN];
UNUSED(el);
UNUSED(mask);
UNUSED(privdata);
while(max--) {
cfd = anetTcpAccept(server.neterr, fd, cip, sizeof(cip), &cport);
if (cfd == ANET_ERR) {
if (errno != EWOULDBLOCK)
serverLog(LL_WARNING,
"Accepting client connection: %s", server.neterr);
return;
}
serverLog(LL_VERBOSE,"Accepted %s:%d", cip, cport);
acceptCommonHandler(connCreateAccepted(CONN_TYPE_SOCKET, cfd, NULL),0,cip);
}
}
void acceptTLSHandler(aeEventLoop *el, int fd, void *privdata, int mask) {
int cport, cfd, max = MAX_ACCEPTS_PER_CALL;
char cip[NET_IP_STR_LEN];
UNUSED(el);
UNUSED(mask);
UNUSED(privdata);
while(max--) {
cfd = anetTcpAccept(server.neterr, fd, cip, sizeof(cip), &cport);
if (cfd == ANET_ERR) {
if (errno != EWOULDBLOCK)
serverLog(LL_WARNING,
"Accepting client connection: %s", server.neterr);
return;
}
serverLog(LL_VERBOSE,"Accepted %s:%d", cip, cport);
acceptCommonHandler(connCreateAccepted(CONN_TYPE_TLS, cfd, &server.tls_auth_clients),0,cip);
}
}
void acceptUnixHandler(aeEventLoop *el, int fd, void *privdata, int mask) { void acceptUnixHandler(aeEventLoop *el, int fd, void *privdata, int mask) {
int cfd, max = MAX_ACCEPTS_PER_CALL; int cfd, max = MAX_ACCEPTS_PER_CALL;
UNUSED(el); UNUSED(el);

View File

@ -2586,10 +2586,10 @@ void initServer(void) {
/* Create an event handler for accepting new connections in TCP and Unix /* Create an event handler for accepting new connections in TCP and Unix
* domain sockets. */ * domain sockets. */
if (createSocketAcceptHandler(&server.ipfd, acceptTcpHandler) != C_OK) { if (createSocketAcceptHandler(&server.ipfd, connAcceptHandler(CONN_TYPE_SOCKET)) != C_OK) {
serverPanic("Unrecoverable error creating TCP socket accept handler."); serverPanic("Unrecoverable error creating TCP socket accept handler.");
} }
if (createSocketAcceptHandler(&server.tlsfd, acceptTLSHandler) != C_OK) { if (createSocketAcceptHandler(&server.tlsfd, connAcceptHandler(CONN_TYPE_TLS)) != C_OK) {
serverPanic("Unrecoverable error creating TLS socket accept handler."); serverPanic("Unrecoverable error creating TLS socket accept handler.");
} }
if (createSocketAcceptHandler(&server.sofd, acceptUnixHandler) != C_OK) { if (createSocketAcceptHandler(&server.sofd, acceptUnixHandler) != C_OK) {
@ -6282,10 +6282,10 @@ int changeBindAddr(void) {
} }
/* Create TCP and TLS event handlers */ /* Create TCP and TLS event handlers */
if (createSocketAcceptHandler(&server.ipfd, acceptTcpHandler) != C_OK) { if (createSocketAcceptHandler(&server.ipfd, connAcceptHandler(CONN_TYPE_SOCKET)) != C_OK) {
serverPanic("Unrecoverable error creating TCP socket accept handler."); serverPanic("Unrecoverable error creating TCP socket accept handler.");
} }
if (createSocketAcceptHandler(&server.tlsfd, acceptTLSHandler) != C_OK) { if (createSocketAcceptHandler(&server.tlsfd, connAcceptHandler(CONN_TYPE_TLS)) != C_OK) {
serverPanic("Unrecoverable error creating TLS socket accept handler."); serverPanic("Unrecoverable error creating TLS socket accept handler.");
} }

View File

@ -2460,8 +2460,7 @@ void setDeferredSetLen(client *c, void *node, long length);
void setDeferredAttributeLen(client *c, void *node, long length); void setDeferredAttributeLen(client *c, void *node, long length);
void setDeferredPushLen(client *c, void *node, long length); void setDeferredPushLen(client *c, void *node, long length);
int processInputBuffer(client *c); int processInputBuffer(client *c);
void acceptTcpHandler(aeEventLoop *el, int fd, void *privdata, int mask); void acceptCommonHandler(connection *conn, int flags, char *ip);
void acceptTLSHandler(aeEventLoop *el, int fd, void *privdata, int mask);
void acceptUnixHandler(aeEventLoop *el, int fd, void *privdata, int mask); void acceptUnixHandler(aeEventLoop *el, int fd, void *privdata, int mask);
void readQueryFromClient(connection *conn); void readQueryFromClient(connection *conn);
int prepareClientToWrite(client *c); int prepareClientToWrite(client *c);

View File

@ -301,6 +301,26 @@ static void connSocketEventHandler(struct aeEventLoop *el, int fd, void *clientD
} }
} }
static void connSocketAcceptHandler(aeEventLoop *el, int fd, void *privdata, int mask) {
int cport, cfd, max = MAX_ACCEPTS_PER_CALL;
char cip[NET_IP_STR_LEN];
UNUSED(el);
UNUSED(mask);
UNUSED(privdata);
while(max--) {
cfd = anetTcpAccept(server.neterr, fd, cip, sizeof(cip), &cport);
if (cfd == ANET_ERR) {
if (errno != EWOULDBLOCK)
serverLog(LL_WARNING,
"Accepting client connection: %s", server.neterr);
return;
}
serverLog(LL_VERBOSE,"Accepted %s:%d", cip, cport);
acceptCommonHandler(connCreateAcceptedSocket(cfd, NULL),0,cip);
}
}
static int connSocketAddr(connection *conn, char *ip, size_t ip_len, int *port, int remote) { static int connSocketAddr(connection *conn, char *ip, size_t ip_len, int *port, int remote) {
if (anetFdToString(conn->fd, ip, ip_len, port, remote) == 0) if (anetFdToString(conn->fd, ip, ip_len, port, remote) == 0)
return C_OK; return C_OK;
@ -360,6 +380,7 @@ static ConnectionType CT_Socket = {
/* ae & accept & listen & error & address handler */ /* ae & accept & listen & error & address handler */
.ae_handler = connSocketEventHandler, .ae_handler = connSocketEventHandler,
.accept_handler = connSocketAcceptHandler,
.addr = connSocketAddr, .addr = connSocketAddr,
/* create/close connection */ /* create/close connection */

View File

@ -719,6 +719,26 @@ static void tlsEventHandler(struct aeEventLoop *el, int fd, void *clientData, in
tlsHandleEvent(conn, mask); tlsHandleEvent(conn, mask);
} }
static void tlsAcceptHandler(aeEventLoop *el, int fd, void *privdata, int mask) {
int cport, cfd, max = MAX_ACCEPTS_PER_CALL;
char cip[NET_IP_STR_LEN];
UNUSED(el);
UNUSED(mask);
UNUSED(privdata);
while(max--) {
cfd = anetTcpAccept(server.neterr, fd, cip, sizeof(cip), &cport);
if (cfd == ANET_ERR) {
if (errno != EWOULDBLOCK)
serverLog(LL_WARNING,
"Accepting client connection: %s", server.neterr);
return;
}
serverLog(LL_VERBOSE,"Accepted %s:%d", cip, cport);
acceptCommonHandler(connCreateAcceptedTLS(cfd, &server.tls_auth_clients),0,cip);
}
}
static int connTLSAddr(connection *conn, char *ip, size_t ip_len, int *port, int remote) { static int connTLSAddr(connection *conn, char *ip, size_t ip_len, int *port, int remote) {
return anetFdToString(conn->fd, ip, ip_len, port, remote); return anetFdToString(conn->fd, ip, ip_len, port, remote);
} }
@ -1082,6 +1102,7 @@ static ConnectionType CT_TLS = {
/* ae & accept & listen & error & address handler */ /* ae & accept & listen & error & address handler */
.ae_handler = tlsEventHandler, .ae_handler = tlsEventHandler,
.accept_handler = tlsAcceptHandler,
.addr = connTLSAddr, .addr = connTLSAddr,
/* create/close connection */ /* create/close connection */