Fully abstract connection type

Abstract common interface of connection type, so Redis can hide the
implementation and uplayer only calls connection API without macro.

               uplayer
                  |
           connection layer
             /          \
          socket        TLS

Currently, for both socket and TLS, all the methods of connection type
are declared as static functions.

It's possible to build TLS(even socket) as a shared library, and Redis
loads it dynamically in the next step.

Also add helper function connTypeOfCluster() and
connTypeOfReplication() to simplify the code:
link->conn = server.tls_cluster ? connCreateTLS() : connCreateSocket();
-> link->conn = connCreate(connTypeOfCluster());

Signed-off-by: zhenwei pi <pizhenwei@bytedance.com>
This commit is contained in:
zhenwei pi 2022-07-27 10:46:31 +08:00
parent c4c02f8036
commit 1234e3a562
7 changed files with 72 additions and 43 deletions

View File

@ -119,6 +119,14 @@ dictType clusterNodesBlackListDictType = {
NULL /* allow to expand */ NULL /* allow to expand */
}; };
static int connTypeOfCluster() {
if (server.tls_cluster) {
return CONN_TYPE_TLS;
}
return CONN_TYPE_SOCKET;
}
/* ----------------------------------------------------------------------------- /* -----------------------------------------------------------------------------
* Initialization * Initialization
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
@ -865,6 +873,7 @@ void clusterAcceptHandler(aeEventLoop *el, int fd, void *privdata, int mask) {
int cport, cfd; int cport, cfd;
int max = MAX_CLUSTER_ACCEPTS_PER_CALL; int max = MAX_CLUSTER_ACCEPTS_PER_CALL;
char cip[NET_IP_STR_LEN]; char cip[NET_IP_STR_LEN];
int require_auth = TLS_CLIENT_AUTH_YES;
UNUSED(el); UNUSED(el);
UNUSED(mask); UNUSED(mask);
UNUSED(privdata); UNUSED(privdata);
@ -882,8 +891,7 @@ void clusterAcceptHandler(aeEventLoop *el, int fd, void *privdata, int mask) {
return; return;
} }
connection *conn = server.tls_cluster ? connection *conn = connCreateAccepted(connTypeOfCluster(), cfd, &require_auth);
connCreateAcceptedTLS(cfd, TLS_CLIENT_AUTH_YES) : connCreateAcceptedSocket(cfd);
/* Make sure connection is not in an error state */ /* Make sure connection is not in an error state */
if (connGetState(conn) != CONN_STATE_ACCEPTING) { if (connGetState(conn) != CONN_STATE_ACCEPTING) {
@ -3969,7 +3977,7 @@ static int clusterNodeCronHandleReconnect(clusterNode *node, mstime_t handshake_
if (node->link == NULL) { if (node->link == NULL) {
clusterLink *link = createClusterLink(node); clusterLink *link = createClusterLink(node);
link->conn = server.tls_cluster ? connCreateTLS() : connCreateSocket(); link->conn = connCreate(connTypeOfCluster());
connSetPrivateData(link->conn, link); connSetPrivateData(link->conn, link);
if (connConnect(link->conn, node->ip, node->cport, server.bind_source_addr, if (connConnect(link->conn, node->ip, node->cport, server.bind_source_addr,
clusterLinkConnectHandler) == -1) { clusterLinkConnectHandler) == -1) {
@ -6175,8 +6183,8 @@ migrateCachedSocket* migrateGetSocket(client *c, robj *host, robj *port, long ti
dictDelete(server.migrate_cached_sockets,dictGetKey(de)); dictDelete(server.migrate_cached_sockets,dictGetKey(de));
} }
/* Create the socket */ /* Create the connection */
conn = server.tls_cluster ? connCreateTLS() : connCreateSocket(); conn = connCreate(connTypeOfCluster());
if (connBlockingConnect(conn, host->ptr, atoi(port->ptr), timeout) if (connBlockingConnect(conn, host->ptr, atoi(port->ptr), timeout)
!= C_OK) { != C_OK) {
addReplyError(c,"-IOERR error or timeout connecting to the client"); addReplyError(c,"-IOERR error or timeout connecting to the client");

View File

@ -152,3 +152,18 @@ void *connTypeGetClientCtx(int type) {
return NULL; return NULL;
} }
connection *connCreate(int type) {
ConnectionType *ct = connectionByType(type);
serverAssert(ct && ct->conn_create);
return ct->conn_create();
}
connection *connCreateAccepted(int type, int fd, void *priv) {
ConnectionType *ct = connectionByType(type);
serverAssert(ct && ct->conn_create_accepted);
return ct->conn_create_accepted(fd, priv);
}

View File

@ -74,6 +74,8 @@ typedef struct ConnectionType {
int (*addr)(connection *conn, char *ip, size_t ip_len, int *port, int remote); int (*addr)(connection *conn, char *ip, size_t ip_len, int *port, int remote);
/* create/close connection */ /* create/close connection */
connection* (*conn_create)(void);
connection* (*conn_create_accepted)(int fd, void *priv);
void (*close)(struct connection *conn); void (*close)(struct connection *conn);
/* connect & accept */ /* connect & accept */
@ -290,12 +292,6 @@ static inline int connAddrSockName(connection *conn, char *ip, size_t ip_len, in
return connAddr(conn, ip, ip_len, port, 0); return connAddr(conn, ip, ip_len, port, 0);
} }
connection *connCreateSocket();
connection *connCreateAcceptedSocket(int fd);
connection *connCreateTLS();
connection *connCreateAcceptedTLS(int fd, int require_auth);
static inline int connGetState(connection *conn) { static inline int connGetState(connection *conn) {
return conn->state; return conn->state;
} }
@ -358,6 +354,16 @@ int connTypeInitialize();
/* Register a connection type into redis connection framework */ /* Register a connection type into redis connection framework */
int connTypeRegister(ConnectionType *ct); int connTypeRegister(ConnectionType *ct);
/* Lookup a connection type by index */
ConnectionType *connectionByType(int type);
/* Create a connection of specified type */
connection *connCreate(int type);
/* Create a accepted connection of specified type.
* @priv is connection type specified argument */
connection *connCreateAccepted(int type, int fd, void *priv);
/* Configure a connection type. A typical case is to configure TLS. /* Configure a connection type. A typical case is to configure TLS.
* @priv is connection type specified, * @priv is connection type specified,
* @reconfigure is boolean type to specify if overwrite the original config */ * @reconfigure is boolean type to specify if overwrite the original config */

View File

@ -1344,7 +1344,7 @@ void acceptTcpHandler(aeEventLoop *el, int fd, void *privdata, int mask) {
return; return;
} }
serverLog(LL_VERBOSE,"Accepted %s:%d", cip, cport); serverLog(LL_VERBOSE,"Accepted %s:%d", cip, cport);
acceptCommonHandler(connCreateAcceptedSocket(cfd),0,cip); acceptCommonHandler(connCreateAccepted(CONN_TYPE_SOCKET, cfd, NULL),0,cip);
} }
} }
@ -1364,7 +1364,7 @@ void acceptTLSHandler(aeEventLoop *el, int fd, void *privdata, int mask) {
return; return;
} }
serverLog(LL_VERBOSE,"Accepted %s:%d", cip, cport); serverLog(LL_VERBOSE,"Accepted %s:%d", cip, cport);
acceptCommonHandler(connCreateAcceptedTLS(cfd, server.tls_auth_clients),0,cip); acceptCommonHandler(connCreateAccepted(CONN_TYPE_TLS, cfd, &server.tls_auth_clients),0,cip);
} }
} }
@ -1383,7 +1383,7 @@ void acceptUnixHandler(aeEventLoop *el, int fd, void *privdata, int mask) {
return; return;
} }
serverLog(LL_VERBOSE,"Accepted connection to %s", server.unixsocket); serverLog(LL_VERBOSE,"Accepted connection to %s", server.unixsocket);
acceptCommonHandler(connCreateAcceptedSocket(cfd),CLIENT_UNIX_SOCKET,NULL); acceptCommonHandler(connCreateAccepted(CONN_TYPE_SOCKET, cfd, NULL),CLIENT_UNIX_SOCKET,NULL);
} }
} }

View File

@ -55,6 +55,13 @@ int cancelReplicationHandshake(int reconnect);
int RDBGeneratedByReplication = 0; int RDBGeneratedByReplication = 0;
/* --------------------------- Utility functions ---------------------------- */ /* --------------------------- Utility functions ---------------------------- */
static int connTypeOfReplication() {
if (server.tls_replication) {
return CONN_TYPE_TLS;
}
return CONN_TYPE_SOCKET;
}
/* Return the pointer to a string representing the slave ip:listening_port /* Return the pointer to a string representing the slave ip:listening_port
* pair. Mostly useful for logging, since we want to log a slave using its * pair. Mostly useful for logging, since we want to log a slave using its
@ -2864,7 +2871,7 @@ write_error: /* Handle sendCommand() errors. */
} }
int connectWithMaster(void) { int connectWithMaster(void) {
server.repl_transfer_s = server.tls_replication ? connCreateTLS() : connCreateSocket(); server.repl_transfer_s = connCreate(connTypeOfReplication());
if (connConnect(server.repl_transfer_s, server.masterhost, server.masterport, if (connConnect(server.repl_transfer_s, server.masterhost, server.masterport,
server.bind_source_addr, syncWithMaster) == C_ERR) { server.bind_source_addr, syncWithMaster) == C_ERR) {
serverLog(LL_WARNING,"Unable to connect to MASTER: %s", serverLog(LL_WARNING,"Unable to connect to MASTER: %s",

View File

@ -49,7 +49,7 @@
* depending on the implementation (for TCP they are; for TLS they aren't). * depending on the implementation (for TCP they are; for TLS they aren't).
*/ */
ConnectionType CT_Socket; static ConnectionType CT_Socket;
/* When a connection is created we must know its type already, but the /* When a connection is created we must know its type already, but the
* underlying socket may or may not exist: * underlying socket may or may not exist:
@ -74,7 +74,7 @@ ConnectionType CT_Socket;
* be embedded in different structs, not just client. * be embedded in different structs, not just client.
*/ */
connection *connCreateSocket() { static connection *connCreateSocket(void) {
connection *conn = zcalloc(sizeof(connection)); connection *conn = zcalloc(sizeof(connection));
conn->type = &CT_Socket; conn->type = &CT_Socket;
conn->fd = -1; conn->fd = -1;
@ -92,7 +92,8 @@ connection *connCreateSocket() {
* is not in an error state (which is not possible for a socket connection, * is not in an error state (which is not possible for a socket connection,
* but could but possible with other protocols). * but could but possible with other protocols).
*/ */
connection *connCreateAcceptedSocket(int fd) { static connection *connCreateAcceptedSocket(int fd, void *priv) {
UNUSED(priv);
connection *conn = connCreateSocket(); connection *conn = connCreateSocket();
conn->fd = fd; conn->fd = fd;
conn->state = CONN_STATE_ACCEPTING; conn->state = CONN_STATE_ACCEPTING;
@ -348,7 +349,7 @@ static int connSocketGetType(connection *conn) {
return CONN_TYPE_SOCKET; return CONN_TYPE_SOCKET;
} }
ConnectionType CT_Socket = { static ConnectionType CT_Socket = {
/* connection type */ /* connection type */
.get_type = connSocketGetType, .get_type = connSocketGetType,
@ -362,6 +363,8 @@ ConnectionType CT_Socket = {
.addr = connSocketAddr, .addr = connSocketAddr,
/* create/close connection */ /* create/close connection */
.conn_create = connCreateSocket,
.conn_create_accepted = connCreateAcceptedSocket,
.close = connSocketClose, .close = connSocketClose,
/* connect & accept */ /* connect & accept */

View File

@ -56,8 +56,6 @@
#define REDIS_TLS_PROTO_DEFAULT (REDIS_TLS_PROTO_TLSv1_2) #define REDIS_TLS_PROTO_DEFAULT (REDIS_TLS_PROTO_TLSv1_2)
#endif #endif
extern ConnectionType CT_Socket;
static SSL_CTX *redis_tls_ctx = NULL; static SSL_CTX *redis_tls_ctx = NULL;
static SSL_CTX *redis_tls_client_ctx = NULL; static SSL_CTX *redis_tls_client_ctx = NULL;
@ -421,7 +419,7 @@ error:
#define TLSCONN_DEBUG(fmt, ...) #define TLSCONN_DEBUG(fmt, ...)
#endif #endif
ConnectionType CT_TLS; static ConnectionType CT_TLS;
/* Normal socket connections have a simple events/handler correlation. /* Normal socket connections have a simple events/handler correlation.
* *
@ -466,7 +464,7 @@ static connection *createTLSConnection(int client_side) {
return (connection *) conn; return (connection *) conn;
} }
connection *connCreateTLS(void) { static connection *connCreateTLS(void) {
return createTLSConnection(1); return createTLSConnection(1);
} }
@ -487,7 +485,8 @@ static void updateTLSError(tls_connection *conn) {
* Callers should use connGetState() and verify the created connection * Callers should use connGetState() and verify the created connection
* is not in an error state. * is not in an error state.
*/ */
connection *connCreateAcceptedTLS(int fd, int require_auth) { static connection *connCreateAcceptedTLS(int fd, void *priv) {
int require_auth = *(int *)priv;
tls_connection *conn = (tls_connection *) createTLSConnection(0); tls_connection *conn = (tls_connection *) createTLSConnection(0);
conn->c.fd = fd; conn->c.fd = fd;
conn->c.state = CONN_STATE_ACCEPTING; conn->c.state = CONN_STATE_ACCEPTING;
@ -550,7 +549,7 @@ static int handleSSLReturnCode(tls_connection *conn, int ret_value, WantIOType *
return 0; return 0;
} }
void registerSSLEvent(tls_connection *conn, WantIOType want) { static void registerSSLEvent(tls_connection *conn, WantIOType want) {
int mask = aeGetFileEvents(server.el, conn->c.fd); int mask = aeGetFileEvents(server.el, conn->c.fd);
switch (want) { switch (want) {
@ -570,7 +569,7 @@ void registerSSLEvent(tls_connection *conn, WantIOType want) {
} }
} }
void updateSSLEvent(tls_connection *conn) { static void updateSSLEvent(tls_connection *conn) {
int mask = aeGetFileEvents(server.el, conn->c.fd); int mask = aeGetFileEvents(server.el, conn->c.fd);
int need_read = conn->c.read_handler || (conn->flags & TLS_CONN_FLAG_WRITE_WANT_READ); int need_read = conn->c.read_handler || (conn->flags & TLS_CONN_FLAG_WRITE_WANT_READ);
int need_write = conn->c.write_handler || (conn->flags & TLS_CONN_FLAG_READ_WANT_WRITE); int need_write = conn->c.write_handler || (conn->flags & TLS_CONN_FLAG_READ_WANT_WRITE);
@ -744,7 +743,7 @@ static void connTLSClose(connection *conn_) {
conn->pending_list_node = NULL; conn->pending_list_node = NULL;
} }
CT_Socket.close(conn_); connectionByType(CONN_TYPE_SOCKET)->close(conn_);
} }
static int connTLSAccept(connection *_conn, ConnectionCallbackFunc accept_handler) { static int connTLSAccept(connection *_conn, ConnectionCallbackFunc accept_handler) {
@ -783,7 +782,7 @@ static int connTLSConnect(connection *conn_, const char *addr, int port, const c
ERR_clear_error(); ERR_clear_error();
/* Initiate Socket connection first */ /* Initiate Socket connection first */
if (CT_Socket.connect(conn_, addr, port, src_addr, connect_handler) == C_ERR) return C_ERR; if (connectionByType(CONN_TYPE_SOCKET)->connect(conn_, addr, port, src_addr, connect_handler) == C_ERR) return C_ERR;
/* Return now, once the socket is connected we'll initiate /* Return now, once the socket is connected we'll initiate
* TLS connection from the event handler. * TLS connection from the event handler.
@ -911,7 +910,7 @@ static const char *connTLSGetLastError(connection *conn_) {
return NULL; return NULL;
} }
int connTLSSetWriteHandler(connection *conn, ConnectionCallbackFunc func, int barrier) { static int connTLSSetWriteHandler(connection *conn, ConnectionCallbackFunc func, int barrier) {
conn->write_handler = func; conn->write_handler = func;
if (barrier) if (barrier)
conn->flags |= CONN_FLAG_WRITE_BARRIER; conn->flags |= CONN_FLAG_WRITE_BARRIER;
@ -921,7 +920,7 @@ int connTLSSetWriteHandler(connection *conn, ConnectionCallbackFunc func, int ba
return C_OK; return C_OK;
} }
int connTLSSetReadHandler(connection *conn, ConnectionCallbackFunc func) { static int connTLSSetReadHandler(connection *conn, ConnectionCallbackFunc func) {
conn->read_handler = func; conn->read_handler = func;
updateSSLEvent((tls_connection *) conn); updateSSLEvent((tls_connection *) conn);
return C_OK; return C_OK;
@ -946,7 +945,7 @@ static int connTLSBlockingConnect(connection *conn_, const char *addr, int port,
if (conn->c.state != CONN_STATE_NONE) return C_ERR; if (conn->c.state != CONN_STATE_NONE) return C_ERR;
/* Initiate socket blocking connect first */ /* Initiate socket blocking connect first */
if (CT_Socket.blocking_connect(conn_, addr, port, timeout) == C_ERR) return C_ERR; if (connectionByType(CONN_TYPE_SOCKET)->blocking_connect(conn_, addr, port, timeout) == C_ERR) return C_ERR;
/* Initiate TLS connection now. We set up a send/recv timeout on the socket, /* Initiate TLS connection now. We set up a send/recv timeout on the socket,
* which means the specified timeout will not be enforced accurately. */ * which means the specified timeout will not be enforced accurately. */
@ -1072,7 +1071,7 @@ static void *tlsGetClientCtx(void) {
return redis_tls_client_ctx; return redis_tls_client_ctx;
} }
ConnectionType CT_TLS = { static ConnectionType CT_TLS = {
/* connection type */ /* connection type */
.get_type = connTLSGetType, .get_type = connTLSGetType,
@ -1086,6 +1085,8 @@ ConnectionType CT_TLS = {
.addr = connTLSAddr, .addr = connTLSAddr,
/* create/close connection */ /* create/close connection */
.conn_create = connCreateTLS,
.conn_create_accepted = connCreateAcceptedTLS,
.close = connTLSClose, .close = connTLSClose,
/* connect & accept */ /* connect & accept */
@ -1126,15 +1127,4 @@ int RedisRegisterConnectionTypeTLS()
return C_ERR; return C_ERR;
} }
connection *connCreateTLS(void) {
return NULL;
}
connection *connCreateAcceptedTLS(int fd, int require_auth) {
UNUSED(fd);
UNUSED(require_auth);
return NULL;
}
#endif #endif