lua-users home
lua-l archive

[Date Prev][Date Next][Thread Prev][Thread Next] [Date Index] [Thread Index]


Hi all,

I'm glad to announce that, thanks to the work of Lesley, we have now
in GSL Shell an implementation of the VEGAS algorithm for monte carlo
integration. The algorithm, for those that who don't already know, can
be used to integrate numerically a function in a multi-dimensional
space and it does use a random number generator to cover the
integration domain.

The algorithm is completely written in Lua using the FFI extension of
LuaJIT2. The template module is also used to improve performance by
generating the code on the fly for a given number of dimensions.

The VEGAS algorithm is already in the "gsl-shell-2" branch of GSL
Shell. In term of performance it is quite close to optimized C but C
is still faster by a factor of ~ 1.7. Here some basic figures:

LuaJIT2 -jon:            real	0m5.935s
LuaJIT2 -joff:            (too long to finish, I didn't wait)
C (gcc -O2):             real	0m3.574s


The algorithm can work also with vanilla LuaJIT2, in attachment a
patch for LuaJIT2 beta9 or the git HEAD. Everyone can play with it if
interested. In the patch is included also the benchmark files for the
Lua and C version. Please note that the GSL library is needed only to
test the implementation using the TAUS2 random number generator. The
algorithms works perfectly also without the GSL library, you can just
use the math.random function built-in with LuaJIT2.

Actually if the math.random is used the algorithm is faster that C but
we didn't clearly understand the reason.

Best regards,
Francesco
From 580622f1c711459c8540b55c18d1f24599aa6667 Mon Sep 17 00:00:00 2001
From: Francesco Abbate <francesco.bbt@gmail.com>
Date: Sun, 22 Jan 2012 01:07:36 +0100
Subject: [PATCH] Implement VEGAS algorithm for Montecarlo integration

---
 rng.lua           |   93 +++++++++++++++
 template.lua      |  108 +++++++++++++++++
 vegas-bench.c     |   63 ++++++++++
 vegas-bench.lua   |   30 +++++
 vegas-defs.lua.in |  339 +++++++++++++++++++++++++++++++++++++++++++++++++++++
 vegas.lua         |   61 ++++++++++
 6 files changed, 694 insertions(+), 0 deletions(-)
 create mode 100644 rng.lua
 create mode 100644 template.lua
 create mode 100644 vegas-bench.c
 create mode 100644 vegas-bench.lua
 create mode 100644 vegas-defs.lua.in
 create mode 100644 vegas.lua

diff --git a/rng.lua b/rng.lua
new file mode 100644
index 0000000..2f1a43c
--- /dev/null
+++ b/rng.lua
@@ -0,0 +1,93 @@
+
+local ffi = require 'ffi'
+
+ffi.cdef[[
+typedef struct
+  {
+    const char *name;
+    unsigned long int max;
+    unsigned long int min;
+    size_t size;
+    void (*set) (void *state, unsigned long int seed);
+    unsigned long int (*get) (void *state);
+    double (*get_double) (void *state);
+  }
+gsl_rng_type;
+
+typedef struct
+  {
+    const gsl_rng_type * type;
+    void *state;
+  }
+gsl_rng;
+
+const gsl_rng_type ** gsl_rng_types_setup(void);
+
+const gsl_rng_type *gsl_rng_default;
+unsigned long int gsl_rng_default_seed;
+
+gsl_rng *gsl_rng_alloc (const gsl_rng_type * T);
+int gsl_rng_memcpy (gsl_rng * dest, const gsl_rng * src);
+gsl_rng *gsl_rng_clone (const gsl_rng * r);
+
+void gsl_rng_free (gsl_rng * r);
+
+void gsl_rng_set (const gsl_rng * r, unsigned long int seed);
+unsigned long int gsl_rng_max (const gsl_rng * r);
+unsigned long int gsl_rng_min (const gsl_rng * r);
+const char *gsl_rng_name (const gsl_rng * r);
+
+const gsl_rng_type * gsl_rng_env_setup (void);
+
+unsigned long int gsl_rng_get (const gsl_rng * r);
+double gsl_rng_uniform (const gsl_rng * r);
+double gsl_rng_uniform_pos (const gsl_rng * r);
+unsigned long int gsl_rng_uniform_int (const gsl_rng * r, unsigned long int n);
+]]
+
+local gsl = ffi.load('libgsl')
+
+local rng_type = ffi.typeof('gsl_rng')
+
+local rng_mt = {
+   __index = {
+      getint = gsl.gsl_rng_uniform_int,
+      get    = gsl.gsl_rng_uniform,
+      set    = gsl.gsl_rng_set,
+   },
+}
+
+ffi.metatype(rng_type, rng_mt)
+
+local function rng_type_lookup(s)
+   if s then
+      local ts = gsl.gsl_rng_types_setup()
+      while ts[0] ~= nil do
+         local t = ts[0]
+         if ffi.string(t.name) == s then
+            return t
+         end
+         ts = ts+1
+      end
+      error('unknown generator type: ' .. s)
+   else
+      return gsl.gsl_rng_default
+   end
+end
+
+local function new(s)
+   local T = rng_type_lookup(s)
+   return ffi.gc(gsl.gsl_rng_alloc(T), gsl.gsl_rng_free)
+end
+
+local function list()
+   local t = {}
+   local ts = gsl.gsl_rng_types_setup()
+   while ts[0] ~= nil do
+      t[#t+1] = ffi.string(ts[0].name)
+      ts = ts+1
+   end
+   return t
+end
+
+return {new= new, list= list}
diff --git a/template.lua b/template.lua
new file mode 100644
index 0000000..7ddf8a9
--- /dev/null
+++ b/template.lua
@@ -0,0 +1,108 @@
+--
+-- A Lua preprocessor for template code specialization.
+-- Adapted by Steve Donovan, based on original code of Rici Lake.
+--
+
+local M = {}
+
+-------------------------------------------------------------------------------
+local function preprocess(chunk, name, defs)
+
+   local function parseDollarParen(pieces, chunk, s, e)
+      local append, format = table.insert, string.format
+      local s = 1
+      for term, executed, e in chunk:gmatch("()$(%b())()") do
+	 append(pieces,
+		format("%q..(%s or '')..", chunk:sub(s, term - 1), executed))
+	 s = e
+      end
+      append(pieces, format("%q", chunk:sub(s)))
+   end
+
+   local function parseHashLines(chunk)
+      local append = table.insert
+      local pieces, s, args = chunk:find("^\n*#ARGS%s*(%b())[ \t]*\n")
+      if not args or find(args, "^%(%s*%)$") then
+	 pieces, s = {"return function(_put) ", n = 1}, s or 1
+      else
+	 pieces = {"return function(_put, ", args:sub(2), n = 2}
+      end
+      while true do
+	 local ss, e, lua = chunk:find("^#+([^\n]*\n?)", s)
+	 if not e then
+	    ss, e, lua = chunk:find("\n#+([^\n]*\n?)", s)
+	    append(pieces, "_put(")
+	    parseDollarParen(pieces, chunk:sub(s, ss))
+	    append(pieces, ")")
+	    if not e then break end
+	 end
+	 append(pieces, lua)
+	 s = e + 1
+      end
+      append(pieces, " end")
+      return table.concat(pieces)
+   end
+
+   local ppenv
+
+   if defs._self then
+      ppenv = defs._self
+   else
+      ppenv = {string= string, table= table, tonumber= tonumber, template= M}
+      for k, v in pairs(defs) do ppenv[k] = v end
+      ppenv._self = ppenv
+      local include = function(filename)
+			 return M.process(filename, ppenv)
+		      end
+      setfenv(include, ppenv)
+      ppenv.include = include
+   end
+
+   local code = parseHashLines(chunk)
+   local fcode = loadstring(code, name)
+   if fcode then
+      setfenv(fcode, ppenv)
+      return fcode()
+   end
+end
+
+local function read_file(filename)
+   local f = io.open(filename)
+   if not f then
+      error(string.format('error opening template file %s', filename))
+   end
+   local content = f:read('*a')
+   f:close()
+   return content
+end
+
+local function process(name, defs)
+   local filename = name..'.lua.in'
+   local template = read_file(filename)
+   local codegen = preprocess(template, 'template_gen', defs)
+   local code = {}
+   local add = function(s) code[#code+1] = s end
+   codegen(add)
+   return table.concat(code)
+end
+
+local function template_error(code, filename, err)
+   local log = io.open('log-out.lua', 'w')
+   log:write(code)
+   log:close()
+   print('output code log in "log-out.lua"')
+   error('error loading ' .. filename .. ':' .. err)
+end
+
+local function load(filename, defs)
+   local code = process(filename, defs)
+   local f, err = loadstring(code, filename)
+   if not f then template_error(code, filename, err) end
+   return f()
+end
+
+M.process = process
+M.load = load
+
+return M
+
diff --git a/vegas-bench.c b/vegas-bench.c
new file mode 100644
index 0000000..3641533
--- /dev/null
+++ b/vegas-bench.c
@@ -0,0 +1,63 @@
+#include <gsl/gsl_rng.h>
+#include <gsl/gsl_monte_vegas.h>
+#include <stdlib.h>
+#include <gsl/gsl_math.h>
+double exact = 30720.;
+
+double 
+f (double *x, size_t dim, void *params)
+     {
+       return 1.*x[0]*x[0]+2.*x[1]*x[1]+3.*x[2]*x[2]
+	      +4.*x[3]*x[3]+5.*x[4]*x[4]+6.*x[5]*x[5]
+	      +7.*x[6]*x[6]+8.*x[7]*x[7]+9.*x[8]*x[8];
+     }
+void
+display_results (char *title, double result, double error, int i)
+     {
+       printf ("%s ==================\n", title);
+       printf ("result = % .6f\n", result);
+       printf ("sigma  = % .6f\n", error);
+       printf ("exact  = % .6f\n", exact);
+       printf ("error  = % .6f = %.2g sigma\n", result - exact,
+               fabs (result - exact) / error);
+       printf ("i      = % d\n", i);
+     }
+int
+main (void)
+   {
+    double res, err;
+    double a= 0.;
+    double b= 2.;
+    int dim=9;
+    double xl[9] = { a,a,a,a,a,a,a,a,a};
+    double xu[9] = { b,b,b,b,b,b,b,b,b};
+    gsl_monte_function G = { &f, dim, 0 };
+    size_t calls =1e6*dim;
+
+    gsl_rng_env_setup ();
+    gsl_rng *r = gsl_rng_alloc (gsl_rng_taus2);
+    gsl_rng_set (r, 30776);
+    
+    gsl_monte_vegas_state *s = gsl_monte_vegas_alloc (dim);
+
+    gsl_monte_vegas_integrate (&G, xl, xu, dim, 1e4, r, s,
+			    &res, &err);
+    //display_results ("vegas warm-up", res, err,0);
+
+    //printf ("converging...\n");
+    int i=0;
+    do
+    {
+      gsl_monte_vegas_integrate (&G, xl, xu, dim, calls/5, r, s,
+				&res, &err);
+      //printf ("result = % .6f sigma = % .6f chisq/dof = %.1f\n",
+      //	    res, err, gsl_monte_vegas_chisq (s));
+      i=i+1;
+    }
+    while (fabs (gsl_monte_vegas_chisq (s) - 1.0) > 0.5);
+
+    display_results ("vegas final", res, err, i);
+
+    gsl_monte_vegas_free (s);
+    return 1;
+    }
diff --git a/vegas-bench.lua b/vegas-bench.lua
new file mode 100644
index 0000000..4e08e6b
--- /dev/null
+++ b/vegas-bench.lua
@@ -0,0 +1,30 @@
+
+local rng = require 'rng'
+
+local monte_vegas = dofile('vegas.lua')
+
+local function testdim(n)
+  local lo,hi = 0,2
+  local exact = n*(n+1)/2 * (hi^3 - lo^3)/3 * (hi-lo)^(n-1)
+  local t={}
+  local a,b={},{}
+  for i=1,n do
+    t[i]=string.format("%s*x[%s]^2",i,i)
+    a[i],b[i]=lo,hi
+  end
+  local s=table.concat(t,"+")
+  io.write("Integrating ",s,"\nExact integral = ",exact,"\n")
+  local calls = 1e6*n
+  local r = rng.new('taus2')
+  local result,sigma,runs,cont = monte_vegas(loadstring("return function(x) return "..s.." end")(),a,b,calls,r)
+  io.write( string.format([[
+==================
+result = %.6f
+sigma  = %.6f
+exact  = %.6f
+error  = %.6f = %.2g sigma
+i      = %d
+]] ,result,sigma,exact, result - exact,  math.abs(result - exact)/sigma, runs))
+end
+
+    testdim(9)
diff --git a/vegas-defs.lua.in b/vegas-defs.lua.in
new file mode 100644
index 0000000..9731b4c
--- /dev/null
+++ b/vegas-defs.lua.in
@@ -0,0 +1,339 @@
+-- monte-vegas.lua
+--
+-- Copyright (C) 2012 Lesley De Cruz
+--
+-- This program is free software; you can redistribute it and/or modify
+-- it under the terms of the GNU General Public License as published by
+-- the Free Software Foundation; either version 3 of the License, or (at
+-- your option) any later version.
+--
+-- This program is distributed in the hope that it will be useful, but
+-- WITHOUT ANY WARRANTY; without even the implied warranty of
+-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+-- General Public License for more details.
+--
+-- You should have received a copy of the GNU General Public License
+-- along with this program; if not, write to the Free Software
+-- Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
+--
+-- This is an implementation of the adaptive Monte-Carlo algorithm "VEGAS"
+-- of G. P. Lepage, originally described in J. Comp. Phys. 27, 192(1978).
+-- The current version of the algorithm was described in the Cornell
+-- preprint CLNS-80/447 of March, 1980.
+--
+-- Adapted from GSL, version 1.15
+--
+-- Original author: Michael J. Booth, 1996
+-- Modified by: Brian Gough, 12/2000
+-- Adapted for LuaJIT2 by Lesley De Cruz and Francesco Abbate, 2012
+
+local ffi = require 'ffi'
+
+local floor, min, max = math.floor, math.min, math.max
+local modf, sqrt, log  = math.modf, math.sqrt, math.log
+
+-- a table is used here because this will be passed to the user defined function
+local x = {} -- evaluate the function at x (# dim)
+
+local dx  = ffi.new('double[$(N)]') -- ranges, delta x (# dim)
+local box = ffi.new('int[$(N)]')    -- current box coordinates (integer) (# dim)
+local bin = ffi.new('int[$(N)]')    -- current bin coordinates (integer) (# dim)
+
+-- distribution (depends on function^2) (# dim*bins)
+local d  = ffi.new('double[$(N)][$(K)]')
+
+-- bin boundaries, i.e. grid (# dim*(bins+1))
+local xi  = ffi.new('double[$(N)][$(K+1)]')
+local xin = ffi.new('double[$(K+1)]')
+
+-- ratio of bin sizes (# bins)
+local weight = ffi.new('double[$(K)]') -- ratio of bin sizes (# bins)
+
+
+local bins = 1          -- number of bins
+local boxes = 0         -- number of boxes
+local volume = 1        -- volume of the integration domain
+
+-- control variables
+local alpha = 1.5         -- grid stiffness (for rebinning), typically between
+                        -- 1 and 2 (higher is more adaptive, 0 is rigid)
+local mode = $(MODE_IMPORTANCE)
+local iterations = 5
+
+-- intermediate results for an iteration
+local result = 0
+local sigma = 0
+
+-- intermediate results for an integrate(...)
+
+-- weighted sum of integrals of each iteration (numerator)
+local wtd_int_sum = 0
+local sum_wgts = 0        -- sum of weights (denominator)
+local chi_sum = 0         -- sum of squares of the integrals computed this run
+local it_num = 1          -- current iteration
+local it_start = 1        -- start iteration for this run
+local samples = 0         -- number of integrals computed this run
+local chisq = 0           -- chi^2 for the integrals computed this run
+local calls_per_box = 2
+local jac
+
+--- initialise a fresh vegas state
+--- NB: function argumens are table indexed from 1
+local function init(a, b)
+    for i=0, $(N-1) do
+        assert(a[i+1] < b[i+1],"lower bound should be smaller than upper bound")
+        dx[i] = b[i+1] - a[i+1]
+        volume = volume * dx[i]
+        xi[i][0] = 0
+        xi[i][1] = 1
+    end
+end
+
+--- reset the distribution of the grid and the current box coordinates
+local function reset_val_and_box()
+    ffi.fill(box, $(N * SIZE_OF_INT))
+    ffi.fill(d, $(N * K * SIZE_OF_DOUBLE))
+end
+
+-- step through the box coordinates like
+-- {0, 1},..., {0, boxes-1}, {1, 0}, {1, 1},..., {1, boxes-1}, ...
+-- returns true when reaching {0,0} again
+local function boxes_traversed()
+#   for i= N-1,0,-1 do
+        box[$(i)] = ( (box[$(i)] + 1) % boxes)
+        if box[$(i)] ~= 0 then return false end
+#   end
+    return true
+end
+
+-- return a random point from the box, weighted with bin_vol
+-- "a" will be a table indexed from 1
+local function random_point(a, x_out, rget)
+    local vol = 1
+#   for i=0, N-1 do
+    do
+        -- box[j] + ran gives the position in the box units,
+        -- while z is the position in bin units.
+        local z = (( box[$(i)] + rget() ) / boxes ) * bins + 1
+        local k, loc = modf(z) -- int: bin index and fract: location inside bin
+        bin[$(i)] = k-1
+        local bin_width = xi[$(i)][k] - xi[$(i)][k-1]
+        local y = xi[$(i)][k-1] + loc * bin_width
+        x_out[$(i+1)] = a[$(i+1)] + y * dx[$(i)]
+        vol = vol * bin_width
+    end
+#   end
+    return vol
+end
+
+-- keep track of the squared function value in each bin
+-- to later refine the grid
+local function accumulate_distribution(fsq)
+#   for i=0, N-1 do
+    do
+        local bin_$(i) = bin[$(i)]
+        d[$(i)][bin_$(i)] = d[$(i)][bin_$(i)] + fsq
+    end
+#   end
+end
+
+-- clear the results, but keep the grid
+-- done in stage 0 and 1 in GSL
+local function clear_stage1()
+    wtd_int_sum = 0
+    sum_wgts = 0
+    chi_sum = 0
+    it_num = 1
+    samples = 0
+    chisq = 0
+end
+
+-- intelligently resize the old grid given the new number of bins
+local function resize(req_bins)
+    local pts_per_bin = bins / req_bins
+    for i=0, $(N-1) do
+          xin[0] = 0
+          local xold,xnew,dw,j=0,0,0,1
+          for k=1, bins do
+              dw = dw + 1
+              xold, xnew =  xnew, xi[i][k]
+              while dw > pts_per_bin do
+                  dw = dw - pts_per_bin
+                  xin[j] = xnew - (xnew - xold) * dw
+                  j = j + 1
+              end
+          end
+          ffi.copy(xi[i], xin, j * $(SIZE_OF_DOUBLE))
+          xi[i][req_bins] = 1
+          -- distribution (depends on function^2) (# dim*bins)
+    end
+    bins = req_bins
+end
+
+-- refine the grid based on accumulated stats in self.d.
+local function refine()
+    for i=0, $(N-1) do
+          -- implements gs[i][j] = (gs[i][j-1]+gs[i][j]+gs[i][j+1])/3
+          local oldg,newg = d[i][0],d[i][1]
+          -- total grid value for dimension i
+          local grid_tot_i = (oldg + newg) / 2
+          d[i][0] = grid_tot_i
+          for j=1,bins-2 do
+              oldg, newg, d[i][j] = newg, d[i][j+1],(oldg + newg + d[i][j+1]) / 3
+              grid_tot_i = grid_tot_i + d[i][j]
+          end
+          d[i][bins - 1] = (oldg + newg) / 2
+          grid_tot_i = grid_tot_i + d[i][bins - 1]
+
+          local tot_weight = 0
+          for j=0, bins - 1 do
+              weight[j] = 0
+              if d[i][j] > 0 then
+                  local invwt = grid_tot_i / d[i][j] -- kind of "inverse weight"
+                  -- damped change
+                  weight[j] = ((invwt - 1) / (invwt* log(invwt)))^alpha
+              end
+              tot_weight = tot_weight + weight[j]
+          end
+
+          -- now determine the new bin boundaries
+          local pts_per_bin = tot_weight / bins
+          if pts_per_bin ~= 0 then -- don't update grid if tot_weight==0
+              xin[0] = 0
+              local xold,xnew,dw,j = 0,0,0,1
+              for k=0, bins - 1 do
+                  dw = dw + weight[k]
+                  xold, xnew = xnew, xi[i][k+1]
+                  while dw > pts_per_bin do
+                      dw = dw - pts_per_bin
+                      xin[j] = xnew - (xnew - xold) * dw / weight[k]
+                      j = j + 1
+                  end
+              end
+              ffi.copy(xi[i], xin, j * $(SIZE_OF_DOUBLE))
+              xi[i][bins] = 1
+	  end
+    end
+end
+
+-- determine the number of calls, bins, boxes etc.
+-- based on the requested number of calls
+-- intelligently rebin the old grid
+-- done in stage 0,1 and 2 in GSL
+local function rebin_stage2(calls)
+    local new_bins = $(K)
+    boxes = 1
+    if mode ~= $(MODE_IMPORTANCE_ONLY) then
+        -- shooting for 2 calls/box
+        boxes = floor((calls/2)^(1/$(N)))
+        mode = $(MODE_IMPORTANCE)
+        if 2*boxes >= $(K) then
+            -- if there are too many boxes, we switch to stratified sampling
+            local box_per_bin = max(floor(boxes/$(K)),1)
+            new_bins = min(floor(boxes/box_per_bin), $(K))
+            boxes = box_per_bin * new_bins
+            mode = $(MODE_STRATIFIED)
+        end
+    end
+
+    local tot_boxes= boxes^$(N)
+    calls_per_box = max(floor(calls/tot_boxes),2)
+    calls = calls_per_box * tot_boxes
+    -- x-space volume / avg number of calls per bin
+    jac = volume * new_bins^$(N) / calls
+    -- If the number of bins changes from the previous invocation, bins
+    -- are expanded or contracted accordingly, while preserving bin
+    -- density
+    if new_bins ~= bins then
+        resize(new_bins)
+    end
+end
+
+--- run (self.iterations) integrations
+-- "a" will be a table indexed from 1
+local function integrate(f, a, rget)
+    it_start = it_num
+    local cum_int, cum_sig = 0, 0
+    for it= 1, iterations do
+        local intgrl = 0 -- integral for this iteration
+        local tss = 0 -- total squared sum
+
+        it_num = it_start + it
+        reset_val_and_box()
+
+        repeat
+            local m,q = 0,0 -- first and second moment
+            local f_sq_sum = 0
+            for k=1,calls_per_box do
+                 local bin_vol = random_point(a, x, rget)
+                 local fval = jac * bin_vol * f(x)
+
+                 -- incrementally calculate first (mean) and second moments
+                 local d = fval - m
+                 m = m + d / (k)
+                 q = q + d*d * ((k-1)/k)
+                 if mode ~= $(MODE_STRATIFIED) then
+                     accumulate_distribution(fval*fval)
+                 end
+
+             end
+
+             intgrl = intgrl + m * calls_per_box;
+             f_sq_sum = q * calls_per_box;
+             tss = tss + f_sq_sum;
+             if mode == $(MODE_STRATIFIED) then
+                 accumulate_distribution(f_sq_sum)
+             end
+        until boxes_traversed()
+
+        -- Compute final results for this iteration
+        -- Determine variance and weight
+        local var, wgt = tss / (calls_per_box - 1), 0
+
+        if var > 0 then
+            wgt = 1 / var
+        elseif sum_wgts > 0 then
+            wgt = sum_wgts / samples
+        end
+        result = intgrl
+        sigma = sqrt(var)
+
+        if wgt > 0 then
+            local old_sum_wgts = sum_wgts
+            local m = (sum_wgts > 0) and (wtd_int_sum / sum_wgts) or 0
+            local q = intgrl - m
+
+            -- update stats
+            samples = samples + 1
+            sum_wgts = sum_wgts + wgt
+            wtd_int_sum = wtd_int_sum + intgrl * wgt
+            chi_sum = chi_sum + intgrl * intgrl * wgt
+            cum_int = wtd_int_sum / sum_wgts
+            cum_sig = sqrt(1 / sum_wgts)
+
+            if samples == 1 then
+                chisq = 0
+            else
+                chisq = chisq * (samples - 2)
+                chisq = chisq + (wgt / (1 + (wgt / old_sum_wgts))) * q * q
+                chisq = chisq / (samples - 1)
+            end
+        else
+            cum_int = cum_int + (intgrl - cum_int) / it
+            cum_sig = 0
+        end
+
+        refine()
+    end
+
+    return cum_int, cum_sig
+end
+
+return {
+    init         = init,
+    integrate    = integrate,
+    clear_stage1 = clear_stage1,
+    rebin_stage2 = rebin_stage2,
+    iterations   = function() return iterations end,
+    chisq        = function() return chisq end,
+}
diff --git a/vegas.lua b/vegas.lua
new file mode 100644
index 0000000..6e92971
--- /dev/null
+++ b/vegas.lua
@@ -0,0 +1,61 @@
+local template = require 'template'
+local ffi = require 'ffi'
+
+local abs, random = math.abs, math.random
+
+local spec = {
+   K = 50, -- bins max. even integer, will be divided by two
+   SIZE_OF_INT = ffi.sizeof('int'),
+   SIZE_OF_DOUBLE = ffi.sizeof('double'),
+   MODE_IMPORTANCE = 1,
+   MODE_IMPORTANCE_ONLY = 2,
+   MODE_STRATIFIED = 3,
+}
+
+--- perform VEGAS monte carlo integration of f
+-- @param f function of an n-dimensional vector
+-- @param a lower bound vector
+-- @param b upper bound vector
+-- @param calls number of function calls (will be rounded down to fit grid) (optional)
+-- @param r random number generator (optional)
+-- @param chidev deviation tolerance for the integrals' chi^2 value (optional)
+-- 	  integration will be repeated until chi^2 < chidev
+-- @return result the result of the integration
+-- @return sigma the estimated error or standard deviation
+-- @return num_int the number of runs required to calculate the integral
+-- @return run function to compute the integral again via run(calls)
+local function monte_vegas(f, a, b, calls, r, chidev)
+  calls = calls or 5e5
+  local rget_call = r and r.get
+  local rget = r and (function() return rget_call(r) end) or random
+  chidev = chidev or 0.5
+  local dim = #a
+  assert(dim==#b,"number of dimensions of lower and upper bounds differ")
+
+  spec.N = dim
+  local state = template.load('vegas-defs', spec)
+  state.init(a, b)
+
+  -- INTEGRATION
+  -- warmup
+  state.clear_stage1() -- clear results
+  state.rebin_stage2(1e4) -- intialise grid for 1e4 calls
+  local result,sigma = state.integrate(f,a,rget)
+  local n
+  -- full (stage 1)
+  local run = function (c)
+    calls = c or calls
+    n=0
+    repeat
+      state.clear_stage1() -- forget previous results, but not the grid
+      state.rebin_stage2(calls/state.iterations()) -- initialise grid for calls/iterations calls
+      result,sigma = state.integrate(f,a,rget)
+      n=n+1
+    until abs(state.chisq() - 1) < chidev
+    return result,sigma,n
+  end
+  result, sigma, n = run(calls)
+  return result, sigma, n, run
+end
+
+return monte_vegas
-- 
1.7.5.4