# Symbolic Differentiation

This started out as a little exercise for the PenlightLibraries, but became sufficiently obsessive to warrant a more throrough implemetation.

The first step in symbolic algebra is defining a representation. Getting expressions into a suitable form is actually straightforward; no parsing of expressions is needed, since we have Lua to do that for us. Using the `pl.func` library does all the hard work; it redefines the arithmetic operations to work on placeholder expressions (PEs), which are Lua expressions involving dummy variables called placeholders. `pl.func` defines standard placeholders for arguments called `_1`,`_2`, etc but the `Var` function will create new ones of our chosing:

```utils.import 'pl.func'
a,b,c,d = Var 'a,b,c,d'
print(a+b+c+d)
```

Which will indeed print out the expression in a readable form. PE operator expressions are stored as combinations of tables looking like ` {op='+',x,y} `, which have an associated metatable which defines the metamethods like `__add` and so forth. As a tree, with the usual associativity of Lua operators we get:

It is irritating to draw these diagrams, so a better notation is Lisp-style S-expressions:

```1: (+ (+ (+ a b) c) d)
```

However, with the various manipulations we will perform, this canonical form is not the only possible representation of `a+b+c+d`:

```2: (+ a (+ b (+ c d)))
3: (+ (+ a b) (+ c d))
```

Now, experience shows that this leads to madness. Instead, it's easier to go for the canonical Lisp representation:

```4: (+ a b c d)
```

Many operations become straightforward once this is in place, for instance comparing with `(+ a c b d)` is just a matter of doing a 'compare with no order' on the arguments. Displaying PEs in this form is straightforward. `isPE` simply checks the expression to see if it is a placeholder expression, by looking at the metatable. PEs with `op=='X'` are placeholder variables, so the rest must be expression nodes.

```function sexpr (e)
if isPE(e) then
if e.op ~= 'X' then
local args = tablex.imap(sexpr,e)
return '('..e.op..' '..table.concat(args,' ')..')'
else
return e.repr
end
else
end
end
```

The first task is to balance the expressions, which converts representations 1-3 into 4.

```function balance (e)
if isPE(e) and e.op ~= 'X' then
local op,args = e.op
if op == '+' or op == '*' then
args = rcollect(e)
else
args = imap(balance,e)
end
for i = 1,#args do
e[i] = args[i]
end
end
return e
end
```

For the non-commutative operators, the idea is just to balance all the subexpressions by mapping `balance` over the array part of the PE, which is the argument list. These are then copied back in-place. The non-trivial part is dealing with + and *, where it is necessary to collect all the arguments from expression trees looking like 1,2 or 3 and convert them into the fourth form.

```function tcollect (op,e,ls)
if isPE(e) and e.op == op then
for i = 1,#e do
tcollect(op,e[i],ls)
end
else
ls:append(e)
return
end
end

function rcollect (e)
local res = List()
tcollect(e.op,e,res)
return res
end
```

This recursively goes down same-operator chains (the `(+ (+ ...)` mentioned earlier) and collects the arguments, flattening them into n-ary + or * expressions.

Here is a useful function, which follows the same recursive pattern:

```-- does this PE contain a reference to x?
function references (e,x)
if isPE(e) then
if e.op == 'X' then return x.repr == e.repr
else
return find_if(e,references,x)
end
else
return false
end
end
```

Here are functions to create n-ary products and sums:

```function muli (args) return PE{op='*',unpack(args)} end
function addi (args) return PE{op='+',unpack(args)} end
```

With this in place, the basic differentiation rules are not difficult. Firstly, only consider subexpressions which do contain the variable:

```function diff (e,x)
if isPE(e) and references(e,x) then
local op = e.op
if op == 'X' then
return 1
else
local a,b = e[1],e[2]
if op == '+' then -- differentiation is linear
local args = imap(diff,e,x)
elseif op == '*' then -- product rule
local res,d,ee = {}
for i = 1,#e do
d = fold(diff(e[i],x))
if d ~= 0 then
ee = {unpack(e)} -- make a copy
ee[i] = d
append(res,balance(muli(ee)))
end
end
if #res > 1 then return addi(res)
else return res[1] end
elseif op == '^' and isnumber(b) then -- power rule
return b*x^(b-1)
end
end
else
return 0
end
end
```

The derivative of a sum of expressions is the sum of the derivatives. Again, `imap` does the job of applying the function recursively over the subexpressions. After constructing the result, we re-balance for luck.

The product rule is given here in its general form, with an explicit check for terms which work out to zero - that is the job of `fold`, which will be discussed next.

```(uvw..)' = u'vw.. + uv'w... + uvw'... + ...
```

And finally, the simple power rule. Note how the result can be expressed in a straightforward fashion, since all these operators are acting on PEs.

In fact, all these rules are certainly clearer if you use form 1, binary + and *! But then simplification becomes unbearable. And simplification ('folding') is the tricky one to get right. `fold` is a longish function, so I will deal with it in sections:

```local op = e.op
local addmul = op == '*' or op == '+'
-- first fold all arguments
local args = imap(fold,e)
if not addmul and not find_if(args,isPE) then
-- no placeholders in these args, we can fold the expression.
local opfn = optable[op]
if opfn then
return opfn(unpack(args))
else
return '?'
end
```

The first `if` is looking for a case where a subexpression has no symbols, i.e. it is something like `2*5` or `10^2`; in this case, the constant can be completely folded. `optable` (defined in `pl.operator`) gives a mapping between the operator names and the actual function implementing them.

```elseif op == '^' then
if args[2] == 1 then return args[1] end -- identity
if args[2] == 0 then return 1 end
end
return PE{op=op,unpack(args)}
```

This clause is clearing up expressions like `x^1` and `y^0` which naturally arise from the power rule in `diff`. Once `args` has been processed, the expression can be put together again.

The bulk of this routine handles the awkward twins, + and *.

```-- split the args into two classes, PE args and non-PE args.
local classes = List.partition(args,isPE)
local pe,npe = classes[true],classes[false]
```

`List.partition` takes a list and a function, which takes one argument and returns a single value. The result is table where the keys are the returned values, and the values are lists of those elements where the function returned that value. So:

```List{1,2,3,4}:partition(function(x) return x > 2 end)
--> {false={1,2},true={3,4}}
List{'one',math.sin,10,20,{1,2}}:partition(type)
--> {function={function: 00369110},string={one},number={10,20},table={{{1,2}} }
```

(Mathematically, these are referred to as equivalence classes and `partition` would be called the quotient set)

In this case, we want to separate the non-symbolic arguments from the symbolic arguments; order does not matter. The non-symbolic arguments `npe` can folded into a constant. At this point, operator identity rules can kick in, so that we can drop `(* 0 x)` and simplify `(+ 0 x)` to be just `x`.

The final simplification is replacing repeated values, so that `(* x x)` should become `(^ x 2)` and `(+ x x x)` should become `(* x 3)`. `count_map` from `pl.tablex` will do the job. It is given a list-like table and a function which defines equivalence, and returns a map from the values to the number of their occurrences, so that ` count_map{'a','b','a'} ` is ` {a=2,b=1} `.

Given this test function:

```function testdiff (e)
balance(e)
e = diff(e,x)
balance(e)
print('+ ',e)
e = fold(e)
print('- ',e)
end
```

and these cases:

```testdiff(x^2+1)
testdiff(3*x^2)
testdiff(x^2 + 2*x^3)
testdiff(x^2 + 2*a*x^3 + x^4)
testdiff(2*a*x^3)
testdiff(x*x*x)
```

we get these results, showing why something like `fold` is so necessary to process the result of `diff`.

```+ 	2 * x ^ 1 + 0
- 	2 * x
+ 	3 * 2 * x
- 	6 * x
+ 	2 * x ^ 1 + 2 * 3 * x ^ 2
- 	2 * x + 6 * x ^ 2
+ 	2 * x ^ 1 + 2 * a * 3 * x ^ 2 + 4 * x ^ 3
- 	6 * a * x ^ 2 + 4 * x ^ 3 + 2 * x
+ 	2 * a * 3 * x ^ 2
- 	6 * a * x ^ 2
+ 	1 * x * x + x * 1 * x + x * x * 1
- 	x ^ 2 * 3
```