Method Chaining Wrapper

lua-users home
wiki

At times we would like to add custom methods to built-in types like strings and functions, particularly when using method chaining [1][2]:

("  test  "):trim():repeatchars(2):upper() --> "TTEESSTT"

(function(x,y) return x*y end):curry(2) --> (function(y) return 2*y end)

We can do this with the debug library (debug.setmetatable) [5]. The downside is that each built-in type has a single, common metatable. Modifying this metatable causes a global side-effect, which is a potential source of conflict between independently maintained modules in a program. Functions in the debug library are often discouraged in regular code for good reason. Many people avoid injecting into these global metatables, while others find it too convenient to avoid [3][6][ExtensionProposal]. Some have even asked why objects of built-in types don't have their own metatables [7].

...
debug.setmetatable("", string_mt)
debug.setmetatable(function()end, function_mt)

We could instead use just standalone functions:

(repeatchars(trim("test"), 2)):upper()

curry(function(x,y) return x*y end, 2)

This is the simplest solution. Simple solutions are often good ones. Nevertheless, there can be a certain level of discordance with some operations being method calls and some being standalone global functions, along with the reordering that results.

One solution to avoid touching the global metatables is to wrap the object inside our own class, perform operations on the wrapper in a method call chain, and unwrap the objects.

Examples would look like this:

S"  test  ":trim():repeatchars(2):upper()() --> TTEESSTT

S"  TEST  ":trim():lower():find('e')() --> 2 2

The S function wraps the given object into a wrapper object. A chain of method calls on the wrapper object operate on the wrapped object in-place. Finally, the wrapper object is unwrapped with a function call ().

For functions that return a single value, an alternative way to unpack is to the use the unary minus:

-S"  test  ":trim():repeatchars(2):upper() --> TTEESSTT

To define S in terms of a table of string functions stringx, we can use this code:

local stringx = {}
for k,v in pairs(string) do stringx[k] = v end
function stringx.trim(self)
  return self:match('^%s*(%S*)%s*$')
end
function stringx.repeatchars(self, n)
  local ts = {}
  for i=1,#self do
    local c = self:sub(i,i)
    for i=1,n do ts[#ts+1] = c end
  end
  return table.concat(ts)
end

local S = buildchainwrapbuilder(stringx)

The buildchainwrapbuilder function is general and implements our design pattern:

-- (c) 2009 David Manura. Licensed under the same terms as Lua (MIT license).
-- version 20090430
local select = select
local setmetatable = setmetatable
local unpack = unpack
local rawget = rawget

-- http://lua-users.org/wiki/CodeGeneration
local function memoize(func)
  return setmetatable({}, {
    __index = function(self, k) local v = func(k); self[k] = v; return v end,
    __call = function(self, k) return self[k] end
  })
end

-- unique IDs (avoid name clashes with wrapped object)
local N = {}
local VALS = memoize(function() return {} end)
local VAL = VALS[1]
local PREV = {}

local function mypack(ow, ...)
  local n = select('#', ...)
  for i=1,n do ow[VALS[i]] = select(i, ...) end
  for i=n+1,ow[N] do ow[VALS[i]] = nil end
  ow[N] = n
end

local function myunpack(ow, i)
  i = i or 1
  if i <= ow[N] then
    return rawget(ow, VALS[i]), myunpack(ow, i+1)
  end
end

local function buildchainwrapbuilder(t)
  local mt = {}
  function mt:__index(k)
    local val = rawget(self, VAL)
    self[PREV] = val -- store in case of method call
    mypack(self, t[k])
    return self
  end
  function mt:__call(...)
    if (...) == self then -- method call
      local val = rawget(self, VAL)
      local prev = rawget(self, PREV)
      self[PREV] = nil
      mypack(self, val(prev, select(2,...)))
      return self
    else
      return myunpack(self, 1, self[N])
    end
  end
  function mt:__unm() return rawget(self, VAL) end

  local function build(o)
    return setmetatable({[VAL]=o,[N]=1}, mt)
  end
  return build
end

local function chainwrap(o, t)
  return buildchainwrapbuilder(t)(o)
end

Test suite:

-- simple examples
assert(-S"AA":lower() == "aa")
assert(-S"AB":lower():reverse() == "ba")
assert(-S"  test  ":trim():repeatchars(2):upper() == "TTEESSTT")
assert(S"  test  ":trim():repeatchars(2):upper()() == "TTEESSTT")

-- basics
assert(S""() == "")
assert(S"a"() == "a")
assert(-S"a" == "a")
assert(S(nil)() == nil)
assert(S"a":byte()() == 97)
local a,b,c = S"TEST":lower():find('e')()
assert(a==2 and b==2 and c==nil)
assert(-S"TEST":lower():find('e') == 2)

-- potentially tricky cases
assert(S"".__index() == nil)
assert(S"".__call() == nil)
assert(S""[1]() == nil)
stringx[1] = 'c'
assert(S"a"[1]() == 'c')
assert(S"a"[1]:upper()() == 'C')
stringx[1] = 'd'
assert(S"a"[1]() == 'd') -- uncached
assert(S"a".lower() == string.lower)

-- improve error messages?
--assert(S(nil):z() == nil)

print 'DONE'

The above implementation has these qualities and assumptions:

There are alternative ways we could have expressed the chaining:

S{"  test  ", "trim", {"repeatchars",2}, "upper"}

S("  test  ", "trim | repeatchars(2) | upper")

but this looks less conventional. (Note: the second argument in the last line is point-free [4].)

Method Chaining Wrapper Take #2 - Object at End of Chain

We could instead express the call chain like this:

chain(stringx):trim():repeatchars(5):upper()('  test   ')

where the object operated on is placed at the very end. This reduces the chance of forgetting to unpack, and it allows separation and reuse:

f = chain(stringx):trim():repeatchars(5):upper()
print ( f('  test  ') )
print ( f('  again  ') )

There's various ways to implement this (functional, CodeGeneration, and VM). Here we take the latter approach.

-- method call chaining, take #2
-- (c) 2009 David Manura. Licensed under the same terms as Lua (MIT license).
-- version 20090501

-- unique IDs to avoid name conflict
local OPS = {}
local INDEX = {}
local METHOD = {}

-- table insert, allowing trailing nils
local function myinsert(t, v)
  local n = t.n + 1; t.n = n
  t[n] = v
end

local function eval(ops, x)
  --print('DEBUG:', unpack(ops,1,ops.n))
  local t = ops.t

  local self = x
  local prev
  local n = ops.n
  local i=1; while i <= n do
    if ops[i] == INDEX then
      local k = ops[i+1]
      prev = x  -- save in case of method call
      x = t[k]
      i = i + 2
    elseif ops[i] == METHOD then
      local narg = ops[i+1]
      x = x(prev, unpack(ops, i+2, i+1+narg))
      i = i + 2 + narg
    else
      assert(false)
    end
  end
  return x
end

local mt = {}
function mt:__index(k)
  local ops = self[OPS]
  myinsert(ops, INDEX)
  myinsert(ops, k)
  return self
end

function mt:__call(x, ...)
  local ops = self[OPS]
  if x == self then -- method call
    myinsert(ops, METHOD)
    local n = select('#', ...)
    myinsert(ops, n)
    for i=1,n do
      myinsert(ops, (select(i, ...)))
    end
    return self
  else
    return eval(ops, x)
  end
end

local function chain(t)
  return setmetatable({[OPS]={n=0,t=t}}, mt)
end

Rudimentary test code:

local stringx = {}
for k,v in pairs(string) do stringx[k] = v end
function stringx.trim(self)
  return self:match('^%s*(%S*)%s*$')
end
function stringx.repeatchars(self, n)
  local ts = {}
  for i=1,#self do
    local c = self:sub(i,i)
    for i=1,n do ts[#ts+1] = c end
  end
  return table.concat(ts)
end

local C = chain
assert(C(stringx):trim():repeatchars(2):upper()("  test  ") == 'TTEESSTT')
local f = C(stringx):trim():repeatchars(2):upper()
assert(f"  test  " == 'TTEESSTT')
assert(f"  again  " == 'AAGGAAIINN')
print 'DONE'

Method Chaining Wrapper Take #3 - Lexical injecting with scope-aware metatable

An alternate idea is to modify the string metatable so that the extensions to the string methods are only visible within a lexical scope. The following is not perfect (e.g. nested functions), but it is a start. Example:

-- test example libraries
local stringx = {}
function stringx.trim(self)  return self:match('^%s*(%S*)%s*$') end
local stringxx = {}
function stringxx.trim(self) return self:match('^%s?(.-)%s?$') end

-- test example
function test2(s)
  assert(s.trim == nil)
  scoped_string_methods(stringxx)
  assert(s:trim() == ' 123 ')
end
function test(s)
  scoped_string_methods(stringx)
  assert(s:trim() == '123')
  test2(s)
  assert(s:trim() == '123')
end
local s = '  123  '
assert(s.trim == nil)
test(s)
assert(s.trim == nil)
print 'DONE'

The function scoped_string_methods assigns the given function table to the scope of the currently executing function. All string indexing within the scope goes through that given table.

The above uses this framework code:

-- framework
local mt = debug.getmetatable('')
local scope = {}
function mt.__index(s, k)
  local f = debug.getinfo(2, 'f').func
  return scope[f] and scope[f][k] or string[k]
end
local function scoped_string_methods(t)
  local f = debug.getinfo(2, 'f').func
  scope[f] = t
end

Method Chaining Wrapper Take #4 - Lexical injecting with Metalua

We can do something similar to the above more robustly via MetaLua. An example is below.

-{extension "lexicalindex"}

-- test example libraries
local stringx = {}
function stringx.trim(self)  return self:match('^%s*(%S*)%s*$') end

local function f(o,k)
  if type(o) == 'string' then
    return stringx[k] or string[k]
  end
  return o[k]
end

local function test(s)
  assert(s.trim == nil)
  lexicalindex f
  assert(s.trim ~= nil)
  assert(s:trim():upper() == 'TEST')
end
local s = '  test  '
assert(s.trim == nil)
test(s)
assert(s.trim == nil)

print 'DONE'

The syntax extension introduces a new keyword lexicalindex that specifies a function to be called whenever a value is to be indexed inside the current scope.

Here is what the corresponding plain Lua source looks like:

--- $ ./build/bin/metalua -S vs.lua
--- Source From "@vs.lua": ---
local function __li_invoke (__li_index, o, name, ...)
   return __li_index (o, name) (o, ...)
end

local stringx = { }

function stringx:trim ()
   return self:match "^%s*(%S*)%s*$"
end

local function f (o, k)
   if type (o) == "string" then
      return stringx[k] or string[k]
   end
   return o[k]
end

local function test (s)
   assert (s.trim == nil)
   local __li_index = f
   assert (__li_index (s, "trim") ~= nil)
   assert (__li_invoke (__li_index, __li_invoke (__li_index, s, "trim"), "upper"
) == "TEST")
end

local s = "  test  "

assert (s.trim == nil)

test (s)

assert (s.trim == nil)

print "DONE"

The lexicalindex Metalua extension is implemented as

-- lexical index in scope iff depth > 0
local depth = 0

-- transform indexing expressions
mlp.expr.transformers:add(function(ast)
  if depth > 0 then
    if ast.tag == 'Index' then
      return +{__li_index(-{ast[1]}, -{ast[2]})}
    elseif ast.tag == 'Invoke' then
      return `Call{`Id'__li_invoke', `Id'__li_index', unpack(ast)}
    end
  end
end)

-- monitor scoping depth
mlp.block.transformers:add(function(ast)
  for _,ast2 in ipairs(ast) do
    if ast2.is_lexicalindex then
      depth = depth - 1; break
    end
  end
end)

-- handle new "lexicalindex" statement
mlp.lexer:add'lexicalindex'
mlp.stat:add{'lexicalindex', mlp.expr, builder=function(x)
  local e = unpack(x)
  local ast_out = +{stat: local __li_index = -{e}}
  ast_out.is_lexicalindex = true
  depth = depth + 1
  return ast_out
end}

-- utility function
-- (note: o must be indexed exactly once to preserve behavior
return +{block:
  local function __li_invoke(__li_index, o, name, ...)
    return __li_index(o, name)(o, ...)
  end
}

--DavidManura

See Also


RecentChanges · preferences
edit · history
Last edited December 9, 2009 1:38 am GMT (diff)