lua-users home
lua-l archive

[Date Prev][Date Next][Thread Prev][Thread Next] [Date Index] [Thread Index]


"Johansen, Vagn" wrote:
> > >gsub() has a problem with \000
> >
> > >From the manual: "A pattern cannot contain embedded zeros.
> > Use %z instead."
> 
> That is a bit surprising. But OK, I guess it simplifies the implementation.

It's not that hard.  A patch is attached.  Afaics, there's only one
place left that does not handle \0 correctly: the "%s" sequence in
format().  Will look at that later.

Ciao, ET.


PS: The patch is pretty fresh.  It compiles and passes some simple tests.
No guaranty ;-)
diff -ruN lua-4.0/include/lualib.h lua-4.0a/include/lualib.h
--- lua-4.0/include/lualib.h	Fri Oct 27 18:15:53 2000
+++ lua-4.0a/include/lualib.h	Fri Mar  2 15:19:20 2001
@@ -28,7 +28,7 @@
 
 /* Auxiliary functions (private) */
 
-const char *luaI_classend (lua_State *L, const char *p);
+const char *luaI_classend (lua_State *L, const char *p, const char *p_end);
 int luaI_singlematch (int c, const char *p, const char *ep);
 
 #endif
diff -ruN lua-4.0/src/lib/lstrlib.c lua-4.0a/src/lib/lstrlib.c
--- lua-4.0/src/lib/lstrlib.c	Fri Oct 27 18:15:53 2000
+++ lua-4.0a/src/lib/lstrlib.c	Fri Mar  2 16:05:29 2001
@@ -41,7 +41,7 @@
   if (end > (long)l) end = l;
   if (start <= end)
     lua_pushlstring(L, s+start-1, end-start+1);
-  else lua_pushstring(L, "");
+  else lua_pushlstring(L, "", 0);
   return 1;
 }
 
@@ -122,7 +122,8 @@
 
 
 struct Capture {
-  const char *src_end;  /* end ('\0') of source string */
+  const char *src_end;  /* end of source string */
+  const char *pat_end;  /* end of pattern string */
   int level;  /* total number of captures (finished or unfinished) */
   struct {
     const char *init;
@@ -132,7 +133,6 @@
 
 
 #define ESC		'%'
-#define SPECIALS	"^$*+?.([%-"
 
 
 static int check_capture (lua_State *L, int l, struct Capture *cap) {
@@ -152,18 +152,20 @@
 }
 
 
-const char *luaI_classend (lua_State *L, const char *p) {
+const char *luaI_classend (lua_State *L, const char *p, const char *p_end) {
   switch (*p++) {
     case ESC:
-      if (*p == '\0') lua_error(L, "malformed pattern (ends with `%')");
-      return p+1;
+      if (p < p_end)
+        return p+1;
+      lua_error(L, "malformed pattern (ends with `%')");
     case '[':
-      if (*p == '^') p++;
-      do {  /* look for a ']' */
-        if (*p == '\0') lua_error(L, "malformed pattern (missing `]')");
-        if (*(p++) == ESC && *p != '\0') p++;  /* skip escapes (e.g. '%]') */
-      } while (*p != ']');
-      return p+1;
+      while (p < p_end) {
+        if (*p == ']')
+          return p+1;
+        if (*p++ == ESC && p < p_end)
+          p++;
+      }
+      lua_error(L, "malformed pattern (missing `]')");
     default:
       return p;
   }
@@ -202,7 +204,7 @@
       if (match_class(c, (unsigned char)*p))
         return sig;
     }
-    else if ((*(p+1) == '-') && (p+2 < endclass)) {
+    else if (p+2 < endclass && *(p+1) == '-') {
       p+=2;
       if ((int)(unsigned char)*(p-2) <= c && c <= (int)(unsigned char)*p)
         return sig;
@@ -234,7 +236,7 @@
 
 static const char *matchbalance (lua_State *L, const char *s, const char *p,
                                  struct Capture *cap) {
-  if (*p == 0 || *(p+1) == 0)
+  if (p+1 >= cap->pat_end)
     lua_error(L, "unbalanced pattern");
   if (*s != *p) return NULL;
   else {
@@ -318,52 +320,55 @@
 
 static const char *match (lua_State *L, const char *s, const char *p,
                           struct Capture *cap) {
-  init: /* using goto's to optimize tail recursion */
-  switch (*p) {
-    case '(':  /* start capture */
-      return start_capture(L, s, p, cap);
-    case ')':  /* end capture */
-      return end_capture(L, s, p, cap);
-    case ESC:  /* may be %[0-9] or %b */
-      if (isdigit((unsigned char)(*(p+1)))) {  /* capture? */
-        s = match_capture(L, s, *(p+1), cap);
-        if (s == NULL) return NULL;
-        p+=2; goto init;  /* else return match(L, s, p+2, cap) */
-      }
-      else if (*(p+1) == 'b') {  /* balanced string? */
-        s = matchbalance(L, s, p+2, cap);
-        if (s == NULL) return NULL;
-        p+=4; goto init;  /* else return match(L, s, p+4, cap); */
-      }
-      else goto dflt;  /* case default */
-    case '\0':  /* end of pattern */
-      return s;  /* match succeeded */
-    case '$':
-      if (*(p+1) == '\0')  /* is the '$' the last char in pattern? */
-        return (s == cap->src_end) ? s : NULL;  /* check end of string */
-      else goto dflt;
-    default: dflt: {  /* it is a pattern item */
-      const char *ep = luaI_classend(L, p);  /* points to what is next */
-      int m = s<cap->src_end && luaI_singlematch((unsigned char)*s, p, ep);
-      switch (*ep) {
-        case '?': {  /* optional */
-          const char *res;
-          if (m && ((res=match(L, s+1, ep+1, cap)) != NULL))
-            return res;
-          p=ep+1; goto init;  /* else return match(L, s, ep+1, cap); */
-        }
-        case '*':  /* 0 or more repetitions */
-          return max_expand(L, s, p, ep, cap);
-        case '+':  /* 1 or more repetitions */
-          return (m ? max_expand(L, s+1, p, ep, cap) : NULL);
-        case '-':  /* 0 or more repetitions (minimum) */
-          return min_expand(L, s, p, ep, cap);
-        default:
-          if (!m) return NULL;
-          s++; p=ep; goto init;  /* else return match(L, s+1, ep, cap); */
+  while (p < cap->pat_end) {
+    const char *ep;
+    int m;
+    switch (*p) {
+      case '(':  /* start capture */
+        return start_capture(L, s, p, cap);
+      case ')':  /* end capture */
+        return end_capture(L, s, p, cap);
+      case ESC:  /* may be %[0-9] or %b */
+        if (p+1 < cap->pat_end)
+          if (isdigit((unsigned char)(*(p+1)))) {  /* capture? */
+            s = match_capture(L, s, *(p+1), cap);
+            if (s == NULL) return NULL;
+            p+=2; continue;  /* else return match(L, s, p+2, cap) */
+          }
+          else if (*(p+1) == 'b') {  /* balanced string? */
+            s = matchbalance(L, s, p+2, cap);
+            if (s == NULL) return NULL;
+            p+=4; continue;  /* else return match(L, s, p+4, cap); */
+          }
+        break;
+      case '$':	/* end of string (only special at end of pattern) */
+        if (p+1 == cap->pat_end)
+          return (s == cap->src_end) ? s : NULL;
+        break;
+    }
+    /* it is a pattern item */
+    ep = luaI_classend(L, p, cap->pat_end);  /* points to what is next */
+    m = s<cap->src_end && luaI_singlematch((unsigned char)*s, p, ep);
+    switch (*ep) {
+      case '?': {  /* optional */
+        const char *res;
+        if (m && ((res=match(L, s+1, ep+1, cap)) != NULL))
+          return res;
+        p=ep+1; continue;  /* else return match(L, s, ep+1, cap); */
       }
+      case '*':  /* 0 or more repetitions */
+        return max_expand(L, s, p, ep, cap);
+      case '+':  /* 1 or more repetitions */
+        return m ? max_expand(L, s+1, p, ep, cap) : NULL;
+      case '-':  /* 0 or more repetitions (minimum) */
+        return min_expand(L, s, p, ep, cap);
+      default:
+        if (!m) return NULL;
+        s++; p=ep; continue;  /* else return match(L, s+1, ep, cap); */
     }
+    /* not reached */
   }
+  return s;
 }
 
 
@@ -390,6 +395,17 @@
 }
 
 
+static int hasspecials (const char *p, size_t l) {
+  while (l--)
+    switch (*p++) {
+      case '^': case '$': case '*': case '+': case '-':
+      case '?': case '.': case '(': case '[': case '%':
+        return 1;
+    }
+  return 0;
+}
+
+
 static int push_captures (lua_State *L, struct Capture *cap) {
   int i;
   luaL_checkstack(L, cap->level, "too many captures");
@@ -409,8 +425,8 @@
   long init = posrelat(luaL_opt_long(L, 3, 1), l1) - 1;
   struct Capture cap;
   luaL_arg_check(L, 0 <= init && (size_t)init <= l1, 3, "out of range");
-  if (lua_gettop(L) > 3 ||  /* extra argument? */
-      strpbrk(p, SPECIALS) == NULL) {  /* or no special characters? */
+  if (lua_gettop(L) > 3 || !hasspecials(p, l2)) {
+    /* extra argument or no specials characters */
     const char *s2 = lmemfind(s+init, l1-init, p, l2);
     if (s2) {
       lua_pushnumber(L, s2-s+1);
@@ -422,6 +438,7 @@
     int anchor = (*p == '^') ? (p++, 1) : 0;
     const char *s1=s+init;
     cap.src_end = s+l1;
+    cap.pat_end = p+l2;
     do {
       const char *res;
       cap.level = 0;
@@ -470,9 +487,9 @@
 
 
 static int str_gsub (lua_State *L) {
-  size_t srcl;
+  size_t srcl, patl;
   const char *src = luaL_check_lstr(L, 1, &srcl);
-  const char *p = luaL_check_string(L, 2);
+  const char *p = luaL_check_lstr(L, 2, &patl);
   int max_s = luaL_opt_int(L, 4, srcl+1);
   int anchor = (*p == '^') ? (p++, 1) : 0;
   int n = 0;
@@ -483,6 +500,7 @@
     3, "string or function expected");
   luaL_buffinit(L, &b);
   cap.src_end = src+srcl;
+  cap.pat_end = p+patl;
   while (n < max_s) {
     const char *e;
     cap.level = 0;
@@ -541,6 +559,7 @@
     else if (*++strfrmt == '%')
       luaL_putchar(&b, *strfrmt++);  /* %% */
     else { /* format item */
+      static const char fmt_spec[] = "[-+ #0]*(%d*)%.?(%d*)";
       struct Capture cap;
       char form[MAX_FORMAT];  /* to store the format ('%...') */
       char buff[MAX_ITEM];  /* to store the formatted item */
@@ -551,9 +570,10 @@
         initf += 2;  /* skip the 'n$' */
       }
       arg++;
-      cap.src_end = strfrmt+strlen(strfrmt)+1;
+      cap.src_end = strfrmt+strlen(strfrmt);
+      cap.pat_end = fmt_spec+sizeof(fmt_spec)-1;
       cap.level = 0;
-      strfrmt = match(L, initf, "[-+ #0]*(%d*)%.?(%d*)", &cap);
+      strfrmt = match(L, initf, fmt_spec, &cap);
       if (cap.capture[0].len > 2 || cap.capture[1].len > 2 ||  /* < 100? */
           strfrmt-initf > MAX_FORMAT-2)
         lua_error(L, "invalid format (width or precision too long)");