lua-users home
lua-l archive

[Date Prev][Date Next][Thread Prev][Thread Next] [Date Index] [Thread Index]


"Reentrant" in this case refers to executing a function call by restarting the current loop. (To be honest, I don't like using the term this way, but it was already written in the code, see CIST_REENTRY in lstate.h.)

This patch makes metamethod calls written in Lua reenter the execution loop. This avoids nesting calls to luaV_execute which will trigger a "C stack overflow" error on recursive metamethods. I encountered this because of a recent thread. (http://lua-users.org/lists/lua-l/2010-09/msg00644.html)

I didn't have to change very much because the most critical part, finishing an operation after the call returns, has already been implemented in luaV_finishOp for yieldable metamethods.

Very lightly tested at this point.

debug = require "debug"
local rawlen = table.getn
local MT = {
  -- Deep compare tables
   __eq = function(a,b)
    if rawlen(a)==rawlen(b) then
      for i=1,rawlen(a) do
        if a[i]~=b[i] then return false end
      end
      return true
    end
    return false
  end,
  -- Completely frivolous
  __concat = function(a,b)
    local s = {}
    if type(a)=='table' then
      for i=1,rawlen(a) do
        s[#s+1] = a[i] .. b
      end
      s[#s+1] = b
    else
      s[#s+1] = a
      for i=1,rawlen(b) do
        s[#s+1] = a .. b[i]
      end
    end
    return table.concat(s)
  end}
local T = '*'
local U = '*'
for i=1,200 do
  T = setmetatable({T},MT)
  U = setmetatable({U},MT)
end
print(T==U)
print(T..'-'..T) -- Fun fact: this always produces symmetrical output
-- Sum(1..n)
debug.setmetatable(1, {
    __len = function(n)
      if n == 1 then return 1 end
      return n + #(n-1)
    end
  })
-- Find more palindromes
for i=5,500 do
  local s=tostring(#i)
  local a,b=1,#s
  while b>a do
    if string.sub(s,a,a)~=string.sub(s,b,b) then
      break
    end
    a=a+1
    b=b-1
  end
  if b<=a then
    print(i,#i)
  end
end

This patch and sample code is public domain.

--
- tom
telliamed@whoopdedo.org
diff -urN lua-5.2.0-work4-orig/src/lvm.c lua-5.2.0-work4/src/lvm.c
--- lua-5.2.0-work4-orig/src/lvm.c	2010-06-30 10:11:17 -0400
+++ lua-5.2.0-work4/src/lvm.c	2010-09-22 02:23:13 -0400
@@ -81,8 +81,8 @@
 }
 
 
-static void callTM (lua_State *L, const TValue *f, const TValue *p1,
-                    const TValue *p2, TValue *p3, int hasres) {
+static int callTM (lua_State *L, const TValue *f, const TValue *p1,
+                    const TValue *p2, TValue *p3, int hasres, int reent) {
   ptrdiff_t result = savestack(L, p3);
   setobj2s(L, L->top++, f);  /* push function */
   setobj2s(L, L->top++, p1);  /* 1st argument */
@@ -91,15 +91,22 @@
     setobj2s(L, L->top++, p3);  /* 3rd argument */
   luaD_checkstack(L, 0);
   /* metamethod may yield only when called from Lua code */
-  luaD_call(L, L->top - (4 - hasres), hasres, isLua(L->ci));
+  if (reent) {
+    if (!luaD_precall(L, L->top - (4 - hasres), hasres))
+      return 0;
+    luaC_checkGC(L);
+  }
+  else
+    luaD_call(L, L->top - (4 - hasres), hasres, isLua(L->ci));
   if (hasres) {  /* if has result, move it to its place */
     p3 = restorestack(L, result);
     setobjs2s(L, p3, --L->top);
   }
+  return 1;
 }
 
 
-void luaV_gettable (lua_State *L, const TValue *t, TValue *key, StkId val) {
+int luaV_gettable_ (lua_State *L, const TValue *t, TValue *key, StkId val, int reent) {
   int loop;
   for (loop = 0; loop < MAXTAGLOOP; loop++) {
     const TValue *tm;
@@ -109,23 +116,23 @@
       if (!ttisnil(res) ||  /* result is not nil? */
           (tm = fasttm(L, h->metatable, TM_INDEX)) == NULL) { /* or no TM? */
         setobj2s(L, val, res);
-        return;
+        return 1;
       }
       /* else will try the tag method */
     }
     else if (ttisnil(tm = luaT_gettmbyobj(L, t, TM_INDEX)))
       luaG_typeerror(L, t, "index");
     if (ttisfunction(tm)) {
-      callTM(L, tm, t, key, val, 1);
-      return;
+      return callTM(L, tm, t, key, val, 1, reent);
     }
     t = tm;  /* else repeat with 'tm' */
   }
   luaG_runerror(L, "loop in gettable");
+  return 1;
 }
 
 
-void luaV_settable (lua_State *L, const TValue *t, TValue *key, StkId val) {
+int luaV_settable_ (lua_State *L, const TValue *t, TValue *key, StkId val, int reent) {
   int loop;
   TValue temp;
   for (loop = 0; loop < MAXTAGLOOP; loop++) {
@@ -137,33 +144,31 @@
           (tm = fasttm(L, h->metatable, TM_NEWINDEX)) == NULL) { /* or no TM? */
         setobj2t(L, oldval, val);
         luaC_barrierback(L, obj2gco(h), val);
-        return;
+        return 1;
       }
       /* else will try the tag method */
     }
     else if (ttisnil(tm = luaT_gettmbyobj(L, t, TM_NEWINDEX)))
       luaG_typeerror(L, t, "index");
     if (ttisfunction(tm)) {
-      callTM(L, tm, t, key, val, 0);
-      return;
+      return callTM(L, tm, t, key, val, 0, reent);
     }
     /* else repeat with 'tm' */
     setobj(L, &temp, tm);  /* avoid pointing inside table (may rehash) */
     t = &temp;
   }
   luaG_runerror(L, "loop in settable");
+  return 1;
 }
 
-
-static int call_binTM (lua_State *L, const TValue *p1, const TValue *p2,
-                       StkId res, TMS event) {
+static const TValue *get_binTM (lua_State *L, const TValue *p1, const TValue *p2,
+                       TMS event) {
   const TValue *tm = luaT_gettmbyobj(L, p1, event);  /* try first operand */
   if (ttisnil(tm))
     tm = luaT_gettmbyobj(L, p2, event);  /* try second operand */
-  if (ttisnil(tm)) return 0;
+  if (ttisnil(tm)) return NULL;
   if (event == TM_UNM) p2 = luaO_nilobject;
-  callTM(L, tm, p1, p2, res, 1);
-  return 1;
+  return tm;
 }
 
 
@@ -183,10 +188,11 @@
 
 static int call_orderTM (lua_State *L, const TValue *p1, const TValue *p2,
                          TMS event) {
-  if (!call_binTM(L, p1, p2, L->top, event))
+  const TValue *tm = get_binTM(L, p1, p2, event);
+  if (!tm)
     return -1;  /* no metamethod */
-  else
-    return !l_isfalse(L->top);
+  callTM(L, tm, p1, p2, L->top, 1, 0);
+  return !l_isfalse(L->top);
 }
 
 
@@ -212,33 +218,49 @@
 }
 
 
-int luaV_lessthan (lua_State *L, const TValue *l, const TValue *r) {
-  int res;
+int luaV_lessthan_ (lua_State *L, const TValue *l, const TValue *r, int reent) {
   if (ttisnumber(l) && ttisnumber(r))
     return luai_numlt(L, nvalue(l), nvalue(r));
   else if (ttisstring(l) && ttisstring(r))
     return l_strcmp(rawtsvalue(l), rawtsvalue(r)) < 0;
-  else if ((res = call_orderTM(L, l, r, TM_LT)) != -1)
-    return res;
+  else {
+    const TValue *tm = get_binTM(L, l, r, TM_LT);
+    if (tm != NULL) {
+      if (callTM(L, tm, l, r, L->top, 1, reent))
+        return !l_isfalse(L->top);
+      else
+        return -1;
+    }
+  }
   return luaG_ordererror(L, l, r);
 }
 
 
-int luaV_lessequal (lua_State *L, const TValue *l, const TValue *r) {
-  int res;
+int luaV_lessequal_ (lua_State *L, const TValue *l, const TValue *r, int reent) {
   if (ttisnumber(l) && ttisnumber(r))
     return luai_numle(L, nvalue(l), nvalue(r));
   else if (ttisstring(l) && ttisstring(r))
     return l_strcmp(rawtsvalue(l), rawtsvalue(r)) <= 0;
-  else if ((res = call_orderTM(L, l, r, TM_LE)) != -1)  /* first try `le' */
-    return res;
-  else if ((res = call_orderTM(L, r, l, TM_LT)) != -1)  /* else try `lt' */
-    return !res;
+  else {
+    const TValue *tm = get_binTM(L, l, r, TM_LE);  /* first try `le' */
+    if (tm != NULL) {
+      if (callTM(L, tm, l, r, L->top, 1, reent))
+        return !l_isfalse(L->top);
+      else
+        return -1;
+    }
+    else if ((tm = get_binTM(L, r, l, TM_LT)) != NULL) {  /* else try `lt' */
+      if (callTM(L, tm, r, l, L->top, 1, reent))
+        return l_isfalse(L->top);
+      else
+        return -1;
+    }
+  }
   return luaG_ordererror(L, l, r);
 }
 
 
-int luaV_equalval_ (lua_State *L, const TValue *t1, const TValue *t2) {
+int luaV_equalval_ (lua_State *L, const TValue *t1, const TValue *t2, int reent) {
   const TValue *tm;
   lua_assert(ttype(t1) == ttype(t2));
   switch (ttype(t1)) {
@@ -261,19 +283,25 @@
     default: return gcvalue(t1) == gcvalue(t2);
   }
   if (tm == NULL) return 0;  /* no TM? */
-  callTM(L, tm, t1, t2, L->top, 1);  /* call TM */
-  return !l_isfalse(L->top);
+  if (callTM(L, tm, t1, t2, L->top, 1, reent))  /* call TM */
+    return !l_isfalse(L->top);
+  else
+    return -1;
 }
 
 
-void luaV_concat (lua_State *L, int total) {
+int luaV_concat_ (lua_State *L, int total, int reent) {
   lua_assert(total >= 2);
   do {
     StkId top = L->top;
     int n = 2;  /* number of elements handled in this pass (at least 2) */
     if (!(ttisstring(top-2) || ttisnumber(top-2)) || !tostring(L, top-1)) {
-      if (!call_binTM(L, top-2, top-1, top-2, TM_CONCAT))
+      const TValue *tm = get_binTM(L, top-2, top-1, TM_CONCAT);
+      if (!tm)
         luaG_concaterror(L, top-2, top-1);
+      else
+        if (!callTM(L, tm, top-2, top-1, top-2, 1, reent))
+          return 0;
     }
     else if (tsvalue(top-1)->len == 0)  /* second operand is empty? */
       (void)tostring(L, top - 2);  /* result is first operand */
@@ -303,10 +331,11 @@
     total -= n-1;  /* got 'n' strings to create 1 new */
     L->top -= n-1;  /* poped 'n' strings and pushed one */
   } while (total > 1);  /* repeat until only 1 result left */
+  return 1;
 }
 
 
-void luaV_objlen (lua_State *L, StkId ra, const TValue *rb) {
+int luaV_objlen_ (lua_State *L, StkId ra, const TValue *rb, int reent) {
   const TValue *tm;
   switch (ttype(rb)) {
     case LUA_TTABLE: {
@@ -314,11 +343,11 @@
       tm = fasttm(L, h->metatable, TM_LEN);
       if (tm) break;  /* metamethod? break switch to call it */
       setnvalue(ra, cast_num(luaH_getn(h)));  /* else primitive len */
-      return;
+      return 1;
     }
     case LUA_TSTRING: {
       setnvalue(ra, cast_num(tsvalue(rb)->len));
-      return;
+      return 1;
     }
     default: {  /* try metamethod */
       tm = luaT_gettmbyobj(L, rb, TM_LEN);
@@ -327,12 +356,12 @@
       break;
     }
   }
-  callTM(L, tm, rb, luaO_nilobject, ra, 1);
+  return callTM(L, tm, rb, luaO_nilobject, ra, 1, reent);
 }
 
 
-void luaV_arith (lua_State *L, StkId ra, const TValue *rb,
-                 const TValue *rc, TMS op) {
+int luaV_arith_ (lua_State *L, StkId ra, const TValue *rb,
+                 const TValue *rc, TMS op, int reent) {
   TValue tempb, tempc;
   const TValue *b, *c;
   if ((b = luaV_tonumber(rb, &tempb)) != NULL &&
@@ -340,8 +369,14 @@
     lua_Number res = luaO_arith(op - TM_ADD + LUA_OPADD, nvalue(b), nvalue(c));
     setnvalue(ra, res);
   }
-  else if (!call_binTM(L, rb, rc, ra, op))
-    luaG_aritherror(L, rb, rc);
+  else {
+    const TValue *tm = get_binTM(L, rb, rc, op);
+    if (!tm)
+      luaG_aritherror(L, rb, rc);
+    else
+      return callTM(L, tm, rb, rc, ra, 1, reent);
+  }
+  return 1;
 }
 
 
@@ -393,7 +428,7 @@
 /*
 ** finish execution of an opcode interrupted by an yield
 */
-void luaV_finishOp (lua_State *L) {
+int luaV_finishOp_ (lua_State *L, int reent) {
   CallInfo *ci = L->ci;
   StkId base = ci->u.l.base;
   Instruction inst = *(ci->u.l.savedpc - 1);  /* interrupted instruction */
@@ -425,7 +460,8 @@
       setobj2s(L, top - 2, top);  /* put TM result in proper position */
       if (total > 1) {  /* are there elements to concat? */
         L->top = top - 1;  /* top is one after last element (at top-2) */
-        luaV_concat(L, total);  /* concat them (may yield again) */
+        if (!luaV_concat_(L, total, reent))  /* concat them (may yield again) */
+          return 0;
       }
       /* move final result to final position */
       setobj2s(L, ci->u.l.base + GETARG_A(inst), L->top - 1);
@@ -446,6 +482,7 @@
       break;
     default: lua_assert(0);
   }
+  return 1;
 }
 
 
@@ -470,6 +507,11 @@
 
 
 #define Protect(x)	{ {x;}; base = ci->u.l.base; }
+/* restart luaV_execute over new Lua function */
+#define Reentrant(x)	if(x){  \
+          ci = L->ci; \
+          ci->callstatus |= CIST_REENTRY; \
+          goto newframe; }
 
 #define checkGC(L)	Protect(luaC_checkGC(L); luai_threadyield(L);)
 
@@ -481,7 +523,7 @@
           lua_Number nb = nvalue(rb), nc = nvalue(rc); \
           setnvalue(ra, op(L, nb, nc)); \
         } \
-        else { Protect(luaV_arith(L, ra, rb, rc, tm)); } }
+        else { Protect(Reentrant(!luaV_arith_(L, ra, rb, rc, tm, 1))); } }
 
 
 #define vmdispatch(o)	switch(o)
@@ -534,14 +576,14 @@
       )
       vmcase(OP_GETTABUP,
         int b = GETARG_B(i);
-        Protect(luaV_gettable(L, cl->upvals[b]->v, RKC(i), ra));
+        Protect(Reentrant(!luaV_gettable_(L, cl->upvals[b]->v, RKC(i), ra, 1)));
       )
       vmcase(OP_GETTABLE,
-        Protect(luaV_gettable(L, RB(i), RKC(i), ra));
+        Protect(Reentrant(!luaV_gettable_(L, RB(i), RKC(i), ra, 1)));
       )
       vmcase(OP_SETTABUP,
         int a = GETARG_A(i);
-        Protect(luaV_settable(L, cl->upvals[a]->v, RKB(i), RKC(i)));
+        Protect(Reentrant(!luaV_settable_(L, cl->upvals[a]->v, RKB(i), RKC(i), 1)));
       )
       vmcase(OP_SETUPVAL,
         UpVal *uv = cl->upvals[GETARG_B(i)];
@@ -549,7 +591,7 @@
         luaC_barrier(L, uv, ra);
       )
       vmcase(OP_SETTABLE,
-        Protect(luaV_settable(L, ra, RKB(i), RKC(i)));
+        Protect(Reentrant(!luaV_settable_(L, ra, RKB(i), RKC(i), 1)));
       )
       vmcase(OP_NEWTABLE,
         int b = GETARG_B(i);
@@ -563,7 +605,7 @@
       vmcase(OP_SELF,
         StkId rb = RB(i);
         setobjs2s(L, ra+1, rb);
-        Protect(luaV_gettable(L, rb, RKC(i), ra));
+        Protect(Reentrant(!luaV_gettable_(L, rb, RKC(i), ra, 1)));
       )
       vmcase(OP_ADD,
         arith_op(luai_numadd, TM_ADD);
@@ -590,7 +632,7 @@
           setnvalue(ra, luai_numunm(L, nb));
         }
         else {
-          Protect(luaV_arith(L, ra, rb, rb, TM_UNM));
+          Protect(Reentrant(!luaV_arith_(L, ra, rb, rb, TM_UNM, 1)));
         }
       )
       vmcase(OP_NOT,
@@ -598,13 +640,13 @@
         setbvalue(ra, res);
       )
       vmcase(OP_LEN,
-        Protect(luaV_objlen(L, ra, RB(i)));
+        Protect(Reentrant(!luaV_objlen_(L, ra, RB(i), 1)));
       )
       vmcase(OP_CONCAT,
         int b = GETARG_B(i);
         int c = GETARG_C(i);
         L->top = base + c + 1;  /* mark the end of concat operands */
-        Protect(luaV_concat(L, c-b+1); checkGC(L);)
+        Protect(Reentrant(!luaV_concat_(L, c-b+1, 1)); checkGC(L);)
         L->top = ci->top;  /* restore top */
         setobjs2s(L, RA(i), base+b);
       )
@@ -615,21 +657,27 @@
         TValue *rb = RKB(i);
         TValue *rc = RKC(i);
         Protect(
-          if (equalobj(L, rb, rc) == GETARG_A(i))
+          int res = (ttype(rb) == ttype(rc)) ? luaV_equalval_(L, rb, rc, 1) : 0;
+          Reentrant(res==-1);
+          if (res == GETARG_A(i))
             dojump(GETARG_sBx(*ci->u.l.savedpc));
         )
         ci->u.l.savedpc++;
       )
       vmcase(OP_LT,
         Protect(
-          if (luaV_lessthan(L, RKB(i), RKC(i)) == GETARG_A(i))
+          int res = luaV_lessthan_(L, RKB(i), RKC(i), 1);
+          Reentrant(res==-1);
+          if (res == GETARG_A(i))
             dojump(GETARG_sBx(*ci->u.l.savedpc));
         )
         ci->u.l.savedpc++;
       )
       vmcase(OP_LE,
         Protect(
-          if (luaV_lessequal(L, RKB(i), RKC(i)) == GETARG_A(i))
+          int res = luaV_lessequal_(L, RKB(i), RKC(i), 1);
+          Reentrant(res==-1);
+          if (res == GETARG_A(i))
             dojump(GETARG_sBx(*ci->u.l.savedpc));
         )
         ci->u.l.savedpc++;
@@ -699,9 +747,8 @@
           return;  /* external invocation: return */
         else {  /* invocation via reentry: continue execution */
           ci = L->ci;
-          if (b) L->top = ci->top;
           lua_assert(isLua(ci));
-          lua_assert(GET_OPCODE(*((ci)->u.l.savedpc - 1)) == OP_CALL);
+		  Reentrant(!luaV_finishOp_(L, 1));
           goto newframe;  /* restart luaV_execute over new Lua function */
         }
       )
diff -urN  lua-5.2.0-work4-orig/src/lvm.h lua-5.2.0-work4/src/lvm.h
--- lua-5.2.0-work4-orig/src/lvm.h	2009-12-17 11:20:01 -0500
+++ lua-5.2.0-work4/src/lvm.h	2010-09-22 01:57:07 -0400
@@ -18,25 +18,35 @@
 #define tonumber(o,n)	(ttisnumber(o) || (((o) = luaV_tonumber(o,n)) != NULL))
 
 #define equalobj(L,o1,o2) \
-	(ttype(o1) == ttype(o2) && luaV_equalval_(L, o1, o2))
+	(ttype(o1) == ttype(o2) && luaV_equalval_(L, o1, o2, 0))
 
+#define luaV_lessthan(L,o1,o2)	luaV_lessthan_(L,o1,o2,0)
+#define luaV_lessequal(L,o1,o2)	luaV_lessequal_(L,o1,o2,0)
+
+#define luaV_gettable(L,t,key,val)	luaV_gettable_(L,t,key,val,0)
+#define luaV_settable(L,t,key,val)	luaV_settable_(L,t,key,val,0)
+#define luaV_arith(L,o1,o2,res,op)	luaV_arith_(L,o1,o2,res,op,0)
+#define luaV_objlen(L,o,res)	luaV_objlen_(L,o,res,0)
+
+#define luaV_finishOp(L)	luaV_finishOp_(L,0)
+#define luaV_concat(L,n)	luaV_concat_(L,n,0)
 
 /* not to called directly */
-LUAI_FUNC int luaV_equalval_ (lua_State *L, const TValue *t1, const TValue *t2);
+LUAI_FUNC int luaV_equalval_ (lua_State *L, const TValue *t1, const TValue *t2, int reent);
 
-LUAI_FUNC int luaV_lessthan (lua_State *L, const TValue *l, const TValue *r);
-LUAI_FUNC int luaV_lessequal (lua_State *L, const TValue *l, const TValue *r);
+LUAI_FUNC int luaV_lessthan_ (lua_State *L, const TValue *l, const TValue *r, int reent);
+LUAI_FUNC int luaV_lessequal_ (lua_State *L, const TValue *l, const TValue *r, int reent);
 LUAI_FUNC const TValue *luaV_tonumber (const TValue *obj, TValue *n);
 LUAI_FUNC int luaV_tostring (lua_State *L, StkId obj);
-LUAI_FUNC void luaV_gettable (lua_State *L, const TValue *t, TValue *key,
-                                            StkId val);
-LUAI_FUNC void luaV_settable (lua_State *L, const TValue *t, TValue *key,
-                                            StkId val);
-LUAI_FUNC void luaV_finishOp (lua_State *L);
+LUAI_FUNC int luaV_gettable_ (lua_State *L, const TValue *t, TValue *key,
+                                            StkId val, int reent);
+LUAI_FUNC int luaV_settable_ (lua_State *L, const TValue *t, TValue *key,
+                                            StkId val, int reent);
+LUAI_FUNC int luaV_finishOp_ (lua_State *L, int reent);
 LUAI_FUNC void luaV_execute (lua_State *L);
-LUAI_FUNC void luaV_concat (lua_State *L, int total);
-LUAI_FUNC void luaV_arith (lua_State *L, StkId ra, const TValue *rb,
-                           const TValue *rc, TMS op);
-LUAI_FUNC void luaV_objlen (lua_State *L, StkId ra, const TValue *rb);
+LUAI_FUNC int luaV_concat_ (lua_State *L, int total, int reent);
+LUAI_FUNC int luaV_arith_ (lua_State *L, StkId ra, const TValue *rb,
+                           const TValue *rc, TMS op, int reent);
+LUAI_FUNC int luaV_objlen_ (lua_State *L, StkId ra, const TValue *rb, int reent);
 
 #endif