diff --git a/src/eval.c b/src/eval.c index 4894b83d3..1be0d42f2 100644 --- a/src/eval.c +++ b/src/eval.c @@ -269,7 +269,7 @@ void scriptingInit(int setup) { /* Lua beginners often don't use "local", this is likely to introduce * subtle bugs in their code. To prevent problems we protect accesses * to global variables. */ - luaEnableGlobalsProtection(lua, 1); + luaEnableGlobalsProtection(lua); lctx.lua = lua; } @@ -378,35 +378,20 @@ sds luaCreateFunction(client *c, robj *body) { sdsfreesplitres(parts, numparts); } - /* Build the lua function to be loaded */ - sds funcdef = sdsempty(); - funcdef = sdscat(funcdef,"function "); - funcdef = sdscatlen(funcdef,funcname,42); - funcdef = sdscatlen(funcdef,"() ",3); /* Note that in case of a shebang line we skip it but keep the line feed to conserve the user's line numbers */ - funcdef = sdscatlen(funcdef,(char*)body->ptr + shebang_len,sdslen(body->ptr) - shebang_len); - funcdef = sdscatlen(funcdef,"\nend",4); - - if (luaL_loadbuffer(lctx.lua,funcdef,sdslen(funcdef),"@user_script")) { + if (luaL_loadbuffer(lctx.lua,(char*)body->ptr + shebang_len,sdslen(body->ptr) - shebang_len,"@user_script")) { if (c != NULL) { addReplyErrorFormat(c, "Error compiling script (new function): %s", lua_tostring(lctx.lua,-1)); } lua_pop(lctx.lua,1); - sdsfree(funcdef); return NULL; } - sdsfree(funcdef); - if (lua_pcall(lctx.lua,0,0,0)) { - if (c != NULL) { - addReplyErrorFormat(c,"Error running script (new function): %s", - lua_tostring(lctx.lua,-1)); - } - lua_pop(lctx.lua,1); - return NULL; - } + serverAssert(lua_isfunction(lctx.lua, -1)); + + lua_setfield(lctx.lua, LUA_REGISTRYINDEX, funcname); /* We also save a SHA1 -> Original script map in a dictionary * so that we can replicate / write in the AOF all the @@ -479,7 +464,7 @@ void evalGenericCommand(client *c, int evalsha) { lua_getglobal(lua, "__redis__err__handler"); /* Try to lookup the Lua function */ - lua_getglobal(lua, funcname); + lua_getfield(lua, LUA_REGISTRYINDEX, funcname); if (lua_isnil(lua,-1)) { lua_pop(lua,1); /* remove the nil from the stack */ /* Function not defined... let's define it if we have the @@ -497,7 +482,7 @@ void evalGenericCommand(client *c, int evalsha) { return; } /* Now the following is guaranteed to return non nil */ - lua_getglobal(lua, funcname); + lua_getfield(lua, LUA_REGISTRYINDEX, funcname); serverAssert(!lua_isnil(lua,-1)); } diff --git a/src/script_lua.c b/src/script_lua.c index 9a08a7e47..4e1f17649 100644 --- a/src/script_lua.c +++ b/src/script_lua.c @@ -1144,7 +1144,7 @@ sds luaGetStringSds(lua_State *lua, int index) { * On Legacy Lua (eval) we need to check 'w ~= \"main\"' otherwise we will not be able * to create the global 'function ()' variable. On Functions Lua engine we do not use * this trick so it's not needed. */ -void luaEnableGlobalsProtection(lua_State *lua, int is_eval) { +void luaEnableGlobalsProtection(lua_State *lua) { char *s[32]; sds code = sdsempty(); int j = 0; @@ -1157,7 +1157,7 @@ void luaEnableGlobalsProtection(lua_State *lua, int is_eval) { s[j++]="mt.__newindex = function (t, n, v)\n"; s[j++]=" if dbg.getinfo(2) then\n"; s[j++]=" local w = dbg.getinfo(2, \"S\").what\n"; - s[j++]= is_eval ? " if w ~= \"main\" and w ~= \"C\" then\n" : " if w ~= \"C\" then\n"; + s[j++]=" if w ~= \"C\" then\n"; s[j++]=" error(\"Script attempted to create global variable '\"..tostring(n)..\"'\", 2)\n"; s[j++]=" end\n"; s[j++]=" end\n"; diff --git a/src/script_lua.h b/src/script_lua.h index 5a4533784..f39c01744 100644 --- a/src/script_lua.h +++ b/src/script_lua.h @@ -67,7 +67,7 @@ typedef struct errorInfo { void luaRegisterRedisAPI(lua_State* lua); sds luaGetStringSds(lua_State *lua, int index); -void luaEnableGlobalsProtection(lua_State *lua, int is_eval); +void luaEnableGlobalsProtection(lua_State *lua); void luaRegisterGlobalProtectionFunction(lua_State *lua); void luaSetGlobalProtection(lua_State *lua); void luaRegisterLogFunction(lua_State* lua);