lua-users home
lua-l archive

[Date Prev][Date Next][Thread Prev][Thread Next] [Date Index] [Thread Index]


Jim Whitehead II wrote:
3. Define a new setfenv() that calls the original, after tracking
which function has been tacked with which environment table, so you
can write...
4. A new getfenv() function which returns cleanEnv by default, and
otherwise does a lookup to see which stack level you need to return.
This is the tricky part...

Here's another implementation handling stack levels. It maintains a whitelist (is_childenv) of known environments created by the safe environment.


function makesafeenv(func)
  -- localize for isolation from environment changes
  local type    = type
  local getfenv = getfenv
  local setfenv = setfenv
  local assert  = assert

  local safeenv = {}
  local is_childenv = setmetatable({}, {__mode = 'k'})

  function safeenv.getfenv(f)
    f = (f == nil) and 1 + 1 or
        (type(f) == 'number') and (f >= 1 and f + 1 or func) or
        (type(f) == 'function' and f) or
        error('invalid argument #1', 2)
    local env = getfenv(f)
    return is_childenv[env] and env or getfenv(func)
  end
  function safeenv.setfenv(f, table)
    f = (type(f) == 'number') and (f >= 1 and f + 1 or func) or
        (type(f) == 'function') and f or
        error('invalid argument #1', 2)
    assert(is_childenv[getfenv(f)])
    is_childenv[table] = true
    return setfenv(f, table)
  end
  safeenv._G = safeenv

  is_childenv[safeenv] = true
  setfenv(func, safeenv)

  return func
end

-- TEST
local print = print
local assert = assert
local pcall = pcall
local G_old = _G
local func; func = makesafeenv(function()
  assert(_G ~= G_old)
  local env = getfenv()
  assert(getfenv( ) == _G)
  assert(getfenv(0) == _G)
  assert(getfenv(1) == _G)
  assert(getfenv(2) == _G)
  assert(getfenv(3) == _G)
  assert(not pcall(function() getfenv(4+2) end))
  assert(not pcall(function() getfenv(-1) end) or getfenv(-1) == _G)
  assert(getfenv(getfenv) == _G)
  assert(getfenv(assert)  == _G)
  local env2 = {__index=_G, getfenv=getfenv, setfenv=setfenv}
  local function func2()
    assert(getfenv(0) == env)
    assert(getfenv(1) == env2)
    assert(getfenv(2) == env)
  end
  setfenv(func2, env2)
  func2()
  setfenv(1,    env2); assert(getfenv() == env2)
  setfenv(0,    env);  assert(getfenv() == env)
  setfenv(0,    env2); assert(getfenv() == env2)
  setfenv(func, env);  assert(getfenv() == env)
  setfenv(func, env2); assert(getfenv() == env2)
end)
func()
print 'PASSED'