lua-users home
lua-l archive

[Date Prev][Date Next][Thread Prev][Thread Next] [Date Index] [Thread Index]


On 16 February 2011 11:15, steve donovan <steve.j.donovan@gmail.com> wrote:
> On Wed, Feb 16, 2011 at 1:04 PM, T T <t34www@googlemail.com> wrote:
>> of code I would like to write or work with.  One would hope that such
>> a (common?) case as unrolling small loops could be done automatically
>> for the programmer, rather than require writing special templating
>> solutions.
>
> One issue is that we do not have this:
>
> const n = 3
>
> With that kind of guarantee, unrolling becomes possible.

OK, I went of my way and templated Francesco's original code in a
simple way to see if I can do any better, the code is attached
(rk4-unroll4.lua).  Basically, I unrolled the loops up to 4 levels
like this:

  if dim > 4 then
    for i = 1,dim do
      y0[i] = y[i]
    end
  else
    if dim >= 1 then
      y0[1] = y[1]
    end
    if dim >= 2 then
      y0[2] = y[2]
    end
    if dim >= 3 then
      y0[3] = y[3]
    end
    if dim >= 4 then
      y0[4] = y[4]
    end
  end

Note that this works for any dimension (thus, it is not specialized
and doesn't require run-time generation).

Timings with luajit -Omaxsnap=300:

  dim=2    time=0.25 sec
  dim=3    time=0.29 sec
  dim=4    time=0.41 sec
  dim=5    time=1.02 sec
  dim=6    time=1.23 sec

Hey, now that's not too bad, is it?  Mike Pall's code posted earlier
(rk_2d.lua) runs in 0.23 sec on my machine.  That's darn close to what
I've got for dim=2.

> One approach to this is to use an integrated preprocessor that uses
> the token-filter patch. That would make our 'n' to be a macro which is
> expanded as '3' wherever used.

Why not let the machine do it based on the runtime value of 'n' (with
some safe guards to fall back on the interpreter if it changes)?

> But the situation that Francesco is dealing with is template
> specialization, as a C++ programmer would understand it.  That
> requires re-compiling different cases.

I'm not convinced that this is the best way forward for stuff like simple loops.

Cheers,

Tomek

use = 'Lua'

if use == 'FFI' then
   ffi = require 'ffi'
   darray = ffi.typeof("double[?]")
elseif use == 'GSL' then
   darray = function(n) return new(n, 1) end
else
   darray = function(n) local t = {}; for k = 1,n do t[k] = 0; end; return t end
end

rk4 = {}

function rk4.new(n)
  local s = {
    k=         darray(n+1), 
    k1=        darray(n+1),
    y0=        darray(n+1),
    ytmp=      darray(n+1),
    y_onestep= darray(n+1),
    dim = n
  }
  return s
end

function rk4.step(y, state, h, t, sys)
  -- Makes a Runge-Kutta 4th order advance with step size h.
  local dim = state.dim
  local f = sys.f

  -- initial values of variables y.
  local y0 = state.y0
  
  -- work space 
  local ytmp = state.ytmp

  -- Runge-Kutta coefficients. Contains values of coefficient k1
  -- in the beginning 
  local k = state.k

  -- k1 step 

  if dim > 4 then
    for i = 1,dim do
      y[i] = y[i] + h / 6 * k[i]; ytmp[i] = y0[i] + 0.5 * h * k[i]
    end
  else
    if dim >= 1 then
      y[1] = y[1] + h / 6 * k[1]; ytmp[1] = y0[1] + 0.5 * h * k[1]
    end
    if dim >= 2 then
      y[2] = y[2] + h / 6 * k[2]; ytmp[2] = y0[2] + 0.5 * h * k[2]
    end
    if dim >= 3 then
      y[3] = y[3] + h / 6 * k[3]; ytmp[3] = y0[3] + 0.5 * h * k[3]
    end
    if dim >= 4 then
      y[4] = y[4] + h / 6 * k[4]; ytmp[4] = y0[4] + 0.5 * h * k[4]
    end
  end
 
  -- k2 step

  f(t + 0.5 * h, ytmp, k)

  if dim > 4 then
    for i = 1,dim do
      y[i] = y[i] + h / 3 * k[i]; ytmp[i] = y0[i] + 0.5 * h * k[i]
    end
  else
    if dim >= 1 then
      y[1] = y[1] + h / 3 * k[1]; ytmp[1] = y0[1] + 0.5 * h * k[1]
    end
    if dim >= 2 then
      y[2] = y[2] + h / 3 * k[2]; ytmp[2] = y0[2] + 0.5 * h * k[2]
    end
    if dim >= 3 then
      y[3] = y[3] + h / 3 * k[3]; ytmp[3] = y0[3] + 0.5 * h * k[3]
    end
    if dim >= 4 then
      y[4] = y[4] + h / 3 * k[4]; ytmp[4] = y0[4] + 0.5 * h * k[4]
    end
  end

  -- k3 step 
  f(t + 0.5 * h, ytmp, k)

  if dim > 4 then
    for i = 1,dim do
      y[i] = y[i] + h / 3 * k[i]; ytmp[i] = y0[i] + 0.5 * h * k[i]
    end
  else
    if dim >= 1 then
      y[1] = y[1] + h / 3 * k[1]; ytmp[1] = y0[1] + 0.5 * h * k[1]
    end
    if dim >= 2 then
      y[2] = y[2] + h / 3 * k[2]; ytmp[2] = y0[2] + 0.5 * h * k[2]
    end
    if dim >= 3 then
      y[3] = y[3] + h / 3 * k[3]; ytmp[3] = y0[3] + 0.5 * h * k[3]
    end
    if dim >= 4 then
      y[4] = y[4] + h / 3 * k[4]; ytmp[4] = y0[4] + 0.5 * h * k[4]
    end
  end

  -- k4 step 
  f(t + h, ytmp, k)

  if dim > 4 then
    for i = 1,dim do
      y[i] = y[i] + h / 6 * k[i]
    end
  else
    if dim >= 1 then
      y[1] = y[1] + h / 6 * k[1]
    end
    if dim >= 2 then
      y[2] = y[2] + h / 6 * k[2]
    end
    if dim >= 3 then
      y[3] = y[3] + h / 6 * k[3]
    end
    if dim >= 4 then
      y[4] = y[4] + h / 6 * k[4]
    end
  end
end

function rk4.apply(state, t, h, y, yerr, dydt_in, dydt_out, sys)
  local f, dim = sys.f, state.dim
  local k, k1, y0, y_onestep = state.k, state.k1, state.y0, state.y_onestep

  if dim > 4 then
    for i = 1,dim do
      y0[i] = y[i]
    end
  else
    if dim >= 1 then
      y0[1] = y[1]
    end
    if dim >= 2 then
      y0[2] = y[2]
    end
    if dim >= 3 then
      y0[3] = y[3]
    end
    if dim >= 4 then
      y0[4] = y[4]
    end
  end

  if dydt_in then 
    if dim > 4 then
      for i = 1,dim do
          k[i] = dydt_in[i]
      end
    else
      if dim >= 1 then
          k[1] = dydt_in[1]
      end
      if dim >= 2 then
          k[2] = dydt_in[2]
      end
      if dim >= 3 then
          k[3] = dydt_in[3]
      end
      if dim >= 4 then
          k[4] = dydt_in[4]
      end
    end
  else 
     f(t, y0, k)
  end

  -- Error estimation is done by step doubling procedure 
  -- Save first point derivatives
  if dim > 4 then
    for i = 1,dim do
      k1[i] = k[i]
    end
  else
    if dim >= 1 then
      k1[1] = k[1]
    end
    if dim >= 2 then
      k1[2] = k[2]
    end
    if dim >= 3 then
      k1[3] = k[3]
    end
    if dim >= 4 then
      k1[4] = k[4]
    end
  end

  -- First traverse h with one step (save to y_onestep) 
  if dim > 4 then
    for i = 1,dim do
      y_onestep[i] = y[i]
    end
  else
    if dim >= 1 then
      y_onestep[1] = y[1]
    end
    if dim >= 2 then
      y_onestep[2] = y[2]
    end
    if dim >= 3 then
      y_onestep[3] = y[3]
    end
    if dim >= 4 then
      y_onestep[4] = y[4]
    end
  end

  rk4.step (y_onestep, state, h, t, sys)

  -- Then with two steps with half step length (save to y) 
  if dim > 4 then
    for i = 1,dim do
      k[i] = k1[i]
    end
  else
    if dim >= 1 then
      k[1] = k1[1]
    end
    if dim >= 2 then
      k[2] = k1[2]
    end
    if dim >= 3 then
      k[3] = k1[3]
    end
    if dim >= 4 then
      k[4] = k1[4]
    end
  end

  rk4.step(y, state, h/2, t, sys)

  -- Update before second step 
  f(t + h/2, y, k)
  
  -- Save original y0 to k1 for possible failures 
  if dim > 4 then
    for i = 1,dim do
      k1[i] = y0[i]
    end
  else
    if dim >= 1 then
      k1[1] = y0[1]
    end
    if dim >= 2 then
      k1[2] = y0[2]
    end
    if dim >= 3 then
      k1[3] = y0[3]
    end
    if dim >= 4 then
      k1[4] = y0[4]
    end
  end

  -- Update y0 for second step 
  if dim > 4 then
    for i = 1,dim do
      y0[i] = y[i]
    end
  else
    if dim >= 1 then
      y0[1] = y[1]
    end
    if dim >= 2 then
      y0[2] = y[2]
    end
    if dim >= 3 then
      y0[3] = y[3]
    end
    if dim >= 4 then
      y0[4] = y[4]
    end
  end

  rk4.step(y, state, h/2, t + h/2, sys)

  -- Derivatives at output
  if dydt_out then f(t + h, y, dydt_out) end
  
  -- Error estimation
  --
  --   yerr = C * 0.5 * | y(onestep) - y(twosteps) | / (2^order - 1)
  --
  --   constant C is approximately 8.0 to ensure 90% of samples lie within
  --   the error (assuming a gaussian distribution with prior p(sigma)=1/sigma.)

  if dim > 4 then
    for i = 1,dim do
      yerr[i] = 4 * (y[i] - y_onestep[i]) / 15
    end
  else
    if dim >= 1 then
      yerr[1] = 4 * (y[1] - y_onestep[1]) / 15
    end
    if dim >= 2 then
      yerr[2] = 4 * (y[2] - y_onestep[2]) / 15
    end
    if dim >= 3 then
      yerr[3] = 4 * (y[3] - y_onestep[3]) / 15
    end
    if dim >= 4 then
      yerr[4] = 4 * (y[4] - y_onestep[4]) / 15
    end
  end
end

function f_ode1(t, y, dydt)
   local p, q = y[1], y[2]
   dydt[1] = - q - p^2
   dydt[2] = 2*p - q^3
end

t0, t1, h0 = 0, 200, 0.001

function do_rk(p0, q0, sample, dim)
--  local dim = tonumber(os.getenv('dim') or 2)
  local state = rk4.new(dim)
  local y, dydt, yerr = darray(dim+1), darray(dim+1), darray(dim+1)
  local sys = {f = f_ode1}

  y[1], y[2] = p0, q0

  local t = t0
  local tsamp = t0
  rk4.apply(state, t, h0, y, yerr, nil, dydt, sys)
  t = t + h0
  while t < t1 do
     rk4.apply(state, t, h0, y, yerr, dydt, dydt, sys)
     t = t + h0
     if sample and t - tsamp > sample then
        print(t, y[1], y[2])
        tsamp = t
     end
  end
  print(t, y[1], y[2])
end

for k=1, 10 do
  local th = math.pi/4 -- *(k-1)/5
  local p0, q0 = math.cos(th), math.sin(th)
  local dim = tonumber(os.getenv('dim') or 2)
  do_rk(p0, q0, sample, dim)
end