lua-users home
lua-l archive

[Date Prev][Date Next][Thread Prev][Thread Next] [Date Index] [Thread Index]


Francesco Abbate wrote:
> A this point a remark is needed about the implementation of the
> algorithm. It does works for system of any dimension and it works by
> using an array of doubles to store the computed values and
> derivatives. As in this example we have only a 2-element
> vector it would have been easy to unpack the components to have a more
> efficient Lua code. The idea is that we don't want
> to do that because we want to have an algorithm that can work for ODE
> systems of any dimension, from 1 to N where N is possibly a big
> integer.

Well, that's the main problem. LuaJIT is not tuned to deal with
tons of loops that run only 2 iterations. It unrolls them, but
there's a limit to that and this hits here.

In C++ one would use templates for that purpose. This instantiates
a copy of the whole code for a specific number of dimensons.
That's not as wasteful as it sounds, since you probably only ever
use a finite set of dimensions, e.g. dim=2 and dim=3

In Lua we can do the same: specialize the Lua code at runtime for
the number of dimensions, memoize the code in a table and dispatch
based on the dimension. I.e. string.format + loadstring +
memoization table.

> Here the results of the benchmark:
> 
> LuaJIT2 with FFI, 0m47.805s
> LuaJIT2 with GSL, 1m20.958s

The tiny Lua function is never compiled. It runs interpreted and
the overhead for the FFI is big when not compiled. But apparently
still slightly less than the overhead of C code with userdatas.

> LuaJIT2 with Lua tables, 0m9.319s
> Lua 5.1.4, 0m26.644s
> C code with GSL, (compiled with -O2): 0m0.607s

You're lucky it runs even that fast. None of the Lua code is
compiled, due to all of those tiny loops! Have a look with -jv.

Just for the sake of the experiment, I've expanded the code for
two dimensions by hand (attached below). This is still problematic
due to all of these loads and stores to y[1], y[2] etc., where
local variables would do. For the real thing you should definitely
expand this to y_1, y_2 etc..

Since this is really just one giant loop, all those stores and
guards cause too many snapshots, so you'll need to run it with
-Omaxsnap=200 for the plain Lua array. The FFI causes less guards,
so it runs with the default settings.

5.68s rk4.lua     Lua tables
5.68s rk4.lua     FFI
0.24s rk4_2d.lua  Lua tables with -Omaxsnap=200
0.23s rk4_2d.lua  FFI

Now this is ~25x faster than before, which also means it's faster
than the pure C code!

And if you'd use local variables instead of all those tiny arrays,
performance would be even better.

> For the other side I was surprised by the FFI performance that are
> quite bad. I guess it is due to less effective optimizations made by
> the JIT compiler but slower then plain Lua ??

You're comparing apples and oranges: a C program plus the C -> Lua
call overhead plus the (uncompiled) FFI overhead with the overhead
of a pure Lua program.

But this actually does prove my point: rewriting mixed C/Lua code
in pure Lua (+FFI) is the way to go.

--Mike
use = 'FFI'

if use == 'FFI' then
   ffi = require 'ffi'
   darray = ffi.typeof("double[?]")
elseif use == 'GSL' then
   darray = function(n) return new(n, 1) end
else
   darray = function(n) return {} end
end

rk4 = {}

function rk4.new(n)
  local s = {
    k=         darray(n+1), 
    k1=        darray(n+1),
    y0=        darray(n+1),
    ytmp=      darray(n+1),
    y_onestep= darray(n+1),
    dim = n
  }
  return s
end

function rk4.step(y, state, h, t, sys)
  -- Makes a Runge-Kutta 4th order advance with step size h.
  local f = sys.f

  -- initial values of variables y.
  local y0 = state.y0
  
  -- work space 
  local ytmp = state.ytmp

  -- Runge-Kutta coefficients. Contains values of coefficient k1
  -- in the beginning 
  local k = state.k

  -- k1 step 
  y[1] = y[1] + h / 6 * k[1]
  ytmp[1] = y0[1] + 0.5 * h * k[1]
  y[2] = y[2] + h / 6 * k[2]
  ytmp[2] = y0[2] + 0.5 * h * k[2]
 
  -- k2 step

  f(t + 0.5 * h, ytmp, k)

  y[1] = y[1] + h / 3 * k[1]
  ytmp[1] = y0[1] + 0.5 * h * k[1]
  y[2] = y[2] + h / 3 * k[2]
  ytmp[2] = y0[2] + 0.5 * h * k[2]

  -- k3 step 
  f(t + 0.5 * h, ytmp, k)

  y[1] = y[1] + h / 3 * k[1]
  ytmp[1] = y0[1] + h * k[1]
  y[2] = y[2] + h / 3 * k[2]
  ytmp[2] = y0[2] + h * k[2]

  -- k4 step 
  f(t + h, ytmp, k)

  y[1] = y[1] + h / 6 * k[1]
  y[2] = y[2] + h / 6 * k[2]
end

function rk4.apply(state, t, h, y, yerr, dydt_in, dydt_out, sys)
  local f = sys.f
  local k, k1, y0, y_onestep = state.k, state.k1, state.y0, state.y_onestep

  y0[1] = y[1]
  y0[2] = y[2]

  if dydt_in then 
     k[1] = dydt_in[1]
     k[2] = dydt_in[2]
  else 
     f(t, y0, k)
  end

  -- Error estimation is done by step doubling procedure 
  -- Save first point derivatives
  k1[1] = k[1]
  k1[2] = k[2]

  -- First traverse h with one step (save to y_onestep) 
  y_onestep[1] = y[2]
  y_onestep[2] = y[2]

  rk4.step (y_onestep, state, h, t, sys)

  -- Then with two steps with half step length (save to y) 
  k[1] = k1[1]
  k[2] = k1[2]

  rk4.step(y, state, h/2, t, sys)

  -- Update before second step 
  f(t + h/2, y, k)
  
  -- Save original y0 to k1 for possible failures 
  k1[1] = y0[1]
  k1[2] = y0[2]

  -- Update y0 for second step 
  y0[1] = y[1]
  y0[2] = y[2]

  rk4.step(y, state, h/2, t + h/2, sys)

  -- Derivatives at output
  if dydt_out then f(t + h, y, dydt_out) end
  
  -- Error estimation
  --
  --   yerr = C * 0.5 * | y(onestep) - y(twosteps) | / (2^order - 1)
  --
  --   constant C is approximately 8.0 to ensure 90% of samples lie within
  --   the error (assuming a gaussian distribution with prior p(sigma)=1/sigma.)

  yerr[1] = 4 * (y[1] - y_onestep[1]) / 15
  yerr[2] = 4 * (y[2] - y_onestep[2]) / 15
end

function f_ode1(t, y, dydt)
   local p, q = y[1], y[2]
   dydt[1] = - q - p^2
   dydt[2] = 2*p - q^3
end

t0, t1, h0 = 0, 200, 0.001

function do_rk(p0, q0, sample)
  local dim = 2
  local state = rk4.new(dim)
  local y, dydt, yerr = darray(dim+1), darray(dim+1), darray(dim+1)
  local sys = {f = f_ode1}

  y[1], y[2] = p0, q0

  local t = t0
  local tsamp = t0
  rk4.apply(state, t, h0, y, yerr, nil, dydt, sys)
  t = t + h0
  while t < t1 do
     rk4.apply(state, t, h0, y, yerr, dydt, dydt, sys)
     t = t + h0
     if sample and t - tsamp > sample then
        print(t, y[1], y[2])
	tsamp = t
     end
  end
  print(t, y[1], y[2])
end

for k=1, 10 do
  local th = math.pi/4 -- *(k-1)/5
  local p0, q0 = math.cos(th), math.sin(th)
  do_rk(p0, q0)
end