[Date Prev][Date Next][Thread Prev][Thread Next]
[Date Index]
[Thread Index]
- Subject: Re: LuaJIT2 performance for number crunching
- From: Francesco Abbate <francesco.bbt@...>
- Date: Tue, 22 Feb 2011 11:19:02 +0100
Hi,
I've modified the implementation of the rkf45 ode integrator in
vectorial form, now it is slightly simpler and I was hoping faster
because I dont pass a pointer to a member instance of an ffi
structure.
The results are still accurate but the program it is slow like before
or similar.
I add in attachment the two files. Please note that I've added a line
to load with ffi the 'libgslcblas-0' library. This should work on
windows with gsl-shell or plain luajit provided that you have
installed the library (included with gsl library).
I hope to get some help because I'm stuck for the moment...
Francesco
local abs, max, min = math.abs, math.max, math.min
local ffi = require "ffi"
local vecsize = 2 * ffi.sizeof('double')
cblas = ffi.load('libgslcblas-0')
ffi.cdef[[
typedef struct {
double t;
double h;
double y[2];
double dydt[2];
} odevec_state;
void cblas_daxpy (const int N, const double ALPHA,
const double * X, const int INCX,
double * Y, const int INCY);
]]
local function ode_new()
return ffi.new('odevec_state')
end
local function ode_init(s, t0, h0, f, y)
ffi.copy(s.y, y, vecsize)
f(t0, s.y, s.dydt)
s.t = t0
s.h = h0
end
local function hadjust(rmax, h)
local S = 0.9
if rmax > 1.1 then
local r = S / rmax^(1/5)
r = max(0.2, r)
return r * h, -1
elseif rmax < 0.5 then
local r = S / rmax^(1/(5+1))
r = max(1, min(r, 5))
return r * h, 1
end
return h, 0
end
local ws_y = ffi.new('double[2]')
local ws_k1 = ffi.new('double[2]')
local ws_k2 = ffi.new('double[2]')
local ws_k3 = ffi.new('double[2]')
local ws_k4 = ffi.new('double[2]')
local ws_k5 = ffi.new('double[2]')
local ws_k6 = ffi.new('double[2]')
local function rkf45_evolve(s, f, t1)
local t, h = s.t, s.h
local hadj, inc
ffi.copy (ws_k1, s.dydt, vecsize)
if t + h > t1 then h = t1 - t end
while h > 0 do
ffi.copy (ws_y, s.y, vecsize)
local rmax = 0
do
cblas.cblas_daxpy (2, h * 0.25, ws_k1, 1, ws_y, 1)
-- k2 step
f(t + 0.25 * h, ws_y, ws_k2)
ffi.copy (ws_y, s.y, vecsize)
cblas.cblas_daxpy (2, h * 0.09375, ws_k1, 1, ws_y, 1)
cblas.cblas_daxpy (2, h * 0.28125, ws_k2, 1, ws_y, 1)
-- k3 step
f(t + 0.375 * h, ws_y, ws_k3)
ffi.copy (ws_y, s.y, vecsize)
cblas.cblas_daxpy (2, h * 0.87938097405553, ws_k1, 1, ws_y, 1)
cblas.cblas_daxpy (2, h * -3.2771961766045, ws_k2, 1, ws_y, 1)
cblas.cblas_daxpy (2, h * 3.3208921256259, ws_k3, 1, ws_y, 1)
-- k4 step
f(t + 0.92307692307692 * h, ws_y, ws_k4)
ffi.copy (ws_y, s.y, vecsize)
cblas.cblas_daxpy (2, h * 2.0324074074074, ws_k1, 1, ws_y, 1)
cblas.cblas_daxpy (2, h * -8, ws_k2, 1, ws_y, 1)
cblas.cblas_daxpy (2, h * 7.1734892787524, ws_k3, 1, ws_y, 1)
cblas.cblas_daxpy (2, h * -0.20589668615984, ws_k4, 1, ws_y, 1)
-- k5 step
f(t + 1 * h, ws_y, ws_k5)
ffi.copy (ws_y, s.y, vecsize)
cblas.cblas_daxpy (2, h * -0.2962962962963, ws_k1, 1, ws_y, 1)
cblas.cblas_daxpy (2, h * 2, ws_k2, 1, ws_y, 1)
cblas.cblas_daxpy (2, h * -1.3816764132554, ws_k3, 1, ws_y, 1)
cblas.cblas_daxpy (2, h * 0.45297270955166, ws_k4, 1, ws_y, 1)
cblas.cblas_daxpy (2, h * -0.275, ws_k5, 1, ws_y, 1)
-- k6 step and final sum
-- since k2 is no more used we could use k2 to store k6
f(t + 0.5 * h, ws_y, ws_k6)
ffi.copy (ws_y, s.y, vecsize)
cblas.cblas_daxpy (2, h * 0.11851851851852, ws_k1, 1, ws_y, 1)
cblas.cblas_daxpy (2, h * 0.51898635477583, ws_k3, 1, ws_y, 1)
cblas.cblas_daxpy (2, h * 0.50613149034202, ws_k4, 1, ws_y, 1)
cblas.cblas_daxpy (2, h * -0.18, ws_k5, 1, ws_y, 1)
cblas.cblas_daxpy (2, h * 0.036363636363636, ws_k6, 1, ws_y, 1)
local yerr, r, d0
yerr = h * (0.0027777777777778 * ws_k1[0] + -0.029941520467836 * ws_k3[0] + -0.029199893673578 * ws_k4[0] + 0.02 * ws_k5[0] + 0.036363636363636 * ws_k6[0])
d0 = 0 * (1 * abs(ws_y[0])) + 1e-006
r = abs(yerr) / abs(d0)
rmax = max(r, rmax)
yerr = h * (0.0027777777777778 * ws_k1[1] + -0.029941520467836 * ws_k3[1] + -0.029199893673578 * ws_k4[1] + 0.02 * ws_k5[1] + 0.036363636363636 * ws_k6[1])
d0 = 0 * (1 * abs(ws_y[1])) + 1e-006
r = abs(yerr) / abs(d0)
rmax = max(r, rmax)
end
hadj, inc = hadjust(rmax, h)
if inc >= 0 then break end
h = hadj
end
f(t + h, ws_y, ws_k2)
ffi.copy (s.dydt, ws_k2, vecsize)
ffi.copy (s.y, ws_y, vecsize)
s.t = t + h
s.h = hadj
return h
end
return {new= ode_new, init= ode_init, evolve= rkf45_evolve}
Attachment:
rkf45vec.lua.in
Description: Binary data