lua-users home
lua-l archive

[Date Prev][Date Next][Thread Prev][Thread Next] [Date Index] [Thread Index]


Hi all,

after a good amount of work I've finalized the writing of the ODE
integration routine. The method that I've implemented is the Embedded
Runge-Kutta-Fehlberg (4, 5) method. This latter method is already much
better then the simple Runge-Kutta method and the idea is that the
following step would be to implement the Embedded Runge-Kutta
Prince-Dormand (8,9) method as someone suggested.

The implementation I've done is Lua is virtually identical to those
given in the GSL library. I've implemented the same methodology to
control the step size to limit the error accordingly to the user
input. The difference is that I don't use vector but everything is
expanded to local variables using a template preprocessor.

To develop the interface I've further refined the Lua preprocessor
that Steve Donovan made based on Rici Lake's original code snippet.
I've changed the implementation to avoid to write in the global
namespace and I've also adde a function to include other files during
pre processing. The resulting file is "template.lua".

In order to test the algorithm both for accuracy I've taken a basic
GSL example to show ODE evolution. I've changed the integration method
to rkf45, in the original examples was rk8pd (runge-kutte
prince-dormand). Then I've augmented the integration time and repeated
the whole process 10 times.

The results are just perfect in term of accuracy. Results produced
with LuaJIT2 are the same of those given by the C code.
For the other size it seems that there is a small problem because the
performance of LuaJIT2 are in this case below my expectations. Here
what I've got:

LuaJIT2:
real	0m14.498s
user	0m14.497s
sys	0m0.000s

C code (-O2) with GSL library:
real	0m1.094s
user	0m1.088s
sys	0m0.000s

so the C code in this case is approx 13.5x times faster.

I hope I've made a big stupid error in my implementation because my
hope was to have better results :-)

You will find in attachment all the files, if someone want to give a
look. The most important one is the preprocessed file,
"rkf45-out.lua". This file is generated from "rkf45.lua.in" and
"ode-defs.lua.in" by using the template module.

Otherwise if you want to reproduce the example with LuaJIT2 you will
need to add to math functions like sin, cos etc the "math." prefix.
The reason is that GSL shell put all the mathematical functions in the
common namespace. You can easily tests everyting by taking the luajit2
branch in the GSL shell git repository.

-- 
Francesco

Attachment: rkf45.lua.in
Description: Binary data





local ffi = require 'ffi'

ffi.cdef[[
  typedef struct {
    double t;
    double h;
    double y[2];
    double dydt[2];
  } ode_state;
]]

local function ode_new()
   return ffi.new('ode_state')
end

local function ode_init(s, t0, h0, f, y_0,y_1)
   s.y[0],s.y[1] = y_0,y_1
   s.dydt[0],s.dydt[1] = f(t0, y_0,y_1)
   s.t = t0
   s.h = h0
end

local function hadjust(rmax, h)
   local S = 0.9
   if rmax > 1.1 then
      local r = S / rmax^(1/5)
      r = max(0.2, r)
      return r * h, -1
   elseif rmax < 0.5 then
      local r = S / rmax^(1/(5+1))
      r = max(1, min(r, 5))
      return r * h, 1
   end
   return h, 0
end




-- These are the differences of fifth and fourth order coefficients
-- for error estimation */




function rkf45_evolve(s, f, t1)
   local t, h = s.t, s.h
   local dydt = s.dydt
   local hadj, inc

   local y_0,y_1
   local k1_0,k1_1 = dydt[0],dydt[1]

   if t + h > t1 then h = t1 - t end

   while h > 0 do
      y_0,y_1 = s.y[0],s.y[1]

         ytmp_0 = y_0 + 0.25 * h * k1_0
         ytmp_1 = y_1 + 0.25 * h * k1_1

      -- k2 step
      local k2_0,k2_1 = f(t + 0.25 * h, ytmp_0,ytmp_1)

         ytmp_0 = y_0 + h * (0.09375 * k1_0 + 0.28125 * k2_0)
         ytmp_1 = y_1 + h * (0.09375 * k1_1 + 0.28125 * k2_1)

      -- k3 step
      local k3_0,k3_1 = f(t + 0.375 * h, ytmp_0,ytmp_1)

         ytmp_0 = y_0 + h * (0.87938097405553 * k1_0 + -3.2771961766045 * k2_0 + 3.3208921256259 * k3_0)
         ytmp_1 = y_1 + h * (0.87938097405553 * k1_1 + -3.2771961766045 * k2_1 + 3.3208921256259 * k3_1)

      -- k4 step
      local k4_0,k4_1 = f(t + 0.92307692307692 * h, ytmp_0,ytmp_1)

         ytmp_0 = y_0 + h * (2.0324074074074 * k1_0 + -8 * k2_0 + 7.1734892787524 * k3_0 + -0.20589668615984 * k4_0)
         ytmp_1 = y_1 + h * (2.0324074074074 * k1_1 + -8 * k2_1 + 7.1734892787524 * k3_1 + -0.20589668615984 * k4_1)

      -- k5 step
      local k5_0,k5_1 = f(t + 1 * h, ytmp_0,ytmp_1)

         ytmp_0 = y_0 + h * (-0.2962962962963 * k1_0 + 2 * k2_0 + -1.3816764132554 * k3_0 + 0.45297270955166 * k4_0 + -0.275 * k5_0)
         ytmp_1 = y_1 + h * (-0.2962962962963 * k1_1 + 2 * k2_1 + -1.3816764132554 * k3_1 + 0.45297270955166 * k4_1 + -0.275 * k5_1)

      -- k6 step and final sum
      -- since k2 is no more used we can use k2 to store k6
      local k6_0,k6_1 = f(t + 0.5 * h, ytmp_0,ytmp_1)

      local di
         di = 0.11851851851852 * k1_0 + 0.51898635477583 * k3_0 + 0.50613149034202 * k4_0 + -0.18 * k5_0 + 0.036363636363636 * k6_0
         y_0 = y_0 + h * di
         di = 0.11851851851852 * k1_1 + 0.51898635477583 * k3_1 + 0.50613149034202 * k4_1 + -0.18 * k5_1 + 0.036363636363636 * k6_1
         y_1 = y_1 + h * di
 

      local yerr, r, d0
      local rmax = 0

         yerr = h * (0.0027777777777778 * k1_0 + -0.029941520467836 * k3_0 + -0.029199893673578 * k4_0 + 0.02 * k5_0 + 0.036363636363636 * k6_0)
         d0 = 0 * (1 * abs(y_0)) + 1e-06
         r = abs(yerr) / abs(d0)
         rmax = max(r, rmax)
         yerr = h * (0.0027777777777778 * k1_1 + -0.029941520467836 * k3_1 + -0.029199893673578 * k4_1 + 0.02 * k5_1 + 0.036363636363636 * k6_1)
         d0 = 0 * (1 * abs(y_1)) + 1e-06
         r = abs(yerr) / abs(d0)
         rmax = max(r, rmax)

      hadj, inc = hadjust(rmax, h)
      if inc >= 0 then break end
        
      h = hadj
   end

      dydt[0],dydt[1] = f(t + h, y_0,y_1)

   s.y[0],s.y[1] = y_0,y_1 
   s.t = t + h
   s.h = hadj

   return h
end

return {new= ode_new, init= ode_init, evolve= rkf45_evolve}
local template = require 'template'

local ode_spec = {N = 2, eps_abs = 1e-6, eps_rel = 0, a_y = 1, a_dydt = 0}

local codegen = template.compile('rkf45.lua.in', ode_spec)
local ode = codegen()

function f_vanderpol_gen(mu)
   return function(t, x, y) return y, -x + mu * y * (1-x^2) end
end

local f = f_vanderpol_gen(10.0)
local s = ode.new()
local x, y = 1, 0
local t0, t1, h0 = 0, 20000, 0.01
local init, evol = ode.init, ode.evolve
for k=1, 10 do
   init(s, t0, h0, f, x, y)
   while s.t < t1 do
      evol(s, f, t1)
   end
   print(s.t, s.y[0], s.y[1])
end
#include <stdio.h>
#include <gsl/gsl_errno.h>
#include <gsl/gsl_matrix.h>
#include <gsl/gsl_odeiv.h>

int
func (double t, const double y[], double f[],
      void *params)
{
  double mu = *(double *)params;
  f[0] = y[1];
  f[1] = -y[0] - mu*y[1]*(y[0]*y[0] - 1);
  return GSL_SUCCESS;
}

int
jac (double t, const double y[], double *dfdy,
     double dfdt[], void *params)
{
  double mu = *(double *)params;
  gsl_matrix_view dfdy_mat
    = gsl_matrix_view_array (dfdy, 2, 2);
  gsl_matrix * m = &dfdy_mat.matrix;
  gsl_matrix_set (m, 0, 0, 0.0);
  gsl_matrix_set (m, 0, 1, 1.0);
  gsl_matrix_set (m, 1, 0, -2.0*mu*y[0]*y[1] - 1.0);
  gsl_matrix_set (m, 1, 1, -mu*(y[0]*y[0] - 1.0));
  dfdt[0] = 0.0;
  dfdt[1] = 0.0;
  return GSL_SUCCESS;
}

int
main (void)
{
  const gsl_odeiv_step_type * T = gsl_odeiv_step_rkf45;
  int k;

  for (k=0; k < 10; k++)
    {
      gsl_odeiv_step * s = gsl_odeiv_step_alloc (T, 2);
      gsl_odeiv_control * c = gsl_odeiv_control_y_new (1e-6, 0.0);
      gsl_odeiv_evolve * e = gsl_odeiv_evolve_alloc (2);

      double mu = 10;
      gsl_odeiv_system sys = {func, jac, 2, &mu};

      double t = 0.0, t1 = 20000.0;
      double h = 1e-6;
      double y[2] = { 1.0, 0.0 };

      while (t < t1)
	{
	  int status = gsl_odeiv_evolve_apply (e, c, s,
					       &sys,
					       &t, t1,
					       &h, y);

	  if (status != GSL_SUCCESS)
	    break;
	}

      printf ("%g %g %g\n", t, y[0], y[1]);

      gsl_odeiv_evolve_free (e);
      gsl_odeiv_control_free (c);
      gsl_odeiv_step_free (s);
    }

  return 0;
}

Attachment: ode-defs.lua.in
Description: Binary data

--
-- A Lua preprocessor for template code specialization.
-- Adapted by Steve Donovan, based on original code of Rici Lake.
--

local M = {}

-------------------------------------------------------------------------------
local function preprocess(chunk, name, defs)

   local function parseDollarParen(pieces, chunk, s, e)
      local append, format = table.insert, string.format
      local s = 1
      for term, executed, e in chunk:gmatch("()$(%b())()") do
	 append(pieces,
		format("%q..(%s or '')..", chunk:sub(s, term - 1), executed))
	 s = e
      end
      append(pieces, format("%q", chunk:sub(s)))
   end

   local function parseHashLines(chunk)
      local append = table.insert
      local pieces, s, args = chunk:find("^\n*#ARGS%s*(%b())[ \t]*\n")
      if not args or find(args, "^%(%s*%)$") then
	 pieces, s = {"return function(_put) ", n = 1}, s or 1
      else
	 pieces = {"return function(_put, ", args:sub(2), n = 2}
      end
      while true do
	 local ss, e, lua = chunk:find("^#+([^\n]*\n?)", s)
	 if not e then
	    ss, e, lua = chunk:find("\n#+([^\n]*\n?)", s)
	    append(pieces, "_put(")
	    parseDollarParen(pieces, chunk:sub(s, ss))
	    append(pieces, ")")
	    if not e then break end
	 end
	 append(pieces, lua)
	 s = e + 1
      end
      append(pieces, " end")
      return table.concat(pieces)
   end

   local ppenv

   if defs._self then
      ppenv = defs._self
   else
      ppenv = {string= string, table= table, template= M}
      for k, v in pairs(defs) do ppenv[k] = v end
      ppenv._self = ppenv
      local include = function(filename)
			 return M.process(filename, ppenv)
		      end
      setfenv(include, ppenv)
      ppenv.include = include
   end

   local code = parseHashLines(chunk)
   local fcode = loadstring(code, name)
   if fcode then
      setfenv(fcode, ppenv)
      return fcode()
   end
end

local function read_file(filename)
   local f = io.open(filename)
   local content = f:read('*a')
   f:close()
   return content
end

local function process(filename, defs)
   local template = read_file(filename)
   local codegen = preprocess(template, 'ode_codegen', defs)
   local code = {}
   local add = function(s) code[#code+1] = s end
   codegen(add)
   return table.concat(code)
end

local function compile(filename, defs)
   return loadstring(process(filename, defs), 'ode_out')
end

M.process = process
M.compile = compile

return M