寻找K近邻及其实现
classification
machine-learning
matlab
5
0

我正在使用具有欧几里得距离的KNN对简单数据进行分类。我已经看到了一个示例,该示例说明了如何使用MATLAB knnsearch函数完成此knnsearch ,如下所示:

load fisheriris 
x = meas(:,3:4);
gscatter(x(:,1),x(:,2),species)
newpoint = [5 1.45];
[n,d] = knnsearch(x,newpoint,'k',10);
line(x(n,1),x(n,2),'color',[.5 .5 .5],'marker','o','linestyle','none','markersize',10)

上面的代码获取一个新点,即[5 1.45]并找到最接近该新点的10个值。谁能给我展示一个MATLAB算法,并详细说明knnsearch函数的作用?还有其他方法吗?

参考资料:
Stack Overflow
收藏
评论
共 1 个回答
高赞 时间 活跃

K最近邻(KNN)算法的基础是,您拥有一个由N行和M列组成的数据矩阵,其中N是我们拥有的数据点的数量,而M是每个数据点的维数。例如,如果我们将笛卡尔坐标放置在数据矩阵内,则通常是N x 2N x 3矩阵。使用此数据矩阵,您可以提供一个查询点,并在该数据矩阵中搜索与该查询点最接近的k个点。

我们通常使用查询与数据矩阵中其余点之间的欧几里得距离来计算我们的距离。但是,也可以使用其他距离,例如L1或城市街区/曼哈顿距离。执行此操作后,您将拥有N欧几里得距离或曼哈顿距离,它们代表查询与数据集中每个对应点之间的距离。一旦你找到了这些,你只需搜索k由按升序排列,检索这些排序的距离最近的点到查询k有你的数据集和查询之间的最小距离点。

假设您的数据矩阵存储在x ,并且newpoint是一个样本点,其中有M列(即1 x M ),这是按照点形式进行的一般过程:

  1. 找到newpointx每个点之间的欧几里得距离或曼哈顿距离。
  2. 将这些距离按升序排序。
  3. 返回x中最接近newpointk数据点。

让我们慢慢地做每一步。


步骤1

某人可能这样做的一种方式可能是在for循环中,如下所示:

N = size(x,1);
dists = zeros(N,1);
for idx = 1 : N
    dists(idx) = sqrt(sum((x(idx,:) - newpoint).^2));
end

如果要实现曼哈顿距离,则只需:

N = size(x,1);
dists = zeros(N,1);
for idx = 1 : N
    dists(idx) = sum(abs(x(idx,:) - newpoint));
end

dists是一个N元素向量,其中包含x每个数据点与newpoint之间的距离。我们在x newpoint和数据点之间进行newpoint元素减法,将差平方,然后将它们全部sum 。然后,该和为平方根,从而完成欧几里得距离。对于曼哈顿距离,您将逐个元素相减,取绝对值,然后将所有分量求和。这可能是最容易理解的实现,但可能效率最低...尤其是对于较大的数据集和较大的数据维。

另一个可能的解决方案是复制newpoint并使该矩阵的大小与x相同,然后对该矩阵进行逐元素减法,然后对每一行的所有列求和,并求平方根。因此,我们可以执行以下操作:

N = size(x, 1);
dists = sqrt(sum((x - repmat(newpoint, N, 1)).^2, 2));

对于曼哈顿距离,您可以:

N = size(x, 1);
dists = sum(abs(x - repmat(newpoint, N, 1)), 2);

repmat接受矩阵或向量,并在给定方向上重复一定次数。在我们的例子中,我们要获取新newpoint向量,并将其相叠N次以创建N x M矩阵,其中每行的长度为M元素。我们将这两个矩阵相减,然后对每个分量求平方。完成此操作后,我们将对每一行的所有列sum ,最后取所有结果的平方根。对于曼哈顿距离,我们进行减法,取绝对值,然后求和。

但是,我认为最有效的方法是使用bsxfun 。这实质上完成了我们通过单个函数调用在后台进行的复制。因此,代码就是这样:

dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));

对我来说,这看起来更干净,而且很重要。对于曼哈顿距离,您可以:

dists = sum(abs(bsxfun(@minus, x, newpoint)), 2);

第2步

现在我们有了距离,我们只需对它们进行排序。我们可以使用sort对距离进行排序:

[d,ind] = sort(dists);

d将包含按升序排序的距离,而ind告诉您未排序数组中每个值在排序结果中出现的位置。我们需要使用ind ,提取此向量的前k元素,然后使用ind索引到我们的x数据矩阵中以返回最接近newpoint那些点。

步骤三

最后一步是现在返回最接近newpoint k数据点。我们可以很简单地做到这一点:

ind_closest = ind(1:k);
x_closest = x(ind_closest,:);

ind_closest应该在原始数据矩阵x中包含最接近newpoint 。具体来说, ind_closest包含需要从x进行采样的 ,以获得与newpoint最接近的点。 x_closest将包含这些实际数据点。


为了您的复制和粘贴乐趣,代码如下所示:

dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
%// Or do this for Manhattan
% dists = sum(abs(bsxfun(@minus, x, newpoint)), 2);
[d,ind] = sort(dists);
ind_closest = ind(1:k);
x_closest = x(ind_closest,:);

在您的示例中运行,让我们来看一下我们的代码:

load fisheriris 
x = meas(:,3:4);
newpoint = [5 1.45];
k = 10;

%// Use Euclidean
dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
[d,ind] = sort(dists);
ind_closest = ind(1:k);
x_closest = x(ind_closest,:);

通过检查ind_closestx_closest ,我们得到的是:

>> ind_closest

ind_closest =

   120
    53
    73
   134
    84
    77
    78
    51
    64
    87

>> x_closest

x_closest =

    5.0000    1.5000
    4.9000    1.5000
    4.9000    1.5000
    5.1000    1.5000
    5.1000    1.6000
    4.8000    1.4000
    5.0000    1.7000
    4.7000    1.4000
    4.7000    1.4000
    4.7000    1.5000

如果您运行knnsearch ,您将看到变量nind_closest匹配。但是,变量d返回newpoint点到每个点x距离 ,而不是实际数据点本身。如果需要实际距离,只需在我编写的代码之后执行以下操作:

dist_sorted = d(1:k);

请注意,以上答案在N示例中仅使用一个查询点。 KNN通常经常在多个示例上同时使用。假设我们有要在KNN中测试的Q查询点。这将导致一个kx M x Q矩阵,其中对于每个示例或每个切片,我们将返回维度为Mk最近点。或者,我们可以返回k最接近点的ID ,从而得到一个Q xk矩阵。让我们计算两者。

天真的方法是将上述代码循环应用,并遍历每个示例。

在分配Q xk矩阵并应用基于bsxfun的方法来将输出矩阵的每一行设置为数据集中的k最近点时, bsxfun的方法会起作用,就像以前一样,在这里我们将使用Fisher Iris数据集。我们还将保持与上一个示例相同的尺寸,并且我将使用四个示例,因此Q = 4M = 2

%// Load the data and create the query points
load fisheriris;
x = meas(:,3:4);
newpoints = [5 1.45; 7 2; 4 2.5; 2 3.5];

%// Define k and the output matrices
Q = size(newpoints, 1);
M = size(x, 2);
k = 10;
x_closest = zeros(k, M, Q);
ind_closest = zeros(Q, k);

%// Loop through each point and do logic as seen above:
for ii = 1 : Q
    %// Get the point
    newpoint = newpoints(ii, :);

    %// Use Euclidean
    dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
    [d,ind] = sort(dists);

    %// New - Output the IDs of the match as well as the points themselves
    ind_closest(ii, :) = ind(1 : k).';
    x_closest(:, :, ii) = x(ind_closest(ii, :), :);
end

尽管这很好,但我们可以做得更好。有一种方法可以有效地计算两组向量之间的平方欧几里德距离。如果您想在曼哈顿进行此操作,我将其保留为练习。咨询该博客 ,假设A是一个Q1 x M矩阵,其中每一行是具有Q1点的维数M点,而B是一个Q2 x M矩阵,其中每一行也是具有Q2点的维数M点。计算的距离矩阵D(i, j)其中在所述行元素i和列j表示行之间的距离iA和行jB使用以下基质制剂:

nA = sum(A.^2, 2); %// Sum of squares for each row of A
nB = sum(B.^2, 2); %// Sum of squares for each row of B
D = bsxfun(@plus, nA, nB.') - 2*A*B.'; %// Compute distance matrix
D = sqrt(D); %// Compute square root to complete calculation

因此,如果我们让A为查询点矩阵,而B为由原始数据组成的数据集,则可以通过分别对每行进行排序并确定每行的k最小的位置来确定k最接近的点。我们还可以另外使用它来检索实际点本身。

因此:

%// Load the data and create the query points
load fisheriris;
x = meas(:,3:4);
newpoints = [5 1.45; 7 2; 4 2.5; 2 3.5];

%// Define k and other variables
k = 10;
Q = size(newpoints, 1);
M = size(x, 2);

nA = sum(newpoints.^2, 2); %// Sum of squares for each row of A
nB = sum(x.^2, 2); %// Sum of squares for each row of B
D = bsxfun(@plus, nA, nB.') - 2*newpoints*x.'; %// Compute distance matrix
D = sqrt(D); %// Compute square root to complete calculation 

%// Sort the distances 
[d, ind] = sort(D, 2);

%// Get the indices of the closest distances
ind_closest = ind(:, 1:k);

%// Also get the nearest points
x_closest = permute(reshape(x(ind_closest(:), :).', M, k, []), [2 1 3]);

我们看到,用于计算距离矩阵的逻辑是相同的,但是一些变量已更改为适合示例。我们还使用sort的两个输入版本对每一行进行独立sort ,因此ind将包含每行的ID,而d将包含相应的距离。然后,通过将矩阵简单地截断为k列,我们可以找出最接近每个查询点的索引。然后,我们使用permutereshape来确定相关的最接近点。我们首先使用所有最接近的索引,然后创建一个点矩阵,该矩阵将所有ID相互叠加,从而得到Q * kx M矩阵。使用reshapepermute可以使我们创建3D矩阵,使其变成我们指定的kx M x Q矩阵。如果您想自己获得实际的距离,我们可以索引d并获取我们需要的东西。为此,您将需要使用sub2ind来获得线性索引,以便我们可以一次性拍摄到dind_closest的值已经为我们提供了需要访问哪些列。我们需要访问的行仅是1, k次,2, k次, k类推,直到Q为止。 k是我们要返回的点数:

row_indices = repmat((1:Q).', 1, k);
linear_ind = sub2ind(size(d), row_indices, ind_closest);
dist_sorted = D(linear_ind);

当我们为上述查询点运行以上代码时,这些是我们获得的索引,点和距离:

>> ind_closest

ind_closest =

   120   134    53    73    84    77    78    51    64    87
   123   119   118   106   132   108   131   136   126   110
   107    62    86   122    71   127   139   115    60    52
    99    65    58    94    60    61    80    44    54    72

>> x_closest

x_closest(:,:,1) =

    5.0000    1.5000
    6.7000    2.0000
    4.5000    1.7000
    3.0000    1.1000
    5.1000    1.5000
    6.9000    2.3000
    4.2000    1.5000
    3.6000    1.3000
    4.9000    1.5000
    6.7000    2.2000


x_closest(:,:,2) =

    4.5000    1.6000
    3.3000    1.0000
    4.9000    1.5000
    6.6000    2.1000
    4.9000    2.0000
    3.3000    1.0000
    5.1000    1.6000
    6.4000    2.0000
    4.8000    1.8000
    3.9000    1.4000


x_closest(:,:,3) =

    4.8000    1.4000
    6.3000    1.8000
    4.8000    1.8000
    3.5000    1.0000
    5.0000    1.7000
    6.1000    1.9000
    4.8000    1.8000
    3.5000    1.0000
    4.7000    1.4000
    6.1000    2.3000


x_closest(:,:,4) =

    5.1000    2.4000
    1.6000    0.6000
    4.7000    1.4000
    6.0000    1.8000
    3.9000    1.4000
    4.0000    1.3000
    4.7000    1.5000
    6.1000    2.5000
    4.5000    1.5000
    4.0000    1.3000

>> dist_sorted

dist_sorted =

    0.0500    0.1118    0.1118    0.1118    0.1803    0.2062    0.2500    0.3041    0.3041    0.3041
    0.3000    0.3162    0.3606    0.4123    0.6000    0.7280    0.9055    0.9487    1.0198    1.0296
    0.9434    1.0198    1.0296    1.0296    1.0630    1.0630    1.0630    1.1045    1.1045    1.1180
    2.6000    2.7203    2.8178    2.8178    2.8320    2.9155    2.9155    2.9275    2.9732    2.9732

将此与knnsearch进行比较,您可以为第二个参数指定点矩阵,其中每行都是一个查询点,您将看到此实现和knnsearch之间的索引和排序距离匹配。


希望这对您有所帮助。祝好运!

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号