[Date Prev][Date Next][Thread Prev][Thread Next]
[Date Index]
[Thread Index]
- Subject: X.509 PKIX certificate parser using lpeg.Cmt (was Re: Elegant design for creating error messages in LPEG parser)
- From: William Ahern <william@...>
- Date: Thu, 18 Apr 2019 02:29:05 -0700
On Thu, Apr 04, 2019 at 01:46:14AM +0400, joy mondal wrote:
> Just a side question.
>
> Under what design consideration do we use 'Cmt(patt,f) ' vs 'patt/f' ?
>
I meant to reply when you first posed the question but couldn't get around
to it.
The following is a basic PKIX X.509 certificate parser. I've refactored my
working code into a single file. It's a work in progress and not fully
generalized. The most obvious hole is that it doesn't parse extensions, yet.
X.509 certificates are usually encoded using ASN.1 DER (technically, CER).
DER is a type of TLV (tag-length-value) encoding. To parse symbols you need
to decode the length to know the extent of the symbol--the tag itself
doesn't tell you how long it will be, and the length can be arbitrarily
large. Formally I think such encodings are a type of context-sensitive
grammar. lpeg.Cmt is used below to match the initial tag and jump to regular
Lua code. The regular Lua code can then unpack the length, consume the
substring (possibly recursively invoking LPeg again), and tell LPeg how far
ahead to jump to resume parsing. It would be impractical if not impossible
to do this without Cmt (i.e. using pure PEGs, which can only handle
context-free grammars).
The LPeg patterns begin around line 570. You can use it like:
local lpeg = require"lpeg"
local pkix = require"pkix"
local pem = require"pkix.pem"
local raw = io.stdin:read"a"
local der = pem.match(raw) or raw
pkix.dump(io.stdout, lpeg.match(pkix.Certificate, der))
-- ==========================================================================
-- pkix.lua - Lua PKIX library
-- --------------------------------------------------------------------------
-- Copyright (c) 2018-2019 William Ahern
--
-- Permission is hereby granted, free of charge, to any person obtaining a
-- copy of this software and associated documentation files (the
-- "Software"), to deal in the Software without restriction, including
-- without limitation the rights to use, copy, modify, merge, publish,
-- distribute, sublicense, and/or sell copies of the Software, and to permit
-- persons to whom the Software is furnished to do so, subject to the
-- following conditions:
--
-- The above copyright notice and this permission notice shall be included
-- in all copies or substantial portions of the Software.
--
-- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
-- OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
-- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN
-- NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
-- DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
-- OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
-- USE OR OTHER DEALINGS IN THE SOFTWARE.
-- ==========================================================================
local lpeg = require"lpeg"
local pkix = {}
local sfind = string.find
local sformat = string.format
local spack = string.pack
local sunpack = string.unpack
local type = type
local P, S, R = lpeg.P, lpeg.S, lpeg.R
local C, Cc, Cf, Cg, Cmt, Cp, Ct = lpeg.C, lpeg.Cc, lpeg.Cf, lpeg.Cg, lpeg.Cmt, lpeg.Cp, lpeg.Ct
package.preload["pkix.base64"] = function ()
local base64 = {}
local byte = string.byte
local char = string.char
local function ord(s)
return byte(s, 1)
end
-- for mapping 6-bit integer value to radix character ordinal value
local charmap = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
-- for mapping 6-bit radix character ordinal value to 6-bit integer
local bytemap = {}
for i=1,#charmap do
bytemap[byte(charmap, i)] = i - 1
end
function base64.encode(data, width)
local function tr(ch)
return byte(charmap, 1 + ch)
end
local function s2i(s)
local x, y, z = byte(s, 1, #s)
local i = x
local n = 8
if y then
i = (i << 8) | y
n = n + 8
end
if z then
i = (i << 8) | z
n = n + 8
end
return i, n
end
local function i2s(i, n)
local a, b, c, d
if n == 24 then
a = tr((i >> 18) & 0x3f)
b = tr((i >> 12) & 0x3f)
c = tr((i >> 6) & 0x3f)
d = tr((i >> 0) & 0x3f)
elseif n == 16 then
a = tr((i >> 10) & 0x3f)
b = tr((i >> 4) & 0x3f)
c = tr((i << 2) & 0x3f)
d = ord"="
elseif n == 8 then
a = tr((i >> 2) & 0x3f)
b = tr((i << 4) & 0x3f)
c, d = ord"=", ord"="
else
error(sformat("invalid # of bits (%d)", n))
end
return char(a, b, c, d)
end
local text = data:gsub("..?.?", function (s)
return i2s(s2i(s))
end)
if width then
text = text:gsub(sformat("(%s)", string.rep(".", width)), "%1\n")
text = text:byte(-1) == 0x0A and text:sub(1, -2) or text
end
return text
end
function base64.decode(text)
local function s2i(s)
local a, b, c, d = byte(s, 1, #s)
local function pushchar(c, i, n)
local b = bytemap[c]
if b then
return ((i << 6) | b), n + 6
else
return i, n
end
end
return pushchar(d, pushchar(c, pushchar(b, pushchar(a, 0, 0))))
end
local function i2s(i, n)
if n == 24 then
return char(((i >> 16) & 0xff), ((i >> 8) & 0xff), ((i >> 0) & 0xff))
elseif n == 18 then
return char(((i >> 10) & 0xff), ((i >> 2) & 0xff))
elseif n == 12 then
return char((i >> 4) & 0xff)
else
error(string.format("invalid # of bits (%d)", n))
end
end
return (text:gsub("..?.?.?", function (s)
return i2s(s2i(s))
end))
end
return base64
end -- pkix.base64
package.preload["pkix.pem"] = function ()
local pem = {}
local function lines(s, init)
local cursor = init or 1
return function ()
local ln, pos = string.match(s, "([^\n]*)\n?()", cursor)
if ln and tonumber(pos) > cursor then
cursor = tonumber(pos)
return ln, cursor
end
end
end
local function testlabel(got, expected)
if not got then return end
local i, j = got:find(expected or ".*")
return i == 1 and j == #got and got or nil
end
-- core PEM parser
--
-- TODO: Maybe add option for parsing headers. Note that RFC 7468 prohibits
-- generation of encapsulated headers.
function pem.find(s, label, init)
local cursor = init or 1
local beginlabel, endlabel
local beginpos, endpos
local data = {}
for ln, pos in lines(s, cursor) do
beginlabel = testlabel(ln:match"^%s*-----BEGIN ([A-Z ]*)-----%s*$", label)
if beginlabel then
beginpos = cursor
cursor = pos
break
else
cursor = pos
end
end
if not beginlabel then return end
for ln, pos in lines(s, cursor) do
cursor = pos
endpos = cursor
endlabel = ln:match"^%s*-----END ([A-Z ]*)-----%s*$"
if endlabel then break end
data[#data + 1] = ln:gsub("[^ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789%+%/%=]+", "")
end
if not endlabel then return end
return beginpos, endpos, require"pkix.base64".decode(table.concat(data)), beginlabel
end
function pem.gmatch(s, label)
local cursor = 1
return function ()
local i, j, data, beginlabel = pem.find(s, label, cursor)
if data then
cursor = j
return data, beginlabel
end
end
end
function pem.match(s, label, init)
local i, j, data, beginlabel = pem.find(s, label, init)
if i then return data, beginlabel end
end
return pem
end
-- pkix.hexdump
do
local push = table.insert
local function newfmt(n)
local n1 = math.min(n, 8)
local n9 = math.min(math.max(n - 8, 0), 8)
local x = {}
local c = {}
push(x, string.sub(string.rep(" %.2x", n1), 2))
push(x, string.rep(" ", 8 - n1))
push(x, " ")
push(x, string.rep(" %.2x", n9))
push(x, string.rep(" ", 8 - n9))
push(c, string.rep("%c", n1))
push(c, string.rep(" ", 8 - n1))
push(c, string.rep("%c", n9))
push(c, string.rep(" ", 8 - n9))
return { x = table.concat(x), c = table.concat(c) }
end
local fmt = {}
for i=1,16 do
fmt[i] = newfmt(i)
end
local function glyph(ch, ...)
if ch then
ch = (ch > 32 and ch < 127 and ch) or 46 -- .
return ch, glyph(...)
end
end
function pkix.hexdump(v, how)
local s = tostring(v)
local b = {}
how = how or "axc"
local a_flag = string.find(how, "a")
local x_flag = string.find(how, "x")
local c_flag = string.find(how, "c")
for i=1,#s,16 do
local n = math.min(#s - i, 15) + 1
local f = fmt[n]
if a_flag then
push(b, sformat("%.4x ", i - 1))
end
if x_flag then
push(b, sformat(f.x, string.byte(s, i, i + 15)))
if c_flag then
push(b, " ")
end
end
if c_flag then
push(b, "|")
push(b, sformat(f.c, glyph(string.byte(s, i, i + 15))))
push(b, "|")
end
push(b, "\n")
end
return table.concat(b)
end
end
-- pkix.dump
do
local function spairs(t)
local keys = {}
for k in pairs(t) do
keys[#keys + 1] = k
end
table.sort(keys)
local i = 0
return function ()
if i < #keys then
i = i + 1
local k = keys[i]
return k, t[k]
end
end
end
local function dump(fh, v, depth, tab)
local tabs = string.rep(tab, depth)
if type(v) == "table" then
for k,v in spairs(v) do
fh:write(tabs, k, "\n")
dump(fh, v, depth + 1, tab)
end
else
local s = tostring(v)
if s:find("[^%g ]") then
s = pkix.hexdump(s)
end
fh:write(tabs, s:gsub("\n", "\n" .. tabs), "\n")
end
end
function pkix.dump(fh, v, depth, tab)
return dump(fh, v, depth or 0, tab or " ")
end
end
-- pkix.viz
do
local dec = {}
for i=0,255 do
local ch = string.char(i)
dec[ch] = sformat("\\%.3d", i)
end
function pkix.viz(s)
return tostring(s):gsub("[^%g ]", dec)
end
end
local function clz8(i)
if 0 == i & 0xFF then
return 8
end
local n = 0
if 0 == i & 0xF0 then
n = n + 4
i = i << 4
end
if 0 == i & 0xC0 then
n = n + 2
i = i << 2
end
if 0 == i & 0x80 then
n = n + 1
end
return n
end
local function clz16(i)
if 0 == i & 0xFF00 then
return 8 + clz8(i)
else
return clz8(i >> 8)
end
end
local function clz32(i)
if 0 == i & 0xFFFF0000 then
return 16 + clz16(i)
else
return clz16(i >> 16)
end
end
local function clz64(i)
if 0 == i & 0xFFFFFFFF00000000 then
return 32 + clz32(i)
else
return clz32(i >> 32)
end
end
local function clzbn(s)
local i = sfind(s, "[^\x00]")
if not i then
return #s * 8
end
return ((i - 1) * 8) + clz8(s:byte(i))
end
local function getsubtable(t, k)
local t1 = t[k]
return type(t1) == "table" and t1 or rawset(t, k, {})[k]
end
local function mapn(f, n, x, ...)
if n > 1 then
return f(x), mapn(f, n - 1, ...)
elseif n == 1 then
return f(x)
end
end
local function map(f, ...)
return mapn(f, select("#", ...), ...)
end
local function packlength(n)
if n <= 127 then
return spack("B", n)
elseif n <= 255 then
return spack("BB", (0x80 | 0x01), n)
elseif n <= 65535 then
return spack("BBB", (0x80 | 0x02), (n >> 8), (0xff & n))
else
error("length too large")
end
end
local function unpacklength(s, pos)
local n = s:byte(pos)
pos = pos + 1
if 0x80 == (n & 0x80) then
local digits = n & 0x7F
local bits = digits * 8
local nlz = clzbn(s:sub(pos, pos + digits - 1))
if bits - nlz > 32 then
return nil, sformat("DER definite-length value too large (got length gte 2^%d, expected lte 2^32-1)", bits - nlz)
elseif nlz >= 8 then
return nil, sformat("overlong DER length encoding (got %d octets, expected %d)", digits, (bits - nlz + 7) // 8)
end
n, pos = sunpack(sformat(">I%d", digits), s, pos)
if n < 128 then
return nil, sformat("overlong DER length encoding (length %d encoded long form, expected short form)", n)
end
end
return n, pos
end
local function packoid(oid)
local value = {}
for v in oid:gmatch("%d+") do
value[#value + 1] = tonumber(v)
end
local octet = {}
octet[1] = 40 * value[1] + value[2]
for i=3,#value do
local n = value[i]
local t = {}
-- encode LSB to MSB
t[1] = n & 0x7f
n = n >> 7
while n > 0 do
t[#t + 1] = 0x80 | (n & 0x7f)
n = n >> 7
end
-- reverse, MSB to LSB
for i=#t,1,-1 do
octet[#octet + 1] = t[i]
end
end
return string.char(table.unpack(octet))
end
local function unpackoid(s)
local value = {}
value[1] = s:byte(1) // 40
value[2] = s:byte(1) % 40
local n = 0
for i=2,#s do
local b = s:byte(i)
n = n << 7
n = n | (0x7f & b)
if 0x80 ~= (b & 0x80) then
value[#value + 1] = n
n = 0
end
end
return table.concat(value, ".")
end
-- pkix.oid2txt, pkix.txt2oid
--
-- OID short and long name mapping
--
do
local namemap = {}
local oidmap = {}
for name,oid in string.gmatch([[
commonName 2.5.4.3
countryName 2.5.4.6
distinguishedName 2.5.4.49
id-ecPublicKey 1.2.840.10045.2.1
id-X25519 1.3.101.110
id-X448 1.3.101.111
id-Ed25519 1.3.101.112
id-Ed448 1.3.101.113
localityName 2.5.4.7
organizationName 2.5.4.10
organizationalUnitName 2.5.4.11
rsaEncryption 1.2.840.113549.1.1.1
secp192r1 1.2.840.10045.3.1.1
secp256r1 1.2.840.10045.3.1.7
secp384r1 1.3.132.0.34
secp521r1 1.3.132.0.35
sha256WithRSAEncryption 1.2.840.113549.1.1.11
sha384WithRSAEncryption 1.2.840.113549.1.1.12
sha512WithRSAEncryption 1.2.840.113549.1.1.13
stateOrProvinceName 2.5.4.8
C countryName
CN commonName
L localityName
O organizationName
OU organizationalUnitName
ST stateOrProvinceName
prime192v1 secp192r1
prime256v1 secp256r1
]], "[ \t]*([^ \t\n]+)[ \t]+([^ \t\n]+)[ \t]*\n") do
if not oid:match"^%d[%d%.]*$" then
local primaryname = oid
oid = namemap[primaryname] or error(sformat("alias %s defined before primary name %s", name, primaryname))
end
namemap[name] = oid
local t = getsubtable(oidmap, oid)
t[#t + 1] = name
end
-- NOTE: an OID will always have at least two integer components
local function isliteral(oid)
return oid:find"^%d+%.[%d%.]*%d$" and not oid:find("%.%.", 1, true)
end
function pkix.txt2oid(name)
if namemap[name] then
return namemap[name]
elseif isliteral(name) then
return name
else
error(sformat("unknown OID name (%s)", tostring(name or "?")))
end
end
function pkix.oid2txt(oid)
return oidmap[oid] and oidmap[oid][1] or oid
end
end
local function Cinner(identifier, patt)
local innermatch
if lpeg.type(patt) then
innermatch = function (s)
return lpeg.match(patt * -P(1), s)
end
elseif type(patt) == "function" then
innermatch = patt
elseif patt == nil then
innermatch = function (s)
return s
end
else
error(sformat("bad argument #2 (expected pattern, function, or nil, got %s)", type(patt)), 2)
end
return Cmt(identifier, function (s, pos)
local n, pos = assert(unpacklength(s, pos))
local s1 = s:sub(pos, pos + n - 1)
pos = pos + n
return (function (pos, v, ...)
if v then
return pos, v, ...
else
return false
end
end)(pos, innermatch(s1))
end)
end
local ANY = Cmt(P(1), function (s, pos)
local n, pos = assert(unpacklength(s, pos))
local bn = s:sub(pos, pos + n - 1)
return pos + n, bn
end)
local INTEGER = Cinner(P"\x02")
local BIT_STRING = Cinner(P"\x03", function (s)
local pad = s:byte(1) -- first octet is number of padding bits
assert(pad == 0, "BIT STRING not octet aligned") -- we only support DER
return s:sub(2)
end)
local OCTET_STRING = Cinner(P"\x04")
local NULL = P"\x05\x00"
local OBJECT_IDENTIFIER = function (oid)
if oid then
local s = packoid(pkix.txt2oid(oid))
return P(sformat("\x06%s%s", packlength(#s), s)) * Cc(oid)
else
return Cinner(P"\x06", function (s)
return assert(unpackoid(s))
end)
end
end
local OID = OBJECT_IDENTIFIER
local SEQUENCE = function (patt)
return Cinner(P"\x30", patt)
end
local SET_OF = function (patt)
return Cinner(P"\x31", Ct(patt^0))
end
local UTCTIME = Cinner(P"\x17", function (s)
-- X.690 requires
-- 1) second precision
-- 2) no local time differential (always UTC)
local YY, MM, DD, hh, mm, ss = map(tonumber, s:match"^(%d%d)(%d%d)(%d%d)(%d%d)(%d%d)(%d%d)Z$")
if not YY then
error(sformat("malformed UTCTIME (%s)", pkix.viz(s)))
elseif MM == 0 or MM > 12 or DD == 0 or DD > 31 or hh > 24 or mm > 59 or ss > 59 then
error(sformat("out of range UTCTIME value (%s)", pkix.viz(s)))
elseif hh == 24 and not (mm == 0 and ss == 0) then
error(sformat("incorrect UTCTIME midnight representation (%s)", pkix.viz(s)))
end
return s
end)
local Version = Cinner("\xA0", INTEGER)
local CertificateSerialNumber = INTEGER
-- algorithm OIDs
local rsaEncryption = OID"rsaEncryption"
local sha256WithRSAEncryption = OID"sha256WithRSAEncryption"
local sha384WithRSAEncryption = OID"sha384WithRSAEncryption"
local sha512WithRSAEncryption = OID"sha512WithRSAEncryption"
local id_ecPublicKey = OID"id-ecPublicKey"
local secp256r1 = OID"secp256r1"
local secp384r1 = OID"secp384r1"
local secp521r1 = OID"secp521r1"
local id_X25519 = OID"id-X25519"
local id_X448 = OID"id-X448"
local id_Ed25519 = OID"id-Ed25519"
local id_Ed448 = OID"id-Ed448"
local RSAPublicKey = SEQUENCE(Ct(Cg(INTEGER, "modulus") * Cg(INTEGER, "publicExponent")))
local namedCurve = secp256r1 + secp384r1 + secp521r1
local ECParameters = Cg(namedCurve, "namedCurve")
do
function pkix.P_AlgorithmIdentifier(algorithm, parameters)
return SEQUENCE(Ct(
Cg(algorithm, "algorithm") *
(parameters and Cg(Ct(parameters), "parameters") or P"")
))
end
local AlgorithmIdentifier_rsa = pkix.P_AlgorithmIdentifier(
(rsaEncryption +
sha256WithRSAEncryption +
sha384WithRSAEncryption +
sha512WithRSAEncryption),
NULL
)
local AlgorithmIdentifier_ellipticCurve =
pkix.P_AlgorithmIdentifier(id_ecPublicKey, ECParameters)
local AlgorithmIdentifier_edwardsCurve =
pkix.P_AlgorithmIdentifier(id_X25519 + id_X448 + id_Ed25519 + id_Ed448)
pkix.AlgorithmIdentifier =
AlgorithmIdentifier_rsa +
AlgorithmIdentifier_ellipticCurve +
AlgorithmIdentifier_edwardsCurve +
SEQUENCE(C(P(1))^0)
end
local subjectPublicKey_RSAPublicKey = Cmt(BIT_STRING, function (_, pos, s)
local t = lpeg.match(RSAPublicKey, s)
if t then
return pos, t
else
return false
end
end)
local SubjectPublicKeyInfo_rsaEncryption = SEQUENCE(Ct(Cg(pkix.P_AlgorithmIdentifier(rsaEncryption, NULL), "algorithm") * Cg(subjectPublicKey_RSAPublicKey, "subjectPublicKey")))
local SubjectPublicKeyInfo_id_ecPublicKey = SEQUENCE(Ct(Cg(pkix.P_AlgorithmIdentifier(id_ecPublicKey, ECParameters), "algorithm") * Cg(BIT_STRING, "subjectPublicKey")))
local SubjectPublicKeyInfo_edwardsCurve = SEQUENCE(Ct(Cg(pkix.P_AlgorithmIdentifier(id_X25519 + id_X448 + id_Ed25519 + id_Ed448), "algorithm") * Cg(BIT_STRING, "subjectPublicKey")))
pkix.SubjectPublicKeyInfo = SubjectPublicKeyInfo_rsaEncryption + SubjectPublicKeyInfo_id_ecPublicKey + SubjectPublicKeyInfo_edwardsCurve
-- NOTE: RDNSequence is an ordered list of RelativeDistinguishedName (RDN),
-- which is itself an unordered list of AttributeTypeAndValue (AV) pairs.
-- In practice RDNs lists only contain a single AV pair and when visualized
-- the RDN list is elided so the extra tree depth is not obvious. This is
-- quite unintuitive.
local AttributeType = OBJECT_IDENTIFIER() / pkix.oid2txt
local AttributeValue = ANY
local AttributeTypeAndValue = SEQUENCE(Ct(Cg(AttributeType, "type") * Cg(AttributeValue, "value")))
local RelativeDistinguishedName = SET_OF(AttributeTypeAndValue)
local RDNSequence = SEQUENCE(Ct(RelativeDistinguishedName^0))
local Name = Ct(Cg(RDNSequence, "rdnSequence"))
local Validity = SEQUENCE(Ct(
Cg(UTCTIME, "notBefore") *
Cg(UTCTIME, "notAfter")
))
local TBSCertificate = SEQUENCE(Ct(
Cg(Version, "version") *
Cg(CertificateSerialNumber, "serialNumber") *
Cg(pkix.AlgorithmIdentifier, "signature") *
Cg(Name, "issuer") *
Cg(Validity, "validity") *
Cg(Name, "subject") *
Cg(pkix.SubjectPublicKeyInfo, "subjectPublicKeyInfo") *
Cg(P(1)^0, "trash")
))
local Signature = BIT_STRING
pkix.Certificate = SEQUENCE(Ct(Cg(TBSCertificate, "tbsCertificate") * Cg(pkix.AlgorithmIdentifier, "signatureAlgorithm") * Cg(Signature, "signature")))
return pkix