RealNVP network
For the definition of this network and concepts of normalizing flow, please refer this realnvp blog: https://lilianweng.github.io/lil-log/2018/10/13/flow-based-deep-generative-models.html, and the pytorch notebook: https://github.com/GiggleLiu/marburg/blob/master/solutions/realnvp.ipynb
using NiLang, NiLang.AD
using LinearAlgebra
using DelimitedFiles
using Plots
include
the optimizer, you can find it under the Adam.jl
file in the examples/
folder.
include(NiLang.project_relative_path("examples", "Adam.jl"))
gclip! (generic function with 1 method)
Model definition
First, define the single layer transformation and its behavior under GVar
- the gradient wrapper.
struct RealNVPLayer{T}
# transform network
W1::Matrix{T}
b1::Vector{T}
W2::Matrix{T}
b2::Vector{T}
y1::Vector{T}
y1a::Vector{T}
# scaling network
sW1::Matrix{T}
sb1::Vector{T}
sW2::Matrix{T}
sb2::Vector{T}
sy1::Vector{T}
sy1a::Vector{T}
end
"""collect parameters in the `layer` into a vector `out`."""
function collect_params!(out, layer::RealNVPLayer)
k=0
for field in [:W1, :b1, :W2, :b2, :sW1, :sb1, :sW2, :sb2]
v = getfield(layer, field)
nv = length(v)
out[k+1:k+nv] .= vec(v)
k += nv
end
return out
end
"""dispatch vectorized parameters `out` into the `layer`."""
function dispatch_params!(layer::RealNVPLayer, out)
k=0
for field in [:W1, :b1, :W2, :b2, :sW1, :sb1, :sW2, :sb2]
v = getfield(layer, field)
nv = length(v)
vec(v) .= out[k+1:k+nv]
k += nv
end
return out
end
function nparameters(n::RealNVPLayer)
sum(x->length(getfield(n, x)), [:W1, :b1, :W2, :b2, :sW1, :sb1, :sW2, :sb2])
end
nparameters (generic function with 1 method)
Then, we define network
and how to access the parameters.
const RealNVP{T} = Vector{RealNVPLayer{T}}
nparameters(n::RealNVP) = sum(nparameters, n)
function collect_params(n::RealNVP{T}) where T
out = zeros(T, nparameters(n))
k = 0
for layer in n
np = nparameters(layer)
collect_params!(view(out, k+1:k+np), layer)
k += np
end
return out
end
function dispatch_params!(network::RealNVP, out)
k = 0
for layer in network
np = nparameters(layer)
dispatch_params!(layer, view(out, k+1:k+np))
k += np
end
return network
end
function random_realnvp(nparams::Int, nhidden::Int, nhidden_s::Int, nlayer::Int; scale=0.1)
random_realnvp(Float64, nparams, nhidden, nhidden_s::Int, nlayer; scale=scale)
end
function random_realnvp(::Type{T}, nparams::Int, nhidden::Int, nhidden_s::Int, nlayer::Int; scale=0.1) where T
nin = nparams÷2
scale = T(scale)
y1 = zeros(T, nhidden)
sy1 = zeros(T, nhidden_s)
RealNVPLayer{T}[RealNVPLayer(
randn(T, nhidden, nin)*scale, randn(T, nhidden)*scale,
randn(T, nin, nhidden)*scale, randn(T, nin)*scale, y1, zero(y1),
randn(T, nhidden_s, nin)*scale, randn(T, nhidden_s)*scale,
randn(T, nin, nhidden_s)*scale, randn(T, nin)*scale, sy1, zero(sy1),
) for _ = 1:nlayer]
end
random_realnvp (generic function with 2 methods)
Loss function
In each layer, we use the information in x
to update y!
. During computing, we use to vector type ancillas y1
and y1a
, both of them can be uncomputed at the end of the function.
@i function onelayer!(x::AbstractVector{T}, layer::RealNVPLayer{T},
y!::AbstractVector{T}, logjacobian!::T; islast) where T
@routine @invcheckoff begin
# scale network
scale ← zero(y!)
ytemp2 ← zero(y!)
i_affine!(layer.sy1, layer.sW1, layer.sb1, x)
@inbounds for i=1:length(layer.sy1)
if (layer.sy1[i] > 0, ~)
layer.sy1a[i] += layer.sy1[i]
end
end
i_affine!(scale, layer.sW2, layer.sb2, layer.sy1a)
# transform network
i_affine!(layer.y1, layer.W1, layer.b1, x)
# relu
@inbounds for i=1:length(layer.y1)
if (layer.y1[i] > 0, ~)
layer.y1a[i] += layer.y1[i]
end
end
end
# inplace multiply exp of scale! -- dangerous
@inbounds @invcheckoff for i=1:length(scale)
@routine begin
expscale ← zero(T)
tanhscale ← zero(T)
if (islast, ~)
tanhscale += tanh(scale[i])
else
tanhscale += scale[i]
end
expscale += exp(tanhscale)
end
logjacobian! += tanhscale
# inplace multiply!!!
temp ← zero(T)
temp += y![i] * expscale
SWAP(temp, y![i])
temp -= y![i] / expscale
temp → zero(T)
~@routine
end
# affine the transform layer
i_affine!(y!, layer.W2, layer.b2, layer.y1a)
~@routine
# clean up accumulated rounding error, since this memory is reused.
@safe layer.y1 .= zero(T)
@safe layer.sy1 .= zero(T)
end
A realnvp network always transforms inputs reversibly. We update one half of x!
a time, so that input and output memory space do not clash.
@i function realnvp!(x!::AbstractVector{T}, network::RealNVP{T}, logjacobian!) where T
@invcheckoff for i=1:length(network)
np ← length(x!)
if (i%2==0, ~)
@inbounds onelayer!(x! |> subarray(np÷2+1:np), network[i], x! |> subarray(1:np÷2), logjacobian!; islast=i==length(network))
else
@inbounds onelayer!(x! |> subarray(1:np÷2), network[i], x! |> subarray(np÷2+1:np), logjacobian!; islast=i==length(network))
end
end
end
How to obtain the log-probability of a data.
@i function logp!(out!::T, x!::AbstractVector{T}, network::RealNVP{T}) where T
(~realnvp!)(x!, network, out!)
@invcheckoff for i = 1:length(x!)
@routine begin
xsq ← zero(T)
@inbounds xsq += x![i]^2
end
out! -= 0.5 * xsq
~@routine
end
end
The negative-log-likelihood loss function
@i function nll_loss!(out!::T, cum!::T, xs!::Matrix{T}, network::RealNVP{T}) where T
@invcheckoff for i=1:size(xs!, 2)
@inbounds logp!(cum!, xs! |> subarray(:,i), network)
end
out! -= cum!/(@const size(xs!, 2))
end
Training
function train(x_data, model; num_epochs = 800)
num_vars = size(x_data, 1)
params = collect_params(model)
optimizer = Adam(; lr=0.01)
for epoch = 1:num_epochs
loss, a, b, c = nll_loss!(0.0, 0.0, copy(x_data), model)
if epoch % 50 == 1
println("epoch = $epoch, loss = $loss")
display(showmodel(x_data, model))
end
_, _, _, gmodel = (~nll_loss!)(GVar(loss, 1.0), GVar(a), GVar(b), GVar(c))
g = grad.(collect_params(gmodel))
update!(params, grad.(collect_params(gmodel)), optimizer)
dispatch_params!(model, params)
end
return model
end
function showmodel(x_data, model; nsamples=2000)
scatter(x_data[1,1:nsamples], x_data[2,1:nsamples]; xlims=(-5,5), ylims=(-5,5))
zs = randn(2, nsamples)
for i=1:nsamples
realnvp!(view(zs, :, i), model, 0.0)
end
scatter!(zs[1,:], zs[2,:])
end
showmodel (generic function with 1 method)
you can find the training data in examples/
folder
x_data = Matrix(readdlm(NiLang.project_relative_path("examples", "train.dat"))')
import Random; Random.seed!(22)
model = random_realnvp(Float64, size(x_data, 1), 10, 10, 4; scale=0.1)
4-element Vector{Main.RealNVPLayer{Float64}}:
Main.RealNVPLayer{Float64}([0.10365283655905022; 0.12832890997547838; … ; 0.10237515285351262; 0.04561278390444176;;], [0.011932671083999709, -0.019461834760040552, 0.09841074615805545, -0.049932719417439035, 0.046770637156686626, 0.094738950442398, -0.07072722954885365, -0.14442322991493425, -0.04050736246766045, 0.0161302098597473], [-0.09076531278487276 -0.08893617376065133 … 0.17688969929809442 0.03332683434777908], [-0.05466360224358683], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04778285245528021; 0.002073152425026603; … ; 0.08114267515467105; 0.03446269262358759;;], [0.015323637584496678, -0.030552214751628316, 0.013764716144273565, 0.3100695796150575, -0.261192567636749, 0.2497349959460083, 0.028594287515553797, 0.027736168629672472, 0.09212377303866616, -0.023160843492401398], [-0.07950664144707173 0.06607646105462771 … -0.06372020403545663 -0.01468114678613847], [-0.05369268305611946], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
Main.RealNVPLayer{Float64}([-0.01994222941233885; -0.12354790135192664; … ; 0.0914829134828038; 0.05468284490890261;;], [0.099984884595698, 0.035003360565331286, -0.01215996827091547, -0.012953387260835123, 0.07587204372806554, -0.023394565668740614, -0.0102036220459518, 0.06005046391996989, -0.14314572987011717, 0.13861740098537687], [0.10258920416705185 -0.009778532190417643 … 0.0669491035775922 -0.17052131152319536], [0.03905158369614853], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [-0.08079652822924399; -0.012018109547847596; … ; -0.0030313328377010503; -0.1035614822693211;;], [0.001980495005703741, -0.21924945243644256, 0.03722989450292816, -0.28291833693177126, 0.135309838967398, -0.021263453632266387, -0.11186250420913675, -0.08763936642038139, 0.06871598760276267, -0.0008703720158541838], [0.04464510169137048 -0.05871256825097563 … 0.13781815631646493 0.08481594857843244], [-0.029644007854268524], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
Main.RealNVPLayer{Float64}([-0.0011229834966656473; -0.018484150618418422; … ; -0.07431024618656085; -0.14150021432790136;;], [0.039187136830781594, -0.12125143109625228, -0.07256413563852397, 0.19987054507999624, -0.10242295290772746, -0.14911384317565934, 0.03058692040925597, -0.02203129176327245, 0.08975909617625784, 0.030514846351154208], [-0.021822218622933848 0.08063911073297139 … 0.055913203994493704 0.12560872073648785], [0.03129287825829668], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [-0.07575678587966743; -0.027325439006003274; … ; 0.0043884950547045595; 0.15422607557662982;;], [0.0672283066154055, -0.028755388897846946, -0.028725846779153314, -0.15447792000314187, 0.2355159994529047, 0.04468972794111743, -0.041014372945593236, -0.06755430048865957, 0.05191188641330835, 0.05681891408866492], [0.08576393997374578 -0.010008951785075104 … 0.04834859122916412 -0.10713472938089504], [-0.08648560763345085], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
Main.RealNVPLayer{Float64}([-0.04095467882435368; 0.03961049214080149; … ; 0.17806833898907104; -0.05059189752404996;;], [0.06365690411873613, -0.0032488405448314554, -0.1309632918320647, -0.15380654594004503, -0.004480789767369362, -0.047857677435183175, 0.0024269400370712747, 0.0674366094545587, -0.05536370041022595, 0.128201780421072], [0.035556705200683984 -0.05831003161535983 … 0.14506785217571846 0.03172703860767502], [-0.12556349867365843], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09467245402868087; 0.032875343863597126; … ; 0.08410268678096289; -0.24550565506652466;;], [-0.05526664064416595, -0.16639791200381898, 0.002874855601853682, 0.046690716337888495, -0.10520642349013601, -0.060176638380842196, -0.15050928292324414, 0.1691216093670418, 0.0021759240220606576, -0.08163114668408883], [0.0238197672353124 -0.060492452534179286 … -0.25991082553365347 0.005147083829832955], [-0.15655622524638355], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
Before training, the distribution looks like
model = train(x_data, model; num_epochs=800)
After training, the distribution looks like
This page was generated using Literate.jl.