lua-users home
lua-l archive

[Date Prev][Date Next][Thread Prev][Thread Next] [Date Index] [Thread Index]


(all code here is 100% untested, but see below for english)

---
local with = require "with"

local f = with ^ function(with, do_error)
  for x in with(io.open("/dev/null", "w")) do
    if do_error then error("erroring", 1) end
  end
end
---
local mt = {}
local function tpop(t) return table.remove(t) end
local function with_call(self, ...)
  local function insert_all_nonil(t, v, ...)
    assert(v)
    table.insert(self, v)
    if select('#', ...) > 0 then return insert_all_nonil(t, ...) end
  end
  insert_all_nonil(self, ...)
  return function(n, done)
    if not done then
        return table.unpack(self)
    else
        for v in tpop, self, nil do
          n = n - 1
          if n == 0 then break end
          v:close()
        end
    end
  end, select('#', ...)
end
local with_pow = function(self, f)
  return function(...)
    local w = setmetatable({}, mt)
    local function cleanup()
      for v in tpop, w, nil do
        v:close()
      end
    end
    return xpcall(f, cleanup, w, ...)
  end
end
mt.__call = with_call
mt.__pow = with_pow
return setmetatable({}, mt)
---

(look at the size of this thing!)

The way it works is very simple:

for v1, v2, v3 in with(x1, x2, x3) do ... end

after the with() call, this becomes

for v1, v2, v3 in unpack_or_cleanup, {x1, x2, x3}, nil do ... end

which, if you know how for loops work, basically that nil is the starting value, so unpack_or_cleanup is called with the table, and the nil. so it unpacks the table.

that fills in v1, v2, v3. and then the next iteration comes around and... oh yeah, v1 is the new value, and it's not nil, so we run a different branch that closes everything, and returns nil, stopping the iteration.

this is how we close stuff normally.

for errors, we need that "with decorator"[1] to wrap the function in something that uses xpcall. you need the "with decorator" for every function you want to use "with" in (or else it may affect unrelated functions, but I haven't tested that).

oh, and it's nestable. this works: (well, it's untested)

---
local with = require "with"

local f = with ^ function(with, do_error)
  for x in with(io.open("/dev/null", "w")) do
    for y in with(io.open("/dev/null", "w")) do
      error()
    end
  end
end
---

this will call y's close, then x's close. this also works: (also untested)

---
local with = require "with"

local f = with ^ function(with, do_error)
  for x in with({close = function() print(4) end}) do
    for y in with({close = function() print(2) end}) do
      print(1)
    end
    print(3)
  end
  print(5)
end
---

and prints 1, 2, 3, 4, 5 in that order

[1] inspired by https://marc.info/?l=lua-l&m=148745285529221&w=2