MyCaffe  1.11.8.27
Deep learning software for Windows C# programmers.
LBFGSSolver.cs
1using MyCaffe.basecode;
2using MyCaffe.common;
3using MyCaffe.db.image;
4using MyCaffe.param;
5using System;
6using System.Collections.Generic;
7using System.Diagnostics;
8using System.Linq;
9using System.Text;
10using System.Threading;
11using System.Threading.Tasks;
12
13namespace MyCaffe.solvers
14{
25 public class LBFGSSolver<T> : Solver<T>
26 {
27 Blob<T> m_blobGradientsPrev;
28 Blob<T> m_blobGradients;
29 Blob<T> m_blobDirection;
30 BlobCollection<T> m_colBlobHistoryS = new BlobCollection<T>();
31 BlobCollection<T> m_colBlobHistoryY = new BlobCollection<T>();
32 List<double> m_rgRhoHistory = new List<double>();
33 int m_nStart;
34 int m_nEnd;
35 int m_nN;
36 double m_dfH0;
37 double m_dfStep;
38 T m_tZero;
39 T m_tOne;
40 T m_tMinusOne;
41
58 public LBFGSSolver(CudaDnn<T> cuda, Log log, SolverParameter p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXImageDatabaseBase imgDb, IXPersist<T> persist, int nSolverCount = 1, int nSolverRank = 0, Net<T> shareNet = null, onGetWorkspace getws = null, onSetWorkspace setws = null)
59 : base(cuda, log, p, evtCancel, evtForceSnapshot, evtForceTest, imgDb, persist, nSolverCount, nSolverRank, shareNet, getws, setws)
60 {
61 m_tZero = (T)Convert.ChangeType(0, typeof(T));
62 m_tOne = (T)Convert.ChangeType(1, typeof(T));
63 m_tMinusOne = (T)Convert.ChangeType(-1, typeof(T));
64 PreSolve();
65 }
66
70 protected override void dispose()
71 {
72 if (m_blobGradients != null)
73 {
74 m_blobGradients.Dispose();
75 m_blobGradients = null;
76 }
77
78 if (m_blobGradientsPrev != null)
79 {
80 m_blobGradientsPrev.Dispose();
81 m_blobGradientsPrev = null;
82 }
83
84 if (m_blobDirection != null)
85 {
86 m_blobDirection.Dispose();
87 m_blobDirection = null;
88 }
89
90 if (m_colBlobHistoryY != null)
91 {
92 m_colBlobHistoryY.Dispose();
93 m_colBlobHistoryY = null;
94 }
95
96 if (m_colBlobHistoryS != null)
97 {
98 m_colBlobHistoryS.Dispose();
99 m_colBlobHistoryS = null;
100 }
101
102 base.dispose();
103 }
104
108 public void PreSolve()
109 {
110 try
111 {
112 BlobCollection<T> net_params = m_net.learnable_parameters;
113
114 m_nN = 0;
115
116 for (int i = 0; i < net_params.Count; i++)
117 {
118 if (m_net.params_lr[i] != 0)
119 m_nN += net_params[i].count();
120 }
121
122 // Nothing to do, all learnable parameters have lr_mult = 0
123 if (m_nN == 0)
124 return;
125
126 List<int> rgShape = new List<int>() { m_nN };
127 m_colBlobHistoryS.Clear(true);
128 m_colBlobHistoryY.Clear(true);
129 m_rgRhoHistory.Clear();
130 m_nStart = 0;
131 m_nEnd = -1;
132
133 m_blobGradients = new Blob<T>(m_cuda, m_log, rgShape, false);
134 m_blobGradients.Name = "gradients";
135 m_blobGradientsPrev = new Blob<T>(m_cuda, m_log, rgShape, false);
136 m_blobGradientsPrev.Name = "gradients prev";
137 m_blobDirection = new Blob<T>(m_cuda, m_log, rgShape, false);
138 m_blobDirection.Name = "direction";
139
140 for (int i = 0; i < m_param.lbgfs_corrections; i++)
141 {
142 m_colBlobHistoryS.Add(new Blob<T>(m_cuda, m_log, rgShape, false));
143 m_colBlobHistoryY.Add(new Blob<T>(m_cuda, m_log, rgShape, false));
144 m_rgRhoHistory.Add(0);
145 }
146 }
147 catch (Exception excpt)
148 {
149 m_colBlobHistoryS.Clear(true);
150 m_colBlobHistoryY.Clear(true);
151 m_rgRhoHistory.Clear();
152
153 if (m_blobGradients != null)
154 {
155 m_blobGradients.Dispose();
156 m_blobGradients = null;
157 }
158
159 if (m_blobGradientsPrev != null)
160 {
161 m_blobGradientsPrev.Dispose();
162 m_blobGradientsPrev = null;
163 }
164
165 if (m_blobDirection != null)
166 {
167 m_blobDirection.Dispose();
168 m_blobDirection = null;
169 }
170
171 throw excpt;
172 }
173 finally
174 {
175 }
176 }
177
183 public override double ApplyUpdate(int nIterationOverride = -1)
184 {
185 if (m_nN == 0)
186 {
187 for (int i = 0; i < m_net.learnable_parameters.Count; i++)
188 {
189 m_net.learnable_parameters[i].SetDiff(0);
190 }
191
192 return 0;
193 }
194
195 m_log.CHECK(is_root_solver, "You can only apply the LBFGS Solver updates on the root solver.");
196
201 ComputeStep();
202 UpdateNet();
203
204 // Increment the internal iter_ counter -- its value should always indicate
205 // the number of times the weights have been updated.
206 m_nIter++;
207
208 return 0;
209 }
210
214 public virtual void CollectGradients()
215 {
216 BlobCollection<T> net_params = m_net.learnable_parameters;
217
218 if (m_nIter != 0)
219 m_cuda.copy(m_nN, m_blobGradients.gpu_data, m_blobGradientsPrev.mutable_gpu_data);
220
221 int nDstOffset = 0;
222 for (int i = 0; i < net_params.Count; i++)
223 {
224 if (m_net.params_lr[i] != 0)
225 {
226 m_cuda.copy(net_params[i].count(), net_params[i].gpu_diff, m_blobGradients.mutable_gpu_data, 0, nDstOffset);
227 nDstOffset += net_params[i].count();
228 }
229 }
230 }
231
235 public virtual void UpdateHistory()
236 {
237 if (m_nIter == 0)
238 return;
239
240 m_cuda.scal(m_nN, m_tMinusOne, m_blobDirection.mutable_gpu_data); // s
241 m_cuda.axpby(m_nN, m_tOne, m_blobGradients.gpu_data, m_tMinusOne, m_blobGradientsPrev.mutable_gpu_data); // y
242 T fYs = m_cuda.dot(m_nN, m_blobDirection.gpu_data, m_blobGradientsPrev.gpu_data);
243 double dfYs = Utility.ConvertVal<T>(fYs);
244
245 if (dfYs < 1e-10)
246 {
247 m_log.WriteLine("WARNING: Skipping L-BFGS update.");
248 if (m_nEnd < 0)
249 m_nEnd = 0;
250
251 return;
252 }
253
254 m_nEnd += 1;
255
256 if (m_nEnd < m_param.lbgfs_corrections)
257 {
258 if (m_nStart != 0)
259 {
260 m_nStart += 1;
261
262 if (m_nStart == m_param.lbgfs_corrections)
263 m_nStart = 0;
264 }
265 }
266 else
267 {
268 m_nStart = 1;
269 m_nEnd = 0;
270 }
271
272 m_cuda.copy(m_nN, m_blobDirection.gpu_data, m_colBlobHistoryS[m_nEnd].mutable_gpu_data);
273 m_cuda.copy(m_nN, m_blobGradientsPrev.gpu_data, m_colBlobHistoryY[m_nEnd].mutable_gpu_data);
274 m_rgRhoHistory[m_nEnd] = 1.0 / dfYs;
275 }
276
280 public virtual void ComputeInitialHessianApprox()
281 {
282 if (m_nIter == 0)
283 return;
284
285 T fh0 = m_cuda.dot(m_nN, m_colBlobHistoryY[m_nEnd].gpu_data, m_colBlobHistoryY[m_nEnd].gpu_data);
286 double dfH0 = Utility.ConvertVal<T>(fh0);
287
288 m_dfH0 = 1.0 / m_rgRhoHistory[m_nEnd] / dfH0;
289 }
290
291 private List<int> lbfgs_history_indices(int nStart, int nEnd, int nMax)
292 {
293 List<int> rgIndices = Utility.Create<int>((nStart == 0) ? nEnd + 1 : nMax, 0);
294
295 if (nStart == 0)
296 {
297 for (int i = nStart; i <= nEnd; i++)
298 {
299 rgIndices[i] = i;
300 }
301 }
302 else
303 {
304 int j = 0;
305
306 for (int i = nStart; i < rgIndices.Count; i++)
307 {
308 rgIndices[j++] = i;
309 }
310
311 for (int i = 0; i <= nEnd; i++)
312 {
313 rgIndices[j++] = i;
314 }
315 }
316
317 return rgIndices;
318 }
319
323 public virtual void ComputeDirection()
324 {
325 m_cuda.copy(m_nN, m_blobGradients.gpu_data, m_blobDirection.mutable_gpu_data);
326
327 if (m_nIter == 0)
328 return;
329
330 List<int> rgIndices = lbfgs_history_indices(m_nStart, m_nEnd, m_param.lbgfs_corrections);
331 List<double> rgAlpha = Utility.Create<double>(rgIndices.Count, 0);
332 double dfBeta = 0;
333
334 for (int i = rgIndices.Count - 1; i >= 0; i--)
335 {
336 int nIdx = rgIndices[i];
337
338 T fAlpha = m_cuda.dot(m_nN, m_colBlobHistoryS[nIdx].gpu_data, m_blobDirection.gpu_data);
339 rgAlpha[nIdx] = (double)Utility.ConvertVal<T>(fAlpha);
340 rgAlpha[nIdx] *= m_rgRhoHistory[nIdx];
341
342 m_cuda.axpy(m_nN, -rgAlpha[nIdx], m_colBlobHistoryY[nIdx].gpu_data, m_blobDirection.mutable_gpu_data);
343 }
344
345 m_cuda.scal(m_nN, m_dfH0, m_blobDirection.mutable_gpu_data);
346
347 for (int i = 0; i < rgIndices.Count; i++)
348 {
349 int nIdx = rgIndices[i];
350
351 T fBeta = m_cuda.dot(m_nN, m_colBlobHistoryY[nIdx].gpu_data, m_blobDirection.gpu_data);
352 dfBeta = (double)Utility.ConvertVal<T>(fBeta);
353 dfBeta *= m_rgRhoHistory[nIdx];
354
355 m_cuda.axpy(m_nN, rgAlpha[nIdx] - dfBeta, m_colBlobHistoryS[nIdx].gpu_data, m_blobDirection.mutable_gpu_data);
356 }
357 }
358
362 public virtual void ComputeStep()
363 {
364 m_dfStep = 1.0;
365 }
366
370 public virtual void UpdateNet()
371 {
372 m_cuda.scal(m_nN, m_dfStep, m_blobDirection.mutable_gpu_data);
373
374 BlobCollection<T> net_params = m_net.learnable_parameters;
375
376 int nOffset = 0;
377 for (int i = 0; i < net_params.Count; i++)
378 {
379 int nCount = net_params[i].count();
380
381 if (m_net.params_lr[i] != 0)
382 {
383 double dfLr = m_net.params_lr[i].GetValueOrDefault(1.0) * m_param.base_lr;
384
385 if (dfLr != 1.0)
386 {
387 T fLr = (T)Convert.ChangeType(m_net.params_lr[i], typeof(T));
388 m_cuda.scale(nCount, fLr, m_blobDirection.gpu_data, net_params[i].mutable_gpu_diff, nOffset, 0);
389 }
390
391 nOffset += nCount;
392 }
393 else
394 {
395 net_params[i].SetDiff(0);
396 }
397 }
398
399 m_net.Update();
400 }
401
406 protected override void RestoreSolverState(byte[] rgState)
407 {
408 SolverState state = m_persist.LoadSolverState(rgState, m_param.type);
409
410 m_nIter = state.iter;
412 m_nStart = state.start;
413 m_nEnd = state.end;
414
415 List<int> rgIndices = lbfgs_history_indices(m_nStart, m_nEnd, m_param.lbgfs_corrections);
416
417 for (int i = 0; i < rgIndices.Count; i++)
418 {
419 int nIdx = rgIndices[i];
420
421 m_colBlobHistoryS[i].FromProto(state.history[nIdx]);
422 m_colBlobHistoryY[i].FromProto(state.s_history[nIdx]);
423 m_rgRhoHistory[i] = state.rho_history[i];
424 }
425
426 m_blobGradients.FromProto(state.gradients);
427 m_blobDirection.FromProto(state.direction);
428 }
429
434 protected override byte[] SnapshotSolverState()
435 {
436 SolverState state = new SolverState();
437
438 state.iter = m_nIter;
440 state.start = m_nStart;
441 state.end = m_nEnd;
442
443 List<int> rgIndices = lbfgs_history_indices(m_nStart, m_nEnd, m_param.lbgfs_corrections);
444
445 for (int i = 0; i < rgIndices.Count; i++)
446 {
447 int nIdx = rgIndices[i];
448
449 state.s_history.Add(m_colBlobHistoryS[nIdx].ToProto());
450 state.history.Add(m_colBlobHistoryY[nIdx].ToProto());
451 state.rho_history.Add(m_rgRhoHistory[nIdx]);
452 }
453
454 state.gradients = m_blobGradients.ToProto();
455 state.direction = m_blobDirection.ToProto();
456
457 return m_persist.SaveSolverState(state, m_param.type);
458 }
459 }
460}
The CancelEvent provides an extension to the manual cancel event that allows for overriding the manua...
Definition: CancelEvent.cs:17
The Log class provides general output in text form.
Definition: Log.cs:13
void CHECK(bool b, string str)
Test a flag for true.
Definition: Log.cs:227
void WriteLine(string str, bool bOverrideEnabled=false, bool bHeader=false, bool bError=false, bool bDisable=false)
Write a line of output.
Definition: Log.cs:80
The Utility class provides general utility funtions.
Definition: Utility.cs:35
static List< int > Create(int nCount, int nStart, int nInc)
Create a new List and fill it with values starting with start and incrementing by inc.
Definition: Utility.cs:683
The BlobCollection contains a list of Blobs.
void Dispose()
Release all resource used by the collection and its Blobs.
void Add(Blob< T > b)
Add a new Blob to the collection.
void SetDiff(double df)
Set all blob diff to the value specified.
int Count
Returns the number of items in the collection.
void Clear(bool bDispose=false)
Remove all items from the collection.
The Blob is the main holder of data that moves through the Layers of the Net.
Definition: Blob.cs:25
long mutable_gpu_data
Returns the data GPU handle used by the CudaDnn connection.
Definition: Blob.cs:1240
void FromProto(BlobProto bp, bool bReshape=true)
Create a new Blob from a given BlobProto.
Definition: Blob.cs:1342
BlobProto ToProto(bool bWriteDiff=false)
Writes the Blob to a new BlobProto.
Definition: Blob.cs:1416
string Name
Get/set the name of the Blob.
Definition: Blob.cs:1907
virtual void Dispose(bool bDisposing)
Releases all resources used by the Blob (including both GPU and Host).
Definition: Blob.cs:368
long gpu_data
Returns the data GPU handle used by the CudaDnn connection.
Definition: Blob.cs:1232
The CudaDnn object is the main interface to the Low-Level Cuda C++ DLL.
Definition: CudaDnn.cs:851
Connects Layer's together into a direct acrylic graph (DAG) specified by a NetParameter
Definition: Net.cs:23
The SolverParameter is a parameter for the solver, specifying the train and test networks.
int lbgfs_corrections
Specifies the number of lbgfs corrections used with the L-BGFS solver.
double base_lr
The base learning rate.
SolverType type
Specifies the solver type.
The SolverState specifies the state of a given solver.
Definition: SolverState.cs:14
int end
Specifies the end used by L-BGFS
Definition: SolverState.cs:55
BlobProto gradients
Gradients used with L-BFGS state.
Definition: SolverState.cs:82
List< double > rho_history
rho history used with L-BFGS state.
Definition: SolverState.cs:109
int iter
The current iteration.
Definition: SolverState.cs:37
List< BlobProto > history
The history for SGD solvers.
Definition: SolverState.cs:64
int start
Specifies the start used by L-BGFS
Definition: SolverState.cs:46
int current_step
The current step for learning rate.
Definition: SolverState.cs:73
List< BlobProto > s_history
S history used with L-BFGS state.
Definition: SolverState.cs:100
BlobProto direction
Direction used with L-BFGS state.
Definition: SolverState.cs:91
Optimizes the parameters of a Net using L-BFGS. This implementation is based on minFunc,...
Definition: LBFGSSolver.cs:26
virtual void CollectGradients()
Collect the gradients from the network learnable parameters.
Definition: LBFGSSolver.cs:214
override void dispose()
Releases all resources (GPU and Host) used by the Solver.
Definition: LBFGSSolver.cs:70
virtual void UpdateNet()
Update the network.
Definition: LBFGSSolver.cs:370
virtual void ComputeStep()
Compute the step.
Definition: LBFGSSolver.cs:362
override double ApplyUpdate(int nIterationOverride=-1)
Apply the gradients to the network.
Definition: LBFGSSolver.cs:183
void PreSolve()
Runs the pre-solve which parpares the Solver to start Solving.
Definition: LBFGSSolver.cs:108
virtual void UpdateHistory()
Update the history values with the gradients and direction.
Definition: LBFGSSolver.cs:235
virtual void ComputeInitialHessianApprox()
Compute the initial Hessian approximation.
Definition: LBFGSSolver.cs:280
LBFGSSolver(CudaDnn< T > cuda, Log log, SolverParameter p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXImageDatabaseBase imgDb, IXPersist< T > persist, int nSolverCount=1, int nSolverRank=0, Net< T > shareNet=null, onGetWorkspace getws=null, onSetWorkspace setws=null)
The LBFGSSolver constructor.
Definition: LBFGSSolver.cs:58
virtual void ComputeDirection()
Compute the direction.
Definition: LBFGSSolver.cs:323
override byte[] SnapshotSolverState()
Save the solver state.
Definition: LBFGSSolver.cs:434
override void RestoreSolverState(byte[] rgState)
Restore a previously saved solver state.
Definition: LBFGSSolver.cs:406
An interface for classes that perform optimization on Nets
Definition: Solver.cs:28
SolverParameter m_param
Specifies the SolverParameter that defines how the Solver operates.
Definition: Solver.cs:40
CudaDnn< T > m_cuda
Specifies the instance of CudaDnn used by the Solver that provides a connection to Cuda.
Definition: Solver.cs:32
bool? is_root_solver
Returns whether or not this is the root solver.
Definition: Solver.cs:1274
int m_nIter
Specifies the current iteration.
Definition: Solver.cs:52
IXPersist< T > m_persist
Specifies the persistance object used to save weight and solver states.
Definition: Solver.cs:85
Net< T > m_net
Specifies the training Net.
Definition: Solver.cs:44
int m_nCurrentStep
Specifies the current step.
Definition: Solver.cs:56
Log m_log
Specifies the Log for output.
Definition: Solver.cs:36
The IXImageDatabaseBase interface defines the general interface to the in-memory image database.
Definition: Interfaces.cs:415
The IXPersist interface is used by the CaffeControl to load and save weights.
Definition: Interfaces.cs:175
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
The MyCaffe.common namespace contains common MyCaffe classes.
Definition: BatchInput.cs:8
The MyCaffe.db.image namespace contains all image database related classes.
Definition: Database.cs:18
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.solvers namespace contains all solver classes, including the base Solver.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12