GraphChi  0.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Macros
als_vertex_program.hpp
Go to the documentation of this file.
1 
25 #ifndef ALS_VERTEX_PROGRAM_HPP
26 #define ALS_VERTEX_PROGRAM_HPP
27 
28 
40 #include <Eigen/Dense>
41 
42 //#include <graphlab.hpp>
43 
44 //#include "eigen_serialization.hpp"
45 
46 
47 typedef Eigen::VectorXd vec_type;
48 typedef Eigen::MatrixXd mat_type;
49 
50 
63 struct vertex_data {
68  static size_t NLATENT;
70  uint32_t nupdates;
72  float residual;
73 
74  vec_type factor;
81  void randomize() { factor.resize(NLATENT); factor.setRandom(); }
83  //void save(graphlab::oarchive& arc) const {
84  // arc << nupdates << residual << factor;
85  //}
87  //void load(graphlab::iarchive& arc) {
88  // arc >> nupdates >> residual >> factor;
89  //}
90 }; // end of vertex data
91 
92 
107  enum data_role_type { TRAIN, VALIDATE, PREDICT };
108 
110  float obs;
111 
114 
116  edge_data(float obs = 0, data_role_type role = PREDICT) :
117  obs(obs), role(role) { }
118 
119 }; // end of edge data
120 
121 
127 
128 
135  const graph_type::vertex_type& vertex) {
136  return vertex.id() == edge.source().id()? edge.target() : edge.source();
137 }; // end of get_other_vertex
138 
139 
144  const double pred =
145  edge.source().data().factor.dot(edge.target().data().factor);
146  return (edge.data().obs - pred) * (edge.data().obs - pred);
147 } // end of extract_l2_error
148 
149 
154 // Commented out for graphchi
155 /*
156 inline bool graph_loader(graph_type& graph,
157  const std::string& filename,
158  const std::string& line) {
159  ASSERT_FALSE(line.empty());
160  // Determine the role of the data
161  edge_data::data_role_type role = edge_data::TRAIN;
162  if(boost::ends_with(filename,".validate")) role = edge_data::VALIDATE;
163  else if(boost::ends_with(filename, ".predict")) role = edge_data::PREDICT;
164  // Parse the line
165  std::stringstream strm(line);
166  graph_type::vertex_id_type source_id(-1), target_id(-1);
167  float obs(0);
168  strm >> source_id >> target_id;
169  if(role == edge_data::TRAIN || role == edge_data::VALIDATE) strm >> obs;
170  // Create an edge and add it to the graph
171  graph.add_edge(source_id, target_id+1000000, edge_data(obs, role));
172  return true; // successful load
173 } // end of graph_loader
174 
175 */
176 
177 
195 class gather_type {
196 public:
201  mat_type XtX;
202 
206  vec_type Xy;
207 
210 
215  gather_type(const vec_type& X, const double y) :
216  XtX(X.size(), X.size()), Xy(X.size()) {
217  XtX.triangularView<Eigen::Upper>() = X * X.transpose();
218  Xy = X * y;
219  } // end of constructor for gather type
220 
222 // void save(graphlab::oarchive& arc) const { arc << XtX << Xy; }
223 
225  // void load(graphlab::iarchive& arc) { arc >> XtX >> Xy; }
226 
232  if(other.Xy.size() == 0) {
233  ASSERT_EQ(other.XtX.rows(), 0);
234  ASSERT_EQ(other.XtX.cols(), 0);
235  } else {
236  if(Xy.size() == 0) {
237  ASSERT_EQ(XtX.rows(), 0);
238  ASSERT_EQ(XtX.cols(), 0);
239  XtX = other.XtX; Xy = other.Xy;
240  } else {
241  XtX.triangularView<Eigen::Upper>() += other.XtX;
242  Xy += other.Xy;
243  }
244  }
245  return *this;
246  } // end of operator+=
247 
248 }; // end of gather type
249 
250 
251 
256  public graphlab::ivertex_program<graph_type, gather_type,
257  graphlab::messages::sum_priority>,
258  public graphlab::IS_POD_TYPE {
259 public:
261  static double TOLERANCE;
262  static double LAMBDA;
263  static size_t MAX_UPDATES;
264 
266  edge_dir_type gather_edges(icontext_type& context,
267  const vertex_type& vertex) const {
268  return graphlab::ALL_EDGES;
269  }; // end of gather_edges
270 
272  gather_type gather(icontext_type& context, const vertex_type& vertex,
273  edge_type& edge) const {
274  if(edge.data().role == edge_data::TRAIN) {
275  const vertex_type other_vertex = get_other_vertex(edge, vertex);
276  return gather_type(other_vertex.data().factor, edge.data().obs);
277  } else return gather_type();
278  } // end of gather function
279 
281  void apply(icontext_type& context, vertex_type& vertex,
282  const gather_type& sum) {
283  // Get and reset the vertex data
284  vertex_data& vdata = vertex.data();
285  // Determine the number of neighbors. Each vertex has only in or
286  // out edges depending on which side of the graph it is located
287  if(sum.Xy.size() == 0) { vdata.residual = 0; ++vdata.nupdates; return; }
288  mat_type XtX = sum.XtX;
289  vec_type Xy = sum.Xy;
290  // Add regularization
291  for(int i = 0; i < XtX.rows(); ++i) XtX(i,i) += LAMBDA; // /nneighbors;
292  // Solve the least squares problem using eigen ----------------------------
293  const vec_type old_factor = vdata.factor;
294  vdata.factor = XtX.selfadjointView<Eigen::Upper>().ldlt().solve(Xy);
295  // Compute the residual change in the factor factor -----------------------
296  vdata.residual = (vdata.factor - old_factor).cwiseAbs().sum() / XtX.rows();
297  ++vdata.nupdates;
298  } // end of apply
299 
301  edge_dir_type scatter_edges(icontext_type& context,
302  const vertex_type& vertex) const {
303  return graphlab::ALL_EDGES;
304  }; // end of scatter edges
305 
307  void scatter(icontext_type& context, const vertex_type& vertex,
308  edge_type& edge) const {
309  /* edge_data& edata = edge.data();
310  if(edata.role == edge_data::TRAIN) {
311  const vertex_type other_vertex = get_other_vertex(edge, vertex);
312  const vertex_data& vdata = vertex.data();
313  const vertex_data& other_vdata = other_vertex.data();
314  const double pred = vdata.factor.dot(other_vdata.factor);
315  const float error = std::fabs(edata.obs - pred);
316  const double priority = (error * vdata.residual);
317  // Reschedule neighbors ------------------------------------------------
318  if( priority > TOLERANCE && other_vdata.nupdates < MAX_UPDATES)
319  context.signal(other_vertex, priority);
320  }*/
321  } // end of scatter function
322 
323 
328  vertex_type& vertex) {
329  if(vertex.num_out_edges() > 0) context.signal(vertex);
330  return graphlab::empty();
331  } // end of signal_left
332 
333 }; // end of als vertex program
334 
335 
339  double train_error, validation_error;
340  size_t ntrain, nvalidation;
341  error_aggregator() :
342  train_error(0), validation_error(0), ntrain(0), nvalidation(0) { }
343  error_aggregator& operator+=(const error_aggregator& other) {
344  train_error += other.train_error;
345  validation_error += other.validation_error;
346  ntrain += other.ntrain;
347  nvalidation += other.nvalidation;
348  return *this;
349  }
350  static error_aggregator map(icontext_type& context, const graph_type::edge_type& edge) {
351  error_aggregator agg;
352  if(edge.data().role == edge_data::TRAIN) {
353  agg.train_error = extract_l2_error(edge); agg.ntrain = 1;
354  } else if(edge.data().role == edge_data::VALIDATE) {
355  agg.validation_error = extract_l2_error(edge); agg.nvalidation = 1;
356  }
357  return agg;
358  }
359  static void finalize(icontext_type& context, error_aggregator& agg) {
360  ASSERT_GT(agg.ntrain, 0);
361  agg.train_error = std::sqrt(agg.train_error / agg.ntrain);
362  context.cout() << context.elapsed_seconds() << "\t" << agg.train_error;
363  if(agg.nvalidation > 0) {
364  const double validation_error =
365  std::sqrt(agg.validation_error / agg.nvalidation);
366  context.cout() << "\t" << validation_error;
367  }
368  context.cout() << std::endl;
369  }
370 }; // end of error aggregator
371 
372 
376  std::string save_vertex(const vertex_type& vertex) const {
377  return ""; //nop
378  }
379  std::string save_edge(const edge_type& edge) const {
380  std::stringstream strm;
381  const double prediction =
382  edge.source().data().factor.dot(edge.target().data().factor);
383  strm << edge.source().id() << '\t'
384  << edge.target().id() << '\t'
385  << prediction << '\n';
386  return strm.str();
387  }
388 }; // end of prediction_saver
389 
390 
391 
392 
393 #endif