Symbolic Differentiation

lua-users home
wiki

This started out as a little exercise for the PenlightLibraries, but became sufficiently obsessive to warrant a more throrough implemetation.

The first step in symbolic algebra is defining a representation. Getting expressions into a suitable form is actually straightforward; no parsing of expressions is needed, since we have Lua to do that for us. Using the pl.func library does all the hard work; it redefines the arithmetic operations to work on placeholder expressions (PEs), which are Lua expressions involving dummy variables called placeholders. pl.func defines standard placeholders for arguments called _1,_2, etc but the Var function will create new ones of our chosing:

utils.import 'pl.func'
a,b,c,d = Var 'a,b,c,d'
print(a+b+c+d)

Which will indeed print out the expression in a readable form. PE operator expressions are stored as combinations of tables looking like {op='+',x,y} , which have an associated metatable which defines the metamethods like __add and so forth. As a tree, with the usual associativity of Lua operators we get:

It is irritating to draw these diagrams, so a better notation is Lisp-style S-expressions:

1: (+ (+ (+ a b) c) d)

However, with the various manipulations we will perform, this canonical form is not the only possible representation of a+b+c+d:

2: (+ a (+ b (+ c d)))
3: (+ (+ a b) (+ c d))

Now, experience shows that this leads to madness. Instead, it's easier to go for the canonical Lisp representation:

4: (+ a b c d)

Many operations become straightforward once this is in place, for instance comparing with (+ a c b d) is just a matter of doing a 'compare with no order' on the arguments. Displaying PEs in this form is straightforward. isPE simply checks the expression to see if it is a placeholder expression, by looking at the metatable. PEs with op=='X' are placeholder variables, so the rest must be expression nodes.

function sexpr (e)
	if isPE(e) then
  	  if e.op ~= 'X' then
	    local args = tablex.imap(sexpr,e)
	    return '('..e.op..' '..table.concat(args,' ')..')'
	  else
	    return e.repr
	  end
	else
	  return tostring(e)
	end
end

The first task is to balance the expressions, which converts representations 1-3 into 4.

function balance (e)
	if isPE(e) and e.op ~= 'X' then
	  local op,args = e.op
	  if op == '+' or op == '*' then
		args = rcollect(e)
	  else
		args = imap(balance,e)
	  end
	  for i = 1,#args do
		e[i] = args[i]
	  end
	end
	return e
end

For the non-commutative operators, the idea is just to balance all the subexpressions by mapping balance over the array part of the PE, which is the argument list. These are then copied back in-place. The non-trivial part is dealing with + and *, where it is necessary to collect all the arguments from expression trees looking like 1,2 or 3 and convert them into the fourth form.

function tcollect (op,e,ls)
    if isPE(e) and e.op == op then
	for i = 1,#e do
	   tcollect(op,e[i],ls)
        end
    else
        ls:append(e)
        return
    end
end

function rcollect (e)
    local res = List()
    tcollect(e.op,e,res)
    return res
end

This recursively goes down same-operator chains (the (+ (+ ...) mentioned earlier) and collects the arguments, flattening them into n-ary + or * expressions.

Here is a useful function, which follows the same recursive pattern:

-- does this PE contain a reference to x?
function references (e,x)
    if isPE(e) then
		if e.op == 'X' then return x.repr == e.repr
		else
			return find_if(e,references,x)
		end
	else
		return false
	end
end

Here are functions to create n-ary products and sums:

function muli (args) return PE{op='*',unpack(args)} end
function addi (args) return PE{op='+',unpack(args)} end

With this in place, the basic differentiation rules are not difficult. Firstly, only consider subexpressions which do contain the variable:

function diff (e,x)
    if isPE(e) and references(e,x) then
		local op = e.op
        if op == 'X' then
            return 1
	else
   	    local a,b = e[1],e[2]
            if op == '+' then -- differentiation is linear
		local args = imap(diff,e,x)
		return balance(addi(args))
            elseif op == '*' then -- product rule
		local res,d,ee = {}
		for i = 1,#e do
			d = fold(diff(e[i],x))
			if d ~= 0 then
	 		  ee = {unpack(e)} -- make a copy
			  ee[i] = d
			  append(res,balance(muli(ee)))
			end
		end
		if #res > 1 then return addi(res)
		else return res[1] end
            elseif op == '^' and isnumber(b) then -- power rule
                return b*x^(b-1)
            end
        end
	else
		return 0
	end
end

The derivative of a sum of expressions is the sum of the derivatives. Again, imap does the job of applying the function recursively over the subexpressions. After constructing the result, we re-balance for luck.

The product rule is given here in its general form, with an explicit check for terms which work out to zero - that is the job of fold, which will be discussed next.

(uvw..)' = u'vw.. + uv'w... + uvw'... + ...

And finally, the simple power rule. Note how the result can be expressed in a straightforward fashion, since all these operators are acting on PEs.

In fact, all these rules are certainly clearer if you use form 1, binary + and *! But then simplification becomes unbearable. And simplification ('folding') is the tricky one to get right. fold is a longish function, so I will deal with it in sections:

local op = e.op
local addmul = op == '*' or op == '+'
-- first fold all arguments
local args = imap(fold,e)
if not addmul and not find_if(args,isPE) then
  -- no placeholders in these args, we can fold the expression.
  local opfn = optable[op]
  if opfn then
    return opfn(unpack(args))
  else
   return '?'
  end
elseif addmul then

The first if is looking for a case where a subexpression has no symbols, i.e. it is something like 2*5 or 10^2; in this case, the constant can be completely folded. optable (defined in pl.operator) gives a mapping between the operator names and the actual function implementing them.

elseif op == '^' then
  if args[2] == 1 then return args[1] end -- identity
  if args[2] == 0 then return 1 end
end
return PE{op=op,unpack(args)}

This clause is clearing up expressions like x^1 and y^0 which naturally arise from the power rule in diff. Once args has been processed, the expression can be put together again.

The bulk of this routine handles the awkward twins, + and *.

-- split the args into two classes, PE args and non-PE args.
local classes = List.partition(args,isPE)
local pe,npe = classes[true],classes[false]

List.partition takes a list and a function, which takes one argument and returns a single value. The result is table where the keys are the returned values, and the values are lists of those elements where the function returned that value. So:

List{1,2,3,4}:partition(function(x) return x > 2 end)
--> {false={1,2},true={3,4}}
List{'one',math.sin,10,20,{1,2}}:partition(type)
--> {function={function: 00369110},string={one},number={10,20},table={{{1,2}} }

(Mathematically, these are referred to as equivalence classes and partition would be called the quotient set)

In this case, we want to separate the non-symbolic arguments from the symbolic arguments; order does not matter. The non-symbolic arguments npe can folded into a constant. At this point, operator identity rules can kick in, so that we can drop (* 0 x) and simplify (+ 0 x) to be just x.

The final simplification is replacing repeated values, so that (* x x) should become (^ x 2) and (+ x x x) should become (* x 3). count_map from pl.tablex will do the job. It is given a list-like table and a function which defines equivalence, and returns a map from the values to the number of their occurrences, so that count_map{'a','b','a'} is {a=2,b=1} .

Given this test function:

function testdiff (e)
  balance(e)
  e = diff(e,x)
  balance(e)
  print('+ ',e)
  e = fold(e)
  print('- ',e)
end

and these cases:

testdiff(x^2+1)
testdiff(3*x^2)
testdiff(x^2 + 2*x^3)
testdiff(x^2 + 2*a*x^3 + x^4)
testdiff(2*a*x^3)
testdiff(x*x*x)

we get these results, showing why something like fold is so necessary to process the result of diff.

+ 	2 * x ^ 1 + 0
- 	2 * x
+ 	3 * 2 * x
- 	6 * x
+ 	2 * x ^ 1 + 2 * 3 * x ^ 2
- 	2 * x + 6 * x ^ 2
+ 	2 * x ^ 1 + 2 * a * 3 * x ^ 2 + 4 * x ^ 3
- 	6 * a * x ^ 2 + 4 * x ^ 3 + 2 * x
+ 	2 * a * 3 * x ^ 2
- 	6 * a * x ^ 2
+ 	1 * x * x + x * 1 * x + x * x * 1
- 	x ^ 2 * 3

https://github.com/stevedonovan/Penlight/blob/master/examples/symbols.lua

https://github.com/stevedonovan/Penlight/blob/master/examples/test-symbols.lua

SteveDonovan

See Also


RecentChanges · preferences
edit · history
Last edited July 4, 2012 10:41 am GMT (diff)