RealNVP network

For the definition of this network and concepts of normalizing flow, please refer this realnvp blog: https://lilianweng.github.io/lil-log/2018/10/13/flow-based-deep-generative-models.html, and the pytorch notebook: https://github.com/GiggleLiu/marburg/blob/master/solutions/realnvp.ipynb

using NiLang, NiLang.AD
using LinearAlgebra
using DelimitedFiles
using Plots

include the optimizer, you can find it under the Adam.jl file in the examples/ folder.

include(NiLang.project_relative_path("examples", "Adam.jl"))
gclip! (generic function with 1 method)

Model definition

First, define the single layer transformation and its behavior under GVar - the gradient wrapper.

struct RealNVPLayer{T}
    # transform network
    W1::Matrix{T}
    b1::Vector{T}
    W2::Matrix{T}
    b2::Vector{T}
    y1::Vector{T}
    y1a::Vector{T}

    # scaling network
    sW1::Matrix{T}
    sb1::Vector{T}
    sW2::Matrix{T}
    sb2::Vector{T}
    sy1::Vector{T}
    sy1a::Vector{T}
end

"""collect parameters in the `layer` into a vector `out`."""
function collect_params!(out, layer::RealNVPLayer)
    k=0
    for field in [:W1, :b1, :W2, :b2, :sW1, :sb1, :sW2, :sb2]
        v = getfield(layer, field)
        nv = length(v)
        out[k+1:k+nv] .= vec(v)
        k += nv
    end
    return out
end

"""dispatch vectorized parameters `out` into the `layer`."""
function dispatch_params!(layer::RealNVPLayer, out)
    k=0
    for field in [:W1, :b1, :W2, :b2, :sW1, :sb1, :sW2, :sb2]
        v = getfield(layer, field)
        nv = length(v)
        vec(v) .= out[k+1:k+nv]
        k += nv
    end
    return out
end

function nparameters(n::RealNVPLayer)
    sum(x->length(getfield(n, x)), [:W1, :b1, :W2, :b2, :sW1, :sb1, :sW2, :sb2])
end
nparameters (generic function with 1 method)

Then, we define network and how to access the parameters.

const RealNVP{T} = Vector{RealNVPLayer{T}}

nparameters(n::RealNVP) = sum(nparameters, n)

function collect_params(n::RealNVP{T}) where T
    out = zeros(T, nparameters(n))
    k = 0
    for layer in n
        np = nparameters(layer)
        collect_params!(view(out, k+1:k+np), layer)
        k += np
    end
    return out
end

function dispatch_params!(network::RealNVP, out)
    k = 0
    for layer in network
        np = nparameters(layer)
        dispatch_params!(layer, view(out, k+1:k+np))
        k += np
    end
    return network
end

function random_realnvp(nparams::Int, nhidden::Int, nhidden_s::Int, nlayer::Int; scale=0.1)
    random_realnvp(Float64, nparams, nhidden, nhidden_s::Int, nlayer; scale=scale)
end

function random_realnvp(::Type{T}, nparams::Int, nhidden::Int, nhidden_s::Int, nlayer::Int; scale=0.1) where T
    nin = nparams÷2
    scale = T(scale)
    y1 = zeros(T, nhidden)
    sy1 = zeros(T, nhidden_s)
    RealNVPLayer{T}[RealNVPLayer(
            randn(T, nhidden, nin)*scale, randn(T, nhidden)*scale,
            randn(T, nin, nhidden)*scale, randn(T, nin)*scale, y1, zero(y1),
            randn(T, nhidden_s, nin)*scale, randn(T, nhidden_s)*scale,
            randn(T, nin, nhidden_s)*scale, randn(T, nin)*scale, sy1, zero(sy1),
            ) for _ = 1:nlayer]
end
random_realnvp (generic function with 2 methods)

Loss function

In each layer, we use the information in x to update y!. During computing, we use to vector type ancillas y1 and y1a, both of them can be uncomputed at the end of the function.

@i function onelayer!(x::AbstractVector{T}, layer::RealNVPLayer{T},
                y!::AbstractVector{T}, logjacobian!::T; islast) where T
    @routine @invcheckoff begin
        # scale network
        scale ← zero(y!)
        ytemp2 ← zero(y!)
        i_affine!(layer.sy1, layer.sW1, layer.sb1, x)
        @inbounds for i=1:length(layer.sy1)
            if (layer.sy1[i] > 0, ~)
                layer.sy1a[i] += layer.sy1[i]
            end
        end
        i_affine!(scale, layer.sW2, layer.sb2, layer.sy1a)

        # transform network
        i_affine!(layer.y1, layer.W1, layer.b1, x)
        # relu
        @inbounds for i=1:length(layer.y1)
            if (layer.y1[i] > 0, ~)
                layer.y1a[i] += layer.y1[i]
            end
        end
    end
    # inplace multiply exp of scale! -- dangerous
    @inbounds @invcheckoff for i=1:length(scale)
        @routine begin
            expscale ← zero(T)
            tanhscale ← zero(T)
            if (islast, ~)
                tanhscale += tanh(scale[i])
            else
                tanhscale += scale[i]
            end
            expscale += exp(tanhscale)
        end
        logjacobian! += tanhscale
        # inplace multiply!!!
        temp ← zero(T)
        temp += y![i] * expscale
        SWAP(temp, y![i])
        temp -= y![i] / expscale
        temp → zero(T)
        ~@routine
    end

    # affine the transform layer
    i_affine!(y!, layer.W2, layer.b2, layer.y1a)
    ~@routine
    # clean up accumulated rounding error, since this memory is reused.
    @safe layer.y1 .= zero(T)
    @safe layer.sy1 .= zero(T)
end

A realnvp network always transforms inputs reversibly. We update one half of x! a time, so that input and output memory space do not clash.

@i function realnvp!(x!::AbstractVector{T}, network::RealNVP{T}, logjacobian!) where T
    @invcheckoff for i=1:length(network)
        np ← length(x!)
        if (i%2==0, ~)
            @inbounds onelayer!(x! |> subarray(np÷2+1:np), network[i], x! |> subarray(1:np÷2), logjacobian!; islast=i==length(network))
        else
            @inbounds onelayer!(x! |> subarray(1:np÷2), network[i], x! |> subarray(np÷2+1:np), logjacobian!; islast=i==length(network))
        end
    end
end

How to obtain the log-probability of a data.

@i function logp!(out!::T, x!::AbstractVector{T}, network::RealNVP{T}) where T
    (~realnvp!)(x!, network, out!)
    @invcheckoff for i = 1:length(x!)
        @routine begin
            xsq ← zero(T)
            @inbounds xsq += x![i]^2
        end
        out! -= 0.5 * xsq
        ~@routine
    end
end

The negative-log-likelihood loss function

@i function nll_loss!(out!::T, cum!::T, xs!::Matrix{T}, network::RealNVP{T}) where T
    @invcheckoff for i=1:size(xs!, 2)
        @inbounds logp!(cum!, xs! |> subarray(:,i), network)
    end
    out! -= cum!/(@const size(xs!, 2))
end

Training

function train(x_data, model; num_epochs = 800)
    num_vars = size(x_data, 1)
    params = collect_params(model)
    optimizer = Adam(; lr=0.01)
    for epoch = 1:num_epochs
        loss, a, b, c = nll_loss!(0.0, 0.0, copy(x_data), model)
        if epoch % 50 == 1
            println("epoch = $epoch, loss = $loss")
            display(showmodel(x_data, model))
        end
        _, _, _, gmodel = (~nll_loss!)(GVar(loss, 1.0), GVar(a), GVar(b), GVar(c))
        g = grad.(collect_params(gmodel))
        update!(params, grad.(collect_params(gmodel)), optimizer)
        dispatch_params!(model, params)
    end
    return model
end

function showmodel(x_data, model; nsamples=2000)
    scatter(x_data[1,1:nsamples], x_data[2,1:nsamples]; xlims=(-5,5), ylims=(-5,5))
    zs = randn(2, nsamples)
    for i=1:nsamples
        realnvp!(view(zs, :, i), model, 0.0)
    end
    scatter!(zs[1,:], zs[2,:])
end
showmodel (generic function with 1 method)

you can find the training data in examples/ folder

x_data = Matrix(readdlm(NiLang.project_relative_path("examples", "train.dat"))')

import Random; Random.seed!(22)
model = random_realnvp(Float64, size(x_data, 1), 10, 10, 4; scale=0.1)
4-element Vector{Main.RealNVPLayer{Float64}}:
 Main.RealNVPLayer{Float64}([0.10365283655905022; 0.12832890997547838; … ; 0.10237515285351262; 0.04561278390444176;;], [0.011932671083999709, -0.019461834760040552, 0.09841074615805545, -0.049932719417439035, 0.046770637156686626, 0.094738950442398, -0.07072722954885365, -0.14442322991493425, -0.04050736246766045, 0.0161302098597473], [-0.09076531278487276 -0.08893617376065133 … 0.17688969929809442 0.03332683434777908], [-0.05466360224358683], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.04778285245528021; 0.002073152425026603; … ; 0.08114267515467105; 0.03446269262358759;;], [0.015323637584496678, -0.030552214751628316, 0.013764716144273565, 0.3100695796150575, -0.261192567636749, 0.2497349959460083, 0.028594287515553797, 0.027736168629672472, 0.09212377303866616, -0.023160843492401398], [-0.07950664144707173 0.06607646105462771 … -0.06372020403545663 -0.01468114678613847], [-0.05369268305611946], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
 Main.RealNVPLayer{Float64}([-0.01994222941233885; -0.12354790135192664; … ; 0.0914829134828038; 0.05468284490890261;;], [0.099984884595698, 0.035003360565331286, -0.01215996827091547, -0.012953387260835123, 0.07587204372806554, -0.023394565668740614, -0.0102036220459518, 0.06005046391996989, -0.14314572987011717, 0.13861740098537687], [0.10258920416705185 -0.009778532190417643 … 0.0669491035775922 -0.17052131152319536], [0.03905158369614853], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [-0.08079652822924399; -0.012018109547847596; … ; -0.0030313328377010503; -0.1035614822693211;;], [0.001980495005703741, -0.21924945243644256, 0.03722989450292816, -0.28291833693177126, 0.135309838967398, -0.021263453632266387, -0.11186250420913675, -0.08763936642038139, 0.06871598760276267, -0.0008703720158541838], [0.04464510169137048 -0.05871256825097563 … 0.13781815631646493 0.08481594857843244], [-0.029644007854268524], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
 Main.RealNVPLayer{Float64}([-0.0011229834966656473; -0.018484150618418422; … ; -0.07431024618656085; -0.14150021432790136;;], [0.039187136830781594, -0.12125143109625228, -0.07256413563852397, 0.19987054507999624, -0.10242295290772746, -0.14911384317565934, 0.03058692040925597, -0.02203129176327245, 0.08975909617625784, 0.030514846351154208], [-0.021822218622933848 0.08063911073297139 … 0.055913203994493704 0.12560872073648785], [0.03129287825829668], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [-0.07575678587966743; -0.027325439006003274; … ; 0.0043884950547045595; 0.15422607557662982;;], [0.0672283066154055, -0.028755388897846946, -0.028725846779153314, -0.15447792000314187, 0.2355159994529047, 0.04468972794111743, -0.041014372945593236, -0.06755430048865957, 0.05191188641330835, 0.05681891408866492], [0.08576393997374578 -0.010008951785075104 … 0.04834859122916412 -0.10713472938089504], [-0.08648560763345085], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
 Main.RealNVPLayer{Float64}([-0.04095467882435368; 0.03961049214080149; … ; 0.17806833898907104; -0.05059189752404996;;], [0.06365690411873613, -0.0032488405448314554, -0.1309632918320647, -0.15380654594004503, -0.004480789767369362, -0.047857677435183175, 0.0024269400370712747, 0.0674366094545587, -0.05536370041022595, 0.128201780421072], [0.035556705200683984 -0.05831003161535983 … 0.14506785217571846 0.03172703860767502], [-0.12556349867365843], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.09467245402868087; 0.032875343863597126; … ; 0.08410268678096289; -0.24550565506652466;;], [-0.05526664064416595, -0.16639791200381898, 0.002874855601853682, 0.046690716337888495, -0.10520642349013601, -0.060176638380842196, -0.15050928292324414, 0.1691216093670418, 0.0021759240220606576, -0.08163114668408883], [0.0238197672353124 -0.060492452534179286 … -0.25991082553365347 0.005147083829832955], [-0.15655622524638355], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])

Before training, the distribution looks like before

model = train(x_data, model; num_epochs=800)

After training, the distribution looks like before


This page was generated using Literate.jl.