lua-users home
lua-l archive

[Date Prev][Date Next][Thread Prev][Thread Next] [Date Index] [Thread Index]


On Thu, Apr 04, 2019 at 01:46:14AM +0400, joy mondal wrote:
> Just a side question.
> 
> Under what design consideration do we use 'Cmt(patt,f) ' vs 'patt/f' ?
> 

I meant to reply when you first posed the question but couldn't get around
to it.

The following is a basic PKIX X.509 certificate parser. I've refactored my
working code into a single file. It's a work in progress and not fully
generalized. The most obvious hole is that it doesn't parse extensions, yet.

X.509 certificates are usually encoded using ASN.1 DER (technically, CER).
DER is a type of TLV (tag-length-value) encoding. To parse symbols you need
to decode the length to know the extent of the symbol--the tag itself
doesn't tell you how long it will be, and the length can be arbitrarily
large. Formally I think such encodings are a type of context-sensitive
grammar. lpeg.Cmt is used below to match the initial tag and jump to regular
Lua code. The regular Lua code can then unpack the length, consume the
substring (possibly recursively invoking LPeg again), and tell LPeg how far
ahead to jump to resume parsing. It would be impractical if not impossible
to do this without Cmt (i.e. using pure PEGs, which can only handle
context-free grammars).

The LPeg patterns begin around line 570. You can use it like:

  local lpeg = require"lpeg"
  local pkix = require"pkix"
  local pem = require"pkix.pem"

  local raw = io.stdin:read"a"
  local der = pem.match(raw) or raw
  pkix.dump(io.stdout, lpeg.match(pkix.Certificate, der))

-- ==========================================================================
-- pkix.lua - Lua PKIX library
-- --------------------------------------------------------------------------
-- Copyright (c) 2018-2019  William Ahern
--
-- Permission is hereby granted, free of charge, to any person obtaining a
-- copy of this software and associated documentation files (the
-- "Software"), to deal in the Software without restriction, including
-- without limitation the rights to use, copy, modify, merge, publish,
-- distribute, sublicense, and/or sell copies of the Software, and to permit
-- persons to whom the Software is furnished to do so, subject to the
-- following conditions:
--
-- The above copyright notice and this permission notice shall be included
-- in all copies or substantial portions of the Software.
--
-- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
-- OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
-- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN
-- NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
-- DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
-- OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
-- USE OR OTHER DEALINGS IN THE SOFTWARE.
-- ==========================================================================
local lpeg = require"lpeg"
local pkix = {}

local sfind = string.find
local sformat = string.format
local spack = string.pack
local sunpack = string.unpack
local type = type
local P, S, R = lpeg.P, lpeg.S, lpeg.R
local C, Cc, Cf, Cg, Cmt, Cp, Ct = lpeg.C, lpeg.Cc, lpeg.Cf, lpeg.Cg, lpeg.Cmt, lpeg.Cp, lpeg.Ct

package.preload["pkix.base64"] = function ()
	local base64 = {}

	local byte = string.byte
	local char = string.char
	local function ord(s)
		return byte(s, 1)
	end

	-- for mapping 6-bit integer value to radix character ordinal value
	local charmap = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"

	-- for mapping 6-bit radix character ordinal value to 6-bit integer
	local bytemap = {}
	for i=1,#charmap do
		bytemap[byte(charmap, i)] = i - 1
	end

	function base64.encode(data, width)
		local function tr(ch)
			return byte(charmap, 1 + ch)
		end

		local function s2i(s)
			local x, y, z = byte(s, 1, #s)
			local i = x
			local n = 8

			if y then
				i = (i << 8) | y
				n = n + 8
			end

			if z then
				i = (i << 8) | z
				n = n + 8
			end

			return i, n
		end

		local function i2s(i, n)
			local a, b, c, d

			if n == 24 then
				a = tr((i >> 18) & 0x3f)
				b = tr((i >> 12) & 0x3f)
				c = tr((i >> 6) & 0x3f)
				d = tr((i >> 0) & 0x3f)
			elseif n == 16 then
				a = tr((i >> 10) & 0x3f)
				b = tr((i >> 4) & 0x3f)
				c = tr((i << 2) & 0x3f)
				d = ord"="
			elseif n == 8 then
				a = tr((i >> 2) & 0x3f)
				b = tr((i << 4) & 0x3f)
				c, d = ord"=", ord"="
			else
				error(sformat("invalid # of bits (%d)", n))
			end

			return char(a, b, c, d)
		end

		local text = data:gsub("..?.?", function (s)
			return i2s(s2i(s))
		end)
		if width then
			text = text:gsub(sformat("(%s)", string.rep(".", width)), "%1\n")
			text = text:byte(-1) == 0x0A and text:sub(1, -2) or text
		end
		return text
	end

	function base64.decode(text)
		local function s2i(s)
			local a, b, c, d = byte(s, 1, #s)

			local function pushchar(c, i, n)
				local b = bytemap[c]
				if b then
					return ((i << 6) | b), n + 6
				else
					return i, n
				end
			end

			return pushchar(d, pushchar(c, pushchar(b, pushchar(a, 0, 0))))
		end

		local function i2s(i, n)
			if n == 24 then
				return char(((i >> 16) & 0xff), ((i >> 8) & 0xff), ((i >> 0) & 0xff))
			elseif n == 18 then
				return char(((i >> 10) & 0xff), ((i >> 2) & 0xff))
			elseif n == 12 then
				return char((i >> 4) & 0xff)
			else
				error(string.format("invalid # of bits (%d)", n))
			end
		end

		return (text:gsub("..?.?.?", function (s)
			return i2s(s2i(s))
		end))
	end

	return base64
end -- pkix.base64

package.preload["pkix.pem"] = function ()
	local pem = {}

	local function lines(s, init)
		local cursor = init or 1
		return function ()
			local ln, pos = string.match(s, "([^\n]*)\n?()", cursor)
			if ln and tonumber(pos) > cursor then
				cursor = tonumber(pos)
				return ln, cursor
			end
		end
	end

	local function testlabel(got, expected)
		if not got then return end
		local i, j = got:find(expected or ".*")
		return i == 1 and j == #got and got or nil
	end

	-- core PEM parser
	--
	-- TODO: Maybe add option for parsing headers. Note that RFC 7468 prohibits
	-- generation of encapsulated headers.
	function pem.find(s, label, init)
		local cursor = init or 1
		local beginlabel, endlabel
		local beginpos, endpos
		local data = {}

		for ln, pos in lines(s, cursor) do
			beginlabel = testlabel(ln:match"^%s*-----BEGIN ([A-Z ]*)-----%s*$", label)
			if beginlabel then
				beginpos = cursor
				cursor = pos
				break
			else
				cursor = pos
			end
		end

		if not beginlabel then return end

		for ln, pos in lines(s, cursor) do
			cursor = pos
			endpos = cursor
			endlabel = ln:match"^%s*-----END ([A-Z ]*)-----%s*$"
			if endlabel then break end
			data[#data + 1] = ln:gsub("[^ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789%+%/%=]+", "")
		end

		if not endlabel then return end

		return beginpos, endpos, require"pkix.base64".decode(table.concat(data)), beginlabel
	end

	function pem.gmatch(s, label)
		local cursor = 1
		return function ()
			local i, j, data, beginlabel = pem.find(s, label, cursor)
			if data then
				cursor = j
				return data, beginlabel
			end
		end
	end

	function pem.match(s, label, init)
		local i, j, data, beginlabel = pem.find(s, label, init)
		if i then return data, beginlabel end
	end

	return pem
end

-- pkix.hexdump
do
	local push = table.insert

	local function newfmt(n)
		local n1 = math.min(n, 8)
		local n9 = math.min(math.max(n - 8, 0), 8)
		local x = {}
		local c = {}

		push(x, string.sub(string.rep(" %.2x", n1), 2))
		push(x, string.rep("   ", 8 - n1))
		push(x, " ")
		push(x, string.rep(" %.2x", n9))
		push(x, string.rep("   ", 8 - n9))

		push(c, string.rep("%c", n1))
		push(c, string.rep(" ", 8 - n1))
		push(c, string.rep("%c", n9))
		push(c, string.rep(" ", 8 - n9))

		return { x = table.concat(x), c = table.concat(c) }
	end

	local fmt = {}
	for i=1,16 do
		fmt[i] = newfmt(i)
	end

	local function glyph(ch, ...)
		if ch then
			ch = (ch > 32 and ch < 127 and ch) or 46 -- .
			return ch, glyph(...)
		end
	end

	function pkix.hexdump(v, how)
		local s = tostring(v)
		local b = {}

		how = how or "axc"
		local a_flag = string.find(how, "a")
		local x_flag = string.find(how, "x")
		local c_flag = string.find(how, "c")

		for i=1,#s,16 do
			local n = math.min(#s - i, 15) + 1
			local f = fmt[n]
			if a_flag then
				push(b, sformat("%.4x  ", i - 1))
			end
			if x_flag then
				push(b, sformat(f.x, string.byte(s, i, i + 15)))
				if c_flag then
					push(b, "  ")
				end
			end
			if c_flag then
				push(b, "|")
				push(b, sformat(f.c, glyph(string.byte(s, i, i + 15))))
				push(b, "|")
			end
			push(b, "\n")
		end

		return table.concat(b)
	end
end

-- pkix.dump
do
	local function spairs(t)
		local keys = {}
		for k in pairs(t) do
			keys[#keys + 1] = k
		end
		table.sort(keys)

		local i = 0
		return function ()
			if i < #keys then
				i = i + 1
				local k = keys[i]
				return k, t[k]
			end
		end
	end

	local function dump(fh, v, depth, tab)
		local tabs = string.rep(tab, depth)

		if type(v) == "table" then
			for k,v in spairs(v) do
				fh:write(tabs, k, "\n")
				dump(fh, v, depth + 1, tab)
			end
		else
			local s = tostring(v)
			if s:find("[^%g ]") then
				s = pkix.hexdump(s)
			end
			fh:write(tabs, s:gsub("\n", "\n" .. tabs), "\n")
		end
	end

	function pkix.dump(fh, v, depth, tab)
		return dump(fh, v, depth or 0, tab or "  ")
	end
end

-- pkix.viz
do
	local dec = {}
	for i=0,255 do
		local ch = string.char(i)
		dec[ch] = sformat("\\%.3d", i)
	end

	function pkix.viz(s)
		return tostring(s):gsub("[^%g ]", dec)
	end
end

local function clz8(i)
	if 0 == i & 0xFF then
		return 8
	end

	local n = 0
	if 0 == i & 0xF0 then
		n = n + 4
		i = i << 4
	end
	if 0 == i & 0xC0 then
		n = n + 2
		i = i << 2
	end
	if 0 == i & 0x80 then
		n = n + 1
	end

	return n
end

local function clz16(i)
	if 0 == i & 0xFF00 then
		return 8 + clz8(i)
	else
		return clz8(i >> 8)
	end
end

local function clz32(i)
	if 0 == i & 0xFFFF0000 then
		return 16 + clz16(i)
	else
		return clz16(i >> 16)
	end
end

local function clz64(i)
	if 0 == i & 0xFFFFFFFF00000000 then
		return 32 + clz32(i)
	else
		return clz32(i >> 32)
	end
end

local function clzbn(s)
	local i = sfind(s, "[^\x00]")
	if not i then
		return #s * 8
	end
	return ((i - 1) * 8) + clz8(s:byte(i))
end

local function getsubtable(t, k)
	local t1 = t[k]
	return type(t1) == "table" and t1 or rawset(t, k, {})[k]
end

local function mapn(f, n, x, ...)
	if n > 1 then
		return f(x), mapn(f, n - 1, ...)
	elseif n == 1 then
		return f(x)
	end
end

local function map(f, ...)
	return mapn(f, select("#", ...), ...)
end

local function packlength(n)
	if n <= 127 then
		return spack("B", n)
	elseif n <= 255 then
		return spack("BB", (0x80 | 0x01), n)
	elseif n <= 65535 then
		return spack("BBB", (0x80 | 0x02), (n >> 8), (0xff & n))
	else
		error("length too large")
	end

end

local function unpacklength(s, pos)
	local n = s:byte(pos)
	pos = pos + 1

	if 0x80 == (n & 0x80) then
		local digits = n & 0x7F
		local bits = digits * 8
		local nlz = clzbn(s:sub(pos, pos + digits - 1))
		if bits - nlz > 32 then
			return nil, sformat("DER definite-length value too large (got length gte 2^%d, expected lte 2^32-1)", bits - nlz)
		elseif nlz >= 8 then
			return nil, sformat("overlong DER length encoding (got %d octets, expected %d)", digits, (bits - nlz + 7) // 8)
		end

		n, pos = sunpack(sformat(">I%d", digits), s, pos)
		if n < 128 then
			return nil, sformat("overlong DER length encoding (length %d encoded long form, expected short form)", n)
		end
	end

	return n, pos
end

local function packoid(oid)
	local value = {}
	for v in oid:gmatch("%d+") do
		value[#value + 1] = tonumber(v)
	end

	local octet = {}
	octet[1] = 40 * value[1] + value[2]
	for i=3,#value do
		local n = value[i]
		local t = {}
		-- encode LSB to MSB
		t[1] = n & 0x7f
		n = n >> 7
		while n > 0 do
			t[#t + 1] = 0x80 | (n & 0x7f)
			n = n >> 7
		end
		-- reverse, MSB to LSB
		for i=#t,1,-1 do
			octet[#octet + 1] = t[i]
		end
	end
	return string.char(table.unpack(octet))
end

local function unpackoid(s)
	local value = {}

	value[1] = s:byte(1) // 40
	value[2] = s:byte(1) % 40

	local n = 0
	for i=2,#s do
		local b = s:byte(i)

		n = n << 7
		n = n | (0x7f & b)

		if 0x80 ~= (b & 0x80) then
			value[#value + 1] = n
			n = 0
		end
	end

	return table.concat(value, ".")
end

-- pkix.oid2txt,  pkix.txt2oid
--
-- OID short and long name mapping
--
do
	local namemap = {}
	local oidmap = {}

	for name,oid in string.gmatch([[
		commonName              2.5.4.3
		countryName             2.5.4.6
		distinguishedName       2.5.4.49
		id-ecPublicKey          1.2.840.10045.2.1
		id-X25519               1.3.101.110
		id-X448                 1.3.101.111
		id-Ed25519              1.3.101.112
		id-Ed448                1.3.101.113
		localityName            2.5.4.7
		organizationName        2.5.4.10
		organizationalUnitName  2.5.4.11
	        rsaEncryption           1.2.840.113549.1.1.1
		secp192r1               1.2.840.10045.3.1.1
	        secp256r1               1.2.840.10045.3.1.7
	        secp384r1               1.3.132.0.34
	        secp521r1               1.3.132.0.35
		sha256WithRSAEncryption 1.2.840.113549.1.1.11
		sha384WithRSAEncryption 1.2.840.113549.1.1.12
		sha512WithRSAEncryption 1.2.840.113549.1.1.13
		stateOrProvinceName     2.5.4.8

		C                       countryName
		CN                      commonName
		L                       localityName
		O                       organizationName
		OU                      organizationalUnitName
		ST                      stateOrProvinceName

		prime192v1              secp192r1
		prime256v1              secp256r1
	]], "[ \t]*([^ \t\n]+)[ \t]+([^ \t\n]+)[ \t]*\n") do
		if not oid:match"^%d[%d%.]*$" then
			local primaryname = oid
			oid = namemap[primaryname] or error(sformat("alias %s defined before primary name %s", name, primaryname))
		end

		namemap[name] = oid
		local t = getsubtable(oidmap, oid)
		t[#t + 1] = name
	end

	-- NOTE: an OID will always have at least two integer components
	local function isliteral(oid)
		return oid:find"^%d+%.[%d%.]*%d$" and not oid:find("%.%.", 1, true)
	end

	function pkix.txt2oid(name)
		if namemap[name] then
			return namemap[name]
		elseif isliteral(name) then
			return name
		else
			error(sformat("unknown OID name (%s)", tostring(name or "?")))
		end
	end

	function pkix.oid2txt(oid)
		return oidmap[oid] and oidmap[oid][1] or oid
	end
end

local function Cinner(identifier, patt)
	local innermatch

	if lpeg.type(patt) then
		innermatch = function (s)
			return lpeg.match(patt * -P(1), s)
		end
	elseif type(patt) == "function" then
		innermatch = patt
	elseif patt == nil then
		innermatch = function (s)
			return s
		end
	else
		error(sformat("bad argument #2 (expected pattern, function, or nil, got %s)", type(patt)), 2)
	end

	return Cmt(identifier, function (s, pos)
		local n, pos = assert(unpacklength(s, pos))
		local s1 = s:sub(pos, pos + n - 1)
		pos = pos + n

		return (function (pos, v, ...)
			if v then
				return pos, v, ...
			else
				return false
			end
		end)(pos, innermatch(s1))
	end)
end

local ANY = Cmt(P(1), function (s, pos)
	local n, pos = assert(unpacklength(s, pos))
	local bn = s:sub(pos, pos + n - 1)
	return pos + n, bn
end)

local INTEGER = Cinner(P"\x02")

local BIT_STRING = Cinner(P"\x03", function (s)
	local pad = s:byte(1) -- first octet is number of padding bits
	assert(pad == 0, "BIT STRING not octet aligned") -- we only support DER
	return s:sub(2)
end)

local OCTET_STRING = Cinner(P"\x04")

local NULL = P"\x05\x00"

local OBJECT_IDENTIFIER = function (oid)
	if oid then
		local s = packoid(pkix.txt2oid(oid))
		return P(sformat("\x06%s%s", packlength(#s), s)) * Cc(oid)
	else
		return Cinner(P"\x06", function (s)
			return assert(unpackoid(s))
		end)
	end
end
local OID = OBJECT_IDENTIFIER

local SEQUENCE = function (patt)
	return Cinner(P"\x30", patt)
end

local SET_OF = function (patt)
	return Cinner(P"\x31", Ct(patt^0))
end

local UTCTIME = Cinner(P"\x17", function (s)
	-- X.690 requires
	--  1) second precision
	--  2) no local time differential (always UTC)
	local YY, MM, DD, hh, mm, ss = map(tonumber, s:match"^(%d%d)(%d%d)(%d%d)(%d%d)(%d%d)(%d%d)Z$")
	if not YY then
		error(sformat("malformed UTCTIME (%s)", pkix.viz(s)))
	elseif MM == 0 or MM > 12 or DD == 0 or DD > 31 or hh > 24 or mm > 59 or ss > 59 then
		error(sformat("out of range UTCTIME value (%s)", pkix.viz(s)))
	elseif hh == 24 and not (mm == 0 and ss == 0) then
		error(sformat("incorrect UTCTIME midnight representation (%s)", pkix.viz(s)))
	end
	return s
end)

local Version = Cinner("\xA0", INTEGER)
local CertificateSerialNumber = INTEGER

-- algorithm OIDs
local rsaEncryption = OID"rsaEncryption"
local sha256WithRSAEncryption = OID"sha256WithRSAEncryption"
local sha384WithRSAEncryption = OID"sha384WithRSAEncryption"
local sha512WithRSAEncryption = OID"sha512WithRSAEncryption"
local id_ecPublicKey = OID"id-ecPublicKey"
local secp256r1 = OID"secp256r1"
local secp384r1 = OID"secp384r1"
local secp521r1 = OID"secp521r1"
local id_X25519 = OID"id-X25519"
local id_X448 = OID"id-X448"
local id_Ed25519 = OID"id-Ed25519"
local id_Ed448 = OID"id-Ed448"

local RSAPublicKey = SEQUENCE(Ct(Cg(INTEGER, "modulus") * Cg(INTEGER, "publicExponent")))
local namedCurve = secp256r1 + secp384r1 + secp521r1
local ECParameters = Cg(namedCurve, "namedCurve")

do
	function pkix.P_AlgorithmIdentifier(algorithm, parameters)
		return SEQUENCE(Ct(
			Cg(algorithm, "algorithm") *
			(parameters and Cg(Ct(parameters), "parameters") or P"")
		))
	end

	local AlgorithmIdentifier_rsa = pkix.P_AlgorithmIdentifier(
		(rsaEncryption +
		 sha256WithRSAEncryption +
		 sha384WithRSAEncryption +
		 sha512WithRSAEncryption),
		NULL
	)

	local AlgorithmIdentifier_ellipticCurve =
		pkix.P_AlgorithmIdentifier(id_ecPublicKey, ECParameters)

	local AlgorithmIdentifier_edwardsCurve =
		pkix.P_AlgorithmIdentifier(id_X25519 + id_X448 + id_Ed25519 + id_Ed448)

	pkix.AlgorithmIdentifier =
		AlgorithmIdentifier_rsa +
		AlgorithmIdentifier_ellipticCurve +
		AlgorithmIdentifier_edwardsCurve +
		SEQUENCE(C(P(1))^0)
end

local subjectPublicKey_RSAPublicKey = Cmt(BIT_STRING, function (_, pos, s)
	local t = lpeg.match(RSAPublicKey, s)
	if t then
		return pos, t
	else
		return false
	end
end)
local SubjectPublicKeyInfo_rsaEncryption = SEQUENCE(Ct(Cg(pkix.P_AlgorithmIdentifier(rsaEncryption, NULL), "algorithm") * Cg(subjectPublicKey_RSAPublicKey, "subjectPublicKey")))
local SubjectPublicKeyInfo_id_ecPublicKey = SEQUENCE(Ct(Cg(pkix.P_AlgorithmIdentifier(id_ecPublicKey, ECParameters), "algorithm") * Cg(BIT_STRING, "subjectPublicKey")))
local SubjectPublicKeyInfo_edwardsCurve = SEQUENCE(Ct(Cg(pkix.P_AlgorithmIdentifier(id_X25519 + id_X448 + id_Ed25519 + id_Ed448), "algorithm") * Cg(BIT_STRING, "subjectPublicKey")))
pkix.SubjectPublicKeyInfo = SubjectPublicKeyInfo_rsaEncryption + SubjectPublicKeyInfo_id_ecPublicKey + SubjectPublicKeyInfo_edwardsCurve

-- NOTE: RDNSequence is an ordered list of RelativeDistinguishedName (RDN),
-- which is itself an unordered list of AttributeTypeAndValue (AV) pairs.
-- In practice RDNs lists only contain a single AV pair and when visualized
-- the RDN list is elided so the extra tree depth is not obvious. This is
-- quite unintuitive.
local AttributeType = OBJECT_IDENTIFIER() / pkix.oid2txt
local AttributeValue = ANY
local AttributeTypeAndValue = SEQUENCE(Ct(Cg(AttributeType, "type") * Cg(AttributeValue, "value")))
local RelativeDistinguishedName = SET_OF(AttributeTypeAndValue)
local RDNSequence = SEQUENCE(Ct(RelativeDistinguishedName^0))
local Name = Ct(Cg(RDNSequence, "rdnSequence"))

local Validity = SEQUENCE(Ct(
	Cg(UTCTIME, "notBefore") *
	Cg(UTCTIME, "notAfter")
))

local TBSCertificate = SEQUENCE(Ct(
	Cg(Version, "version") *
	Cg(CertificateSerialNumber, "serialNumber") *
	Cg(pkix.AlgorithmIdentifier, "signature") *
	Cg(Name, "issuer") *
	Cg(Validity, "validity") *
	Cg(Name, "subject") *
	Cg(pkix.SubjectPublicKeyInfo, "subjectPublicKeyInfo") *
	Cg(P(1)^0, "trash")
))
local Signature = BIT_STRING
pkix.Certificate = SEQUENCE(Ct(Cg(TBSCertificate, "tbsCertificate") * Cg(pkix.AlgorithmIdentifier, "signatureAlgorithm") * Cg(Signature, "signature")))

return pkix