size_t max_elements_{0}; mutable std::atomic<size_t> cur_element_count{0}; // current number of elements size_t size_data_per_element_{0}; size_t size_links_per_element_{0}; mutable std::atomic<size_t> num_deleted_{0}; // number of deleted elements size_t M_{0}; size_t maxM_{0}; size_t maxM0_{0}; size_t ef_construction_{0}; size_t ef_{ 0 };
double mult_{0.0}, revSize_{0.0}; int maxlevel_{0};
bool allow_replace_deleted_ = false; // flag to replace deleted elements (marked as deleted) during insertions
std::mutex deleted_elements_lock; // lock for deleted_elements std::unordered_set<tableint> deleted_elements; // contains internal ids of deleted elements
/* * Adds point. Updates the point if it is already in the index. * If replacement of deleted elements is enabled: replaces previously deleted point if any, updating it with new point */ voidaddPoint(constvoid *data_point, labeltype label, bool replace_deleted = false){ if ((allow_replace_deleted_ == false) && (replace_deleted == true)) { throw std::runtime_error("Replacement of deleted elements is disabled in constructor"); }
// lock all operations with element by label std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label)); if (!replace_deleted) { addPoint(data_point, label, -1); return; } // check if there is vacant place tableint internal_id_replaced; std::unique_lock <std::mutex> lock_deleted_elements(deleted_elements_lock); bool is_vacant_place = !deleted_elements.empty(); if (is_vacant_place) { internal_id_replaced = *deleted_elements.begin(); deleted_elements.erase(internal_id_replaced); } lock_deleted_elements.unlock();
// if there is no vacant place then add or update point // else add point to vacant place if (!is_vacant_place) { addPoint(data_point, label, -1); } else { // we assume that there are no concurrent operations on deleted element labeltype label_replaced = getExternalLabel(internal_id_replaced); setExternalLabel(internal_id_replaced, label);
intmain(){ int dim = 16; // Dimension of the elements int max_elements = 10000; // Maximum number of elements, should be known beforehand int M = 16; // Tightly connected with internal dimensionality of the data // strongly affects the memory consumption int ef_construction = 200; // Controls index search speed/build speed tradeoff
// Initing index hnswlib::L2Space space(dim); hnswlib::HierarchicalNSW<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, max_elements, M, ef_construction);
// Generate random data std::mt19937 rng; rng.seed(47); std::uniform_real_distribution<> distrib_real; float* data = newfloat[dim * max_elements]; for (int i = 0; i < dim * max_elements; i++) { data[i] = distrib_real(rng); }
// Add data to index for (int i = 0; i < max_elements; i++) { alg_hnsw->addPoint(data + i * dim, i); }
可以看到, addPoint 实际传入的参数是 random data 中的第 i 个元素 和 i。这下可以根据命名猜测了, label 是用于标记向量的一个类似于主键的标签。而 replace_deleted 默认为 false ,所以我们接着看注解:
If replacement of deleted elements is enabled: replaces previously deleted point if any, updating it with new point
// lock all operations with element by label std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label)); if (!replace_deleted) { addPoint(data_point, label, -1); return; }
// check if there is vacant place tableint internal_id_replaced; std::unique_lock <std::mutex> lock_deleted_elements(deleted_elements_lock); bool is_vacant_place = !deleted_elements.empty(); if (is_vacant_place) { internal_id_replaced = *deleted_elements.begin(); deleted_elements.erase(internal_id_replaced); } lock_deleted_elements.unlock();
// if there is no vacant place then add or update point // else add point to vacant place if (!is_vacant_place) { addPoint(data_point, label, -1); } else { // we assume that there are no concurrent operations on deleted element labeltype label_replaced = getExternalLabel(internal_id_replaced); setExternalLabel(internal_id_replaced, label);
tableint addPoint(constvoid *data_point, labeltype label, int level){
tableint cur_c = 0; { // Checking if the element with the same label already exists // if so, updating it *instead* of creating a new element. std::unique_lock <std::mutex> lock_table(label_lookup_lock); auto search = label_lookup_.find(label); if (search != label_lookup_.end()) { tableint existingInternalId = search->second; if (allow_replace_deleted_) { if (isMarkedDeleted(existingInternalId)) { throw std::runtime_error("Can't use addPoint to update deleted elements if replacement of deleted elements is enabled."); } } lock_table.unlock();
if (isMarkedDeleted(existingInternalId)) { unmarkDeletedInternal(existingInternalId); } updatePoint(data_point, existingInternalId, 1.0);
return existingInternalId; }
if (cur_element_count >= max_elements_) { throw std::runtime_error("The number of elements exceeds the specified limit"); }
// Initialisation of the data and label memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype)); memcpy(getDataByInternalId(cur_c), data_point, data_size_);
tableint cur_c = 0; { // Checking if the element with the same label already exists // if so, updating it *instead* of creating a new element. std::unique_lock <std::mutex> lock_table(label_lookup_lock); auto search = label_lookup_.find(label); if (search != label_lookup_.end()) { tableint existingInternalId = search->second; if (allow_replace_deleted_) { if (isMarkedDeleted(existingInternalId)) { throw std::runtime_error("Can't use addPoint to update deleted elements if replacement of deleted elements is enabled."); } } lock_table.unlock();
if (isMarkedDeleted(existingInternalId)) { unmarkDeletedInternal(existingInternalId); } updatePoint(data_point, existingInternalId, 1.0);
return existingInternalId; }
if (cur_element_count >= max_elements_) { throw std::runtime_error("The number of elements exceeds the specified limit"); }
VisitedList(int numelements1) { curV = -1; numelements = numelements1; mass = new vl_type[numelements]; }
voidreset(){ curV++; if (curV == 0) { memset(mass, 0, sizeof(vl_type) * numelements); curV++; } } s ~VisitedList() { delete[] mas; } }; /////////////////////////////////////////////////////////// // // Class for multi-threaded pool-management of VisitedLists // /////////////////////////////////////////////////////////
classVisitedListPool { std::deque<VisitedList *> pool; std::mutex poolguard; int numelements;
public: VisitedListPool(int initmaxpools, int numelements1) { numelements = numelements1; for (int i = 0; i < initmaxpools; i++) pool.push_front(newVisitedList(numelements)); }
接着,通过找到的目前候选队列中的最近邻,获取对应节点的相邻节点的数据(也就是获取全部的 link )。如果当前是底层,则使用 get_linklist0 从 data_level0_memory_ 中获取节点的邻节点,否则使用 get_linklist 从 linkLists_ 中的对应层位置取出节点的 link 数据。这里拿到的 data 的第一个 unsigned short int 还是存储了有多少个邻居,因此 size 就是当前遍历的节点的邻居数量。
接下来的这段掺杂了预编译指令,主要目的是判断当前 cpu 是否支持 SSE。SSE 全名为 Streaming SIMD Extensions ,用于在单指令多数据的指令集下进行计算效率的优化。
size_t Mcurmax = level ? maxM_ : maxM0_; getNeighborsByHeuristic2(top_candidates, M_); if (top_candidates.size() > M_) throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic");
{ // lock only during the update // because during the addition the lock for cur_c is already acquired std::unique_lock <std::mutex> lock(link_list_locks_[cur_c], std::defer_lock); if (isUpdate) { lock.lock(); } linklistsizeint *ll_cur; if (level == 0) ll_cur = get_linklist0(cur_c); else ll_cur = get_linklist(cur_c, level);
if (*ll_cur && !isUpdate) { throw std::runtime_error("The newly inserted element should have blank link list"); } setListCount(ll_cur, selectedNeighbors.size()); tableint *data = (tableint *) (ll_cur + 1); for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { if (data[idx] && !isUpdate) throw std::runtime_error("Possible memory corruption"); if (level > element_levels_[selectedNeighbors[idx]]) throw std::runtime_error("Trying to make a link on a non-existent level");
if (sz_link_list_other > Mcurmax) throw std::runtime_error("Bad value of sz_link_list_other"); if (selectedNeighbors[idx] == cur_c) throw std::runtime_error("Trying to connect an element to itself"); if (level > element_levels_[selectedNeighbors[idx]]) throw std::runtime_error("Trying to make a link on a non-existent level");
tableint *data = (tableint *) (ll_other + 1);
bool is_cur_c_present = false; if (isUpdate) { for (size_t j = 0; j < sz_link_list_other; j++) { if (data[j] == cur_c) { is_cur_c_present = true; break; } } }
// If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics. if (!is_cur_c_present) { if (sz_link_list_other < Mcurmax) { data[sz_link_list_other] = cur_c; setListCount(ll_other, sz_link_list_other + 1); } else { // finding the "weakest" element to replace it with the new one dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]), dist_func_param_); // Heuristic: std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates; candidates.emplace(d_max, cur_c);
size_t Mcurmax = level ? maxM_ : maxM0_; getNeighborsByHeuristic2(top_candidates, M_); if (top_candidates.size() > M_) throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic");
首先是对 maxM_ 和 maxM0_ 取 max 。这里我们并不理解其中的含义,因此我们还需要阅读示例文件,示例文件的相关内容我列在下面。
1 2 3
// example_search.cpp 中的说明 int M = 16; // Tightly connected with internal dimensionality of the data // strongly affects the memory consumption
示例文件 example_search.cpp 中的注解翻译一下的意思是:
与数据的内部维度紧密相连、强烈地影响着内存消耗
但是我们还是读不懂这是在约定什么,因此需要读论文,论文中对 M 和 Mmax 的描述如下:
number of established connections M
maximum number of connections for each element per layer Mmax
// 构造函数中与 M 相关的代码 if ( M <= 10000 ) { M_ = M; } else { HNSWERR << "warning: M parameter exceeds 10000 which may lead to adverse effects." << std::endl; HNSWERR << " Cap to 10000 will be applied for the rest of the processing." << std::endl; M_ = 10000; } maxM_ = M_; // 设置为M maxM0_ = M_ * 2; // 设置为2M ef_construction_ = std::max(ef_construction, M_);
while (queue_closest.size()) { if (return_list.size() >= M) break; std::pair<dist_t, tableint> curent_pair = queue_closest.top(); dist_t dist_to_query = -curent_pair.first; // 当前点curent_pair距离查询点的距离 queue_closest.pop(); bool good = true;
for (std::pair<dist_t, tableint> second_pair : return_list) { // 计算return_list中的点second_pair与当前点curent_pair的距离 dist_t curdist = fstdistfunc_(getDataByInternalId(second_pair.second), getDataByInternalId(curent_pair.second), dist_func_param_); if (curdist < dist_to_query) { good = false; break; } } if (good) { return_list.push_back(curent_pair); } }
这里的实现,在论文中是这样描述的:
可以看到最关键的一句话是:
if e is closer to q compared to any element from R
按照字面理解,我想当然地认为这句话的含义是:
e比R中的任何元素都更接近q
我这里又通过各种翻译软件确定了我理解的没什么大问题:
有道翻译:如果 e 相对于集合 R 中的任何元素都更接近于 q,那么 e 就是 R 中与 q 距离最近的元素。 百度翻译: 如果e比R中的任何元素更接近q Deepl:如果与 R 中的任何元素相比,e 更接近 q
但是这样的理解会与上面这句话冲突:
e ← extract nearest element from W to q
这里从 W 中取出距离 q 最近的元素作为 e ,那么第二遍及以后的循环中,上述的判断不是恒 false 吗?这个函数到底有什么用?当时阅读论文的时候我就不理解。
但是,读了 hnswlib 实现后,我发现这是一个歧义引发的问题。这句话的正确翻译是:
相比于 e 距离 R 中的任意元素,e 距离 q 更近
代码中的 dist_to_query 是这个点 e 到 待插入点q 的距离,而循环中的 curdist 是点 e 到 return_list 中单个元素的距离。只有在 dist_to_query 大于全部的 curdist 时,才会被加入到 return_list。这与我们纠正了歧义后的理解一致。因此,我们通过阅读实现代码纠正了阅读论文中的错误理解!
总而言之, getNeighborsByHeuristic2 函数的功能我们明白了,从输入的候选队列中剪切出 M 个最适合的邻居。如果比 M 少,就直接返回。
mutuallyConnectNewElement Part2
那么我们接着读 mutuallyConnectNewElement 。
1 2 3 4
size_t Mcurmax = level ? maxM_ : maxM0_; getNeighborsByHeuristic2(top_candidates, M_); if (top_candidates.size() > M_) throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic");
这段找近邻的代码我们读完了。第一行也理解了,如果当前节点插入的是0层,就最多有 2M 个邻居,否则最多有 M 个邻居。(这里的 M 是初始化构建索引时设定的 M )
{ // lock only during the update // because during the addition the lock for cur_c is already acquired std::unique_lock <std::mutex> lock(link_list_locks_[cur_c], std::defer_lock); if (isUpdate) { lock.lock(); } linklistsizeint *ll_cur; if (level == 0) ll_cur = get_linklist0(cur_c); else ll_cur = get_linklist(cur_c, level);
if (*ll_cur && !isUpdate) { throw std::runtime_error("The newly inserted element should have blank link list"); } setListCount(ll_cur, selectedNeighbors.size()); tableint *data = (tableint *) (ll_cur + 1); for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { if (data[idx] && !isUpdate) throw std::runtime_error("Possible memory corruption"); if (level > element_levels_[selectedNeighbors[idx]]) throw std::runtime_error("Trying to make a link on a non-existent level");
if (sz_link_list_other > Mcurmax) throw std::runtime_error("Bad value of sz_link_list_other"); if (selectedNeighbors[idx] == cur_c) throw std::runtime_error("Trying to connect an element to itself"); if (level > element_levels_[selectedNeighbors[idx]]) throw std::runtime_error("Trying to make a link on a non-existent level");
tableint *data = (tableint *) (ll_other + 1);
bool is_cur_c_present = false; if (isUpdate) { for (size_t j = 0; j < sz_link_list_other; j++) { if (data[j] == cur_c) { is_cur_c_present = true; break; } } }
// If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics. if (!is_cur_c_present) { if (sz_link_list_other < Mcurmax) { data[sz_link_list_other] = cur_c; setListCount(ll_other, sz_link_list_other + 1); } else { // finding the "weakest" element to replace it with the new one dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]), dist_func_param_); // Heuristic: std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates; candidates.emplace(d_max, cur_c);
if (sz_link_list_other > Mcurmax) throw std::runtime_error("Bad value of sz_link_list_other"); if (selectedNeighbors[idx] == cur_c) throw std::runtime_error("Trying to connect an element to itself"); if (level > element_levels_[selectedNeighbors[idx]]) throw std::runtime_error("Trying to make a link on a non-existent level");
通过检查后,对这个邻居节点,遍历一遍它的边,检查是否有当前的插入节点:
1 2 3 4 5 6 7 8 9 10 11 12
tableint *data = (tableint *) (ll_other + 1);
bool is_cur_c_present = false; if (isUpdate) { for (size_t j = 0; j < sz_link_list_other; j++) { if (data[j] == cur_c) { is_cur_c_present = true; break; } } }
} else { // finding the "weakest" element to replace it with the new one dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]), dist_func_param_); // Heuristic: std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates; candidates.emplace(d_max, cur_c);
// bare_bone_search means there is no check for deletions and stop condition is ignored in return of extra performance template <bool bare_bone_search = true, bool collect_metrics = false> std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> searchBaseLayerST( tableint ep_id, constvoid *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr, BaseSearchStopCondition<dist_t>* stop_condition = nullptr) const {
bare_bone_search means there is no check for deletions and stop condition is ignored in return of extra performance
// bare_bone_search means there is no check for deletions and stop condition is ignored in return of extra performance template <bool bare_bone_search = true, bool collect_metrics = false> std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> searchBaseLayerST( tableint ep_id, constvoid *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr, BaseSearchStopCondition<dist_t>* stop_condition = nullptr) const {