From 2bd13cf0eb9b2369ade0ec495a2d9e2c2a3be680 Mon Sep 17 00:00:00 2001 From: Yossi Gottlieb Date: Sun, 5 Jun 2016 10:03:34 +0300 Subject: [PATCH 1/2] Allow passing arguments to modules on load. --- src/config.c | 17 +++++++++++++++-- src/module.c | 34 ++++++++++++++++++++++++---------- src/server.h | 8 +++++++- 3 files changed, 46 insertions(+), 13 deletions(-) diff --git a/src/config.c b/src/config.c index c72f0aeb2..8b5a6f218 100644 --- a/src/config.c +++ b/src/config.c @@ -153,6 +153,19 @@ void resetServerSaveParams(void) { server.saveparamslen = 0; } +void queueLoadModule(sds path, sds *argv, int argc) +{ + struct loadmodule *loadmod = zmalloc(sizeof(struct loadmodule)+sizeof(sds)*argc); + int i; + + loadmod->path = sdsnew(path); + loadmod->argc = argc; + for (i = 0; i < argc; i++) { + loadmod->argv[i] = sdsnew(argv[i]); + } + listAddNodeTail(server.loadmodule_queue,loadmod); +} + void loadServerConfigFromString(char *config) { char *err = NULL; int linenum = 0, totlines, i; @@ -632,8 +645,8 @@ void loadServerConfigFromString(char *config) { "Allowed values: 'upstart', 'systemd', 'auto', or 'no'"; goto loaderr; } - } else if (!strcasecmp(argv[0],"loadmodule") && argc == 2) { - listAddNodeTail(server.loadmodule_queue,sdsnew(argv[1])); + } else if (!strcasecmp(argv[0],"loadmodule") && argc >= 2) { + queueLoadModule(argv[1],&argv[2],argc-2); } else if (!strcasecmp(argv[0],"sentinel")) { /* argc == 1 is handled by main() as we need to enter the sentinel * mode ASAP. */ diff --git a/src/module.c b/src/module.c index 0a16b9408..27e041b50 100644 --- a/src/module.c +++ b/src/module.c @@ -2897,11 +2897,11 @@ void moduleLoadFromQueue(void) { listRewind(server.loadmodule_queue,&li); while((ln = listNext(&li))) { - sds modulepath = ln->value; - if (moduleLoad(modulepath) == C_ERR) { + struct loadmodule *loadmod = ln->value; + if (moduleLoad(loadmod->path,(void **)loadmod->argv,loadmod->argc) == C_ERR) { serverLog(LL_WARNING, "Can't load module from %s: server aborting", - modulepath); + loadmod->path); exit(1); } } @@ -2915,8 +2915,8 @@ void moduleFreeModuleStructure(struct RedisModule *module) { /* Load a module and initialize it. On success C_OK is returned, otherwise * C_ERR is returned. */ -int moduleLoad(const char *path) { - int (*onload)(void *); +int moduleLoad(const char *path, void **module_argv, int module_argc) { + int (*onload)(void *, void **, int); void *handle; RedisModuleCtx ctx = REDISMODULE_CTX_INIT; @@ -2925,14 +2925,14 @@ int moduleLoad(const char *path) { serverLog(LL_WARNING, "Module %s failed to load: %s", path, dlerror()); return C_ERR; } - onload = (int (*)(void *))(unsigned long) dlsym(handle,"RedisModule_OnLoad"); + onload = (int (*)(void *, void **, int))(unsigned long) dlsym(handle,"RedisModule_OnLoad"); if (onload == NULL) { serverLog(LL_WARNING, "Module %s does not export RedisModule_OnLoad() " "symbol. Module not loaded.",path); return C_ERR; } - if (onload((void*)&ctx) == REDISMODULE_ERR) { + if (onload((void*)&ctx,module_argv,module_argc) == REDISMODULE_ERR) { if (ctx.module) moduleFreeModuleStructure(ctx.module); dlclose(handle); serverLog(LL_WARNING, @@ -3006,16 +3006,30 @@ int moduleUnload(sds name) { /* Redis MODULE command. * - * MODULE LOAD */ + * MODULE LOAD [args...] */ void moduleCommand(client *c) { char *subcmd = c->argv[1]->ptr; - if (!strcasecmp(subcmd,"load") && c->argc == 3) { - if (moduleLoad(c->argv[2]->ptr) == C_OK) + if (!strcasecmp(subcmd,"load") && c->argc >= 3) { + sds *argv = NULL; + int argc = 0; + int i; + + if (c->argc > 3) { + argc = c->argc - 3; + argv = zmalloc(sizeof(sds)*argc); + for (i=0; iargv[i+3]->ptr; + } + } + + if (moduleLoad(c->argv[2]->ptr,(void **)argv,argc) == C_OK) addReply(c,shared.ok); else addReplyError(c, "Error loading the extension. Please check the server logs."); + if (argv) + zfree(argv); } else if (!strcasecmp(subcmd,"unload") && c->argc == 3) { if (moduleUnload(c->argv[2]->ptr) == C_OK) addReply(c,shared.ok); diff --git a/src/server.h b/src/server.h index e5e4ea236..a16d1a4ec 100644 --- a/src/server.h +++ b/src/server.h @@ -683,6 +683,12 @@ struct saveparam { int changes; }; +struct loadmodule { + sds path; + int argc; + sds argv[]; +}; + struct sharedObjectsStruct { robj *crlf, *ok, *err, *emptybulk, *czero, *cone, *cnegone, *pong, *space, *colon, *nullbulk, *nullmultibulk, *queued, @@ -1156,7 +1162,7 @@ extern dictType modulesDictType; /* Modules */ void moduleInitModulesSystem(void); -int moduleLoad(const char *path); +int moduleLoad(const char *path, void **argv, int argc); void moduleLoadFromQueue(void); int *moduleGetCommandKeysViaAPI(struct redisCommand *cmd, robj **argv, int argc, int *numkeys); moduleType *moduleTypeLookupModuleByID(uint64_t id); From cc58f11ccc295cbe6b96eb47e4c01627ca718252 Mon Sep 17 00:00:00 2001 From: Yossi Gottlieb Date: Sun, 5 Jun 2016 13:18:24 +0300 Subject: [PATCH 2/2] Use RedisModuleString for OnLoad argv. --- src/config.c | 4 ++-- src/module.c | 10 ++-------- src/server.h | 2 +- 3 files changed, 5 insertions(+), 11 deletions(-) diff --git a/src/config.c b/src/config.c index 8b5a6f218..05d0257cc 100644 --- a/src/config.c +++ b/src/config.c @@ -155,13 +155,13 @@ void resetServerSaveParams(void) { void queueLoadModule(sds path, sds *argv, int argc) { - struct loadmodule *loadmod = zmalloc(sizeof(struct loadmodule)+sizeof(sds)*argc); + struct loadmodule *loadmod = zmalloc(sizeof(struct loadmodule)+sizeof(robj*)*argc); int i; loadmod->path = sdsnew(path); loadmod->argc = argc; for (i = 0; i < argc; i++) { - loadmod->argv[i] = sdsnew(argv[i]); + loadmod->argv[i] = createStringObject(argv[i],sdslen(argv[i])); } listAddNodeTail(server.loadmodule_queue,loadmod); } diff --git a/src/module.c b/src/module.c index 27e041b50..8f45cf48d 100644 --- a/src/module.c +++ b/src/module.c @@ -3011,16 +3011,12 @@ void moduleCommand(client *c) { char *subcmd = c->argv[1]->ptr; if (!strcasecmp(subcmd,"load") && c->argc >= 3) { - sds *argv = NULL; + robj **argv = NULL; int argc = 0; - int i; if (c->argc > 3) { argc = c->argc - 3; - argv = zmalloc(sizeof(sds)*argc); - for (i=0; iargv[i+3]->ptr; - } + argv = &c->argv[3]; } if (moduleLoad(c->argv[2]->ptr,(void **)argv,argc) == C_OK) @@ -3028,8 +3024,6 @@ void moduleCommand(client *c) { else addReplyError(c, "Error loading the extension. Please check the server logs."); - if (argv) - zfree(argv); } else if (!strcasecmp(subcmd,"unload") && c->argc == 3) { if (moduleUnload(c->argv[2]->ptr) == C_OK) addReply(c,shared.ok); diff --git a/src/server.h b/src/server.h index a16d1a4ec..82bee10a8 100644 --- a/src/server.h +++ b/src/server.h @@ -686,7 +686,7 @@ struct saveparam { struct loadmodule { sds path; int argc; - sds argv[]; + robj *argv[]; }; struct sharedObjectsStruct {