MVE - Multi-View Environment mve-devel
Loading...
Searching...
No Matches
cascade_hashing.h
Go to the documentation of this file.
1/*
2 * Copyright (C) 2016, Andre Schulz
3 * TU Darmstadt - Graphics, Capture and Massively Parallel Computing
4 * All rights reserved.
5 *
6 * This software may be modified and distributed under the terms
7 * of the BSD 3-Clause license. See the LICENSE.txt file for details.
8 */
9
10#ifndef SFM_CASCADE_HASHING_HEADER
11#define SFM_CASCADE_HASHING_HEADER
12
13#include <cstring>
14#include <iostream>
15#include <random>
16
17#include "math/functions.h"
18#include "math/vector.h"
19#include "sfm/defines.h"
21#include "sfm/matching.h"
22#include "sfm/sift.h"
23#include "sfm/surf.h"
24#include "util/system.h"
25#include "util/timer.h"
26
28
30{
31public:
32 struct Options
33 {
35 uint8_t num_bucket_groups = 6;
36
38 uint8_t num_bucket_bits = 8;
39
41 uint16_t min_num_candidates = 6;
42
44 uint16_t max_num_candidates = 10;
45 };
46
51 void init (bundler::ViewportList* viewports) override;
52
54 void pairwise_match (int view_1_id, int view_2_id,
55 Matching::Result* result) const override;
56
57private:
58 struct LocalData;
59
64 template <typename D>
65 void twoway_match (Matching::Options const& matching_opts,
66 LocalData const& set_1, LocalData const& set_2,
67 D const& set_1_descs, D const& set_2_descs,
68 Matching::Result* matches, Options const& cashash_opts) const;
69
74 template <typename D>
75 void oneway_match (Matching::Options const& matching_opts,
76 LocalData const& set_1, LocalData const& set_2,
77 D const& set_1_descs, D const& set_2_descs,
78 std::vector<int>* result, Options const& cashash_opts) const;
79
85 template <typename T>
86 struct ProjMats
87 {
96 std::vector<T> prim_proj_mat;
97
103 std::vector<std::vector<T>> sec_proj_mats;
104 };
105
106 typedef ProjMats<math::Vec64f> ProjMatsSurf;
107 typedef ProjMats<math::Vec128f> ProjMatsSift;
108
115 class GlobalData
116 {
117 public:
118 void generate_proj_matrices (Options const& cashash_opts);
119
120 ProjMatsSift sift;
121 ProjMatsSurf surf;
122
123 private:
124 template <typename T>
125 void generate_proj_matrices (std::vector<T>* prim_hash,
126 std::vector<std::vector<T>>* sec_hash, Options const& cashash_opts);
127 };
128
129 typedef std::vector<std::size_t> BucketFeatureIDs;
130 typedef std::vector<BucketFeatureIDs> BucketGroupFeatures;
131 typedef std::vector<BucketGroupFeatures> BucketGroupsFeatures;
132
133 typedef std::vector<uint16_t> BucketIDs;
134 typedef std::vector<BucketIDs> BucketGroupsBuckets;
135
137 struct LocalData
138 {
146 std::vector<uint64_t> comp_hash_data;
147
153 BucketGroupsBuckets bucket_grps_bucket_ids;
154
160 BucketGroupsFeatures bucket_grps_feature_ids;
161 };
162
164 void compute (LocalData* ld_sift, LocalData* ld_surf,
165 std::vector<math::Vec128f> const& sift_zero_mean_descs,
166 std::vector<math::Vec64f> const& surf_zero_mean_descs,
167 GlobalData const& cashash_global_data, Options const& cashash_opts);
168
170 void compute_avg_descriptors (ProcessedFeatureSets const& pfs,
171 math::Vec128f* sift_avg, math::Vec64f* surf_avg);
172
174 void compute_zero_mean_descs(
175 std::vector<math::Vec128f>* sift_zero_mean_descs,
176 std::vector<math::Vec64f>* surf_zero_mean_descs,
177 SiftDescriptors const& sift_descs, SurfDescriptors const& surf_descs,
178 math::Vec128f const& sift_avg, math::Vec64f const& surf_avg);
179
181 template <typename T>
182 void compute_cascade_hashes (std::vector<T> const& zero_mean_descs,
183 std::vector<uint64_t>* comp_hash_data,
184 BucketGroupsBuckets* bucket_grps_bucket_ids,
185 std::vector<T> const& prim_proj_mat,
186 std::vector<std::vector<T>> const& sec_proj_mats, Options const& cashash_opts);
187
189 void build_buckets (BucketGroupsFeatures* bucket_grps_feature_ids,
190 BucketGroupsBuckets const& bucket_grps_bucket_ids, size_t num_descs,
191 Options const& cashash_opts);
192
198 template <int DIMHASH>
199 void collect_features_from_buckets (
200 std::vector<std::vector<uint32_t>>* grouped_features,
201 size_t feature_id, std::vector<bool>* data_index_used,
202 BucketGroupsBuckets const& bucket_grps_bucket_ids,
203 BucketGroupsFeatures const& bucket_grps_feature_ids,
204 uint64_t const* comp_hash_data1,
205 std::vector<uint64_t> const& comp_hash_data2) const;
206
211 void collect_top_ranked_candidates (std::vector<uint32_t>* top_candidates,
212 std::vector<std::vector<uint32_t>> const& grouped_features,
213 uint8_t dim_hash_data, uint16_t min_num_candidates,
214 uint16_t max_num_candidates) const;
215
216private:
217 GlobalData global_data;
218 std::vector<LocalData> local_data_sift;
219 std::vector<LocalData> local_data_surf;
220 Options cashash_opts;
221};
222
223/* ---------------------------------------------------------------- */
224
225template <typename T>
226void CascadeHashing::GlobalData::generate_proj_matrices (
227 std::vector<T>* prim_hash, std::vector<std::vector<T>>* sec_hash,
228 Options const& cashash_opts)
229{
230 int const dim_desc = T::dim;
231 int const dim_hash_data = dim_desc;
232
233 prim_hash->resize(dim_hash_data);
234
235 /* Setup PRNG. */
236 std::mt19937 prng_mt(0);
237 std::normal_distribution<> dis(0, 1);
238
239 /* Generate values for primary hashing function. */
240 for (uint8_t i = 0; i < dim_hash_data; i++)
241 for (uint8_t j = 0; j < dim_desc; j++)
242 (*prim_hash)[i][j] = dis(prng_mt);
243
244 uint8_t const num_bucket_groups = cashash_opts.num_bucket_groups;
245 uint8_t const num_bucket_bits = cashash_opts.num_bucket_bits;
246
247 sec_hash->resize(num_bucket_groups, std::vector<T>(num_bucket_bits));
248
249 /* Generate values for secondary hashing function. */
250 for (uint8_t group_idx = 0; group_idx < num_bucket_groups; group_idx++)
251 for (uint8_t i = 0; i < num_bucket_bits; i++)
252 for (uint8_t j = 0; j < dim_desc; j++)
253 (*sec_hash)[group_idx][i][j] = dis(prng_mt);
254}
255
256/* ---------------------------------------------------------------- */
257
258template <typename T>
259void
260CascadeHashing::compute_cascade_hashes (std::vector<T> const& zero_mean_descs,
261 std::vector<uint64_t>* comp_hash_data,
262 BucketGroupsBuckets* bucket_grps_bucket_ids,
263 std::vector<T> const& prim_proj_mat,
264 std::vector<std::vector<T>> const& sec_proj_mats,
265 Options const& cashash_opts)
266{
267 int const dim_desc = T::dim;
268 int const dim_hash_data = dim_desc;
269 uint8_t const dim_comp_hash_data = dim_hash_data / 64;
270 uint8_t const num_bucket_bits = cashash_opts.num_bucket_bits;
271 uint8_t const num_bucket_grps = cashash_opts.num_bucket_groups;
272
273 /* Allocate memory. */
274 size_t num_descs = zero_mean_descs.size();
275 comp_hash_data->resize(num_descs * dim_comp_hash_data);
276 bucket_grps_bucket_ids->resize(num_bucket_grps, std::vector<uint16_t>(num_descs));
277
278 for (size_t i = 0; i < num_descs; i++)
279 {
280 T const& desc = zero_mean_descs[i];
281
282 /* Compute hash code. */
283 for (uint8_t j = 0; j < dim_comp_hash_data; j++)
284 {
285 uint64_t comp_hash = 0;
286 uint8_t data_start = j * 64;
287 uint8_t data_end = (j + 1) * 64;
288 for (uint8_t k = data_start; k < data_end; k++)
289 {
290 T const& proj_vec = prim_proj_mat[k];
291 float sum = desc.dot(proj_vec);
292 comp_hash = (comp_hash << 1) | (sum > 0.0f);
293 }
294 (*comp_hash_data)[i * dim_comp_hash_data + j] = comp_hash;
295 }
296
297 /* Determine the descriptor's bucket index for each bucket group. */
298 for (uint8_t grp_idx = 0; grp_idx < num_bucket_grps; grp_idx++)
299 {
300 uint16_t bucket_id = 0;
301 for (uint8_t bit_idx = 0; bit_idx < num_bucket_bits; bit_idx++)
302 {
303 T const& proj_vec = sec_proj_mats[grp_idx][bit_idx];
304 float sum = desc.dot(proj_vec);
305 bucket_id = (bucket_id << 1) | (sum > 0.0f);
306 }
307
308 (*bucket_grps_bucket_ids)[grp_idx][i] = bucket_id;
309 }
310 }
311}
312
313template <typename D>
314void
315CascadeHashing::twoway_match (Matching::Options const& matching_opts,
316 LocalData const& set_1, LocalData const& set_2,
317 D const& set_1_descs, D const& set_2_descs,
318 Matching::Result* matches, Options const& cashash_opts) const
319{
320 oneway_match(matching_opts, set_1, set_2, set_1_descs, set_2_descs,
321 &matches->matches_1_2,
322 cashash_opts);
323 oneway_match(matching_opts, set_2, set_1, set_2_descs, set_1_descs,
324 &matches->matches_2_1,
325 cashash_opts);
326}
327
328template <typename D>
329void
330CascadeHashing::oneway_match (Matching::Options const& matching_opts,
331 LocalData const& set_1, LocalData const& set_2,
332 D const& set_1_descs, D const& set_2_descs,
333 std::vector<int>* result, Options const& cashash_opts) const
334{
335 typedef typename D::value_type V;
336 typedef typename V::ValueType T;
337
338 size_t set_1_size = set_1_descs.size();
339 size_t set_2_size = set_2_descs.size();
340
341 if (set_1_size == 0 || set_2_size == 0)
342 return;
343
344 float const square_lowe_thres = MATH_POW2(matching_opts.lowe_ratio_threshold);
345 float const square_dist_thres = MATH_POW2(matching_opts.distance_threshold);
346
347 uint16_t const min_num_candidates = cashash_opts.min_num_candidates;
348 uint16_t const max_num_candidates = cashash_opts.max_num_candidates;
349 int const descriptor_length = V::dim;
350 uint8_t const dim_hash_data = descriptor_length;
351 uint32_t const dim_comp_hash_data = dim_hash_data / 64;
352
353 result->resize(set_1_size, -1);
354 std::vector<bool> data_index_used(set_2_size);
355 std::vector<std::vector<uint32_t> > grouped_features(dim_hash_data + 1);
356 std::vector<uint32_t> top_candidates;
357
358 top_candidates.reserve(max_num_candidates);
359
360 std::unique_ptr<T> tmp(new T[max_num_candidates * descriptor_length]);
361 NearestNeighbor<T> nn;
362 nn.set_elements(tmp.get());
363 nn.set_element_dimensions(descriptor_length);
364
365 for (size_t i = 0; i < set_1_size; i++)
366 {
367 std::fill(data_index_used.begin(), data_index_used.end(), false);
368
369 for (size_t j = 0; j < grouped_features.size(); j++)
370 grouped_features[j].clear();
371
372 top_candidates.clear();
373
374 /* Fetch candidate features from the buckets in each group. */
375 collect_features_from_buckets<dim_hash_data>(
376 &grouped_features,
377 i,
378 &data_index_used,
379 set_1.bucket_grps_bucket_ids,
380 set_2.bucket_grps_feature_ids,
381 &set_1.comp_hash_data[i * dim_comp_hash_data],
382 set_2.comp_hash_data);
383
384 /* Add closest candidates by Hamming distance to top_candidates vector. */
385 collect_top_ranked_candidates(
386 &top_candidates,
387 grouped_features,
388 dim_hash_data,
389 min_num_candidates,
390 max_num_candidates);
391
392 /* Copy top candidates' descriptors into a contiguous array. */
393 for (size_t j = 0; j < top_candidates.size(); j++)
394 {
395 uint32_t candidate_id = top_candidates[j];
396 std::memcpy(tmp.get() + j * descriptor_length,
397 set_2_descs[candidate_id].begin(),
398 sizeof(V));
399 }
400
401 typename NearestNeighbor<T>::Result nn_result;
402 nn.set_num_elements(top_candidates.size());
403 nn.find(set_1_descs[i].begin(), &nn_result);
404
405 if (nn_result.dist_1st_best > square_dist_thres)
406 continue;
407
408 if (static_cast<float>(nn_result.dist_1st_best)
409 / static_cast<float>(nn_result.dist_2nd_best)
410 > square_lowe_thres)
411 continue;
412
413 result->at(i) = top_candidates[nn_result.index_1st_best];
414 }
415}
416
417template <int DIMHASH>
418void
419CascadeHashing::collect_features_from_buckets (
420 std::vector<std::vector<uint32_t>>* grouped_features,
421 size_t feature_id, std::vector<bool>* data_index_used,
422 BucketGroupsBuckets const& bucket_grps_bucket_ids,
423 BucketGroupsFeatures const& bucket_grps_feature_ids,
424 uint64_t const* comp_hash_data1,
425 std::vector<uint64_t> const& comp_hash_data2) const
426{
427 size_t num_bucket_grps = bucket_grps_bucket_ids.size();
428 for (size_t grp_idx = 0; grp_idx < num_bucket_grps; grp_idx++)
429 {
430 uint8_t bucket_id = bucket_grps_bucket_ids[grp_idx][feature_id];
431 BucketFeatureIDs const& bucket_feature_ids = bucket_grps_feature_ids[grp_idx][bucket_id];
432
433 for (size_t j = 0; j < bucket_feature_ids.size(); j++)
434 {
435 size_t candidate_id = bucket_feature_ids[j];
436 if ((*data_index_used)[candidate_id])
437 continue;
438
439 uint64_t const *ptr2 = &comp_hash_data2[candidate_id * DIMHASH / 64];
440 int hamming_dist = 0;
441 for (int k = 0; k < (DIMHASH / 64); k++)
442 hamming_dist += math::popcount(comp_hash_data1[k] ^ ptr2[k]);
443
444 (*grouped_features)[hamming_dist].emplace_back(candidate_id);
445 (*data_index_used)[candidate_id] = true;
446 }
447 }
448}
449
450inline void
451CascadeHashing::collect_top_ranked_candidates (
452 std::vector<uint32_t>* top_candidates,
453 std::vector<std::vector<uint32_t>> const& grouped_features,
454 uint8_t dim_hash_data, uint16_t min_num_candidates,
455 uint16_t max_num_candidates) const
456{
457 for (uint16_t hash_dist = 0; hash_dist <= dim_hash_data; hash_dist++)
458 {
459 for (size_t j = 0; j < grouped_features[hash_dist].size(); j++)
460 {
461 uint32_t idx2 = grouped_features[hash_dist][j];
462 top_candidates->emplace_back(idx2);
463
464 if (top_candidates->size() >= max_num_candidates)
465 break;
466 }
467
468 if (top_candidates->size() >= min_num_candidates)
469 break;
470 }
471}
472
474
475#endif /* SFM_CASCADE_HASHING_HEADER */
#define MATH_POW2(x)
Definition defines.h:68
std::size_t constexpr popcount(T const x)
Returns the number of one bits of an integer.
Definition functions.h:166
std::vector< Viewport > ViewportList
The list of all viewports considered for bundling.
#define SFM_NAMESPACE_END
Definition defines.h:14
#define SFM_NAMESPACE_BEGIN
Definition defines.h:13
Feature matching options.
Definition matching.h:30
Feature matching result reported as two lists, each with indices in the other set.
Definition matching.h:57