hnswlib源码精读

hnswlib 简介

Hnswlib 是一个强大的近邻搜索(ANN) 库,热门的向量数据库 Milvus 底层的 ANN 库之一就是 Hnswlib , 为 Milvus 提供HNSW检索。

项目地址:

https://github.com/nmslib/hnswlib/tree/master

项目代码结构

pic5.png

项目的结构非常简单,对我们有用的( c++ 二次开发)是图上标红的几个部分。

examples/cpp

这个目录下存放着若干个 C++ 的使用范例。 EXAMPLES.md 文件对各个范例进行了简要的说明,在阅读代码前,应该先阅读这里的范例,了解使用方法。

hnswlib

这个目录下存放的是全部的源码,也是我们将阅读和改写的部分。

/home/hnswlib-master/tests/cpp

这个目录下存放着若干 C++ 实现的测试用例,有精力可以进行阅读。

CMakeLists.txt

这个文件是项目的CMakeLists,是唯一一个CMakeLists,如果 实现/新增 依赖,可能需要对这里进行改动。

源码阅读

前言

在这之前,你应该首先阅读 example 文件夹下的示例代码,尤其是 example_search.cpp ,弄明白这个算法包怎么用!

HierarchicalNSW 是 HNSW 算法的实现类,这是一个巨大类,有共1400行代码。

下面这段是 HierarchicalNSW 类的全部属性,可以看到这里的属性实在是太多了。如果没有重点地阅读,只会造成读了忘读了忘。找到你需要的函数和结构体去阅读,在阅读的过程中不断阅读相关函数,才是正确的方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
public:
static const tableint MAX_LABEL_OPERATION_LOCKS = 65536;
static const unsigned char DELETE_MARK = 0x01;

size_t max_elements_{0};
mutable std::atomic<size_t> cur_element_count{0}; // current number of elements
size_t size_data_per_element_{0};
size_t size_links_per_element_{0};
mutable std::atomic<size_t> num_deleted_{0}; // number of deleted elements
size_t M_{0};
size_t maxM_{0};
size_t maxM0_{0};
size_t ef_construction_{0};
size_t ef_{ 0 };

double mult_{0.0}, revSize_{0.0};
int maxlevel_{0};

std::unique_ptr<VisitedListPool> visited_list_pool_{nullptr};

// Locks operations with element by label value
mutable std::vector<std::mutex> label_op_locks_;

std::mutex global;
std::vector<std::mutex> link_list_locks_;

tableint enterpoint_node_{0};

size_t size_links_level0_{0};
size_t offsetData_{0}, offsetLevel0_{0}, label_offset_{ 0 };

char *data_level0_memory_{nullptr};
char **linkLists_{nullptr};
std::vector<int> element_levels_; // keeps level of each element

size_t data_size_{0};

DISTFUNC<dist_t> fstdistfunc_;
void *dist_func_param_{nullptr};

mutable std::mutex label_lookup_lock; // lock for label_lookup_
std::unordered_map<labeltype, tableint> label_lookup_;

std::default_random_engine level_generator_;
std::default_random_engine update_probability_generator_;

mutable std::atomic<long> metric_distance_computations{0};
mutable std::atomic<long> metric_hops{0};

bool allow_replace_deleted_ = false; // flag to replace deleted elements (marked as deleted) during insertions

std::mutex deleted_elements_lock; // lock for deleted_elements
std::unordered_set<tableint> deleted_elements; // contains internal ids of deleted elements

因此,我将源码阅读分了三步:

1、速览全文

2、找到入口

3、不断深入

开始吧

1、速览全文

首先,通过快速浏览了源码,了解代码的整个结构。 hnswlib 是一个非常优雅的小包,命名极为精炼。因此,通过一两分钟阅读函数名,你能清晰地理解每个函数大致的功能:

结构体:

结构体没有必要从开始就阅读,理解不了又记不住。

函数:

  1. 工具函数部分

从 CompareByFirst 开始,到 getDeletedCount 为止,全部都是简单的工具函数,没有必要阅读。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
struct CompareByFirst {
constexpr bool operator()(std::pair<dist_t, tableint> const& a,
std::pair<dist_t, tableint> const& b) const noexcept {
return a.first < b.first;
}
};


void setEf(size_t ef) {
ef_ = ef;
}


inline std::mutex& getLabelOpMutex(labeltype label) const {
// calculate hash
size_t lock_id = label & (MAX_LABEL_OPERATION_LOCKS - 1);
return label_op_locks_[lock_id];
}


inline labeltype getExternalLabel(tableint internal_id) const {
labeltype return_label;
memcpy(&return_label, (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype));
return return_label;
}


inline void setExternalLabel(tableint internal_id, labeltype label) const {
memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype));
}


inline labeltype *getExternalLabeLp(tableint internal_id) const {
return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_);
}


inline char *getDataByInternalId(tableint internal_id) const {
return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_);
}

/**
* @brief 这个函数是随机设置层级的函数
*/
int getRandomLevel(double reverse_size) {
std::uniform_real_distribution<double> distribution(0.0, 1.0);
double r = -log(distribution(level_generator_)) * reverse_size;
return (int) r;
}

size_t getMaxElements() {
return max_elements_;
}

size_t getCurrentElementCount() {
return cur_element_count;
}

size_t getDeletedCount() {
return num_deleted_;
}
  1. 搜索函数
    如下图,这三个函数是核心函数,但是这并不是我们的入口,所以暂时不需要仔细阅读。

pic1.png

  1. 工具函数 及 持久化函数
    接下来是一些用于数据访问的工具函数,此外还有几个持久化反持久化的工具函数。

pic2.png

  1. 删除相关函数
    接下来的函数是与删除相关的函数。这几个函数共同组合,对系统提供了伪删除的功能。但是我们暂时用不到,所以不用细看:

pic3.png

  1. 插入、查询函数
    接下来的函数是核心函数,也是我们阅读代码的入口函数。add 函数是整个索引构建的开始,所以需要从这里开始仔细阅读:

pic4.png

2、找到入口

通过上一篇的通篇泛读,我们已经找到了重点,那么就从 addPoint 函数开始读吧。

addPoint

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
/*
* Adds point. Updates the point if it is already in the index.
* If replacement of deleted elements is enabled: replaces previously deleted point if any, updating it with new point
*/
void addPoint(const void *data_point, labeltype label, bool replace_deleted = false) {
if ((allow_replace_deleted_ == false) && (replace_deleted == true)) {
throw std::runtime_error("Replacement of deleted elements is disabled in constructor");
}

// lock all operations with element by label
std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));
if (!replace_deleted) {
addPoint(data_point, label, -1);
return;
}
// check if there is vacant place
tableint internal_id_replaced;
std::unique_lock <std::mutex> lock_deleted_elements(deleted_elements_lock);
bool is_vacant_place = !deleted_elements.empty();
if (is_vacant_place) {
internal_id_replaced = *deleted_elements.begin();
deleted_elements.erase(internal_id_replaced);
}
lock_deleted_elements.unlock();

// if there is no vacant place then add or update point
// else add point to vacant place
if (!is_vacant_place) {
addPoint(data_point, label, -1);
} else {
// we assume that there are no concurrent operations on deleted element
labeltype label_replaced = getExternalLabel(internal_id_replaced);
setExternalLabel(internal_id_replaced, label);

std::unique_lock <std::mutex> lock_table(label_lookup_lock);
label_lookup_.erase(label_replaced);
label_lookup_[label] = internal_id_replaced;
lock_table.unlock();

unmarkDeletedInternal(internal_id_replaced);
updatePoint(data_point, internal_id_replaced, 1.0);
}
}

入参肯定就迷糊了,data_point 可以理解是插入的数据指针,那多出来的 label 和 replace_deleted 都是什么呢?别急,带着疑惑找一下示例。示例位于 examples/cpp/example_search.cpp ,我把有用的代码放在下面:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
int main() {
int dim = 16; // Dimension of the elements
int max_elements = 10000; // Maximum number of elements, should be known beforehand
int M = 16; // Tightly connected with internal dimensionality of the data
// strongly affects the memory consumption
int ef_construction = 200; // Controls index search speed/build speed tradeoff

// Initing index
hnswlib::L2Space space(dim);
hnswlib::HierarchicalNSW<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, max_elements, M, ef_construction);

// Generate random data
std::mt19937 rng;
rng.seed(47);
std::uniform_real_distribution<> distrib_real;
float* data = new float[dim * max_elements];
for (int i = 0; i < dim * max_elements; i++) {
data[i] = distrib_real(rng);
}

// Add data to index
for (int i = 0; i < max_elements; i++) {
alg_hnsw->addPoint(data + i * dim, i);
}

可以看到, addPoint 实际传入的参数是 random data 中的第 i 个元素 和 i。这下可以根据命名猜测了, label 是用于标记向量的一个类似于主键的标签。而 replace_deleted 默认为 false ,所以我们接着看注解:

If replacement of deleted elements is enabled: replaces previously deleted point if any, updating it with new point

这里说的很清楚,如果 replacement 开启了,那么被删除的 elements 将有机会被更新成新插入的点。这是一个论文中并没有提到的实现,是删除操作的相关机制的工程实现。接着我们继续阅读代码,可以发现以 replace_deleted 为界, 函数可以被分为两部分—— 开启了 replacement 和 不开启 replacement。

replace_deleted 为 false 时的代码非常简单,可以看到是加锁后直接调用了被重载的另一个 addPoint 函数。

1
2
3
4
5
6
// lock all operations with element by label
std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));
if (!replace_deleted) {
addPoint(data_point, label, -1);
return;
}

而 replace_deleted 为 true 时则有趣得多,首先通过检查 deleted_elements ,尝试找出一个被删除的 tableint ,接着更新这个被删除的节点的label信息,最后调用 updatePoint 函数来更新 tableint 指向的节点的各种关系。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
// check if there is vacant place
tableint internal_id_replaced;
std::unique_lock <std::mutex> lock_deleted_elements(deleted_elements_lock);
bool is_vacant_place = !deleted_elements.empty();
if (is_vacant_place) {
internal_id_replaced = *deleted_elements.begin();
deleted_elements.erase(internal_id_replaced);
}
lock_deleted_elements.unlock();

// if there is no vacant place then add or update point
// else add point to vacant place
if (!is_vacant_place) {
addPoint(data_point, label, -1);
} else {
// we assume that there are no concurrent operations on deleted element
labeltype label_replaced = getExternalLabel(internal_id_replaced);
setExternalLabel(internal_id_replaced, label);

std::unique_lock <std::mutex> lock_table(label_lookup_lock);
label_lookup_.erase(label_replaced);
label_lookup_[label] = internal_id_replaced;
lock_table.unlock();

unmarkDeletedInternal(internal_id_replaced);
updatePoint(data_point, internal_id_replaced, 1.0);
}

读到这里, addPoint 的大致框架我们便读完了,秉持着最简化原则,我们只读最简单的代码。也就是这两项都不开启:

1
(allow_replace_deleted_ == false) && (replace_deleted == true)

那么有效的代码就是这两行:

1
2
3
4
5
std::unique_lock <std::mutex> lock_label(getLabelOpMutex(label));
if (!replace_deleted) {
addPoint(data_point, label, -1);
return;
}

我们接着阅读这个被重载的 addPoint :

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
tableint addPoint(const void *data_point, labeltype label, int level) {

tableint cur_c = 0;
{
// Checking if the element with the same label already exists
// if so, updating it *instead* of creating a new element.
std::unique_lock <std::mutex> lock_table(label_lookup_lock);
auto search = label_lookup_.find(label);
if (search != label_lookup_.end()) {
tableint existingInternalId = search->second;
if (allow_replace_deleted_) {
if (isMarkedDeleted(existingInternalId)) {
throw std::runtime_error("Can't use addPoint to update deleted elements if replacement of deleted elements is enabled.");
}
}
lock_table.unlock();

if (isMarkedDeleted(existingInternalId)) {
unmarkDeletedInternal(existingInternalId);
}
updatePoint(data_point, existingInternalId, 1.0);

return existingInternalId;
}

if (cur_element_count >= max_elements_) {
throw std::runtime_error("The number of elements exceeds the specified limit");
}

cur_c = cur_element_count;
cur_element_count++;
label_lookup_[label] = cur_c;
}

std::unique_lock <std::mutex> lock_el(link_list_locks_[cur_c]);
int curlevel = getRandomLevel(mult_);
if (level > 0)
curlevel = level;

element_levels_[cur_c] = curlevel;

std::unique_lock <std::mutex> templock(global);
int maxlevelcopy = maxlevel_;
if (curlevel <= maxlevelcopy)
templock.unlock();
tableint currObj = enterpoint_node_;
tableint enterpoint_copy = enterpoint_node_;

memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_);

// Initialisation of the data and label
memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype));
memcpy(getDataByInternalId(cur_c), data_point, data_size_);

if (curlevel) {
linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1);
if (linkLists_[cur_c] == nullptr)
throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist");
memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1);
}

if ((signed)currObj != -1) {
if (curlevel < maxlevelcopy) {
dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_);
for (int level = maxlevelcopy; level > curlevel; level--) {
bool changed = true;
while (changed) {
changed = false;
unsigned int *data;
std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
data = get_linklist(currObj, level);
int size = getListCount(data);

tableint *datal = (tableint *) (data + 1);
for (int i = 0; i < size; i++) {
tableint cand = datal[i];
if (cand < 0 || cand > max_elements_)
throw std::runtime_error("cand error");
dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_);
if (d < curdist) {
curdist = d;
currObj = cand;
changed = true;
}
}
}
}
}

bool epDeleted = isMarkedDeleted(enterpoint_copy);
for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) {
if (level > maxlevelcopy || level < 0) // possible?
throw std::runtime_error("Level error");

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates = searchBaseLayer(
currObj, data_point, level);
if (epDeleted) {
top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy);
if (top_candidates.size() > ef_construction_)
top_candidates.pop();
}
currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false);
}
} else {
// Do nothing for the first element
enterpoint_node_ = 0;
maxlevel_ = curlevel;
}

// Releasing lock for the maximum level
if (curlevel > maxlevelcopy) {
enterpoint_node_ = cur_c;
maxlevel_ = curlevel;
}
return cur_c;
}

又是一个大函数,那么我们接着划分。第一部分是这里:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
tableint cur_c = 0;
{
// Checking if the element with the same label already exists
// if so, updating it *instead* of creating a new element.
std::unique_lock <std::mutex> lock_table(label_lookup_lock);
auto search = label_lookup_.find(label);
if (search != label_lookup_.end()) {
tableint existingInternalId = search->second;
if (allow_replace_deleted_) {
if (isMarkedDeleted(existingInternalId)) {
throw std::runtime_error("Can't use addPoint to update deleted elements if replacement of deleted elements is enabled.");
}
}
lock_table.unlock();

if (isMarkedDeleted(existingInternalId)) {
unmarkDeletedInternal(existingInternalId);
}
updatePoint(data_point, existingInternalId, 1.0);

return existingInternalId;
}

if (cur_element_count >= max_elements_) {
throw std::runtime_error("The number of elements exceeds the specified limit");
}

cur_c = cur_element_count;
cur_element_count++;
label_lookup_[label] = cur_c;
}

hnswlib 已经很贴心地特意为我们划分了区间,这部分的功能很简单,注解里也描述的很清晰:

// Checking if the element with the same label already exists
// if so, updating it instead of creating a new element.

简单说就是,根据输入的 label 查询是否已经存在这个节点,如果没有就全局更新 cur_element_count ,同时设置 cur_c 为这个 cur_element_count 。 至此, label 的真正含义便清晰了起来—— label 是用户指定的用于唯一约束一个节点的标签,换句话说就是 主键/ID。而 cur_element_count 的作用也很清晰了——索引中全部的节点数量。

还是最简化阅读,这段的代码的核心功能便是下面几句:

1
2
3
cur_c = cur_element_count;
cur_element_count++;
label_lookup_[label] = cur_c;

也就是刚说的设置 cur_c 作为内存中存储向量数据的索引,更新 cur_element_count ,设置 label_lookup_ 这个 map 的 label 处为 cur_c 。通过 label -> label_lookup_ -> cur_c -> 内存中存储向量数据的地址 + cur_c 。 这种方式来定位向量。

上面第一段函数的作用是给输入的向量,在内存中分配存储空间,以及存储 label 和向量的映射关系。可以理解为存储引擎做的事情。

那么下面一段代码便是构建索引:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
// 随机设定层级,如果有输入层级不覆盖
std::unique_lock <std::mutex> lock_el(link_list_locks_[cur_c]);
int curlevel = getRandomLevel(mult_);
if (level > 0)
curlevel = level;
// 储存 cur_c 对应的向量的的 element_levels_
element_levels_[cur_c] = curlevel;

std::unique_lock <std::mutex> templock(global);
int maxlevelcopy = maxlevel_; // 猜想:这里使用copy很可能是面向编译优化,通过这种copy模式在流水线上提前将数据从内存区拷贝至函数栈,从而加速后续的比较
if (curlevel <= maxlevelcopy)
templock.unlock();
tableint currObj = enterpoint_node_;
tableint enterpoint_copy = enterpoint_node_;

memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_); // 初始化 data_level0_memory_ 内存区中 cur_c 对应向量的空间

// Initialisation of the data and label
memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype)); // 储存 size_t 格式的 label
memcpy(getDataByInternalId(cur_c), data_point, data_size_); // 储存向量数据 data_point 到 data_level0_memory_

if (curlevel) {
// 如果节点插入的不是最底层
linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1);
if (linkLists_[cur_c] == nullptr)
throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist");
memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1);
}

// 首先判断大部分导向的分支是否符合条件,更有利于流水线上的分支预测准确率!
// 这里对第一个节点和后续节点进行分歧操作,第一个节点不做操作。
if ((signed)currObj != -1) {
// 不是第一个节点,也要分情况讨论:当前预插入的层级是否高于目前的最高层

if (curlevel < maxlevelcopy) {
// 从 ep 开始检索
dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_); // dist_t 是类的模板类型,指定了距离的类型
// 从最高层开始,向刚刚设定的层级 curlevel ,依次寻找最近邻来步入
for (int level = maxlevelcopy; level > curlevel; level--) {
bool changed = true;
while (changed) {
changed = false;
unsigned int *data;
std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
data = get_linklist(currObj, level); // 从边链表中拿到 currObj ,也就是当前入节点的边的数据
int size = getListCount(data); // 反序列化得到当前入节点的边数

tableint *datal = (tableint *) (data + 1); // 跳过第一位unsigned short int,第一位unsigned short int在上面这行被反序列化为了边数
// 遍历全部的边
for (int i = 0; i < size; i++) {
tableint cand = datal[i];
if (cand < 0 || cand > max_elements_)
throw std::runtime_error("cand error");
dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_); // 计算 输入向量 与 当前入节点相邻向量的举例
// 如果更近,更新入节点,启动下一次搜索
if (d < curdist) {
curdist = d;
currObj = cand;
changed = true; // 启动下一次搜索
}
}
}
}
}

// 上面代码的主要功能是,从最高层向下检索到预插入层的上层,找到预插入层的上层中举例插入节点最近邻的节点。
// 但是如果新增节点的预插入层比目前最高层还高,那么上面的检索过程将被忽略
// 而下面代码的功能就是找到最M近邻,然后执行真正的插入:

bool epDeleted = isMarkedDeleted(enterpoint_copy);
for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) {
if (level > maxlevelcopy || level < 0) // possible?
throw std::runtime_error("Level error");

// 从
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates = searchBaseLayer(
currObj, data_point, level);
if (epDeleted) {
top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy);
if (top_candidates.size() > ef_construction_)
top_candidates.pop();
}
currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false);
}
} else {
// Do nothing for the first element
enterpoint_node_ = 0;
maxlevel_ = curlevel;
}


// Releasing lock for the maximum level
if (curlevel > maxlevelcopy) {
enterpoint_node_ = cur_c;
maxlevel_ = curlevel;
}
return cur_c;

这段代码还是很长,那么还是分割来读:

第一部分的代码可以被理解为下面这部分,主要功能是

  • 随机设置插入向量的层级。
  • 储存 向量、label、节点插入的层级数 至内存中,做一个半持久化的存储。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
// 随机设定层级,如果有输入层级不覆盖
std::unique_lock <std::mutex> lock_el(link_list_locks_[cur_c]);
int curlevel = getRandomLevel(mult_);
if (level > 0)
curlevel = level;
// 储存 cur_c 对应的向量的的 element_levels_
element_levels_[cur_c] = curlevel;

std::unique_lock <std::mutex> templock(global);
int maxlevelcopy = maxlevel_; // 猜想:这里使用copy很可能是面向编译优化,通过这种copy模式在流水线上提前将数据从内存区拷贝至函数栈,从而加速后续的比较
if (curlevel <= maxlevelcopy)
templock.unlock();
tableint currObj = enterpoint_node_;
tableint enterpoint_copy = enterpoint_node_;

memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_); // 初始化 data_level0_memory_ 内存区中 cur_c 对应向量的空间

// Initialisation of the data and label
memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype)); // 储存 size_t 格式的 label
memcpy(getDataByInternalId(cur_c), data_point, data_size_); // 储存向量数据 data_point 到 data_level0_memory_

接着是第二部分,主要功能是初始化节点的边数组。条件是 curlevel > 0 , 此时意味着节点同时是一个层次索引节点。

1
2
3
4
5
6
7
if (curlevel) {
// 如果节点插入的不是最底层
linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1);
if (linkLists_[cur_c] == nullptr)
throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist");
memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1);
}

最后是最长也是最关键的部分,这部分的解读我直接放在注释中:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
// currObj 是一个 typedef unsigned int hnswlib::tableint , 转为 singned 的目的是 unsigned int hnswlib::tableint enterpoint_node_ 在索引初始化时,被设置为-1。
// -1 转化为补码后是全1,转为无符号数就变成了最大的无符号数,所以这样设计的含义是直接让 16 位 1 代表 第一个添加进的节点,从而让第一个节点走另一条分支。
// 此外,这里首先判断的是,大部分分支符合的条件,更有利于流水线上的分支预测准确率!
if ((signed)currObj != -1) {

// 不是第一个节点,也要分情况讨论:当前预插入的层级是否高于目前的最高层
if (curlevel < maxlevelcopy) {
// 在最高层下面,就直接找到这个节点在待插入层上层的入节点。缓存至 currObj
// 从 ep 开始检索
dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_); // dist_t 是类的模板类型,指定了距离的类型
// 从最高层开始,向刚刚设定的层级 curlevel ,依次寻找最近邻来步入
for (int level = maxlevelcopy; level > curlevel; level--) {
bool changed = true;
while (changed) {
changed = false;
unsigned int *data;
std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
data = get_linklist(currObj, level); // 从边链表中拿到 currObj ,也就是当前入节点的边的数据
int size = getListCount(data); // 反序列化得到当前入节点的边数

tableint *datal = (tableint *) (data + 1); // 跳过第一位unsigned short int,第一位unsigned short int在上面这行被反序列化为了边数
// 遍历全部的边
for (int i = 0; i < size; i++) {
tableint cand = datal[i];
if (cand < 0 || cand > max_elements_)
throw std::runtime_error("cand error");
dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_); // 计算 输入向量 与 当前入节点相邻向量的举例
// 如果更近,更新入节点,启动下一次搜索
if (d < curdist) {
curdist = d;
currObj = cand;
changed = true; // 启动下一次搜索
}
}
}
}
}

// 上面代码的主要功能是,从最高层向下检索到预插入层的上层,找到预插入层的上层中,举例待插入节点最近的节点。也就是进入下一层的入节点
// 但是如果新增节点的预插入层比目前最高层还高,那么上面的检索过程将不进行。
// 而下面代码的功能就是找到最M近邻,然后执行真正的插入。
// epDeleted 是与删除相关的标志位,我们暂时不阅读
bool epDeleted = isMarkedDeleted(enterpoint_copy);
// 从能够建立索引的最高层min(curlevel, maxlevelcopy)开始,直到最底层0。给输入的vector搜寻近邻,并依次建立对应的边。
for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) {
if (level > maxlevelcopy || level < 0) // possible?
throw std::runtime_error("Level error");

// 从当前的 level 级,以 currObj 为入节点,找 data_point 的近邻。结果是一个存储了 dist_t, tableint pair 的优先队列。
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates = searchBaseLayer(
currObj, data_point, level);

// 这段删除相关的也不阅读
if (epDeleted) {
top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy);
if (top_candidates.size() > ef_construction_)
top_candidates.pop();
}

// 可以看到,这里调用了建立连接的函数(从函数名理解)。
currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false);
}
} else {
// Do nothing for the first element
enterpoint_node_ = 0;
maxlevel_ = curlevel;
}

到此为止,这个被重载的 addPoint 就读完了。可以看到,这里面调用的最重要的函数是 mutuallyConnectNewElement 以及 searchBaseLayer 。那么我们接下来就读 searchBaseLayermutuallyConnectNewElement

PS: 关于 unsigned int = -1 的解读:

What will happened when unsigned int t = -1

searchBaseLayer

下面是 searchBaseLayer 的全部代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
searchBaseLayer(tableint ep_id, const void *data_point, int layer) {
VisitedList *vl = visited_list_pool_->getFreeVisitedList();
vl_type *visited_array = vl->mass;
vl_type visited_array_tag = vl->curV;

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidateSet;

dist_t lowerBound;
if (!isMarkedDeleted(ep_id)) {
dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
top_candidates.emplace(dist, ep_id);
lowerBound = dist;
candidateSet.emplace(-dist, ep_id);
} else {
lowerBound = std::numeric_limits<dist_t>::max();
candidateSet.emplace(-lowerBound, ep_id);
}
visited_array[ep_id] = visited_array_tag;

while (!candidateSet.empty()) {
std::pair<dist_t, tableint> curr_el_pair = candidateSet.top();
if ((-curr_el_pair.first) > lowerBound && top_candidates.size() == ef_construction_) {
break;
}
candidateSet.pop();

tableint curNodeNum = curr_el_pair.second;

std::unique_lock <std::mutex> lock(link_list_locks_[curNodeNum]);

int *data; // = (int *)(linkList0_ + curNodeNum * size_links_per_element0_);
if (layer == 0) {
data = (int*)get_linklist0(curNodeNum);
} else {
data = (int*)get_linklist(curNodeNum, layer);
// data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_);
}
size_t size = getListCount((linklistsizeint*)data);
tableint *datal = (tableint *) (data + 1);
#ifdef USE_SSE
_mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
_mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
_mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
_mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0);
#endif

for (size_t j = 0; j < size; j++) {
tableint candidate_id = *(datal + j);
// if (candidate_id == 0) continue;
#ifdef USE_SSE
_mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0);
_mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0);
#endif
if (visited_array[candidate_id] == visited_array_tag) continue;
visited_array[candidate_id] = visited_array_tag;
char *currObj1 = (getDataByInternalId(candidate_id));

dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_);
if (top_candidates.size() < ef_construction_ || lowerBound > dist1) {
candidateSet.emplace(-dist1, candidate_id);
#ifdef USE_SSE
_mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0);
#endif

if (!isMarkedDeleted(candidate_id))
top_candidates.emplace(dist1, candidate_id);

if (top_candidates.size() > ef_construction_)
top_candidates.pop();

if (!top_candidates.empty())
lowerBound = top_candidates.top().first;
}
}
}
visited_list_pool_->releaseVisitedList(vl);

return top_candidates;
}

返回体是个难以言喻的嵌套结构:

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>

所以这个函数的难点是在于读懂这个返回结构的设计含义。所以接着读吧:

1
2
3
std::unique_ptr<VisitedListPool> visited_list_pool_{nullptr};

VisitedList *vl = visited_list_pool_->getFreeVisitedList();

上来一个 VisitedList 就看蒙了,为什么不直接拿一个 list , 而是去池里拿?带着疑问跳转阅读 VisitedListPool 。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#pragma once

#include <mutex>
#include <string.h>
#include <deque>

namespace hnswlib {
typedef unsigned short int vl_type;

class VisitedList {
public:
vl_type curV;
vl_type *mass;
unsigned int numelements;

VisitedList(int numelements1) {
curV = -1;
numelements = numelements1;
mass = new vl_type[numelements];
}

void reset() {
curV++;
if (curV == 0) {
memset(mass, 0, sizeof(vl_type) * numelements);
curV++;
}
}
s
~VisitedList() { delete[] mas; }
};
///////////////////////////////////////////////////////////
//
// Class for multi-threaded pool-management of VisitedLists
//
/////////////////////////////////////////////////////////

class VisitedListPool {
std::deque<VisitedList *> pool;
std::mutex poolguard;
int numelements;

public:
VisitedListPool(int initmaxpools, int numelements1) {
numelements = numelements1;
for (int i = 0; i < initmaxpools; i++)
pool.push_front(new VisitedList(numelements));
}

VisitedList *getFreeVisitedList() {
VisitedList *rez;
{
std::unique_lock <std::mutex> lock(poolguard);
if (pool.size() > 0) {
rez = pool.front();
pool.pop_front();
} else {
rez = new VisitedList(numelements);
}
}
rez->reset();
return rez;
}

void releaseVisitedList(VisitedList *vl) {
std::unique_lock <std::mutex> lock(poolguard);
pool.push_front(vl);
}

~VisitedListPool() {
while (pool.size()) {
VisitedList *rez = pool.front();
pool.pop_front();
delete rez;
}
}
};
} // namespace hnswlib

可以看到, VisitedList 是一个非常简单的定长数组封装工具;
数据类型使用 typedef unsigned short int vl_type; ,是为了节省空间吗?但是我们刚刚看到的 cur_element_count 是 size_t 的,而 enterpoint_node_ 是 typedef unsigned int 的,它们都比 shot int 长,那么这个 VisitedList 到底存的是什么呢?还有,为什么 reset 函数只在 curV == 0 时对 mass 进行赋0操作呢?

带着疑问,我们接着阅读 VisitedListPool 。

在构造函数中,可以看到 VisitedListPool 的主要数据结构体是一个定长 VisitedList 的 deque ,通过输入的 numelements1 将定长 VisitedList 资源统一初始化为相同长度。接着,提供了获取和释放列表的函数,在获取和释放之前都对共享资源队列加全局锁。这里有趣的是获取和释放的不一致,通过在池中资源不足时新增 VisitedList 实例来防止获取失败。这里的设计可能会造成 内存溢出 :假设有多个线程同时 getFreeVisitedList ,线程总数远超 initmaxpools ,那么这里会为每个超过 initmaxpools 的线程分配一个新的 VisitedList 。使用结束后,调用 releaseVisitedList 释放 VisitedList ,但这里多余归还的 VisitedList 会一直占用内存而不会被 delete ,应该设置更完善的释放机制。但鉴于这个包的目的是精简,所以这里就不过多追究了。

可以看到 getFreeVisitedList 函数中,最终都会执行 reset 函数来初始化 VisitedList 。 也就是通过 getFreeVisitedList 函数,拿到的第一个没有被使用过的 VisitedList 都会被 Set0 的,而后续再次拿到这个 VisitedList 后,不会被再次 Set0 。这里解答了一部分上面的问题,但还是不清楚,所以接着读。

读完了 VisitedListPool ,我们已经理解了上述池化的意义:尽可能减少资源占用的前提下提供并发能力。也明白了第一次拿到一个 VisitedList 时,这个 VisitedList 会被全部赋0。但当我们重复地从池中再次取出这个 VisitedList 时,不会被赋0,而是保留全部的修改。

接着读 searchBaseLayersearchBaseLayer 的第一段是申请 VisitedList , 第二段便是下面的代码。我们不看 isMarkedDeleted 的情况,只看节点不被删除时的逻辑,那么可以理解这段代码是:计算 ep 节点与待插入节点的距离,将这个距离设置为目前的最小距离,将 ep 节点作为第一个候选节点插入到优先队列中。

1
2
3
4
5
6
7
8
9
10
11
dist_t lowerBound;
if (!isMarkedDeleted(ep_id)) {
dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
top_candidates.emplace(dist, ep_id);
lowerBound = dist;
candidateSet.emplace(-dist, ep_id);
} else {
lowerBound = std::numeric_limits<dist_t>::max();
candidateSet.emplace(-lowerBound, ep_id);
}
visited_array[ep_id] = visited_array_tag;

这里最让我困惑的是这段代码

visited_array[ep_id] = visited_array_tag;

visited_list_pool_ = std::unique_ptr(new VisitedListPool(1, max_elements));

ep_id 怎么保证不会越界呢?反回头一看,在构造函数中 visited_list_pool_ 竟然被初始化成了 max_elements 大小的空间(上面的代码)!这下刚刚的疑惑都解开了。 visited_array_tag 是 每个线程 每次调用 独有的访问标记, visited_array 中每个节点都有自己的 slot 。一旦这个节点被某个线程 单次访问了,那么 visited_array 中该节点对应的 slot 就会被被设置成这个访问标记,用于后续的图遍历终止。比如,第一个线程、第一次调用,那么这次访问的标记就是1(第一次reset会被设为1)。每访问一个节点,这个节点的 ID 对应的 visited_array[id] 就会被设为 1 ,代表已经访问。

可见 VisitedList 是一个超大 HASH ,用于存储节点是否被遍历过。这样我们也理解了为什么这样设计!通过每次访问设置不同的 tag 的方式,避免了重复构造 bool 类型的访问 map ,减少了系统开销。同时,刚刚的 内存溢出 问题也得到了解答:在溢出和消耗中,trade-off 选择了消耗,因为这么大的数组,销毁一次的开销还是巨大的。同时 vl_type 选用 unsigned short int 的原因也明了了, vl_type 记录的是这个 VisitedList 的被使用次数,所以可能不需要做这么多次的记录。但是这里仍然存在一个隐患,当查询次数大了,或者这个包被长期当作基础服务使用时,int 会出现越界。因此!这里选用了 unsigned ,在溢出后直接变为0!

问题闭环了!多么完美的代码!!!

我们接着向下读,接下来的代码的主要功能是在输入的层级中搜索输入的 data_point 的 M 个近邻。由于代码很长,我们还是分段阅读:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
            std::pair<dist_t, tableint> curr_el_pair = candidateSet.top();
if ((-curr_el_pair.first) > lowerBound && top_candidates.size() == ef_construction_) {
break;
}
candidateSet.pop();

tableint curNodeNum = curr_el_pair.second;

std::unique_lock <std::mutex> lock(link_list_locks_[curNodeNum]);

int *data; // = (int *)(linkList0_ + curNodeNum * size_links_per_element0_);
if (layer == 0) {
data = (int*)get_linklist0(curNodeNum);
} else {
data = (int*)get_linklist(curNodeNum, layer);
// data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_);
}
size_t size = getListCount((linklistsizeint*)data);

这段代码检查了循环跳出条件:候选队列中距离待插入点最小的距离仍然比下界大,而且已经有了 ef_construction_ 个实际最近邻,则跳出循环。

接着,通过找到的目前候选队列中的最近邻,获取对应节点的相邻节点的数据(也就是获取全部的 link )。如果当前是底层,则使用 get_linklist0data_level0_memory_ 中获取节点的邻节点,否则使用 get_linklistlinkLists_ 中的对应层位置取出节点的 link 数据。这里拿到的 data 的第一个 unsigned short int 还是存储了有多少个邻居,因此 size 就是当前遍历的节点的邻居数量。

接下来的这段掺杂了预编译指令,主要目的是判断当前 cpu 是否支持 SSE。SSE 全名为 Streaming SIMD Extensions ,用于在单指令多数据的指令集下进行计算效率的优化。

https://wiki.osdev.org/SSE#Streaming_SIMD_Extensions_(SSE)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
            tableint *datal = (tableint *) (data + 1);
#ifdef USE_SSE
_mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
_mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
_mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
_mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0);
#endif

for (size_t j = 0; j < size; j++) {
tableint candidate_id = *(datal + j);
// if (candidate_id == 0) continue;
#ifdef USE_SSE
_mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0);
_mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0);
#endif
if (visited_array[candidate_id] == visited_array_tag) continue;
visited_array[candidate_id] = visited_array_tag;
char *currObj1 = (getDataByInternalId(candidate_id));

dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_);
if (top_candidates.size() < ef_construction_ || lowerBound > dist1) {
candidateSet.emplace(-dist1, candidate_id);
#ifdef USE_SSE
_mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0);
#endif

if (!isMarkedDeleted(candidate_id))
top_candidates.emplace(dist1, candidate_id);

if (top_candidates.size() > ef_construction_)
top_candidates.pop();

if (!top_candidates.empty())
lowerBound = top_candidates.top().first;
}
}
}

为了避免复杂,我们首先阅读没有 SIMD 优化的代码,代码简化为下面的这段。再然后,我们把带有删除操作的代码删掉,便是下面这段剪枝的核心算法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
    tableint *datal = (tableint *) (data + 1);

for (size_t j = 0; j < size; j++) {
tableint candidate_id = *(datal + j);

if (visited_array[candidate_id] == visited_array_tag) continue;
visited_array[candidate_id] = visited_array_tag;
char *currObj1 = (getDataByInternalId(candidate_id));

dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_);
if (top_candidates.size() < ef_construction_ || lowerBound > dist1) {
candidateSet.emplace(-dist1, candidate_id);

if (top_candidates.size() > ef_construction_)
top_candidates.pop();

if (!top_candidates.empty())
lowerBound = top_candidates.top().first;
}
}
}

这下读起来轻松多了。这里有一个关键点, top_candidates 是结果近邻列表,存放最 ef_construction_ 个最近邻; candidateSet 是用于 refine 的候选队列。因此,向队列中插入的 dist1 是取负的,目的是将 candidateSet 变成小根堆:原本的 CompareByFirst 函数,是将距离更大的元素更优先放在队尾,也就是大根堆。而取反后则是小根堆。举例如下:

1
2
3
4
大根堆1中元素:1,2,3
大根堆2中元素:-1,-2,-3

大根堆1先出堆的是1,而大根堆2先出的是-3

通过这样的方式,将候选队列中,距离当前节点最近的节点优先出队。如果这个节点都大于当前的下界,那么其它距离更远的节点更不需要考虑了,可以直接进行剪枝。这里可以结合终止条件来一起阅读,我统一摘取出来列在下面,具体的含义可以阅读我在这里写的注释:

1
2
3
4
5
6
7
8
9
10
// 跳出 while 循环的条件
if ((-curr_el_pair.first) > lowerBound && top_candidates.size() == ef_construction_) {
break;
}

// 向 candidateSet 中插入取负的距离,与计算时再取负的机制一起,将 candidateSet 变成小根堆。
candidateSet.emplace(-dist1, candidate_id);

// 向 top_candidates 中插入的直接就是真实的距离,大根堆,最优先弹出的是距离最大的节点。
top_candidates.emplace(dist1, candidate_id);
1
2
3
4
5
6
7
8
9
10
11
12
13
14
读完了没有 **SIMD** 优化的代码,我们再继续阅读一下使用了 **SSE** 优化的代码。在阅读之前,我们有必要先拥有对 **SSE** 的最基础的认知,以及对这里使用的 **mm_prefetch** 的了解。 **mm_prefetch** 是一个软件层面上的预读取指令,这是英特尔官方对 **mm_prefetch** 的解释:

> Prefetches data from the specified address on one memory cache line. Intrinsic subroutines cannot be passed as actual arguments.

这是英特尔的官方文档:

> https://www.intel.com/content/www/us/en/docs/fortran-compiler/developer-guide-reference/2024-1/mm-prefetch.html

我们直接阅读源码,发现:

```c++
#define _mm_prefetch(P, I) \
__builtin_prefetch ((P), ((I & 0x4) >> 2), (I & 0x3))
#endif
_mm_prefetch 其实是 __builtin_prefetch 的预编译指令。 **__builtin_prefetch** 是真正的预取指令。这里先开个坑,等读完了回过头来再详细讲一下CPU预抓取的优化效果。TODO
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

接着我们再捋顺一下剪枝算法的逻辑,针对每一个当前节点的 **邻居** , **candidate_id** 是邻居的 **id** ;首先在访问历史 **visited_array** 中查询 **candidate_id** 是否在这次搜索中已经被访问过了,如果已经被访问过了直接结束 **这个邻居** 的判断;否则判断 **这个邻居** 能否被加入到候选队列以及结果队列中继续计算。

能否被加入的条件是两个中满足任意一个即可:
* **top_candidates** 结果集合不满
* 将 **这个邻居** 与 **输入的节点** 之间的距离计算出来, **这个距离** 比 **top_candidates**(结果set)中的最大距离 **lowerBound** 还小。

第一个条件只会在初始化 **top_candidates** 过程中满足,主要进行的后续判断是第二个条件。也就是下面这段代码:

```c++
dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_);
if (top_candidates.size() < ef_construction_ || lowerBound > dist1) {
candidateSet.emplace(-dist1, candidate_id);

if (top_candidates.size() > ef_construction_)
top_candidates.pop();

if (!top_candidates.empty())
lowerBound = top_candidates.top().first;
}

这里选择条件1作为前置条件进行优化,这样写会在符合条件 top_candidates.size() < ef_construction_** 的分支数大于 **lowerBound > dist1 时有更高的流水线执行效率。这也在某种程度上说明了,这个搜索函数的剪枝效率会很高,可能都不用 ef_construction_ 次更新就能完成。

mutuallyConnectNewElement Part1

读完了 searchBaseLayer ,我们接下来要读的是 mutuallyConnectNewElement , 这个函数的功能,在我们之前的猜想中是一个给新的节点与它的 ef_construction_ 个邻居建立边。带着这个猜想继续阅读吧!

下面是 mutuallyConnectNewElement 的全部代码,有100多行,我们还是将它拆分阅读

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
tableint mutuallyConnectNewElement(
const void *data_point,
tableint cur_c,
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
int level,
bool isUpdate) {

size_t Mcurmax = level ? maxM_ : maxM0_;
getNeighborsByHeuristic2(top_candidates, M_);
if (top_candidates.size() > M_)
throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic");

std::vector<tableint> selectedNeighbors;
selectedNeighbors.reserve(M_);
while (top_candidates.size() > 0) {
selectedNeighbors.push_back(top_candidates.top().second);
top_candidates.pop();
}

tableint next_closest_entry_point = selectedNeighbors.back();

{
// lock only during the update
// because during the addition the lock for cur_c is already acquired
std::unique_lock <std::mutex> lock(link_list_locks_[cur_c], std::defer_lock);
if (isUpdate) {
lock.lock();
}
linklistsizeint *ll_cur;
if (level == 0)
ll_cur = get_linklist0(cur_c);
else
ll_cur = get_linklist(cur_c, level);

if (*ll_cur && !isUpdate) {
throw std::runtime_error("The newly inserted element should have blank link list");
}
setListCount(ll_cur, selectedNeighbors.size());
tableint *data = (tableint *) (ll_cur + 1);
for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
if (data[idx] && !isUpdate)
throw std::runtime_error("Possible memory corruption");
if (level > element_levels_[selectedNeighbors[idx]])
throw std::runtime_error("Trying to make a link on a non-existent level");

data[idx] = selectedNeighbors[idx];
}
}

for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
std::unique_lock <std::mutex> lock(link_list_locks_[selectedNeighbors[idx]]);

linklistsizeint *ll_other;
if (level == 0)
ll_other = get_linklist0(selectedNeighbors[idx]);
else
ll_other = get_linklist(selectedNeighbors[idx], level);

size_t sz_link_list_other = getListCount(ll_other);

if (sz_link_list_other > Mcurmax)
throw std::runtime_error("Bad value of sz_link_list_other");
if (selectedNeighbors[idx] == cur_c)
throw std::runtime_error("Trying to connect an element to itself");
if (level > element_levels_[selectedNeighbors[idx]])
throw std::runtime_error("Trying to make a link on a non-existent level");

tableint *data = (tableint *) (ll_other + 1);

bool is_cur_c_present = false;
if (isUpdate) {
for (size_t j = 0; j < sz_link_list_other; j++) {
if (data[j] == cur_c) {
is_cur_c_present = true;
break;
}
}
}

// If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics.
if (!is_cur_c_present) {
if (sz_link_list_other < Mcurmax) {
data[sz_link_list_other] = cur_c;
setListCount(ll_other, sz_link_list_other + 1);
} else {
// finding the "weakest" element to replace it with the new one
dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]),
dist_func_param_);
// Heuristic:
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
candidates.emplace(d_max, cur_c);

for (size_t j = 0; j < sz_link_list_other; j++) {
candidates.emplace(
fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]),
dist_func_param_), data[j]);
}

getNeighborsByHeuristic2(candidates, Mcurmax);

int indx = 0;
while (candidates.size() > 0) {
data[indx] = candidates.top().second;
candidates.pop();
indx++;
}

setListCount(ll_other, indx);
// Nearest K:
/*int indx = -1;
for (int j = 0; j < sz_link_list_other; j++) {
dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_);
if (d > d_max) {
indx = j;
d_max = d;
}
}
if (indx >= 0) {
data[indx] = cur_c;
} */
}
}
}

return next_closest_entry_point;
}

第一段代码如下:

1
2
3
4
size_t Mcurmax = level ? maxM_ : maxM0_;
getNeighborsByHeuristic2(top_candidates, M_);
if (top_candidates.size() > M_)
throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic");

首先是对 maxM_ 和 maxM0_ 取 max 。这里我们并不理解其中的含义,因此我们还需要阅读示例文件,示例文件的相关内容我列在下面。

1
2
3
// example_search.cpp 中的说明
int M = 16; // Tightly connected with internal dimensionality of the data
// strongly affects the memory consumption

示例文件 example_search.cpp 中的注解翻译一下的意思是:

与数据的内部维度紧密相连、强烈地影响着内存消耗

但是我们还是读不懂这是在约定什么,因此需要读论文,论文中对 M 和 Mmax 的描述如下:

  • number of established connections M

  • maximum number of connections for each element per layer Mmax

翻译过来其实就是:

  • 每个节点已经建立的连接数量 M

  • 每层中,每个节点(也就是 element )允许建立的最大连接数量 Mmax

为了确定我们理解的准确,我们还需要阅读关于 M 和 Mmax 的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
// 构造函数中与 M 相关的代码
if ( M <= 10000 ) {
M_ = M;
} else {
HNSWERR << "warning: M parameter exceeds 10000 which may lead to adverse effects." << std::endl;
HNSWERR << " Cap to 10000 will be applied for the rest of the processing." << std::endl;
M_ = 10000;
}
maxM_ = M_; // 设置为M
maxM0_ = M_ * 2; // 设置为2M
ef_construction_ = std::max(ef_construction, M_);

...
size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); // 2M个tableint和一个linklistsizeint
size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype); // vector、label、和size_links_level0_,那么size_links_level0_大概率是这个节点的相邻节点列表
offsetData_ = size_links_level0_; // 每个节点的全部数据中,节点数据的位置,也就是存储vector的偏移量。
label_offset_ = size_links_level0_ + data_size_; // 存储label的偏移量

data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_); // 一次性初始化全部第0层的内存空间
if (data_level0_memory_ == nullptr)
throw std::runtime_error("Not enough memory");

...

linkLists_ = (char **) malloc(sizeof(void *) * max_elements_); // linkLists_ 是一个储存char *的数组。这里一次性分配 max_elements_ 个元素的空间。代表每个节点有一个对应的char*指针,指向什么还不知道
if (linkLists_ == nullptr)
throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists");
size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); // M个tableint和一个linklistsizeint


读到这里, linkLists_ 的含义我们又忘掉了,接着回过头找到之前 addPoint 函数中使用过的 linkLists_ ,相关的代码如下。可以看到这里为每个插入的节点,按照其插入的层 curlevel ,分配了 size_links_per_element_ * curlevel + 1 个字节的空间。 size_links_per_element_ 是我们刚读过的 M个tableint和一个linklistsizeint 所占的空间大小,乘 curlevel 意味着这是除了最底层之外的额外的空间消耗。比如插入第1层,就有1层的消耗,插入第2层就需要2层的消耗。

1
2
3
4
5
6
7
if (curlevel) {
// 如果节点插入的不是最底层
linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1);
if (linkLists_[cur_c] == nullptr)
throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist");
memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1);
}

为了验证上述的猜想,我们接着读 get_linklist 函数。这个函数提供了读取某个节点 internal_id 在第某层 level 的边数据。

1
2
3
linklistsizeint *get_linklist(tableint internal_id, int level) const {
return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_);
}

可以看到偏移量计算的公式和我们猜想的一样。简而言之, linkLists_ 的结构可以用下面这个图示概括:

1
2
3
4
5
6
7
8
9
10
11
12
// linkLists_ 存储指针
linkLists_[节点id1] -> char* 1
linkLists_[节点id2] -> char* 2

// 每个指针指向不连续的地址空间
char* 1 -> 一段 ( maxM_ * sizeof(tableint) + sizeof(linklistsizeint) ) * curlevel + 1 个字节的空间

// 数据的反序列化表示
[第1层中该节点的边数(linklistsizeint) | 第1层中该节点的边列表(分配了M个tableint)]
[第2层中该节点的边数(linklistsizeint) | 第2层中该节点的边列表(分配了M个tableint)]
[1个byte的空间,用途未知]

这说明了在这里 M 是每个节点在每层中允许建立的最大的邻居数量,而2M是每个节点在最底层的最大邻居数。为什么这样设计,我们目前还不清楚。但是我们至少明确了变量 M 确实限定了每个节点的边数。

接着函数调用了 getNeighborsByHeuristic2 从字面意义上是启发式地获取邻居,对应论文中的算法4。那么我们有必要阅读一下 getNeighborsByHeuristic2

getNeighborsByHeuristic2

getNeighborsByHeuristic2 的代码非常少,只有下面几行。经过查询,getNeighborsByHeuristic2 是 hnswlib 中唯一一个启发式搜索邻居的实现算法,这个算法的实现与论文中的算法4还是有一些差距的。主要区别点在于,这个实现删掉了论文中的 extendCandidateskeepPrunedConnections 两个参数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
void getNeighborsByHeuristic2(
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> &top_candidates,
const size_t M) {
if (top_candidates.size() < M) {
return;
}

std::priority_queue<std::pair<dist_t, tableint>> queue_closest;
std::vector<std::pair<dist_t, tableint>> return_list;
while (top_candidates.size() > 0) {
queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second);
top_candidates.pop();
}

while (queue_closest.size()) {
if (return_list.size() >= M)
break;
std::pair<dist_t, tableint> curent_pair = queue_closest.top();
dist_t dist_to_query = -curent_pair.first;
queue_closest.pop();
bool good = true;

for (std::pair<dist_t, tableint> second_pair : return_list) {
dist_t curdist =
fstdistfunc_(getDataByInternalId(second_pair.second),
getDataByInternalId(curent_pair.second),
dist_func_param_);
if (curdist < dist_to_query) {
good = false;
break;
}
}
if (good) {
return_list.push_back(curent_pair);
}
}

for (std::pair<dist_t, tableint> curent_pair : return_list) {
top_candidates.emplace(-curent_pair.first, curent_pair.second);
}
}

函数的输入是一个优先队列 top_candidates 代表候选列表, 一个预计返回队列大小 M 用于约束返回的元素数量。这里的核心循环中涉及到了一个距离比较:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
while (queue_closest.size()) {
if (return_list.size() >= M)
break;
std::pair<dist_t, tableint> curent_pair = queue_closest.top();
dist_t dist_to_query = -curent_pair.first; // 当前点curent_pair距离查询点的距离
queue_closest.pop();
bool good = true;

for (std::pair<dist_t, tableint> second_pair : return_list) {
// 计算return_list中的点second_pair与当前点curent_pair的距离
dist_t curdist =
fstdistfunc_(getDataByInternalId(second_pair.second),
getDataByInternalId(curent_pair.second),
dist_func_param_);
if (curdist < dist_to_query) {
good = false;
break;
}
}
if (good) {
return_list.push_back(curent_pair);
}
}

这里的实现,在论文中是这样描述的:

pic6.png

可以看到最关键的一句话是:

if e is closer to q compared to any element from R

按照字面理解,我想当然地认为这句话的含义是:

e比R中的任何元素都更接近q

我这里又通过各种翻译软件确定了我理解的没什么大问题:

有道翻译:如果 e 相对于集合 R 中的任何元素都更接近于 q,那么 e 就是 R 中与 q 距离最近的元素。
百度翻译: 如果e比R中的任何元素更接近q
Deepl:如果与 R 中的任何元素相比,e 更接近 q

但是这样的理解会与上面这句话冲突:

e ← extract nearest element from W to q

这里从 W 中取出距离 q 最近的元素作为 e ,那么第二遍及以后的循环中,上述的判断不是恒 false 吗?这个函数到底有什么用?当时阅读论文的时候我就不理解。

但是,读了 hnswlib 实现后,我发现这是一个歧义引发的问题。这句话的正确翻译是:

相比于 e 距离 R 中的任意元素,e 距离 q 更近

代码中的 dist_to_query 是这个点 e待插入点q 的距离,而循环中的 curdist 是点 ereturn_list 中单个元素的距离。只有在 dist_to_query 大于全部的 curdist 时,才会被加入到 return_list。这与我们纠正了歧义后的理解一致。因此,我们通过阅读实现代码纠正了阅读论文中的错误理解!

总而言之, getNeighborsByHeuristic2 函数的功能我们明白了,从输入的候选队列中剪切出 M 个最适合的邻居。如果比 M 少,就直接返回。

mutuallyConnectNewElement Part2

那么我们接着读 mutuallyConnectNewElement

1
2
3
4
size_t Mcurmax = level ? maxM_ : maxM0_;
getNeighborsByHeuristic2(top_candidates, M_);
if (top_candidates.size() > M_)
throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic");

这段找近邻的代码我们读完了。第一行也理解了,如果当前节点插入的是0层,就最多有 2M 个邻居,否则最多有 M 个邻居。(这里的 M 是初始化构建索引时设定的 M

接下来我们继续往下读,这块代码的功能比较简单,我们统一来阅读分析:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
std::vector<tableint> selectedNeighbors;
selectedNeighbors.reserve(M_);
while (top_candidates.size() > 0) {
selectedNeighbors.push_back(top_candidates.top().second);
top_candidates.pop();
}

tableint next_closest_entry_point = selectedNeighbors.back();

{
// lock only during the update
// because during the addition the lock for cur_c is already acquired
std::unique_lock <std::mutex> lock(link_list_locks_[cur_c], std::defer_lock);
if (isUpdate) {
lock.lock();
}
linklistsizeint *ll_cur;
if (level == 0)
ll_cur = get_linklist0(cur_c);
else
ll_cur = get_linklist(cur_c, level);

if (*ll_cur && !isUpdate) {
throw std::runtime_error("The newly inserted element should have blank link list");
}
setListCount(ll_cur, selectedNeighbors.size());
tableint *data = (tableint *) (ll_cur + 1);
for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
if (data[idx] && !isUpdate)
throw std::runtime_error("Possible memory corruption");
if (level > element_levels_[selectedNeighbors[idx]])
throw std::runtime_error("Trying to make a link on a non-existent level");

data[idx] = selectedNeighbors[idx];
}
}

上面这段代码可以分三块理解

  • 将 top_candidates 中的节点id,按照距查询点的距离从大到小,依次尾插入到 selectedNeighbors ,清空 top_candidates
  • 获取离查询点最近的点,作为 next_closest_entry_point
  • 第三部分是括号括起来的逻辑,这部分的功能是将 selectedNeighbors 更新至 linklis 中。

这段代码我们看明白了。但是因为边是双向的,我们还需要给每条边上的另一个节点建立相同的关系,这些逻辑是下面的代码实现的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
std::unique_lock <std::mutex> lock(link_list_locks_[selectedNeighbors[idx]]);

linklistsizeint *ll_other;
if (level == 0)
ll_other = get_linklist0(selectedNeighbors[idx]);
else
ll_other = get_linklist(selectedNeighbors[idx], level);

size_t sz_link_list_other = getListCount(ll_other);

if (sz_link_list_other > Mcurmax)
throw std::runtime_error("Bad value of sz_link_list_other");
if (selectedNeighbors[idx] == cur_c)
throw std::runtime_error("Trying to connect an element to itself");
if (level > element_levels_[selectedNeighbors[idx]])
throw std::runtime_error("Trying to make a link on a non-existent level");

tableint *data = (tableint *) (ll_other + 1);

bool is_cur_c_present = false;
if (isUpdate) {
for (size_t j = 0; j < sz_link_list_other; j++) {
if (data[j] == cur_c) {
is_cur_c_present = true;
break;
}
}
}

// If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics.
if (!is_cur_c_present) {
if (sz_link_list_other < Mcurmax) {
data[sz_link_list_other] = cur_c;
setListCount(ll_other, sz_link_list_other + 1);
} else {
// finding the "weakest" element to replace it with the new one
dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]),
dist_func_param_);
// Heuristic:
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
candidates.emplace(d_max, cur_c);

for (size_t j = 0; j < sz_link_list_other; j++) {
candidates.emplace(
fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]),
dist_func_param_), data[j]);
}

getNeighborsByHeuristic2(candidates, Mcurmax);

int indx = 0;
while (candidates.size() > 0) {
data[indx] = candidates.top().second;
candidates.pop();
indx++;
}

setListCount(ll_other, indx);
}
}
}

这部分仍然可以分开来看。第一部分是获取每个邻居节点的边数据的指针,这里不赘述。接着第二部分是三个检查:

  • 邻居节点的边数还没大于 Mcurmax
  • 非自反
  • 邻居节点在这个插入层存在
1
2
3
4
5
6
if (sz_link_list_other > Mcurmax)
throw std::runtime_error("Bad value of sz_link_list_other");
if (selectedNeighbors[idx] == cur_c)
throw std::runtime_error("Trying to connect an element to itself");
if (level > element_levels_[selectedNeighbors[idx]])
throw std::runtime_error("Trying to make a link on a non-existent level");

通过检查后,对这个邻居节点,遍历一遍它的边,检查是否有当前的插入节点:

1
2
3
4
5
6
7
8
9
10
11
12

tableint *data = (tableint *) (ll_other + 1);

bool is_cur_c_present = false;
if (isUpdate) {
for (size_t j = 0; j < sz_link_list_other; j++) {
if (data[j] == cur_c) {
is_cur_c_present = true;
break;
}
}
}

最后,如果还没建立关系,则需要建立关系。但是这里可能会出现,当前邻居节点的关系表已经满了的情况。因此需要分类讨论:

如果不满,即边数少于 Mcurmax 。直接插入即可。

1
2
3
if (sz_link_list_other < Mcurmax) {
data[sz_link_list_other] = cur_c;
setListCount(ll_other, sz_link_list_other + 1);

如果已经满了,则需要替换掉 当前邻居的 最差的邻居。

这里采用了刚刚阅读过的 getNeighborsByHeuristic2 函数,对 M + 1 个元素进行排序,选出其中最优秀的 M 个节点作为邻居。这里并不保证我们这里的待插入节点百分百能够建立这条反向边,我们的待插入节点可能在这步被淘汰。

也就是下面这段代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
} else {
// finding the "weakest" element to replace it with the new one
dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]),
dist_func_param_);
// Heuristic:
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates;
candidates.emplace(d_max, cur_c);

for (size_t j = 0; j < sz_link_list_other; j++) {
candidates.emplace(
fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]),
dist_func_param_), data[j]);
}

getNeighborsByHeuristic2(candidates, Mcurmax);

int indx = 0;
while (candidates.size() > 0) {
data[indx] = candidates.top().second;
candidates.pop();
indx++;
}

setListCount(ll_other, indx);
}

读到这里, addPoint 函数的全部流程我们便读完了。留下了两个坑,一个是 SSE 上的优化没讲;另一个坑是 maxM_maxM0_ 两个与不同层节点的边数设置有关的参数,这样设计的原因。读到这里,我们还差一个搜素的算法,就把论文中提到的算法都读完了,那么我们接着读 searchKnn

searchKnn

searchKnn 是KNN搜索的入口函数,是实际使用中检索接口,也我们修改统计代码的入口。这个函数比较简单,我把代码的功能直接写在下面的注释中了。主要分三块:

  • 初始化入节点
  • 从最高层逐层向下,从 ep 启动 BFS 搜索,找到这层中的最近邻,作为下一层的 enterpoint
  • 从第1层中找到的 ep 开始,检索第0层,找到 ef_ 或者 K 个近邻,返回其中最近的 K 个近邻的距离和 label
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
std::priority_queue<std::pair<dist_t, labeltype >> result;
if (cur_element_count == 0) return result;

// 初始化入节点
tableint currObj = enterpoint_node_;
dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);

// 逐层搜索,直到找到第1层到第0层的入节点
for (int level = maxlevel_; level > 0; level--) {
bool changed = true;
while (changed) {
changed = false;
unsigned int *data;

// 获取当前节点的邻居节点
data = (unsigned int *) get_linklist(currObj, level);
int size = getListCount(data);
metric_hops++;
metric_distance_computations+=size;

// 遍历邻居,找到最近的节点
tableint *datal = (tableint *) (data + 1);
for (int i = 0; i < size; i++) {
tableint cand = datal[i];
if (cand < 0 || cand > max_elements_)
throw std::runtime_error("cand error");
dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);

if (d < curdist) {
curdist = d;
currObj = cand;
changed = true;
}
}
}
}

// 从上面找到的入节点开始,搜索最底层,根据两种不同的策略搜索,
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
bool bare_bone_search = !num_deleted_ && !isIdAllowed;
if (bare_bone_search) {
top_candidates = searchBaseLayerST<true>(
currObj, query_data, std::max(ef_, k), isIdAllowed);
} else {
top_candidates = searchBaseLayerST<false>(
currObj, query_data, std::max(ef_, k), isIdAllowed);
}

// 检查结果集数量是否符合要求,把查询到的结果的距离与label打包后返回
while (top_candidates.size() > k) {
top_candidates.pop();
}
while (top_candidates.size() > 0) {
std::pair<dist_t, tableint> rez = top_candidates.top();
result.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second)));
top_candidates.pop();
}
return result;
}

可以看到, searchKnn 函数调用了 searchBaseLayerST 函数。 searchBaseLayerST 是我们之前没阅读的另一个层次搜索函数,所以下一步需要阅读 searchBaseLayerST

searchBaseLayerST

searchBaseLayerST 与我们之前阅读的 searchBaseLayer 有一些区别。 searchBaseLayer 可以执行对任意层的检索,在参数中用 layer 指定 ,但是 searchBaseLayerST 并不支持指定层检索。下面这段代码是函数头,可以看到这个函数是一个模板函数。

1
2
3
4
5
6
7
8
9
// bare_bone_search means there is no check for deletions and stop condition is ignored in return of extra performance
template <bool bare_bone_search = true, bool collect_metrics = false>
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
searchBaseLayerST(
tableint ep_id,
const void *data_point,
size_t ef,
BaseFilterFunctor* isIdAllowed = nullptr,
BaseSearchStopCondition<dist_t>* stop_condition = nullptr) const {

bare_bone_search means there is no check for deletions and stop condition is ignored in return of extra performance

在注释中说明了参数的意义:如果 bare_bone_search 是 true,那么检索将忽略 删除停止条件删除 是我们之前就阅读过一些相关代码的,主要思想是通过 逻辑删除位 + 删除缓冲区 + 更新 的机制来实现。而 停止条件 是这里新提出的概念,也是原论文中没有提及的概念。此外, collect_metrics 在这里并没有被解释,也是原论文中并没有被提及的概念。带着这两个问题,我们接着往下读。

我们首先还是假定 不开启删除机制、 不开启更新机制、不开启SSE、那么代码将会被简化到下面这样:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
// bare_bone_search means there is no check for deletions and stop condition is ignored in return of extra performance
template <bool bare_bone_search = true, bool collect_metrics = false>
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
searchBaseLayerST(
tableint ep_id,
const void *data_point,
size_t ef,
BaseFilterFunctor* isIdAllowed = nullptr,
BaseSearchStopCondition<dist_t>* stop_condition = nullptr) const {

VisitedList *vl = visited_list_pool_->getFreeVisitedList();
vl_type *visited_array = vl->mass;
vl_type visited_array_tag = vl->curV;

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;

// 初始化入节点的距离,将入节点添加至候选列表
dist_t lowerBound;
if (bare_bone_search) {
char* ep_data = getDataByInternalId(ep_id);
dist_t dist = fstdistfunc_(data_point, ep_data, dist_func_param_);
lowerBound = dist;
top_candidates.emplace(dist, ep_id);
candidate_set.emplace(-dist, ep_id);
}

// 记录入节点访问记录
visited_array[ep_id] = visited_array_tag;

//
while (!candidate_set.empty()) {
std::pair<dist_t, tableint> current_node_pair = candidate_set.top();
dist_t candidate_dist = -current_node_pair.first;

// 如果距离插入节点最近的候选节点的距离都大于结果队列中的最上界,那说明这些候选节点都可以被剪枝掉,直接终止循环
bool flag_stop_search;
if (bare_bone_search) {
flag_stop_search = candidate_dist > lowerBound;
}
if (flag_stop_search) {
break;
}
candidate_set.pop();

// 取出当前节点的邻居节点
tableint current_node_id = current_node_pair.second;
int *data = (int *) get_linklist0(current_node_id);
size_t size = getListCount((linklistsizeint*)data);
// 遍历邻居节点
for (size_t j = 1; j <= size; j++) {
int candidate_id = *(data + j);
if (!(visited_array[candidate_id] == visited_array_tag)) {
visited_array[candidate_id] = visited_array_tag;

char *currObj1 = (getDataByInternalId(candidate_id));
dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_);

// 这里采用了另一种写法,也就是直接计算出或操作的结果存储到寄存器 flag_consider_candidate 中。之前特意说过,这里的条件判断可以通过将大批量命中的分支条件放在前面,来优化流水线分支预测的命中率。这里通过计算+寄存器的写法让可读性更高。
// 下面代码的功能与之前阅读的 searchBaseLayer 一致,都是过滤邻居节点,将比目前下界更近的邻居节点插入到候选队列和结果队列中,并维持结果队列的数量小于等于 ef 。
bool flag_consider_candidate;
flag_consider_candidate = top_candidates.size() < ef || lowerBound > dist;

if (flag_consider_candidate) {
candidate_set.emplace(-dist, candidate_id);

if (bare_bone_search) {
top_candidates.emplace(dist, candidate_id);
}

bool flag_remove_extra = false;
flag_remove_extra = top_candidates.size() > ef;
while (flag_remove_extra) {
tableint id = top_candidates.top().second;
top_candidates.pop();
flag_remove_extra = top_candidates.size() > ef;
}

if (!top_candidates.empty())
lowerBound = top_candidates.top().first;
}
}
}
}

visited_list_pool_->releaseVisitedList(vl);
return top_candidates;
}

注释中我添加了对代码的讲解,可以看到,这段代码与之前的 searchBaseLayer 几乎完全一致,仅有的区别也是在写法上的。所以到此为止,论文中提到的算法我们都阅读完了。

剩下的就是 hnswlib 对于 HNSW 索引的工业实现,通过 假删除 + 更新 的机制实现了增删改查。这些机制我们日后来读。

PS:总结一下剩下的坑

  1. SSE 优化的讲解
  2. 删除、更新机制在工业上的实现
  3. collect_metrics 参数的含义