diff --git a/src/script_lua.c b/src/script_lua.c index d7332cf86..2873159d4 100644 --- a/src/script_lua.c +++ b/src/script_lua.c @@ -655,55 +655,19 @@ static void luaReplyToRedisReply(client *c, client* script_client, lua_State *lu * Lua redis.* functions implementations. * ------------------------------------------------------------------------- */ -#define LUA_CMD_OBJCACHE_SIZE 32 -#define LUA_CMD_OBJCACHE_MAX_LEN 64 -static int luaRedisGenericCommand(lua_State *lua, int raise_error) { - int j, argc = lua_gettop(lua); - scriptRunCtx* rctx = luaGetFromRegistry(lua, REGISTRY_RUN_CTX_NAME); - if (!rctx) { - luaPushError(lua, "redis.call/pcall can only be called inside a script invocation"); - return luaRaiseError(lua); - } - sds err = NULL; - client* c = rctx->c; - sds reply; - - /* Cached across calls. */ - static robj **argv = NULL; - static int argv_size = 0; - static robj *cached_objects[LUA_CMD_OBJCACHE_SIZE]; - static size_t cached_objects_len[LUA_CMD_OBJCACHE_SIZE]; - static int inuse = 0; /* Recursive calls detection. */ - - /* By using Lua debug hooks it is possible to trigger a recursive call - * to luaRedisGenericCommand(), which normally should never happen. - * To make this function reentrant is futile and makes it slower, but - * we should at least detect such a misuse, and abort. */ - if (inuse) { - char *recursion_warning = - "luaRedisGenericCommand() recursive call detected. " - "Are you doing funny stuff with Lua debug hooks?"; - serverLog(LL_WARNING,"%s",recursion_warning); - luaPushError(lua,recursion_warning); - return 1; - } - inuse++; - +static robj **luaArgsToRedisArgv(lua_State *lua, int *argc) { + int j; /* Require at least one argument */ - if (argc == 0) { - luaPushError(lua, - "Please specify at least one argument for redis.call()"); - inuse--; - return raise_error ? luaRaiseError(lua) : 1; + *argc = lua_gettop(lua); + if (*argc == 0) { + luaPushError(lua, "Please specify at least one argument for this redis lib call"); + return NULL; } /* Build the arguments vector */ - if (argv_size < argc) { - argv = zrealloc(argv,sizeof(robj*)*argc); - argv_size = argc; - } + robj **argv = zcalloc(sizeof(robj*) * *argc); - for (j = 0; j < argc; j++) { + for (j = 0; j < *argc; j++) { char *obj_s; size_t obj_len; char dbuf[64]; @@ -720,38 +684,62 @@ static int luaRedisGenericCommand(lua_State *lua, int raise_error) { if (obj_s == NULL) break; /* Not a string. */ } - /* Try to use a cached object. */ - if (j < LUA_CMD_OBJCACHE_SIZE && cached_objects[j] && - cached_objects_len[j] >= obj_len) - { - sds s = cached_objects[j]->ptr; - argv[j] = cached_objects[j]; - cached_objects[j] = NULL; - memcpy(s,obj_s,obj_len+1); - sdssetlen(s, obj_len); - } else { - argv[j] = createStringObject(obj_s, obj_len); - } + argv[j] = createStringObject(obj_s, obj_len); } + /* Pop all arguments from the stack, we do not need them anymore + * and this way we guaranty we will have room on the stack for the result. */ + lua_pop(lua, *argc); + /* Check if one of the arguments passed by the Lua script * is not a string or an integer (lua_isstring() return true for * integers as well). */ - if (j != argc) { + if (j != *argc) { j--; while (j >= 0) { decrRefCount(argv[j]); j--; } - luaPushError(lua, - "Lua redis() command arguments must be strings or integers"); - inuse--; + zfree(argv); + luaPushError(lua, "Lua redis lib command arguments must be strings or integers"); + return NULL; + } + + return argv; +} + +static int luaRedisGenericCommand(lua_State *lua, int raise_error) { + int j; + scriptRunCtx* rctx = luaGetFromRegistry(lua, REGISTRY_RUN_CTX_NAME); + if (!rctx) { + luaPushError(lua, "redis.call/pcall can only be called inside a script invocation"); + return luaRaiseError(lua); + } + sds err = NULL; + client* c = rctx->c; + sds reply; + + int argc; + robj **argv = luaArgsToRedisArgv(lua, &argc); + if (argv == NULL) { return raise_error ? luaRaiseError(lua) : 1; } - /* Pop all arguments from the stack, we do not need them anymore - * and this way we guaranty we will have room on the stack for the result. */ - lua_pop(lua, argc); + static int inuse = 0; /* Recursive calls detection. */ + + /* By using Lua debug hooks it is possible to trigger a recursive call + * to luaRedisGenericCommand(), which normally should never happen. + * To make this function reentrant is futile and makes it slower, but + * we should at least detect such a misuse, and abort. */ + if (inuse) { + char *recursion_warning = + "luaRedisGenericCommand() recursive call detected. " + "Are you doing funny stuff with Lua debug hooks?"; + serverLog(LL_WARNING,"%s",recursion_warning); + luaPushError(lua,recursion_warning); + return 1; + } + inuse++; /* Log the command if debugging is active. */ if (ldbIsEnabled()) { @@ -769,7 +757,6 @@ static int luaRedisGenericCommand(lua_State *lua, int raise_error) { ldbLog(cmdlog); } - scriptCall(rctx, argv, argc, &err); if (err) { luaPushError(lua, err); @@ -810,45 +797,16 @@ static int luaRedisGenericCommand(lua_State *lua, int raise_error) { cleanup: /* Clean up. Command code may have changed argv/argc so we use the * argv/argc of the client instead of the local variables. */ - for (j = 0; j < c->argc; j++) { - robj *o = c->argv[j]; - - /* Try to cache the object in the cached_objects array. - * The object must be small, SDS-encoded, and with refcount = 1 - * (we must be the only owner) for us to cache it. */ - if (j < LUA_CMD_OBJCACHE_SIZE && - o->refcount == 1 && - (o->encoding == OBJ_ENCODING_RAW || - o->encoding == OBJ_ENCODING_EMBSTR) && - sdslen(o->ptr) <= LUA_CMD_OBJCACHE_MAX_LEN) - { - sds s = o->ptr; - if (cached_objects[j]) decrRefCount(cached_objects[j]); - cached_objects[j] = o; - cached_objects_len[j] = sdsalloc(s); - } else { - decrRefCount(o); - } - } - - if (c->argv != argv) { - zfree(c->argv); - argv = NULL; - argv_size = 0; - } - + freeClientArgv(c); c->user = NULL; - c->argv = NULL; - c->argc = 0; + inuse--; if (raise_error) { /* If we are here we should have an error in the stack, in the * form of a table with an "err" field. Extract the string to * return the plain error. */ - inuse--; return luaRaiseError(lua); } - inuse--; return 1; } @@ -939,6 +897,46 @@ static int luaRedisSetReplCommand(lua_State *lua) { return 0; } +/* redis.acl_check_cmd() + * + * Checks ACL permissions for given command for the current user. */ +static int luaRedisAclCheckCmdPermissionsCommand(lua_State *lua) { + scriptRunCtx* rctx = luaGetFromRegistry(lua, REGISTRY_RUN_CTX_NAME); + if (!rctx) { + lua_pushstring(lua, "redis.acl_check_cmd can only be called inside a script invocation"); + return lua_error(lua); + } + int raise_error = 0; + + int argc; + robj **argv = luaArgsToRedisArgv(lua, &argc); + + /* Require at least one argument */ + if (argv == NULL) return lua_error(lua); + + /* Find command */ + struct redisCommand *cmd; + if ((cmd = lookupCommand(argv, argc)) == NULL) { + lua_pushstring(lua, "Invalid command passed to redis.acl_check_cmd()"); + raise_error = 1; + } else { + int keyidxptr; + if (ACLCheckAllUserCommandPerm(rctx->original_client->user, cmd, argv, argc, &keyidxptr) != ACL_OK) { + lua_pushboolean(lua, 0); + } else { + lua_pushboolean(lua, 1); + } + } + + while (argc--) decrRefCount(argv[argc]); + zfree(argv); + if (raise_error) + return lua_error(lua); + else + return 1; +} + + /* redis.log() */ static int luaLogCommand(lua_State *lua) { int j, argc = lua_gettop(lua); @@ -1251,8 +1249,13 @@ void luaRegisterRedisAPI(lua_State* lua) { lua_pushstring(lua,"REPL_ALL"); lua_pushnumber(lua,PROPAGATE_AOF|PROPAGATE_REPL); - lua_settable(lua,-3); + + /* redis.acl_check_cmd */ + lua_pushstring(lua,"acl_check_cmd"); + lua_pushcfunction(lua,luaRedisAclCheckCmdPermissionsCommand); + lua_settable(lua,-3); + /* Finally set the table as 'redis' global var. */ lua_setglobal(lua,REDIS_API_NAME); diff --git a/src/server.h b/src/server.h index 14c571a5e..0b1904d4e 100644 --- a/src/server.h +++ b/src/server.h @@ -2378,6 +2378,7 @@ int beforeNextClient(client *c); void clearClientConnectionState(client *c); void resetClient(client *c); void freeClientOriginalArgv(client *c); +void freeClientArgv(client *c); void sendReplyToClient(connection *conn); void *addReplyDeferredLen(client *c); void setDeferredArrayLen(client *c, void *node, long length); diff --git a/tests/unit/scripting.tcl b/tests/unit/scripting.tcl index f16555350..93db6d071 100644 --- a/tests/unit/scripting.tcl +++ b/tests/unit/scripting.tcl @@ -701,6 +701,32 @@ start_server {tags {"scripting"}} { return redis.call("EXISTS", "key") } 0] 0 } + + test "Script ACL check" { + r acl setuser bob on {>123} {+@scripting} {+set} {~x*} + assert_equal [r auth bob 123] {OK} + + # Check permission granted + assert_equal [run_script { + return redis.acl_check_cmd('set','xx',1) + } 1 xx] 1 + + # Check permission denied unauthorised command + assert_equal [run_script { + return redis.acl_check_cmd('hset','xx','f',1) + } 1 xx] {} + + # Check permission denied unauthorised key + # Note: we don't pass the "yy" key as an argument to the script so key acl checks won't block the script + assert_equal [run_script { + return redis.acl_check_cmd('set','yy',1) + } 0] {} + + # Check error due to invalid command + assert_error {ERR *Invalid command passed to redis.acl_check_cmd()} {run_script { + return redis.acl_check_cmd('invalid-cmd','arg') + } 0} + } } # Start a new server since the last test in this stanza will kill the