lua-users home
lua-l archive

[Date Prev][Date Next][Thread Prev][Thread Next] [Date Index] [Thread Index]


Hi,

I've modified the implementation of the rkf45 ode integrator in
vectorial form, now it is slightly simpler and I was hoping faster
because I dont pass a pointer to a member instance of an ffi
structure.

The results are still accurate but the program it is slow like before
or similar.

I add in attachment the two files. Please note that I've added a line
to load with ffi the 'libgslcblas-0' library. This should work on
windows with gsl-shell or plain luajit provided that you have
installed the library (included with gsl library).

I hope to get some help because I'm stuck for the moment...

Francesco
local abs, max, min = math.abs, math.max, math.min







local ffi = require "ffi"

local vecsize = 2 * ffi.sizeof('double')

cblas = ffi.load('libgslcblas-0')

ffi.cdef[[
  typedef struct {
    double t;
    double h;
    double y[2];
    double dydt[2];
  } odevec_state;

  void cblas_daxpy (const int N, const double ALPHA,
		    const double * X, const int INCX,
		    double * Y, const int INCY);
]]

local function ode_new()
   return ffi.new('odevec_state')
end

local function ode_init(s, t0, h0, f, y)
   ffi.copy(s.y, y, vecsize)
   f(t0, s.y, s.dydt)
   s.t = t0
   s.h = h0
end

local function hadjust(rmax, h)
   local S = 0.9
   if rmax > 1.1 then
      local r = S / rmax^(1/5)
      r = max(0.2, r)
      return r * h, -1
   elseif rmax < 0.5 then
      local r = S / rmax^(1/(5+1))
      r = max(1, min(r, 5))
      return r * h, 1
   end
   return h, 0
end

local ws_y = ffi.new('double[2]')
local ws_k1   = ffi.new('double[2]')
local ws_k2   = ffi.new('double[2]')
local ws_k3   = ffi.new('double[2]')
local ws_k4   = ffi.new('double[2]')
local ws_k5   = ffi.new('double[2]')
local ws_k6   = ffi.new('double[2]')

local function rkf45_evolve(s, f, t1)
   local t, h = s.t, s.h
   local hadj, inc

   ffi.copy (ws_k1, s.dydt, vecsize)

   if t + h > t1 then h = t1 - t end

   while h > 0 do
      ffi.copy (ws_y, s.y, vecsize)
      local rmax = 0

      do
	 cblas.cblas_daxpy (2, h * 0.25, ws_k1, 1, ws_y, 1)

	 -- k2 step
	 f(t + 0.25 * h, ws_y, ws_k2)

	 ffi.copy (ws_y, s.y, vecsize)
	 cblas.cblas_daxpy (2, h * 0.09375, ws_k1, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * 0.28125, ws_k2, 1, ws_y, 1)

	 -- k3 step
	 f(t + 0.375 * h, ws_y, ws_k3)

	 ffi.copy (ws_y, s.y, vecsize)
	 cblas.cblas_daxpy (2, h * 0.87938097405553, ws_k1, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * -3.2771961766045, ws_k2, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * 3.3208921256259, ws_k3, 1, ws_y, 1)

	 -- k4 step
	 f(t + 0.92307692307692 * h, ws_y, ws_k4)

	 ffi.copy (ws_y, s.y, vecsize)
	 cblas.cblas_daxpy (2, h * 2.0324074074074, ws_k1, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * -8, ws_k2, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * 7.1734892787524, ws_k3, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * -0.20589668615984, ws_k4, 1, ws_y, 1)

	 -- k5 step
	 f(t + 1 * h, ws_y, ws_k5)

	 ffi.copy (ws_y, s.y, vecsize)
	 cblas.cblas_daxpy (2, h * -0.2962962962963, ws_k1, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * 2, ws_k2, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * -1.3816764132554, ws_k3, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * 0.45297270955166, ws_k4, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * -0.275, ws_k5, 1, ws_y, 1)

	 -- k6 step and final sum
	 -- since k2 is no more used we could use k2 to store k6
	 f(t + 0.5 * h, ws_y, ws_k6)

	 ffi.copy (ws_y, s.y, vecsize)
	 cblas.cblas_daxpy (2, h * 0.11851851851852, ws_k1, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * 0.51898635477583, ws_k3, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * 0.50613149034202, ws_k4, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * -0.18, ws_k5, 1, ws_y, 1)
	 cblas.cblas_daxpy (2, h * 0.036363636363636, ws_k6, 1, ws_y, 1)
 

         local yerr, r, d0
            yerr = h * (0.0027777777777778 * ws_k1[0] + -0.029941520467836 * ws_k3[0] + -0.029199893673578 * ws_k4[0] + 0.02 * ws_k5[0] + 0.036363636363636 * ws_k6[0])
            d0 = 0 * (1 * abs(ws_y[0])) + 1e-006
         r = abs(yerr) / abs(d0)
         rmax = max(r, rmax)
            yerr = h * (0.0027777777777778 * ws_k1[1] + -0.029941520467836 * ws_k3[1] + -0.029199893673578 * ws_k4[1] + 0.02 * ws_k5[1] + 0.036363636363636 * ws_k6[1])
            d0 = 0 * (1 * abs(ws_y[1])) + 1e-006
         r = abs(yerr) / abs(d0)
         rmax = max(r, rmax)
      end

      hadj, inc = hadjust(rmax, h)
      if inc >= 0 then break end

      h = hadj
   end

      f(t + h, ws_y, ws_k2)
      ffi.copy (s.dydt, ws_k2, vecsize)

   ffi.copy (s.y, ws_y, vecsize)

   s.t = t + h
   s.h = hadj

   return h
end

return {new= ode_new, init= ode_init, evolve= rkf45_evolve}

Attachment: rkf45vec.lua.in
Description: Binary data