Module API: fix missing RM_CLIENTINFO_FLAG_SSL. (#7666)

The `REDISMODULE_CLIENTINFO_FLAG_SSL` flag was already a part of the `RedisModuleClientInfo` structure but was not implemented.
This commit is contained in:
Yossi Gottlieb 2020-08-17 17:46:54 +03:00 committed by GitHub
parent fb2a94af3f
commit 64c360c515
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 82 additions and 1 deletions

View File

@ -329,6 +329,11 @@ static ssize_t connSocketSyncReadLine(connection *conn, char *ptr, ssize_t size,
return syncReadLine(conn->fd, ptr, size, timeout);
}
static int connSocketGetType(connection *conn) {
(void) conn;
return CONN_TYPE_SOCKET;
}
ConnectionType CT_Socket = {
.ae_handler = connSocketEventHandler,
@ -343,7 +348,8 @@ ConnectionType CT_Socket = {
.blocking_connect = connSocketBlockingConnect,
.sync_write = connSocketSyncWrite,
.sync_read = connSocketSyncRead,
.sync_readline = connSocketSyncReadLine
.sync_readline = connSocketSyncReadLine,
.get_type = connSocketGetType
};

View File

@ -48,6 +48,9 @@ typedef enum {
#define CONN_FLAG_CLOSE_SCHEDULED (1<<0) /* Closed scheduled by a handler */
#define CONN_FLAG_WRITE_BARRIER (1<<1) /* Write barrier requested */
#define CONN_TYPE_SOCKET 1
#define CONN_TYPE_TLS 2
typedef void (*ConnectionCallbackFunc)(struct connection *conn);
typedef struct ConnectionType {
@ -64,6 +67,7 @@ typedef struct ConnectionType {
ssize_t (*sync_write)(struct connection *conn, char *ptr, ssize_t size, long long timeout);
ssize_t (*sync_read)(struct connection *conn, char *ptr, ssize_t size, long long timeout);
ssize_t (*sync_readline)(struct connection *conn, char *ptr, ssize_t size, long long timeout);
int (*get_type)(struct connection *conn);
} ConnectionType;
struct connection {
@ -194,6 +198,11 @@ static inline ssize_t connSyncReadLine(connection *conn, char *ptr, ssize_t size
return conn->type->sync_readline(conn, ptr, size, timeout);
}
/* Return CONN_TYPE_* for the specified connection */
static inline int connGetType(connection *conn) {
return conn->type->get_type(conn);
}
connection *connCreateSocket();
connection *connCreateAcceptedSocket(int fd);

View File

@ -1753,6 +1753,8 @@ int modulePopulateClientInfoStructure(void *ci, client *client, int structver) {
ci1->flags |= REDISMODULE_CLIENTINFO_FLAG_TRACKING;
if (client->flags & CLIENT_BLOCKED)
ci1->flags |= REDISMODULE_CLIENTINFO_FLAG_BLOCKED;
if (connGetType(client->conn) == CONN_TYPE_TLS)
ci1->flags |= REDISMODULE_CLIENTINFO_FLAG_SSL;
int port;
connPeerToString(client->conn,ci1->addr,sizeof(ci1->addr),&port);

View File

@ -823,6 +823,12 @@ exit:
return nread;
}
static int connTLSGetType(connection *conn_) {
(void) conn_;
return CONN_TYPE_TLS;
}
ConnectionType CT_TLS = {
.ae_handler = tlsEventHandler,
.accept = connTLSAccept,
@ -837,6 +843,7 @@ ConnectionType CT_TLS = {
.sync_write = connTLSSyncWrite,
.sync_read = connTLSSyncRead,
.sync_readline = connTLSSyncReadLine,
.get_type = connTLSGetType
};
int tlsHasPendingData() {

View File

@ -195,6 +195,42 @@ int test_setlfu(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
return REDISMODULE_OK;
}
int test_clientinfo(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
{
(void) argv;
(void) argc;
RedisModuleClientInfo ci = { .version = REDISMODULE_CLIENTINFO_VERSION };
if (RedisModule_GetClientInfoById(&ci, RedisModule_GetClientId(ctx)) == REDISMODULE_ERR) {
RedisModule_ReplyWithError(ctx, "failed to get client info");
return REDISMODULE_OK;
}
RedisModule_ReplyWithArray(ctx, 10);
char flags[512];
snprintf(flags, sizeof(flags) - 1, "%s:%s:%s:%s:%s:%s",
ci.flags & REDISMODULE_CLIENTINFO_FLAG_SSL ? "ssl" : "",
ci.flags & REDISMODULE_CLIENTINFO_FLAG_PUBSUB ? "pubsub" : "",
ci.flags & REDISMODULE_CLIENTINFO_FLAG_BLOCKED ? "blocked" : "",
ci.flags & REDISMODULE_CLIENTINFO_FLAG_TRACKING ? "tracking" : "",
ci.flags & REDISMODULE_CLIENTINFO_FLAG_UNIXSOCKET ? "unixsocket" : "",
ci.flags & REDISMODULE_CLIENTINFO_FLAG_MULTI ? "multi" : "");
RedisModule_ReplyWithCString(ctx, "flags");
RedisModule_ReplyWithCString(ctx, flags);
RedisModule_ReplyWithCString(ctx, "id");
RedisModule_ReplyWithLongLong(ctx, ci.id);
RedisModule_ReplyWithCString(ctx, "addr");
RedisModule_ReplyWithCString(ctx, ci.addr);
RedisModule_ReplyWithCString(ctx, "port");
RedisModule_ReplyWithLongLong(ctx, ci.port);
RedisModule_ReplyWithCString(ctx, "db");
RedisModule_ReplyWithLongLong(ctx, ci.db);
return REDISMODULE_OK;
}
int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc) {
REDISMODULE_NOT_USED(argv);
REDISMODULE_NOT_USED(argc);
@ -221,6 +257,8 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, RedisModuleString **argv, int argc)
return REDISMODULE_ERR;
if (RedisModule_CreateCommand(ctx,"test.getlfu", test_getlfu,"",0,0,0) == REDISMODULE_ERR)
return REDISMODULE_ERR;
if (RedisModule_CreateCommand(ctx,"test.clientinfo", test_clientinfo,"",0,0,0) == REDISMODULE_ERR)
return REDISMODULE_ERR;
return REDISMODULE_OK;
}

View File

@ -67,4 +67,23 @@ start_server {tags {"modules"}} {
assert { $was_set == 0 }
}
test {test module clientinfo api} {
# Test basic sanity and SSL flag
set info [r test.clientinfo]
set ssl_flag [expr $::tls ? {"ssl:"} : {":"}]
assert { [dict get $info db] == 9 }
assert { [dict get $info flags] == "${ssl_flag}::::" }
# Test MULTI flag
r multi
r test.clientinfo
set info [lindex [r exec] 0]
assert { [dict get $info flags] == "${ssl_flag}::::multi" }
# Test TRACKING flag
r client tracking on
set info [r test.clientinfo]
assert { [dict get $info flags] == "${ssl_flag}::tracking::" }
}
}