Curried Lua

lua-users home
wiki

Currying is defined by Wikipedia[1] as follows:

"In computer science, currying is the technique of transforming a function taking multiple arguments into a function that takes a single argument (the first of the arguments to the original function) and returns a new function that takes the remainder of the arguments and returns the result"

You can implement curried functions in all languages that support functions as first-class objects. For example, there's a little [tutorial about curried JavaScript].

Here is a small Lua example of a curried function:

function sum(number) 
  return function(anothernumber) 
    return number + anothernumber
  end
end

local f = sum(5)
print(f(3)) --> 8

-- WalterCruz

Here is another, contributed by [GavinWraith], which takes a variable number of arguments terminated with a "()":

function addup(x)
  local sum = 0
  local function f(n)
    if type(n) == "number" then
      sum = sum + n
      return f
    else
      return sum
    end
  end
  return f(x)
end

print(addup (1) (2) (3) ())  --> 6
print(addup (4) (5) (6) ())  --> 15

Although these pre-curried functions are useful, what we would really like to do is make a general-purpose function that can perform the curry operation on any other function. To do this, we need to realize that functions can be operated upon by a "Higher-order function" -- a function that takes functions as arguments. The following curry function is an example of this, and curries a 2-argument function:

function curry(f)
    return function (x) return function (y) return f(x,y) end end
end

powcurry = curry(math.pow)
powcurry (2) (4) --> 16
pow2 = powcurry(2)
pow2(3) --> 8
pow2(4) --> 16
pow2(8) --> 256

To go from currying 2 arguments to currying 'n' arguments is a bit more complicated. We need to store an indeterminate number of partial applications, and unfortunately there is no way for Lua to know how many arguments a function requires; Lua functions can successfully receive any number of arguments, whether too many or too few. So, it's necessary to tell the curry function how many single-argument calls to accept before applying those collected arguments to the original function.

(this code is freely available from http://tinylittlelife.org/?p=249 and includes a full discussion of how to tackle this problem.)

-- curry(func, num_args) : take a function requiring a tuple for num_args arguments
--                         and turn it into a series of 1-argument functions
-- e.g.: you have a function dosomething(a, b, c)
-- curried_dosomething = curry(dosomething, 3) -- we want to curry 3 arguments
-- curried_dosomething (a1) (b1) (c1) -- returns the result of dosomething(a1, b1, c1)
-- partial_dosomething1 = curried_dosomething (a_value) -- returns a function
-- partial_dosomething2 = partial_dosomething1 (b_value) -- returns a function
-- partial_dosomething2 (c_value) -- returns the result of dosomething(a_value, b_value, c_value)
function curry(func, num_args)

   -- currying 2-argument functions seems to be the most popular application
   num_args = num_args or 2

   -- no sense currying for 1 arg or less
   if num_args <= 1 then return func end

   -- helper takes an argtrace function, and number of arguments remaining to be applied
   local function curry_h(argtrace, n)
      if 0 == n then
	 -- kick off argtrace, reverse argument list, and call the original function
         return func(reverse(argtrace()))
      else
         -- "push" argument (by building a wrapper function) and decrement n
         return function (onearg)
                   return curry_h(function () return onearg, argtrace() end, n - 1)
                end
      end
   end  
   
   -- push the terminal case of argtrace into the function first
   return curry_h(function () return end, num_args)

end

-- reverse(...) : take some tuple and return a tuple of elements in reverse order
--                  
-- e.g. "reverse(1,2,3)" returns 3,2,1
function reverse(...)

   --reverse args by building a function to do it, similar to the unpack() example
   local function reverse_h(acc, v, ...)
      if 0 == select('#', ...) then
	 return v, acc()
      else
         return reverse_h(function () return v, acc() end, ...)
      end
   end  

   -- initial acc is the end of the list
   return reverse_h(function () return end, ...)
end

The above code is Lua 5.1 compatible.

Since Lua 5.2 (or LuaJIT 2.0) provides an advanced debug.getinfo function that let us know how many arguments a function desires, we can make a practical function which mixes currying and partial application techniques. Here's the code:

function curry(func, num_args)
  num_args = num_args or debug.getinfo(func, "u").nparams
  if num_args < 2 then return func end
  local function helper(argtrace, n)
    if n < 1 then
      return func(unpack(flatten(argtrace)))
    else
      return function (...)
        return helper({argtrace, ...}, n - select("#", ...))
      end
    end
  end
  return helper({}, num_args)
end

function flatten(t)
  local ret = {}
  for _, v in ipairs(t) do
    if type(v) == 'table' then
      for _, fv in ipairs(flatten(v)) do
        ret[#ret + 1] = fv
      end
    else
      ret[#ret + 1] = v
    end
  end
  return ret
end

function multiplyAndAdd (a, b, c) return a * b + c end

curried_multiplyAndAdd = curry(multiplyAndAdd)

multiplyBySevenAndAdd = curried_multiplyAndAdd(7)

multiplySevenByEightAndAdd_v1 = multiplyBySevenAndAdd(8)
multiplySevenByEightAndAdd_v2 = curried_multiplyAndAdd(7, 8)

assert(multiplyAndAdd(7, 8, 9) == multiplySevenByEightAndAdd_v1(9))
assert(multiplyAndAdd(7, 8, 9) == multiplySevenByEightAndAdd_v2(9))
assert(multiplyAndAdd(7, 8, 9) == multiplyBySevenAndAdd(8, 9))
assert(multiplyAndAdd(7, 8, 9) == curried_multiplyAndAdd(7, 8, 9))

See Also


RecentChanges · preferences
edit · history
Last edited March 27, 2014 2:39 pm GMT (diff)