MyCaffe  1.11.7.7
Deep learning software for Windows C# programmers.
Solver.cs
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using System.Threading;
6using System.IO;
7using System.Diagnostics;
8using System.Collections;
9using MyCaffe.basecode;
10using MyCaffe.db.image;
11using MyCaffe.common;
12using MyCaffe.param;
13
17namespace MyCaffe.solvers
18{
27 public abstract class Solver<T> : IDisposable
28 {
32 protected CudaDnn<T> m_cuda;
36 protected Log m_log;
44 protected Net<T> m_net;
48 protected List<Net<T>> m_rgTestNets = new List<Net<T>>();
52 protected int m_nIter;
56 protected int m_nCurrentStep;
60 protected List<double> m_rgLosses = new List<double>();
61 AutoResetEvent m_evtCompleted = new AutoResetEvent(false);
62 bool m_bEnableTest = true;
63 bool m_bEnableBlobDebugging = false;
64 bool m_bEnableBreakOnNan = false;
65 bool m_bEnableDetailedNanDetection = false;
66 bool m_bEnableSingleStep = false;
67
68 double m_dfSmoothedLoss = 0;
69 CancelEvent m_evtCancel;
70 AutoResetEvent m_evtForceSnapshot;
71 AutoResetEvent m_evtForceTest;
75 protected int m_nSolverCount = 1;
79 protected int m_nSolverRank = 0;
87 protected double m_dfLearningRateOverride = 0;
88 double m_dfLastAccuracy = 0;
89 double m_dfLastError = double.MaxValue;
90 double m_dfBestAccuracy = 0;
91 double m_dfBestError = double.MaxValue;
92 IXImageDatabaseBase m_db = null;
93 int m_nTrainingIterationOverride = -1;
94 int m_nTestingIterationOverride = -1;
95 object m_tag = null;
96 bool m_bWeightsUpdated = false;
97 static object m_syncGetRi = new object();
98 Blob<T> m_blobBatchInputData = null;
99 double m_dfAverageTestTime = 0;
100 SNAPSHOT_WEIGHT_UPDATE_METHOD m_snapshotWeightUpdatemMethod = SNAPSHOT_WEIGHT_UPDATE_METHOD.FAVOR_ACCURACY;
101 int m_nTrainingTimeLimitInMinutes = 0;
102 long m_hWorkspaceData = 0; // shared among the layers and nets, only grows in size.
103 ulong m_lWorkspaceSize = 0;
104 bool m_bFirstNanError = true;
105 List<double> m_rgAverageAccuracyWindow = null;
106
110 public event EventHandler OnStart;
114 public event EventHandler OnAborted;
118 public event EventHandler<GradientsReadyArgs> OnGradientsReady;
122 public event EventHandler<SnapshotArgs> OnSnapshot;
126 public event EventHandler<TrainingIterationArgs<T>> OnTrainingIteration;
130 public event EventHandler<TestingIterationArgs<T>> OnTestingIteration;
134 public event EventHandler<TestResultArgs<T>> OnTestResults;
138 public event EventHandler<TestArgs> OnTest;
142 public event EventHandler OnTestStart;
147 public event EventHandler<CustomForwardBackArgs<T>> OnCustomForwardBack;
151 public event EventHandler<WorkspaceArgs> OnGetWorkspace;
155 public event EventHandler<WorkspaceArgs> OnSetWorkspace;
156
173 public Solver(CudaDnn<T> cuda, Log log, SolverParameter p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXImageDatabaseBase imgDb, IXPersist<T> persist, int nSolverCount = 1, int nSolverRank = 0, Net<T> shareNet = null, onGetWorkspace getws = null, onSetWorkspace setws = null)
174 {
175 m_cuda = cuda;
176 m_log = log;
177 m_evtCancel = evtCancel;
178 m_evtForceSnapshot = evtForceSnapshot;
179 m_evtForceTest = evtForceTest;
180
181 if (m_log.IsEnabled)
183
184 m_db = imgDb;
185 m_persist = persist;
186 m_nSolverCount = nSolverCount;
187 m_nSolverRank = nSolverRank;
188
189 if (getws != null)
190 OnGetWorkspace += new EventHandler<WorkspaceArgs>(getws);
191
192 if (setws != null)
193 OnSetWorkspace += new EventHandler<WorkspaceArgs>(setws);
194
195 if (p.accuracy_average_window > 0)
196 {
197 m_rgAverageAccuracyWindow = new List<double>();
198 for (int i = 0; i < p.accuracy_average_window; i++)
199 {
200 m_rgAverageAccuracyWindow.Add(0);
201 }
202 }
203
204 Init(p, shareNet);
205 }
206
210 public void Dispose()
211 {
212 dispose();
213 }
214
219 {
220 get { return m_dfLearningRateOverride; }
221 set { m_dfLearningRateOverride = value; }
222 }
223
229 {
230 int nTimingCount = 0;
231 double dfTotalTime = 0;
232 return fireOnTrainingIterationEvent(false, 0, 0, ref nTimingCount, ref dfTotalTime);
233 }
234
235 private bool fireOnTrainingIterationEvent(bool bFwdPassNanFree, double dfLoss, double dfLastLearningRate, ref int nTimingCount, ref double dfTotalTime)
236 {
237 if (is_root_solver && OnTrainingIteration != null)
238 {
239 string strFirstNanBlob = null;
240 DebugInformation<T> dbgInfo = null;
241
242 if (m_bEnableBlobDebugging)
243 {
244 dbgInfo = TrainingNet.GetDebugInformation(m_bEnableDetailedNanDetection);
245
246 if (m_bEnableBreakOnNan && dbgInfo != null)
247 {
248 string strType;
249 strFirstNanBlob = dbgInfo.DetectFirstNaN(out strType);
250
251 if (strFirstNanBlob != null)
252 {
253 string strPass = (!bFwdPassNanFree) ? "Forward" : "Backward";
254 m_log.WriteLine("First NaN detected in the '" + strType + "' of blob '" + strFirstNanBlob + "' after " + strPass + " pass.");
255
256 string strTypeLast;
257 string strLastNanBlob = dbgInfo.DetectLastNaN(out strTypeLast);
258
259 if (strLastNanBlob != strFirstNanBlob && strType != strTypeLast)
260 m_log.WriteLine("Last NaN detected in the '" + strTypeLast + "' of blob '" + strLastNanBlob + "' after " + strPass + " pass.");
261 }
262 }
263 }
264
265 double dfTime = (nTimingCount > 0) ? (dfTotalTime / nTimingCount) : 0;
266 OnTrainingIteration(this, new TrainingIterationArgs<T>(m_nIter, m_dfLastAccuracy, dfLoss, m_dfSmoothedLoss, m_dfBestError, m_bWeightsUpdated, m_net.ActiveLabelCounts, m_net.LabelQueryHitPercents, m_net.LabelQueryEpochs, m_net.BoostQueryHitPercents, dfLastLearningRate, dfTime, dbgInfo));
267 dfTotalTime = 0;
268 nTimingCount = 0;
269
270 if (strFirstNanBlob != null)
271 {
272 m_log.WriteLine("Training is now stopping at iteration " + m_nIter.ToString("N0") + " as the first NaN has been detected ('" + strFirstNanBlob + "').");
273 return false;
274 }
275 }
276
277 return true;
278 }
279
284 {
285 get { return m_nTrainingTimeLimitInMinutes; }
286 set { m_nTrainingTimeLimitInMinutes = value; }
287 }
288
293 {
294 get { return m_snapshotWeightUpdatemMethod; }
295 set { m_snapshotWeightUpdatemMethod = value; }
296 }
297
302 {
303 get { return m_db; }
304 }
305
309 protected virtual void dispose()
310 {
311 if (m_net != null)
312 {
313 m_net.Dispose();
314 m_net = null;
315 }
316
317 foreach (Net<T> net in m_rgTestNets)
318 {
319 net.Dispose();
320 }
321
322 m_rgTestNets.Clear();
323
324 if (m_blobBatchInputData != null)
325 {
326 m_blobBatchInputData.Dispose();
327 m_blobBatchInputData = null;
328 }
329
330 if (m_hWorkspaceData != 0)
331 {
332 m_cuda.FreeMemory(m_hWorkspaceData);
333 m_hWorkspaceData = 0;
334 m_lWorkspaceSize = 0;
335 }
336 }
337
341 public bool EnableTesting
342 {
343 get { return m_bEnableTest; }
344 set { m_bEnableTest = value; }
345 }
346
351 {
352 get { return m_bEnableBlobDebugging; }
353 set { m_bEnableBlobDebugging = value; }
354 }
355
363 {
364 get { return TrainingNet.EnableLayerDebugging; }
365 set { TrainingNet.EnableLayerDebugging = value; }
366 }
367
372 {
373 get { return m_bEnableBreakOnNan; }
374 set { m_bEnableBreakOnNan = value; }
375 }
376
385 {
386 get { return m_bEnableDetailedNanDetection; }
387 set { m_bEnableDetailedNanDetection = value; }
388 }
389
394 {
395 get { return m_bEnableSingleStep; }
396 set { m_bEnableSingleStep = value; }
397 }
398
402 public bool WeightsUpdated
403 {
404 get { return m_bWeightsUpdated; }
405 set { m_bWeightsUpdated = value; }
406 }
407
411 public object Tag
412 {
413 get { return m_tag; }
414 set { m_tag = value; }
415 }
416
421 {
422 get
423 {
424 if (m_rgTestNets.Count == 0)
425 return null;
426
427 return m_rgTestNets[0];
428 }
429 }
430
435 {
436 get { return m_net; }
437 }
438
444 public void Init(SolverParameter p, Net<T> shareNet = null)
445 {
446 m_log.WriteLine("Initializing solver from parameters: " + p.DebugString());
447 m_param = p;
448 m_log.CHECK_GE(m_param.average_loss, 1, "Average loss should be non-negative and >= 1.0.");
449
450 if (m_param.random_seed >= 0)
452
453 // Scaffolding code.
454 InitTrainNet(shareNet);
455 InitTestNets();
456
457 if (is_root_solver)
458 m_log.WriteLine("Solver scaffolding done.");
459
460 Reset();
461 }
462
466 public void Reset()
467 {
468 m_nIter = 0;
469 m_nCurrentStep = 0;
470 }
471
476 protected void InitTrainNet(Net<T> shareNet = null)
477 {
478 try
479 {
480 int num_train_nets = ((m_param.net_param != null) ? 1 : 0) + ((m_param.train_net_param != null) ? 1 : 0);
481 string field_names = "net_param, train_net_param";
482 m_log.CHECK_GE(num_train_nets, 1, "SolverParameter must specify a train net using one of these fields: " + field_names);
483 m_log.CHECK_LE(num_train_nets, 1, "SolverParameter must not contain more than one of these fields specifying a train_net: " + field_names);
484 NetParameter net_param = null;
485
486 if (m_param.train_net_param != null)
487 {
488 m_log.WriteLine("Creating training net specified in train_net_param.");
489 net_param = m_param.train_net_param.Clone(true);
490 }
491
492 if (m_param.net_param != null)
493 {
494 m_log.WriteLine("Creating training net specified in net_param.");
495 net_param = m_param.net_param.Clone(true);
496 }
497
498 // Set the correct NetState. We start with the solver defaults (lowest
499 // precedence); then, merge in any NetState specified by the net_param itself;
500 // finally, merge in any NetState specified by the train-state (highest
501 // precedence).
502 NetState net_state = new NetState();
503 net_state.phase = Phase.TRAIN;
504 net_state.MergeFrom(net_param.state);
505 net_state.MergeFrom(m_param.train_state);
506 net_param.state = net_state;
507 net_param.solver_count = m_nSolverCount;
508 net_param.solver_rank = m_nSolverRank;
509 m_net = new Net<T>(m_cuda, m_log, net_param, m_evtCancel, m_db, Phase.NONE, m_evtCompleted, shareNet, net_OnGetWorkspace, net_OnSetWorkspace);
510 m_net.OnGetIteration += net_OnGetIteration;
511 }
512 catch(Exception excpt)
513 {
514 throw new Exception("Initializing Training Net: " + excpt.Message);
515 }
516 }
517
518 private void net_OnSetWorkspace(object sender, WorkspaceArgs e)
519 {
520 if (OnSetWorkspace != null)
521 {
522 OnSetWorkspace(sender, e);
523 return;
524 }
525
526 if (e.Size <= m_lWorkspaceSize)
527 return;
528
529 m_lWorkspaceSize = e.Size;
530 m_cuda.DisableGhostMemory();
531
532 if (m_hWorkspaceData != 0)
533 m_cuda.FreeMemory(m_hWorkspaceData);
534
535 m_hWorkspaceData = m_cuda.AllocMemory((long)m_lWorkspaceSize);
536 m_cuda.ResetGhostMemory();
537 }
538
539 private void net_OnGetWorkspace(object sender, WorkspaceArgs e)
540 {
541 if (OnGetWorkspace != null)
542 {
543 OnGetWorkspace(sender, e);
544 return;
545 }
546
547 e.Data = m_hWorkspaceData;
548 e.Size = m_lWorkspaceSize;
549 }
550
551 private void net_OnGetIteration(object sender, GetIterationArgs e)
552 {
553 e.SetIteration(Phase.TRAIN, m_nIter);
554 }
555
559 protected void InitTestNets()
560 {
561 try
562 {
563 int num_generic_nets = ((m_param.net_param != null) ? 1 : 0);
564 int num_test_net_params = m_param.test_net_param.Count;
565 int num_test_nets = num_test_net_params;
566
567 if (num_generic_nets > 0)
568 m_log.CHECK_GE(m_param.test_iter.Count, num_test_nets, "test_iter must be specified fore each test network.");
569 else
570 m_log.CHECK_EQ(m_param.test_iter.Count, num_test_nets, "test_iter must be specified fore each test network.");
571
572 // If we have a generic net (specified by net or net_param, rather than
573 // test_net or test_net_param), we may have an unlimited number of actual
574 // test networks -- the actual number is given by the number of remaining
575 // test_iters after any test nets specified by test_net_param and/or test_net
576 // are evaluated.
577 int num_generic_net_instances = m_param.test_iter.Count - num_test_nets;
578 int num_test_net_instances = num_test_nets + num_generic_net_instances;
579
580 if (m_param.test_state.Count > 0)
581 m_log.CHECK_EQ(m_param.test_state.Count, num_test_net_instances, "test_state must be unspecified or specified once per test net.");
582
583 if (num_test_net_instances > 0)
584 m_log.CHECK_GT(m_param.test_interval, 0, "The test interval must be greater than zero.");
585
586 List<string> sources = new List<string>();
587 List<NetParameter> net_params = new List<NetParameter>();
588
589 for (int i = 0; i < num_test_net_params; i++)
590 {
591 sources.Add("test_net_param");
592 net_params.Add(m_param.test_net_param[i].Clone());
593 }
594
595 int remaining_test_nets = m_param.test_iter.Count - num_test_net_params;
596
597 if (m_param.net_param != null)
598 {
599 for (int i = 0; i < remaining_test_nets; i++)
600 {
601 sources.Add("net_param");
602 net_params.Add(m_param.net_param.Clone());
603 }
604 }
605
606 m_rgTestNets = new List<Net<T>>();
607
608 for (int i = 0; i < num_test_net_instances; i++)
609 {
610 // Set the correct NetState. We start with the solver defaults (lowest
611 // precedence); then, merge in any NetState specified by the net_param
612 // itself; finally, merge in any NetState specified by the test_state
613 // (highest precedence).
614 NetState net_state = new NetState();
615 net_state.phase = Phase.TEST;
616 net_state.MergeFrom(net_params[i].state);
617
618 if (m_param.test_state.Count > 0)
619 net_state.MergeFrom(m_param.test_state[i]);
620
621 net_params[i].state = net_state;
622
623 m_log.WriteLine("Creating test net (#" + i.ToString() + ") specified by " + sources[i], true);
624 Net<T> net = new Net<T>(m_cuda, m_log, net_params[i], m_evtCancel, m_db, Phase.NONE, null, TrainingNet, net_OnGetWorkspace, net_OnSetWorkspace);
625
626 m_rgTestNets.Add(net);
627 m_rgTestNets[i].set_debug_info(m_param.debug_info);
628 }
629 }
630 catch (Exception excpt)
631 {
632 throw new Exception("Initializing Testing Nets: " + excpt.Message);
633 }
634 }
635
640 {
641 get { return m_cuda; }
642 }
643
647 public string ActiveLabelCounts
648 {
649 get { return m_net.ActiveLabelCounts; }
650 }
651
656 {
657 get { return m_net.LabelQueryHitPercents; }
658 }
659
663 public string LabelQueryEpochs
664 {
665 get { return m_net.LabelQueryEpochs; }
666 }
667
672 {
673 get { return m_nIter; }
674 }
675
680 {
681 get { return m_param.max_iter; }
682 }
683
688 {
689 get
690 {
691 int nIters = m_param.max_iter - m_nIter;
692
693 if (m_nTrainingIterationOverride > 0)
694 nIters = m_nTrainingIterationOverride;
695
696 return nIters;
697 }
698 }
699
704 {
705 get
706 {
707 int nIters = (m_param.test_iter.Count == 0) ? 0 : m_param.test_iter[0];
708
709 if (m_nTestingIterationOverride > 0)
710 nIters = m_nTestingIterationOverride;
711
712 return nIters;
713 }
714 }
715
724 public virtual void Solve(int nIterationOverride = -1, byte[] rgWeights = null, byte[] rgState = null, TRAIN_STEP step = TRAIN_STEP.NONE)
725 {
726 m_log.CHECK(is_root_solver, "Solve is only supported by the root solver.");
727 m_log.WriteLine("Solving " + m_net.name);
728 m_log.WriteLine("Learing Rate Policy: " + m_param.lr_policy);
729
730 if (rgWeights != null || rgState != null)
731 Restore(rgWeights, rgState);
732
733 // For a network that is trained by the solver, no bottom or top vecs
734 // should be given, and we will just provide dummy vecs.
735 int start_iter = m_nIter;
736
737 if (nIterationOverride <= 0)
738 nIterationOverride = TrainingIterations;
739
740 if (!Step(nIterationOverride, step))
741 return;
742
743 // If we haven't already, save a snapshot after optimization, unless
744 // overriden by setting snapshot_after_train = false.
745 if (step == TRAIN_STEP.NONE && (m_param.snapshot_after_train && (m_param.snapshot == 0 || (m_nIter % m_param.snapshot) != 0)))
746 Snapshot(false, true);
747 else if (m_net.learnable_parameters.SnapshotRequested(true))
748 Snapshot(true, false);
749
750 if (m_evtCancel.WaitOne(0))
751 {
752 m_log.WriteLine("Optimization stopped early.");
753 return;
754 }
755
756 // After the optimization is done, run an additional train and test pass to
757 // display the train and test loss/outputs if appropriate (based on the
758 // display and test_interval settings, respectively). Unlike in the rest of
759 // training, for the train net we only run a forward pass as we've already
760 // updated the parameters 'max_iter' times -- this final pass is only done to
761 // display the loss, which is computed in the forward pass.
762 if (m_param.display > 0 && (m_nIter % m_param.display) == 0)
763 {
764 double dfLoss;
765 m_net.Forward(out dfLoss);
766
767 UpdateSmoothedLoss(dfLoss, start_iter);
768 m_log.WriteLine("Iteration " + m_nIter + ", loss = " + m_dfSmoothedLoss.ToString());
769 }
770
772 {
773 if (m_bEnableTest)
774 TestAll();
775 }
776
777 m_log.WriteLine("Optimization done.");
778
779 if (m_blobBatchInputData != null)
780 {
781 m_blobBatchInputData.Dispose();
782 m_blobBatchInputData = null;
783 }
784 }
785
798 public bool Step(int nIters, TRAIN_STEP step = TRAIN_STEP.NONE, bool bZeroDiffs = true, bool bApplyUpdates = true, bool bDisableOutput = false, bool bDisableProgress = false, double? dfLossOverride = null, bool? bAllowSnapshot = null)
799 {
800 Exception err = null;
801
802 try
803 {
804 BlobCollection<T> colBottom = new BlobCollection<T>();
805 int start_iter = m_nIter;
806 int stop_iter = m_nIter + nIters;
807
808 m_rgLosses.Clear();
809 m_dfSmoothedLoss = 0;
810
811 // Break on first NaN is a debugging tool
812 // that causes the network to stop training
813 // right after a NaN is discovered either
814 // just after the forward pass or just
815 // after the backward pass.
816 m_net.EnableBreakOnFirstNaN = m_bEnableBreakOnNan && m_bEnableBlobDebugging;
817 m_net.EnableDetailedNanDetection = m_bEnableDetailedNanDetection & m_bEnableBlobDebugging;
818
819 Stopwatch sw = new Stopwatch();
820 sw.Start();
821
822 Stopwatch swTimeout = new Stopwatch();
823 swTimeout.Start();
824
825 while (m_nIter < stop_iter && !m_evtCompleted.WaitOne(0))
826 {
827 // zero-init the params.
828 if (bZeroDiffs)
829 m_net.ClearParamDiffs();
830
831 if (OnStart != null)
832 OnStart(this, new EventArgs());
833
834 if (step == TRAIN_STEP.NONE && (forceTest ||
835 (m_param.test_interval > 0 &&
836 (m_nIter % m_param.test_interval) == 0 &&
838 {
839 if (m_bEnableTest && is_root_solver)
840 m_dfLastAccuracy = TestAll();
841
842 // Break out of the while loop because a stop was requested while testing.
843 if (m_evtCancel.WaitOne(0))
844 break;
845 }
846
847 // on_start currently not used, so no event added.
848 bool bDisplay1 = (is_root_solver && m_param.display > 0 && (m_nIter % m_param.display) == 0 && !bDisableOutput) ? true : false;
849 m_net.set_debug_info(bDisplay1 && m_param.debug_info);
850
851 // accumulate the loss and gradient
852 double dfLoss = 0;
853 double dfLossTotal = 0;
854 int nIterCount = 0;
855
856 Stopwatch swTiming = new Stopwatch();
857 double dfTotalTime = 0;
858 int nTimingCount = 0;
859 bool bFwdPassNanFree = true;
860
861 for (int i = 0; i < m_param.iter_size; i++)
862 {
863 double dfLocalLoss;
864
865 swTiming.Restart();
866
867 if (OnCustomForwardBack != null)
868 {
870 OnCustomForwardBack(this, args);
871 bFwdPassNanFree = args.FwdPassNanFree;
872 dfLocalLoss = args.LocalLoss;
873 }
874 else
875 {
876 bFwdPassNanFree = m_net.ForwardBackward(colBottom, out dfLocalLoss, step);
877 }
878
879 if (double.IsNaN(dfLocalLoss) || double.IsInfinity(dfLocalLoss))
880 {
881 if (m_bFirstNanError)
882 {
883 m_log.WriteError(new Exception("The local loss at iteration " + m_nIter.ToString() + " is invalid (NAN or INFINITY)!"));
884 m_bFirstNanError = false;
885 }
886 }
887
888 dfLossTotal += dfLocalLoss;
889 swTiming.Stop();
890
891 dfTotalTime += swTiming.Elapsed.TotalMilliseconds;
892 nTimingCount++;
893 nIterCount++;
894
895 if (!bFwdPassNanFree)
896 break;
897 }
898
899 dfLoss = dfLossTotal / nIterCount;
900 dfLoss = dfLossOverride.GetValueOrDefault(dfLoss);
901
902 // average the loss across iterations for smoothed reporting
903 UpdateSmoothedLoss(dfLoss, start_iter);
904
905 bool bDisplay = false;
906 if (!bDisplay1 && sw.ElapsedMilliseconds > 2000 && !bDisableOutput)
907 {
908 bDisplay = true;
909 m_bFirstNanError = true;
910 sw.Restart();
911 }
912
913 if (bDisplay && bDisplay1)
914 {
915 m_log.WriteLine("Iteration " + m_nIter.ToString() + ", loss = " + m_dfSmoothedLoss.ToString());
916
917 BlobCollection<T> colResult = m_net.output_blobs;
918 int score_index = 0;
919
920 if (is_root_solver)
921 {
922 for (int j = 0; j < colResult.Count; j++)
923 {
924 double[] result_vec = Utility.ConvertVec<T>(colResult[j].update_cpu_data());
925 int nIdx = m_net.output_blob_indices[j];
926 string output_name = m_net.blob_names[nIdx];
927 double loss_weight = m_net.blob_loss_weights[nIdx];
928 double dfTotalLossWeight = 0;
929 int nResultCount = colResult[j].count();
930
931 for (int k = 0; k < nResultCount; k++)
932 {
934 {
935 string strOut = "";
936
937 if (loss_weight != 0)
938 strOut += " (* " + loss_weight.ToString() + " = " + (loss_weight * result_vec[k]).ToString() + " loss)";
939
940 m_log.WriteLine(" Train net output #" + score_index.ToString() + ": " + output_name + " = " + result_vec[k].ToString() + strOut);
941 score_index++;
942 }
943 else
944 {
945 dfTotalLossWeight += loss_weight * result_vec[k];
946 }
947 }
948
950 {
951 double dfAverage = dfTotalLossWeight / nResultCount;
952 m_log.WriteLine(" Average weighted score = " + dfAverage.ToString() + " for '" + output_name + "' - averaged over " + nResultCount.ToString("N0") + " results.");
953 }
954 }
955 }
956 }
957
958 if (OnGradientsReady != null && bFwdPassNanFree)
960
961 double dfLastLearningRate = 0;
962
963 if (step != TRAIN_STEP.FORWARD && bApplyUpdates)
964 dfLastLearningRate = ApplyUpdate(m_nIter);
965
966 if (m_evtCancel.WaitOne(0))
967 break;
968
969 if (!bDisableProgress)
970 m_log.Progress = (double)m_nIter / (double)stop_iter;
971
972 bool bSnapshotTaken = false;
973 bool bForceSnapshot = forceSnapshot;
974
975 if ((step == TRAIN_STEP.NONE || bAllowSnapshot.GetValueOrDefault(false)) && (is_root_solver && bFwdPassNanFree &&
976 (bForceSnapshot ||
977 (m_param.snapshot > 0 && (m_nIter % m_param.snapshot) == 0) ||
978 (m_dfLastAccuracy > m_dfBestAccuracy))))
979 {
980 bSnapshotTaken = true;
981 Snapshot(bForceSnapshot, ((m_param.snapshot > 0 && (m_nIter % m_param.snapshot) == 0)) ? true : false);
982
983 if (m_dfLastAccuracy > m_dfBestAccuracy)
984 m_dfBestAccuracy = m_dfLastAccuracy;
985 }
986
987 //-------------------------------------
988 // Call the training iteration event
989 // on the root solver.
990 //-------------------------------------
991 fireOnTrainingIterationEvent(bFwdPassNanFree, dfLoss, dfLastLearningRate, ref nTimingCount, ref dfTotalTime);
992
993 //-------------------------------------
994 // If single stepping, stop the solver.
995 //-------------------------------------
996 if (step != TRAIN_STEP.NONE || m_bEnableSingleStep)
997 {
998 if (step == TRAIN_STEP.BOTH)
999 {
1000 if (!bDisableOutput)
1001 m_log.WriteLine("Single step (both) triggered - solving stopped after a single forward/backward pass.");
1002 }
1003 else if (step == TRAIN_STEP.FORWARD)
1004 {
1005 if (!bDisableOutput)
1006 m_log.WriteLine("Single step (forward) triggered - solving stopped after a single forward pass.");
1007 }
1008 else if (step == TRAIN_STEP.BACKWARD)
1009 {
1010 if (!bDisableOutput)
1011 m_log.WriteLine("Single step (backward) triggered - solving stopped after a single backward pass.");
1012 }
1013 else
1014 {
1015 // When single stepping, force the snapshot so as to allow
1016 // debugging the net visually.
1017 if (!bSnapshotTaken)
1018 Snapshot(true, false);
1019 }
1020 break;
1021 }
1022
1023 //-------------------------------------
1024 // If a time-limit has been imposed
1025 // and we have exceeded it, stop
1026 // training.
1027 //-------------------------------------
1028 if (m_nTrainingTimeLimitInMinutes > 0 && swTimeout.Elapsed.TotalMinutes > m_nTrainingTimeLimitInMinutes)
1029 {
1030 m_log.WriteLine("A training time-limit of " + m_nTrainingTimeLimitInMinutes.ToString("N0") + " minutes has been exceeded - training will now stop.");
1031 return true;
1032 }
1033
1034 if (!bApplyUpdates)
1035 break;
1036 }
1037
1038 return true;
1039 }
1040 catch (Exception excpt)
1041 {
1042 err = excpt;
1043 throw excpt;
1044 }
1045 finally
1046 {
1047 if (err != null || m_evtCancel.WaitOne(0))
1048 {
1049 if (OnAborted != null)
1050 OnAborted(this, new EventArgs());
1051 }
1052 }
1053 }
1054
1061 public void Restore(byte[] rgWeights, byte[] rgState, string strSkipBlobTypes = null)
1062 {
1063 m_net.LoadWeights(rgWeights, m_persist, null, null, strSkipBlobTypes);
1064
1065 if (rgState != null)
1066 {
1067 m_log.WriteLine("Restoring previous solver state from restore state...");
1068 RestoreSolverState(rgState);
1069 }
1070 }
1071
1079 public void Snapshot(bool bForced, bool bScheduled, bool bUpdateDatabase = true)
1080 {
1081 m_log.WriteLine("Starting snap shot...");
1082 m_log.CHECK(is_root_solver, "Snapshot only supported on the root solver.");
1083
1084 if (OnSnapshot == null)
1085 return;
1086
1087 if (m_snapshotWeightUpdatemMethod == SNAPSHOT_WEIGHT_UPDATE_METHOD.DISABLED && !bForced)
1088 {
1089 m_log.WriteLine("WARNING: Snapshot UPDATE_METHOD = DISABLED.");
1090 return;
1091 }
1092
1093 SnapshotArgs args = GetSnapshotArgs(null, null, m_dfLastAccuracy, m_dfLastError, m_nIter, m_snapshotWeightUpdatemMethod);
1094 args.Forced = bForced;
1095 args.Scheduled = bScheduled;
1096 args.UpdateDatabase = bUpdateDatabase;
1097
1098 OnSnapshot(this, args);
1099 m_log.WriteLine("Snapshot completed.");
1100 }
1101
1102 private void args_OnGetWeights(object sender, GetBytesArgs e)
1103 {
1104 if (m_net != null)
1105 e.Data = m_net.SaveWeights(m_persist, m_param.snapshot_diff);
1106 }
1107
1108 private void args_OnGetState(object sender, GetBytesArgs e)
1109 {
1111 }
1112
1123 public SnapshotArgs GetSnapshotArgs(byte[] rgState, byte[] rgWeights, double dfAccuracy, double dfError, int nIteration, SNAPSHOT_WEIGHT_UPDATE_METHOD wtUpdt)
1124 {
1125 if (dfAccuracy == 0)
1126 dfAccuracy = 0.0001;
1127
1128 SnapshotArgs args = new SnapshotArgs(rgState, rgWeights, dfAccuracy, dfError, nIteration, wtUpdt);
1129
1132 args.SingleStep = m_bEnableSingleStep;
1133 args.OnGetState += args_OnGetState;
1134 args.OnGetWeights += args_OnGetWeights;
1135
1136 return args;
1137 }
1138
1143 {
1144 get { return m_nTrainingIterationOverride; }
1145 set { m_nTrainingIterationOverride = value; }
1146 }
1147
1152 {
1153 get { return m_nTestingIterationOverride; }
1154 set { m_nTestingIterationOverride = value; }
1155 }
1156
1160 public AutoResetEvent CompletedEvent
1161 {
1162 get { return m_evtCompleted; }
1163 }
1164
1169 {
1170 get { return m_evtCancel; }
1171 }
1172
1176 public double smoothed_loss
1177 {
1178 get { return m_dfSmoothedLoss; }
1179 }
1180
1185 {
1186 get { return m_param; }
1187 }
1188
1193 {
1194 get { return m_net; }
1195 }
1196
1200 public List<Net<T>> test_nets
1201 {
1202 get { return m_rgTestNets; }
1203 }
1204
1208 public int iter
1209 {
1210 get { return m_nIter; }
1211 }
1212
1217 {
1218 get { return m_param.type; }
1219 }
1220
1224 protected bool forceSnapshot
1225 {
1226 get
1227 {
1228 if (m_evtForceSnapshot == null)
1229 return false;
1230
1231 return m_evtForceSnapshot.WaitOne(0);
1232 }
1233 }
1234
1238 public bool forceTest
1239 {
1240 get
1241 {
1242 if (m_evtForceTest == null)
1243 return false;
1244
1245 return m_evtForceTest.WaitOne(0);
1246 }
1247 }
1248
1252 public int solver_count
1253 {
1254 get { return m_nSolverCount; }
1255 }
1256
1260 public int solver_rank
1261 {
1262 get { return m_nSolverRank; }
1263 }
1264
1271 public bool is_root_solver
1272 {
1273 get { return (m_nSolverRank == 0) ? true : false; }
1274 }
1275
1285 public double TestAll(int nIterationOverride = -1)
1286 {
1287 double dfTotalAccuracy = 0;
1288 double dfTotalTime = 0;
1289 int nTotalCount = 0;
1290
1291 for (int test_net_id = 0; test_net_id < m_rgTestNets.Count; test_net_id++)
1292 {
1293 if (m_evtCancel.WaitOne(0))
1294 return 0;
1295
1296 if (OnTest != null)
1297 {
1298 TestArgs args = new TestArgs(nIterationOverride, test_net_id);
1299 OnTest(this, args);
1300 dfTotalAccuracy += args.Accuracy;
1301 }
1302 else
1303 dfTotalAccuracy += testOne(nIterationOverride, test_net_id);
1304
1305 dfTotalTime += m_dfAverageTestTime;
1306 nTotalCount++;
1307 }
1308
1309 if (m_rgTestNets.Count == 0)
1310 {
1311 if (OnTest != null)
1312 {
1313 TestArgs args = new TestArgs(nIterationOverride, 0);
1314 OnTest(this, args);
1315 dfTotalAccuracy += args.Accuracy;
1316 }
1317 else
1318 dfTotalAccuracy += testOne(nIterationOverride, 0);
1319 }
1320
1321 double dfAccuracy = (m_rgTestNets.Count > 0) ? dfTotalAccuracy / m_rgTestNets.Count : 0;
1322
1323 if (m_rgAverageAccuracyWindow != null)
1324 {
1325 m_rgAverageAccuracyWindow.Add(dfAccuracy);
1326 m_rgAverageAccuracyWindow.RemoveAt(0);
1327 dfAccuracy = m_rgAverageAccuracyWindow.Average();
1328 }
1329
1330 if (OnTestingIteration != null)
1331 {
1332 double dfTime = (nTotalCount > 0) ? dfTotalTime / nTotalCount : 0;
1333 OnTestingIteration(this, new TestingIterationArgs<T>(m_nIter, dfAccuracy, dfTime));
1334 }
1335
1336 return dfAccuracy;
1337 }
1338
1339 private double testOne(int nIterationOverride = -1, int nTestNetId = 0)
1340 {
1341 switch (m_param.eval_type)
1342 {
1343 // Test SSD Detection
1344 case SolverParameter.EvaluationType.DETECTION:
1345 return TestDetection(nIterationOverride, nTestNetId);
1346
1347 // Perform regular classification Test.
1348 default:
1349 return TestClassification(nIterationOverride, nTestNetId);
1350 }
1351 }
1352
1359 public double TestDetection(int nIterationOverride = -1, int nTestNetId = 0)
1360 {
1361 Stopwatch sw = new Stopwatch();
1362 BBoxUtility<T> bboxUtil = new BBoxUtility<T>(m_cuda, m_log);
1363
1364 try
1365 {
1366 if (is_root_solver)
1367 m_log.WriteLine("Iteration " + m_nIter.ToString() + ", Testing net (#" + nTestNetId.ToString() + ")");
1368
1369 Net<T> test_net = m_net;
1370
1371 if (m_rgTestNets.Count > nTestNetId)
1372 {
1373 m_log.CHECK(m_rgTestNets[nTestNetId] != null, "The test net at " + nTestNetId.ToString() + " is null!");
1374 m_rgTestNets[nTestNetId].ShareTrainedLayersWith(m_net);
1375 test_net = m_rgTestNets[nTestNetId];
1376 }
1377
1378 Dictionary<int, Dictionary<int, List<Tuple<float, int>>>> rgAllTruePos = new Dictionary<int, Dictionary<int, List<Tuple<float, int>>>>();
1379 Dictionary<int, Dictionary<int, List<Tuple<float, int>>>> rgAllFalsePos = new Dictionary<int, Dictionary<int, List<Tuple<float, int>>>>();
1380 Dictionary<int, Dictionary<int, int>> rgAllNumPos = new Dictionary<int, Dictionary<int, int>>();
1381
1382 double dfLoss = 0;
1383
1384 if (nIterationOverride <= 0)
1385 nIterationOverride = TestingIterations;
1386
1387 int nIter = nIterationOverride;
1388 sw.Start();
1389
1390 for (int i = 0; i < nIter; i++)
1391 {
1392 // Check to see if stoppage of testing/training has been requested.
1393 if (m_evtCancel.WaitOne(0))
1394 break;
1395
1396 if (OnTestStart != null)
1397 OnTestStart(this, new EventArgs());
1398
1399 double iter_loss;
1400 BlobCollection<T> colResult = test_net.Forward(out iter_loss);
1401
1403 dfLoss += iter_loss;
1404
1405 for (int j = 0; j < colResult.Count; j++)
1406 {
1407 m_log.CHECK_EQ(colResult[j].width, 5, "The width must be = 5 for SSD.");
1408 double[] result_vec = Utility.ConvertVec<T>(colResult[j].update_cpu_data());
1409 int num_det = colResult[j].height;
1410
1411 for (int k = 0; k < num_det; k++)
1412 {
1413 int item_id = (int)result_vec[k * 5];
1414 int nLabel = (int)result_vec[k * 5 + 1];
1415
1416 // Special row for storing number of positives for a label.
1417 if (item_id == -1)
1418 {
1419 if (!rgAllNumPos.ContainsKey(j))
1420 rgAllNumPos.Add(j, new Dictionary<int, int>());
1421
1422 if (!rgAllNumPos[j].ContainsKey(nLabel))
1423 rgAllNumPos[j].Add(nLabel, (int)result_vec[k * 5 + 2]);
1424 else
1425 rgAllNumPos[j][nLabel] += (int)result_vec[k * 5 + 2];
1426 }
1427 // Normal row storing detection status.
1428 else
1429 {
1430 float fScore = (float)result_vec[k * 5 + 2];
1431 int tp = (int)result_vec[k * 5 + 3];
1432 int fp = (int)result_vec[k * 5 + 4];
1433
1434 // Ignore such case, which happens when a detection bbox is matched to
1435 // a difficult gt bbox and we don't evaluate on difficult gt bbox.
1436 if (tp == 0 && fp == 0)
1437 continue;
1438
1439 if (!rgAllTruePos.ContainsKey(j))
1440 rgAllTruePos.Add(j, new Dictionary<int, List<Tuple<float, int>>>());
1441
1442 if (!rgAllTruePos[j].ContainsKey(nLabel))
1443 rgAllTruePos[j].Add(nLabel, new List<Tuple<float, int>>());
1444
1445 if (!rgAllFalsePos.ContainsKey(j))
1446 rgAllFalsePos.Add(j, new Dictionary<int, List<Tuple<float, int>>>());
1447
1448 if (!rgAllFalsePos[j].ContainsKey(nLabel))
1449 rgAllFalsePos[j].Add(nLabel, new List<Tuple<float, int>>());
1450
1451 rgAllTruePos[j][nLabel].Add(new Tuple<float, int>(fScore, tp));
1452 rgAllFalsePos[j][nLabel].Add(new Tuple<float, int>(fScore, fp));
1453 }
1454 }
1455 }
1456
1457 if (sw.Elapsed.TotalMilliseconds > 1000)
1458 {
1459 m_log.Progress = (double)i / (double)nIter;
1460 m_log.WriteLine("Testing at " + m_log.Progress.ToString("P") + " " + i.ToString() + " of " + nIter.ToString() + "...");
1461 sw.Restart();
1462 }
1463 }
1464
1465 if (m_evtCancel.WaitOne(0))
1466 {
1467 m_log.WriteLine("Test interrupted.");
1468 return 0;
1469 }
1470
1472 {
1473 dfLoss /= m_param.test_iter[nTestNetId];
1474 m_log.WriteLine("Test loss: " + dfLoss.ToString());
1475 }
1476
1477 float fTotalmAP = 0;
1478 for (int i = 0; i < rgAllTruePos.Count; i++)
1479 {
1480 if (!rgAllTruePos.ContainsKey(i))
1481 m_log.FAIL("Missing output_blob true_pos: " + i.ToString());
1482
1483 Dictionary<int, List<Tuple<float, int>>> rgTruePos = rgAllTruePos[i];
1484
1485 if (!rgAllFalsePos.ContainsKey(i))
1486 m_log.FAIL("Missing output_blob false_pos: " + i.ToString());
1487
1488 Dictionary<int, List<Tuple<float, int>>> rgFalsePos = rgAllFalsePos[i];
1489
1490 if (!rgAllNumPos.ContainsKey(i))
1491 m_log.FAIL("Missing output_blob num_pos: " + i.ToString());
1492
1493 Dictionary<int, int> rgNumPos = rgAllNumPos[i];
1494
1495 Dictionary<int, float> rgAPs = new Dictionary<int, float>();
1496 float fmAP = 0.0f;
1497
1498 // Sort true_pos and false_pos with descending scores.
1499 foreach (KeyValuePair<int, int> kv in rgNumPos)
1500 {
1501 int nLabel = kv.Key;
1502 int nLabelNumPos = kv.Value;
1503
1504 if (!rgTruePos.ContainsKey(nLabel))
1505 {
1506 m_log.WriteLine("WARNING: Missing true_pos for label: " + nLabel.ToString() + "!");
1507 continue;
1508 }
1509 List<Tuple<float, int>> rgLabelTruePos = rgTruePos[nLabel];
1510
1511 if (!rgFalsePos.ContainsKey(nLabel))
1512 {
1513 m_log.WriteLine("WARNING: Missing false_pos for label: " + nLabel.ToString() + "!");
1514 continue;
1515 }
1516 List<Tuple<float, int>> rgLabelFalsePos = rgFalsePos[nLabel];
1517
1518 List<float> rgPrec;
1519 List<float> rgRec;
1520 float fAp = bboxUtil.ComputeAP(rgLabelTruePos, nLabelNumPos, rgLabelFalsePos, m_param.ap_version, out rgPrec, out rgRec);
1521
1522 if (!rgAPs.ContainsKey(nLabel))
1523 rgAPs.Add(nLabel, fAp);
1524 else
1525 rgAPs[nLabel] = fAp;
1526
1527 fmAP += fAp;
1528
1530 m_log.WriteLine("class " + nLabel.ToString() + ": " + fAp.ToString());
1531 }
1532
1533 fmAP /= rgNumPos.Count;
1534
1535 int nOutputBlobIdx = test_net.output_blob_indices[i];
1536 string strOutputName = test_net.blob_names[nOutputBlobIdx];
1537
1538 m_log.WriteLine(" Test net output #" + i.ToString() + ": " + strOutputName + " = " + fmAP.ToString());
1539 fTotalmAP += fmAP;
1540 }
1541
1542 return fTotalmAP / rgAllTruePos.Count;
1543 }
1544 catch (Exception excpt)
1545 {
1546 throw excpt;
1547 }
1548 finally
1549 {
1550 bboxUtil.Dispose();
1551 }
1552 }
1553
1560 public double TestClassification(int nIterationOverride = -1, int nTestNetId = 0)
1561 {
1562 bool bDisplay = (is_root_solver && m_param.display > 0 && (m_nIter % m_param.display) == 0) ? true : false;
1563
1564 if (bDisplay)
1565 m_log.WriteLine("Iteration " + m_nIter.ToString() + ", Testing net (#" + nTestNetId.ToString() + ")");
1566
1567 Net<T> test_net = m_net;
1568
1569 if (m_rgTestNets.Count > nTestNetId)
1570 {
1571 m_log.CHECK(m_rgTestNets[nTestNetId] != null, "The test net at " + nTestNetId.ToString() + " is null!");
1572 m_rgTestNets[nTestNetId].ShareTrainedLayersWith(m_net);
1573 test_net = m_rgTestNets[nTestNetId];
1574 }
1575
1576 List<double> test_score = new List<double>();
1577 List<int> test_score_output_id = new List<int>();
1578 double dfLoss = 0;
1579
1580 if (nIterationOverride <= 0)
1581 nIterationOverride = TestingIterations;
1582
1583 int nIter = nIterationOverride;
1584
1585 Stopwatch sw = new Stopwatch();
1586 sw.Start();
1587
1588 double dfTotalTiming = 0;
1589 int nTestCount = 0;
1590 int nAccuracyIdx = 0;
1591 int nMinRank = int.MaxValue;
1592 bool bAccuracyValid = false;
1593 Stopwatch swTiming = new Stopwatch();
1594
1595 for (int i = 0; i < nIter; i++)
1596 {
1597 // Check to see if stoppage of testing/training has been requested.
1598 if (m_evtCancel.WaitOne(0))
1599 break;
1600
1601 if (OnTestStart != null)
1602 OnTestStart(this, new EventArgs());
1603
1604 swTiming.Restart();
1605
1606 double iter_loss;
1607 BlobCollection<T> colResult = test_net.Forward(out iter_loss);
1608
1610 dfLoss += iter_loss;
1611
1612 TestResultArgs<T> args = new TestResultArgs<T>(colResult);
1613 if (OnTestResults != null)
1614 {
1615 OnTestResults(this, args);
1616 if (args.AccuracyValid)
1617 {
1618 test_score.Add(args.Accuracy);
1619 test_score_output_id.Add(1);
1620 bAccuracyValid = true;
1621 }
1622 }
1623
1624 if (!args.AccuracyValid)
1625 {
1626 if (i == 0)
1627 {
1628 for (int j = 0; j < colResult.Count; j++)
1629 {
1630 double[] result_vec = Utility.ConvertVec<T>(colResult[j].update_cpu_data());
1631
1632 for (int k = 0; k < colResult[j].count(); k++)
1633 {
1634 test_score.Add(result_vec[k]);
1635 test_score_output_id.Add(j);
1636 }
1637
1638 if (colResult[j].type == BLOB_TYPE.ACCURACY)
1639 {
1640 int nRank = (int)getNumber(colResult[j].Tag, 0);
1641 if (nRank < nMinRank)
1642 {
1643 nMinRank = nRank;
1644 nAccuracyIdx = j;
1645 }
1646 }
1647 }
1648 }
1649 else
1650 {
1651 int idx = 0;
1652
1653 for (int j = 0; j < colResult.Count; j++)
1654 {
1655 double[] result_vec = Utility.ConvertVec<T>(colResult[j].update_cpu_data());
1656
1657 for (int k = 0; k < colResult[j].count(); k++)
1658 {
1659 test_score[idx] += result_vec[k];
1660 idx++;
1661 }
1662 }
1663 }
1664 }
1665
1666 swTiming.Stop();
1667 dfTotalTiming += swTiming.Elapsed.TotalMilliseconds;
1668 nTestCount++;
1669
1670 if (sw.ElapsedMilliseconds > 2000)
1671 {
1672 double dfPct = (double)i / (double)nIter;
1673
1674 if (bDisplay)
1675 {
1676 m_log.Progress = dfPct;
1677 m_log.WriteLine("Testing '" + test_net.name + "' at " + dfPct.ToString("P"));
1678 }
1679
1680 sw.Restart();
1681 }
1682 }
1683
1684 m_dfAverageTestTime = (nTestCount > 0) ? dfTotalTiming / nTestCount : 0;
1685
1686 if (m_evtCancel.WaitOne(0))
1687 {
1688 m_log.WriteLine("Test interrupted.");
1689 return 0;
1690 }
1691
1693 {
1694 dfLoss /= m_param.test_iter[nTestNetId];
1695 m_log.WriteLine("Test loss: " + dfLoss.ToString());
1696 }
1697
1698 double dfFinalScore = 0;
1699
1700 if (bAccuracyValid)
1701 {
1702 dfFinalScore = test_score.Sum();
1703 int nTotal = test_score_output_id.Sum();
1704 dfFinalScore /= nTotal;
1705 }
1706 else
1707 {
1708 for (int i = 0; i < test_score.Count; i++)
1709 {
1710 int nIdxTestScore = test_score_output_id[i];
1711 int output_blob_index = test_net.output_blob_indices[nIdxTestScore];
1712 string output_name = test_net.blob_names[output_blob_index];
1713 double loss_weight = test_net.blob_loss_weights[output_blob_index];
1714 double dfMeanScore = test_score[i] / nIter;
1715 string strOut = "";
1716
1717 if (bDisplay)
1718 {
1719 if (loss_weight != 0)
1720 strOut += " (* " + loss_weight.ToString() + " = " + (loss_weight * dfMeanScore).ToString() + " loss)";
1721
1722 m_log.WriteLine(" Test net output #" + i.ToString() + ": " + output_name + " = " + dfMeanScore.ToString() + strOut);
1723 }
1724
1725 if (i == nAccuracyIdx)
1726 dfFinalScore = dfMeanScore;
1727 }
1728 }
1729
1730 if (test_score.Count == 0)
1731 return 0;
1732
1733 return dfFinalScore;
1734 }
1735
1736 private double getNumber(object value, double dfDefault)
1737 {
1738 if (value == null)
1739 return dfDefault;
1740
1741 if (value is sbyte)
1742 return (double)(sbyte)value;
1743
1744 if (value is byte)
1745 return (double)(byte)value;
1746
1747 if (value is short)
1748 return (double)(short)value;
1749
1750 if (value is ushort)
1751 return (double)(ushort)value;
1752
1753 if (value is int)
1754 return (double)(int)value;
1755
1756 if (value is uint)
1757 return (double)(uint)value;
1758
1759 if (value is long)
1760 return (double)(long)value;
1761
1762 if (value is ulong)
1763 return (double)(ulong)value;
1764
1765 if (value is float)
1766 return (double)(float)value;
1767
1768 if (value is double)
1769 return (double)value;
1770
1771 if (value is decimal)
1772 return (double)(decimal)value;
1773
1774 return dfDefault;
1775 }
1776
1783 public void UpdateSmoothedLoss(double dfLoss, int nStartIter, int nAverageLoss = 0)
1784 {
1785 if (nAverageLoss == 0)
1786 nAverageLoss = m_param.average_loss;
1787
1788 if (m_rgLosses.Count < nAverageLoss)
1789 {
1790 m_rgLosses.Add(dfLoss);
1791 int nCount = m_rgLosses.Count;
1792 m_dfSmoothedLoss = (m_dfSmoothedLoss * (nCount - 1) + dfLoss) / nCount;
1793 }
1794 else
1795 {
1796 int nIdx = (m_nIter - nStartIter) % nAverageLoss;
1797 m_dfSmoothedLoss += (dfLoss - m_rgLosses[nIdx]) / nAverageLoss;
1798 m_rgLosses[nIdx] = dfLoss;
1799 }
1800
1801 if (m_bWeightsUpdated)
1802 {
1803 m_dfSmoothedLoss = dfLoss;
1804 m_bWeightsUpdated = false;
1805 }
1806
1807 m_dfLastError = m_dfSmoothedLoss;
1808
1809 if (m_dfLastError < m_dfBestError)
1810 m_dfBestError = m_dfLastError;
1811 }
1812
1817 public abstract double ApplyUpdate(int nIterationOverride = -1);
1818
1822 protected abstract byte[] SnapshotSolverState();
1823
1827 protected abstract void RestoreSolverState(byte[] rgState);
1828
1846 public static SGDSolver<T> Create(CudaDnn<T> cuda, Log log, ProjectEx p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXImageDatabaseBase imgDb, IXPersist<T> persist, int nSolverCount = 1, int nSolverRank = 0, Net<T> shareNet = null, onGetWorkspace getws = null, onSetWorkspace setws = null)
1847 {
1848 SolverParameter solverParam = null;
1849
1850 if (p.SolverDescription != null)
1851 {
1852 RawProto protoSolver = RawProto.Parse(p.SolverDescription);
1853 solverParam = SolverParameter.FromProto(protoSolver);
1854 }
1855 else
1856 {
1857 solverParam = new param.SolverParameter();
1858 }
1859
1860 if (solverParam.net_param == null)
1861 {
1862 RawProto protoModel = RawProto.Parse(p.ModelDescription);
1863 solverParam.net_param = NetParameter.FromProto(protoModel);
1864 solverParam.net_param.ProjectID = p.ID;
1865 }
1866
1867 return Create(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, imgDb, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1868 }
1869
1887 public static SGDSolver<T> Create(CudaDnn<T> cuda, Log log, SolverParameter solverParam, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXImageDatabaseBase imgDb, IXPersist<T> persist, int nSolverCount = 1, int nSolverRank = 0, Net<T> shareNet = null, onGetWorkspace getws = null, onSetWorkspace setws = null)
1888 {
1889 SGDSolver<T> solver = null;
1890
1891 switch (solverParam.type)
1892 {
1893 case SolverParameter.SolverType.SGD:
1894 solver = new SGDSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, imgDb, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1895 break;
1896
1897 case SolverParameter.SolverType.NESTEROV:
1898 solver = new NesterovSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, imgDb, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1899 break;
1900
1901 case SolverParameter.SolverType.ADAGRAD:
1902 solver = new AdaGradSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, imgDb, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1903 break;
1904
1905 case SolverParameter.SolverType.ADADELTA:
1906 solver = new AdaDeltaSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, imgDb, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1907 break;
1908
1909 case SolverParameter.SolverType.ADAM:
1910 solver = new AdamSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, imgDb, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1911 break;
1912
1913 case SolverParameter.SolverType.RMSPROP:
1914 solver = new RmsPropSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, imgDb, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1915 break;
1916
1917 default:
1918 throw new NotImplementedException("The solver " + solverParam.type.ToString() + " is not implemented yet!");
1919 }
1920
1921 return solver;
1922 }
1923 }
1924
1925#pragma warning disable 1591
1926
1927 public class OutputCollection
1928 {
1929 OutputDataCollection m_rgError = new OutputDataCollection();
1930 OutputDataCollection m_rgAccuracy = new OutputDataCollection();
1931
1932 public OutputCollection()
1933 {
1934 }
1935
1936 public OutputDataCollection Errors
1937 {
1938 get { return m_rgError; }
1939 }
1940
1941 public OutputDataCollection Accuracies
1942 {
1943 get { return m_rgAccuracy; }
1944 }
1945 }
1946
1947 public class OutputDataCollection : IEnumerable<OutputData>
1948 {
1949 List<OutputData> m_rgData = new List<OutputData>();
1950
1951 public OutputDataCollection()
1952 {
1953 }
1954
1955 public List<OutputData> Data
1956 {
1957 get { return m_rgData; }
1958 }
1959
1960 public int Count
1961 {
1962 get { return m_rgData.Count; }
1963 }
1964
1965 public OutputData this[int nIdx]
1966 {
1967 get { return m_rgData[nIdx]; }
1968 set { m_rgData[nIdx] = value; }
1969 }
1970
1971 public void Add(int nTotal, string strName, int nIdx, double dfVal)
1972 {
1973 OutputData data = Find(strName);
1974
1975 if (data == null)
1976 {
1977 data = new OutputData(strName, nIdx);
1978 m_rgData.Add(data);
1979 }
1980
1981 data.Add(nTotal, dfVal);
1982 }
1983
1984 public OutputData Find(string strName)
1985 {
1986 foreach (OutputData data in m_rgData)
1987 {
1988 if (data.Name == strName)
1989 return data;
1990 }
1991
1992 return null;
1993 }
1994
1995 public IEnumerator<OutputData> GetEnumerator()
1996 {
1997 return m_rgData.GetEnumerator();
1998 }
1999
2000 IEnumerator IEnumerable.GetEnumerator()
2001 {
2002 return m_rgData.GetEnumerator();
2003 }
2004 }
2005
2006 public class OutputData
2007 {
2008 string m_strName;
2009 double m_dfValue = 0;
2010 int m_nIdx;
2011
2012 public OutputData(string strName, int nIdx)
2013 {
2014 m_strName = strName;
2015 m_nIdx = nIdx;
2016 }
2017
2018 public int Index
2019 {
2020 get { return m_nIdx; }
2021 }
2022
2023 public string Name
2024 {
2025 get { return m_strName; }
2026 }
2027
2028 public double Value
2029 {
2030 get { return m_dfValue; }
2031 set { m_dfValue = value; }
2032 }
2033
2034 public void Add(int nTotal, double dfVal)
2035 {
2036 double dfRatio = 1.0 / (double)nTotal;
2037 m_dfValue = (m_dfValue * (1.0 - dfRatio)) + (dfRatio * dfVal);
2038 }
2039 }
2040
2041#pragma warning restore 1591
2042}
The CancelEvent provides an extension to the manual cancel event that allows for overriding the manua...
Definition: CancelEvent.cs:17
bool WaitOne(int nMs=int.MaxValue)
Waits for the signal state to occur.
Definition: CancelEvent.cs:290
The Log class provides general output in text form.
Definition: Log.cs:13
void CHECK(bool b, string str)
Test a flag for true.
Definition: Log.cs:227
bool IsEnabled
Returns whether or not the Log is enabled.
Definition: Log.cs:50
void WriteLine(string str, bool bOverrideEnabled=false, bool bHeader=false, bool bError=false, bool bDisable=false)
Write a line of output.
Definition: Log.cs:80
bool Enable
Enables/disables the Log. When disabled, the Log does not output any data.
Definition: Log.cs:42
void FAIL(string str)
Causes a failure which throws an exception with the desciptive text.
Definition: Log.cs:394
double Progress
Get/set the progress associated with the Log.
Definition: Log.cs:147
void CHECK_EQ(double df1, double df2, string str)
Test whether one number is equal to another.
Definition: Log.cs:239
void WriteError(Exception e)
Write an error as output.
Definition: Log.cs:130
void CHECK_GT(double df1, double df2, string str)
Test whether one number is greater than another.
Definition: Log.cs:299
void CHECK_LE(double df1, double df2, string str)
Test whether one number is less than or equal to another.
Definition: Log.cs:263
void CHECK_GE(double df1, double df2, string str)
Test whether one number is greater than or equal to another.
Definition: Log.cs:287
The ProjectEx class manages a project containing the solver description, model description,...
Definition: ProjectEx.cs:15
string? SolverDescription
Get/set the solver description script used by the Project.
Definition: ProjectEx.cs:710
int ID
Returns the ID of the Project in the database.
Definition: ProjectEx.cs:517
string? ModelDescription
Get/set the model description script used by the Project.
Definition: ProjectEx.cs:741
The RawProto class is used to parse and output Google prototxt file data.
Definition: RawProto.cs:17
static RawProto Parse(string str)
Parses a prototxt and places it in a new RawProto.
Definition: RawProto.cs:306
The Utility class provides general utility funtions.
Definition: Utility.cs:35
The BBox class processes the NormalizedBBox data used with SSD.
Definition: BBoxUtility.cs:22
void Dispose()
Clean up all resources.
Definition: BBoxUtility.cs:43
float ComputeAP(List< Tuple< float, int > > rgTp, int nNumPos, List< Tuple< float, int > > rgFp, ApVersion apVersion, out List< float > rgPrec, out List< float > rgRec)
Compute the average precision given true positive and false positive vectors.
Definition: BBoxUtility.cs:69
The BlobCollection contains a list of Blobs.
void Add(Blob< T > b)
Add a new Blob to the collection.
int Count
Returns the number of items in the collection.
The Blob is the main holder of data that moves through the Layers of the Net.
Definition: Blob.cs:24
virtual void Dispose(bool bDisposing)
Releases all resources used by the Blob (including both GPU and Host).
Definition: Blob.cs:340
The CudaDnn object is the main interface to the Low-Level Cuda C++ DLL.
Definition: CudaDnn.cs:829
The CustomForwardBackArgs provide the arguments to the OnCustomForwardBack event within the Solver St...
Definition: EventArgs.cs:609
double LocalLoss
Get/set the local loss of the pass.
Definition: EventArgs.cs:655
bool FwdPassNanFree
Get/set whether or a NAN was detected in the forward pass.
Definition: EventArgs.cs:646
The DebugInformation contains information used to help debug the Layers of a Net while it is training...
string DetectFirstNaN(out string strType)
Searches for the first NaN within any of the Layers.
string DetectLastNaN(out string strType)
Searches for the last NaN within any of the Layers.
The GetBytesArgs is passed along to the SnapshotArgs::OnGetWeights and SnapshotArgs::OnGetState event...
Definition: EventArgs.cs:392
byte[] Data
Get/set the data as an array of bytes.
Definition: EventArgs.cs:406
The GetIterationArgs is sent bubbled up to the solver when a layer needs to know the curret training ...
Definition: EventArgs.cs:748
void SetIteration(Phase p, int nIteration)
The SetIteration method is used to set the iteration and the phase.
Definition: EventArgs.cs:764
The GradientsReadyArgs is sent to the Solver::OnGradientsReady event which fires at the end of each S...
Definition: EventArgs.cs:734
Connects Layer's together into a direct acrylic graph (DAG) specified by a NetParameter
Definition: Net.cs:23
BlobCollection< T > Forward()
Run forward with the input Blob's already fed separately.
Definition: Net.cs:1406
List< string > blob_names
Returns the blob names.
Definition: Net.cs:1943
List< double > blob_loss_weights
Returns the collection of blob loss weights.
Definition: Net.cs:2025
string name
Returns the network name.
Definition: Net.cs:1927
List< int > output_blob_indices
Returns a list of the output Blob indexes.
Definition: Net.cs:2173
The SnapshotArgs is sent to the Solver::OnSnapshot event which fires each time the Solver::Snapshot m...
Definition: EventArgs.cs:416
bool Forced
Get/set whether or not the snapshot was forced or not.
Definition: EventArgs.cs:580
bool SingleStep
Get/set the Solver single step.
Definition: EventArgs.cs:571
bool IncludeWeights
Get/set whether or not to include the weights in the snapshot.
Definition: EventArgs.cs:553
bool Scheduled
Get/set whether or not the snapshot is a regular scheduled snapshot (e.g. not an improved accuracy or...
Definition: EventArgs.cs:589
bool IncludeState
Get/set whether or not to include the Solver state in the snapshot.
Definition: EventArgs.cs:562
EventHandler< GetBytesArgs > OnGetState
Specifies the OnGetState event which fires when the SnapshotArgs::UpdateState method is called.
Definition: EventArgs.cs:444
bool UpdateDatabase
Get/set whether or not to update the database (default = true).
Definition: EventArgs.cs:598
EventHandler< GetBytesArgs > OnGetWeights
Specifies the OnGetWeights event which fires when the SnapshotArgs::UpdateWeights method is called.
Definition: EventArgs.cs:437
The TestArgs are passed to the Solver::OnTest event.
Definition: EventArgs.cs:169
double Accuracy
Get/set the accuracy for the test run. When overriding the testing, the override should set the accur...
Definition: EventArgs.cs:205
The TestResultArgs are passed to the Solver::OnTestResults event.
Definition: EventArgs.cs:116
bool AccuracyValid
Get/set the accuracy valid flag. When not valid, the OnTestResults event is ignored.
Definition: EventArgs.cs:156
double Accuracy
Get/set the accuracy. The recipient of this event should set this value.
Definition: EventArgs.cs:143
Specifies the TestingIterationArgs sent to the Solver::OnTestingIteration, which is called at the end...
Definition: EventArgs.cs:216
The TrainingIterationArgs is sent to the Solver::OnTrainingIteration event that fires at the end of a...
Definition: EventArgs.cs:264
The WorkspaceArgs are passed to both the Layer::OnSetWorkspace and Layer::OnGetWorkspace events.
Definition: EventArgs.cs:17
long Data
Get/set the handle to workspace data in GPU memory.
Definition: EventArgs.cs:36
ulong Size
Get/set the size of the workspace memory (in bytes).
Definition: EventArgs.cs:45
The Database class manages the actual connection to the physical database using Entity Framworks from...
Definition: Database.cs:23
Specifies the parameters use to create a Net
Definition: NetParameter.cs:16
static NetParameter FromProto(RawProto rp)
Parse a RawProto into a new instance of the parameter.
NetState state
The current 'state' of the network, including the phase, level and stage. Some layers may be included...
int ProjectID
Specifies the ID of the project that created this net param (if any).
Definition: NetParameter.cs:78
int solver_rank
Specifies the rank of the solver using this network.
int solver_count
Specifies the number of solvers used in a multi-gpu training session.
NetParameter Clone(bool bCloneLayers=true, int? nSolverCount=null, int? nSolverRank=null)
Creates a new copy of this instance of the parameter.
Specifies the NetState which includes the phase, level and stage for which a given Net is to run unde...
Definition: NetState.cs:17
Phase phase
Specifies the Phase of the NetState.
Definition: NetState.cs:61
void MergeFrom(NetState ns)
Merges another NetState with this instance.
Definition: NetState.cs:96
The SolverParameter is a parameter for the solver, specifying the train and test networks.
int max_iter
The maximum number of iterations.
List< int > test_iter
The number of iterations for each test.
NetParameter net_param
Inline train net param, possibly combined with one or more test nets.
bool debug_info
If true, print information about the state of the net that may help with debugging learning problems.
NetParameter train_net_param
Inline train net param, possibly combined with one or more test nets.
List< NetState > test_state
The states for the train/test nets. Must be unspecified or specified once per net.
SolverType
Defines the type of solver.
string lr_policy
The learning rate decay policy.
static SolverParameter FromProto(RawProto rp)
Parses a new SolverParameter from a RawProto.
ApVersion ap_version
Specifies the AP Version to use for average precision when using Single-Shot Detection (SSD) - (defau...
long random_seed
If non-negative, the seed with which the Solver will initialize the caffe random number generator – u...
int average_loss
Display the loss averaged over the last average_loss iterations.
int test_interval
The number of iterations between two testing phases.
bool output_average_results
Specifies to average loss results before they are output - this can be faster when there are a lot of...
int iter_size
Accumulate gradients over 'iter_size' x 'batch_size' instances.
string DebugString()
Returns a debug string for the SolverParameter.
EvaluationType
Defines the evaluation method used in the SSD algorithm.
bool snapshot_after_train
If false, don't save a snapshot after training finishes.
bool snapshot_include_weights
Specifies whether or not the snapshot includes the trained weights. The default = true.
bool test_compute_loss
Test the compute loss.
SolverParameter()
The SolverParameter constructor.
EvaluationType eval_type
Specifies the evaluation type to use when using Single-Shot Detection (SSD) - (default = NONE,...
bool test_initialization
If true, run an initial test pass before the first iteration, ensuring memory availability and printi...
List< NetParameter > test_net_param
Inline test net params.
int display
The number of iterations between displaying info. If display = 0, no info will be displayed.
bool snapshot_diff
Whether to snapshot diff in the results or not. Snapshotting diff will help debugging but the final p...
bool snapshot_include_state
Specifies whether or not the snapshot includes the solver state. The default = false....
bool show_per_class_result
Specifies whether or not to display results per class when using Single-Shot Detection (SSD) - (defau...
int accuracy_average_window
Specifies the window over which to average the accuracies (default = 0 which ignores averaging).
int snapshot
Specifies the snapshot interval.
SolverType type
Specifies the solver type.
NetState train_state
The states for the train/test nets. Must be unspecified or specified once per net.
Use AdaDelta Solver which has gradient based optimization like SGD.
Use AdaGrad Solver based optimization like SGD that tries to find rarely seen features.
Use Adam Solver which uses gradient based optimization like SGD that includes 'adaptive momentum esti...
Definition: AdamSolver.cs:22
Use Nesterov's accelerated gradient Solver, which is similar to SGD, but the error gradient is comput...
Use RmsProp Solver which uses gradient based optimization like SGD.
Stochastic Gradient Descent solver with momentum updates weights by a linear combination of the negat...
Definition: SGDSolver.cs:22
An interface for classes that perform optimization on Nets
Definition: Solver.cs:28
List< Net< T > > m_rgTestNets
Specifies the testing Nets.
Definition: Solver.cs:48
int TrainingIterations
Returns the current training iterations remaining.
Definition: Solver.cs:688
void InitTestNets()
Initializes the Net used by the Solver for testing.
Definition: Solver.cs:559
EventHandler< CustomForwardBackArgs< T > > OnCustomForwardBack
The OnCustomForwardBack allows for overriding the forward/backward operations within the solver.
Definition: Solver.cs:147
int m_nSolverCount
Specifies the Solver count in a multi-GPU training session.
Definition: Solver.cs:75
void Dispose()
Discards the resources (GPU and Host) used by this Solver.
Definition: Solver.cs:210
SolverParameter m_param
Specifies the SolverParameter that defines how the Solver operates.
Definition: Solver.cs:40
EventHandler< TrainingIterationArgs< T > > OnTrainingIteration
The OnTrainingIteration event fires at the end of each training iteration.
Definition: Solver.cs:126
List< double > m_rgLosses
Specifies the Losses used to calculate the smoothed Loss.
Definition: Solver.cs:60
abstract byte[] SnapshotSolverState()
Save the current solver state.
double smoothed_loss
Returns the smoothed loss.
Definition: Solver.cs:1177
void Restore(byte[] rgWeights, byte[] rgState, string strSkipBlobTypes=null)
The restore method simply calls the RestoreSolverState method of the inherited class.
Definition: Solver.cs:1061
static SGDSolver< T > Create(CudaDnn< T > cuda, Log log, SolverParameter solverParam, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXImageDatabaseBase imgDb, IXPersist< T > persist, int nSolverCount=1, int nSolverRank=0, Net< T > shareNet=null, onGetWorkspace getws=null, onSetWorkspace setws=null)
Create a new Solver based on the project containing the SolverParameter.
Definition: Solver.cs:1887
int iter
Returns the current training iteration.
Definition: Solver.cs:1209
CudaDnn< T > m_cuda
Specifies the instance of CudaDnn used by the Solver that provides a connection to Cuda.
Definition: Solver.cs:32
void Snapshot(bool bForced, bool bScheduled, bool bUpdateDatabase=true)
The snapshot function implements the basic snapshotting utility that stores the learned net....
Definition: Solver.cs:1079
int MaximumIteration
Returns the maximum training iterations.
Definition: Solver.cs:680
EventHandler< SnapshotArgs > OnSnapshot
The OnSnapshot event fires when the Solver detects that a snapshot is needed.
Definition: Solver.cs:122
bool EnableBlobDebugging
When enabled, the OnTrainingIteration event is set extra debugging information describing the state o...
Definition: Solver.cs:351
SolverParameter.SolverType type
Returns the type of solver.
Definition: Solver.cs:1217
Net< T > net
Returns the main training Net.
Definition: Solver.cs:1193
bool ForceOnTrainingIterationEvent()
Force an OnTrainingIterationEvent to fire.
Definition: Solver.cs:228
object Tag
Returns a generic tag associated with the Solver.
Definition: Solver.cs:412
double TestDetection(int nIterationOverride=-1, int nTestNetId=0)
Run an SSD detection test on a given test Net by running it through its iterations.
Definition: Solver.cs:1359
bool? is_root_solver
Returns whether or not this is the root solver.
Definition: Solver.cs:1272
double LearningRateOverride
Get/set the learning rate override. When 0, this setting is ignored.
Definition: Solver.cs:219
bool EnableTesting
When enabled, the training cycle calls TestAll periodically based on the SolverParameter....
Definition: Solver.cs:342
int m_nIter
Specifies the current iteration.
Definition: Solver.cs:52
Net< T > TrainingNet
Returns the training Net used by the solver.
Definition: Solver.cs:435
double m_dfLearningRateOverride
Optionally, specifies a learning rate override (default = 0, which ignores this setting).
Definition: Solver.cs:87
EventHandler< TestingIterationArgs< T > > OnTestingIteration
The OnTestingIteration event fires at the end of each testing iteration.
Definition: Solver.cs:130
void InitTrainNet(Net< T > shareNet=null)
Initializes the Net used by the solver for training.
Definition: Solver.cs:476
abstract void RestoreSolverState(byte[] rgState)
Restore a solver state.
void UpdateSmoothedLoss(double dfLoss, int nStartIter, int nAverageLoss=0)
Update the avaraged loss value.
Definition: Solver.cs:1783
void Init(SolverParameter p, Net< T > shareNet=null)
Initializes the Solver.
Definition: Solver.cs:444
bool EnableBreakOnFirstNaN
When enabled (requires EnableBlobDebugging = true), the Solver immediately stop training upon detecti...
Definition: Solver.cs:372
int solver_count
Returns the solver count in a multi-GPU session.
Definition: Solver.cs:1253
SolverParameter parameter
Returns the SolverParameter used.
Definition: Solver.cs:1185
bool forceSnapshot
Returns whether or not a snapshot has been forced.
Definition: Solver.cs:1225
SNAPSHOT_WEIGHT_UPDATE_METHOD SnapshotWeightUpdateMethod
Get/set the snapshot weight update method.
Definition: Solver.cs:293
EventHandler OnAborted
The OnAborted event fires after aborting a training cycle.
Definition: Solver.cs:114
List< Net< T > > test_nets
Returns the testing Nets.
Definition: Solver.cs:1201
static SGDSolver< T > Create(CudaDnn< T > cuda, Log log, ProjectEx p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXImageDatabaseBase imgDb, IXPersist< T > persist, int nSolverCount=1, int nSolverRank=0, Net< T > shareNet=null, onGetWorkspace getws=null, onSetWorkspace setws=null)
Create a new Solver based on the project containing the SolverParameter.
Definition: Solver.cs:1846
IXPersist< T > m_persist
Specifies the persistance object used to save weight and solver states.
Definition: Solver.cs:83
int TrainingTimeLimitInMinutes
Get/set the training time limit in minutes. When set to 0, no time limit is imposed on training.
Definition: Solver.cs:284
EventHandler< WorkspaceArgs > OnGetWorkspace
Specifies the OnGetWorkspace event that fires when the getWorkspace() function is called by a layer t...
Definition: Solver.cs:151
double TestClassification(int nIterationOverride=-1, int nTestNetId=0)
Run a test on a given test Net by running it through its iterations.
Definition: Solver.cs:1560
void Reset()
Reset the iterations of the net.
Definition: Solver.cs:466
bool Step(int nIters, TRAIN_STEP step=TRAIN_STEP.NONE, bool bZeroDiffs=true, bool bApplyUpdates=true, bool bDisableOutput=false, bool bDisableProgress=false, double? dfLossOverride=null, bool? bAllowSnapshot=null)
Steps a set of iterations through a training cycle.
Definition: Solver.cs:798
double TestAll(int nIterationOverride=-1)
Run a TestAll by running all test Nets.
Definition: Solver.cs:1285
string LabelQueryEpochs
Return the label query epochs for the active datasource.
Definition: Solver.cs:664
EventHandler< TestResultArgs< T > > OnTestResults
When specified, the OnTestResults event fires after each single test run. The recipient is responsibl...
Definition: Solver.cs:134
EventHandler< GradientsReadyArgs > OnGradientsReady
The OnGradientsReady event fires after the gradients of a Solver are ready for distribution to other ...
Definition: Solver.cs:118
EventHandler< WorkspaceArgs > OnSetWorkspace
Specifies the OnSetWorkspace event that fires when the setWorkspace() function is called by a layer t...
Definition: Solver.cs:155
int? TestingIterations
Returns the current testing iterations remaining.
Definition: Solver.cs:704
bool forceTest
Returns whether or not a test has been forced.
Definition: Solver.cs:1239
int TrainingIterationOverride
Get/set the training iteration override.
Definition: Solver.cs:1143
EventHandler OnTestStart
The OnTestStart event fires at the start of each testing iteration.
Definition: Solver.cs:142
Net< T > m_net
Specifies the training Net.
Definition: Solver.cs:44
bool WeightsUpdated
Get/set when the weights have been updated.
Definition: Solver.cs:403
int m_nCurrentStep
Specifies the current step.
Definition: Solver.cs:56
int solver_rank
Returns this Solver's rank in a multi-GPU session.
Definition: Solver.cs:1261
bool EnableDetailedNanDetection
When enabled (requires EnableBlobDebugging = true), the detailed Nan (and Infinity) detection is pero...
Definition: Solver.cs:385
Log m_log
Specifies the Log for output.
Definition: Solver.cs:36
SnapshotArgs GetSnapshotArgs(byte[] rgState, byte[] rgWeights, double dfAccuracy, double dfError, int nIteration, SNAPSHOT_WEIGHT_UPDATE_METHOD wtUpdt)
The GetSnapshotArgs method fills out a snapshot args structure.
Definition: Solver.cs:1123
virtual void dispose()
Override that allows discarding of resources (GPU and Host) used by this Solver.
Definition: Solver.cs:309
EventHandler< TestArgs > OnTest
When specified, the OnTest event fires during a TestAll and overrides the call to Test.
Definition: Solver.cs:138
int TestingIterationOverride
Get/set the testing iteration override.
Definition: Solver.cs:1152
virtual void Solve(int nIterationOverride=-1, byte[] rgWeights=null, byte[] rgState=null, TRAIN_STEP step=TRAIN_STEP.NONE)
The main entry of the solver function. In default, iter will be zero. Pass in a non-zero iter number ...
Definition: Solver.cs:724
EventHandler OnStart
The OnStart event fires at the start of each training iteration.
Definition: Solver.cs:110
string ActiveLabelCounts
Returns a string describing the labels detected in the training along with the % that each label has ...
Definition: Solver.cs:648
AutoResetEvent CompletedEvent
Returns an auto reset event that is set upon training completion.
Definition: Solver.cs:1161
abstract double ApplyUpdate(int nIterationOverride=-1)
Make and apply the update value for the current iteration.
CudaDnn< T > Cuda
Returns the CudaDnn instance used by the Solver.
Definition: Solver.cs:640
bool EnableSingleStep
When enabled (requires EnableBlobDebugging = true), the Solver only runs one training cycle.
Definition: Solver.cs:394
int m_nSolverRank
Specifies the Solver rank of this solver, where rank == 0 is the root Solver.
Definition: Solver.cs:79
string LabelQueryHitPercents
Return the label query hit percentages for the active datasource.
Definition: Solver.cs:656
Net< T > TestingNet
Returns the testing Net used by the solver.
Definition: Solver.cs:421
bool EnableLayerDebugging
Enable/disable layer debugging which causes each layer to check for NAN/INF on each forward/backward ...
Definition: Solver.cs:363
Solver(CudaDnn< T > cuda, Log log, SolverParameter p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXImageDatabaseBase imgDb, IXPersist< T > persist, int nSolverCount=1, int nSolverRank=0, Net< T > shareNet=null, onGetWorkspace getws=null, onSetWorkspace setws=null)
The Solver constructor.
Definition: Solver.cs:173
int CurrentIteration
Returns the current training iteration.
Definition: Solver.cs:672
The IXImageDatabaseBase interface defines the general interface to the in-memory image database.
Definition: Interfaces.cs:415
The IXPersist interface is used by the CaffeControl to load and save weights.
Definition: Interfaces.cs:175
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Definition: Annotation.cs:12
Phase
Defines the Phase under which to run a Net.
Definition: Interfaces.cs:42
SNAPSHOT_WEIGHT_UPDATE_METHOD
Defines the snapshot weight update method.
Definition: Interfaces.cs:162
The MyCaffe.common namespace contains common MyCaffe classes.
Definition: BatchInput.cs:8
BLOB_TYPE
Defines the tpe of data held by a given Blob.
Definition: Interfaces.cs:62
TRAIN_STEP
Defines the training stepping method (if any).
Definition: Interfaces.cs:119
The MyCaffe.db.image namespace contains all image database related classes.
Definition: Database.cs:18
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.solvers namespace contains all solver classes, including the base Solver.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...
Definition: Annotation.cs:12