-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.jl
141 lines (135 loc) · 4.59 KB
/
train.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
using ConvOptDL
using ConvOptDL.Utils
using StatsBase
using LinearAlgebra: I
using ArgParse
using Serialization
using Pipe: @pipe
using Flux: update!, trainmode!
using Statistics: mean
using Logging: @info
using Dates: now
import Flux
function loss_log_softmax(logits, one_hot_vec)
log_prob = log.(Flux.softmax(logits, dims = 1))
loss = @pipe -(log_prob .* one_hot_vec) |> sum(_, dims = 1)
mean(loss)
end
_format_batch(batch) = reshape(Float32.(batch), size(batch)[1:end-2]..., :)
function train!(loss, model, batch, opt)
# https://fluxml.ai/Flux.jl/stable/training/training/#Custom-Training-loops-1
ps = Flux.params(model)
# smoothed onehot encoding
# onehot is not differentiable, moving out of the do block
query_onehot = @pipe ConvOptDL.Utils.onehot(batch.query_labels) |>
reshape(_, batch.support_n_ways, :) |>
_ .* (1 - 5f-2) .+ 5f-2 * (1 .- _) ./ (batch.support_n_ways - 1f0)
support_onehot = ConvOptDL.Utils.onehot(batch.support_labels)
local meta_loss
gs = Flux.gradient(ps) do
# (`feat_dim`, `n_ways`*`k_shots`, `num_tasks`)
embed_query = reshape(
model(_format_batch(batch.query_samples)),
:,
batch.n_query,
size(batch),
)
embed_support = reshape(
model(_format_batch(batch.support_samples)),
:,
batch.n_support,
size(batch),
)
Q, p, G, h, A, b = crammer_svm(embed_support, support_onehot, batch)
# solve QP
X, λ, ν = solve_qp_batch(Q, p, G, h, A, b)
α = @pipe X |>
reshape(_, 1, batch.n_support, batch.support_n_ways, size(batch)) |>
repeat(_, outer = (batch.n_query, 1, 1, 1)) |>
permutedims(_, (3, 2, 1, 4))
# logits
logits = @pipe ConvOptDL.Utils.gram_matrix(embed_query, embed_support) |>
reshape(_, 1, size(_)...) |>
repeat(_, outer = (5, 1, 1, 1)) |>
_ .* α |>
sum(_, dims = 2) |>
dropdims(_, dims = 2) |>
reshape(_, batch.support_n_ways, :)
meta_loss = loss(logits, query_onehot)
return meta_loss
end
update!(opt, ps, gs)
return meta_loss
end
function parse_opts()
s = ArgParseSettings("Meta Convex Optimization Learning training")
@add_arg_table! s begin
"--batch-size"
arg_type = Int64
metavar = "INT"
default = 8
"--batches-per-episode"
arg_type = Int64
metavar = "INT"
default = 200
"--log-period"
help = "print log for every given number of batches"
arg_type = Int64
metavar = "INT"
default = 25
"--num-episodes"
arg_type = Int64
metavar = "INT"
default = 50
"--learning-rate"
arg_type = Float32
default = 1f-1
"-o"
arg_type = String
default = "model.jls"
"data_file"
arg_type = String
default = "./test/test_data.jls"
end
args = parse_args(s)
return args
end
if nameof(@__MODULE__) == :Main
args = parse_opts()
@info "Training Options: " args
batch_size = args["batch-size"]
batches_per_episode = args["batches-per-episode"]
log_period = args["log-period"]
num_episodes = args["num-episodes"]
data_file = args["data_file"]
out_model_file = args["o"]
lr = args["learning-rate"]
model_name, _ = splitext(out_model_file)
dloader = FewShotDataLoader(data_file)
opt = Flux.Optimise.Momentum(lr)
model = resnet12()
trainmode!(model)
record = Dict()
for episode = 1:num_episodes
@info "episode $(episode) start" now()
meta_losses = []
total_time = 0
for i = 1:batches_per_episode
time_start = time()
batch = sample(dloader, batch_size, support_n_ways = 5, support_k_shots = 2)
meta_loss = train!(loss_log_softmax, model, batch, opt)
time_end = time()
push!(meta_losses, meta_loss)
total_time += (time_end - time_start)
if i % log_period == 0
@info "[$(now())] last $(log_period) meta losses" meta_losses[end-log_period+1:end] mean(meta_losses)
@info "[$(now())] average execution time" mean_exec_time = (total_time / i)
end
end
@info "episode $(episode) end" now()
record[episode] = meta_losses
serialize("train_$(model_name)_meta_losses_$(episode).jls", meta_losses)
end
serialize("all_train_$(model_name)_meta_losses.jls", record)
serialize(out_model_file, model)
end