2using System.Collections.Generic;
27 public abstract class Solver<T> : IDisposable
61 AutoResetEvent m_evtCompleted =
new AutoResetEvent(
false);
62 bool m_bEnableTest =
true;
63 bool m_bEnableBlobDebugging =
false;
64 bool m_bEnableBreakOnNan =
false;
65 bool m_bEnableDetailedNanDetection =
false;
66 bool m_bEnableSingleStep =
false;
72 AutoResetEvent m_evtForceSnapshot;
73 AutoResetEvent m_evtForceTest;
90 double m_dfLastAccuracy = 0;
91 double m_dfLastError =
double.MaxValue;
92 double m_dfBestAccuracy = 0;
93 double m_dfBestError =
double.MaxValue;
95 int m_nTrainingIterationOverride = -1;
96 int m_nTestingIterationOverride = -1;
98 bool m_bWeightsUpdated =
false;
99 static object m_syncGetRi =
new object();
100 Blob<T> m_blobBatchInputData =
null;
101 double m_dfAverageTestTime = 0;
103 int m_nTrainingTimeLimitInMinutes = 0;
104 long m_hWorkspaceData = 0;
105 ulong m_lWorkspaceSize = 0;
106 bool m_bFirstNanError =
true;
107 List<double> m_rgAverageAccuracyWindow =
null;
140 public event EventHandler<TestArgs>
OnTest;
175 public Solver(
CudaDnn<T> cuda,
Log log,
SolverParameter p,
CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest,
IXImageDatabaseBase imgDb,
IXPersist<T> persist,
int nSolverCount = 1,
int nSolverRank = 0,
Net<T> shareNet =
null, onGetWorkspace getws =
null, onSetWorkspace setws =
null)
179 m_evtCancel = evtCancel;
180 m_evtForceSnapshot = evtForceSnapshot;
181 m_evtForceTest = evtForceTest;
199 m_rgAverageAccuracyWindow =
new List<double>();
202 m_rgAverageAccuracyWindow.Add(0);
232 int nTimingCount = 0;
233 double dfTotalTime = 0;
234 return fireOnTrainingIterationEvent(
false, 0, 0, ref nTimingCount, ref dfTotalTime);
237 private bool fireOnTrainingIterationEvent(
bool bFwdPassNanFree,
double dfLoss,
double dfLastLearningRate, ref
int nTimingCount, ref
double dfTotalTime)
241 string strFirstNanBlob =
null;
244 if (m_bEnableBlobDebugging)
246 dbgInfo =
TrainingNet.GetDebugInformation(m_bEnableDetailedNanDetection);
248 if (m_bEnableBreakOnNan && dbgInfo !=
null)
253 if (strFirstNanBlob !=
null)
255 string strPass = (!bFwdPassNanFree) ?
"Forward" :
"Backward";
256 m_log.
WriteLine(
"First NaN detected in the '" + strType +
"' of blob '" + strFirstNanBlob +
"' after " + strPass +
" pass.");
259 string strLastNanBlob = dbgInfo.
DetectLastNaN(out strTypeLast);
261 if (strLastNanBlob != strFirstNanBlob && strType != strTypeLast)
262 m_log.
WriteLine(
"Last NaN detected in the '" + strTypeLast +
"' of blob '" + strLastNanBlob +
"' after " + strPass +
" pass.");
267 double dfTime = (nTimingCount > 0) ? (dfTotalTime / nTimingCount) : 0;
268 OnTrainingIteration(
this,
new TrainingIterationArgs<T>(
m_nIter, m_dfLastAccuracy, dfLoss,
m_dfSmoothedLoss, m_dfBestError, m_bWeightsUpdated,
m_net.ActiveLabelCounts,
m_net.LabelQueryHitPercents,
m_net.LabelQueryEpochs,
m_net.BoostQueryHitPercents, dfLastLearningRate, dfTime, dbgInfo));
272 if (strFirstNanBlob !=
null)
274 m_log.
WriteLine(
"Training is now stopping at iteration " +
m_nIter.ToString(
"N0") +
" as the first NaN has been detected ('" + strFirstNanBlob +
"').");
287 get {
return m_nTrainingTimeLimitInMinutes; }
288 set { m_nTrainingTimeLimitInMinutes = value; }
296 get {
return m_snapshotWeightUpdatemMethod; }
297 set { m_snapshotWeightUpdatemMethod = value; }
326 if (m_blobBatchInputData !=
null)
328 m_blobBatchInputData.
Dispose();
329 m_blobBatchInputData =
null;
332 if (m_hWorkspaceData != 0)
334 m_cuda.FreeMemory(m_hWorkspaceData);
335 m_hWorkspaceData = 0;
336 m_lWorkspaceSize = 0;
345 get {
return m_bEnableTest; }
346 set { m_bEnableTest = value; }
354 get {
return m_bEnableBlobDebugging; }
355 set { m_bEnableBlobDebugging = value; }
375 get {
return m_bEnableBreakOnNan; }
376 set { m_bEnableBreakOnNan = value; }
388 get {
return m_bEnableDetailedNanDetection; }
389 set { m_bEnableDetailedNanDetection = value; }
397 get {
return m_bEnableSingleStep; }
398 set { m_bEnableSingleStep = value; }
406 get {
return m_bWeightsUpdated; }
407 set { m_bWeightsUpdated = value; }
415 get {
return m_tag; }
416 set { m_tag = value; }
438 get {
return m_net; }
483 string field_names =
"net_param, train_net_param";
484 m_log.
CHECK_GE(num_train_nets, 1,
"SolverParameter must specify a train net using one of these fields: " + field_names);
485 m_log.
CHECK_LE(num_train_nets, 1,
"SolverParameter must not contain more than one of these fields specifying a train_net: " + field_names);
490 m_log.
WriteLine(
"Creating training net specified in train_net_param.");
508 net_param.
state = net_state;
511 m_net =
new Net<T>(
m_cuda,
m_log, net_param, m_evtCancel, m_db,
Phase.NONE, m_evtCompleted, shareNet, net_OnGetWorkspace, net_OnSetWorkspace);
512 m_net.OnGetIteration += net_OnGetIteration;
514 catch(Exception excpt)
516 throw new Exception(
"Initializing Training Net: " + excpt.Message);
520 private void net_OnSetWorkspace(
object sender,
WorkspaceArgs e)
528 if (e.
Size <= m_lWorkspaceSize)
531 m_lWorkspaceSize = e.
Size;
532 m_cuda.DisableGhostMemory();
534 if (m_hWorkspaceData != 0)
535 m_cuda.FreeMemory(m_hWorkspaceData);
537 m_hWorkspaceData =
m_cuda.AllocMemory((
long)m_lWorkspaceSize);
538 m_cuda.ResetGhostMemory();
541 private void net_OnGetWorkspace(
object sender,
WorkspaceArgs e)
549 e.
Data = m_hWorkspaceData;
550 e.
Size = m_lWorkspaceSize;
567 int num_test_nets = num_test_net_params;
569 if (num_generic_nets > 0)
580 int num_test_net_instances = num_test_nets + num_generic_net_instances;
585 if (num_test_net_instances > 0)
588 List<string> sources =
new List<string>();
589 List<NetParameter> net_params =
new List<NetParameter>();
591 for (
int i = 0; i < num_test_net_params; i++)
593 sources.Add(
"test_net_param");
601 for (
int i = 0; i < remaining_test_nets; i++)
603 sources.Add(
"net_param");
610 for (
int i = 0; i < num_test_net_instances; i++)
618 net_state.
MergeFrom(net_params[i].state);
623 net_params[i].state = net_state;
625 m_log.
WriteLine(
"Creating test net (#" + i.ToString() +
") specified by " + sources[i],
true);
632 catch (Exception excpt)
634 throw new Exception(
"Initializing Testing Nets: " + excpt.Message);
651 get {
return m_net.ActiveLabelCounts; }
659 get {
return m_net.LabelQueryHitPercents; }
667 get {
return m_net.LabelQueryEpochs; }
695 if (m_nTrainingIterationOverride > 0)
696 nIters = m_nTrainingIterationOverride;
711 if (m_nTestingIterationOverride > 0)
712 nIters = m_nTestingIterationOverride;
726 public virtual void Solve(
int nIterationOverride = -1,
byte[] rgWeights =
null,
byte[] rgState =
null,
TRAIN_STEP step =
TRAIN_STEP.NONE)
732 if (rgWeights !=
null || rgState !=
null)
739 if (nIterationOverride <= 0)
742 if (!
Step(nIterationOverride, step))
749 else if (
m_net.learnable_parameters.SnapshotRequested(
true))
767 m_net.Forward(out dfLoss);
781 if (m_blobBatchInputData !=
null)
783 m_blobBatchInputData.
Dispose();
784 m_blobBatchInputData =
null;
800 public bool Step(
int nIters,
TRAIN_STEP step =
TRAIN_STEP.NONE,
bool bZeroDiffs =
true,
bool bApplyUpdates =
true,
bool bDisableOutput =
false,
bool bDisableProgress =
false,
double? dfLossOverride =
null,
bool? bAllowSnapshot =
null)
802 Exception err =
null;
808 int stop_iter =
m_nIter + nIters;
818 m_net.EnableBreakOnFirstNaN = m_bEnableBreakOnNan && m_bEnableBlobDebugging;
819 m_net.EnableDetailedNanDetection = m_bEnableDetailedNanDetection & m_bEnableBlobDebugging;
821 Stopwatch sw =
new Stopwatch();
824 Stopwatch swTimeout =
new Stopwatch();
827 while (
m_nIter < stop_iter && !m_evtCompleted.WaitOne(0))
831 m_net.ClearParamDiffs();
834 OnStart(
this,
new EventArgs());
855 double dfLossTotal = 0;
858 Stopwatch swTiming =
new Stopwatch();
859 double dfTotalTime = 0;
860 int nTimingCount = 0;
861 bool bFwdPassNanFree =
true;
878 bFwdPassNanFree =
m_net.ForwardBackward(colBottom, out dfLocalLoss, step);
881 if (
double.IsNaN(dfLocalLoss) ||
double.IsInfinity(dfLocalLoss))
883 if (m_bFirstNanError)
885 m_log.
WriteError(
new Exception(
"The local loss at iteration " +
m_nIter.ToString() +
" is invalid (NAN or INFINITY)!"));
886 m_bFirstNanError =
false;
890 dfLossTotal += dfLocalLoss;
893 dfTotalTime += swTiming.Elapsed.TotalMilliseconds;
897 if (!bFwdPassNanFree)
901 dfLoss = dfLossTotal / nIterCount;
902 dfLoss = dfLossOverride.GetValueOrDefault(dfLoss);
907 bool bDisplay =
false;
908 if (!bDisplay1 && sw.ElapsedMilliseconds > 2000 && !bDisableOutput)
911 m_bFirstNanError =
true;
915 if (bDisplay && bDisplay1)
924 for (
int j = 0; j < colResult.
Count; j++)
926 double[] result_vec =
Utility.ConvertVec<T>(colResult[j].update_cpu_data());
927 int nIdx =
m_net.output_blob_indices[j];
928 string output_name =
m_net.blob_names[nIdx];
929 double loss_weight =
m_net.blob_loss_weights[nIdx];
930 double dfTotalLossWeight = 0;
931 int nResultCount = colResult[j].count();
933 for (
int k = 0; k < nResultCount; k++)
939 if (loss_weight != 0)
940 strOut +=
" (* " + loss_weight.ToString() +
" = " + (loss_weight * result_vec[k]).ToString() +
" loss)";
942 m_log.
WriteLine(
" Train net output #" + score_index.ToString() +
": " + output_name +
" = " + result_vec[k].ToString() + strOut);
947 dfTotalLossWeight += loss_weight * result_vec[k];
953 double dfAverage = dfTotalLossWeight / nResultCount;
954 m_log.
WriteLine(
" Average weighted score = " + dfAverage.ToString() +
" for '" + output_name +
"' - averaged over " + nResultCount.ToString(
"N0") +
" results.");
963 double dfLastLearningRate = 0;
965 if (step !=
TRAIN_STEP.FORWARD && bApplyUpdates)
971 if (!bDisableProgress)
974 bool bSnapshotTaken =
false;
980 (m_dfLastAccuracy > m_dfBestAccuracy))))
982 bSnapshotTaken =
true;
985 if (m_dfLastAccuracy > m_dfBestAccuracy)
986 m_dfBestAccuracy = m_dfLastAccuracy;
993 fireOnTrainingIterationEvent(bFwdPassNanFree, dfLoss, dfLastLearningRate, ref nTimingCount, ref dfTotalTime);
998 if (step !=
TRAIN_STEP.NONE || m_bEnableSingleStep)
1002 if (!bDisableOutput)
1003 m_log.
WriteLine(
"Single step (both) triggered - solving stopped after a single forward/backward pass.");
1007 if (!bDisableOutput)
1008 m_log.
WriteLine(
"Single step (forward) triggered - solving stopped after a single forward pass.");
1012 if (!bDisableOutput)
1013 m_log.
WriteLine(
"Single step (backward) triggered - solving stopped after a single backward pass.");
1019 if (!bSnapshotTaken)
1030 if (m_nTrainingTimeLimitInMinutes > 0 && swTimeout.Elapsed.TotalMinutes > m_nTrainingTimeLimitInMinutes)
1032 m_log.
WriteLine(
"A training time-limit of " + m_nTrainingTimeLimitInMinutes.ToString(
"N0") +
" minutes has been exceeded - training will now stop.");
1042 catch (Exception excpt)
1049 if (err !=
null || m_evtCancel.
WaitOne(0))
1063 public void Restore(
byte[] rgWeights,
byte[] rgState,
string strSkipBlobTypes =
null)
1065 m_net.LoadWeights(rgWeights,
m_persist,
null,
null, strSkipBlobTypes);
1067 if (rgState !=
null)
1069 m_log.
WriteLine(
"Restoring previous solver state from restore state...");
1081 public void Snapshot(
bool bForced,
bool bScheduled,
bool bUpdateDatabase =
true)
1104 private void args_OnGetWeights(
object sender,
GetBytesArgs e)
1110 private void args_OnGetState(
object sender,
GetBytesArgs e)
1127 if (dfAccuracy == 0)
1128 dfAccuracy = 0.0001;
1146 get {
return m_nTrainingIterationOverride; }
1147 set { m_nTrainingIterationOverride = value; }
1155 get {
return m_nTestingIterationOverride; }
1156 set { m_nTestingIterationOverride = value; }
1164 get {
return m_evtCompleted; }
1172 get {
return m_evtCancel; }
1196 get {
return m_net; }
1230 if (m_evtForceSnapshot ==
null)
1233 return m_evtForceSnapshot.WaitOne(0);
1244 if (m_evtForceTest ==
null)
1247 return m_evtForceTest.WaitOne(0);
1287 public double TestAll(
int nIterationOverride = -1)
1289 double dfTotalAccuracy = 0;
1290 double dfTotalTime = 0;
1291 int nTotalCount = 0;
1293 for (
int test_net_id = 0; test_net_id <
m_rgTestNets.Count; test_net_id++)
1305 dfTotalAccuracy += testOne(nIterationOverride, test_net_id);
1307 dfTotalTime += m_dfAverageTestTime;
1320 dfTotalAccuracy += testOne(nIterationOverride, 0);
1325 if (m_rgAverageAccuracyWindow !=
null)
1327 m_rgAverageAccuracyWindow.Add(dfAccuracy);
1328 m_rgAverageAccuracyWindow.RemoveAt(0);
1329 dfAccuracy = m_rgAverageAccuracyWindow.Average();
1334 double dfTime = (nTotalCount > 0) ? dfTotalTime / nTotalCount : 0;
1341 private double testOne(
int nIterationOverride = -1,
int nTestNetId = 0)
1363 Stopwatch sw =
new Stopwatch();
1369 m_log.
WriteLine(
"Iteration " +
m_nIter.ToString() +
", Testing net (#" + nTestNetId.ToString() +
")");
1380 Dictionary<int, Dictionary<int, List<Tuple<float, int>>>> rgAllTruePos =
new Dictionary<int, Dictionary<int, List<Tuple<float, int>>>>();
1381 Dictionary<int, Dictionary<int, List<Tuple<float, int>>>> rgAllFalsePos =
new Dictionary<int, Dictionary<int, List<Tuple<float, int>>>>();
1382 Dictionary<int, Dictionary<int, int>> rgAllNumPos =
new Dictionary<int, Dictionary<int, int>>();
1386 if (nIterationOverride <= 0)
1389 int nIter = nIterationOverride;
1392 for (
int i = 0; i < nIter; i++)
1405 dfLoss += iter_loss;
1407 for (
int j = 0; j < colResult.
Count; j++)
1409 m_log.
CHECK_EQ(colResult[j].width, 5,
"The width must be = 5 for SSD.");
1410 double[] result_vec =
Utility.ConvertVec<T>(colResult[j].update_cpu_data());
1411 int num_det = colResult[j].height;
1413 for (
int k = 0; k < num_det; k++)
1415 int item_id = (int)result_vec[k * 5];
1416 int nLabel = (int)result_vec[k * 5 + 1];
1421 if (!rgAllNumPos.ContainsKey(j))
1422 rgAllNumPos.
Add(j,
new Dictionary<int, int>());
1424 if (!rgAllNumPos[j].ContainsKey(nLabel))
1425 rgAllNumPos[j].Add(nLabel, (
int)result_vec[k * 5 + 2]);
1427 rgAllNumPos[j][nLabel] += (int)result_vec[k * 5 + 2];
1432 float fScore = (float)result_vec[k * 5 + 2];
1433 int tp = (int)result_vec[k * 5 + 3];
1434 int fp = (int)result_vec[k * 5 + 4];
1438 if (tp == 0 && fp == 0)
1441 if (!rgAllTruePos.ContainsKey(j))
1442 rgAllTruePos.Add(j,
new Dictionary<
int, List<Tuple<float, int>>>());
1444 if (!rgAllTruePos[j].ContainsKey(nLabel))
1445 rgAllTruePos[j].Add(nLabel,
new List<Tuple<float, int>>());
1447 if (!rgAllFalsePos.ContainsKey(j))
1448 rgAllFalsePos.Add(j,
new Dictionary<
int, List<Tuple<float, int>>>());
1450 if (!rgAllFalsePos[j].ContainsKey(nLabel))
1451 rgAllFalsePos[j].Add(nLabel,
new List<Tuple<float, int>>());
1453 rgAllTruePos[j][nLabel].Add(
new Tuple<float, int>(fScore, tp));
1454 rgAllFalsePos[j][nLabel].Add(
new Tuple<float, int>(fScore, fp));
1459 if (sw.Elapsed.TotalMilliseconds > 1000)
1479 float fTotalmAP = 0;
1480 for (
int i = 0; i < rgAllTruePos.Count; i++)
1482 if (!rgAllTruePos.ContainsKey(i))
1483 m_log.
FAIL(
"Missing output_blob true_pos: " + i.ToString());
1485 Dictionary<int, List<Tuple<float, int>>> rgTruePos = rgAllTruePos[i];
1487 if (!rgAllFalsePos.ContainsKey(i))
1488 m_log.
FAIL(
"Missing output_blob false_pos: " + i.ToString());
1490 Dictionary<int, List<Tuple<float, int>>> rgFalsePos = rgAllFalsePos[i];
1492 if (!rgAllNumPos.ContainsKey(i))
1493 m_log.
FAIL(
"Missing output_blob num_pos: " + i.ToString());
1495 Dictionary<int, int> rgNumPos = rgAllNumPos[i];
1497 Dictionary<int, float> rgAPs =
new Dictionary<int, float>();
1501 foreach (KeyValuePair<int, int> kv
in rgNumPos)
1503 int nLabel = kv.Key;
1504 int nLabelNumPos = kv.Value;
1506 if (!rgTruePos.ContainsKey(nLabel))
1508 m_log.
WriteLine(
"WARNING: Missing true_pos for label: " + nLabel.ToString() +
"!");
1511 List<Tuple<float, int>> rgLabelTruePos = rgTruePos[nLabel];
1513 if (!rgFalsePos.ContainsKey(nLabel))
1515 m_log.
WriteLine(
"WARNING: Missing false_pos for label: " + nLabel.ToString() +
"!");
1518 List<Tuple<float, int>> rgLabelFalsePos = rgFalsePos[nLabel];
1524 if (!rgAPs.ContainsKey(nLabel))
1525 rgAPs.Add(nLabel, fAp);
1527 rgAPs[nLabel] = fAp;
1532 m_log.
WriteLine(
"class " + nLabel.ToString() +
": " + fAp.ToString());
1535 fmAP /= rgNumPos.Count;
1538 string strOutputName = test_net.
blob_names[nOutputBlobIdx];
1540 m_log.
WriteLine(
" Test net output #" + i.ToString() +
": " + strOutputName +
" = " + fmAP.ToString());
1544 return fTotalmAP / rgAllTruePos.Count;
1546 catch (Exception excpt)
1567 m_log.
WriteLine(
"Iteration " +
m_nIter.ToString() +
", Testing net (#" + nTestNetId.ToString() +
")");
1578 List<double> test_score =
new List<double>();
1579 List<int> test_score_output_id =
new List<int>();
1582 if (nIterationOverride <= 0)
1585 int nIter = nIterationOverride;
1587 Stopwatch sw =
new Stopwatch();
1590 double dfTotalTiming = 0;
1592 int nAccuracyIdx = 0;
1593 int nMinRank =
int.MaxValue;
1594 bool bAccuracyValid =
false;
1595 Stopwatch swTiming =
new Stopwatch();
1597 for (
int i = 0; i < nIter; i++)
1612 dfLoss += iter_loss;
1621 test_score_output_id.Add(1);
1622 bAccuracyValid =
true;
1630 for (
int j = 0; j < colResult.
Count; j++)
1632 double[] result_vec =
Utility.ConvertVec<T>(colResult[j].update_cpu_data());
1634 for (
int k = 0; k < colResult[j].count(); k++)
1636 test_score.Add(result_vec[k]);
1637 test_score_output_id.Add(j);
1642 int nRank = (int)getNumber(colResult[j].
Tag, 0);
1643 if (nRank < nMinRank)
1655 for (
int j = 0; j < colResult.
Count; j++)
1657 double[] result_vec =
Utility.ConvertVec<T>(colResult[j].update_cpu_data());
1659 for (
int k = 0; k < colResult[j].count(); k++)
1661 test_score[idx] += result_vec[k];
1669 dfTotalTiming += swTiming.Elapsed.TotalMilliseconds;
1672 if (sw.ElapsedMilliseconds > 2000)
1674 double dfPct = (double)i / (
double)nIter;
1686 m_dfAverageTestTime = (nTestCount > 0) ? dfTotalTiming / nTestCount : 0;
1700 double dfFinalScore = 0;
1704 dfFinalScore = test_score.Sum();
1705 int nTotal = test_score_output_id.Sum();
1706 dfFinalScore /= nTotal;
1710 for (
int i = 0; i < test_score.Count; i++)
1712 int nIdxTestScore = test_score_output_id[i];
1714 string output_name = test_net.
blob_names[output_blob_index];
1716 double dfMeanScore = test_score[i] / nIter;
1721 if (loss_weight != 0)
1722 strOut +=
" (* " + loss_weight.ToString() +
" = " + (loss_weight * dfMeanScore).ToString() +
" loss)";
1724 m_log.
WriteLine(
" Test net output #" + i.ToString() +
": " + output_name +
" = " + dfMeanScore.ToString() + strOut);
1727 if (i == nAccuracyIdx)
1728 dfFinalScore = dfMeanScore;
1732 if (test_score.Count == 0)
1735 return dfFinalScore;
1738 private double getNumber(
object value,
double dfDefault)
1744 return (
double)(sbyte)value;
1747 return (
double)(byte)value;
1750 return (
double)(short)value;
1752 if (value is ushort)
1753 return (
double)(ushort)value;
1756 return (
double)(int)value;
1759 return (
double)(uint)value;
1762 return (
double)(long)value;
1765 return (
double)(ulong)value;
1768 return (
double)(float)value;
1770 if (value is
double)
1771 return (
double)value;
1773 if (value is decimal)
1774 return (
double)(decimal)value;
1787 if (nAverageLoss == 0)
1798 int nIdx = (
m_nIter - nStartIter) % nAverageLoss;
1803 if (m_bWeightsUpdated)
1806 m_bWeightsUpdated =
false;
1811 if (m_dfLastError < m_dfBestError)
1812 m_dfBestError = m_dfLastError;
1848 public static SGDSolver<T> Create(
CudaDnn<T> cuda,
Log log,
ProjectEx p,
CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest,
IXImageDatabaseBase imgDb,
IXPersist<T> persist,
int nSolverCount = 1,
int nSolverRank = 0,
Net<T> shareNet =
null, onGetWorkspace getws =
null, onSetWorkspace setws =
null)
1869 return Create(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, imgDb, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1889 public static SGDSolver<T> Create(
CudaDnn<T> cuda,
Log log,
SolverParameter solverParam,
CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest,
IXImageDatabaseBase imgDb,
IXPersist<T> persist,
int nSolverCount = 1,
int nSolverRank = 0,
Net<T> shareNet =
null, onGetWorkspace getws =
null, onSetWorkspace setws =
null)
1893 switch (solverParam.
type)
1896 solver =
new SGDSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, imgDb, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1900 solver =
new NesterovSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, imgDb, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1904 solver =
new AdaGradSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, imgDb, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1908 solver =
new AdaDeltaSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, imgDb, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1912 solver =
new AdamSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, imgDb, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1916 solver =
new AdamWSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, imgDb, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1920 solver =
new RmsPropSolver<T>(cuda, log, solverParam, evtCancel, evtForceSnapshot, evtForceTest, imgDb, persist, nSolverCount, nSolverRank, shareNet, getws, setws);
1924 throw new NotImplementedException(
"The solver " + solverParam.
type.ToString() +
" is not implemented yet!");
1931#pragma warning disable 1591
1933 public class OutputCollection
1935 OutputDataCollection m_rgError =
new OutputDataCollection();
1936 OutputDataCollection m_rgAccuracy =
new OutputDataCollection();
1938 public OutputCollection()
1942 public OutputDataCollection Errors
1944 get {
return m_rgError; }
1947 public OutputDataCollection Accuracies
1949 get {
return m_rgAccuracy; }
1953 public class OutputDataCollection : IEnumerable<OutputData>
1955 List<OutputData> m_rgData =
new List<OutputData>();
1957 public OutputDataCollection()
1961 public List<OutputData> Data
1963 get {
return m_rgData; }
1968 get {
return m_rgData.Count; }
1971 public OutputData
this[
int nIdx]
1973 get {
return m_rgData[nIdx]; }
1974 set { m_rgData[nIdx] = value; }
1977 public void Add(
int nTotal,
string strName,
int nIdx,
double dfVal)
1979 OutputData data = Find(strName);
1983 data =
new OutputData(strName, nIdx);
1987 data.Add(nTotal, dfVal);
1990 public OutputData Find(
string strName)
1992 foreach (OutputData data
in m_rgData)
1994 if (data.Name == strName)
2001 public IEnumerator<OutputData> GetEnumerator()
2003 return m_rgData.GetEnumerator();
2006 IEnumerator IEnumerable.GetEnumerator()
2008 return m_rgData.GetEnumerator();
2012 public class OutputData
2015 double m_dfValue = 0;
2018 public OutputData(
string strName,
int nIdx)
2020 m_strName = strName;
2026 get {
return m_nIdx; }
2031 get {
return m_strName; }
2036 get {
return m_dfValue; }
2037 set { m_dfValue = value; }
2040 public void Add(
int nTotal,
double dfVal)
2042 double dfRatio = 1.0 / (double)nTotal;
2043 m_dfValue = (m_dfValue * (1.0 - dfRatio)) + (dfRatio * dfVal);
2047#pragma warning restore 1591
The CancelEvent provides an extension to the manual cancel event that allows for overriding the manua...
bool WaitOne(int nMs=int.MaxValue)
Waits for the signal state to occur.
The Log class provides general output in text form.
void CHECK(bool b, string str)
Test a flag for true.
bool IsEnabled
Returns whether or not the Log is enabled.
void WriteLine(string str, bool bOverrideEnabled=false, bool bHeader=false, bool bError=false, bool bDisable=false)
Write a line of output.
bool Enable
Enables/disables the Log. When disabled, the Log does not output any data.
void FAIL(string str)
Causes a failure which throws an exception with the desciptive text.
double Progress
Get/set the progress associated with the Log.
void CHECK_EQ(double df1, double df2, string str)
Test whether one number is equal to another.
void WriteError(Exception e)
Write an error as output.
void CHECK_GT(double df1, double df2, string str)
Test whether one number is greater than another.
void CHECK_LE(double df1, double df2, string str)
Test whether one number is less than or equal to another.
void CHECK_GE(double df1, double df2, string str)
Test whether one number is greater than or equal to another.
The ProjectEx class manages a project containing the solver description, model description,...
string? SolverDescription
Get/set the solver description script used by the Project.
int ID
Returns the ID of the Project in the database.
string? ModelDescription
Get/set the model description script used by the Project.
The RawProto class is used to parse and output Google prototxt file data.
static RawProto Parse(string str)
Parses a prototxt and places it in a new RawProto.
The Utility class provides general utility funtions.
The BBox class processes the NormalizedBBox data used with SSD.
void Dispose()
Clean up all resources.
float ComputeAP(List< Tuple< float, int > > rgTp, int nNumPos, List< Tuple< float, int > > rgFp, ApVersion apVersion, out List< float > rgPrec, out List< float > rgRec)
Compute the average precision given true positive and false positive vectors.
The BlobCollection contains a list of Blobs.
void Add(Blob< T > b)
Add a new Blob to the collection.
int Count
Returns the number of items in the collection.
The Blob is the main holder of data that moves through the Layers of the Net.
virtual void Dispose(bool bDisposing)
Releases all resources used by the Blob (including both GPU and Host).
The CudaDnn object is the main interface to the Low-Level Cuda C++ DLL.
The CustomForwardBackArgs provide the arguments to the OnCustomForwardBack event within the Solver St...
double LocalLoss
Get/set the local loss of the pass.
bool FwdPassNanFree
Get/set whether or a NAN was detected in the forward pass.
The GetBytesArgs is passed along to the SnapshotArgs::OnGetWeights and SnapshotArgs::OnGetState event...
byte[] Data
Get/set the data as an array of bytes.
The GetIterationArgs is sent bubbled up to the solver when a layer needs to know the curret training ...
void SetIteration(Phase p, int nIteration)
The SetIteration method is used to set the iteration and the phase.
The GradientsReadyArgs is sent to the Solver::OnGradientsReady event which fires at the end of each S...
Connects Layer's together into a direct acrylic graph (DAG) specified by a NetParameter
BlobCollection< T > Forward()
Run forward with the input Blob's already fed separately.
List< string > blob_names
Returns the blob names.
List< double > blob_loss_weights
Returns the collection of blob loss weights.
string name
Returns the network name.
List< int > output_blob_indices
Returns a list of the output Blob indexes.
The SnapshotArgs is sent to the Solver::OnSnapshot event which fires each time the Solver::Snapshot m...
bool Forced
Get/set whether or not the snapshot was forced or not.
bool SingleStep
Get/set the Solver single step.
bool IncludeWeights
Get/set whether or not to include the weights in the snapshot.
bool Scheduled
Get/set whether or not the snapshot is a regular scheduled snapshot (e.g. not an improved accuracy or...
bool IncludeState
Get/set whether or not to include the Solver state in the snapshot.
EventHandler< GetBytesArgs > OnGetState
Specifies the OnGetState event which fires when the SnapshotArgs::UpdateState method is called.
bool UpdateDatabase
Get/set whether or not to update the database (default = true).
EventHandler< GetBytesArgs > OnGetWeights
Specifies the OnGetWeights event which fires when the SnapshotArgs::UpdateWeights method is called.
The TestArgs are passed to the Solver::OnTest event.
double Accuracy
Get/set the accuracy for the test run. When overriding the testing, the override should set the accur...
The TestResultArgs are passed to the Solver::OnTestResults event.
bool AccuracyValid
Get/set the accuracy valid flag. When not valid, the OnTestResults event is ignored.
double Accuracy
Get/set the accuracy. The recipient of this event should set this value.
Specifies the TestingIterationArgs sent to the Solver::OnTestingIteration, which is called at the end...
The TrainingIterationArgs is sent to the Solver::OnTrainingIteration event that fires at the end of a...
The WorkspaceArgs are passed to both the Layer::OnSetWorkspace and Layer::OnGetWorkspace events.
long Data
Get/set the handle to workspace data in GPU memory.
ulong Size
Get/set the size of the workspace memory (in bytes).
The Database class manages the actual connection to the physical database using Entity Framworks from...
Specifies the parameters use to create a Net
static NetParameter FromProto(RawProto rp)
Parse a RawProto into a new instance of the parameter.
NetState state
The current 'state' of the network, including the phase, level and stage. Some layers may be included...
int ProjectID
Specifies the ID of the project that created this net param (if any).
int solver_rank
Specifies the rank of the solver using this network.
int solver_count
Specifies the number of solvers used in a multi-gpu training session.
NetParameter Clone(bool bCloneLayers=true, int? nSolverCount=null, int? nSolverRank=null)
Creates a new copy of this instance of the parameter.
Specifies the NetState which includes the phase, level and stage for which a given Net is to run unde...
Phase phase
Specifies the Phase of the NetState.
void MergeFrom(NetState ns)
Merges another NetState with this instance.
The SolverParameter is a parameter for the solver, specifying the train and test networks.
int max_iter
The maximum number of iterations.
List< int > test_iter
The number of iterations for each test.
NetParameter net_param
Inline train net param, possibly combined with one or more test nets.
bool debug_info
If true, print information about the state of the net that may help with debugging learning problems.
NetParameter train_net_param
Inline train net param, possibly combined with one or more test nets.
List< NetState > test_state
The states for the train/test nets. Must be unspecified or specified once per net.
SolverType
Defines the type of solver.
string lr_policy
The learning rate decay policy.
static SolverParameter FromProto(RawProto rp)
Parses a new SolverParameter from a RawProto.
ApVersion ap_version
Specifies the AP Version to use for average precision when using Single-Shot Detection (SSD) - (defau...
long random_seed
If non-negative, the seed with which the Solver will initialize the caffe random number generator – u...
int average_loss
Display the loss averaged over the last average_loss iterations.
int test_interval
The number of iterations between two testing phases.
bool output_average_results
Specifies to average loss results before they are output - this can be faster when there are a lot of...
int iter_size
Accumulate gradients over 'iter_size' x 'batch_size' instances.
string DebugString()
Returns a debug string for the SolverParameter.
EvaluationType
Defines the evaluation method used in the SSD algorithm.
bool snapshot_after_train
If false, don't save a snapshot after training finishes.
bool snapshot_include_weights
Specifies whether or not the snapshot includes the trained weights. The default = true.
bool test_compute_loss
Test the compute loss.
SolverParameter()
The SolverParameter constructor.
EvaluationType eval_type
Specifies the evaluation type to use when using Single-Shot Detection (SSD) - (default = NONE,...
bool test_initialization
If true, run an initial test pass before the first iteration, ensuring memory availability and printi...
List< NetParameter > test_net_param
Inline test net params.
int display
The number of iterations between displaying info. If display = 0, no info will be displayed.
bool snapshot_diff
Whether to snapshot diff in the results or not. Snapshotting diff will help debugging but the final p...
bool snapshot_include_state
Specifies whether or not the snapshot includes the solver state. The default = false....
bool show_per_class_result
Specifies whether or not to display results per class when using Single-Shot Detection (SSD) - (defau...
int accuracy_average_window
Specifies the window over which to average the accuracies (default = 0 which ignores averaging).
int snapshot
Specifies the snapshot interval.
SolverType type
Specifies the solver type.
NetState train_state
The states for the train/test nets. Must be unspecified or specified once per net.
Use AdaDelta Solver which has gradient based optimization like SGD.
Use AdaGrad Solver based optimization like SGD that tries to find rarely seen features.
Use Adam Solver which uses gradient based optimization like SGD that includes 'adaptive momentum esti...
Use AdamW Solver which uses gradient based optimization like Adam with a decoupled weight decay.
Use Nesterov's accelerated gradient Solver, which is similar to SGD, but the error gradient is comput...
Use RmsProp Solver which uses gradient based optimization like SGD.
Stochastic Gradient Descent solver with momentum updates weights by a linear combination of the negat...
An interface for classes that perform optimization on Nets
List< Net< T > > m_rgTestNets
Specifies the testing Nets.
int TrainingIterations
Returns the current training iterations remaining.
void InitTestNets()
Initializes the Net used by the Solver for testing.
EventHandler< CustomForwardBackArgs< T > > OnCustomForwardBack
The OnCustomForwardBack allows for overriding the forward/backward operations within the solver.
int m_nSolverCount
Specifies the Solver count in a multi-GPU training session.
void Dispose()
Discards the resources (GPU and Host) used by this Solver.
double m_dfSmoothedLoss
Specifies the smoothed loss protected for derived classes to use.
SolverParameter m_param
Specifies the SolverParameter that defines how the Solver operates.
EventHandler< TrainingIterationArgs< T > > OnTrainingIteration
The OnTrainingIteration event fires at the end of each training iteration.
List< double > m_rgLosses
Specifies the Losses used to calculate the smoothed Loss.
abstract byte[] SnapshotSolverState()
Save the current solver state.
double smoothed_loss
Returns the smoothed loss.
void Restore(byte[] rgWeights, byte[] rgState, string strSkipBlobTypes=null)
The restore method simply calls the RestoreSolverState method of the inherited class.
static SGDSolver< T > Create(CudaDnn< T > cuda, Log log, SolverParameter solverParam, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXImageDatabaseBase imgDb, IXPersist< T > persist, int nSolverCount=1, int nSolverRank=0, Net< T > shareNet=null, onGetWorkspace getws=null, onSetWorkspace setws=null)
Create a new Solver based on the project containing the SolverParameter.
int iter
Returns the current training iteration.
CudaDnn< T > m_cuda
Specifies the instance of CudaDnn used by the Solver that provides a connection to Cuda.
void Snapshot(bool bForced, bool bScheduled, bool bUpdateDatabase=true)
The snapshot function implements the basic snapshotting utility that stores the learned net....
int MaximumIteration
Returns the maximum training iterations.
EventHandler< SnapshotArgs > OnSnapshot
The OnSnapshot event fires when the Solver detects that a snapshot is needed.
bool EnableBlobDebugging
When enabled, the OnTrainingIteration event is set extra debugging information describing the state o...
SolverParameter.SolverType type
Returns the type of solver.
Net< T > net
Returns the main training Net.
bool ForceOnTrainingIterationEvent()
Force an OnTrainingIterationEvent to fire.
object Tag
Returns a generic tag associated with the Solver.
double TestDetection(int nIterationOverride=-1, int nTestNetId=0)
Run an SSD detection test on a given test Net by running it through its iterations.
bool? is_root_solver
Returns whether or not this is the root solver.
double LearningRateOverride
Get/set the learning rate override. When 0, this setting is ignored.
bool EnableTesting
When enabled, the training cycle calls TestAll periodically based on the SolverParameter....
int m_nIter
Specifies the current iteration.
Net< T > TrainingNet
Returns the training Net used by the solver.
double m_dfLearningRateOverride
Optionally, specifies a learning rate override (default = 0, which ignores this setting).
EventHandler< TestingIterationArgs< T > > OnTestingIteration
The OnTestingIteration event fires at the end of each testing iteration.
void InitTrainNet(Net< T > shareNet=null)
Initializes the Net used by the solver for training.
abstract void RestoreSolverState(byte[] rgState)
Restore a solver state.
void UpdateSmoothedLoss(double dfLoss, int nStartIter, int nAverageLoss=0)
Update the avaraged loss value.
void Init(SolverParameter p, Net< T > shareNet=null)
Initializes the Solver.
bool EnableBreakOnFirstNaN
When enabled (requires EnableBlobDebugging = true), the Solver immediately stop training upon detecti...
int solver_count
Returns the solver count in a multi-GPU session.
SolverParameter parameter
Returns the SolverParameter used.
bool forceSnapshot
Returns whether or not a snapshot has been forced.
SNAPSHOT_WEIGHT_UPDATE_METHOD SnapshotWeightUpdateMethod
Get/set the snapshot weight update method.
EventHandler OnAborted
The OnAborted event fires after aborting a training cycle.
List< Net< T > > test_nets
Returns the testing Nets.
static SGDSolver< T > Create(CudaDnn< T > cuda, Log log, ProjectEx p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXImageDatabaseBase imgDb, IXPersist< T > persist, int nSolverCount=1, int nSolverRank=0, Net< T > shareNet=null, onGetWorkspace getws=null, onSetWorkspace setws=null)
Create a new Solver based on the project containing the SolverParameter.
IXPersist< T > m_persist
Specifies the persistance object used to save weight and solver states.
int TrainingTimeLimitInMinutes
Get/set the training time limit in minutes. When set to 0, no time limit is imposed on training.
EventHandler< WorkspaceArgs > OnGetWorkspace
Specifies the OnGetWorkspace event that fires when the getWorkspace() function is called by a layer t...
double TestClassification(int nIterationOverride=-1, int nTestNetId=0)
Run a test on a given test Net by running it through its iterations.
void Reset()
Reset the iterations of the net.
bool Step(int nIters, TRAIN_STEP step=TRAIN_STEP.NONE, bool bZeroDiffs=true, bool bApplyUpdates=true, bool bDisableOutput=false, bool bDisableProgress=false, double? dfLossOverride=null, bool? bAllowSnapshot=null)
Steps a set of iterations through a training cycle.
double TestAll(int nIterationOverride=-1)
Run a TestAll by running all test Nets.
string LabelQueryEpochs
Return the label query epochs for the active datasource.
EventHandler< TestResultArgs< T > > OnTestResults
When specified, the OnTestResults event fires after each single test run. The recipient is responsibl...
EventHandler< GradientsReadyArgs > OnGradientsReady
The OnGradientsReady event fires after the gradients of a Solver are ready for distribution to other ...
EventHandler< WorkspaceArgs > OnSetWorkspace
Specifies the OnSetWorkspace event that fires when the setWorkspace() function is called by a layer t...
int? TestingIterations
Returns the current testing iterations remaining.
bool forceTest
Returns whether or not a test has been forced.
int TrainingIterationOverride
Get/set the training iteration override.
EventHandler OnTestStart
The OnTestStart event fires at the start of each testing iteration.
Net< T > m_net
Specifies the training Net.
bool WeightsUpdated
Get/set when the weights have been updated.
int m_nCurrentStep
Specifies the current step.
int solver_rank
Returns this Solver's rank in a multi-GPU session.
bool EnableDetailedNanDetection
When enabled (requires EnableBlobDebugging = true), the detailed Nan (and Infinity) detection is pero...
Log m_log
Specifies the Log for output.
SnapshotArgs GetSnapshotArgs(byte[] rgState, byte[] rgWeights, double dfAccuracy, double dfError, int nIteration, SNAPSHOT_WEIGHT_UPDATE_METHOD wtUpdt)
The GetSnapshotArgs method fills out a snapshot args structure.
virtual void dispose()
Override that allows discarding of resources (GPU and Host) used by this Solver.
EventHandler< TestArgs > OnTest
When specified, the OnTest event fires during a TestAll and overrides the call to Test.
int TestingIterationOverride
Get/set the testing iteration override.
virtual void Solve(int nIterationOverride=-1, byte[] rgWeights=null, byte[] rgState=null, TRAIN_STEP step=TRAIN_STEP.NONE)
The main entry of the solver function. In default, iter will be zero. Pass in a non-zero iter number ...
EventHandler OnStart
The OnStart event fires at the start of each training iteration.
string ActiveLabelCounts
Returns a string describing the labels detected in the training along with the % that each label has ...
AutoResetEvent CompletedEvent
Returns an auto reset event that is set upon training completion.
abstract double ApplyUpdate(int nIterationOverride=-1)
Make and apply the update value for the current iteration.
CudaDnn< T > Cuda
Returns the CudaDnn instance used by the Solver.
bool EnableSingleStep
When enabled (requires EnableBlobDebugging = true), the Solver only runs one training cycle.
int m_nSolverRank
Specifies the Solver rank of this solver, where rank == 0 is the root Solver.
string LabelQueryHitPercents
Return the label query hit percentages for the active datasource.
Net< T > TestingNet
Returns the testing Net used by the solver.
bool EnableLayerDebugging
Enable/disable layer debugging which causes each layer to check for NAN/INF on each forward/backward ...
Solver(CudaDnn< T > cuda, Log log, SolverParameter p, CancelEvent evtCancel, AutoResetEvent evtForceSnapshot, AutoResetEvent evtForceTest, IXImageDatabaseBase imgDb, IXPersist< T > persist, int nSolverCount=1, int nSolverRank=0, Net< T > shareNet=null, onGetWorkspace getws=null, onSetWorkspace setws=null)
The Solver constructor.
int CurrentIteration
Returns the current training iteration.
The IXImageDatabaseBase interface defines the general interface to the in-memory image database.
The IXPersist interface is used by the CaffeControl to load and save weights.
The MyCaffe.basecode contains all generic types used throughout MyCaffe.
Phase
Defines the Phase under which to run a Net.
SNAPSHOT_WEIGHT_UPDATE_METHOD
Defines the snapshot weight update method.
The MyCaffe.common namespace contains common MyCaffe classes.
BLOB_TYPE
Defines the tpe of data held by a given Blob.
TRAIN_STEP
Defines the training stepping method (if any).
The MyCaffe.db.image namespace contains all image database related classes.
The MyCaffe.param namespace contains parameters used to create models.
The MyCaffe.solvers namespace contains all solver classes, including the base Solver.
The MyCaffe namespace contains the main body of MyCaffe code that closesly tracks the C++ Caffe open-...