随机投影采样近似距离ADSampling算法论文阅读笔记-High-Dimensional Approximate Nearest Neighbor Search with Reliable and Efficient Distance Comparison Operations

论文简介

High-Dimensional Approximate Nearest Neighbor Search with Reliable and Efficient Distance Comparison Operations 是一片发表于 SIGMOD 2023 的论文。

论文贡献

1、论文通过 “是否更新候选集“ 为指标,将 相似度计算DCO 分为了两类,能够更新候选集的为 positive ,否则为 negative 。通过实验论证了 HNSW 、 IVF 、 IVF-PQ 中大量 DCO 都是无效的 。

2、论文提出了动态随机投影 ADSampling 用于替换传统的 FDScanning 也就是 DCO ,并且几乎没有精度损失,论文将其称为 AKNN+ 。

3、在 HNSW+ 和 IVF+ 上 ,论文通过一个缓存友好的存储布局进一步优化了性能,将其命名为 HNSW++ 和 IVF++ 。

主要内容

1、AKNN Algorithms and Their DCOs

DCOs 指的就是相似度计算消耗。这章节比较简单,主要内容可以用下面的图来概括:

alt text

简单来说,在大部分索引中,大部分的相似度运算都对迭代无效

2、 ADSAMPLING 随机采样距离估算

因为第一章节提出了这个现象,所以论文认为可以通过动态地控制相似度计算代价,来减少 negative 的运算代价,同时尽可能地不影响 positive 部分的运算精度。

论文从简单的随机采样距离计算的误差精度理论开始,描述了 ADSAMPLING 算法的设计思路。

2.1 从 随机投影 到 维度采样 的线性代数基本原理

2.1.1 关于随机降维投影

为了提高距离比较操作的效率,一种自然的想法是对对象执行随机投影,在更少的维度上计算距离,从而降低高维度的计算消耗 ( DCOs 是一个时间复杂度 O(D) 的操作,D 是向量维度)。具体来说,是将一个高维向量 x 与一个 d 行 D 列的随机正交矩阵相乘 (d < D 才能取得降维的效果),然后基于得到的低维向量计算近似距离。

然而,一旦对象被投影,其对应的近似距离就会有一个固定的分辨率,缺乏灵活性。也就是论文中的 Lemma 3.1

alt text

Lemma 3.1来自:High-Dimensional Probability: An Introduction with Applications in Data Science (2018.)

其中 ||X|| 是向量 x 的欧几里得范数(norm):

alt text

因此这个绝对值的含义是:

alt text

低维空间投影后的向量的欧几里得范数 经过与压缩维度成比例的放大后原向量欧几里得范数 的差值。也就是 投影 操作造成的失真度

这个失真度是一个与 原始向量 x投影矩阵P 相关的值,只要给定两个固定的 xP ,就能得到一个固定的失真度

但是为了表示任何 xP 失真度的取值范围,论文《High-Dimensional Probability: An Introduction with Applications in Data Science》 基于失真度限定参数 ε ,给出了任何 xP ,能够满足这个失真度的概率。

这个失真度在公式的后半部分,用 原始向量x欧几里得范数 乘一个 可变比例 ε 进行了描述,也就是下面这张图:

alt text

我们简单地推导下边界条件,当 ε 趋近 正无穷 时,

alt text

这个式子百分百正确,理由是 不等式右侧趋近 正无穷,而左侧是一个固定正实数。

同时,概率不等式的右侧

alt text

也趋近于1,也就是说左侧式子为 true 的概率趋近于 1 。

到这里我们其实就可以理解,为什么说固定的投影矩阵只能提供一个固定的精度,而不能动态地适配不同的精度需求。理由其实就是当向量被投影到一个固定维度的空间上时,保障的最小精度损失是一个概率式,而不是一个固定值。我们只能说对于某个 P 和 任何 x ,有多少概率能提供多大的精度误差,但是不能说这是百分百的。

简单理解了下 Lemma 3.1 ,我们接下来需要关注,Lemma 3.1 到底有什么用?

我们看看向量的相似度计算公式:

alt text

可以看到,投影后的相似度其实就是一个 ||Px||,也就是 Lemma 3.1 里所谓的投影向量,而 ||Px|| 与 ||x|| 的差的绝对值,其实就是投影距离与精准距离之间的差值,也就是我们最关注的 低维距离与原始距离的差值 ,即 近似操作造成的距离误差

2.1.2 先降维和后降维

论文的目标是能够提供动态的近似距离计算精度,因此一个简单的想法就是将原始向量 o 都乘一个相同的投影向量 P’ (D 行 D 列),得到不降维的投影向量,然后根据不同的精度需求,随机从 投影向量 中抽取不同数量的维度,进行距离计算。维度越高,精度越高。

维度越高,精度越高 这个理论,我们可以自己带入到 Lemma 3.1 中来思考下。 这里就不展开了。

因此论文用简单的矩阵乘法介绍了 对原始数据不降维投影 与 随机降维投影 的关系:
alt text
在这里, P’ 是一个 D 行 D 列 的正交投影矩阵,而

alt text
这个公式的含义是随机抽取矩阵 P’d 行作为一个新的矩阵。

公式
alt text
的含义是先把 x 投影到 D*D 空间上再随机抽取 d 行,也就是随机抽取 d 维的投影结果。

这是一个简单的矩阵乘法,所以两者是相等的。因此我们就通过一个 D*D 的投影运算后,原始向量 o 就可以同时满足多种不同的精度需求,精度需求越高,我们抽取的维度越多即可

论文把这种相似度计算定义为:

alt text

2.2 基于假设检验的采样维度增加机制

有了 2.1 章节的近似距离公式,我们就可以同时满足不同的查询精度。

回忆一下第一章节,大部分 DCOs 都是无效的,但怎么判断到底是不是到了精细搜索阶段了呢?

论文的设计很简单——假设检验:

alt text

论文给定了一个距离阈值 r ,如果 距离小于 r ,就认为两个向量离得很近了,需要精细化搜索;否则就直接认为当前 o 是一个 negative object ,可以被忽略。

这里我们直接用通一千问的解读,就不过多赘述了:

alt text

虽然论文的假设检验机制很有趣,但在 graph based 这里还有一些别的做法,可以阅读我之前的博客:

FINGER: Fast Inference for Graph-based Approximate Nearest Neighbor Search (WWW 2023)

2.3 实现伪代码

alt text

伪代码基本与上面的设计一致,值得一说的是,在时间复杂度的分析上,论文采用了 采样维度期望 的思路进行分析,也就是基于平均失败概率进行的分析。我们这里就不过多讲解了。感兴趣的可以阅读下原文:

alt text

在保证一定准确性的前提下, ADSampling 方法可以显著减少对于负对象(negative objects)所需的计算复杂度,从线性复杂度 O(D) 降低到对数复杂度 O(logD) 。这一结果强调了随着 ϵ0 的增加,虽然时间复杂度会相应增加,但失败概率会以二次指数形式下降。

需要额外提一嘴的是,在需要增加采样维度时,之前采样维度的距离计算结果是可以保留的,只需要再计算新增维度的差值平方然后加和即可,所以总体的时间复杂度只和最终的需求维度 d 相关。

3、 AKNN+ 使用 ADSampling 加速向量检索

简而言之,就是直接用 ADSampling 替换了 DCOs 运算符,工作量不是很大:

alt text

我把 hnsw 的实现函数贴在了下面,

adsampling:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
/*
The file is the core of the ADSampling algorithm.
We have included detailed comments in the function dist_comp.
Note that in the whole algorithm we do not calculate the square root of the distances.
*/

#include <cmath>
#include <limits>
#include <queue>
#include <vector>
#include <iostream>

using namespace std;

namespace adsampling{

unsigned int D = 960; // The dimensionality of the dataset.
float epsilon0 = 2.1; // epsilon0 - by default 2.1, recommended in [1.0,4.0], valid in in [0, +\infty)
unsigned int delta_d = 32; // dimension sampling for every delta_d dimensions.

long double distance_time = 0;
unsigned long long tot_dimension = 0;
unsigned long long tot_dist_calculation = 0;
unsigned long long tot_full_dist = 0;

void clear(){
distance_time = 0;
tot_dimension = 0;
tot_dist_calculation = 0;
tot_full_dist = 0;
}

// The hypothesis testing checks whether \sqrt{D/d} dis' > (1 + epsilon0 / \sqrt{d}) * r.
// We equivalently check whether dis' > \sqrt{d/D} * (1 + epsilon0 / \sqrt{d}) * r.
inline float ratio(const int &D, const int &i){
if(i == D)return 1.0;
return 1.0 * i / D * (1.0 + epsilon0 / std::sqrt(i)) * (1.0 + epsilon0 / std::sqrt(i));
}

/*
float dist_comp(const float&, const void *, const void *, float, int) is a generic function for DCOs.

When D, epsilon_0 and delta_d can be pre-determined, it is highly suggested to define them as constexpr and provide dataset-specific functions.
*/
float dist_comp(const float& dis, const void *data, const void *query,
float res = 0, int i = 0){
// If the algorithm starts a non-zero dimensionality (i.e., the case of IVF++), we conduct the hypothesis testing immediately.
if(i && res >= dis * ratio(D, i)){
return -res * D / i;
}
float * q = (float *) query;
float * d = (float *) data;

while(i < D){
// It continues to sample additional delta_d dimensions.
int check = std::min(delta_d, D-i);
i += check;
for(int j = 1;j<=check;j++){
float t = *d - *q;
d ++;
q ++;
res += t * t;
}
// Hypothesis tesing
if(res >= dis * ratio(D, i)){

// If the null hypothesis is reject, we return the approximate distance.
// We return -dis' to indicate that it's a negative object.
return -res * D / i;
}
}

// We return the exact distance when we have sampled all the dimensions.
return res;
}

};

float sqr_dist(float* a, float* b, int D){
float ret = 0;
for(int i=0;i!=D;i++){
float tmp = (*a - *b);
ret += tmp * tmp;
a++;
b++;
}
return ret;
}

HNSW+ 查询:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
//max heap
std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(void *query_data, size_t k, int adaptive=0) const {

std::priority_queue<std::pair<dist_t, labeltype >> result;
if (cur_element_count == 0) return result;

tableint currObj = enterpoint_node_;

dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);

adsampling::tot_dist_calculation ++;
for (int level = maxlevel_; level > 0; level--) {

bool changed = true;
while (changed) {

changed = false;
unsigned int *data;

data = (unsigned int *) get_linklist(currObj, level);
int size = getListCount(data);
metric_hops++;
metric_distance_computations+=size;

tableint *datal = (tableint *) (data + 1);
for (int i = 0; i < size; i++) {
tableint cand = datal[i];
if (cand < 0 || cand > max_elements_)
throw std::runtime_error("cand error");
adsampling::tot_dist_calculation ++;
if(adaptive){

// 计算采样近似距离
dist_t d = adsampling::dist_comp(curdist, getDataByInternalId(cand), query_data, 0, 0);

if(d > 0){
curdist = d;
currObj = cand;
changed = true;
}
}
else {

dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_);

adsampling::tot_full_dist ++;
if (d < curdist) {
curdist = d;
currObj = cand;
changed = true;
}
}
}
}
}
//max heap
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>> top_candidates;

if(adaptive == 1) top_candidates=searchBaseLayerADstar<true,true>(currObj, query_data, std::max(ef_, k), k);
else if(adaptive == 2) top_candidates=searchBaseLayerAD<true,true>(currObj, query_data, std::max(ef_, k));
else top_candidates=searchBaseLayerST<false,true>(currObj, query_data, std::max(ef_, k));

while (top_candidates.size() > k) {
top_candidates.pop();
}
while (top_candidates.size() > 0) {
std::pair<dist_t, tableint> rez = top_candidates.top();
result.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second)));
top_candidates.pop();
}
return result;
};

4、 AKNN++ 额外的适配工作

4.1 HNSW++: Towards More Approximation

alt text

我们这里就不继续阅读了,核心思想很简单:

1、加大 candidate 的计算错误率来减少计算消耗。
2、结果队列使用全精度的距离,候选集里使用估算距离。

HNSW++ 是有可能降低精度的,但是概率相对较小,理由是

1、被过滤掉的数据几乎百分比劣于 R1 中的最差值
2、邻域图的贪心搜索算法有多维导航的可能性,即使一个 route 被 ban 掉了,还存在其它路径能够导航至目标节点

alt text

实现代码我贴在了下面:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
template <bool has_deletions, bool collect_metrics=false>
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>>
searchBaseLayerADstar(tableint ep_id, const void *data_point, size_t ef, size_t k) const {
VisitedList *vl = visited_list_pool_->getFreeVisitedList();
vl_type *visited_array = vl->mass;
vl_type visited_array_tag = vl->curV;

// answers - the KNN set R1
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>> answers;
// top_candidates - the result set R2
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>> top_candidates;
// candidate_set - the search set S
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;

dist_t lowerBound;
dist_t lowerBoundcan;
// Insert the entry point to the result and search set with its exact distance as a key.
if (!has_deletions || !isMarkedDeleted(ep_id)) {

dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);

adsampling::tot_dist_calculation++;
adsampling::tot_full_dist ++;
lowerBound = dist;
lowerBoundcan = dist;
answers.emplace(dist, ep_id);
top_candidates.emplace(dist, ep_id);
candidate_set.emplace(-dist, ep_id);
}
else {
lowerBound = std::numeric_limits<dist_t>::max();
lowerBoundcan = std::numeric_limits<dist_t>::max();
candidate_set.emplace(-lowerBound, ep_id);
}

visited_array[ep_id] = visited_array_tag;
int cnt_visit = 0;
// Iteratively generate candidates.
while (!candidate_set.empty()) {
std::pair<dist_t, tableint> current_node_pair = candidate_set.top();

// When the smallest object in S has its distance larger than the largest in R2, terminate the algorithm.
if ((-current_node_pair.first) > top_candidates.top().first && (top_candidates.size() == ef || has_deletions == false)) {
break;
}
candidate_set.pop();

// Fetch the smallest object in S.
tableint current_node_id = current_node_pair.second;
int *data = (int *) get_linklist0(current_node_id);
size_t size = getListCount((linklistsizeint*)data);
if(collect_metrics){
metric_hops++;
metric_distance_computations+=size;
}


// Enumerate all the neighbors of the object and view them as candidates of KNNs.
for (size_t j = 1; j <= size; j++) {
int candidate_id = *(data + j);
if (!(visited_array[candidate_id] == visited_array_tag)) {
cnt_visit ++;
visited_array[candidate_id] = visited_array_tag;


// If the KNN set is not full, then calculate the exact distance. (i.e., assume the distance threshold to be infinity)
if (answers.size() < k){
char *currObj1 = (getDataByInternalId(candidate_id));

dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_);

adsampling::tot_full_dist ++;
if (!has_deletions || !isMarkedDeleted(candidate_id)){
candidate_set.emplace(-dist, candidate_id);
top_candidates.emplace(dist, candidate_id);
answers.emplace(dist, candidate_id);
}
if (!answers.empty())
lowerBound = answers.top().first;
if (!top_candidates.empty())
lowerBoundcan = top_candidates.top().first;
}
// Otherwise, conduct DCO with ADSampling wrt the Kth NN.
else {
char *currObj1 = (getDataByInternalId(candidate_id));

dist_t dist = adsampling::dist_comp(lowerBound, currObj1, data_point, 0, 0);

// If it's a positive object, then include it in R1, R2, S.
if(dist >= 0){
candidate_set.emplace(-dist, candidate_id);
if(!has_deletions || !isMarkedDeleted(candidate_id)){
top_candidates.emplace(dist, candidate_id);
answers.emplace(dist, candidate_id);
}
if(top_candidates.size() > ef)
top_candidates.pop();
if(answers.size() > k)
answers.pop();

if (!answers.empty())
lowerBound = answers.top().first;
if (!top_candidates.empty())
lowerBoundcan = top_candidates.top().first;
}
// If it's a negative object, then update R2, S with the approximate distance.
else{
if(top_candidates.size() < ef || lowerBoundcan > -dist){
top_candidates.emplace(-dist, candidate_id);
candidate_set.emplace(dist, candidate_id);
}
if(top_candidates.size() > ef){
top_candidates.pop();
}
if (!top_candidates.empty())
lowerBoundcan = top_candidates.top().first;
}
}
}
}
}
adsampling::tot_dist_calculation += cnt_visit;
visited_list_pool_->releaseVisitedList(vl);
return answers;
}

4.2 IVF++: Towards Cache Friendliness

alt text

入上图,论文做了一个类似列存的优化,来增加局部性,减少 cache miss 。

写在后面

1、论文的优化思路很有趣:通过某些维度的距离来判断两个向量是不是很近,如果很近则计算更精确的距离,直到得到了全精度距离。这个思路把欧几里得距离计算的最小计算单元从整体中拆掉,让距离非常远的向量只需要几个维度的运算就能被排除掉。

2、论文通过一个随机的投影矩阵,论证了距离错误率的上下界,是很有趣的理论证明。

3、论文代码仓库为 https://github.com/gaoj0017/ADSampling ,有时间了可以考虑详细读一下。