10#ifndef SFM_CASCADE_HASHING_HEADER
11#define SFM_CASCADE_HASHING_HEADER
35 uint8_t num_bucket_groups = 6;
38 uint8_t num_bucket_bits = 8;
41 uint16_t min_num_candidates = 6;
44 uint16_t max_num_candidates = 10;
54 void pairwise_match (
int view_1_id,
int view_2_id,
66 LocalData
const& set_1, LocalData
const& set_2,
67 D
const& set_1_descs, D
const& set_2_descs,
76 LocalData
const& set_1, LocalData
const& set_2,
77 D
const& set_1_descs, D
const& set_2_descs,
78 std::vector<int>* result,
Options const& cashash_opts)
const;
96 std::vector<T> prim_proj_mat;
103 std::vector<std::vector<T>> sec_proj_mats;
106 typedef ProjMats<math::Vec64f> ProjMatsSurf;
107 typedef ProjMats<math::Vec128f> ProjMatsSift;
118 void generate_proj_matrices (Options
const& cashash_opts);
124 template <
typename T>
125 void generate_proj_matrices (std::vector<T>* prim_hash,
126 std::vector<std::vector<T>>* sec_hash, Options
const& cashash_opts);
129 typedef std::vector<std::size_t> BucketFeatureIDs;
130 typedef std::vector<BucketFeatureIDs> BucketGroupFeatures;
131 typedef std::vector<BucketGroupFeatures> BucketGroupsFeatures;
133 typedef std::vector<uint16_t> BucketIDs;
134 typedef std::vector<BucketIDs> BucketGroupsBuckets;
146 std::vector<uint64_t> comp_hash_data;
153 BucketGroupsBuckets bucket_grps_bucket_ids;
160 BucketGroupsFeatures bucket_grps_feature_ids;
164 void compute (LocalData* ld_sift, LocalData* ld_surf,
165 std::vector<math::Vec128f>
const& sift_zero_mean_descs,
166 std::vector<math::Vec64f>
const& surf_zero_mean_descs,
167 GlobalData
const& cashash_global_data, Options
const& cashash_opts);
170 void compute_avg_descriptors (ProcessedFeatureSets
const& pfs,
174 void compute_zero_mean_descs(
175 std::vector<math::Vec128f>* sift_zero_mean_descs,
176 std::vector<math::Vec64f>* surf_zero_mean_descs,
177 SiftDescriptors
const& sift_descs, SurfDescriptors
const& surf_descs,
181 template <
typename T>
182 void compute_cascade_hashes (std::vector<T>
const& zero_mean_descs,
183 std::vector<uint64_t>* comp_hash_data,
184 BucketGroupsBuckets* bucket_grps_bucket_ids,
185 std::vector<T>
const& prim_proj_mat,
186 std::vector<std::vector<T>>
const& sec_proj_mats, Options
const& cashash_opts);
189 void build_buckets (BucketGroupsFeatures* bucket_grps_feature_ids,
190 BucketGroupsBuckets
const& bucket_grps_bucket_ids,
size_t num_descs,
191 Options
const& cashash_opts);
198 template <
int DIMHASH>
199 void collect_features_from_buckets (
200 std::vector<std::vector<uint32_t>>* grouped_features,
201 size_t feature_id, std::vector<bool>* data_index_used,
202 BucketGroupsBuckets
const& bucket_grps_bucket_ids,
203 BucketGroupsFeatures
const& bucket_grps_feature_ids,
204 uint64_t
const* comp_hash_data1,
205 std::vector<uint64_t>
const& comp_hash_data2)
const;
211 void collect_top_ranked_candidates (std::vector<uint32_t>* top_candidates,
212 std::vector<std::vector<uint32_t>>
const& grouped_features,
213 uint8_t dim_hash_data, uint16_t min_num_candidates,
214 uint16_t max_num_candidates)
const;
217 GlobalData global_data;
218 std::vector<LocalData> local_data_sift;
219 std::vector<LocalData> local_data_surf;
220 Options cashash_opts;
226void CascadeHashing::GlobalData::generate_proj_matrices (
227 std::vector<T>* prim_hash, std::vector<std::vector<T>>* sec_hash,
228 Options
const& cashash_opts)
230 int const dim_desc = T::dim;
231 int const dim_hash_data = dim_desc;
233 prim_hash->resize(dim_hash_data);
236 std::mt19937 prng_mt(0);
237 std::normal_distribution<> dis(0, 1);
240 for (uint8_t i = 0; i < dim_hash_data; i++)
241 for (uint8_t j = 0; j < dim_desc; j++)
242 (*prim_hash)[i][j] = dis(prng_mt);
244 uint8_t
const num_bucket_groups = cashash_opts.num_bucket_groups;
245 uint8_t
const num_bucket_bits = cashash_opts.num_bucket_bits;
247 sec_hash->resize(num_bucket_groups, std::vector<T>(num_bucket_bits));
250 for (uint8_t group_idx = 0; group_idx < num_bucket_groups; group_idx++)
251 for (uint8_t i = 0; i < num_bucket_bits; i++)
252 for (uint8_t j = 0; j < dim_desc; j++)
253 (*sec_hash)[group_idx][i][j] = dis(prng_mt);
260CascadeHashing::compute_cascade_hashes (std::vector<T>
const& zero_mean_descs,
261 std::vector<uint64_t>* comp_hash_data,
262 BucketGroupsBuckets* bucket_grps_bucket_ids,
263 std::vector<T>
const& prim_proj_mat,
264 std::vector<std::vector<T>>
const& sec_proj_mats,
265 Options
const& cashash_opts)
267 int const dim_desc = T::dim;
268 int const dim_hash_data = dim_desc;
269 uint8_t
const dim_comp_hash_data = dim_hash_data / 64;
270 uint8_t
const num_bucket_bits = cashash_opts.num_bucket_bits;
271 uint8_t
const num_bucket_grps = cashash_opts.num_bucket_groups;
274 size_t num_descs = zero_mean_descs.size();
275 comp_hash_data->resize(num_descs * dim_comp_hash_data);
276 bucket_grps_bucket_ids->resize(num_bucket_grps, std::vector<uint16_t>(num_descs));
278 for (
size_t i = 0; i < num_descs; i++)
280 T
const& desc = zero_mean_descs[i];
283 for (uint8_t j = 0; j < dim_comp_hash_data; j++)
285 uint64_t comp_hash = 0;
286 uint8_t data_start = j * 64;
287 uint8_t data_end = (j + 1) * 64;
288 for (uint8_t k = data_start; k < data_end; k++)
290 T
const& proj_vec = prim_proj_mat[k];
291 float sum = desc.dot(proj_vec);
292 comp_hash = (comp_hash << 1) | (sum > 0.0f);
294 (*comp_hash_data)[i * dim_comp_hash_data + j] = comp_hash;
298 for (uint8_t grp_idx = 0; grp_idx < num_bucket_grps; grp_idx++)
300 uint16_t bucket_id = 0;
301 for (uint8_t bit_idx = 0; bit_idx < num_bucket_bits; bit_idx++)
303 T
const& proj_vec = sec_proj_mats[grp_idx][bit_idx];
304 float sum = desc.dot(proj_vec);
305 bucket_id = (bucket_id << 1) | (sum > 0.0f);
308 (*bucket_grps_bucket_ids)[grp_idx][i] = bucket_id;
315CascadeHashing::twoway_match (Matching::Options
const& matching_opts,
316 LocalData
const& set_1, LocalData
const& set_2,
317 D
const& set_1_descs, D
const& set_2_descs,
318 Matching::Result* matches, Options
const& cashash_opts)
const
320 oneway_match(matching_opts, set_1, set_2, set_1_descs, set_2_descs,
321 &matches->matches_1_2,
323 oneway_match(matching_opts, set_2, set_1, set_2_descs, set_1_descs,
324 &matches->matches_2_1,
330CascadeHashing::oneway_match (Matching::Options
const& matching_opts,
331 LocalData
const& set_1, LocalData
const& set_2,
332 D
const& set_1_descs, D
const& set_2_descs,
333 std::vector<int>* result, Options
const& cashash_opts)
const
335 typedef typename D::value_type V;
336 typedef typename V::ValueType T;
338 size_t set_1_size = set_1_descs.size();
339 size_t set_2_size = set_2_descs.size();
341 if (set_1_size == 0 || set_2_size == 0)
344 float const square_lowe_thres =
MATH_POW2(matching_opts.lowe_ratio_threshold);
345 float const square_dist_thres =
MATH_POW2(matching_opts.distance_threshold);
347 uint16_t
const min_num_candidates = cashash_opts.min_num_candidates;
348 uint16_t
const max_num_candidates = cashash_opts.max_num_candidates;
349 int const descriptor_length = V::dim;
350 uint8_t
const dim_hash_data = descriptor_length;
351 uint32_t
const dim_comp_hash_data = dim_hash_data / 64;
353 result->resize(set_1_size, -1);
354 std::vector<bool> data_index_used(set_2_size);
355 std::vector<std::vector<uint32_t> > grouped_features(dim_hash_data + 1);
356 std::vector<uint32_t> top_candidates;
358 top_candidates.reserve(max_num_candidates);
360 std::unique_ptr<T> tmp(
new T[max_num_candidates * descriptor_length]);
361 NearestNeighbor<T> nn;
362 nn.set_elements(tmp.get());
363 nn.set_element_dimensions(descriptor_length);
365 for (
size_t i = 0; i < set_1_size; i++)
367 std::fill(data_index_used.begin(), data_index_used.end(),
false);
369 for (
size_t j = 0; j < grouped_features.size(); j++)
370 grouped_features[j].clear();
372 top_candidates.clear();
375 collect_features_from_buckets<dim_hash_data>(
379 set_1.bucket_grps_bucket_ids,
380 set_2.bucket_grps_feature_ids,
381 &set_1.comp_hash_data[i * dim_comp_hash_data],
382 set_2.comp_hash_data);
385 collect_top_ranked_candidates(
393 for (
size_t j = 0; j < top_candidates.size(); j++)
395 uint32_t candidate_id = top_candidates[j];
396 std::memcpy(tmp.get() + j * descriptor_length,
397 set_2_descs[candidate_id].begin(),
401 typename NearestNeighbor<T>::Result nn_result;
402 nn.set_num_elements(top_candidates.size());
403 nn.find(set_1_descs[i].begin(), &nn_result);
405 if (nn_result.dist_1st_best > square_dist_thres)
408 if (
static_cast<float>(nn_result.dist_1st_best)
409 /
static_cast<float>(nn_result.dist_2nd_best)
413 result->at(i) = top_candidates[nn_result.index_1st_best];
417template <
int DIMHASH>
419CascadeHashing::collect_features_from_buckets (
420 std::vector<std::vector<uint32_t>>* grouped_features,
421 size_t feature_id, std::vector<bool>* data_index_used,
422 BucketGroupsBuckets
const& bucket_grps_bucket_ids,
423 BucketGroupsFeatures
const& bucket_grps_feature_ids,
424 uint64_t
const* comp_hash_data1,
425 std::vector<uint64_t>
const& comp_hash_data2)
const
427 size_t num_bucket_grps = bucket_grps_bucket_ids.size();
428 for (
size_t grp_idx = 0; grp_idx < num_bucket_grps; grp_idx++)
430 uint8_t bucket_id = bucket_grps_bucket_ids[grp_idx][feature_id];
431 BucketFeatureIDs
const& bucket_feature_ids = bucket_grps_feature_ids[grp_idx][bucket_id];
433 for (
size_t j = 0; j < bucket_feature_ids.size(); j++)
435 size_t candidate_id = bucket_feature_ids[j];
436 if ((*data_index_used)[candidate_id])
439 uint64_t
const *ptr2 = &comp_hash_data2[candidate_id * DIMHASH / 64];
440 int hamming_dist = 0;
441 for (
int k = 0; k < (DIMHASH / 64); k++)
444 (*grouped_features)[hamming_dist].emplace_back(candidate_id);
445 (*data_index_used)[candidate_id] =
true;
451CascadeHashing::collect_top_ranked_candidates (
452 std::vector<uint32_t>* top_candidates,
453 std::vector<std::vector<uint32_t>>
const& grouped_features,
454 uint8_t dim_hash_data, uint16_t min_num_candidates,
455 uint16_t max_num_candidates)
const
457 for (uint16_t hash_dist = 0; hash_dist <= dim_hash_data; hash_dist++)
459 for (
size_t j = 0; j < grouped_features[hash_dist].size(); j++)
461 uint32_t idx2 = grouped_features[hash_dist][j];
462 top_candidates->emplace_back(idx2);
464 if (top_candidates->size() >= max_num_candidates)
468 if (top_candidates->size() >= min_num_candidates)
std::size_t constexpr popcount(T const x)
Returns the number of one bits of an integer.
std::vector< Viewport > ViewportList
The list of all viewports considered for bundling.
#define SFM_NAMESPACE_END
#define SFM_NAMESPACE_BEGIN
Feature matching options.
Feature matching result reported as two lists, each with indices in the other set.