lua-users home
lua-l archive

[Date Prev][Date Next][Thread Prev][Thread Next] [Date Index] [Thread Index]


Pattern matching in Lua uses recursion to implement several features.
The recursion depth is not limited, meaning that user code can cause
stack exhaustion and thus a segfault.

Please consider applying the attached patch, which limits the
recursion depth to LUAI_MAXCALLS.

The coding style is mostly aimed at changing the minimal number of
lines, to make the patch easier to review. If you like, I can
implement the depth decrement with "goto cleanup" or a wrapper
function, or I can manage the depth on the stack by adding a depth
parameter to match() and all the functions it calls. I can also split
the configuration macro if that's desired.

-- Tim Starling
--- lua-5.1.5~/src/lstrlib.c	2010-05-15 01:34:19.000000000 +1000
+++ lua-5.1.5/src/lstrlib.c	2012-07-08 14:56:25.236038945 +1000
@@ -176,6 +176,7 @@
     const char *init;
     ptrdiff_t len;
   } capture[LUA_MAXCAPTURES];
+  int depth; /* the current recursion depth of match() */
 } MatchState;
 
 
@@ -361,24 +362,34 @@
   else return NULL;
 }
 
+#define MATCH_RETURN(r) { \
+  const char * result = (r); \
+  --ms->depth; \
+  return result; \
+}
 
 static const char *match (MatchState *ms, const char *s, const char *p) {
+  if (++ms->depth > LUAI_MAXCALLS) {
+    luaL_error(ms->L, "recursion depth limit exceeded");
+  }
+
   init: /* using goto's to optimize tail recursion */
   switch (*p) {
     case '(': {  /* start capture */
-      if (*(p+1) == ')')  /* position capture? */
-        return start_capture(ms, s, p+2, CAP_POSITION);
-      else
-        return start_capture(ms, s, p+1, CAP_UNFINISHED);
+      if (*(p+1) == ')') { /* position capture? */
+        MATCH_RETURN(start_capture(ms, s, p+2, CAP_POSITION));
+      } else {
+        MATCH_RETURN(start_capture(ms, s, p+1, CAP_UNFINISHED));
+      }
     }
     case ')': {  /* end capture */
-      return end_capture(ms, s, p+1);
+      MATCH_RETURN(end_capture(ms, s, p+1));
     }
     case L_ESC: {
       switch (*(p+1)) {
         case 'b': {  /* balanced string? */
           s = matchbalance(ms, s, p+2);
-          if (s == NULL) return NULL;
+          if (s == NULL) MATCH_RETURN(NULL);
           p+=4; goto init;  /* else return match(ms, s, p+4); */
         }
         case 'f': {  /* frontier? */
@@ -390,13 +401,13 @@
           ep = classend(ms, p);  /* points to what is next */
           previous = (s == ms->src_init) ? '\0' : *(s-1);
           if (matchbracketclass(uchar(previous), p, ep-1) ||
-             !matchbracketclass(uchar(*s), p, ep-1)) return NULL;
+             !matchbracketclass(uchar(*s), p, ep-1)) MATCH_RETURN(NULL);
           p=ep; goto init;  /* else return match(ms, s, ep); */
         }
         default: {
           if (isdigit(uchar(*(p+1)))) {  /* capture results (%0-%9)? */
             s = match_capture(ms, s, uchar(*(p+1)));
-            if (s == NULL) return NULL;
+            if (s == NULL) MATCH_RETURN(NULL);
             p+=2; goto init;  /* else return match(ms, s, p+2) */
           }
           goto dflt;  /* case default */
@@ -404,12 +415,12 @@
       }
     }
     case '\0': {  /* end of pattern */
-      return s;  /* match succeeded */
+      MATCH_RETURN(s);  /* match succeeded */
     }
     case '$': {
-      if (*(p+1) == '\0')  /* is the `$' the last char in pattern? */
-        return (s == ms->src_end) ? s : NULL;  /* check end of string */
-      else goto dflt;
+      if (*(p+1) == '\0') /* is the `$' the last char in pattern? */
+        MATCH_RETURN((s == ms->src_end) ? s : NULL);  /* check end of string */
+      goto dflt;
     }
     default: dflt: {  /* it is a pattern item */
       const char *ep = classend(ms, p);  /* points to what is next */
@@ -418,20 +429,20 @@
         case '?': {  /* optional */
           const char *res;
           if (m && ((res=match(ms, s+1, ep+1)) != NULL))
-            return res;
+            MATCH_RETURN(res);
           p=ep+1; goto init;  /* else return match(ms, s, ep+1); */
         }
         case '*': {  /* 0 or more repetitions */
-          return max_expand(ms, s, p, ep);
+          MATCH_RETURN(max_expand(ms, s, p, ep));
         }
         case '+': {  /* 1 or more repetitions */
-          return (m ? max_expand(ms, s+1, p, ep) : NULL);
+          MATCH_RETURN(m ? max_expand(ms, s+1, p, ep) : NULL);
         }
         case '-': {  /* 0 or more repetitions (minimum) */
-          return min_expand(ms, s, p, ep);
+          MATCH_RETURN(min_expand(ms, s, p, ep));
         }
         default: {
-          if (!m) return NULL;
+          if (!m) MATCH_RETURN(NULL);
           s++; p=ep; goto init;  /* else return match(ms, s+1, ep); */
         }
       }
@@ -439,7 +450,7 @@
   }
 }
 
-
+#undef MATCH_RETURN
 
 static const char *lmemfind (const char *s1, size_t l1,
                                const char *s2, size_t l2) {
@@ -658,6 +669,7 @@
   ms.L = L;
   ms.src_init = src;
   ms.src_end = src+srcl;
+  ms.depth = 0;
   while (n < max_s) {
     const char *e;
     ms.level = 0;