SHOGUN  v3.2.0
LogDetEstimator.cpp
浏览该文件的文档.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2013 Soumyajit De
8  */
9 
10 #include <shogun/lib/common.h>
11 #include <shogun/lib/SGVector.h>
12 #include <shogun/lib/SGMatrix.h>
20 
21 namespace shogun
22 {
23 
25  : CSGObject()
26 {
27  init();
28 }
29 
31  COperatorFunction<float64_t>* operator_log,
32  CIndependentComputationEngine* computation_engine)
33  : CSGObject()
34 {
35  init();
36 
37  m_trace_sampler=trace_sampler;
38  SG_REF(m_trace_sampler);
39 
40  m_operator_log=operator_log;
41  SG_REF(m_operator_log);
42 
43  m_computation_engine=computation_engine;
44  SG_REF(m_computation_engine);
45 }
46 
47 void CLogDetEstimator::init()
48 {
49  m_trace_sampler=NULL;
50  m_operator_log=NULL;
51  m_computation_engine=NULL;
52 
53  SG_ADD((CSGObject**)&m_trace_sampler, "trace_sampler",
54  "Trace sampler for the log operator", MS_NOT_AVAILABLE);
55 
56  SG_ADD((CSGObject**)&m_operator_log, "operator_log",
57  "The log operator function", MS_NOT_AVAILABLE);
58 
59  SG_ADD((CSGObject**)&m_computation_engine, "computation_engine",
60  "The computation engine for the jobs", MS_NOT_AVAILABLE);
61 }
62 
64 {
65  SG_UNREF(m_trace_sampler);
66  SG_UNREF(m_operator_log);
67  SG_UNREF(m_computation_engine);
68 }
69 
71 {
72  SG_DEBUG("Entering\n");
73  SG_INFO("Computing %d log-det estimates\n", num_estimates);
74 
75  REQUIRE(m_operator_log, "Operator function is NULL\n");
76  // call the precompute of operator function to compute the prerequisites
77  m_operator_log->precompute();
78 
79  REQUIRE(m_trace_sampler, "Trace sampler is NULL\n");
80  // call the precompute of the sampler
81  m_trace_sampler->precompute();
82 
83  REQUIRE(m_operator_log->get_operator()->get_dimension()\
84  ==m_trace_sampler->get_dimension(),
85  "Mismatch in dimensions of the operator and trace-sampler, %d vs %d!\n",
86  m_operator_log->get_operator()->get_dimension(),
87  m_trace_sampler->get_dimension());
88 
89  // for storing the aggregators that submit_jobs return
90  CDynamicObjectArray* aggregators=new CDynamicObjectArray();
91  index_t num_trace_samples=m_trace_sampler->get_num_samples();
92 
93  for (index_t i=0; i<num_estimates; ++i)
94  {
95  for (index_t j=0; j<num_trace_samples; ++j)
96  {
97  SG_INFO("Computing log-determinant trace sample %d/%d\n", j,
98  num_trace_samples);
99 
100  SG_DEBUG("Creating job for estimate %d, trace sample %d/%d\n", i, j,
101  num_trace_samples);
102  // get the trace sampler vector
103  SGVector<float64_t> s=m_trace_sampler->sample(j);
104  // create jobs with the sample vector and store the aggregator
105  CJobResultAggregator* agg=m_operator_log->submit_jobs(s);
106  aggregators->append_element(agg);
107  SG_UNREF(agg);
108  }
109  }
110 
111  REQUIRE(m_computation_engine, "Computation engine is NULL\n");
112 
113  // wait for all the jobs to be completed
114  SG_INFO("Waiting for jobs to finish\n");
115  m_computation_engine->wait_for_all();
116  SG_INFO("All jobs finished, aggregating results\n");
117 
118  // the samples vector which stores the estimates with averaging
119  SGVector<float64_t> samples(num_estimates);
120  samples.zero();
121 
122  // use the aggregators to find the final result
123  // use the same order as job submission to combine results
124  int32_t num_aggregates=aggregators->get_num_elements();
125  index_t idx_row=0;
126  index_t idx_col=0;
127  for (int32_t i=0; i<num_aggregates; ++i)
128  {
129  // this cast is safe due to above way of building the array
130  CJobResultAggregator* agg=dynamic_cast<CJobResultAggregator*>
131  (aggregators->get_element(i));
132  ASSERT(agg);
133 
134  // call finalize on all the aggregators, cast is safe again
135  agg->finalize();
137  (agg->get_final_result());
138  ASSERT(r);
139 
140  // iterate through indices, group results in the same way as jobs
141  samples[idx_col]+=r->get_result();
142  idx_row++;
143  if (idx_row>=num_trace_samples)
144  {
145  idx_row=0;
146  idx_col++;
147  }
148 
149  SG_UNREF(agg);
150  }
151 
152  // clear all aggregators
153  SG_UNREF(aggregators)
154 
155  SG_INFO("Finished computing %d log-det estimates\n", num_estimates);
156 
157  SG_DEBUG("Leaving\n");
158  return samples;
159 }
160 
162  index_t num_estimates)
163 {
164  SG_DEBUG("Entering...\n")
165 
166  REQUIRE(m_operator_log, "Operator function is NULL\n");
167  // call the precompute of operator function to compute all prerequisites
168  m_operator_log->precompute();
169 
170  REQUIRE(m_trace_sampler, "Trace sampler is NULL\n");
171  // call the precompute of the sampler
172  m_trace_sampler->precompute();
173 
174  // for storing the aggregators that submit_jobs return
175  CDynamicObjectArray aggregators;
176  index_t num_trace_samples=m_trace_sampler->get_num_samples();
177 
178  for (index_t i=0; i<num_estimates; ++i)
179  {
180  for (index_t j=0; j<num_trace_samples; ++j)
181  {
182  // get the trace sampler vector
183  SGVector<float64_t> s=m_trace_sampler->sample(j);
184  // create jobs with the sample vector and store the aggregator
185  CJobResultAggregator* agg=m_operator_log->submit_jobs(s);
186  aggregators.append_element(agg);
187  SG_UNREF(agg);
188  }
189  }
190 
191  REQUIRE(m_computation_engine, "Computation engine is NULL\n");
192  // wait for all the jobs to be completed
193  m_computation_engine->wait_for_all();
194 
195  // the samples matrix which stores the estimates without averaging
196  // dimension: number of trace samples x number of log-det estimates
197  SGMatrix<float64_t> samples(num_trace_samples, num_estimates);
198 
199  // use the aggregators to find the final result
200  int32_t num_aggregates=aggregators.get_num_elements();
201  for (int32_t i=0; i<num_aggregates; ++i)
202  {
203  CJobResultAggregator* agg=dynamic_cast<CJobResultAggregator*>
204  (aggregators.get_element(i));
205  if (!agg)
206  SG_ERROR("Element is not CJobResultAggregator type!\n");
207 
208  // call finalize on all the aggregators
209  agg->finalize();
211  (agg->get_final_result());
212  if (!r)
213  SG_ERROR("Result is not CScalarResult type!\n");
214 
215  // its important that we don't just unref the result here
216  index_t idx_row=i%num_trace_samples;
217  index_t idx_col=i/num_trace_samples;
218  samples(idx_row, idx_col)=r->get_result();
219  SG_UNREF(agg);
220  }
221 
222  // clear all aggregators
223  aggregators.clear_array();
224 
225  SG_DEBUG("Leaving\n")
226  return samples;
227 }
228 
229 }
230 
Base class that stores the result of an independent job when the result is a scalar.
#define SG_INFO(...)
Definition: SGIO.h:120
SGVector< float64_t > sample(index_t num_estimates)
int32_t index_t
Definition: common.h:60
const T get_result() const
Definition: ScalarResult.h:66
#define SG_UNREF(x)
Definition: SGRefObject.h:35
virtual SGVector< float64_t > sample(index_t idx) const =0
CSGObject * get_element(int32_t index) const
#define SG_ERROR(...)
Definition: SGIO.h:131
#define REQUIRE(x,...)
Definition: SGIO.h:208
SGMatrix< float64_t > sample_without_averaging(index_t num_estimates)
CLinearOperator< T > * get_operator() const
virtual void precompute()=0
#define ASSERT(x)
Definition: SGIO.h:203
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:102
const index_t get_dimension() const
virtual const index_t get_dimension() const
Definition: TraceSampler.h:77
#define SG_REF(x)
Definition: SGRefObject.h:34
virtual CJobResultAggregator * submit_jobs(SGVector< T > sample)=0
Dynamic array class for CSGObject pointers that creates an array that can be used like a list or an a...
virtual void precompute()=0
Abstract base class that provides an interface for computing an aggeregation of the job results of in...
CJobResult * get_final_result() const
#define SG_DEBUG(...)
Definition: SGIO.h:109
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:16
Abstract base class for solving multiple independent instances of CIndependentJob. It has one method, submit_job, which may add the job to an internal queue and might block if there is yet not space in the queue. After jobs are submitted, it might not yet be ready. wait_for_all waits until all jobs are completed, which must be called to guarantee that all jobs are finished.
Abstract template base class that provides an interface for sampling the trace of a linear operator u...
Definition: TraceSampler.h:23
#define SG_ADD(...)
Definition: SGObject.h:71
virtual const index_t get_num_samples() const
Definition: TraceSampler.h:71
bool append_element(CSGObject *e)

SHOGUN Machine Learning Toolbox - Documentation