MyCaffe  1.11.8.27
Deep learning software for Windows C# programmers.
Solver.cs
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using System.Threading;
6using System.IO;
7using System.Diagnostics;
8using System.Collections;
9using MyCaffe.basecode;
10using MyCaffe.db.image;
11using MyCaffe.common;
12using MyCaffe.param;
13
17namespace MyCaffe.solvers
18{
27 public abstract class Solver<T> : IDisposable
28 {
32 protected CudaDnn<T> m_cuda;
36 protected Log m_log;
44 protected Net<T> m_net;
48 protected List<Net<T>> m_rgTestNets = new List<Net<T>>();
52 protected int m_nIter;
56 protected int m_nCurrentStep;
60 protected List<double> m_rgLosses = new List<double>();
61 AutoResetEvent m_evtCompleted = new AutoResetEvent(false);
62 bool m_bEnableTest = true;
63 bool m_bEnableBlobDebugging = false;
64 bool m_bEnableBreakOnNan = false;
65 bool m_bEnableDetailedNanDetection = false;
66 bool m_bEnableSingleStep = false;
70 protected double m_dfSmoothedLoss = 0;
71 CancelEvent m_evtCancel;
72 AutoResetEvent m_evtForceSnapshot;
73 AutoResetEvent m_evtForceTest;
77 protected int m_nSolverCount = 1;
81 protected int m_nSolverRank = 0;
89 protected double m_dfLearningRateOverride = 0;
90 double m_dfLastAccuracy = 0;
91 double m_dfLastError = double.MaxValue;
92 double m_dfBestAccuracy = 0;
93 double m_dfBestError = double.MaxValue;
94 IXImageDatabaseBase m_db = null;
95 int m_nTrainingIterationOverride = -1;
96 int m_nTestingIterationOverride = -1;
97 object m_tag = null;
98 bool m_bWeightsUpdated = false;
99 static object m_syncGetRi = new object();
100 Blob<T> m_blobBatchInputData = null;
101 double m_dfAverageTestTime = 0;
102 SNAPSHOT_WEIGHT_UPDATE_METHOD m_snapshotWeightUpdatemMethod = SNAPSHOT_WEIGHT_UPDATE_METHOD.FAVOR_ACCURACY;
103 int m_nTrainingTimeLimitInMinutes = 0;
104 long m_hWorkspaceData = 0; // shared among the layers and nets, only grows in size.
105 ulong m_lWorkspaceSize = 0;
106 bool m_bFirstNanError = true;
107 List<double> m_rgAverageAccuracyWindow = null;
108
112 public event EventHandler OnStart;
116 public event EventHandler OnAborted;
120 public event EventHandler<GradientsReadyArgs> OnGradientsReady;
124 public event EventHandler<SnapshotArgs> OnSnapshot;
128 public event EventHandler<TrainingIterationArgs<T>> OnTrainingIteration;
132 public event EventHandler<TestingIterationArgs<T>> OnTestingIteration;
136 public event EventHandler<TestResultArgs<T>> OnTestResults;
140 public event EventHandler<TestArgs> OnTest;
144 public event EventHandler OnTestStart;
149 public event EventHandler<CustomForwardBackArgs<T>> OnCustomForwardBack;
153 public event EventHandler<WorkspaceArgs> OnGetWorkspace;
157 public event EventHandler<WorkspaceArgs> OnSetWorkspace;
158
175 public Solver(CudaDnn<T> cuda, Log log, SolverParameter p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXImageDatabaseBase imgDb, IXPersist<T> persist, int nSolverCount = 1, int nSolverRank = 0, Net<T> shareNet = null, onGetWorkspace getws = null, onSetWorkspace setws = null)
176 {
177 m_cuda = cuda;
178 m_log = log;
179 m_evtCancel = evtCancel;
180 m_evtForceSnapshot = evtForceSnapshot;
181 m_evtForceTest = evtForceTest;
182
183 if (m_log.IsEnabled)
185
186 m_db = imgDb;
187 m_persist = persist;
188 m_nSolverCount = nSolverCount;
189 m_nSolverRank = nSolverRank;
190
191 if (getws != null)
192 OnGetWorkspace += new EventHandler<WorkspaceArgs>(getws);
193
194 if (setws != null)
195 OnSetWorkspace += new EventHandler<WorkspaceArgs>(setws);
196
197 if (p.accuracy_average_window > 0)
198 {
199 m_rgAverageAccuracyWindow = new List<double>();
200 for (int i = 0; i < p.accuracy_average_window; i++)
201 {
202 m_rgAverageAccuracyWindow.Add(0);
203 }
204 }
205
206 Init(p, shareNet);
207 }
208
212 public void Dispose()
213 {
214 dispose();
215 }
216
221 {
222 get { return m_dfLearningRateOverride; }
223 set { m_dfLearningRateOverride = value; }
224 }
225
231 {
232 int nTimingCount = 0;
233 double dfTotalTime = 0;
234 return fireOnTrainingIterationEvent(false, 0, 0, ref nTimingCount, ref dfTotalTime);
235 }
236
237 private bool fireOnTrainingIterationEvent(bool bFwdPassNanFree, double dfLoss, double dfLastLearningRate, ref int nTimingCount, ref double dfTotalTime)
238 {
239 if (is_root_solver && OnTrainingIteration != null)
240 {
241 string strFirstNanBlob = null;
242 DebugInformation<T> dbgInfo = null;
243
244 if (m_bEnableBlobDebugging)
245 {
246 dbgInfo = TrainingNet.GetDebugInformation(m_bEnableDetailedNanDetection);
247
248 if (m_bEnableBreakOnNan && dbgInfo != null)
249 {
250 string strType;
251 strFirstNanBlob = dbgInfo.DetectFirstNaN(out strType);
252
253 if (strFirstNanBlob != null)
254 {
255 string strPass = (!bFwdPassNanFree) ? "Forward" : "Backward";
256 m_log.WriteLine("First NaN detected in the '" + strType + "' of blob '" + strFirstNanBlob + "' after " + strPass + " pass.");
257
258 string strTypeLast;
259 string strLastNanBlob = dbgInfo.DetectLastNaN(out strTypeLast);
260
261 if (strLastNanBlob != strFirstNanBlob && strType != strTypeLast)
262 m_log.WriteLine("Last NaN detected in the '" + strTypeLast + "' of blob '" + strLastNanBlob + "' after " + strPass + " pass.");
263 }
264 }
265 }
266
267 double dfTime = (nTimingCount > 0) ? (dfTotalTime / nTimingCount) : 0;
268 OnTrainingIteration(this, new TrainingIterationArgs<T>(m_nIter, m_dfLastAccuracy, dfLoss, m_dfSmoothedLoss, m_dfBestError, m_bWeightsUpdated, m_net.ActiveLabelCounts, m_net.LabelQueryHitPercents, m_net.LabelQueryEpochs, m_net.BoostQueryHitPercents, dfLastLearningRate, dfTime, dbgInfo));
269 dfTotalTime = 0;
270 nTimingCount = 0;
271
272 if (strFirstNanBlob != null)
273 {
274 m_log.WriteLine("Training is now stopping at iteration " + m_nIter.ToString("N0") + " as the first NaN has been detected ('" + strFirstNanBlob + "').");
275 return false;
276 }
277 }
278
279 return true;
280 }
281
286 {
287 get { return m_nTrainingTimeLimitInMinutes; }
288 set { m_nTrainingTimeLimitInMinutes = value; }
289 }
290
295 {
296 get { return m_snapshotWeightUpdatemMethod; }
297 set { m_snapshotWeightUpdatemMethod = value; }
298 }
299
304 {
305 get { return m_db; }
306 }
307
311 protected virtual void dispose()
312 {
313 if (m_net != null)
314 {
315 m_net.Dispose();
316 m_net = null;
317 }
318
319 foreach (Net<T> net in m_rgTestNets)
320 {
321 net.Dispose();
322 }
323
324 m_rgTestNets.Clear();
325
326 if (m_blobBatchInputData != null)
327 {
328 m_blobBatchInputData.Dispose();
329 m_blobBatchInputData = null;
330 }
331
332 if (m_hWorkspaceData != 0)
333 {
334 m_cuda.FreeMemory(m_hWorkspaceData);
335 m_hWorkspaceData = 0;
336 m_lWorkspaceSize = 0;
337 }
338 }
339
343 public bool EnableTesting
344 {
345 get { return m_bEnableTest; }
346 set { m_bEnableTest = value; }
347 }
348
353 {
354 get { return m_bEnableBlobDebugging; }
355 set { m_bEnableBlobDebugging = value; }
356 }
357
365 {
366 get { return TrainingNet.EnableLayerDebugging; }
367 set { TrainingNet.EnableLayerDebugging = value; }
368 }
369
374 {
375 get { return m_bEnableBreakOnNan; }
376 set { m_bEnableBreakOnNan = value; }
377 }
378
387 {
388 get { return m_bEnableDetailedNanDetection; }
389 set { m_bEnableDetailedNanDetection = value; }
390 }
391
396 {
397 get { return m_bEnableSingleStep; }
398 set { m_bEnableSingleStep = value; }
399 }
400
404 public bool WeightsUpdated
405 {
406 get { return m_bWeightsUpdated; }
407 set { m_bWeightsUpdated = value; }
408 }
409
413 public object Tag
414 {
415 get { return m_tag; }
416 set { m_tag = value; }
417 }
418
423 {
424 get
425 {
426 if (m_rgTestNets.Count == 0)
427 return null;
428
429 return m_rgTestNets[0];
430 }
431 }
432
437 {
438 get { return m_net; }
439 }
440
446 public void Init(SolverParameter p, Net<T> shareNet = null)
447 {
448 m_log.WriteLine("Initializing solver from parameters: " + p.DebugString());
449 m_param = p;
450 m_log.CHECK_GE(m_param.average_loss, 1, "Average loss should be non-negative and >= 1.0.");
451
452 if (m_param.random_seed >= 0)
454
455 // Scaffolding code.
456 InitTrainNet(shareNet);
457 InitTestNets();
458
459 if (is_root_solver)
460 m_log.WriteLine("Solver scaffolding done.");
461
462 Reset();
463 }
464
468 public void Reset()
469 {
470 m_nIter = 0;
471 m_nCurrentStep = 0;
472 }
473
478 protected void InitTrainNet(Net<T> shareNet = null)
479 {
480 try
481 {
482 int num_train_nets = ((m_param.net_param != null) ? 1 : 0) + ((m_param.train_net_param != null) ? 1 : 0);
483 string field_names = "net_param, train_net_param";
484 m_log.CHECK_GE(num_train_nets, 1, "SolverParameter must specify a train net using one of these fields: " + field_names);
485 m_log.CHECK_LE(num_train_nets, 1, "SolverParameter must not contain more than one of these fields specifying a train_net: " + field_names);
486 NetParameter net_param = null;
487
488 if (m_param.train_net_param != null)
489 {
490 m_log.WriteLine("Creating training net specified in train_net_param.");
491 net_param = m_param.train_net_param.Clone(true);
492 }
493
494 if (m_param.net_param != null)
495 {
496 m_log.WriteLine("Creating training net specified in net_param.");
497 net_param = m_param.net_param.Clone(true);
498 }
499
500 // Set the correct NetState. We start with the solver defaults (lowest
501 // precedence); then, merge in any NetState specified by the net_param itself;
502 // finally, merge in any NetState specified by the train-state (highest
503 // precedence).
504 NetState net_state = new NetState();
505 net_state.phase = Phase.TRAIN;
506 net_state.MergeFrom(net_param.state);
507 net_state.MergeFrom(m_param.train_state);
508 net_param.state = net_state;
509 net_param.solver_count = m_nSolverCount;
510 net_param.solver_rank = m_nSolverRank;
511 m_net = new Net<T>(m_cuda, m_log, net_param, m_evtCancel, m_db, Phase.NONE, m_evtCompleted, shareNet, net_OnGetWorkspace, net_OnSetWorkspace);
512 m_net.OnGetIteration += net_OnGetIteration;
513 }
514 catch(Exception excpt)
515 {
516 throw new Exception("Initializing Training Net: " + excpt.Message);
517 }
518 }
519
520 private void net_OnSetWorkspace(object sender, WorkspaceArgs e)
521 {
522 if (OnSetWorkspace != null)
523 {
524 OnSetWorkspace(sender, e);
525 return;
526 }
527
528 if (e.Size <= m_lWorkspaceSize)
529 return;
530
531 m_lWorkspaceSize = e.Size;
532 m_cuda.DisableGhostMemory();
533
534 if (m_hWorkspaceData != 0)
535 m_cuda.FreeMemory(m_hWorkspaceData);
536
537 m_hWorkspaceData = m_cuda.AllocMemory((long)m_lWorkspaceSize);
538 m_cuda.ResetGhostMemory();
539 }
540
541 private void net_OnGetWorkspace(object sender, WorkspaceArgs e)
542 {
543 if (OnGetWorkspace != null)
544 {
545 OnGetWorkspace(sender, e);
546 return;
547 }
548
549 e.Data = m_hWorkspaceData;
550 e.Size = m_lWorkspaceSize;
551 }
552
553 private void net_OnGetIteration(object sender, GetIterationArgs e)
554 {
555 e.SetIteration(Phase.TRAIN, m_nIter);
556 }
557
561 protected void InitTestNets()
562 {
563 try
564 {
565 int num_generic_nets = ((m_param.net_param != null) ? 1 : 0);
566 int num_test_net_params = m_param.test_net_param.Count;
567 int num_test_nets = num_test_net_params;
568
569 if (num_generic_nets > 0)
570 m_log.CHECK_GE(m_param.test_iter.Count, num_test_nets, "test_iter must be specified fore each test network.");
571 else
572 m_log.CHECK_EQ(m_param.test_iter.Count, num_test_nets, "test_iter must be specified fore each test network.");
573
574 // If we have a generic net (specified by net or net_param, rather than
575 // test_net or test_net_param), we may have an unlimited number of actual
576 // test networks -- the actual number is given by the number of remaining
577 // test_iters after any test nets specified by test_net_param and/or test_net
578 // are evaluated.
579 int num_generic_net_instances = m_param.test_iter.Count - num_test_nets;
580 int num_test_net_instances = num_test_nets + num_generic_net_instances;
581
582 if (m_param.test_state.Count > 0)
583 m_log.CHECK_EQ(m_param.test_state.Count, num_test_net_instances, "test_state must be unspecified or specified once per test net.");
584
585 if (num_test_net_instances > 0)
586 m_log.CHECK_GT(m_param.test_interval, 0, "The test interval must be greater than zero.");
587
588 List<string> sources = new List<string>();
589 List<NetParameter> net_params = new List<NetParameter>();
590
591 for (int i = 0; i < num_test_net_params; i++)
592 {
593 sources.Add("test_net_param");
594 net_params.Add(m_param.test_net_param[i].Clone());
595 }
596
597 int remaining_test_nets = m_param.test_iter.Count - num_test_net_params;
598
599 if (m_param.net_param != null)
600 {
601 for (int i = 0; i < remaining_test_nets; i++)
602 {
603 sources.Add("net_param");
604 net_params.Add(m_param.net_param.Clone());
605 }
606 }
607
608 m_rgTestNets = new List<Net<T>>();
609
610 for (int i = 0; i < num_test_net_instances; i++)
611 {
612 // Set the correct NetState. We start with the solver defaults (lowest
613 // precedence); then, merge in any NetState specified by the net_param
614 // itself; finally, merge in any NetState specified by the test_state
615 // (highest precedence).
616 NetState net_state = new NetState();
617 net_state.phase = Phase.TEST;
618 net_state.MergeFrom(net_params[i].state);
619
620 if (m_param.test_state.Count > 0)
621 net_state.MergeFrom(m_param.test_state[i]);
622
623 net_params[i].state = net_state;
624
625 m_log.WriteLine("Creating test net (#" + i.ToString() + ") specified by " + sources[i], true);
626 Net<T> net = new Net<T>(m_cuda, m_log, net_params[i], m_evtCancel, m_db, Phase.NONE, null, TrainingNet, net_OnGetWorkspace, net_OnSetWorkspace);
627
628 m_rgTestNets.Add(net);
629 m_rgTestNets[i].set_debug_info(m_param.debug_info);
630 }
631 }
632 catch (Exception excpt)
633 {
634 throw new Exception("Initializing Testing Nets: " + excpt.Message);
635 }
636 }
637
642 {
643 get { return m_cuda; }
644 }
645
649 public string ActiveLabelCounts
650 {
651 get { return m_net.ActiveLabelCounts; }
652 }
653
658 {
659 get { return m_net.LabelQueryHitPercents; }
660 }
661
665 public string LabelQueryEpochs
666 {
667 get { return m_net.LabelQueryEpochs; }
668 }
669
674 {
675 get { return m_nIter; }
676 }
677
682 {
683 get { return m_param.max_iter; }
684 }
685
690 {
691 get
692 {
693 int nIters = m_param.max_iter - m_nIter;
694
695 if (m_nTrainingIterationOverride > 0)
696 nIters = m_nTrainingIterationOverride;
697
698 return nIters;
699 }
700 }
701
706 {
707 get
708 {
709 int nIters = (m_param.test_iter.Count == 0) ? 0 : m_param.test_iter[0];
710
711 if (m_nTestingIterationOverride > 0)
712 nIters = m_nTestingIterationOverride;
713
714 return nIters;
715 }
716 }
717
726 public virtual void Solve(int nIterationOverride = -1, byte[] rgWeights = null, byte[] rgState = null, TRAIN_STEP step = TRAIN_STEP.NONE)
727 {
728 m_log.CHECK(is_root_solver, "Solve is only supported by the root solver.");
729 m_log.WriteLine("Solving " + m_net.name);
730 m_log.WriteLine("Learing Rate Policy: " + m_param.lr_policy);
731
732 if (rgWeights != null || rgState != null)
733 Restore(rgWeights, rgState);
734
735 // For a network that is trained by the solver, no bottom or top vecs
736 // should be given, and we will just provide dummy vecs.
737 int start_iter = m_nIter;
738
739 if (nIterationOverride <= 0)
740 nIterationOverride = TrainingIterations;
741
742 if (!Step(nIterationOverride, step))
743 return;
744
745 // If we haven't already, save a snapshot after optimization, unless
746 // overriden by setting snapshot_after_train = false.
747 if (step == TRAIN_STEP.NONE && (m_param.snapshot_after_train && (m_param.snapshot == 0 || (m_nIter % m_param.snapshot) != 0)))
748 Snapshot(false, true);
749 else if (m_net.learnable_parameters.SnapshotRequested(true))
750 Snapshot(true, false);
751
752 if (m_evtCancel.WaitOne(0))
753 {
754 m_log.WriteLine("Optimization stopped early.");
755 return;
756 }
757
758 // After the optimization is done, run an additional train and test pass to
759 // display the train and test loss/outputs if appropriate (based on the
760 // display and test_interval settings, respectively). Unlike in the rest of
761 // training, for the train net we only run a forward pass as we've already
762 // updated the parameters 'max_iter' times -- this final pass is only done to
763 // display the loss, which is computed in the forward pass.
764 if (m_param.display > 0 && (m_nIter % m_param.display) == 0)
765 {
766 double dfLoss;
767 m_net.Forward(out dfLoss);
768
769 UpdateSmoothedLoss(dfLoss, start_iter);
770 m_log.WriteLine("Iteration " + m_nIter + ", loss = " + m_dfSmoothedLoss.ToString());
771 }
772
774 {
775 if (m_bEnableTest)
776 TestAll();
777 }
778
779 m_log.WriteLine("Optimization done.");
780
781 if (m_blobBatchInputData != null)
782 {
783 m_blobBatchInputData.Dispose();
784 m_blobBatchInputData = null;
785 }
786 }
787
800 public bool Step(int nIters, TRAIN_STEP step = TRAIN_STEP.NONE, bool bZeroDiffs = true, bool bApplyUpdates = true, bool bDisableOutput = false, bool bDisableProgress = false, double? dfLossOverride = null, bool? bAllowSnapshot = null)
801 {
802 Exception err = null;
803
804 try
805 {
806 BlobCollection<T> colBottom = new BlobCollection<T>();
807 int start_iter = m_nIter;
808 int stop_iter = m_nIter + nIters;
809
810 m_rgLosses.Clear();
812
813 // Break on first NaN is a debugging tool
814 // that causes the network to stop training
815 // right after a NaN is discovered either
816 // just after the forward pass or just
817 // after the backward pass.
818 m_net.EnableBreakOnFirstNaN = m_bEnableBreakOnNan && m_bEnableBlobDebugging;
819 m_net.EnableDetailedNanDetection = m_bEnableDetailedNanDetection & m_bEnableBlobDebugging;
820
821 Stopwatch sw = new Stopwatch();
822 sw.Start();
823
824 Stopwatch swTimeout = new Stopwatch();
825 swTimeout.Start();
826
827 while (m_nIter < stop_iter && !m_evtCompleted.WaitOne(0))
828 {
829 // zero-init the params.
830 if (bZeroDiffs)
831 m_net.ClearParamDiffs();
832
833 if (OnStart != null)
834 OnStart(this, new EventArgs());
835
836 if (step == TRAIN_STEP.NONE && (forceTest ||
837 (m_param.test_interval > 0 &&
838 (m_nIter % m_param.test_interval) == 0 &&
840 {
841 if (m_bEnableTest && is_root_solver)
842 m_dfLastAccuracy = TestAll();
843
844 // Break out of the while loop because a stop was requested while testing.
845 if (m_evtCancel.WaitOne(0))
846 break;
847 }
848
849 // on_start currently not used, so no event added.
850 bool bDisplay1 = (is_root_solver && m_param.display > 0 && (m_nIter % m_param.display) == 0 && !bDisableOutput) ? true : false;
851 m_net.set_debug_info(bDisplay1 && m_param.debug_info);
852
853 // accumulate the loss and gradient
854 double dfLoss = 0;
855 double dfLossTotal = 0;
856 int nIterCount = 0;
857
858 Stopwatch swTiming = new Stopwatch();
859 double dfTotalTime = 0;
860 int nTimingCount = 0;
861 bool bFwdPassNanFree = true;
862
863 for (int i = 0; i < m_param.iter_size; i++)
864 {
865 double dfLocalLoss;
866
867 swTiming.Restart();
868
869 if (OnCustomForwardBack != null)
870 {
872 OnCustomForwardBack(this, args);
873 bFwdPassNanFree = args.FwdPassNanFree;
874 dfLocalLoss = args.LocalLoss;
875 }
876 else
877 {
878 bFwdPassNanFree = m_net.ForwardBackward(colBottom, out dfLocalLoss, step);
879 }
880
881 if (double.IsNaN(dfLocalLoss) || double.IsInfinity(dfLocalLoss))
882 {
883 if (m_bFirstNanError)
884 {
885 m_log.WriteError(new Exception("The local loss at iteration " + m_nIter.ToString() + " is invalid (NAN or INFINITY)!"));
886 m_bFirstNanError = false;
887 }
888 }
889
890 dfLossTotal += dfLocalLoss;
891 swTiming.Stop();
892
893 dfTotalTime += swTiming.Elapsed.TotalMilliseconds;
894 nTimingCount++;
895 nIterCount++;
896
897 if (!bFwdPassNanFree)
898 break;
899 }
900
901 dfLoss = dfLossTotal / nIterCount;
902 dfLoss = dfLossOverride.GetValueOrDefault(dfLoss);
903
904 // average the loss across iterations for smoothed reporting
905 UpdateSmoothedLoss(dfLoss, start_iter);
906
907 bool bDisplay = false;
908 if (!bDisplay1 && sw.ElapsedMilliseconds > 2000 && !bDisableOutput)
909 {
910 bDisplay = true;
911 m_bFirstNanError = true;
912 sw.Restart();
913 }
914
915 if (bDisplay && bDisplay1)
916 {
917 m_log.WriteLine("Iteration " + m_nIter.ToString() + ", loss = " + m_dfSmoothedLoss.ToString());
918
919 BlobCollection<T> colResult = m_net.output_blobs;
920 int score_index = 0;
921
922 if (is_root_solver)
923 {
924 for (int j = 0; j < colResult.Count; j++)
925 {
926 double[] result_vec = Utility.ConvertVec<T>(colResult[j].update_cpu_data());
927 int nIdx = m_net.output_blob_indices[j];
928 string output_name = m_net.blob_names[nIdx];
929 double loss_weight = m_net.blob_loss_weights[nIdx];
930 double dfTotalLossWeight = 0;
931 int nResultCount = colResult[j].count();
932
933 for (int k = 0; k < nResultCount; k++)
934 {
936 {
937 string strOut = "";
938
939 if (loss_weight != 0)
940 strOut += " (* " + loss_weight.ToString() + " = " + (loss_weight * result_vec[k]).ToString() + " loss)";
941
942 m_log.WriteLine(" Train net output #" + score_index.ToString() + ": " + output_name + " = " + result_vec[k].ToString() + strOut);
943 score_index++;
944 }
945 else
946 {
947 dfTotalLossWeight += loss_weight * result_vec[k];
948 }
949 }
950
952 {
953 double dfAverage = dfTotalLossWeight / nResultCount;
954 m_log.WriteLine(" Average weighted score = " + dfAverage.ToString() + " for '" + output_name + "' - averaged over " + nResultCount.ToString("N0") + " results.");
955 }
956 }
957 }
958 }
959
960 if (OnGradientsReady != null && bFwdPassNanFree)
962
963 double dfLastLearningRate = 0;
964
965 if (step != TRAIN_STEP.FORWARD && bApplyUpdates)
966 dfLastLearningRate = ApplyUpdate(m_nIter);
967
968 if (m_evtCancel.WaitOne(0))
969 break;
970
971 if (!bDisableProgress)
972 m_log.Progress = (double)m_nIter / (double)stop_iter;
973
974 bool bSnapshotTaken = false;
975 bool bForceSnapshot = forceSnapshot;
976
977 if ((step == TRAIN_STEP.NONE || bAllowSnapshot.GetValueOrDefault(false)) && (is_root_solver && bFwdPassNanFree &&
978 (bForceSnapshot ||
979 (m_param.snapshot > 0 && (m_nIter % m_param.snapshot) == 0) ||
980 (m_dfLastAccuracy > m_dfBestAccuracy))))
981 {
982 bSnapshotTaken = true;
983 Snapshot(bForceSnapshot, ((m_param.snapshot > 0 && (m_nIter % m_param.snapshot) == 0)) ? true : false);
984
985 if (m_dfLastAccuracy > m_dfBestAccuracy)
986 m_dfBestAccuracy = m_dfLastAccuracy;
987 }
988
989 //-------------------------------------
990 // Call the training iteration event
991 // on the root solver.
992 //-------------------------------------
993 fireOnTrainingIterationEvent(bFwdPassNanFree, dfLoss, dfLastLearningRate, ref nTimingCount, ref dfTotalTime);
994
995 //-------------------------------------
996 // If single stepping, stop the solver.
997 //-------------------------------------
998 if (step != TRAIN_STEP.NONE || m_bEnableSingleStep)
999 {
1000 if (step == TRAIN_STEP.BOTH)
1001 {
1002 if (!bDisableOutput)
1003 m_log.WriteLine("Single step (both) triggered - solving stopped after a single forward/backward pass.");
1004 }
1005 else if (step == TRAIN_STEP.FORWARD)
1006 {
1007 if (!bDisableOutput)
1008 m_log.WriteLine("Single step (forward) triggered - solving stopped after a single forward pass.");
1009 }
1010 else if (step == TRAIN_STEP.BACKWARD)
1011 {
1012 if (!bDisableOutput)
1013 m_log.WriteLine("Single step (backward) triggered - solving stopped after a single backward pass.");
1014 }
1015 else
1016 {
1017 // When single stepping, force the snapshot so as to allow
1018 // debugging the net visually.
1019 if (!bSnapshotTaken)
1020 Snapshot(true, false);
1021 }
1022 break;
1023 }
1024
1025 //-------------------------------------
1026 // If a time-limit has been imposed
1027 // and we have exceeded it, stop
1028 // training.
1029 //-------------------------------------
1030 if (m_nTrainingTimeLimitInMinutes > 0 && swTimeout.Elapsed.TotalMinutes > m_nTrainingTimeLimitInMinutes)
1031 {
1032 m_log.WriteLine("A training time-limit of " + m_nTrainingTimeLimitInMinutes.ToString("N0") + " minutes has been exceeded - training will now stop.");
1033 return true;
1034 }
1035
1036 if (!bApplyUpdates)
1037 break;
1038 }
1039
1040 return true;
1041 }
1042 catch (Exception excpt)
1043 {
1044 err = excpt;
1045 throw excpt;
1046 }
1047 finally
1048 {
1049 if (err != null || m_evtCancel.WaitOne(0))
1050 {
1051 if (OnAborted != null)
1052 OnAborted(this, new EventArgs());
1053 }
1054 }
1055 }
1056
1063 public void Restore(byte[] rgWeights, byte[] rgState, string strSkipBlobTypes = null)
1064 {
1065 m_net.LoadWeights(rgWeights, m_persist, null, null, strSkipBlobTypes);
1066
1067 if (rgState != null)
1068 {
1069 m_log.WriteLine("Restoring previous solver state from restore state...");
1070 RestoreSolverState(rgState);
1071 }
1072 }
1073
1081 public void Snapshot(bool bForced, bool bScheduled, bool bUpdateDatabase = true)
1082 {
1083 m_log.WriteLine("Starting snap shot...");
1084 m_log.CHECK(is_root_solver, "Snapshot only supported on the root solver.");
1085
1086 if (OnSnapshot == null)
1087 return;
1088
1089 if (m_snapshotWeightUpdatemMethod == SNAPSHOT_WEIGHT_UPDATE_METHOD.DISABLED && !bForced)
1090 {
1091 m_log.WriteLine("WARNING: Snapshot UPDATE_METHOD = DISABLED.");
1092 return;
1093 }
1094
1095 SnapshotArgs args = GetSnapshotArgs(null, null, m_dfLastAccuracy, m_dfLastError, m_nIter, m_snapshotWeightUpdatemMethod);
1096 args.Forced = bForced;
1097 args.Scheduled = bScheduled;
1098 args.UpdateDatabase = bUpdateDatabase;
1099
1100 OnSnapshot(this, args);
1101 m_log.WriteLine("Snapshot completed.");
1102 }
1103
1104 private void args_OnGetWeights(object sender, GetBytesArgs e)
1105 {
1106 if (m_net != null)
1107 e.Data = m_net.SaveWeights(m_persist, m_param.snapshot_diff);
1108 }
1109
1110 private void args_OnGetState(object sender, GetBytesArgs e)
1111 {
1113 }
1114
1125 public SnapshotArgs GetSnapshotArgs(byte[] rgState, byte[] rgWeights, double dfAccuracy, double dfError, int nIteration, SNAPSHOT_WEIGHT_UPDATE_METHOD wtUpdt)
1126 {
1127 if (dfAccuracy == 0)
1128 dfAccuracy = 0.0001;
1129
1130 SnapshotArgs args = new SnapshotArgs(rgState, rgWeights, dfAccuracy, dfError, nIteration, wtUpdt);
1131
1134 args.SingleStep = m_bEnableSingleStep;
1135 args.OnGetState += args_OnGetState;
1136 args.OnGetWeights += args_OnGetWeights;
1137
1138 return args;
1139 }
1140
1145 {
1146 get { return m_nTrainingIterationOverride; }
1147 set { m_nTrainingIterationOverride = value; }
1148 }
1149
1154 {
1155 get { return m_nTestingIterationOverride; }
1156 set { m_nTestingIterationOverride = value; }
1157 }
1158
1162 public AutoResetEvent CompletedEvent
1163 {
1164 get { return m_evtCompleted; }
1165 }
1166
1171 {
1172 get { return m_evtCancel; }
1173 }
1174
1178 public double smoothed_loss
1179 {
1180 get { return m_dfSmoothedLoss; }
1181 }
1182
1187 {
1188 get { return m_param; }
1189 }
1190
1195 {
1196 get { return m_net; }
1197 }
1198
1202 public List<Net<T>> test_nets
1203 {
1204 get { return m_rgTestNets; }
1205 }
1206
1210 public int iter
1211 {
1212 get { return m_nIter; }
1213 }
1214
1219 {
1220 get { return m_param.type; }
1221 }
1222
1226 protected bool forceSnapshot
1227 {
1228 get
1229 {
1230 if (m_evtForceSnapshot == null)
1231 return false;
1232
1233 return m_evtForceSnapshot.WaitOne(0);
1234 }
1235 }
1236
1240 public bool forceTest
1241 {
1242 get
1243 {
1244 if (m_evtForceTest == null)
1245 return false;
1246
1247 return m_evtForceTest.WaitOne(0);
1248 }
1249 }
1250
1254 public int solver_count
1255 {
1256 get { return m_nSolverCount; }
1257 }
1258
1262 public int solver_rank
1263 {
1264 get { return m_nSolverRank; }
1265 }
1266
1273 public bool is_root_solver
1274 {
1275 get { return (m_nSolverRank == 0) ? true : false; }
1276 }
1277
1287 public double TestAll(int nIterationOverride = -1)
1288 {
1289 double dfTotalAccuracy = 0;
1290 double dfTotalTime = 0;
1291 int nTotalCount = 0;
1292
1293 for (int test_net_id = 0; test_net_id < m_rgTestNets.Count; test_net_id++)
1294 {
1295 if (m_evtCancel.WaitOne(0))
1296 return 0;
1297
1298 if (OnTest != null)
1299 {
1300 TestArgs args = new TestArgs(nIterationOverride, test_net_id);
1301 OnTest(this, args);
1302 dfTotalAccuracy += args.Accuracy;
1303 }
1304 else
1305 dfTotalAccuracy += testOne(nIterationOverride, test_net_id);
1306
1307 dfTotalTime += m_dfAverageTestTime;
1308 nTotalCount++;
1309 }
1310
1311 if (m_rgTestNets.Count == 0)
1312 {
1313 if (OnTest != null)
1314 {
1315 TestArgs args = new TestArgs(nIterationOverride, 0);
1316 OnTest(this, args);
1317 dfTotalAccuracy += args.Accuracy;
1318 }
1319 else
1320 dfTotalAccuracy += testOne(nIterationOverride, 0);
1321 }
1322
1323 double dfAccuracy = (m_rgTestNets.Count > 0) ? dfTotalAccuracy / m_rgTestNets.Count : 0;
1324
1325 if (m_rgAverageAccuracyWindow != null)
1326 {
1327 m_rgAverageAccuracyWindow.Add(dfAccuracy);
1328 m_rgAverageAccuracyWindow.RemoveAt(0);
1329 dfAccuracy = m_rgAverageAccuracyWindow.Average();
1330 }
1331
1332 if (OnTestingIteration != null)
1333 {
1334 double dfTime = (nTotalCount > 0) ? dfTotalTime / nTotalCount : 0;
1335 OnTestingIteration(this, new TestingIterationArgs<T>(m_nIter, dfAccuracy, dfTime));
1336 }
1337
1338 return dfAccuracy;
1339 }
1340
1341 private double testOne(int nIterationOverride = -1, int nTestNetId = 0)
1342 {
1343 switch (m_param.eval_type)
1344 {
1345 // Test SSD Detection
1346 case SolverParameter.EvaluationType.DETECTION:
1347 return TestDetection(nIterationOverride, nTestNetId);
1348
1349 // Perform regular classification Test.
1350 default:
1351 return TestClassification(nIterationOverride, nTestNetId);
1352 }
1353 }
1354
1361 public double TestDetection(int nIterationOverride = -1, int nTestNetId = 0)
1362 {
1363 Stopwatch sw = new Stopwatch();
1364 BBoxUtility<T> bboxUtil = new BBoxUtility<T>(m_cuda, m_log);
1365
1366 try
1367 {
1368 if (is_root_solver)
1369 m_log.WriteLine("Iteration " + m_nIter.ToString() + ", Testing net (#" + nTestNetId.ToString() + ")");
1370
1371 Net<T> test_net = m_net;
1372
1373 if (m_rgTestNets.Count > nTestNetId)
1374 {
1375 m_log.CHECK(m_rgTestNets[nTestNetId] != null, "The test net at " + nTestNetId.ToString() + " is null!");
1376 m_rgTestNets[nTestNetId].ShareTrainedLayersWith(m_net);
1377 test_net = m_rgTestNets[nTestNetId];
1378 }
1379
1380 Dictionary<int, Dictionary<int, List<Tuple<float, int>>>> rgAllTruePos = new Dictionary<int, Dictionary<int, List<Tuple<float, int>>>>();
1381 Dictionary<int, Dictionary<int, List<Tuple<float, int>>>> rgAllFalsePos = new Dictionary<int, Dictionary<int, List<Tuple<float, int>>>>();
1382 Dictionary<int, Dictionary<int, int>> rgAllNumPos = new Dictionary<int, Dictionary<int, int>>();
1383
1384 double dfLoss = 0;
1385
1386 if (nIterationOverride <= 0)
1387 nIterationOverride = TestingIterations;
1388
1389 int nIter = nIterationOverride;
1390 sw.Start();
1391
1392 for (int i = 0; i < nIter; i++)
1393 {
1394 // Check to see if stoppage of testing/training has been requested.
1395 if (m_evtCancel.WaitOne(0))
1396 break;
1397
1398 if (OnTestStart != null)
1399 OnTestStart(this, new EventArgs());
1400
1401 double iter_loss;
1402 BlobCollection<T> colResult = test_net.Forward(out iter_loss);
1403
1405 dfLoss += iter_loss;
1406
1407 for (int j = 0; j < colResult.Count; j++)
1408 {
1409 m_log.CHECK_EQ(colResult[j].width, 5, "The width must be = 5 for SSD.");
1410 double[] result_vec = Utility.ConvertVec<T>(colResult[j].update_cpu_data());
1411 int num_det = colResult[j].height;
1412
1413 for (int k = 0; k < num_det; k++)
1414 {
1415 int item_id = (int)result_vec[k * 5];
1416 int nLabel = (int)result_vec[k * 5 + 1];
1417
1418 // Special row for storing number of positives for a label.
1419 if (item_id == -1)
1420 {
1421 if (!rgAllNumPos.ContainsKey(j))
1422 rgAllNumPos.Add(j, new Dictionary<int, int>());
1423
1424 if (!rgAllNumPos[j].ContainsKey(nLabel))
1425 rgAllNumPos[j].Add(nLabel, (int)result_vec[k * 5 + 2]);
1426 else
1427 rgAllNumPos[j][nLabel] += (int)result_vec[k * 5 + 2];
1428 }
1429 // Normal row storing detection status.
1430 else
1431 {
1432 float fScore = (float)result_vec[k * 5 + 2];
1433 int tp = (int)result_vec[k * 5 + 3];
1434 int fp = (int)result_vec[k * 5 + 4];
1435
1436 // Ignore such case, which happens when a detection bbox is matched to
1437 // a difficult gt bbox and we don't evaluate on difficult gt bbox.
1438 if (tp == 0 && fp == 0)
1439 continue;
1440
1441 if (!rgAllTruePos.ContainsKey(j))
1442 rgAllTruePos.Add(j, new Dictionary<int, List<Tuple<float, int>>>());
1443
1444 if (!rgAllTruePos[j].ContainsKey(nLabel))
1445 rgAllTruePos[j].Add(nLabel, new List<Tuple<float, int>>());
1446
1447 if (!rgAllFalsePos.ContainsKey(j))
1448 rgAllFalsePos.Add(j, new Dictionary<int, List<Tuple<float, int>>>());
1449
1450 if (!rgAllFalsePos[j].ContainsKey(nLabel))
1451 rgAllFalsePos[j].Add(nLabel, new List<Tuple<float, int>>());
1452
1453 rgAllTruePos[j][nLabel].Add(new Tuple<float, int>(fScore, tp));
1454 rgAllFalsePos[j][nLabel].Add(new Tuple<float, int>(fScore, fp));
1455 }
1456 }
1457 }
1458
1459 if (sw.Elapsed.TotalMilliseconds > 1000)
1460 {
1461 m_log.Progress = (double)i / (double)nIter;
1462 m_log.WriteLine("Testing at " + m_log.Progress.ToString("P") + " " + i.ToString() + " of " + nIter.ToString() + "...");
1463 sw.Restart();
1464 }
1465 }
1466
1467 if (m_evtCancel.WaitOne(0))
1468 {
1469 m_log.WriteLine("Test interrupted.");
1470 return 0;
1471 }
1472
1474 {
1475 dfLoss /= m_param.test_iter[nTestNetId];
1476 m_log.WriteLine("Test loss: " + dfLoss.ToString());
1477 }
1478
1479 float fTotalmAP = 0;
1480 for (int i = 0; i < rgAllTruePos.Count; i++)
1481 {
1482 if (!rgAllTruePos.ContainsKey(i))
1483 m_log.FAIL("Missing output_blob true_pos: " + i.ToString());
1484
1485 Dictionary<int, List<Tuple<float, int>>> rgTruePos = rgAllTruePos[i];
1486
1487 if (!rgAllFalsePos.ContainsKey(i))
1488 m_log.FAIL("Missing output_blob false_pos: " + i.ToString());
1489
1490 Dictionary<int, List<Tuple<float, int>>> rgFalsePos = rgAllFalsePos[i];
1491
1492 if (!rgAllNumPos.ContainsKey(i))
1493 m_log.FAIL("Missing output_blob num_pos: " + i.ToString());
1494
1495 Dictionary<int, int> rgNumPos = rgAllNumPos[i];
1496
1497 Dictionary<int, float> rgAPs = new Dictionary<int, float>();
1498 float fmAP = 0.0f;
1499
1500 // Sort true_pos and false_pos with descending scores.
1501 foreach (KeyValuePair<int, int> kv in rgNumPos)
1502 {
1503 int nLabel = kv.Key;
1504 int nLabelNumPos = kv.Value;
1505
1506 if (!rgTruePos.ContainsKey(nLabel))
1507 {
1508 m_log.WriteLine("WARNING: Missing true_pos for label: " + nLabel.ToString() + "!");
1509 continue;
1510 }
1511 List<Tuple<float, int>> rgLabelTruePos = rgTruePos[nLabel];
1512
1513 if (!rgFalsePos.ContainsKey(nLabel))
1514 {
1515 m_log.WriteLine("WARNING: Missing false_pos for label: " + nLabel.ToString() + "!");
1516 continue;
1517 }
1518 List<Tuple<float, int>> rgLabelFalsePos = rgFalsePos[nLabel];
1519
1520 List<float> rgPrec;
1521 List<float> rgRec;
1522 float fAp = bboxUtil.ComputeAP(rgLabelTruePos, nLabelNumPos, rgLabelFalsePos, m_param.ap_version, out rgPrec, out rgRec);
1523
1524 if (!rgAPs.ContainsKey(nLabel))
1525 rgAPs.Add(nLabel, fAp);
1526 else
1527 rgAPs[nLabel] = fAp;
1528
1529 fmAP += fAp;
1530
1532 m_log.WriteLine("class " + nLabel.ToString() + ": " + fAp.ToString());
1533 }
1534
1535 fmAP /= rgNumPos.Count;
1536
1537 int nOutputBlobIdx = test_net.output_blob_indices[i];
1538 string strOutputName = test_net.blob_names[nOutputBlobIdx];
1539
1540 m_log.WriteLine(" Test net output #" + i.ToString() + ": " + strOutputName + " = " + fmAP.ToString());
1541 fTotalmAP += fmAP;
1542 }
1543
1544 return fTotalmAP / rgAllTruePos.Count;
1545 }
1546 catch (Exception excpt)
1547 {
1548 throw excpt;
1549 }
1550 finally
1551 {
1552 bboxUtil.Dispose();
1553 }
1554 }
1555
1562 public double TestClassification(int nIterationOverride = -1, int nTestNetId = 0)
1563 {
1564 bool bDisplay = (is_root_solver && m_param.display > 0 && (m_nIter % m_param.display) == 0) ? true : false;
1565
1566 if (bDisplay)
1567 m_log.WriteLine("Iteration " + m_nIter.ToString() + ", Testing net (#" + nTestNetId.ToString() + ")");
1568
1569 Net<T> test_net = m_net;
1570
1571 if (m_rgTestNets.Count > nTestNetId)
1572 {
1573 m_log.CHECK(m_rgTestNets[nTestNetId] != null, "The test net at " + nTestNetId.ToString() + " is null!");
1574 m_rgTestNets[nTestNetId].ShareTrainedLayersWith(m_net);
1575 test_net = m_rgTestNets[nTestNetId];
1576 }
1577
1578 List<double> test_score = new List<double>();
1579 List<int> test_score_output_id = new List<int>();
1580 double dfLoss = 0;
1581
1582 if (nIterationOverride <= 0)
1583 nIterationOverride = TestingIterations;
1584
1585 int nIter = nIterationOverride;
1586
1587 Stopwatch sw = new Stopwatch();
1588 sw.Start();
1589
1590 double dfTotalTiming = 0;
1591 int nTestCount = 0;
1592 int nAccuracyIdx = 0;
1593 int nMinRank = int.MaxValue;
1594 bool bAccuracyValid = false;
1595 Stopwatch swTiming = new Stopwatch();
1596
1597 for (int i = 0; i < nIter; i++)
1598 {
1599 // Check to see if stoppage of testing/training has been requested.
1600 if (m_evtCancel.WaitOne(0))
1601 break;
1602
1603 if (OnTestStart != null)
1604 OnTestStart(this, new EventArgs());
1605
1606 swTiming.Restart();
1607
1608 double iter_loss;
1609 BlobCollection<T> colResult = test_net.Forward(out iter_loss);
1610
1612 dfLoss += iter_loss;
1613
1614 TestResultArgs<T> args = new TestResultArgs<T>(colResult);
1615 if (OnTestResults != null)
1616 {
1617 OnTestResults(this, args);
1618 if (args.AccuracyValid)
1619 {
1620 test_score.Add(args.Accuracy);
1621 test_score_output_id.Add(1);
1622 bAccuracyValid = true;
1623 }
1624 }
1625
1626 if (!args.AccuracyValid)
1627 {
1628 if (i == 0)
1629 {
1630 for (int j = 0; j < colResult.Count; j++)
1631 {
1632 double[] result_vec = Utility.ConvertVec<T>(colResult[j].update_cpu_data());
1633
1634 for (int k = 0; k < colResult[j].count(); k++)
1635 {
1636 test_score.Add(result_vec[k]);
1637 test_score_output_id.Add(j);
1638 }
1639
1640 if (colResult[j].type == BLOB_TYPE.ACCURACY)
1641 {
1642 int nRank = (int)getNumber(colResult[j].Tag, 0);
1643 if (nRank < nMinRank)
1644 {
1645 nMinRank = nRank;
1646 nAccuracyIdx = j;
1647 }
1648 }
1649 }
1650 }
1651 else
1652 {
1653 int idx = 0;
1654
1655 for (int j = 0; j < colResult.Count; j++)
1656 {
1657 double[] result_vec = Utility.ConvertVec<T>(colResult[j].update_cpu_data());
1658
1659 for (int k = 0; k < colResult[j].count(); k++)
1660 {
1661 test_score[idx] += result_vec[k];
1662 idx++;
1663 }
1664 }
1665 }
1666 }
1667
1668 swTiming.Stop();
1669 dfTotalTiming += swTiming.Elapsed.TotalMilliseconds;
1670 nTestCount++;
1671
1672 if (sw.ElapsedMilliseconds > 2000)
1673 {
1674 double dfPct = (double)i / (double)nIter;
1675
1676 if (bDisplay)
1677 {
1678 m_log.Progress = dfPct;
1679 m_log.WriteLine("Testing '" + test_net.name + "' at " + dfPct.ToString("P"));
1680 }
1681
1682 sw.Restart();
1683 }
1684 }
1685
1686 m_dfAverageTestTime = (nTestCount > 0) ? dfTotalTiming / nTestCount : 0;
1687
1688 if (m_evtCancel.WaitOne(0))
1689 {
1690 m_log.WriteLine("Test interrupted.");
1691 return 0;
1692 }
1693
1695 {
1696 dfLoss /= m_param.test_iter[nTestNetId];
1697 m_log.WriteLine("Test loss: " + dfLoss.ToString());
1698 }
1699
1700 double dfFinalScore = 0;
1701
1702 if (bAccuracyValid)
1703 {
1704 dfFinalScore = test_score.Sum();
1705 int nTotal = test_score_output_id.Sum();
1706 dfFinalScore /= nTotal;
1707 }
1708 else
1709 {
1710 for (int i = 0; i < test_score.Count; i++)
1711 {
1712 int nIdxTestScore = test_score_output_id[i];
1713 int output_blob_index = test_net.output_blob_indices[nIdxTestScore];
1714 string output_name = test_net.blob_names[output_blob_index];
1715 double loss_weight = test_net.blob_loss_weights[output_blob_index];
1716 double dfMeanScore = test_score[i] / nIter;
1717 string strOut = "";
1718
1719 if (bDisplay)
1720 {
1721 if (loss_weight != 0)
1722 strOut += " (* " + loss_weight.ToString() + " = " + (loss_weight * dfMeanScore).ToString() + " loss)";
1723
1724 m_log.WriteLine(" Test net output #" + i.ToString() + ": " + output_name + " = " + dfMeanScore.ToString() + strOut);
1725 }
1726
1727 if (i == nAccuracyIdx)
1728 dfFinalScore = dfMeanScore;
1729 }
1730 }
1731
1732 if (test_score.Count == 0)
1733 return 0;
1734
1735 return dfFinalScore;
1736 }
1737
1738 private double getNumber(object value, double dfDefault)
1739 {
1740 if (value == null)
1741 return dfDefault;
1742
1743 if (value is sbyte)
1744 return (double)(sbyte)value;
1745
1746 if (value is byte)
1747 return (double)(byte)value;
1748
1749 if (value is short)
1750 return (double)(short)value;
1751
1752 if (value is ushort)
1753 return (double)(ushort)value;
1754
1755 if (value is int)
1756 return (double)(int)value;
1757
1758 if (value is uint)
1759 return (double)(uint)value;
1760
1761 if (value is long)
1762 return (double)(long)value;
1763
1764 if (value is ulong)
1765 return (double)(ulong)value;
1766
1767 if (value is float)
1768 return (double)(float)value;
1769
1770 if (value is double)
1771 return (double)value;
1772
1773 if (value is decimal)
1774 return (double)(decimal)value;
1775
1776 return dfDefault;
1777 }
1778
1785 public void UpdateSmoothedLoss(double dfLoss, int nStartIter, int nAverageLoss = 0)
1786 {
1787 if (nAverageLoss == 0)
1788 nAverageLoss = m_param.average_loss;
1789
1790 if (m_rgLosses.Count < nAverageLoss)
1791 {
1792 m_rgLosses.Add(dfLoss);
1793 int nCount = m_rgLosses.Count;
1794 m_dfSmoothedLoss = (m_dfSmoothedLoss * (nCount - 1) + dfLoss) / nCount;
1795 }
1796 else
1797 {
1798 int nIdx = (m_nIter - nStartIter) % nAverageLoss;
1799 m_dfSmoothedLoss += (dfLoss - m_rgLosses[nIdx]) / nAverageLoss;
1800 m_rgLosses[nIdx] = dfLoss;
1801 }
1802
1803 if (m_bWeightsUpdated)
1804 {
1805 m_dfSmoothedLoss = dfLoss;
1806 m_bWeightsUpdated = false;
1807 }
1808
1809 m_dfLastError = m_dfSmoothedLoss;
1810
1811 if (m_dfLastError < m_dfBestError)
1812 m_dfBestError = m_dfLastError;
1813 }
1814
1819 public abstract double ApplyUpdate(int nIterationOverride = -1);
1820
1824 protected abstract byte[] SnapshotSolverState();
1825
1829 protected abstract void RestoreSolverState(byte[] rgState);
1830
1848 public static SGDSolver<T> Create(CudaDnn<T> cuda, Log log, ProjectEx p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXImageDatabaseBase imgDb, IXPersist<T> persist, int nSolverCount = 1, int nSolverRank = 0, Net<T> shareNet = null, onGetWorkspace getws = null, onSetWorkspace setws = null)
1849 {
1850 SolverParameter solverParam = null;
1851
1852 if (p.SolverDescription != null)
1853 {
1854 RawProto protoSolver = RawProto.Parse(p.SolverDescription);
1855 solverParam = SolverParameter.FromProto(protoSolver);
1856 }
1857 else
1858 {
1859 solverParam = new param.SolverParameter();
1860 }
1861
1862 if (solverParam.net_param == null)
1863 {
1864 RawProto protoModel = RawProto.Parse(p.ModelDescription);
1865 solverParam.net_param = NetParameter.FromProto(protoModel);
1866 solverParam.net_param.ProjectID = p.ID;
1867 }
1868
1869 return Create(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, imgDb, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1870 }
1871
1889 public static SGDSolver<T> Create(CudaDnn<T> cuda, Log log, SolverParameter solverParam, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXImageDatabaseBase imgDb, IXPersist<T> persist, int nSolverCount = 1, int nSolverRank = 0, Net<T> shareNet = null, onGetWorkspace getws = null, onSetWorkspace setws = null)
1890 {
1891 SGDSolver<T> solver = null;
1892
1893 switch (solverParam.type)
1894 {
1895 case SolverParameter.SolverType.SGD:
1896 solver = new SGDSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, imgDb, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1897 break;
1898
1899 case SolverParameter.SolverType.NESTEROV:
1900 solver = new NesterovSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, imgDb, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1901 break;
1902
1903 case SolverParameter.SolverType.ADAGRAD:
1904 solver = new AdaGradSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, imgDb, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1905 break;
1906
1907 case SolverParameter.SolverType.ADADELTA:
1908 solver = new AdaDeltaSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, imgDb, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1909 break;
1910
1911 case SolverParameter.SolverType.ADAM:
1912 solver = new AdamSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, imgDb, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1913 break;
1914
1915 case SolverParameter.SolverType.ADAMW:
1916 solver = new AdamWSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, imgDb, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1917 break;
1918
1919 case SolverParameter.SolverType.RMSPROP:
1920 solver = new RmsPropSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, imgDb, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1921 break;
1922
1923 default:
1924 throw new NotImplementedException("The solver " + solverParam.type.ToString() + " is not implemented yet!");
1925 }
1926
1927 return solver;
1928 }
1929 }
1930
1931#pragma warning disable 1591
1932
1933 public class OutputCollection
1934 {
1935 OutputDataCollection m_rgError = new OutputDataCollection();
1936 OutputDataCollection m_rgAccuracy = new OutputDataCollection();
1937
1938 public OutputCollection()
1939 {
1940 }
1941
1942 public OutputDataCollection Errors
1943 {
1944 get { return m_rgError; }
1945 }
1946
1947 public OutputDataCollection Accuracies
1948 {
1949 get { return m_rgAccuracy; }
1950 }
1951 }
1952
1953 public class OutputDataCollection : IEnumerable<OutputData>
1954 {
1955 List<OutputData> m_rgData = new List<OutputData>();
1956
1957 public OutputDataCollection()
1958 {
1959 }
1960
1961 public List<OutputData> Data
1962 {
1963 get { return m_rgData; }
1964 }
1965
1966 public int Count
1967 {
1968 get { return m_rgData.Count; }
1969 }
1970
1971 public OutputData this[int nIdx]
1972 {
1973 get { return m_rgData[nIdx]; }
1974 set { m_rgData[nIdx] = value; }
1975 }
1976
1977 public void Add(int nTotal, string strName, int nIdx, double dfVal)
1978 {
1979 OutputData data = Find(strName);
1980
1981 if (data == null)
1982 {
1983 data = new OutputData(strName, nIdx);
1984 m_rgData.Add(data);
1985 }
1986
1987 data.Add(nTotal, dfVal);
1988 }
1989
1990 public OutputData Find(string strName)
1991 {
1992 foreach (OutputData data in m_rgData)
1993 {
1994 if (data.Name == strName)
1995 return data;
1996 }
1997
1998 return null;
1999 }
2000
2001 public IEnumerator<OutputData> GetEnumerator()
2002 {
2003 return m_rgData.GetEnumerator();
2004 }
2005
2006 IEnumerator IEnumerable.GetEnumerator()
2007 {
2008 return m_rgData.GetEnumerator();
2009 }
2010 }
2011
2012 public class OutputData
2013 {
2014 string m_strName;
2015 double m_dfValue = 0;
2016 int m_nIdx;
2017
2018 public OutputData(string strName, int nIdx)
2019 {
2020 m_strName = strName;
2021 m_nIdx = nIdx;
2022 }
2023
2024 public int Index
2025 {
2026 get { return m_nIdx; }
2027 }
2028
2029 public string Name
2030 {
2031 get { return m_strName; }
2032 }
2033
2034 public double Value
2035 {
2036 get { return m_dfValue; }
2037 set { m_dfValue = value; }
2038 }
2039
2040 public void Add(int nTotal, double dfVal)
2041 {
2042 double dfRatio = 1.0 / (double)nTotal;
2043 m_dfValue = (m_dfValue * (1.0 - dfRatio)) + (dfRatio * dfVal);
2044 }
2045 }
2046
2047#pragma warning restore 1591
2048}
The CancelEvent provides an extension to the manual cancel event that allows for overriding the manua...
Definition: CancelEvent.cs:17
bool WaitOne(int nMs=int.MaxValue)
Waits for the signal state to occur.
Definition: CancelEvent.cs:290
The Log class provides general output in text form.
Definition: Log.cs:13
void CHECK(bool b, string str)
Test a flag for true.
Definition: Log.cs:227
bool IsEnabled
Returns whether or not the Log is enabled.
Definition: Log.cs:50
void WriteLine(string str, bool bOverrideEnabled=false, bool bHeader=false, bool bError=false, bool bDisable=false)
Write a line of output.
Definition: Log.cs:80
bool Enable
Enables/disables the Log. When disabled, the Log does not output any data.
Definition: Log.cs:42
void FAIL(string str)
Causes a failure which throws an exception with the desciptive text.
Definition: Log.cs:394
double Progress
Get/set the progress associated with the Log.
Definition: Log.cs:147
void CHECK_EQ(double df1, double df2, string str)
Test whether one number is equal to another.
Definition: Log.cs:239
void WriteError(Exception e)
Write an error as output.
Definition: Log.cs:130
void CHECK_GT(double df1, double df2, string str)
Test whether one number is greater than another.
Definition: Log.cs:299
void CHECK_LE(double df1, double df2, string str)
Test whether one number is less than or equal to another.
Definition: Log.cs:263
void CHECK_GE(double df1, double df2, string str)
Test whether one number is greater than or equal to another.
Definition: Log.cs:287
The ProjectEx class manages a project containing the solver description, model description,...
Definition: ProjectEx.cs:15
string? SolverDescription
Get/set the solver description script used by the Project.
Definition: ProjectEx.cs:710
int ID
Returns the ID of the Project in the database.
Definition: ProjectEx.cs:517
string? ModelDescription
Get/set the model description script used by the Project.
Definition: ProjectEx.cs:741
The RawProto class is used to parse and output Google prototxt file data.
Definition: RawProto.cs:17
static RawProto Parse(string str)
Parses a prototxt and places it in a new RawProto.
Definition: RawProto.cs:306
The Utility class provides general utility funtions.
Definition: Utility.cs:35
The BBox class processes the NormalizedBBox data used with SSD.
Definition: BBoxUtility.cs:22
void Dispose()
Clean up all resources.
Definition: BBoxUtility.cs:43
float ComputeAP(List< Tuple< float, int > > rgTp, int nNumPos, List< Tuple< float, int > > rgFp, ApVersion apVersion, out List< float > rgPrec, out List< float > rgRec)
Compute the average precision given true positive and false positive vectors.
Definition: BBoxUtility.cs:69
The BlobCollection contains a list of Blobs.
void Add(Blob< T > b)
Add a new Blob to the collection.
int Count
Returns the number of items in the collection.
The Blob is the main holder of data that moves through the Layers of the Net.
Definition: Blob.cs:25
virtual void Dispose(bool bDisposing)
Releases all resources used by the Blob (including both GPU and Host).
Definition: Blob.cs:368
The CudaDnn object is the main interface to the Low-Level Cuda C++ DLL.
Definition: CudaDnn.cs:851
The CustomForwardBackArgs provide the arguments to the OnCustomForwardBack event within the Solver St...
Definition: EventArgs.cs:609
double LocalLoss
Get/set the local loss of the pass.
Definition: EventArgs.cs:655
bool FwdPassNanFree
Get/set whether or a NAN was detected in the forward pass.
Definition: EventArgs.cs:646
The DebugInformation contains information used to help debug the Layers of a Net while it is training...
string DetectFirstNaN(out string strType)
Searches for the first NaN within any of the Layers.
string DetectLastNaN(out string strType)
Searches for the last NaN within any of the Layers.
The GetBytesArgs is passed along to the SnapshotArgs::OnGetWeights and SnapshotArgs::OnGetState event...
Definition: EventArgs.cs:392
byte[] Data
Get/set the data as an array of bytes.
Definition: EventArgs.cs:406
The GetIterationArgs is sent bubbled up to the solver when a layer needs to know the curret training ...
Definition: EventArgs.cs:748
void SetIteration(Phase p, int nIteration)
The SetIteration method is used to set the iteration and the phase.
Definition: EventArgs.cs:764
The GradientsReadyArgs is sent to the Solver::OnGradientsReady event which fires at the end of each S...
Definition: EventArgs.cs:734
Connects Layer's together into a direct acrylic graph (DAG) specified by a NetParameter
Definition: Net.cs:23
BlobCollection< T > Forward()
Run forward with the input Blob's already fed separately.
Definition: Net.cs:1406
List< string > blob_names
Returns the blob names.
Definition: Net.cs:1948
List< double > blob_loss_weights
Returns the collection of blob loss weights.
Definition: Net.cs:2030
string name
Returns the network name.
Definition: Net.cs:1932
List< int > output_blob_indices
Returns a list of the output Blob indexes.
Definition: Net.cs:2178
The SnapshotArgs is sent to the Solver::OnSnapshot event which fires each time the Solver::Snapshot m...
Definition: EventArgs.cs:416
bool Forced
Get/set whether or not the snapshot was forced or not.
Definition: EventArgs.cs:580
bool SingleStep
Get/set the Solver single step.
Definition: EventArgs.cs:571
bool IncludeWeights
Get/set whether or not to include the weights in the snapshot.
Definition: EventArgs.cs:553
bool Scheduled
Get/set whether or not the snapshot is a regular scheduled snapshot (e.g. not an improved accuracy or...
Definition: EventArgs.cs:589
bool IncludeState
Get/set whether or not to include the Solver state in the snapshot.
Definition: EventArgs.cs:562
EventHandler< GetBytesArgs > OnGetState
Specifies the OnGetState event which fires when the SnapshotArgs::UpdateState method is called.
Definition: EventArgs.cs:444
bool UpdateDatabase
Get/set whether or not to update the database (default = true).
Definition: EventArgs.cs:598
EventHandler< GetBytesArgs > OnGetWeights
Specifies the OnGetWeights event which fires when the SnapshotArgs::UpdateWeights method is called.
Definition: EventArgs.cs:437
The TestArgs are passed to the Solver::OnTest event.
Definition: EventArgs.cs:169
double Accuracy
Get/set the accuracy for the test run. When overriding the testing, the override should set the accur...
Definition: EventArgs.cs:205
The TestResultArgs are passed to the Solver::OnTestResults event.
Definition: EventArgs.cs:116
bool AccuracyValid
Get/set the accuracy valid flag. When not valid, the OnTestResults event is ignored.
Definition: EventArgs.cs:156
double Accuracy
Get/set the accuracy. The recipient of this event should set this value.
Definition: EventArgs.cs:143
Specifies the TestingIterationArgs sent to the Solver::OnTestingIteration, which is called at the end...
Definition: EventArgs.cs:216
The TrainingIterationArgs is sent to the Solver::OnTrainingIteration event that fires at the end of a...
Definition: EventArgs.cs:264
The WorkspaceArgs are passed to both the Layer::OnSetWorkspace and Layer::OnGetWorkspace events.
Definition: EventArgs.cs:17
long Data
Get/set the handle to workspace data in GPU memory.
Definition: EventArgs.cs:36
ulong Size
Get/set the size of the workspace memory (in bytes).
Definition: EventArgs.cs:45
The Database class manages the actual connection to the physical database using Entity Framworks from...
Definition: Database.cs:23
Specifies the parameters use to create a Net
Definition: NetParameter.cs:16
static NetParameter FromProto(RawProto rp)
Parse a RawProto into a new instance of the parameter.
NetState state
The current 'state' of the network, including the phase, level and stage. Some layers may be included...
int ProjectID
Specifies the ID of the project that created this net param (if any).
Definition: NetParameter.cs:78
int solver_rank
Specifies the rank of the solver using this network.
int solver_count
Specifies the number of solvers used in a multi-gpu training session.
NetParameter Clone(bool bCloneLayers=true, int? nSolverCount=null, int? nSolverRank=null)
Creates a new copy of this instance of the parameter.
Specifies the NetState which includes the phase, level and stage for which a given Net is to run unde...
Definition: NetState.cs:17
Phase phase
Specifies the Phase of the NetState.
Definition: NetState.cs:61
void MergeFrom(NetState ns)
Merges another NetState with this instance.
Definition: NetState.cs:96
The SolverParameter is a parameter for the solver, specifying the train and test networks.
int max_iter
The maximum number of iterations.
List< int > test_iter
The number of iterations for each test.
NetParameter net_param
Inline train net param, possibly combined with one or more test nets.
bool debug_info
If true, print information about the state of the net that may help with debugging learning problems.
NetParameter train_net_param
Inline train net param, possibly combined with one or more test nets.
List< NetState > test_state
The states for the train/test nets. Must be unspecified or specified once per net.
SolverType
Defines the type of solver.
string lr_policy
The learning rate decay policy.
static SolverParameter FromProto(RawProto rp)
Parses a new SolverParameter from a RawProto.
ApVersion ap_version
Specifies the AP Version to use for average precision when using Single-Shot Detection (SSD) - (defau...
long random_seed
If non-negative, the seed with which the Solver will initialize the caffe random number generator – u...
int average_loss
Display the loss averaged over the last average_loss iterations.
int test_interval
The number of iterations between two testing phases.
bool output_average_results
Specifies to average loss results before they are output - this can be faster when there are a lot of...
int iter_size
Accumulate gradients over 'iter_size' x 'batch_size' instances.
string DebugString()
Returns a debug string for the SolverParameter.
EvaluationType
Defines the evaluation method used in the SSD algorithm.
bool snapshot_after_train
If false, don't save a snapshot after training finishes.
bool snapshot_include_weights
Specifies whether or not the snapshot includes the trained weights. The default = true.
bool test_compute_loss
Test the compute loss.
SolverParameter()
The SolverParameter constructor.
EvaluationType eval_type
Specifies the evaluation type to use when using Single-Shot Detection (SSD) - (default = NONE,...
bool test_initialization
If true, run an initial test pass before the first iteration, ensuring memory availability and printi...
List< NetParameter > test_net_param
Inline test net params.
int display
The number of iterations between displaying info. If display = 0, no info will be displayed.
bool snapshot_diff
Whether to snapshot diff in the results or not. Snapshotting diff will help debugging but the final p...
bool snapshot_include_state
Specifies whether or not the snapshot includes the solver state. The default = false....
bool show_per_class_result
Specifies whether or not to display results per class when using Single-Shot Detection (SSD) - (defau...
int accuracy_average_window
Specifies the window over which to average the accuracies (default = 0 which ignores averaging).
int snapshot
Specifies the snapshot interval.
SolverType type
Specifies the solver type.
NetState train_state
The states for the train/test nets. Must be unspecified or specified once per net.
Use AdaDelta Solver which has gradient based optimization like SGD.
Use AdaGrad Solver based optimization like SGD that tries to find rarely seen features.
Use Adam Solver which uses gradient based optimization like SGD that includes 'adaptive momentum esti...
Definition: AdamSolver.cs:22
Use AdamW Solver which uses gradient based optimization like Adam with a decoupled weight decay.
Definition: AdamWSolver.cs:23
Use Nesterov's accelerated gradient Solver, which is similar to SGD, but the error gradient is comput...
Use RmsProp Solver which uses gradient based optimization like SGD.
Stochastic Gradient Descent solver with momentum updates weights by a linear combination of the negat...
Definition: SGDSolver.cs:22
An interface for classes that perform optimization on Nets
Definition: Solver.cs:28
List< Net< T > > m_rgTestNets
Specifies the testing Nets.
Definition: Solver.cs:48
int TrainingIterations
Returns the current training iterations remaining.
Definition: Solver.cs:690
void InitTestNets()
Initializes the Net used by the Solver for testing.
Definition: Solver.cs:561
EventHandler< CustomForwardBackArgs< T > > OnCustomForwardBack
The OnCustomForwardBack allows for overriding the forward/backward operations within the solver.
Definition: Solver.cs:149
int m_nSolverCount
Specifies the Solver count in a multi-GPU training session.
Definition: Solver.cs:77
void Dispose()
Discards the resources (GPU and Host) used by this Solver.
Definition: Solver.cs:212
double m_dfSmoothedLoss
Specifies the smoothed loss protected for derived classes to use.
Definition: Solver.cs:70
SolverParameter m_param
Specifies the SolverParameter that defines how the Solver operates.
Definition: Solver.cs:40
EventHandler< TrainingIterationArgs< T > > OnTrainingIteration
The OnTrainingIteration event fires at the end of each training iteration.
Definition: Solver.cs:128
List< double > m_rgLosses
Specifies the Losses used to calculate the smoothed Loss.
Definition: Solver.cs:60
abstract byte[] SnapshotSolverState()
Save the current solver state.
double smoothed_loss
Returns the smoothed loss.
Definition: Solver.cs:1179
void Restore(byte[] rgWeights, byte[] rgState, string strSkipBlobTypes=null)
The restore method simply calls the RestoreSolverState method of the inherited class.
Definition: Solver.cs:1063
static SGDSolver< T > Create(CudaDnn< T > cuda, Log log, SolverParameter solverParam, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXImageDatabaseBase imgDb, IXPersist< T > persist, int nSolverCount=1, int nSolverRank=0, Net< T > shareNet=null, onGetWorkspace getws=null, onSetWorkspace setws=null)
Create a new Solver based on the project containing the SolverParameter.
Definition: Solver.cs:1889
int iter
Returns the current training iteration.
Definition: Solver.cs:1211
CudaDnn< T > m_cuda
Specifies the instance of CudaDnn used by the Solver that provides a connection to Cuda.
Definition: Solver.cs:32
void Snapshot(bool bForced, bool bScheduled, bool bUpdateDatabase=true)
The snapshot function implements the basic snapshotting utility that stores the learned net....
Definition: Solver.cs:1081
int MaximumIteration
Returns the maximum training iterations.
Definition: Solver.cs:682
EventHandler< SnapshotArgs > OnSnapshot
The OnSnapshot event fires when the Solver detects that a snapshot is needed.
Definition: Solver.cs:124
bool EnableBlobDebugging
When enabled, the OnTrainingIteration event is set extra debugging information describing the state o...
Definition: Solver.cs:353
SolverParameter.SolverType type
Returns the type of solver.
Definition: Solver.cs:1219
Net< T > net
Returns the main training Net.
Definition: Solver.cs:1195
bool ForceOnTrainingIterationEvent()
Force an OnTrainingIterationEvent to fire.
Definition: Solver.cs:230
object Tag
Returns a generic tag associated with the Solver.
Definition: Solver.cs:414
double TestDetection(int nIterationOverride=-1, int nTestNetId=0)
Run an SSD detection test on a given test Net by running it through its iterations.
Definition: Solver.cs:1361
bool? is_root_solver
Returns whether or not this is the root solver.
Definition: Solver.cs:1274
double LearningRateOverride
Get/set the learning rate override. When 0, this setting is ignored.
Definition: Solver.cs:221
bool EnableTesting
When enabled, the training cycle calls TestAll periodically based on the SolverParameter....
Definition: Solver.cs:344
int m_nIter
Specifies the current iteration.
Definition: Solver.cs:52
Net< T > TrainingNet
Returns the training Net used by the solver.
Definition: Solver.cs:437
double m_dfLearningRateOverride
Optionally, specifies a learning rate override (default = 0, which ignores this setting).
Definition: Solver.cs:89
EventHandler< TestingIterationArgs< T > > OnTestingIteration
The OnTestingIteration event fires at the end of each testing iteration.
Definition: Solver.cs:132
void InitTrainNet(Net< T > shareNet=null)
Initializes the Net used by the solver for training.
Definition: Solver.cs:478
abstract void RestoreSolverState(byte[] rgState)
Restore a solver state.
void UpdateSmoothedLoss(double dfLoss, int nStartIter, int nAverageLoss=0)
Update the avaraged loss value.
Definition: Solver.cs:1785
void Init(SolverParameter p, Net< T > shareNet=null)
Initializes the Solver.
Definition: Solver.cs:446
bool EnableBreakOnFirstNaN
When enabled (requires EnableBlobDebugging = true), the Solver immediately stop training upon detecti...
Definition: Solver.cs:374
int solver_count
Returns the solver count in a multi-GPU session.
Definition: Solver.cs:1255
SolverParameter parameter
Returns the SolverParameter used.
Definition: Solver.cs:1187
bool forceSnapshot
Returns whether or not a snapshot has been forced.
Definition: Solver.cs:1227
SNAPSHOT_WEIGHT_UPDATE_METHOD SnapshotWeightUpdateMethod
Get/set the snapshot weight update method.
Definition: Solver.cs:295
EventHandler OnAborted
The OnAborted event fires after aborting a training cycle.
Definition: Solver.cs:116
List< Net< T > > test_nets
Returns the testing Nets.
Definition: Solver.cs:1203
static SGDSolver< T > Create(CudaDnn< T > cuda, Log log, ProjectEx p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXImageDatabaseBase imgDb, IXPersist< T > persist, int nSolverCount=1, int nSolverRank=0, Net< T > shareNet=null, onGetWorkspace getws=null, onSetWorkspace setws=null)
Create a new Solver based on the project containing the SolverParameter.
Definition: Solver.cs:1848
IXPersist< T > m_persist
Specifies the persistance object used to save weight and solver states.
Definition: Solver.cs:85
int TrainingTimeLimitInMinutes
Get/set the training time limit in minutes. When set to 0, no time limit is imposed on training.
Definition: Solver.cs:286
EventHandler< WorkspaceArgs > OnGetWorkspace
Specifies the OnGetWorkspace event that fires when the getWorkspace() function is called by a layer t...
Definition: Solver.cs:153
double TestClassification(int nIterationOverride=-1, int nTestNetId=0)
Run a test on a given test Net by running it through its iterations.
Definition: Solver.cs:1562
void Reset()
Reset the iterations of the net.
Definition: Solver.cs:468
bool Step(int nIters, TRAIN_STEP step=TRAIN_STEP.NONE, bool bZeroDiffs=true, bool bApplyUpdates=true, bool bDisableOutput=false, bool bDisableProgress=false, double? dfLossOverride=null, bool? bAllowSnapshot=null)
Steps a set of iterations through a training cycle.
Definition: Solver.cs:800
double TestAll(int nIterationOverride=-1)
Run a TestAll by running all test Nets.
Definition: Solver.cs:1287
string LabelQueryEpochs
Return the label query epochs for the active datasource.
Definition: Solver.cs:666
EventHandler< TestResultArgs< T > > OnTestResults
When specified, the OnTestResults event fires after each single test run. The recipient is responsibl...
Definition: Solver.cs:136
EventHandler< GradientsReadyArgs > OnGradientsReady
The OnGradientsReady event fires after the gradients of a Solver are ready for distribution to other ...
Definition: Solver.cs:120
EventHandler< WorkspaceArgs > OnSetWorkspace
Specifies the OnSetWorkspace event that fires when the setWorkspace() function is called by a layer t...
Definition: Solver.cs:157
int? TestingIterations
Returns the current testing iterations remaining.
Definition: Solver.cs:706
bool forceTest
Returns whether or not a test has been forced.
Definition: Solver.cs:1241
int TrainingIterationOverride
Get/set the training iteration override.
Definition: Solver.cs:1145
EventHandler OnTestStart
The OnTestStart event fires at the start of each testing iteration.
Definition: Solver.cs:144
Net< T > m_net
Specifies the training Net.
Definition: Solver.cs:44
bool WeightsUpdated
Get/set when the weights have been updated.
Definition: Solver.cs:405
int m_nCurrentStep
Specifies the current step.
Definition: Solver.cs:56
int solver_rank
Returns this Solver's rank in a multi-GPU session.
Definition: Solver.cs:1263
bool EnableDetailedNanDetection
When enabled (requires EnableBlobDebugging = true), the detailed Nan (and Infinity) detection is pero...
Definition: Solver.cs:387
Log m_log
Specifies the Log for output.
Definition: Solver.cs:36
SnapshotArgs GetSnapshotArgs(byte[] rgState, byte[] rgWeights, double dfAccuracy, double dfError, int nIteration, SNAPSHOT_WEIGHT_UPDATE_METHOD wtUpdt)
The GetSnapshotArgs method fills out a snapshot args structure.
Definition: Solver.cs:1125
virtual void dispose()
Override that allows discarding of resources (GPU and Host) used by this Solver.
Definition: Solver.cs:311
EventHandler< TestArgs > OnTest
When specified, the OnTest event fires during a TestAll and overrides the call to Test.
Definition: Solver.cs:140
int TestingIterationOverride
Get/set the testing iteration override.
Definition: Solver.cs:1154
virtual void Solve(int nIterationOverride=-1, byte[] rgWeights=null, byte[] rgState=null, TRAIN_STEP step=TRAIN_STEP.NONE)
The main entry of the solver function. In default, iter will be zero. Pass in a non-zero iter number ...
Definition: Solver.cs:726
EventHandler OnStart
The OnStart event fires at the start of each training iteration.
Definition: Solver.cs:112
string ActiveLabelCounts
Returns a string describing the labels detected in the training along with the % that each label has ...
Definition: Solver.cs:650
AutoResetEvent CompletedEvent
Returns an auto reset event that is set upon training completion.
Definition: Solver.cs:1163
abstract double ApplyUpdate(int nIterationOverride=-1)
Make and apply the update value for the current iteration.
CudaDnn< T > Cuda
Returns the CudaDnn instance used by the Solver.
Definition: Solver.cs:642
bool EnableSingleStep
When enabled (requires EnableBlobDebugging = true), the Solver only runs one training cycle.
Definition: Solver.cs:396
int m_nSolverRank
Specifies the Solver rank of this solver, where rank == 0 is the root Solver.
Definition: Solver.cs:81
string LabelQueryHitPercents
Return the label query hit percentages for the active datasource.
Definition: Solver.cs:658
Net< T > TestingNet
Returns the testing Net used by the solver.
Definition: Solver.cs:423
bool EnableLayerDebugging
Enable/disable layer debugging which causes each layer to check for NAN/INF on each forward/backward ...
Definition: Solver.cs:365
Solver(CudaDnn< T > cuda, Log log, SolverParameter p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXImageDatabaseBase imgDb, IXPersist< T > persist, int nSolverCount=1, int nSolverRank=0, Net< T > shareNet=null, onGetWorkspace getws=null, onSetWorkspace setws=null)
The Solver constructor.
Definition: Solver.cs:175
int CurrentIteration
Returns the current training iteration.
Definition: Solver.cs:674
The IXImageDatabaseBase interface defines the general interface to the in-memory image database.
Definition: Interfaces.cs:415
The IXPersist interface is used by the CaffeControl to load and save weights.
Definition: Interfaces.cs:175
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
Phase
Defines the Phase under which to run a Net.
Definition: Interfaces.cs:42
SNAPSHOT_WEIGHT_UPDATE_METHOD
Defines the snapshot weight update method.
Definition: Interfaces.cs:162
The MyCaffe.common namespace contains common MyCaffe classes.
Definition: BatchInput.cs:8
BLOB_TYPE
Defines the tpe of data held by a given Blob.
Definition: Interfaces.cs:62
TRAIN_STEP
Defines the training stepping method (if any).
Definition: Interfaces.cs:119
The MyCaffe.db.image namespace contains all image database related classes.
Definition: Database.cs:18
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.solvers namespace contains all solver classes, including the base Solver.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12