lua-users home
lua-l archive

[Date Prev][Date Next][Thread Prev][Thread Next] [Date Index] [Thread Index]


Hi,

I want to share a patch to optimize for-in loops for the pairs and ipairs case. It is inspired by LuaJITs ITERN optimization and a similar optimization in Luau.

It works by checking in the OP_TFORPREP opcode for the pairs and ipairs case and use the to-be-closed slot for an index variable in the pairs and a marker in the ipairs case. This allows to easily check in the OP_TFORCALL opcode for the index or marker and use the pairs or ipairs fast path or fall back to the default path in other cases. Furthermore, the patch can not be observed (except for a speedup). In case of debug.getlocal it will return nil instead of the special values and deoptimize the loop in case any of the state variables are modified with debug.setlocal.

The speed after the patch makes for-in loops for pairs, ipairs and numeric for loops almost equal. For a simple sum loop this makes pairs and ipairs around 4x faster, while for a simple table clone loop it drops to a speedup of around 2x.
However, the default path got slightly slower.

Maybe someone is interested in such a patch.

Regards,
Xmilia
diff --git a/lbaselib.c b/lbaselib.c
index 1d60c9de..bf999467 100644
--- a/lbaselib.c
+++ b/lbaselib.c
@@ -264,7 +264,8 @@ static int luaB_type (lua_State *L) {
 }
 
 
-static int luaB_next (lua_State *L) {
+LUAI_FUNC int luaB_next (lua_State *L);
+int luaB_next (lua_State *L) {
   luaL_checktype(L, 1, LUA_TTABLE);
   lua_settop(L, 2);  /* create a 2nd argument if there isn't one */
   if (lua_next(L, 1))
@@ -299,7 +300,9 @@ static int luaB_pairs (lua_State *L) {
 /*
 ** Traversal function for 'ipairs'
 */
-static int ipairsaux (lua_State *L) {
+#define ipairsaux luaB_ipairsaux
+LUAI_FUNC int ipairsaux (lua_State *L);
+int ipairsaux (lua_State *L) {
   lua_Integer i = luaL_checkinteger(L, 2);
   i = luaL_intop(+, i, 1);
   lua_pushinteger(L, i);
diff --git a/ldebug.c b/ldebug.c
index fa15eaf6..7349c91c 100644
--- a/ldebug.c
+++ b/ldebug.c
@@ -230,7 +230,10 @@ LUA_API const char *lua_getlocal (lua_State *L, const lua_Debug *ar, int n) {
     StkId pos = NULL;  /* to avoid warnings */
     name = luaG_findlocal(L, ar->i_ci, n, &pos);
     if (name) {
-      setobjs2s(L, L->top, pos);
+      if (luai_unlikely(ttype(s2v(pos)) == LUA_TITER))
+        setnilvalue(s2v(L->top));
+      else
+        setobjs2s(L, L->top, pos);
       api_incr_top(L);
     }
   }
@@ -245,8 +248,16 @@ LUA_API const char *lua_setlocal (lua_State *L, const lua_Debug *ar, int n) {
   lua_lock(L);
   name = luaG_findlocal(L, ar->i_ci, n, &pos);
   if (name) {
+    StkId to = pos + 4;
     setobjs2s(L, pos, L->top - 1);
     L->top--;  /* pop value */
+    if (to > L->top) to = L->top;
+    while(++pos < to) {
+      if (luai_unlikely(ttype(s2v(pos)) == LUA_TITER)) {
+        setnilvalue(s2v(pos));
+        break;
+      }
+    }
   }
   lua_unlock(L);
   return name;
diff --git a/lobject.h b/lobject.h
index 77cc606f..51edfecf 100644
--- a/lobject.h
+++ b/lobject.h
@@ -22,6 +22,7 @@
 #define LUA_TUPVAL	LUA_NUMTYPES  /* upvalues */
 #define LUA_TPROTO	(LUA_NUMTYPES+1)  /* function prototypes */
 #define LUA_TDEADKEY	(LUA_NUMTYPES+2)  /* removed keys in tables */
+#define LUA_TITER  (LUA_NUMTYPES+3) /* Iterator marker */
 
 
 
@@ -52,6 +53,7 @@ typedef union Value {
   lua_CFunction f; /* light C functions */
   lua_Integer i;   /* integer numbers */
   lua_Number n;    /* float numbers */
+  unsigned int it; /* iterator index */
   /* not used, but may avoid warnings for uninitialized value */
   lu_byte ub;
 } Value;
@@ -765,6 +767,11 @@ typedef struct Table {
 #define setdeadkey(node)	(keytt(node) = LUA_TDEADKEY)
 #define keyisdead(node)		(keytt(node) == LUA_TDEADKEY)
 
+
+/* Value used for faster iterations */
+#define LUA_VITER  makevariant(LUA_TITER, 0)
+#define LUA_VITERI  makevariant(LUA_TITER, 1)
+
 /* }================================================================== */
 
 
diff --git a/lvm.c b/lvm.c
index 614df055..252f906a 100644
--- a/lvm.c
+++ b/lvm.c
@@ -1144,6 +1144,8 @@ void luaV_finishOp (lua_State *L) {
 #define vmcase(l)	case l:
 #define vmbreak		break
 
+LUAI_FUNC int luaB_next (lua_State *L);
+LUAI_FUNC int luaB_ipairsaux (lua_State *L);
 
 void luaV_execute (lua_State *L, CallInfo *ci) {
   LClosure *cl;
@@ -1809,11 +1811,24 @@ void luaV_execute (lua_State *L, CallInfo *ci) {
       }
       vmcase(OP_TFORPREP) {
        StkId ra = RA(i);
-        /* create to-be-closed upvalue (if needed) */
-        halfProtect(luaF_newtbcupval(L, ra + 3));
-        pc += GETARG_Bx(i);
-        i = *(pc++);  /* go to next instruction */
+        const Instruction* callpc = pc + GETARG_Bx(i);
+        i = *callpc;
         lua_assert(GET_OPCODE(i) == OP_TFORCALL && ra == RA(i));
+        if (ttypetag(s2v(ra)) == LUA_VLCF && ttistable(s2v(ra + 1)) && ttisnil(s2v(ra + 3)) && !trap && (GETARG_C(i) == 1 || GETARG_C(i) == 2)) {
+          if (fvalue(s2v(ra)) == luaB_next && ttisnil(s2v(ra + 2))) {
+            settt_(s2v(ra + 3), LUA_VITER);
+            val_(s2v(ra + 3)).it = 0;
+          } else if (fvalue(s2v(ra)) == luaB_ipairsaux && ttisinteger(s2v(ra + 2))) {
+            settt_(s2v(ra + 3), LUA_VITERI);
+          } else {
+            /* create to-be-closed upvalue (if needed) */
+            halfProtect(luaF_newtbcupval(L, ra + 3));
+          }
+        } else {
+          /* create to-be-closed upvalue (if needed) */
+          halfProtect(luaF_newtbcupval(L, ra + 3));
+        }
+        pc = callpc + 1;
         goto l_tforcall;
       }
       vmcase(OP_TFORCALL) {
@@ -1824,6 +1839,60 @@ void luaV_execute (lua_State *L, CallInfo *ci) {
            to-be-closed variable. The call will use the stack after
            these values (starting at 'ra + 4')
         */
+        if (luai_likely(ttypetag(s2v(ra + 3)) == LUA_VITER)) {
+          if (luai_likely(!trap)) {
+            Table *t = hvalue(s2v(ra + 1));
+            unsigned int idx = val_(s2v(ra + 3)).it;
+            unsigned int asize = luaH_realasize(t);
+
+            i = *(pc++);  /* go to next instruction */
+            lua_assert(GET_OPCODE(i) == OP_TFORLOOP && ra == RA(i));
+
+            for (; idx < asize; idx++) {  /* try first array part */
+              if (luai_likely(!isempty(&t->array[idx]))) {  /* a non-empty entry? */
+                setivalue(s2v(ra + 4), idx + 1);
+                setobj2s(L, ra + 5, &t->array[idx]);
+                goto l_tforcall_found;
+              }
+            }
+            for (idx -= asize; cast_int(idx) < sizenode(t); idx++) {  /* hash part */
+              Node *n = gnode(t, idx);
+              if (luai_likely(!isempty(gval(n)))) {  /* a non-empty entry? */
+                getnodekey(L, s2v(ra + 4), n);
+                setobj2s(L, ra + 5, gval(n));
+                idx += asize;
+                goto l_tforcall_found;
+              }
+            }
+            vmbreak;
+           l_tforcall_found:
+            val_(s2v(ra + 3)).it = idx + 1;
+            setobjs2s(L, ra + 2, ra + 4);  /* save control variable */
+            pc -= GETARG_Bx(i);  /* jump back */
+            vmbreak;
+          }
+          setnilvalue(s2v(ra + 3));
+        } else if (luai_likely(ttypetag(s2v(ra + 3)) == LUA_VITERI)) {
+          if (luai_likely(!trap)) {
+            /* No check for type as LUA_VITERI is removed in case of debug setlocal. */
+            Table *t = hvalue(s2v(ra + 1));
+            lua_Integer n = ivalue(s2v(ra + 2));
+            const TValue *slot;
+            n = intop(+, n, 1);
+            slot = luai_likely(l_castS2U(n) - 1 < t->alimit) ? &t->array[n - 1] : luaH_getint(t, n);
+            if (luai_likely(!isempty(slot))) {
+              setobj2s(L, ra + 5, slot);
+              chgivalue(s2v(ra + 2), n);
+              i = *(pc++);  /* go to next instruction */
+              lua_assert(GET_OPCODE(i) == OP_TFORLOOP && ra == RA(i));
+              setobjs2s(L, ra + 4, ra + 2);  /* save control variable */
+              pc -= GETARG_Bx(i);  /* jump back */
+              vmbreak;
+            }
+          } else {
+            setnilvalue(s2v(ra + 3));
+          }
+        }
         /* push function, state, and control variable */
         memcpy(ra + 4, ra, 3 * sizeof(*ra));
         L->top = ra + 4 + 3;