[Date Prev][Date Next][Thread Prev][Thread Next]
[Date Index]
[Thread Index]
- Subject: [patch] string.match() recursion depth limit
- From: Tim Starling <tstarling@...>
- Date: Sun, 08 Jul 2012 20:51:25 +1000
Pattern matching in Lua uses recursion to implement several features.
The recursion depth is not limited, meaning that user code can cause
stack exhaustion and thus a segfault.
Please consider applying the attached patch, which limits the
recursion depth to LUAI_MAXCALLS.
The coding style is mostly aimed at changing the minimal number of
lines, to make the patch easier to review. If you like, I can
implement the depth decrement with "goto cleanup" or a wrapper
function, or I can manage the depth on the stack by adding a depth
parameter to match() and all the functions it calls. I can also split
the configuration macro if that's desired.
-- Tim Starling
--- lua-5.1.5~/src/lstrlib.c 2010-05-15 01:34:19.000000000 +1000
+++ lua-5.1.5/src/lstrlib.c 2012-07-08 14:56:25.236038945 +1000
@@ -176,6 +176,7 @@
const char *init;
ptrdiff_t len;
} capture[LUA_MAXCAPTURES];
+ int depth; /* the current recursion depth of match() */
} MatchState;
@@ -361,24 +362,34 @@
else return NULL;
}
+#define MATCH_RETURN(r) { \
+ const char * result = (r); \
+ --ms->depth; \
+ return result; \
+}
static const char *match (MatchState *ms, const char *s, const char *p) {
+ if (++ms->depth > LUAI_MAXCALLS) {
+ luaL_error(ms->L, "recursion depth limit exceeded");
+ }
+
init: /* using goto's to optimize tail recursion */
switch (*p) {
case '(': { /* start capture */
- if (*(p+1) == ')') /* position capture? */
- return start_capture(ms, s, p+2, CAP_POSITION);
- else
- return start_capture(ms, s, p+1, CAP_UNFINISHED);
+ if (*(p+1) == ')') { /* position capture? */
+ MATCH_RETURN(start_capture(ms, s, p+2, CAP_POSITION));
+ } else {
+ MATCH_RETURN(start_capture(ms, s, p+1, CAP_UNFINISHED));
+ }
}
case ')': { /* end capture */
- return end_capture(ms, s, p+1);
+ MATCH_RETURN(end_capture(ms, s, p+1));
}
case L_ESC: {
switch (*(p+1)) {
case 'b': { /* balanced string? */
s = matchbalance(ms, s, p+2);
- if (s == NULL) return NULL;
+ if (s == NULL) MATCH_RETURN(NULL);
p+=4; goto init; /* else return match(ms, s, p+4); */
}
case 'f': { /* frontier? */
@@ -390,13 +401,13 @@
ep = classend(ms, p); /* points to what is next */
previous = (s == ms->src_init) ? '\0' : *(s-1);
if (matchbracketclass(uchar(previous), p, ep-1) ||
- !matchbracketclass(uchar(*s), p, ep-1)) return NULL;
+ !matchbracketclass(uchar(*s), p, ep-1)) MATCH_RETURN(NULL);
p=ep; goto init; /* else return match(ms, s, ep); */
}
default: {
if (isdigit(uchar(*(p+1)))) { /* capture results (%0-%9)? */
s = match_capture(ms, s, uchar(*(p+1)));
- if (s == NULL) return NULL;
+ if (s == NULL) MATCH_RETURN(NULL);
p+=2; goto init; /* else return match(ms, s, p+2) */
}
goto dflt; /* case default */
@@ -404,12 +415,12 @@
}
}
case '\0': { /* end of pattern */
- return s; /* match succeeded */
+ MATCH_RETURN(s); /* match succeeded */
}
case '$': {
- if (*(p+1) == '\0') /* is the `$' the last char in pattern? */
- return (s == ms->src_end) ? s : NULL; /* check end of string */
- else goto dflt;
+ if (*(p+1) == '\0') /* is the `$' the last char in pattern? */
+ MATCH_RETURN((s == ms->src_end) ? s : NULL); /* check end of string */
+ goto dflt;
}
default: dflt: { /* it is a pattern item */
const char *ep = classend(ms, p); /* points to what is next */
@@ -418,20 +429,20 @@
case '?': { /* optional */
const char *res;
if (m && ((res=match(ms, s+1, ep+1)) != NULL))
- return res;
+ MATCH_RETURN(res);
p=ep+1; goto init; /* else return match(ms, s, ep+1); */
}
case '*': { /* 0 or more repetitions */
- return max_expand(ms, s, p, ep);
+ MATCH_RETURN(max_expand(ms, s, p, ep));
}
case '+': { /* 1 or more repetitions */
- return (m ? max_expand(ms, s+1, p, ep) : NULL);
+ MATCH_RETURN(m ? max_expand(ms, s+1, p, ep) : NULL);
}
case '-': { /* 0 or more repetitions (minimum) */
- return min_expand(ms, s, p, ep);
+ MATCH_RETURN(min_expand(ms, s, p, ep));
}
default: {
- if (!m) return NULL;
+ if (!m) MATCH_RETURN(NULL);
s++; p=ep; goto init; /* else return match(ms, s+1, ep); */
}
}
@@ -439,7 +450,7 @@
}
}
-
+#undef MATCH_RETURN
static const char *lmemfind (const char *s1, size_t l1,
const char *s2, size_t l2) {
@@ -658,6 +669,7 @@
ms.L = L;
ms.src_init = src;
ms.src_end = src+srcl;
+ ms.depth = 0;
while (n < max_s) {
const char *e;
ms.level = 0;