lua-users home
lua-l archive

[Date Prev][Date Next][Thread Prev][Thread Next] [Date Index] [Thread Index]


Hi all,

here I am again with my Lua numeric algorithm implementation. I've
managed to write the rkf45 ODE integrator in vector form using the
cblas functions to perform vector arithmentic.

The results is good in term of accuracy, I've tested for the same
example as before with N=2 and I obtain the same results. For the
other side the execution is speed is ~ 100x - 150x slower!

I guess the reason is that, once again, for some reason LuaJIT2 refuse
to compile the code but I don't have any clear idea of why that
happens.

In attachment you will find the algorithm in vector form
rkf45vec.lua.in and the result after template preprocessing,
rkf45vec-out.lua. I include also the good version of the template and
the benchmark code.

Please note that you cannot run it with plain luajit2 because this
latter isn't linked with cblas as gsl shell. I guess that this problem
can be easily solved by loading the cblas library but I can give more
help if needed.

I've given a look at the trace and it seems that the root of the
problem is the cblas function that LuaJIT2 doesn't like:

[TRACE --- rkf45vec-out.lua:78 -- NYI: unsupported C function type at
rkf45vec-out.lua:83]

the function incriminated is cblas_daxpy. But I don't really know.

I hope that Mike can save me yet another time! :-)

Francesco

Attachment: rkf45vec.lua.in
Description: Binary data

local abs, max, min = math.abs, math.max, math.min







local ffi = require "ffi"

local vecsize = 2 * ffi.sizeof('double')

ffi.cdef[[
  typedef struct {
    double t;
    double h;
    double y[2];
    double dydt[2];
  } odevec_state;

  typedef struct {
    double y0[2];
    double ytmp[2];
    double k1[2];
    double k2[2];
    double k3[2];
    double k4[2];
    double k5[2];
    double k6[2];
  } ode_workspace;

  void cblas_daxpy (const int N, const double ALPHA,
		    const double * X, const int INCX,
		    double * Y, const int INCY);

  int cblas_idamax (const int N, const double * X, const int INCX);

  void cblas_dscal (const int N, const double ALPHA, double * X, const int INCX);
]]

local function ode_new()
   return ffi.new('odevec_state')
end

local function ode_init(s, t0, h0, f, y)
   ffi.copy(s.y, y, vecsize)
   f(t0, s.y, s.dydt)
   s.t = t0
   s.h = h0
end

local function hadjust(rmax, h)
   local S = 0.9
   if rmax > 1.1 then
      local r = S / rmax^(1/5)
      r = max(0.2, r)
      return r * h, -1
   elseif rmax < 0.5 then
      local r = S / rmax^(1/(5+1))
      r = max(1, min(r, 5))
      return r * h, 1
   end
   return h, 0
end

local ws = ffi.new('ode_workspace')

local function rkf45_evolve(s, f, t1)
   local t, h = s.t, s.h
   local hadj, inc

   ffi.copy (ws.y0, s.y, vecsize)
   ffi.copy (ws.k1, s.dydt, vecsize)

   if t + h > t1 then h = t1 - t end

   while h > 0 do
      local rmax = 0

      do
	 ffi.copy (ws.ytmp, s.y, vecsize)
	 ffi.C.cblas_daxpy (2, h * 0.25, ws.k1, 1, ws.ytmp, 1)

	 -- k2 step
	 f(t + 0.25 * h, ws.ytmp, ws.k2)

	 ffi.copy (ws.ytmp, s.y, vecsize)
	 ffi.C.cblas_daxpy (2, h * 0.09375, ws.k1, 1, ws.ytmp, 1)
	 ffi.C.cblas_daxpy (2, h * 0.28125, ws.k2, 1, ws.ytmp, 1)

	 -- k3 step
	 f(t + 0.375 * h, ws.ytmp, ws.k3)

	 ffi.copy (ws.ytmp, s.y, vecsize)
	 ffi.C.cblas_daxpy (2, h * 0.87938097405553, ws.k1, 1, ws.ytmp, 1)
	 ffi.C.cblas_daxpy (2, h * -3.2771961766045, ws.k2, 1, ws.ytmp, 1)
	 ffi.C.cblas_daxpy (2, h * 3.3208921256259, ws.k3, 1, ws.ytmp, 1)

	 -- k4 step
	 f(t + 0.92307692307692 * h, ws.ytmp, ws.k4)

	 ffi.copy (ws.ytmp, s.y, vecsize)
	 ffi.C.cblas_daxpy (2, h * 2.0324074074074, ws.k1, 1, ws.ytmp, 1)
	 ffi.C.cblas_daxpy (2, h * -8, ws.k2, 1, ws.ytmp, 1)
	 ffi.C.cblas_daxpy (2, h * 7.1734892787524, ws.k3, 1, ws.ytmp, 1)
	 ffi.C.cblas_daxpy (2, h * -0.20589668615984, ws.k4, 1, ws.ytmp, 1)

	 -- k5 step
	 f(t + 1 * h, ws.ytmp, ws.k5)

	 ffi.copy (ws.ytmp, s.y, vecsize)
	 ffi.C.cblas_daxpy (2, h * -0.2962962962963, ws.k1, 1, ws.ytmp, 1)
	 ffi.C.cblas_daxpy (2, h * 2, ws.k2, 1, ws.ytmp, 1)
	 ffi.C.cblas_daxpy (2, h * -1.3816764132554, ws.k3, 1, ws.ytmp, 1)
	 ffi.C.cblas_daxpy (2, h * 0.45297270955166, ws.k4, 1, ws.ytmp, 1)
	 ffi.C.cblas_daxpy (2, h * -0.275, ws.k5, 1, ws.ytmp, 1)

	 -- k6 step and final sum
	 -- since k2 is no more used we could use k2 to store k6
	 f(t + 0.5 * h, ws.ytmp, ws.k6)

	 ffi.C.cblas_daxpy (2, h * 0.11851851851852, ws.k1, 1, s.y, 1)
	 ffi.C.cblas_daxpy (2, h * 0.51898635477583, ws.k3, 1, s.y, 1)
	 ffi.C.cblas_daxpy (2, h * 0.50613149034202, ws.k4, 1, s.y, 1)
	 ffi.C.cblas_daxpy (2, h * -0.18, ws.k5, 1, s.y, 1)
	 ffi.C.cblas_daxpy (2, h * 0.036363636363636, ws.k6, 1, s.y, 1)
 

         local yerr, r, d0
            yerr = h * (0.0027777777777778 * ws.k1[0] + -0.029941520467836 * ws.k3[0] + -0.029199893673578 * ws.k4[0] + 0.02 * ws.k5[0] + 0.036363636363636 * ws.k6[0])
            d0 = 0 * (1 * abs(s.y[0])) + 1e-06
         r = abs(yerr) / abs(d0)
         rmax = max(r, rmax)
            yerr = h * (0.0027777777777778 * ws.k1[1] + -0.029941520467836 * ws.k3[1] + -0.029199893673578 * ws.k4[1] + 0.02 * ws.k5[1] + 0.036363636363636 * ws.k6[1])
            d0 = 0 * (1 * abs(s.y[1])) + 1e-06
         r = abs(yerr) / abs(d0)
         rmax = max(r, rmax)
      end

      hadj, inc = hadjust(rmax, h)
      if inc >= 0 then break end

      ffi.copy(s.y, ws.y0, vecsize)
      h = hadj
   end

      f(t + h, s.y, s.dydt)
   s.t = t + h
   s.h = hadj

   return h
end

return {new= ode_new, init= ode_init, evolve= rkf45_evolve}
--
-- A Lua preprocessor for template code specialization.
-- Adapted by Steve Donovan, based on original code of Rici Lake.
--

local M = {}

-------------------------------------------------------------------------------
local function preprocess(chunk, name, defs)

   local function parseDollarParen(pieces, chunk, s, e)
      local append, format = table.insert, string.format
      local s = 1
      for term, executed, e in chunk:gmatch("()$(%b())()") do
	 append(pieces,
		format("%q..(%s or '')..", chunk:sub(s, term - 1), executed))
	 s = e
      end
      append(pieces, format("%q", chunk:sub(s)))
   end

   local function parseHashLines(chunk)
      local append = table.insert
      local pieces, s, args = chunk:find("^\n*#ARGS%s*(%b())[ \t]*\n")
      if not args or find(args, "^%(%s*%)$") then
	 pieces, s = {"return function(_put) ", n = 1}, s or 1
      else
	 pieces = {"return function(_put, ", args:sub(2), n = 2}
      end
      while true do
	 local ss, e, lua = chunk:find("^#+([^\n]*\n?)", s)
	 if not e then
	    ss, e, lua = chunk:find("\n#+([^\n]*\n?)", s)
	    append(pieces, "_put(")
	    parseDollarParen(pieces, chunk:sub(s, ss))
	    append(pieces, ")")
	    if not e then break end
	 end
	 append(pieces, lua)
	 s = e + 1
      end
      append(pieces, " end")
      return table.concat(pieces)
   end

   local ppenv

   if defs._self then
      ppenv = defs._self
   else
      ppenv = {string= string, table= table, template= M}
      for k, v in pairs(defs) do ppenv[k] = v end
      ppenv._self = ppenv
      local include = function(filename)
			 return M.process(filename, ppenv)
		      end
      setfenv(include, ppenv)
      ppenv.include = include
   end

   local code = parseHashLines(chunk)
   local fcode = loadstring(code, name)
   if fcode then
      setfenv(fcode, ppenv)
      return fcode()
   end
end

local function read_file(filename)
   local f = io.open(filename)
   local content = f:read('*a')
   f:close()
   return content
end

local function process(filename, defs)
   local template = read_file(filename)
   local codegen = preprocess(template, 'ode_codegen', defs)
   local code = {}
   local add = function(s) code[#code+1] = s end
   codegen(add)
   return table.concat(code)
end

local function require(filename)
   local f = loadstring(process(filename .. '.lua.in', {}), 'ode_out')
   if not f then error 'error loading ODE module' end
   return f()
end

local function load(filename, defs)
   local f = loadstring(process(filename, defs), 'ode_out')
   if not f then error 'error loading ODE module' end
   return f()
end

M.process = process
M.require = require
M.load    = load

return M
local template = require 'template'

local ffi = require "ffi"

local ode_spec = {N = 2, eps_abs = 1e-6, eps_rel = 0, a_y = 1, a_dydt = 0}
local ode = template.load('rkf45vec.lua.in', ode_spec)

function f_vanderpol_gen(mu)
   return function(t, y, f) 
	     f[0] =  y[1]
	     f[1] = -y[0] + mu * y[1]  * (1-y[0]^2)
	  end
end

local f = f_vanderpol_gen(10.0)
local s = ode.new()
local y = ffi.new('double[2]', {1, 0})
local t0, t1, h0 = 0, 20000, 0.01
local init, evol = ode.init, ode.evolve
for k=1, 10 do
   init(s, t0, h0, f, y)
   while s.t < t1 do
      evol(s, f, t1)
   end
   print(s.t, s.y[0], s.y[1])
end