48

8

13

7

ReverseDiffSource.jl

Reverse automated differentiation from an expression or a function

This package provides a function rdiff() that generates valid Julia code for the calculation of derivatives up to any order for a user supplied expression or generic function. Install with Pkg.add("ReverseDiffSource"). Package documentation and examples can be found here.

This version of automated differentiation operates at the source level (provided either in an expression or a generic function) to output Julia code calculating the derivatives (in a expression or a function respectively). Compared to other automated differentiation methods it does not rely on method overloading or new types and should, in principle, produce fast code.

Usage examples:

• derivative of x³

julia> rdiff( :(x^3) , x=Float64)  # 'x=Float64' indicates the type of x to rdiff
:(begin
(x^3,3 * x^2.0)  # expression calculates a tuple of (value, derivate)
end)

• first 10 derivatives of sin(x) (notice the simplifications)

julia> rdiff( :(sin(x)) , order=10, x=Float64)  # derivatives up to order 10
:(begin
_tmp1 = sin(x)
_tmp2 = cos(x)
_tmp3 = -_tmp1
_tmp4 = -_tmp2
_tmp5 = -_tmp3
(_tmp1,_tmp2,_tmp3,_tmp4,_tmp5,_tmp2,_tmp3,_tmp4,_tmp5,_tmp2,_tmp3)
end)

• works on functions too

julia> rosenbrock(x) = (1 - x)^2 + 100(x - x^2)^2   # function to be derived
julia> rosen2 = rdiff(rosenbrock, (Vector{Float64},), order=2)       # orders up to 2
(anonymous function)

• gradient calculation of a 3 hidden layer neural network for backpropagation

# w1-w3 are the hidden layer weight matrices, x1 the input vector
function ann(w1, w2, w3, x1)
x2 = w1 * x1
x2 = log(1. + exp(x2))   # soft RELU unit
x3 = w2 * x2
x3 = log(1. + exp(x3))   # soft RELU unit
x4 = w3 * x3
1. / (1. + exp(-x4))  # sigmoid output
end

w1, w2, w3 = randn(10,10), randn(10,10), randn(1,10)
x1 = randn(10)
dann = m.rdiff(ann, (Matrix{Float64}, Matrix{Float64}, Matrix{Float64}, Vector{Float64}))
dann(w1, w2, w3, x1) # network output + gradient on w1, w2, w3 and x1

02/06/2014

2 months ago

322 commits