From a4d82e14ccc271542668ff3371f7f142db122538 Mon Sep 17 00:00:00 2001 From: Ben Karsin Date: Wed, 29 Apr 2026 17:05:08 -0700 Subject: [PATCH 01/14] Initial multi-warp greedy search optimization (cherry picked from commit 14e36b3feae134be2a71fb8304f0dec20a0ab37c) --- .../neighbors/detail/vamana/greedy_search.cuh | 190 ++++++++---------- .../detail/vamana/priority_queue.cuh | 89 +++++++- .../neighbors/detail/vamana/vamana_build.cuh | 43 ++-- .../detail/vamana/vamana_structs.cuh | 49 ++++- 4 files changed, 240 insertions(+), 131 deletions(-) diff --git a/cpp/src/neighbors/detail/vamana/greedy_search.cuh b/cpp/src/neighbors/detail/vamana/greedy_search.cuh index 4e71c1189c..0a017e3cb7 100644 --- a/cpp/src/neighbors/detail/vamana/greedy_search.cuh +++ b/cpp/src/neighbors/detail/vamana/greedy_search.cuh @@ -1,11 +1,11 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once -#include +#include #include "macros.cuh" #include "priority_queue.cuh" @@ -74,6 +74,9 @@ __global__ void SortPairsKernel(void* query_list_ptr, int num_queries, int topk) /******************************************************************************************** GPU kernel to perform a batched GreedySearch on a graph. Since this is used for Vamana construction, the entire visited list is kept and stored within the query_list. + Uses 128 threads per block (4 warps), each warp processes one query independently + with per-warp scratch space to avoid block synchronization overhead. + Input - graph with edge lists, dataset vectors, query_list_ptr with the ids of dataset vectors to be searched. All inputs, including dataset, must be device accessible. @@ -85,7 +88,7 @@ template , raft::memory_type::host>> -__global__ void GreedySearchKernel( +__global__ __launch_bounds__(128, 12) void GreedySearchKernel( raft::device_matrix_view graph, raft::mdspan, raft::row_major, Accessor> dataset, void* query_list_ptr, @@ -96,186 +99,159 @@ __global__ void GreedySearchKernel( int max_queue_size, Node* topk_pq_mem) { - int n = dataset.extent(0); + const int warpIdx = threadIdx.x / 32; + const int laneId = threadIdx.x % 32; + int dim = dataset.extent(1); int degree = graph.extent(1); QueryCandidates* query_list = static_cast*>(query_list_ptr); - static __shared__ int topk_q_size; - static __shared__ int cand_q_size; - static __shared__ accT cur_k_max; - static __shared__ int k_max_idx; - - static __shared__ Point s_query; - union ShmemLayout { - // All blocksort sizes have same alignment (16) T coords; - int neighborhood_arr; + IdxT neighborhood_arr; DistPair candidate_queue; }; + int align_padding = raft::alignTo(dim, 16) - dim; - int align_padding = (((dim - 1) / alignof(ShmemLayout)) + 1) * alignof(ShmemLayout) - dim; - - // Dynamic shared memory used for blocksort, temp vector storage, and neighborhood list extern __shared__ __align__(alignof(ShmemLayout)) char smem[]; - size_t smem_offset = 0; - - T* s_coords = reinterpret_cast(&smem[smem_offset]); - smem_offset += (dim + align_padding) * sizeof(T); - - Node* topk_pq = &topk_pq_mem[blockIdx.x * topk]; - - int* neighbor_array = reinterpret_cast(&smem[smem_offset]); - smem_offset += degree * sizeof(int); + // Per-warp shared memory layout: coords, neighbor_array, candidate_queue + const int coords_size = (dim + align_padding) * sizeof(T); + const int neighbor_size = degree * sizeof(IdxT); + const int queue_size_bytes = max_queue_size * sizeof(DistPair); + const int per_warp_size = (coords_size + neighbor_size + queue_size_bytes + 15) & ~15; + char* warp_smem = &smem[warpIdx * per_warp_size]; + T* s_coords = + reinterpret_cast(warp_smem); + IdxT* neighbor_array = reinterpret_cast(warp_smem + coords_size); DistPair* candidate_queue_smem = - reinterpret_cast*>(&smem[smem_offset]); + reinterpret_cast*>(warp_smem + coords_size + neighbor_size); - s_query.coords = s_coords; + static __shared__ int topk_q_size[4]; + static __shared__ int cand_q_size[4]; + static __shared__ accT cur_k_max[4]; + static __shared__ int k_max_idx[4]; + static __shared__ int num_neighbors[4]; + + Point s_query; s_query.Dim = dim; + s_query.coords = s_coords; PriorityQueue heap_queue; - - if (threadIdx.x == 0) { - heap_queue.initialize(candidate_queue_smem, max_queue_size, &cand_q_size); + if (laneId == 0) { + heap_queue.initialize(candidate_queue_smem, max_queue_size, &cand_q_size[warpIdx]); } - static __shared__ int num_neighbors; - - for (int i = blockIdx.x; i < num_queries; i += gridDim.x) { - __syncthreads(); + Node* topk_pq = &topk_pq_mem[(blockIdx.x * 4 + warpIdx) * topk]; + const T* vec_ptr = &dataset(0, 0); - // resetting visited list - query_list[i].reset(); + for (int i = blockIdx.x * 4 + warpIdx; i < num_queries; i += gridDim.x * 4) { + query_list[i].reset_warp(laneId); - // storing the current query vector into shared memory - update_shared_point(&s_query, &dataset(0, 0), query_list[i].queryId, dim); + update_shared_point_warp(&s_query, vec_ptr, query_list[i].queryId, dim, laneId); - if (threadIdx.x == 0) { - topk_q_size = 0; - cand_q_size = 0; - s_query.id = query_list[i].queryId; - cur_k_max = 0; - k_max_idx = 0; + if (laneId == 0) { + topk_q_size[warpIdx] = 0; + cand_q_size[warpIdx] = 0; + s_query.id = query_list[i].queryId; + cur_k_max[warpIdx] = 0; + k_max_idx[warpIdx] = 0; heap_queue.reset(); } - __syncthreads(); - - Point* query_vec; + Point* query_vec = &s_query; + query_vec->Dim = dim; + query_vec->coords = s_coords; + accT medoid_dist = + dist(query_vec->coords, &vec_ptr[(size_t)medoid_id * (size_t)dim], dim, metric); - // Just start from medoid every time, rather than multiple set_ups - query_vec = &s_query; - query_vec->Dim = dim; - const T* medoid = &dataset((size_t)medoid_id, 0); - accT medoid_dist = dist(query_vec->coords, medoid, dim, metric); - - if (threadIdx.x == 0) { heap_queue.insert_back(medoid_dist, medoid_id); } - __syncthreads(); - - while (cand_q_size != 0) { - __syncthreads(); + if (laneId == 0) { heap_queue.insert_back(medoid_dist, medoid_id); } + while (cand_q_size[warpIdx] != 0) { int cand_num; accT cur_distance; - if (threadIdx.x == 0) { - Node test_cand; + if (laneId == 0) { DistPair test_cand_out = heap_queue.pop(); - test_cand.distance = test_cand_out.dist; - test_cand.nodeid = test_cand_out.idx; - cand_num = test_cand.nodeid; + cand_num = test_cand_out.idx; cur_distance = test_cand_out.dist; } - __syncthreads(); - - cand_num = raft::shfl(cand_num, 0); - - __syncthreads(); - - if (query_list[i].check_visited(cand_num, cur_distance)) { continue; } - + cand_num = raft::shfl(cand_num, 0); cur_distance = raft::shfl(cur_distance, 0); - // stop condition for the graph traversal process + if (query_list[i].check_visited_warp(cand_num, cur_distance, laneId)) { continue; } + bool done = false; bool pass_flag = false; - if (topk_q_size == topk) { - // Check the current node with the worst candidate in top-k queue - if (threadIdx.x == 0) { - if (cur_k_max <= cur_distance) { done = true; } + if (topk_q_size[warpIdx] == topk) { + if (laneId == 0) { + if (cur_k_max[warpIdx] <= cur_distance) { done = true; } } - done = raft::shfl(done, 0); if (done) { if (query_list[i].size < topk) { pass_flag = true; - } - - else if (query_list[i].size >= topk) { + } else if (query_list[i].size >= topk) { break; } } } - // The current node is closer to the query vector than the worst candidate in top-K queue, so - // enquee the current node in top-k queue Node new_cand; new_cand.distance = cur_distance; new_cand.nodeid = cand_num; - if (check_duplicate(topk_pq, topk_q_size, new_cand) == false) { + if (check_duplicate_warp(topk_pq, topk_q_size[warpIdx], new_cand, laneId) == false) { if (!pass_flag) { - parallel_pq_max_enqueue( - topk_pq, &topk_q_size, topk, new_cand, &cur_k_max, &k_max_idx); - - __syncthreads(); + parallel_pq_max_enqueue_warp(topk_pq, + &topk_q_size[warpIdx], + topk, + new_cand, + &cur_k_max[warpIdx], + &k_max_idx[warpIdx], + laneId); } } else { - // already visited continue; } - num_neighbors = degree; - __syncthreads(); + num_neighbors[warpIdx] = degree; - for (size_t j = threadIdx.x; j < degree; j += blockDim.x) { - // Load neighbors from the graph array and store them in neighbor array (shared memory) + for (size_t j = laneId; j < degree; j += 32) { neighbor_array[j] = graph(cand_num, j); if (neighbor_array[j] == raft::upper_bound()) - atomicMin(&num_neighbors, (int)j); // warp-wide min to find the number of neighbors + atomicMin(&num_neighbors[warpIdx], (int)j); } - // computing distances between the query vector and neighbor vectors then enqueue in priority - // queue. - enqueue_all_neighbors( - num_neighbors, query_vec, &dataset(0, 0), neighbor_array, heap_queue, dim, metric); - - __syncthreads(); - - } // End cand_q_size != 0 loop + enqueue_all_neighbors_warp(num_neighbors[warpIdx], + query_vec, + vec_ptr, + neighbor_array, + heap_queue, + dim, + metric, + laneId); + } bool self_found = false; - // Remove self edges - for (int j = threadIdx.x; j < query_list[i].size; j += blockDim.x) { + for (int j = laneId; j < query_list[i].size; j += 32) { if (query_list[i].ids[j] == query_vec->id) { query_list[i].dists[j] = raft::upper_bound(); query_list[i].ids[j] = raft::upper_bound(); - self_found = true; // Flag to reduce size by 1 + self_found = true; } } + self_found = (raft::ballot(self_found) != 0); - for (int j = query_list[i].size + threadIdx.x; j < query_list[i].maxSize; j += blockDim.x) { + for (int j = query_list[i].size + laneId; j < query_list[i].maxSize; j += 32) { query_list[i].ids[j] = raft::upper_bound(); query_list[i].dists[j] = raft::upper_bound(); } - __syncthreads(); - if (self_found) query_list[i].size--; + if (self_found && laneId == 0) { query_list[i].size--; } } return; diff --git a/cpp/src/neighbors/detail/vamana/priority_queue.cuh b/cpp/src/neighbors/detail/vamana/priority_queue.cuh index a1ce4e7159..0bccfd4db4 100644 --- a/cpp/src/neighbors/detail/vamana/priority_queue.cuh +++ b/cpp/src/neighbors/detail/vamana/priority_queue.cuh @@ -231,6 +231,22 @@ __device__ bool check_duplicate(const Node* pq, const int size, Node return true; } +// Warp-level version: each warp scans its own pq with laneId and stride 32 +template +__device__ bool check_duplicate_warp( + const Node* pq, const int size, Node new_node, int laneId) +{ + bool found = false; + for (int i = laneId; i < size; i += 32) { + if (pq[i].nodeid == new_node.nodeid) { + found = true; + break; + } + } + unsigned mask = raft::ballot(found); + return (mask != 0); +} + /* Enqueuing a input value into parallel queue with tracker */ @@ -291,6 +307,59 @@ __inline__ __device__ void parallel_pq_max_enqueue(Node* pq, __syncthreads(); } +// Warp-level version: no __syncthreads, uses laneId for single-thread ops and warp shuffle +template +__inline__ __device__ void parallel_pq_max_enqueue_warp(Node* pq, + int* size, + const int pq_size, + Node input_data, + SUMTYPE* cur_max_val, + int* max_idx, + int laneId) +{ + if (*size < pq_size) { + if (laneId == 0) { + pq[*size].distance = input_data.distance; + pq[*size].nodeid = input_data.nodeid; + *size = *size + 1; + if (input_data.distance > (*cur_max_val)) { + *cur_max_val = input_data.distance; + *max_idx = *size - 1; + } + } + return; + } else { + if (input_data.distance >= (*cur_max_val)) { return; } + if (laneId == 0) { + pq[*max_idx].distance = input_data.distance; + pq[*max_idx].nodeid = input_data.nodeid; + } + int idx = 0; + SUMTYPE max_val = pq[0].distance; + + for (int i = laneId; i < pq_size; i += 32) { + if (pq[i].distance > max_val) { + max_val = pq[i].distance; + idx = i; + } + } + + for (int offset = 16; offset > 0; offset /= 2) { + SUMTYPE new_max_val = raft::shfl_up(max_val, offset); + int new_idx = raft::shfl_up(idx, offset); + if (new_max_val > max_val) { + max_val = new_max_val; + idx = new_idx; + } + } + + if (laneId == 31) { + *max_idx = idx; + *cur_max_val = max_val; + } + } +} + /* Compute the distances between the source vector and all nodes in the neighbor_array and enqueue them in the PQ @@ -299,7 +368,7 @@ template __forceinline__ __device__ void enqueue_all_neighbors(int num_neighbors, Point* query_vec, const T* vec_ptr, - int* neighbor_array, + IdxT* neighbor_array, PriorityQueue& heap_queue, int dim, cuvs::distance::DistanceType metric) @@ -314,4 +383,22 @@ __forceinline__ __device__ void enqueue_all_neighbors(int num_neighbors, } } +// Warp-level version: lane 0 does insert_back, no __syncthreads +template +__forceinline__ __device__ void enqueue_all_neighbors_warp(int num_neighbors, + Point* query_vec, + const T* vec_ptr, + IdxT* neighbor_array, + PriorityQueue& heap_queue, + int dim, + cuvs::distance::DistanceType metric, + int laneId) +{ + for (int i = 0; i < num_neighbors; i++) { + accT dist_out = dist( + query_vec->coords, &vec_ptr[(size_t)(neighbor_array[i]) * (size_t)(dim)], dim, metric); + if (laneId == 0) { heap_queue.insert_back(dist_out, neighbor_array[i]); } + } +} + } // namespace cuvs::neighbors::vamana::detail diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index 336d81215b..6cbaa525de 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -41,8 +41,9 @@ namespace cuvs::neighbors::vamana::detail { * @{ */ -static const int blockD = 32; -static const int maxBlocks = 10000; +static const int blockD = 32; +static const int blockD_greedy = 128; // 4 warps per block, each warp processes one query +static const int maxBlocks = 10000; // generate random permutation of inserts - TODO do this on GPU / faster template @@ -158,7 +159,7 @@ void batched_insert_vamana( raft::resource::get_large_workspace_resource_ref(res), raft::make_extents(max_batchsize, visited_size)); - // Assign memory to query_list structures and initialize + // Assign memory to query_list structures and initiailize init_query_candidate_list<<<256, blockD, 0, stream>>>(query_list, visited_ids.data_handle(), visited_dists.data_handle(), @@ -186,10 +187,12 @@ void batched_insert_vamana( int sort_smem_size = 0; SELECT_SORT_SMEM_SIZE(degree, visited_size); // Sets sort_smem_size based on dataset - // Total dynamic shared memory used by GreedySearch + // GreedySearch: per-warp shared memory (4 warps): coords, neighbor_array, candidate_queue + const int coords_size = (dim + align_padding) * sizeof(T); + const int neighbor_size = degree * sizeof(IdxT); + const int queue_size_bytes = queue_size * sizeof(DistPair); int search_smem_total_size = - static_cast((dim + align_padding) * sizeof(T) + // visited_size * sizeof(Node) + - degree * sizeof(int) + queue_size * sizeof(DistPair)); + static_cast(4 * ((coords_size + neighbor_size + queue_size_bytes + 15) & ~15)); // Total dynamic shared memory size needed by both RobustPrune calls int prune_smem_total_size = (degree + visited_size) * sizeof(float) + // Occlusion list @@ -235,18 +238,16 @@ void batched_insert_vamana( RAFT_LOG_DEBUG("Starting batch of inserts indices_start:%d, batch_size:%d", start, step_size); int num_blocks = min(maxBlocks, step_size); + int num_blocks_greedy = min(maxBlocks, (step_size + 3) / 4); // Copy ids to be inserted for this batch - raft::copy( - res, - raft::make_device_vector_view(query_ids.data_handle(), int64_t(step_size)), - raft::make_host_vector_view(insert_order.data() + start, int64_t(step_size))); + raft::copy(query_ids.data_handle(), &insert_order.data()[start], step_size, stream); set_query_ids<<>>( query_list_ptr.data_handle(), query_ids.data_handle(), step_size); // Call greedy search to get candidates for every vector being inserted GreedySearchKernel - <<>>(d_graph.view(), + <<>>(d_graph.view(), dataset, query_list_ptr.data_handle(), step_size, @@ -545,7 +546,7 @@ void batched_insert_vamana( batch_prune); #endif - raft::copy(res, graph, d_graph.view()); + raft::copy(graph.data_handle(), d_graph.data_handle(), d_graph.size(), stream); RAFT_CHECK_CUDA(stream); } @@ -607,10 +608,10 @@ index build( res, codebook_params.pq_encoding_table.size()); // logically a 2D matrix with dimensions // pq_codebook_size x dim_per_subspace * pq_dim - raft::copy(res, - pq_encoding_table_device_vec.view(), - raft::make_host_vector_view(codebook_params.pq_encoding_table.data(), - pq_encoding_table_device_vec.extent(0))); + raft::copy(pq_encoding_table_device_vec.data_handle(), + codebook_params.pq_encoding_table.data(), + codebook_params.pq_encoding_table.size(), + raft::resource::get_cuda_stream(res)); int dim_per_subspace = dim / pq_dim; auto pq_codebook = raft::make_device_matrix(res, pq_codebook_size * pq_dim, dim_per_subspace); @@ -634,12 +635,10 @@ index build( // prepare rotation matrix auto rotation_matrix_device = raft::make_device_matrix(res, dim, dim); - raft::copy( - res, - raft::make_device_vector_view(rotation_matrix_device.data_handle(), - int64_t(codebook_params.rotation_matrix.size())), - raft::make_host_vector_view(codebook_params.rotation_matrix.data(), - int64_t(codebook_params.rotation_matrix.size()))); + raft::copy(rotation_matrix_device.data_handle(), + codebook_params.rotation_matrix.data(), + codebook_params.rotation_matrix.size(), + raft::resource::get_cuda_stream(res)); // process in batches const uint32_t n_rows = dataset.extent(0); diff --git a/cpp/src/neighbors/detail/vamana/vamana_structs.cuh b/cpp/src/neighbors/detail/vamana/vamana_structs.cuh index 9a076a6727..ee604527b9 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_structs.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_structs.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -20,6 +20,7 @@ #include #include #include +#include #include @@ -226,6 +227,16 @@ struct QueryCandidates { size = 0; } + // Warp-level reset: uses laneId and stride 32, no block sync + __device__ void reset_warp(int laneId) + { + for (int i = laneId; i < maxSize; i += 32) { + ids[i] = raft::upper_bound(); + dists[i] = raft::upper_bound(); + } + if (laneId == 0) { size = 0; } + } + // Checks current list to see if a node as previously been visited __inline__ __device__ bool check_visited(IdxT target, accT dist) { @@ -249,6 +260,26 @@ struct QueryCandidates { } return found; } + + // Warp-level check_visited: no __syncthreads, uses laneId and warp ballot + __inline__ __device__ bool check_visited_warp(IdxT target, accT dist_val, int laneId) + { + bool my_found = false; + for (int i = laneId; i < size; i += 32) { + if (ids[i] == target) { + my_found = true; + break; + } + } + unsigned mask = raft::ballot(my_found); + bool found = (mask != 0); + if (!found && size < maxSize && laneId == 0) { + ids[size] = target; + dists[size] = dist_val; + size++; + } + return found; + } // For debugging /* __inline__ __device__ void print_visited() { @@ -361,6 +392,21 @@ __device__ void update_shared_point(Point* shared_point, } } +// Warp-level: uses laneId and stride 32 for coordinate copy +template +__device__ void update_shared_point_warp(Point* shared_point, + const T* data_ptr, + int id, + int dim, + int laneId) +{ + shared_point->id = id; + shared_point->Dim = dim; + for (size_t i = laneId; i < dim; i += 32) { + shared_point->coords[i] = data_ptr[(size_t)(id) * (size_t)(dim) + i]; + } +} + // Update the graph from the results of the query list (or reverse edge list) template __global__ void write_graph_edges_kernel(raft::device_matrix_view graph, @@ -391,6 +437,7 @@ __global__ void create_reverse_edge_list(void* query_list_ptr, for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num_queries; i += blockDim.x * gridDim.x) { + int read_idx = i * query_list[i].maxSize; int cand_count = query_list[i + 1].size - query_list[i].size; for (int j = 0; j < cand_count; j++) { From 05cc8ccf8600b308ef780cf6c6b2fae3f6e76d7b Mon Sep 17 00:00:00 2001 From: bkarsin Date: Fri, 1 May 2026 14:37:00 -0700 Subject: [PATCH 02/14] Add fp16 support for Vamana build/serialize (cherry picked from commit d8b547b616abb26bf6873012bcd1158007b946ce) --- c/src/neighbors/vamana.cpp | 11 ++ cpp/CMakeLists.txt | 2 + cpp/include/cuvs/neighbors/vamana.hpp | 20 +++- .../neighbors/detail/vamana/vamana_build.cuh | 8 +- .../detail/vamana/vamana_structs.cuh | 105 ++++++++++++++++-- cpp/src/neighbors/vamana.cuh | 1 + cpp/src/neighbors/vamana_build_half.cu | 33 ++++++ cpp/src/neighbors/vamana_serialize_half.cu | 13 +++ python/cuvs/cuvs/neighbors/vamana/vamana.pyx | 6 +- python/cuvs/cuvs/tests/test_vamana.py | 10 +- 10 files changed, 191 insertions(+), 18 deletions(-) create mode 100644 cpp/src/neighbors/vamana_build_half.cu create mode 100644 cpp/src/neighbors/vamana_serialize_half.cu diff --git a/c/src/neighbors/vamana.cpp b/c/src/neighbors/vamana.cpp index d1686ad96f..beabe6d616 100644 --- a/c/src/neighbors/vamana.cpp +++ b/c/src/neighbors/vamana.cpp @@ -4,6 +4,7 @@ */ #include +#include #include #include @@ -82,6 +83,8 @@ extern "C" cuvsError_t cuvsVamanaIndexDestroy(cuvsVamanaIndex_t index_c_ptr) if (index.addr != 0) { if (index.dtype.code == kDLFloat && index.dtype.bits == 32) { delete reinterpret_cast*>(index.addr); + } else if (index.dtype.code == kDLFloat && index.dtype.bits == 16) { + delete reinterpret_cast*>(index.addr); } else if (index.dtype.code == kDLInt && index.dtype.bits == 8) { delete reinterpret_cast*>(index.addr); } else if (index.dtype.code == kDLUInt && index.dtype.bits == 8) { @@ -100,6 +103,10 @@ extern "C" cuvsError_t cuvsVamanaIndexGetDims(cuvsVamanaIndex_t index, int* dim) auto index_ptr = reinterpret_cast*>(index->addr); *dim = index_ptr->dim(); + } else if (index->dtype.code == kDLFloat && index->dtype.bits == 16) { + auto index_ptr = + reinterpret_cast*>(index->addr); + *dim = index_ptr->dim(); } else if (index->dtype.code == kDLInt && index->dtype.bits == 8) { auto index_ptr = reinterpret_cast*>(index->addr); @@ -123,6 +130,8 @@ extern "C" cuvsError_t cuvsVamanaBuild(cuvsResources_t res, if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) { index->addr = reinterpret_cast(_build(res, params, dataset_tensor)); + } else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 16) { + index->addr = reinterpret_cast(_build(res, params, dataset_tensor)); } else if (dataset.dtype.code == kDLInt && dataset.dtype.bits == 8) { index->addr = reinterpret_cast(_build(res, params, dataset_tensor)); } else if (dataset.dtype.code == kDLUInt && dataset.dtype.bits == 8) { @@ -143,6 +152,8 @@ extern "C" cuvsError_t cuvsVamanaSerialize(cuvsResources_t res, return cuvs::core::translate_exceptions([=] { if (index->dtype.code == kDLFloat && index->dtype.bits == 32) { _serialize(res, filename, index, include_dataset); + } else if (index->dtype.code == kDLFloat && index->dtype.bits == 16) { + _serialize(res, filename, index, include_dataset); } else if (index->dtype.code == kDLInt && index->dtype.bits == 8) { _serialize(res, filename, index, include_dataset); } else if (index->dtype.code == kDLUInt && index->dtype.bits == 8) { diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 227c2906cc..79b575c925 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -1218,10 +1218,12 @@ if(NOT BUILD_CPU_ONLY) src/neighbors/tiered_index.cu src/neighbors/sparse_brute_force.cu src/neighbors/vamana_build_float.cu + src/neighbors/vamana_build_half.cu src/neighbors/vamana_build_uint8.cu src/neighbors/vamana_build_int8.cu src/neighbors/vamana_codebooks_float.cu src/neighbors/vamana_serialize_float.cu + src/neighbors/vamana_serialize_half.cu src/neighbors/vamana_serialize_uint8.cu src/neighbors/vamana_serialize_int8.cu src/preprocessing/quantize/scalar.cu diff --git a/cpp/include/cuvs/neighbors/vamana.hpp b/cpp/include/cuvs/neighbors/vamana.hpp index 645adc5c5d..e00bda1fa4 100644 --- a/cpp/include/cuvs/neighbors/vamana.hpp +++ b/cpp/include/cuvs/neighbors/vamana.hpp @@ -6,6 +6,7 @@ #pragma once #include "common.hpp" +#include #include #include #include @@ -286,7 +287,8 @@ struct index : cuvs::neighbors::index { * to improve graph quality. The index_params struct controls the degree of the final graph. * * The following distance metrics are supported: - * - L2 + * - L2Expanded (squared L2) + * - L2SqrtExpanded (Euclidean distance) * * Usage example: * @code{.cpp} @@ -344,6 +346,16 @@ auto build(raft::resources const& res, raft::host_matrix_view dataset) -> cuvs::neighbors::vamana::index; +auto build(raft::resources const& res, + const cuvs::neighbors::vamana::index_params& params, + raft::device_matrix_view dataset) + -> cuvs::neighbors::vamana::index; + +auto build(raft::resources const& res, + const cuvs::neighbors::vamana::index_params& params, + raft::host_matrix_view dataset) + -> cuvs::neighbors::vamana::index; + /** * @brief Build the index from the dataset for efficient DiskANN search. * @@ -520,6 +532,12 @@ void serialize(raft::resources const& handle, bool include_dataset = true, bool sector_aligned = false); +void serialize(raft::resources const& handle, + const std::string& file_prefix, + const cuvs::neighbors::vamana::index& index, + bool include_dataset = true, + bool sector_aligned = false); + /** * Save the index to file. * diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index 6cbaa525de..c92a8ba97a 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -562,8 +562,9 @@ index build( { uint32_t graph_degree = params.graph_degree; - RAFT_EXPECTS(params.metric == cuvs::distance::DistanceType::L2Expanded, - "Currently only L2Expanded metric is supported"); + RAFT_EXPECTS(params.metric == cuvs::distance::DistanceType::L2Expanded || + params.metric == cuvs::distance::DistanceType::L2SqrtExpanded, + "Only L2Expanded and L2SqrtExpanded metrics are supported"); const int* deg_size = std::find(std::begin(DEGREE_SIZES), std::end(DEGREE_SIZES), graph_degree); RAFT_EXPECTS(deg_size != std::end(DEGREE_SIZES), "Provided graph_degree not currently supported"); @@ -580,10 +581,9 @@ index build( RAFT_LOG_DEBUG("Running Vamana batched insert algorithm"); - cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded; IdxT medoid_id; batched_insert_vamana( - res, params, dataset, vamana_graph.view(), &medoid_id, metric); + res, params, dataset, vamana_graph.view(), &medoid_id, params.metric); std::optional> quantized_vectors; if (params.codebooks) { diff --git a/cpp/src/neighbors/detail/vamana/vamana_structs.cuh b/cpp/src/neighbors/detail/vamana/vamana_structs.cuh index ee604527b9..357370b59f 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_structs.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_structs.cuh @@ -8,7 +8,9 @@ #include #include #include +#include #include +#include #include #include @@ -171,16 +173,102 @@ __device__ SUMTYPE l2_ILP4(Point* src_vec, Point* dst_ve return partial_sum[0]; } +/* fp16: accumulate in float; promote operands once for correct fmaf behavior */ +template +__device__ SUMTYPE l2_SEQ_half(Point<__half, SUMTYPE>* src_vec, Point<__half, SUMTYPE>* dst_vec) +{ + SUMTYPE partial_sum = 0; + + for (int i = threadIdx.x; i < src_vec->Dim; i += blockDim.x) { + float s = __half2float(src_vec[0].coords[i]); + float t = __half2float(dst_vec[0].coords[i]); + float d = s - t; + partial_sum = fmaf(d, d, partial_sum); + } + + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum += __shfl_down_sync(FULL_BITMASK, partial_sum, offset); + } + return partial_sum; +} + +template +__device__ SUMTYPE l2_ILP2_half(Point<__half, SUMTYPE>* src_vec, Point<__half, SUMTYPE>* dst_vec) +{ + float partial_sum[2] = {0, 0}; + for (int i = threadIdx.x; i < src_vec->Dim; i += 2 * blockDim.x) { + float t0 = __half2float(dst_vec->coords[i]); + float s0 = __half2float(src_vec[0].coords[i]); + partial_sum[0] = fmaf(s0 - t0, s0 - t0, partial_sum[0]); + + if (i + 32 < src_vec->Dim) { + float t1 = __half2float(dst_vec->coords[i + 32]); + float s1 = __half2float(src_vec[0].coords[i + 32]); + partial_sum[1] = fmaf(s1 - t1, s1 - t1, partial_sum[1]); + } + } + partial_sum[0] += partial_sum[1]; + + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + } + return partial_sum[0]; +} + +template +__device__ SUMTYPE l2_ILP4_half(Point<__half, SUMTYPE>* src_vec, Point<__half, SUMTYPE>* dst_vec) +{ + float partial_sum[4] = {0, 0, 0, 0}; + for (int i = threadIdx.x; i < src_vec->Dim; i += 4 * blockDim.x) { + float t0 = __half2float(dst_vec->coords[i]); + float s0 = __half2float(src_vec[0].coords[i]); + partial_sum[0] = fmaf(s0 - t0, s0 - t0, partial_sum[0]); + + if (i + 32 < src_vec->Dim) { + float t1 = __half2float(dst_vec->coords[i + 32]); + float s1 = __half2float(src_vec[0].coords[i + 32]); + partial_sum[1] = fmaf(s1 - t1, s1 - t1, partial_sum[1]); + } + if (i + 64 < src_vec->Dim) { + float t2 = __half2float(dst_vec->coords[i + 64]); + float s2 = __half2float(src_vec[0].coords[i + 64]); + partial_sum[2] = fmaf(s2 - t2, s2 - t2, partial_sum[2]); + } + if (i + 96 < src_vec->Dim) { + float t3 = __half2float(dst_vec->coords[i + 96]); + float s3 = __half2float(src_vec[0].coords[i + 96]); + partial_sum[3] = fmaf(s3 - t3, s3 - t3, partial_sum[3]); + } + } + partial_sum[0] += partial_sum[1] + partial_sum[2] + partial_sum[3]; + + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + } + + return partial_sum[0]; +} + /* Selects ILP optimization level based on dimension */ template __forceinline__ __device__ SUMTYPE l2(Point* src_vec, Point* dst_vec) { - if (src_vec->Dim >= 128) { - return l2_ILP4(src_vec, dst_vec); - } else if (src_vec->Dim >= 64) { - return l2_ILP2(src_vec, dst_vec); + if constexpr (std::is_same_v, __half>) { + if (src_vec->Dim >= 128) { + return l2_ILP4_half(src_vec, dst_vec); + } else if (src_vec->Dim >= 64) { + return l2_ILP2_half(src_vec, dst_vec); + } else { + return l2_SEQ_half(src_vec, dst_vec); + } } else { - return l2_SEQ(src_vec, dst_vec); + if (src_vec->Dim >= 128) { + return l2_ILP4(src_vec, dst_vec); + } else if (src_vec->Dim >= 64) { + return l2_ILP2(src_vec, dst_vec); + } else { + return l2_SEQ(src_vec, dst_vec); + } } } @@ -198,12 +286,15 @@ __host__ __device__ SUMTYPE l2(const T* src, const T* dest, int dim) return l2(&src_p, &dest_p); } -// Currently only L2Expanded is supported template __host__ __device__ SUMTYPE dist(const T* src, const T* dest, int dim, cuvs::distance::DistanceType metric) { - return l2(src, dest, dim); + SUMTYPE d = l2(src, dest, dim); + if (metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + return static_cast(sqrtf(static_cast(d))); + } + return d; } /*************************************************************************************** diff --git a/cpp/src/neighbors/vamana.cuh b/cpp/src/neighbors/vamana.cuh index d2a73809ed..a6c941d3c4 100644 --- a/cpp/src/neighbors/vamana.cuh +++ b/cpp/src/neighbors/vamana.cuh @@ -45,6 +45,7 @@ namespace cuvs::neighbors::vamana { * * The following distance metrics are supported: * - L2Expanded + * - L2SqrtExpanded * * Usage example: * @code{.cpp} diff --git a/cpp/src/neighbors/vamana_build_half.cu b/cpp/src/neighbors/vamana_build_half.cu new file mode 100644 index 0000000000..817f8c0488 --- /dev/null +++ b/cpp/src/neighbors/vamana_build_half.cu @@ -0,0 +1,33 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "vamana.cuh" +#include +#include + +namespace cuvs::neighbors::vamana { + +#define RAFT_INST_VAMANA_BUILD(T, IdxT) \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::vamana::index_params& params, \ + raft::device_matrix_view dataset) \ + -> cuvs::neighbors::vamana::index \ + { \ + return cuvs::neighbors::vamana::build(handle, params, dataset); \ + } \ + \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::vamana::index_params& params, \ + raft::host_matrix_view dataset) \ + -> cuvs::neighbors::vamana::index \ + { \ + return cuvs::neighbors::vamana::build(handle, params, dataset); \ + } + +RAFT_INST_VAMANA_BUILD(half, uint32_t); + +#undef RAFT_INST_VAMANA_BUILD + +} // namespace cuvs::neighbors::vamana diff --git a/cpp/src/neighbors/vamana_serialize_half.cu b/cpp/src/neighbors/vamana_serialize_half.cu new file mode 100644 index 0000000000..76e0d81fa7 --- /dev/null +++ b/cpp/src/neighbors/vamana_serialize_half.cu @@ -0,0 +1,13 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include "vamana_serialize.cuh" + +namespace cuvs::neighbors::vamana { + +CUVS_INST_VAMANA_SERIALIZE(half); + +} // namespace cuvs::neighbors::vamana diff --git a/python/cuvs/cuvs/neighbors/vamana/vamana.pyx b/python/cuvs/cuvs/neighbors/vamana/vamana.pyx index a2b4be0a61..3bd2b6706b 100644 --- a/python/cuvs/cuvs/neighbors/vamana/vamana.pyx +++ b/python/cuvs/cuvs/neighbors/vamana/vamana.pyx @@ -169,13 +169,14 @@ def build(IndexParams index_params, dataset, resources=None): struct controls the degree of the final graph. The following distance metrics are supported: - - L2Expanded + - L2Expanded (sqeuclidean) + - L2SqrtExpanded (l2 / euclidean distance) Parameters ---------- index_params : IndexParams object dataset : CUDA array interface compliant matrix shape (n_samples, dim) - Supported dtype [float, int8, uint8] + Supported dtype [float, float16, int8, uint8] {resources_docstring} Returns @@ -201,6 +202,7 @@ def build(IndexParams index_params, dataset, resources=None): # in RAFT to make this a single call dataset_ai = wrap_array(dataset) _check_input_array(dataset_ai, [np.dtype('float32'), + np.dtype('float16'), np.dtype('int8'), np.dtype('uint8')]) diff --git a/python/cuvs/cuvs/tests/test_vamana.py b/python/cuvs/cuvs/tests/test_vamana.py index f692e4ae69..bd19615788 100644 --- a/python/cuvs/cuvs/tests/test_vamana.py +++ b/python/cuvs/cuvs/tests/test_vamana.py @@ -12,6 +12,8 @@ def _gen_data(shape, dtype): rng = np.random.default_rng(12345) if dtype == np.float32: return rng.random(shape, dtype=np.float32) + if dtype == np.float16: + return rng.random(shape, dtype=np.float32).astype(np.float16) if dtype == np.int8: # keep small magnitude to avoid overflow if used elsewhere return rng.integers(low=-10, high=10, size=shape, dtype=np.int8) @@ -20,7 +22,7 @@ def _gen_data(shape, dtype): raise AssertionError("unexpected dtype in test helper") -@pytest.mark.parametrize("dtype", [np.float32, np.int8, np.uint8]) +@pytest.mark.parametrize("dtype", [np.float32, np.float16, np.int8, np.uint8]) def test_vamana_build_basic(dtype): n_rows, n_cols = 1000, 16 data = _gen_data((n_rows, n_cols), dtype) @@ -36,7 +38,7 @@ def test_vamana_build_basic(dtype): assert "Index(type=Vamana" in repr(idx) -@pytest.mark.parametrize("dtype", [np.float32, np.int8, np.uint8]) +@pytest.mark.parametrize("dtype", [np.float32, np.float16, np.int8, np.uint8]) @pytest.mark.skip( reason="Skipping host build test because of CUDA error " "in C++ API. Reference issue: " @@ -69,9 +71,9 @@ def test_vamana_serialize(tmp_path, include_dataset): def test_vamana_build_rejects_unsupported_dtype(): - # float16 is not in the accepted list in the build wrapper; expect failure + # e.g. float64 is not supported for Vamana build n_rows, n_cols = 64, 8 - data = _gen_data((n_rows, n_cols), np.float32).astype(np.float16) + data = np.random.default_rng(1).random((n_rows, n_cols), dtype=np.float64) from pylibraft.common import device_ndarray data_dev = device_ndarray(data) From 5bb5e6cb106b3da63f23cd0ea702ed829cc79712 Mon Sep 17 00:00:00 2001 From: bkarsin Date: Wed, 13 May 2026 17:14:47 -0700 Subject: [PATCH 03/14] Fix recall issue with multi-warp optimization (cherry picked from commit 4d970e0a0a1b8dd1692cc1004295465cc2ae9249) --- .../neighbors/detail/vamana/greedy_search.cuh | 4 +- .../detail/vamana/priority_queue.cuh | 7 +- .../detail/vamana/vamana_structs.cuh | 201 ++++++++++++++++++ 3 files changed, 208 insertions(+), 4 deletions(-) diff --git a/cpp/src/neighbors/detail/vamana/greedy_search.cuh b/cpp/src/neighbors/detail/vamana/greedy_search.cuh index 0a017e3cb7..e6b5224462 100644 --- a/cpp/src/neighbors/detail/vamana/greedy_search.cuh +++ b/cpp/src/neighbors/detail/vamana/greedy_search.cuh @@ -165,8 +165,8 @@ __global__ __launch_bounds__(128, 12) void GreedySearchKernel( Point* query_vec = &s_query; query_vec->Dim = dim; query_vec->coords = s_coords; - accT medoid_dist = - dist(query_vec->coords, &vec_ptr[(size_t)medoid_id * (size_t)dim], dim, metric); + accT medoid_dist = dist_warp( + query_vec->coords, &vec_ptr[(size_t)medoid_id * (size_t)dim], dim, metric, laneId); if (laneId == 0) { heap_queue.insert_back(medoid_dist, medoid_id); } diff --git a/cpp/src/neighbors/detail/vamana/priority_queue.cuh b/cpp/src/neighbors/detail/vamana/priority_queue.cuh index 0bccfd4db4..ddf35d8cfe 100644 --- a/cpp/src/neighbors/detail/vamana/priority_queue.cuh +++ b/cpp/src/neighbors/detail/vamana/priority_queue.cuh @@ -395,8 +395,11 @@ __forceinline__ __device__ void enqueue_all_neighbors_warp(int num_neighbors, int laneId) { for (int i = 0; i < num_neighbors; i++) { - accT dist_out = dist( - query_vec->coords, &vec_ptr[(size_t)(neighbor_array[i]) * (size_t)(dim)], dim, metric); + accT dist_out = dist_warp(query_vec->coords, + &vec_ptr[(size_t)(neighbor_array[i]) * (size_t)(dim)], + dim, + metric, + laneId); if (laneId == 0) { heap_queue.insert_back(dist_out, neighbor_array[i]); } } } diff --git a/cpp/src/neighbors/detail/vamana/vamana_structs.cuh b/cpp/src/neighbors/detail/vamana/vamana_structs.cuh index 357370b59f..e2f9e9ed7c 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_structs.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_structs.cuh @@ -34,6 +34,9 @@ namespace cuvs::neighbors::vamana::detail { #define FULL_BITMASK 0xFFFFFFFF +// Warp stride for per-warp distance reduction (GreedySearch uses multiple warps per block). +static constexpr int VAMANA_WARP_SIZE = 32; + // Currently supported values for graph_degree. static const int DEGREE_SIZES[4] = {32, 64, 128, 256}; @@ -297,6 +300,204 @@ dist(const T* src, const T* dest, int dim, cuvs::distance::DistanceType metric) return d; } +/* + * Warp-strided L2 / dist: each warp computes one distance using lanes 0..31 only. + * Required when blockDim.x > 32 but only one distance per warp is desired — plain l2() shards + * work across the whole block then reduces inside each warp only, which under-counts L2. + */ +template +__device__ SUMTYPE l2_SEQ_warp(Point* src_vec, Point* dst_vec, int lane) +{ + SUMTYPE partial_sum = 0; + for (int i = lane; i < src_vec->Dim; i += VAMANA_WARP_SIZE) { + partial_sum = fmaf((src_vec[0].coords[i] - dst_vec[0].coords[i]), + (src_vec[0].coords[i] - dst_vec[0].coords[i]), + partial_sum); + } + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum += __shfl_down_sync(FULL_BITMASK, partial_sum, offset); + } + return partial_sum; +} + +template +__device__ SUMTYPE l2_ILP2_warp(Point* src_vec, Point* dst_vec, int lane) +{ + T temp_dst[2] = {0, 0}; + SUMTYPE partial_sum[2] = {0, 0}; + for (int i = lane; i < src_vec->Dim; i += 2 * VAMANA_WARP_SIZE) { + temp_dst[0] = dst_vec->coords[i]; + if (i + 32 < src_vec->Dim) temp_dst[1] = dst_vec->coords[i + 32]; + + partial_sum[0] = fmaf( + (src_vec[0].coords[i] - temp_dst[0]), (src_vec[0].coords[i] - temp_dst[0]), partial_sum[0]); + if (i + 32 < src_vec->Dim) + partial_sum[1] = fmaf((src_vec[0].coords[i + 32] - temp_dst[1]), + (src_vec[0].coords[i + 32] - temp_dst[1]), + partial_sum[1]); + } + partial_sum[0] += partial_sum[1]; + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + } + return partial_sum[0]; +} + +template +__device__ SUMTYPE l2_ILP4_warp(Point* src_vec, Point* dst_vec, int lane) +{ + T temp_dst[4] = {0, 0, 0, 0}; + SUMTYPE partial_sum[4] = {0, 0, 0, 0}; + for (int i = lane; i < src_vec->Dim; i += 4 * VAMANA_WARP_SIZE) { + temp_dst[0] = dst_vec->coords[i]; + if (i + 32 < src_vec->Dim) temp_dst[1] = dst_vec->coords[i + 32]; + if (i + 64 < src_vec->Dim) temp_dst[2] = dst_vec->coords[i + 64]; + if (i + 96 < src_vec->Dim) temp_dst[3] = dst_vec->coords[i + 96]; + + partial_sum[0] = fmaf( + (src_vec[0].coords[i] - temp_dst[0]), (src_vec[0].coords[i] - temp_dst[0]), partial_sum[0]); + if (i + 32 < src_vec->Dim) + partial_sum[1] = fmaf((src_vec[0].coords[i + 32] - temp_dst[1]), + (src_vec[0].coords[i + 32] - temp_dst[1]), + partial_sum[1]); + if (i + 64 < src_vec->Dim) + partial_sum[2] = fmaf((src_vec[0].coords[i + 64] - temp_dst[2]), + (src_vec[0].coords[i + 64] - temp_dst[2]), + partial_sum[2]); + if (i + 96 < src_vec->Dim) + partial_sum[3] = fmaf((src_vec[0].coords[i + 96] - temp_dst[3]), + (src_vec[0].coords[i + 96] - temp_dst[3]), + partial_sum[3]); + } + partial_sum[0] += partial_sum[1] + partial_sum[2] + partial_sum[3]; + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + } + return partial_sum[0]; +} + +template +__device__ SUMTYPE l2_SEQ_half_warp(Point<__half, SUMTYPE>* src_vec, + Point<__half, SUMTYPE>* dst_vec, + int lane) +{ + SUMTYPE partial_sum = 0; + for (int i = lane; i < src_vec->Dim; i += VAMANA_WARP_SIZE) { + float s = __half2float(src_vec[0].coords[i]); + float t = __half2float(dst_vec[0].coords[i]); + float d = s - t; + partial_sum = fmaf(d, d, partial_sum); + } + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum += __shfl_down_sync(FULL_BITMASK, partial_sum, offset); + } + return partial_sum; +} + +template +__device__ SUMTYPE l2_ILP2_half_warp(Point<__half, SUMTYPE>* src_vec, + Point<__half, SUMTYPE>* dst_vec, + int lane) +{ + float partial_sum[2] = {0, 0}; + for (int i = lane; i < src_vec->Dim; i += 2 * VAMANA_WARP_SIZE) { + float t0 = __half2float(dst_vec->coords[i]); + float s0 = __half2float(src_vec[0].coords[i]); + partial_sum[0] = fmaf(s0 - t0, s0 - t0, partial_sum[0]); + + if (i + 32 < src_vec->Dim) { + float t1 = __half2float(dst_vec->coords[i + 32]); + float s1 = __half2float(src_vec[0].coords[i + 32]); + partial_sum[1] = fmaf(s1 - t1, s1 - t1, partial_sum[1]); + } + } + partial_sum[0] += partial_sum[1]; + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + } + return partial_sum[0]; +} + +template +__device__ SUMTYPE l2_ILP4_half_warp(Point<__half, SUMTYPE>* src_vec, + Point<__half, SUMTYPE>* dst_vec, + int lane) +{ + float partial_sum[4] = {0, 0, 0, 0}; + for (int i = lane; i < src_vec->Dim; i += 4 * VAMANA_WARP_SIZE) { + float t0 = __half2float(dst_vec->coords[i]); + float s0 = __half2float(src_vec[0].coords[i]); + partial_sum[0] = fmaf(s0 - t0, s0 - t0, partial_sum[0]); + + if (i + 32 < src_vec->Dim) { + float t1 = __half2float(dst_vec->coords[i + 32]); + float s1 = __half2float(src_vec[0].coords[i + 32]); + partial_sum[1] = fmaf(s1 - t1, s1 - t1, partial_sum[1]); + } + if (i + 64 < src_vec->Dim) { + float t2 = __half2float(dst_vec->coords[i + 64]); + float s2 = __half2float(src_vec[0].coords[i + 64]); + partial_sum[2] = fmaf(s2 - t2, s2 - t2, partial_sum[2]); + } + if (i + 96 < src_vec->Dim) { + float t3 = __half2float(dst_vec->coords[i + 96]); + float s3 = __half2float(src_vec[0].coords[i + 96]); + partial_sum[3] = fmaf(s3 - t3, s3 - t3, partial_sum[3]); + } + } + partial_sum[0] += partial_sum[1] + partial_sum[2] + partial_sum[3]; + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + } + return partial_sum[0]; +} + +template +__forceinline__ __device__ SUMTYPE l2_warp(Point* src_vec, Point* dst_vec, int lane) +{ + if constexpr (std::is_same_v, __half>) { + if (src_vec->Dim >= 128) { + return l2_ILP4_half_warp(src_vec, dst_vec, lane); + } else if (src_vec->Dim >= 64) { + return l2_ILP2_half_warp(src_vec, dst_vec, lane); + } else { + return l2_SEQ_half_warp(src_vec, dst_vec, lane); + } + } else { + if (src_vec->Dim >= 128) { + return l2_ILP4_warp(src_vec, dst_vec, lane); + } else if (src_vec->Dim >= 64) { + return l2_ILP2_warp(src_vec, dst_vec, lane); + } else { + return l2_SEQ_warp(src_vec, dst_vec, lane); + } + } +} + +template +__forceinline__ __device__ SUMTYPE l2_warp(const T* src, const T* dest, int dim, int lane) +{ + Point src_p; + src_p.coords = const_cast(src); + src_p.Dim = dim; + Point dest_p; + dest_p.coords = const_cast(dest); + dest_p.Dim = dim; + + return l2_warp(&src_p, &dest_p, lane); +} + +template +__forceinline__ __device__ SUMTYPE +dist_warp(const T* src, const T* dest, int dim, cuvs::distance::DistanceType metric, int lane) +{ + SUMTYPE d = l2_warp(src, dest, dim, lane); + if (metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + return static_cast(sqrtf(static_cast(d))); + } + return d; +} + /*************************************************************************************** * Structure that holds information about and results of a query. Use by both * GreedySearch and RobustPrune, as well as reverse edge lists. From 3dca1b0ad2980863e2aeae3d1fa7a2d29f601151 Mon Sep 17 00:00:00 2001 From: bkarsin Date: Tue, 9 Jun 2026 14:06:18 -0700 Subject: [PATCH 04/14] Speed up fp16 vamana build with promoted query vectors and optimized L2 comparators. (cherry picked from commit 0746d158db33be2dfe2f6c19ff3fcbb0b0398ccc) --- .../neighbors/detail/vamana/greedy_search.cuh | 54 ++- .../detail/vamana/priority_queue.cuh | 29 +- .../neighbors/detail/vamana/robust_prune.cuh | 38 +- .../neighbors/detail/vamana/vamana_build.cuh | 6 +- .../detail/vamana/vamana_structs.cuh | 385 +++++++++++++----- examples/cpp/src/vamana_example.cu | 13 + 6 files changed, 385 insertions(+), 140 deletions(-) diff --git a/cpp/src/neighbors/detail/vamana/greedy_search.cuh b/cpp/src/neighbors/detail/vamana/greedy_search.cuh index e6b5224462..f23248222f 100644 --- a/cpp/src/neighbors/detail/vamana/greedy_search.cuh +++ b/cpp/src/neighbors/detail/vamana/greedy_search.cuh @@ -18,6 +18,7 @@ #include #include +#include #include namespace cuvs::neighbors::vamana::detail { @@ -108,8 +109,10 @@ __global__ __launch_bounds__(128, 12) void GreedySearchKernel( QueryCandidates* query_list = static_cast*>(query_list_ptr); + using QueryCoordT = typename greedy_search_query_coord::type; + union ShmemLayout { - T coords; + QueryCoordT coords; IdxT neighborhood_arr; DistPair candidate_queue; }; @@ -118,14 +121,13 @@ __global__ __launch_bounds__(128, 12) void GreedySearchKernel( extern __shared__ __align__(alignof(ShmemLayout)) char smem[]; // Per-warp shared memory layout: coords, neighbor_array, candidate_queue - const int coords_size = (dim + align_padding) * sizeof(T); + const int coords_size = (dim + align_padding) * sizeof(QueryCoordT); const int neighbor_size = degree * sizeof(IdxT); const int queue_size_bytes = max_queue_size * sizeof(DistPair); const int per_warp_size = (coords_size + neighbor_size + queue_size_bytes + 15) & ~15; char* warp_smem = &smem[warpIdx * per_warp_size]; - T* s_coords = - reinterpret_cast(warp_smem); + QueryCoordT* s_coords = reinterpret_cast(warp_smem); IdxT* neighbor_array = reinterpret_cast(warp_smem + coords_size); DistPair* candidate_queue_smem = reinterpret_cast*>(warp_smem + coords_size + neighbor_size); @@ -136,7 +138,7 @@ __global__ __launch_bounds__(128, 12) void GreedySearchKernel( static __shared__ int k_max_idx[4]; static __shared__ int num_neighbors[4]; - Point s_query; + Point s_query; s_query.Dim = dim; s_query.coords = s_coords; @@ -151,7 +153,12 @@ __global__ __launch_bounds__(128, 12) void GreedySearchKernel( for (int i = blockIdx.x * 4 + warpIdx; i < num_queries; i += gridDim.x * 4) { query_list[i].reset_warp(laneId); - update_shared_point_warp(&s_query, vec_ptr, query_list[i].queryId, dim, laneId); + if constexpr (is_cuda_fp16_v) { + update_shared_point_warp_half_to_float( + &s_query, vec_ptr, query_list[i].queryId, dim, laneId); + } else { + update_shared_point_warp(&s_query, vec_ptr, query_list[i].queryId, dim, laneId); + } if (laneId == 0) { topk_q_size[warpIdx] = 0; @@ -162,11 +169,20 @@ __global__ __launch_bounds__(128, 12) void GreedySearchKernel( heap_queue.reset(); } - Point* query_vec = &s_query; - query_vec->Dim = dim; - query_vec->coords = s_coords; - accT medoid_dist = dist_warp( - query_vec->coords, &vec_ptr[(size_t)medoid_id * (size_t)dim], dim, metric, laneId); + Point* query_vec = &s_query; + query_vec->Dim = dim; + query_vec->coords = s_coords; + accT medoid_dist; + if constexpr (is_cuda_fp16_v) { + medoid_dist = dist_warp(query_vec->coords, + &vec_ptr[(size_t)medoid_id * (size_t)dim], + dim, + metric, + laneId); + } else { + medoid_dist = dist_warp( + query_vec->coords, &vec_ptr[(size_t)medoid_id * (size_t)dim], dim, metric, laneId); + } if (laneId == 0) { heap_queue.insert_back(medoid_dist, medoid_id); } @@ -226,14 +242,14 @@ __global__ __launch_bounds__(128, 12) void GreedySearchKernel( atomicMin(&num_neighbors[warpIdx], (int)j); } - enqueue_all_neighbors_warp(num_neighbors[warpIdx], - query_vec, - vec_ptr, - neighbor_array, - heap_queue, - dim, - metric, - laneId); + enqueue_all_neighbors_warp(num_neighbors[warpIdx], + query_vec, + vec_ptr, + neighbor_array, + heap_queue, + dim, + metric, + laneId); } bool self_found = false; diff --git a/cpp/src/neighbors/detail/vamana/priority_queue.cuh b/cpp/src/neighbors/detail/vamana/priority_queue.cuh index ddf35d8cfe..891065bf97 100644 --- a/cpp/src/neighbors/detail/vamana/priority_queue.cuh +++ b/cpp/src/neighbors/detail/vamana/priority_queue.cuh @@ -384,22 +384,25 @@ __forceinline__ __device__ void enqueue_all_neighbors(int num_neighbors, } // Warp-level version: lane 0 does insert_back, no __syncthreads -template +template __forceinline__ __device__ void enqueue_all_neighbors_warp(int num_neighbors, - Point* query_vec, - const T* vec_ptr, - IdxT* neighbor_array, - PriorityQueue& heap_queue, - int dim, - cuvs::distance::DistanceType metric, - int laneId) + Point* query_vec, + const DataT* vec_ptr, + IdxT* neighbor_array, + PriorityQueue& heap_queue, + int dim, + cuvs::distance::DistanceType metric, + int laneId) { for (int i = 0; i < num_neighbors; i++) { - accT dist_out = dist_warp(query_vec->coords, - &vec_ptr[(size_t)(neighbor_array[i]) * (size_t)(dim)], - dim, - metric, - laneId); + const DataT* neighbor_vec = &vec_ptr[(size_t)(neighbor_array[i]) * (size_t)(dim)]; + accT dist_out; + if constexpr (std::is_same_v && is_cuda_fp16_v) { + dist_out = dist_warp(query_vec->coords, neighbor_vec, dim, metric, laneId); + } else { + static_assert(std::is_same_v); + dist_out = dist_warp(query_vec->coords, neighbor_vec, dim, metric, laneId); + } if (laneId == 0) { heap_queue.insert_back(dist_out, neighbor_array[i]); } } } diff --git a/cpp/src/neighbors/detail/vamana/robust_prune.cuh b/cpp/src/neighbors/detail/vamana/robust_prune.cuh index 31fb6d589f..6df57e3505 100644 --- a/cpp/src/neighbors/detail/vamana/robust_prune.cuh +++ b/cpp/src/neighbors/detail/vamana/robust_prune.cuh @@ -9,6 +9,8 @@ #include +#include + #include "macros.cuh" #include "vamana_structs.cuh" @@ -61,7 +63,7 @@ __global__ void RobustPruneKernel( int visited_size, cuvs::distance::DistanceType metric, float alpha, - T* s_coords_mem) + typename greedy_search_query_coord::type* s_coords_mem) { int n = dataset.extent(0); int dim = dataset.extent(1); @@ -69,6 +71,8 @@ __global__ void RobustPruneKernel( QueryCandidates* query_list = static_cast*>(query_list_ptr); + using QueryCoordT = typename greedy_search_query_coord::type; + union ShmemLayout { // All blocksort sizes have same alignment (16) float occlusion; @@ -84,7 +88,7 @@ __global__ void RobustPruneKernel( DistPair* new_nbh_list = reinterpret_cast*>(&smem[(degree + visited_size) * sizeof(float)]); - static __shared__ Point s_query; + static __shared__ Point s_query; s_query.coords = &s_coords_mem[blockIdx.x * (dim + align_padding)]; s_query.Dim = dim; static __shared__ int prev_edges; @@ -93,7 +97,11 @@ __global__ void RobustPruneKernel( for (int i = blockIdx.x; i < num_queries; i += gridDim.x) { int queryId = query_list[i].queryId; - update_shared_point(&s_query, &dataset(0, 0), queryId, dim, i); + if constexpr (is_cuda_fp16_v) { + update_shared_point_half_to_float(&s_query, &dataset(0, 0), queryId, dim); + } else { + update_shared_point(&s_query, &dataset(0, 0), queryId, dim, i); + } int graphIdx = 0; int listIdx = 0; @@ -141,8 +149,16 @@ __global__ void RobustPruneKernel( } } else if (listIdx >= visited_size) { next_cand.idx = graph(queryId, graphIdx); - accT tempDist = - dist(s_query.coords, &dataset((size_t)graph(queryId, graphIdx), 0), dim, metric); + accT tempDist; + if constexpr (is_cuda_fp16_v) { + tempDist = dist(s_query.coords, + &dataset((size_t)graph(queryId, graphIdx), 0), + dim, + metric); + } else { + tempDist = dist( + s_query.coords, &dataset((size_t)graph(queryId, graphIdx), 0), dim, metric); + } if (threadIdx.x == 0) graphDist = tempDist; __syncthreads(); next_cand.dist = graphDist; @@ -150,8 +166,16 @@ __global__ void RobustPruneKernel( } else { accT listDist = query_list[i].dists[listIdx]; - accT tempDist = - dist(s_query.coords, &dataset((size_t)graph(queryId, graphIdx), 0), dim, metric); + accT tempDist; + if constexpr (is_cuda_fp16_v) { + tempDist = dist(s_query.coords, + &dataset((size_t)graph(queryId, graphIdx), 0), + dim, + metric); + } else { + tempDist = dist( + s_query.coords, &dataset((size_t)graph(queryId, graphIdx), 0), dim, metric); + } if (threadIdx.x == 0) graphDist = tempDist; __syncthreads(); diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index c92a8ba97a..e1dbc1c59f 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -173,7 +173,8 @@ void batched_insert_vamana( int align_padding = raft::alignTo(dim, 16) - dim; - auto s_coords_mem = raft::make_device_mdarray( + using QueryCoordT = typename greedy_search_query_coord::type; + auto s_coords_mem = raft::make_device_mdarray( res, raft::resource::get_large_workspace_resource_ref(res), raft::make_extents(min(maxBlocks, max(max_batchsize, reverse_batch)), @@ -188,7 +189,8 @@ void batched_insert_vamana( SELECT_SORT_SMEM_SIZE(degree, visited_size); // Sets sort_smem_size based on dataset // GreedySearch: per-warp shared memory (4 warps): coords, neighbor_array, candidate_queue - const int coords_size = (dim + align_padding) * sizeof(T); + // Half datasets promote query coords to float in smem. + const int coords_size = (dim + align_padding) * static_cast(sizeof(QueryCoordT)); const int neighbor_size = degree * sizeof(IdxT); const int queue_size_bytes = queue_size * sizeof(DistPair); int search_smem_total_size = diff --git a/cpp/src/neighbors/detail/vamana/vamana_structs.cuh b/cpp/src/neighbors/detail/vamana/vamana_structs.cuh index e2f9e9ed7c..906b6e1c4e 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_structs.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_structs.cuh @@ -37,6 +37,16 @@ namespace cuvs::neighbors::vamana::detail { // Warp stride for per-warp distance reduction (GreedySearch uses multiple warps per block). static constexpr int VAMANA_WARP_SIZE = 32; +// vamana fp16 instantiations use CUDA's half type (alias of __half on device). +template +inline constexpr bool is_cuda_fp16_v = std::is_same_v, half>; + +// GreedySearch promotes fp16 queries to float in shared memory for distance reuse. +template +struct greedy_search_query_coord { + using type = std::conditional_t, float, T>; +}; + // Currently supported values for graph_degree. static const int DEGREE_SIZES[4] = {32, 64, 128, 256}; @@ -176,80 +186,99 @@ __device__ SUMTYPE l2_ILP4(Point* src_vec, Point* dst_ve return partial_sum[0]; } -/* fp16: accumulate in float; promote operands once for correct fmaf behavior */ +/* fp16: native __hsub/__hfma throughout; single float widen at return */ +__device__ __forceinline__ void l2_half_accum(__half& lane_sum, __half s, __half t) +{ + __half d = __hsub(s, t); + lane_sum = __hfma(d, d, lane_sum); +} + +/* ILP helpers: accumulate (s-t)^2 into acc; operands must already be loaded */ +__device__ __forceinline__ void l2_half_fma_sq(__half& acc, __half s, __half t) +{ + __half d = __hsub(s, t); + acc = __hfma(d, d, acc); +} + +__device__ __forceinline__ __half l2_half_shfl_down(__half val, int offset) +{ + unsigned int v = static_cast(__half_as_ushort(val)); + v = __shfl_down_sync(FULL_BITMASK, v, offset); + return __ushort_as_half(static_cast(v & 0xFFFFu)); +} + +__device__ __forceinline__ __half l2_half_warp_reduce_to_half(__half lane_sum) +{ + for (int offset = 16; offset > 0; offset /= 2) { + lane_sum = __hadd(lane_sum, l2_half_shfl_down(lane_sum, offset)); + } + return lane_sum; +} + +template +__device__ __forceinline__ SUMTYPE l2_half_warp_reduce(__half lane_sum) +{ + return static_cast(__half2float(l2_half_warp_reduce_to_half(lane_sum))); +} + template __device__ SUMTYPE l2_SEQ_half(Point<__half, SUMTYPE>* src_vec, Point<__half, SUMTYPE>* dst_vec) { - SUMTYPE partial_sum = 0; + __half lane_sum = __float2half(0.0f); for (int i = threadIdx.x; i < src_vec->Dim; i += blockDim.x) { - float s = __half2float(src_vec[0].coords[i]); - float t = __half2float(dst_vec[0].coords[i]); - float d = s - t; - partial_sum = fmaf(d, d, partial_sum); + l2_half_accum(lane_sum, src_vec[0].coords[i], dst_vec[0].coords[i]); } - for (int offset = 16; offset > 0; offset /= 2) { - partial_sum += __shfl_down_sync(FULL_BITMASK, partial_sum, offset); - } - return partial_sum; + return l2_half_warp_reduce(lane_sum); } template __device__ SUMTYPE l2_ILP2_half(Point<__half, SUMTYPE>* src_vec, Point<__half, SUMTYPE>* dst_vec) { - float partial_sum[2] = {0, 0}; + __half temp_dst[2] = {__float2half(0.0f), __float2half(0.0f)}; + __half partial_sum[2] = {__float2half(0.0f), __float2half(0.0f)}; for (int i = threadIdx.x; i < src_vec->Dim; i += 2 * blockDim.x) { - float t0 = __half2float(dst_vec->coords[i]); - float s0 = __half2float(src_vec[0].coords[i]); - partial_sum[0] = fmaf(s0 - t0, s0 - t0, partial_sum[0]); - - if (i + 32 < src_vec->Dim) { - float t1 = __half2float(dst_vec->coords[i + 32]); - float s1 = __half2float(src_vec[0].coords[i + 32]); - partial_sum[1] = fmaf(s1 - t1, s1 - t1, partial_sum[1]); - } - } - partial_sum[0] += partial_sum[1]; + temp_dst[0] = dst_vec->coords[i]; + if (i + 32 < src_vec->Dim) temp_dst[1] = dst_vec->coords[i + 32]; - for (int offset = 16; offset > 0; offset /= 2) { - partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + l2_half_fma_sq(partial_sum[0], src_vec[0].coords[i], temp_dst[0]); + if (i + 32 < src_vec->Dim) + l2_half_fma_sq(partial_sum[1], src_vec[0].coords[i + 32], temp_dst[1]); } - return partial_sum[0]; + partial_sum[0] = __hadd(partial_sum[0], partial_sum[1]); + + return l2_half_warp_reduce(partial_sum[0]); } template __device__ SUMTYPE l2_ILP4_half(Point<__half, SUMTYPE>* src_vec, Point<__half, SUMTYPE>* dst_vec) { - float partial_sum[4] = {0, 0, 0, 0}; + __half temp_dst[4] = {__float2half(0.0f), + __float2half(0.0f), + __float2half(0.0f), + __float2half(0.0f)}; + __half partial_sum[4] = {__float2half(0.0f), + __float2half(0.0f), + __float2half(0.0f), + __float2half(0.0f)}; for (int i = threadIdx.x; i < src_vec->Dim; i += 4 * blockDim.x) { - float t0 = __half2float(dst_vec->coords[i]); - float s0 = __half2float(src_vec[0].coords[i]); - partial_sum[0] = fmaf(s0 - t0, s0 - t0, partial_sum[0]); - - if (i + 32 < src_vec->Dim) { - float t1 = __half2float(dst_vec->coords[i + 32]); - float s1 = __half2float(src_vec[0].coords[i + 32]); - partial_sum[1] = fmaf(s1 - t1, s1 - t1, partial_sum[1]); - } - if (i + 64 < src_vec->Dim) { - float t2 = __half2float(dst_vec->coords[i + 64]); - float s2 = __half2float(src_vec[0].coords[i + 64]); - partial_sum[2] = fmaf(s2 - t2, s2 - t2, partial_sum[2]); - } - if (i + 96 < src_vec->Dim) { - float t3 = __half2float(dst_vec->coords[i + 96]); - float s3 = __half2float(src_vec[0].coords[i + 96]); - partial_sum[3] = fmaf(s3 - t3, s3 - t3, partial_sum[3]); - } - } - partial_sum[0] += partial_sum[1] + partial_sum[2] + partial_sum[3]; + temp_dst[0] = dst_vec->coords[i]; + if (i + 32 < src_vec->Dim) temp_dst[1] = dst_vec->coords[i + 32]; + if (i + 64 < src_vec->Dim) temp_dst[2] = dst_vec->coords[i + 64]; + if (i + 96 < src_vec->Dim) temp_dst[3] = dst_vec->coords[i + 96]; - for (int offset = 16; offset > 0; offset /= 2) { - partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + l2_half_fma_sq(partial_sum[0], src_vec[0].coords[i], temp_dst[0]); + if (i + 32 < src_vec->Dim) + l2_half_fma_sq(partial_sum[1], src_vec[0].coords[i + 32], temp_dst[1]); + if (i + 64 < src_vec->Dim) + l2_half_fma_sq(partial_sum[2], src_vec[0].coords[i + 64], temp_dst[2]); + if (i + 96 < src_vec->Dim) + l2_half_fma_sq(partial_sum[3], src_vec[0].coords[i + 96], temp_dst[3]); } + partial_sum[0] = __hadd(partial_sum[0], __hadd(partial_sum[1], __hadd(partial_sum[2], partial_sum[3]))); - return partial_sum[0]; + return l2_half_warp_reduce(partial_sum[0]); } /* Selects ILP optimization level based on dimension */ @@ -381,17 +410,11 @@ __device__ SUMTYPE l2_SEQ_half_warp(Point<__half, SUMTYPE>* src_vec, Point<__half, SUMTYPE>* dst_vec, int lane) { - SUMTYPE partial_sum = 0; + __half lane_sum = __float2half(0.0f); for (int i = lane; i < src_vec->Dim; i += VAMANA_WARP_SIZE) { - float s = __half2float(src_vec[0].coords[i]); - float t = __half2float(dst_vec[0].coords[i]); - float d = s - t; - partial_sum = fmaf(d, d, partial_sum); + l2_half_accum(lane_sum, src_vec[0].coords[i], dst_vec[0].coords[i]); } - for (int offset = 16; offset > 0; offset /= 2) { - partial_sum += __shfl_down_sync(FULL_BITMASK, partial_sum, offset); - } - return partial_sum; + return l2_half_warp_reduce(lane_sum); } template @@ -399,23 +422,18 @@ __device__ SUMTYPE l2_ILP2_half_warp(Point<__half, SUMTYPE>* src_vec, Point<__half, SUMTYPE>* dst_vec, int lane) { - float partial_sum[2] = {0, 0}; + __half temp_dst[2] = {__float2half(0.0f), __float2half(0.0f)}; + __half partial_sum[2] = {__float2half(0.0f), __float2half(0.0f)}; for (int i = lane; i < src_vec->Dim; i += 2 * VAMANA_WARP_SIZE) { - float t0 = __half2float(dst_vec->coords[i]); - float s0 = __half2float(src_vec[0].coords[i]); - partial_sum[0] = fmaf(s0 - t0, s0 - t0, partial_sum[0]); - - if (i + 32 < src_vec->Dim) { - float t1 = __half2float(dst_vec->coords[i + 32]); - float s1 = __half2float(src_vec[0].coords[i + 32]); - partial_sum[1] = fmaf(s1 - t1, s1 - t1, partial_sum[1]); - } - } - partial_sum[0] += partial_sum[1]; - for (int offset = 16; offset > 0; offset /= 2) { - partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + temp_dst[0] = dst_vec->coords[i]; + if (i + 32 < src_vec->Dim) temp_dst[1] = dst_vec->coords[i + 32]; + + l2_half_fma_sq(partial_sum[0], src_vec[0].coords[i], temp_dst[0]); + if (i + 32 < src_vec->Dim) + l2_half_fma_sq(partial_sum[1], src_vec[0].coords[i + 32], temp_dst[1]); } - return partial_sum[0]; + partial_sum[0] = __hadd(partial_sum[0], partial_sum[1]); + return l2_half_warp_reduce(partial_sum[0]); } template @@ -423,33 +441,30 @@ __device__ SUMTYPE l2_ILP4_half_warp(Point<__half, SUMTYPE>* src_vec, Point<__half, SUMTYPE>* dst_vec, int lane) { - float partial_sum[4] = {0, 0, 0, 0}; + __half temp_dst[4] = {__float2half(0.0f), + __float2half(0.0f), + __float2half(0.0f), + __float2half(0.0f)}; + __half partial_sum[4] = {__float2half(0.0f), + __float2half(0.0f), + __float2half(0.0f), + __float2half(0.0f)}; for (int i = lane; i < src_vec->Dim; i += 4 * VAMANA_WARP_SIZE) { - float t0 = __half2float(dst_vec->coords[i]); - float s0 = __half2float(src_vec[0].coords[i]); - partial_sum[0] = fmaf(s0 - t0, s0 - t0, partial_sum[0]); - - if (i + 32 < src_vec->Dim) { - float t1 = __half2float(dst_vec->coords[i + 32]); - float s1 = __half2float(src_vec[0].coords[i + 32]); - partial_sum[1] = fmaf(s1 - t1, s1 - t1, partial_sum[1]); - } - if (i + 64 < src_vec->Dim) { - float t2 = __half2float(dst_vec->coords[i + 64]); - float s2 = __half2float(src_vec[0].coords[i + 64]); - partial_sum[2] = fmaf(s2 - t2, s2 - t2, partial_sum[2]); - } - if (i + 96 < src_vec->Dim) { - float t3 = __half2float(dst_vec->coords[i + 96]); - float s3 = __half2float(src_vec[0].coords[i + 96]); - partial_sum[3] = fmaf(s3 - t3, s3 - t3, partial_sum[3]); - } - } - partial_sum[0] += partial_sum[1] + partial_sum[2] + partial_sum[3]; - for (int offset = 16; offset > 0; offset /= 2) { - partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + temp_dst[0] = dst_vec->coords[i]; + if (i + 32 < src_vec->Dim) temp_dst[1] = dst_vec->coords[i + 32]; + if (i + 64 < src_vec->Dim) temp_dst[2] = dst_vec->coords[i + 64]; + if (i + 96 < src_vec->Dim) temp_dst[3] = dst_vec->coords[i + 96]; + + l2_half_fma_sq(partial_sum[0], src_vec[0].coords[i], temp_dst[0]); + if (i + 32 < src_vec->Dim) + l2_half_fma_sq(partial_sum[1], src_vec[0].coords[i + 32], temp_dst[1]); + if (i + 64 < src_vec->Dim) + l2_half_fma_sq(partial_sum[2], src_vec[0].coords[i + 64], temp_dst[2]); + if (i + 96 < src_vec->Dim) + l2_half_fma_sq(partial_sum[3], src_vec[0].coords[i + 96], temp_dst[3]); } - return partial_sum[0]; + partial_sum[0] = __hadd(partial_sum[0], __hadd(partial_sum[1], __hadd(partial_sum[2], partial_sum[3]))); + return l2_half_warp_reduce(partial_sum[0]); } template @@ -487,6 +502,121 @@ __forceinline__ __device__ SUMTYPE l2_warp(const T* src, const T* dest, int dim, return l2_warp(&src_p, &dest_p, lane); } +/* float query vs half dataset: vectorized half2/float2 loads (even lane*2 indices) */ +__device__ __forceinline__ float2 l2_load_src2(const float* src, int i) +{ + return *reinterpret_cast(&src[i]); +} + +__device__ __forceinline__ float2 l2_load_dst2_half(const __half* dst, int i) +{ + return __half22float2(*reinterpret_cast(&dst[i])); +} + +template +__device__ __forceinline__ void l2_fma_sq2(SUMTYPE& acc, float sx, float sy, float2 dst2) +{ + float dx = sx - dst2.x; + float dy = sy - dst2.y; + acc = fmaf(dx, dx, acc); + acc = fmaf(dy, dy, acc); +} + +template +__device__ SUMTYPE l2_SEQ_warp_float_half(const float* src, const __half* dst, int dim, int lane) +{ + SUMTYPE partial_sum = 0; + for (int i = lane * 2; i < dim; i += VAMANA_WARP_SIZE * 2) { + float2 dst2 = l2_load_dst2_half(dst, i); + float2 src2 = l2_load_src2(src, i); + l2_fma_sq2(partial_sum, src2.x, src2.y, dst2); + } + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum += __shfl_down_sync(FULL_BITMASK, partial_sum, offset); + } + return partial_sum; +} + +template +__device__ SUMTYPE l2_ILP2_warp_float_half(const float* src, const __half* dst, int dim, int lane) +{ + SUMTYPE partial_sum[2] = {0, 0}; + for (int i = lane * 2; i < dim; i += 2 * VAMANA_WARP_SIZE * 2) { + float2 temp_dst[2] = {{0, 0}, {0, 0}}; + temp_dst[0] = l2_load_dst2_half(dst, i); + if (i + 64 < dim) temp_dst[1] = l2_load_dst2_half(dst, i + 64); + + float2 src0 = l2_load_src2(src, i); + l2_fma_sq2(partial_sum[0], src0.x, src0.y, temp_dst[0]); + if (i + 64 < dim) { + float2 src1 = l2_load_src2(src, i + 64); + l2_fma_sq2(partial_sum[1], src1.x, src1.y, temp_dst[1]); + } + } + partial_sum[0] += partial_sum[1]; + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + } + return partial_sum[0]; +} + +template +__device__ SUMTYPE l2_ILP4_warp_float_half(const float* src, const __half* dst, int dim, int lane) +{ + SUMTYPE partial_sum[4] = {0, 0, 0, 0}; + for (int i = lane * 2; i < dim; i += 4 * VAMANA_WARP_SIZE * 2) { + float2 temp_dst[4] = {{0, 0}, {0, 0}, {0, 0}, {0, 0}}; + temp_dst[0] = l2_load_dst2_half(dst, i); + if (i + 64 < dim) temp_dst[1] = l2_load_dst2_half(dst, i + 64); + if (i + 128 < dim) temp_dst[2] = l2_load_dst2_half(dst, i + 128); + if (i + 192 < dim) temp_dst[3] = l2_load_dst2_half(dst, i + 192); + + float2 src0 = l2_load_src2(src, i); + l2_fma_sq2(partial_sum[0], src0.x, src0.y, temp_dst[0]); + if (i + 64 < dim) { + float2 src1 = l2_load_src2(src, i + 64); + l2_fma_sq2(partial_sum[1], src1.x, src1.y, temp_dst[1]); + } + if (i + 128 < dim) { + float2 src2 = l2_load_src2(src, i + 128); + l2_fma_sq2(partial_sum[2], src2.x, src2.y, temp_dst[2]); + } + if (i + 192 < dim) { + float2 src3 = l2_load_src2(src, i + 192); + l2_fma_sq2(partial_sum[3], src3.x, src3.y, temp_dst[3]); + } + } + partial_sum[0] += partial_sum[1] + partial_sum[2] + partial_sum[3]; + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + } + return partial_sum[0]; +} + +template +__forceinline__ __device__ SUMTYPE +l2_warp_float_half(const float* src, const __half* dest, int dim, int lane) +{ + if (dim >= 128) { + return l2_ILP4_warp_float_half(src, dest, dim, lane); + } else if (dim >= 64) { + return l2_ILP2_warp_float_half(src, dest, dim, lane); + } else { + return l2_SEQ_warp_float_half(src, dest, dim, lane); + } +} + +template +__forceinline__ __device__ SUMTYPE dist_warp( + const float* src, const half* dest, int dim, cuvs::distance::DistanceType metric, int lane) +{ + SUMTYPE d = l2_warp_float_half(src, reinterpret_cast(dest), dim, lane); + if (metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + return static_cast(sqrtf(static_cast(d))); + } + return d; +} + template __forceinline__ __device__ SUMTYPE dist_warp(const T* src, const T* dest, int dim, cuvs::distance::DistanceType metric, int lane) @@ -498,6 +628,21 @@ dist_warp(const T* src, const T* dest, int dim, cuvs::distance::DistanceType met return d; } +/* Block/warp L2: float query vs half dataset (RobustPrune uses blockDim=32) */ +template +__forceinline__ __device__ SUMTYPE dist(const float* src, + const half* dest, + int dim, + cuvs::distance::DistanceType metric) +{ + SUMTYPE d = + l2_warp_float_half(src, reinterpret_cast(dest), dim, threadIdx.x); + if (metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + return static_cast(sqrtf(static_cast(d))); + } + return d; +} + /*************************************************************************************** * Structure that holds information about and results of a query. Use by both * GreedySearch and RobustPrune, as well as reverse edge lists. @@ -699,6 +844,48 @@ __device__ void update_shared_point_warp(Point* shared_point, } } +// Promote half dataset vector to float in shared memory (once per query) +template +__device__ void update_shared_point_half_to_float(Point* shared_point, + const half* data_ptr, + int id, + int dim) +{ + const __half* half_ptr = reinterpret_cast(data_ptr); + shared_point->id = id; + shared_point->Dim = dim; + const size_t base = (size_t)id * (size_t)dim; + for (size_t i = threadIdx.x * 2; i + 1 < (size_t)dim; i += (size_t)blockDim.x * 2) { + float2 promoted = __half22float2(*reinterpret_cast(&half_ptr[base + i])); + float2* coord_pair = reinterpret_cast(&shared_point->coords[i]); + *coord_pair = promoted; + } + if (((size_t)dim & 1u) != 0u && threadIdx.x == 0) { + shared_point->coords[dim - 1] = __half2float(half_ptr[base + dim - 1]); + } +} + +template +__device__ void update_shared_point_warp_half_to_float(Point* shared_point, + const half* data_ptr, + int id, + int dim, + int laneId) +{ + const __half* half_ptr = reinterpret_cast(data_ptr); + shared_point->id = id; + shared_point->Dim = dim; + const size_t base = (size_t)id * (size_t)dim; + for (size_t i = laneId * 2; i + 1 < (size_t)dim; i += 64) { + float2 promoted = __half22float2(*reinterpret_cast(&half_ptr[base + i])); + float2* coord_pair = reinterpret_cast(&shared_point->coords[i]); + *coord_pair = promoted; + } + if (((size_t)dim & 1u) != 0u && laneId == 0) { + shared_point->coords[dim - 1] = __half2float(half_ptr[base + dim - 1]); + } +} + // Update the graph from the results of the query list (or reverse edge list) template __global__ void write_graph_edges_kernel(raft::device_matrix_view graph, diff --git a/examples/cpp/src/vamana_example.cu b/examples/cpp/src/vamana_example.cu index b43eb20b1d..a73dd2a0b5 100644 --- a/examples/cpp/src/vamana_example.cu +++ b/examples/cpp/src/vamana_example.cu @@ -136,6 +136,19 @@ int main(int argc, char* argv[]) max_fraction, iters, codebook_prefix); + } else if (dtype == "half" || dtype == "fp16") { + // Read in binary dataset file + auto dataset = read_bin_dataset<__half, int64_t>(dev_resources, data_fname, INT_MAX); + + // Simple build example to create graph and write to a file + vamana_build_and_write<__half>(dev_resources, + raft::make_const_mdspan(dataset.view()), + out_fname, + degree, + max_visited, + max_fraction, + iters, + codebook_prefix); } else { usage(); } From 02b563800636e49f59ab45a792561808ecd2e76b Mon Sep 17 00:00:00 2001 From: bkarsin Date: Thu, 11 Jun 2026 13:06:42 -0700 Subject: [PATCH 05/14] =?UTF-8?q?perf(vamana):=20P1=5Fcache=5Fcandidate=5F?= =?UTF-8?q?vector=20=E2=80=94=20Cache=20the=20accepted=20candidate=20vecto?= =?UTF-8?q?r=20in=20shared=20memory=20in=20the=20RobustPrune=20occlusion?= =?UTF-8?q?=20loop?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (cherry picked from commit f45fd1b49283434eb4a3017da069ead501e938c3) --- cpp/src/neighbors/detail/vamana/macros.cuh | 3 + .../neighbors/detail/vamana/robust_prune.cuh | 65 ++++++++++++++++--- .../neighbors/detail/vamana/vamana_build.cuh | 5 +- 3 files changed, 62 insertions(+), 11 deletions(-) diff --git a/cpp/src/neighbors/detail/vamana/macros.cuh b/cpp/src/neighbors/detail/vamana/macros.cuh index 8ec1509677..d8e8c0ebe9 100644 --- a/cpp/src/neighbors/detail/vamana/macros.cuh +++ b/cpp/src/neighbors/detail/vamana/macros.cuh @@ -7,6 +7,9 @@ namespace cuvs::neighbors::vamana::detail { +// RobustPrune caches accepted candidate vectors in dynamic shared memory only at wide dims. +static constexpr int kRobustPruneCandCacheMinDim = 128; + /* Macros to compute the shared memory requirements for CUB primitives used by search and prune */ #define COMPUTE_SMEM_SIZE(degree, visited_size, DEG, CANDS) \ if (degree == DEG && visited_size <= CANDS && visited_size > CANDS / 2) { \ diff --git a/cpp/src/neighbors/detail/vamana/robust_prune.cuh b/cpp/src/neighbors/detail/vamana/robust_prune.cuh index 6df57e3505..f241f7d677 100644 --- a/cpp/src/neighbors/detail/vamana/robust_prune.cuh +++ b/cpp/src/neighbors/detail/vamana/robust_prune.cuh @@ -85,8 +85,14 @@ __global__ void RobustPruneKernel( int align_padding = raft::alignTo(dim, alignof(ShmemLayout)) - dim; float* occlusion_list = reinterpret_cast(smem); + const int nbh_list_offset = (degree + visited_size) * sizeof(float); DistPair* new_nbh_list = - reinterpret_cast*>(&smem[(degree + visited_size) * sizeof(float)]); + reinterpret_cast*>(&smem[nbh_list_offset]); + QueryCoordT* s_cand_coords = nullptr; + if (dim >= kRobustPruneCandCacheMinDim) { + s_cand_coords = reinterpret_cast( + &smem[nbh_list_offset + (degree + visited_size) * sizeof(DistPair)]); + } static __shared__ Point s_query; s_query.coords = &s_coords_mem[blockIdx.x * (dim + align_padding)]; @@ -201,6 +207,12 @@ __global__ void RobustPruneKernel( // If we need to prune at all... if (res_size > degree) { int accept_count = 0; + const bool cache_cand_in_smem = dim >= kRobustPruneCandCacheMinDim; + Point s_cand; + if (cache_cand_in_smem) { + s_cand.coords = s_cand_coords; + s_cand.Dim = dim; + } // Go through different alpha values. These constants are hard-coded in the MSFT DiskANN code for (float cur_alpha = 1.0; cur_alpha <= alpha && accept_count < degree; cur_alpha *= 1.2) { @@ -213,20 +225,53 @@ __global__ void RobustPruneKernel( if (new_nbh_list[pass_start].idx == queryId) { continue; } - T* cand_ptr = const_cast(&dataset((size_t)(new_nbh_list[pass_start].idx), 0)); + if (cache_cand_in_smem) { + if constexpr (is_cuda_fp16_v) { + update_shared_point_half_to_float( + &s_cand, &dataset(0, 0), new_nbh_list[pass_start].idx, dim); + } else { + update_shared_point( + &s_cand, &dataset(0, 0), new_nbh_list[pass_start].idx, dim); + } + __syncthreads(); + } occlusion_list[pass_start] = raft::lower_bound(); // Mark as "accepted" accept_count++; // Update rest of the occlusion list - for (int occId = pass_start + 1; occId < res_size; occId++) { - if (occlusion_list[occId] <= alpha && - occlusion_list[occId] != raft::lower_bound()) { - T* k_ptr = const_cast(&dataset((size_t)(new_nbh_list[occId].idx), 0)); - accT djk = dist(cand_ptr, k_ptr, dim, metric); - accT new_occ = (float)(new_nbh_list[occId].dist / djk); - - occlusion_list[occId] = std::max(occlusion_list[occId], new_occ); + if (cache_cand_in_smem) { + for (int occId = pass_start + 1; occId < res_size; occId++) { + if (occlusion_list[occId] <= alpha && + occlusion_list[occId] != raft::lower_bound()) { + T* k_ptr = const_cast(&dataset((size_t)(new_nbh_list[occId].idx), 0)); + accT djk; + if constexpr (is_cuda_fp16_v) { + djk = dist(s_cand.coords, k_ptr, dim, metric); + } else { + djk = dist(s_cand.coords, k_ptr, dim, metric); + } + accT new_occ = (float)(new_nbh_list[occId].dist / djk); + + occlusion_list[occId] = std::max(occlusion_list[occId], new_occ); + } + } + } else { + T* cand_ptr = const_cast(&dataset((size_t)(new_nbh_list[pass_start].idx), 0)); + for (int occId = pass_start + 1; occId < res_size; occId++) { + if (occlusion_list[occId] <= alpha && + occlusion_list[occId] != raft::lower_bound()) { + T* k_ptr = const_cast(&dataset((size_t)(new_nbh_list[occId].idx), 0)); + accT djk; + if constexpr (is_cuda_fp16_v) { + djk = dist(cand_ptr, k_ptr, dim, metric); + } else { + djk = dist(cand_ptr, k_ptr, dim, metric); + } + accT new_occ = (float)(new_nbh_list[occId].dist / djk); + + occlusion_list[occId] = std::max(occlusion_list[occId], new_occ); + } } } } diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index e1dbc1c59f..34226d6b17 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -197,8 +197,11 @@ void batched_insert_vamana( static_cast(4 * ((coords_size + neighbor_size + queue_size_bytes + 15) & ~15)); // Total dynamic shared memory size needed by both RobustPrune calls + const int cand_coords_smem_size = + (dim >= kRobustPruneCandCacheMinDim) ? coords_size : 0; int prune_smem_total_size = (degree + visited_size) * sizeof(float) + // Occlusion list - (degree + visited_size) * sizeof(DistPair); + (degree + visited_size) * sizeof(DistPair) + + cand_coords_smem_size; RAFT_LOG_DEBUG( "Dynamic shared memory usage (bytes): GreedySearch: %d, Segment Sort: %d, Robust Prune: %d", From 9ec9b99c26e11ab5b0d2c44077db10e0f996d333 Mon Sep 17 00:00:00 2001 From: bkarsin Date: Thu, 11 Jun 2026 13:23:57 -0700 Subject: [PATCH 06/14] =?UTF-8?q?perf(vamana):=20O1=5Fparallel=5Fscan=5Fbu?= =?UTF-8?q?ffer=5Fpool=20=E2=80=94=20Replace=20single-thread=20prefix=20su?= =?UTF-8?q?m=20with=20cub=20scan=20and=20hoist=20per-batch=20reverse-edge?= =?UTF-8?q?=20allocations?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (cherry picked from commit 3b8650f5ebea52421f187332a2f6f3bdd599c42e) --- .../neighbors/detail/vamana/vamana_build.cuh | 210 ++++++++++-------- .../detail/vamana/vamana_structs.cuh | 5 +- 2 files changed, 120 insertions(+), 95 deletions(-) diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index 34226d6b17..0b10a72129 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -28,6 +28,12 @@ #include #include +#include +#include + +#include +#include + #include #include @@ -85,6 +91,26 @@ __global__ void print_queryIds(void* query_list_ptr) #define KERNEL_TIMING (RAFT_LOG_ACTIVE_LEVEL <= RAPIDS_LOGGER_LOG_LEVEL_DEBUG) +template +__global__ void gather_query_sizes(QueryCandidates* query_list, + int* edge_counts, + int count) +{ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) { + edge_counts[i] = query_list[i].size; + } +} + +template +__global__ void scatter_prefix_offsets(QueryCandidates* query_list, + const int* edge_offsets, + int count) +{ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) { + query_list[i].size = edge_offsets[i]; + } +} + /******************************************************************************************** * Main Vamana building function - insert vectors into empty graph in batches * Pre - dataset contains the vector data, host matrix allocated to store the graph @@ -103,7 +129,8 @@ void batched_insert_vamana( IdxT* medoid_id, cuvs::distance::DistanceType metric) { - auto stream = raft::resource::get_cuda_stream(res); + auto stream = raft::resource::get_cuda_stream(res); + cudaStream_t cs = stream; int N = dataset.extent(0); int dim = dataset.extent(1); int degree = graph.extent(1); @@ -222,6 +249,64 @@ void batched_insert_vamana( double batch_prune = 0.0; #endif + const int64_t max_total_edges = static_cast(max_batchsize) * degree; + const int max_reverse_batch = params.reverse_batchsize; + auto large_ws = raft::resource::get_large_workspace_resource_ref(res); + + auto edge_dist_pair = raft::make_device_mdarray>( + res, large_ws, raft::make_extents(max_total_edges)); + auto edge_dest = + raft::make_device_mdarray(res, large_ws, raft::make_extents(max_total_edges)); + auto edge_src = + raft::make_device_mdarray(res, large_ws, raft::make_extents(max_total_edges)); + + auto edge_counts = raft::make_device_mdarray( + res, large_ws, raft::make_extents(max_batchsize + 1)); + auto edge_offsets = raft::make_device_mdarray( + res, large_ws, raft::make_extents(max_batchsize + 1)); + + size_t temp_storage_bytes_dist = 0; + size_t temp_storage_bytes_edge = 0; + cub::DeviceMergeSort::SortPairs(nullptr, + temp_storage_bytes_dist, + edge_dist_pair.data_handle(), + edge_src.data_handle(), + max_total_edges, + CmpDist(), + cs); + cub::DeviceMergeSort::SortPairs(nullptr, + temp_storage_bytes_edge, + edge_dest.data_handle(), + edge_src.data_handle(), + max_total_edges, + CmpEdge(), + cs); + size_t temp_storage_bytes = + std::max(temp_storage_bytes_dist, temp_storage_bytes_edge); + auto temp_sort_storage = raft::make_device_mdarray( + res, large_ws, raft::make_extents(std::max(temp_storage_bytes, size_t{1}))); + + size_t scan_temp_bytes = 0; + cub::DeviceScan::ExclusiveSum(nullptr, + scan_temp_bytes, + edge_counts.data_handle(), + edge_offsets.data_handle(), + max_batchsize + 1, + cs); + auto scan_temp_storage = raft::make_device_mdarray( + res, large_ws, raft::make_extents(std::max(scan_temp_bytes, size_t{1}))); + + thrust::device_vector edge_dest_vec(max_total_edges); + + auto reverse_list_ptr = raft::make_device_mdarray>( + res, large_ws, raft::make_extents(max_reverse_batch)); + auto rev_ids = raft::make_device_mdarray( + res, large_ws, raft::make_extents(max_reverse_batch, visited_size)); + auto rev_dists = raft::make_device_mdarray( + res, large_ws, raft::make_extents(max_reverse_batch, visited_size)); + QueryCandidates* reverse_list = + static_cast*>(reverse_list_ptr.data_handle()); + // Random medoid has minor impact on recall // TODO: use heuristic for better medoid selection, issue: // https://github.com/rapidsai/cuvs/issues/355 @@ -322,30 +407,27 @@ void batched_insert_vamana( start_t = std::chrono::system_clock::now(); #endif - // compute prefix sums of query_list sizes - TODO parallelize prefix sums - // auto d_total_edges = raft::make_device_mdarray( - // res, raft::resource::get_workspace_resource_ref(res), raft::make_extents(1)); - rmm::device_scalar d_total_edges(stream); - prefix_sums_sizes<<<1, 1, 0, stream>>>(query_list, step_size, d_total_edges.data()); + // compute prefix sums of query_list sizes + const int prefix_count = step_size + 1; + gather_query_sizes<<>>( + query_list, edge_counts.data_handle(), prefix_count); RAFT_CUDA_TRY(cudaPeekAtLastError()); - int total_edges = d_total_edges.value(stream); - // raft::copy(&total_edges, d_total_edges.data_handle(), 1, stream); - // RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + cub::DeviceScan::ExclusiveSum(scan_temp_storage.data_handle(), + scan_temp_bytes, + edge_counts.data_handle(), + edge_offsets.data_handle(), + prefix_count, + cs); + RAFT_CUDA_TRY(cudaPeekAtLastError()); - auto edge_dist_pair = raft::make_device_mdarray>( - res, - raft::resource::get_large_workspace_resource_ref(res), - raft::make_extents(total_edges)); - - auto edge_dest = - raft::make_device_mdarray(res, - raft::resource::get_large_workspace_resource_ref(res), - raft::make_extents(total_edges)); - auto edge_src = - raft::make_device_mdarray(res, - raft::resource::get_large_workspace_resource_ref(res), - raft::make_extents(total_edges)); + scatter_prefix_offsets<<>>( + query_list, edge_offsets.data_handle(), prefix_count); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + + int total_edges; + raft::copy(&total_edges, edge_offsets.data_handle() + step_size, 1, stream); + raft::resource::sync_stream(res); // Create reverse edge list create_reverse_edge_list @@ -358,62 +440,20 @@ void batched_insert_vamana( { // Sort by dists first so final edge lists are each sorted by dist - void* d_temp_storage = nullptr; - size_t temp_storage_bytes = 0; - - cub::DeviceMergeSort::SortPairs(d_temp_storage, - temp_storage_bytes, - edge_dist_pair.data_handle(), - edge_src.data_handle(), - total_edges, - CmpDist(), - stream); - - RAFT_LOG_DEBUG("Temp storage needed for sorting dist (bytes): %lu", temp_storage_bytes); - - auto temp_sort_storage = raft::make_device_mdarray( - res, - raft::resource::get_large_workspace_resource_ref(res), - raft::make_extents(temp_storage_bytes / sizeof(IdxT))); - - // Sort to group reverse edges by destination cub::DeviceMergeSort::SortPairs(temp_sort_storage.data_handle(), temp_storage_bytes, edge_dist_pair.data_handle(), edge_src.data_handle(), total_edges, CmpDist(), - stream); + cs); } - /* - DistPair* temp_ptr = edge_dist_pair.data_handle(); + DistPair* edge_dist_pair_ptr = edge_dist_pair.data_handle(); raft::linalg::map_offset( - res, edge_dest.view(), [temp_ptr] __device__(size_t i) { return temp_ptr[i].idx; }); - */ - raft::linalg::map( res, - edge_dest.view(), - [] __device__(auto x) { return x.idx; }, - raft::make_const_mdspan(edge_dist_pair.view())); - - void* d_temp_storage = nullptr; - size_t temp_storage_bytes = 0; - - cub::DeviceMergeSort::SortPairs(d_temp_storage, - temp_storage_bytes, - edge_dest.data_handle(), - edge_src.data_handle(), - total_edges, - CmpEdge(), - stream); - - RAFT_LOG_DEBUG("Temp storage needed for sorting (bytes): %lu", temp_storage_bytes); - - auto temp_sort_storage = raft::make_device_mdarray( - res, - raft::resource::get_large_workspace_resource_ref(res), - raft::make_extents(temp_storage_bytes / sizeof(IdxT))); + raft::make_device_vector_view(edge_dest.data_handle(), total_edges), + [edge_dist_pair_ptr] __device__(size_t i) { return edge_dist_pair_ptr[i].idx; }); // Sort to group reverse edges by destination cub::DeviceMergeSort::SortPairs(temp_sort_storage.data_handle(), @@ -422,22 +462,24 @@ void batched_insert_vamana( edge_src.data_handle(), total_edges, CmpEdge(), - stream); + cs); // Get number of unique node destinations IdxT unique_dests = cuvs::sparse::neighbors::get_n_components(edge_dest.data_handle(), total_edges, stream); // Find which node IDs have reverse edges and their indices in the reverse edge list - thrust::device_vector edge_dest_vec(edge_dest.data_handle(), - edge_dest.data_handle() + total_edges); + RAFT_CUDA_TRY(cudaMemcpyAsync(thrust::raw_pointer_cast(edge_dest_vec.data()), + edge_dest.data_handle(), + total_edges * sizeof(IdxT), + cudaMemcpyDeviceToDevice, + stream)); auto unique_indices = raft::make_device_vector(res, total_edges); raft::linalg::map_offset(res, unique_indices.view(), raft::identity_op{}); - thrust::unique_by_key(edge_dest_vec.begin(), edge_dest_vec.end(), unique_indices.data_handle()); - - edge_dest_vec.clear(); - edge_dest_vec.shrink_to_fit(); + thrust::unique_by_key(edge_dest_vec.begin(), + edge_dest_vec.begin() + total_edges, + unique_indices.data_handle()); #if KERNEL_TIMING RAFT_CUDA_TRY(cudaDeviceSynchronize()); @@ -454,24 +496,6 @@ void batched_insert_vamana( reverse_batch = (int)unique_dests - rev_start; } - // Allocate reverse QueryCandidate list based on number of unique destinations - auto reverse_list_ptr = raft::make_device_mdarray>( - res, - raft::resource::get_large_workspace_resource_ref(res), - raft::make_extents(reverse_batch)); - auto rev_ids = - raft::make_device_mdarray(res, - raft::resource::get_large_workspace_resource_ref(res), - raft::make_extents(reverse_batch, visited_size)); - - auto rev_dists = - raft::make_device_mdarray(res, - raft::resource::get_large_workspace_resource_ref(res), - raft::make_extents(reverse_batch, visited_size)); - - QueryCandidates* reverse_list = - static_cast*>(reverse_list_ptr.data_handle()); - init_query_candidate_list<<<256, blockD, 0, stream>>>(reverse_list, rev_ids.data_handle(), rev_dists.data_handle(), diff --git a/cpp/src/neighbors/detail/vamana/vamana_structs.cuh b/cpp/src/neighbors/detail/vamana/vamana_structs.cuh index 906b6e1c4e..fd656609d5 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_structs.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_structs.cuh @@ -73,7 +73,8 @@ __device__ __host__ void swap(DistPair* a, DistPair* b) // Structure to sort by distance template struct CmpDist { - __device__ bool operator()(const DistPair& lhs, const DistPair& rhs) + __host__ __device__ bool operator()(const DistPair& lhs, + const DistPair& rhs) { return lhs.dist < rhs.dist; } @@ -82,7 +83,7 @@ struct CmpDist { // Used to sort reverse edges by destination template struct CmpEdge { - __device__ bool operator()(const IdxT& lhs, const IdxT& rhs) { return lhs < rhs; } + __host__ __device__ bool operator()(const IdxT& lhs, const IdxT& rhs) { return lhs < rhs; } }; /********************************************************************* From 696d3d286bb0e44ffbf6db7dddcc9e67f5f1ebca Mon Sep 17 00:00:00 2001 From: bkarsin Date: Sat, 13 Jun 2026 03:18:20 -0700 Subject: [PATCH 07/14] =?UTF-8?q?perf(vamana):=20N3=5Fmultiwarp=5Frobust?= =?UTF-8?q?=5Fprune=20=E2=80=94=20Parallelize=20RobustPrune=20occlusion=20?= =?UTF-8?q?across=20multiple=20warps=20per=20query=20(raise=20occupancy)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (cherry picked from commit 2e02f938f97e65ca073daf07397d538432b52867) --- .../neighbors/detail/vamana/robust_prune.cuh | 134 ++++++++++-------- .../neighbors/detail/vamana/vamana_build.cuh | 34 ++--- 2 files changed, 91 insertions(+), 77 deletions(-) diff --git a/cpp/src/neighbors/detail/vamana/robust_prune.cuh b/cpp/src/neighbors/detail/vamana/robust_prune.cuh index f241f7d677..ee0bd8d6da 100644 --- a/cpp/src/neighbors/detail/vamana/robust_prune.cuh +++ b/cpp/src/neighbors/detail/vamana/robust_prune.cuh @@ -99,6 +99,12 @@ __global__ void RobustPruneKernel( s_query.Dim = dim; static __shared__ int prev_edges; static __shared__ accT graphDist; + static __shared__ int s_accept_count; + static __shared__ int s_do_accept; + + const int laneId = threadIdx.x & 31; + const int warpId = threadIdx.x >> 5; + const int num_warps = blockDim.x >> 5; for (int i = blockIdx.x; i < num_queries; i += gridDim.x) { int queryId = query_list[i].queryId; @@ -155,34 +161,46 @@ __global__ void RobustPruneKernel( } } else if (listIdx >= visited_size) { next_cand.idx = graph(queryId, graphIdx); - accT tempDist; - if constexpr (is_cuda_fp16_v) { - tempDist = dist(s_query.coords, - &dataset((size_t)graph(queryId, graphIdx), 0), - dim, - metric); - } else { - tempDist = dist( - s_query.coords, &dataset((size_t)graph(queryId, graphIdx), 0), dim, metric); + if (warpId == 0) { + accT tempDist; + if constexpr (is_cuda_fp16_v) { + tempDist = dist_warp(s_query.coords, + &dataset((size_t)graph(queryId, graphIdx), 0), + dim, + metric, + laneId); + } else { + tempDist = dist_warp(s_query.coords, + &dataset((size_t)graph(queryId, graphIdx), 0), + dim, + metric, + laneId); + } + if (laneId == 0) graphDist = tempDist; } - if (threadIdx.x == 0) graphDist = tempDist; __syncthreads(); next_cand.dist = graphDist; graphIdx++; } else { accT listDist = query_list[i].dists[listIdx]; - accT tempDist; - if constexpr (is_cuda_fp16_v) { - tempDist = dist(s_query.coords, - &dataset((size_t)graph(queryId, graphIdx), 0), - dim, - metric); - } else { - tempDist = dist( - s_query.coords, &dataset((size_t)graph(queryId, graphIdx), 0), dim, metric); + if (warpId == 0) { + accT tempDist; + if constexpr (is_cuda_fp16_v) { + tempDist = dist_warp(s_query.coords, + &dataset((size_t)graph(queryId, graphIdx), 0), + dim, + metric, + laneId); + } else { + tempDist = dist_warp(s_query.coords, + &dataset((size_t)graph(queryId, graphIdx), 0), + dim, + metric, + laneId); + } + if (laneId == 0) graphDist = tempDist; } - if (threadIdx.x == 0) graphDist = tempDist; __syncthreads(); if (listDist <= graphDist) { @@ -206,7 +224,8 @@ __global__ void RobustPruneKernel( // If we need to prune at all... if (res_size > degree) { - int accept_count = 0; + if (threadIdx.x == 0) s_accept_count = 0; + __syncthreads(); const bool cache_cand_in_smem = dim >= kRobustPruneCandCacheMinDim; Point s_cand; if (cache_cand_in_smem) { @@ -215,17 +234,19 @@ __global__ void RobustPruneKernel( } // Go through different alpha values. These constants are hard-coded in the MSFT DiskANN code - for (float cur_alpha = 1.0; cur_alpha <= alpha && accept_count < degree; cur_alpha *= 1.2) { - for (int pass_start = 0; pass_start < res_size && accept_count < degree; pass_start++) { - // pick next non-occluded element - if (occlusion_list[pass_start] == raft::lower_bound() || - occlusion_list[pass_start] > cur_alpha) { - continue; // Skip over elements already pruned or already accepted + for (float cur_alpha = 1.0; cur_alpha <= alpha && s_accept_count < degree; + cur_alpha *= 1.2) { + for (int pass_start = 0; pass_start < res_size && s_accept_count < degree; pass_start++) { + if (threadIdx.x == 0) { + s_do_accept = (occlusion_list[pass_start] != raft::lower_bound() && + occlusion_list[pass_start] <= cur_alpha && + new_nbh_list[pass_start].idx != queryId) + ? 1 + : 0; } + __syncthreads(); - if (new_nbh_list[pass_start].idx == queryId) { continue; } - - if (cache_cand_in_smem) { + if (s_do_accept && cache_cand_in_smem) { if constexpr (is_cuda_fp16_v) { update_shared_point_half_to_float( &s_cand, &dataset(0, 0), new_nbh_list[pass_start].idx, dim); @@ -233,54 +254,45 @@ __global__ void RobustPruneKernel( update_shared_point( &s_cand, &dataset(0, 0), new_nbh_list[pass_start].idx, dim); } - __syncthreads(); } + __syncthreads(); - occlusion_list[pass_start] = raft::lower_bound(); // Mark as "accepted" - accept_count++; - - // Update rest of the occlusion list - if (cache_cand_in_smem) { - for (int occId = pass_start + 1; occId < res_size; occId++) { - if (occlusion_list[occId] <= alpha && - occlusion_list[occId] != raft::lower_bound()) { - T* k_ptr = const_cast(&dataset((size_t)(new_nbh_list[occId].idx), 0)); - accT djk; - if constexpr (is_cuda_fp16_v) { - djk = dist(s_cand.coords, k_ptr, dim, metric); - } else { - djk = dist(s_cand.coords, k_ptr, dim, metric); - } - accT new_occ = (float)(new_nbh_list[occId].dist / djk); - - occlusion_list[occId] = std::max(occlusion_list[occId], new_occ); - } + if (s_do_accept) { + if (threadIdx.x == 0) { + occlusion_list[pass_start] = raft::lower_bound(); + s_accept_count++; } - } else { + T* cand_ptr = const_cast(&dataset((size_t)(new_nbh_list[pass_start].idx), 0)); - for (int occId = pass_start + 1; occId < res_size; occId++) { + for (int occId = pass_start + 1 + warpId; occId < res_size; occId += num_warps) { if (occlusion_list[occId] <= alpha && occlusion_list[occId] != raft::lower_bound()) { T* k_ptr = const_cast(&dataset((size_t)(new_nbh_list[occId].idx), 0)); accT djk; - if constexpr (is_cuda_fp16_v) { - djk = dist(cand_ptr, k_ptr, dim, metric); + if (cache_cand_in_smem) { + if constexpr (is_cuda_fp16_v) { + djk = dist_warp(s_cand.coords, k_ptr, dim, metric, laneId); + } else { + djk = dist_warp(s_cand.coords, k_ptr, dim, metric, laneId); + } } else { - djk = dist(cand_ptr, k_ptr, dim, metric); + djk = dist_warp(cand_ptr, k_ptr, dim, metric, laneId); + } + if (laneId == 0) { + accT new_occ = (float)(new_nbh_list[occId].dist / djk); + occlusion_list[occId] = std::max(occlusion_list[occId], new_occ); } - accT new_occ = (float)(new_nbh_list[occId].dist / djk); - - occlusion_list[occId] = std::max(occlusion_list[occId], new_occ); } } } + __syncthreads(); } } // Move all "accepted" candidates to front of list and zero out the rest if (threadIdx.x == 0) { int out_idx = 1; - for (int read_idx = 1; out_idx < accept_count; read_idx++) { + for (int read_idx = 1; out_idx < s_accept_count; read_idx++) { if (occlusion_list[read_idx] == raft::lower_bound()) { // If it is "accepted" new_nbh_list[out_idx].idx = new_nbh_list[read_idx].idx; new_nbh_list[out_idx].dist = new_nbh_list[read_idx].dist; @@ -289,12 +301,12 @@ __global__ void RobustPruneKernel( } } __syncthreads(); - for (int out_idx = accept_count + threadIdx.x; out_idx < degree; out_idx++) { + for (int out_idx = s_accept_count + threadIdx.x; out_idx < degree; out_idx += blockDim.x) { new_nbh_list[out_idx].idx = raft::upper_bound(); new_nbh_list[out_idx].dist = raft::upper_bound(); } - if (threadIdx.x == 0) { res_size = accept_count; } + if (threadIdx.x == 0) { res_size = s_accept_count; } __syncthreads(); } diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index 0b10a72129..f1b3fdd484 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -49,6 +49,7 @@ namespace cuvs::neighbors::vamana::detail { static const int blockD = 32; static const int blockD_greedy = 128; // 4 warps per block, each warp processes one query +static const int blockD_prune = 128; // 4 warps per block, parallel occlusion per query static const int maxBlocks = 10000; // generate random permutation of inserts - TODO do this on GPU / faster @@ -371,14 +372,14 @@ void batched_insert_vamana( // Run on candidates of vectors being inserted RobustPruneKernel - <<>>(d_graph.view(), - dataset, - query_list_ptr.data_handle(), - step_size, - visited_size, - metric, - alpha, - s_coords_mem.data_handle()); + <<>>(d_graph.view(), + dataset, + query_list_ptr.data_handle(), + step_size, + visited_size, + metric, + alpha, + s_coords_mem.data_handle()); RAFT_CUDA_TRY(cudaPeekAtLastError()); // Segmented sort on query list @@ -525,14 +526,15 @@ void batched_insert_vamana( // Call 2nd RobustPrune on reverse query_list RobustPruneKernel - <<>>(d_graph.view(), - raft::make_const_mdspan(dataset), - reverse_list_ptr.data_handle(), - reverse_batch, - visited_size, - metric, - alpha, - s_coords_mem.data_handle()); + <<>>(d_graph.view(), + raft::make_const_mdspan( + dataset), + reverse_list_ptr.data_handle(), + reverse_batch, + visited_size, + metric, + alpha, + s_coords_mem.data_handle()); RAFT_CUDA_TRY(cudaPeekAtLastError()); // Segmented sort on reverse_list From 24f2ed7e03e54abb684d1e7ee047e5c2805a172d Mon Sep 17 00:00:00 2001 From: bkarsin Date: Sat, 13 Jun 2026 05:08:59 -0700 Subject: [PATCH 08/14] =?UTF-8?q?perf(vamana):=20N7=5Freuse=5Fsearch=5Fdis?= =?UTF-8?q?tances=5Fin=5Fprune=20=E2=80=94=20Reuse=20GreedySearch=20query-?= =?UTF-8?q?>existing-edge=20distances=20in=20the=20RobustPrune=20merge=20(?= =?UTF-8?q?avoid=20recompute)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (cherry picked from commit 041d355c585b7f98302d054b80ac287d637a07e4) --- cpp/src/neighbors/detail/vamana/macros.cuh | 2 +- .../neighbors/detail/vamana/robust_prune.cuh | 96 +++++++++---------- .../neighbors/detail/vamana/vamana_build.cuh | 3 +- .../detail/vamana/vamana_structs.cuh | 24 +++++ 4 files changed, 73 insertions(+), 52 deletions(-) diff --git a/cpp/src/neighbors/detail/vamana/macros.cuh b/cpp/src/neighbors/detail/vamana/macros.cuh index d8e8c0ebe9..714509b99e 100644 --- a/cpp/src/neighbors/detail/vamana/macros.cuh +++ b/cpp/src/neighbors/detail/vamana/macros.cuh @@ -7,7 +7,7 @@ namespace cuvs::neighbors::vamana::detail { -// RobustPrune caches accepted candidate vectors in dynamic shared memory only at wide dims. +// RobustPrune wide-dim optimizations: smem candidate cache and GreedySearch distance reuse. static constexpr int kRobustPruneCandCacheMinDim = 128; /* Macros to compute the shared memory requirements for CUB primitives used by search and prune */ diff --git a/cpp/src/neighbors/detail/vamana/robust_prune.cuh b/cpp/src/neighbors/detail/vamana/robust_prune.cuh index ee0bd8d6da..71707d03df 100644 --- a/cpp/src/neighbors/detail/vamana/robust_prune.cuh +++ b/cpp/src/neighbors/detail/vamana/robust_prune.cuh @@ -88,17 +88,21 @@ __global__ void RobustPruneKernel( const int nbh_list_offset = (degree + visited_size) * sizeof(float); DistPair* new_nbh_list = reinterpret_cast*>(&smem[nbh_list_offset]); + const int cand_coords_offset = + nbh_list_offset + (degree + visited_size) * sizeof(DistPair); + const int coord_bytes = (dim + align_padding) * static_cast(sizeof(QueryCoordT)); + int graph_dists_offset = cand_coords_offset; QueryCoordT* s_cand_coords = nullptr; if (dim >= kRobustPruneCandCacheMinDim) { - s_cand_coords = reinterpret_cast( - &smem[nbh_list_offset + (degree + visited_size) * sizeof(DistPair)]); + s_cand_coords = reinterpret_cast(&smem[cand_coords_offset]); + graph_dists_offset = cand_coords_offset + coord_bytes; } + accT* graph_dists = reinterpret_cast(&smem[graph_dists_offset]); static __shared__ Point s_query; s_query.coords = &s_coords_mem[blockIdx.x * (dim + align_padding)]; s_query.Dim = dim; static __shared__ int prev_edges; - static __shared__ accT graphDist; static __shared__ int s_accept_count; static __shared__ int s_do_accept; @@ -135,6 +139,32 @@ __global__ void RobustPruneKernel( } __syncthreads(); + // Precompute graph-edge distances in parallel; reuse GreedySearch dists when bit-exact. + const int visited_count = query_list[i].size; + const IdxT* visited_ids = query_list[i].ids; + const accT* visited_dists = query_list[i].dists; + const bool reuse_search_dists = (dim >= kRobustPruneCandCacheMinDim); + for (int j = warpId; j < prev_edges; j += num_warps) { + IdxT gid = graph(queryId, j); + accT d; + bool found = false; + if (reuse_search_dists) { + found = lookup_visited_dist_warp( + visited_ids, visited_dists, visited_count, gid, d, laneId); + } + if (!found) { + if constexpr (is_cuda_fp16_v) { + d = dist_warp( + s_query.coords, &dataset((size_t)gid, 0), dim, metric, laneId); + } else { + d = dist_warp( + s_query.coords, &dataset((size_t)gid, 0), dim, metric, laneId); + } + } + if (laneId == 0) { graph_dists[j] = d; } + } + __syncthreads(); + DistPair next_cand; // Merge graph and candidate list for (int outIdx = 0; outIdx < degree + visited_size; outIdx++) { @@ -160,60 +190,26 @@ __global__ void RobustPruneKernel( listIdx++; } } else if (listIdx >= visited_size) { - next_cand.idx = graph(queryId, graphIdx); - if (warpId == 0) { - accT tempDist; - if constexpr (is_cuda_fp16_v) { - tempDist = dist_warp(s_query.coords, - &dataset((size_t)graph(queryId, graphIdx), 0), - dim, - metric, - laneId); - } else { - tempDist = dist_warp(s_query.coords, - &dataset((size_t)graph(queryId, graphIdx), 0), - dim, - metric, - laneId); - } - if (laneId == 0) graphDist = tempDist; - } - __syncthreads(); - next_cand.dist = graphDist; + next_cand.idx = graph(queryId, graphIdx); + next_cand.dist = graph_dists[graphIdx]; graphIdx++; } else { accT listDist = query_list[i].dists[listIdx]; + IdxT listId = query_list[i].ids[listIdx]; + IdxT graphId = graph(queryId, graphIdx); - if (warpId == 0) { - accT tempDist; - if constexpr (is_cuda_fp16_v) { - tempDist = dist_warp(s_query.coords, - &dataset((size_t)graph(queryId, graphIdx), 0), - dim, - metric, - laneId); - } else { - tempDist = dist_warp(s_query.coords, - &dataset((size_t)graph(queryId, graphIdx), 0), - dim, - metric, - laneId); - } - if (laneId == 0) graphDist = tempDist; - } - __syncthreads(); - - if (listDist <= graphDist) { - next_cand.idx = query_list[i].ids[listIdx]; + if (graphId == listId) { + next_cand.idx = listId; + next_cand.dist = listDist; + graphIdx++; + listIdx++; + } else if (listDist <= graph_dists[graphIdx]) { + next_cand.idx = listId; next_cand.dist = listDist; - - if (graph(queryId, graphIdx) == query_list[i].ids[listIdx]) { // Duplicate found! - graphIdx++; // Skip the duplicate - } listIdx++; } else { - next_cand.idx = graph(queryId, graphIdx); - next_cand.dist = graphDist; + next_cand.idx = graphId; + next_cand.dist = graph_dists[graphIdx]; graphIdx++; } } diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index f1b3fdd484..c7330a9cde 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -229,7 +229,8 @@ void batched_insert_vamana( (dim >= kRobustPruneCandCacheMinDim) ? coords_size : 0; int prune_smem_total_size = (degree + visited_size) * sizeof(float) + // Occlusion list (degree + visited_size) * sizeof(DistPair) + - cand_coords_smem_size; + cand_coords_smem_size + + degree * static_cast(sizeof(accT)); // graph edge dist cache RAFT_LOG_DEBUG( "Dynamic shared memory usage (bytes): GreedySearch: %d, Segment Sort: %d, Robust Prune: %d", diff --git a/cpp/src/neighbors/detail/vamana/vamana_structs.cuh b/cpp/src/neighbors/detail/vamana/vamana_structs.cuh index fd656609d5..ac8155c6ef 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_structs.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_structs.cuh @@ -644,6 +644,30 @@ __forceinline__ __device__ SUMTYPE dist(const float* src, return d; } +// Warp-cooperative id lookup in a GreedySearch visited list (sorted by distance, not id). +// All lanes execute the same number of ballot rounds to avoid warp deadlock. +template +__forceinline__ __device__ bool lookup_visited_dist_warp( + const IdxT* ids, const accT* dists, int size, IdxT target, accT& out_dist, int laneId) +{ + bool found = false; + accT found_dist = static_cast(0); + const int num_iters = (size + 31) >> 5; + for (int k = 0; k < num_iters; ++k) { + const int j = (k << 5) + laneId; + const bool hit = (j < size) && (ids[j] == target); + accT my_dist = hit ? dists[j] : static_cast(0); + const unsigned hits = raft::ballot(hit); + if (hits != 0 && !found) { + const int src_lane = __ffs(hits) - 1; + found_dist = raft::shfl(my_dist, src_lane); + found = true; + } + } + if (found) { out_dist = found_dist; } + return found; +} + /*************************************************************************************** * Structure that holds information about and results of a query. Use by both * GreedySearch and RobustPrune, as well as reverse edge lists. From bbe3d502c1dcaee34cc608a3286fe7e87588963b Mon Sep 17 00:00:00 2001 From: bkarsin Date: Mon, 15 Jun 2026 19:09:43 -0700 Subject: [PATCH 09/14] =?UTF-8?q?perf(vamana):=20M3=5Ffp16=5Fquery=5Fsmem?= =?UTF-8?q?=5Foccupancy=20=E2=80=94=20Store=20the=20cached=20query=20coord?= =?UTF-8?q?s=20in=20FP16=20smem=20for=20dim>=3D512=20to=20raise=20GreedySe?= =?UTF-8?q?arch=20occupancy=20(salvage=20of=20N8)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../neighbors/detail/vamana/greedy_search.cuh | 84 +++-- .../detail/vamana/priority_queue.cuh | 39 +- .../neighbors/detail/vamana/vamana_build.cuh | 5 +- .../detail/vamana/vamana_structs.cuh | 354 ++++++++++++++++++ 4 files changed, 450 insertions(+), 32 deletions(-) diff --git a/cpp/src/neighbors/detail/vamana/greedy_search.cuh b/cpp/src/neighbors/detail/vamana/greedy_search.cuh index f23248222f..2511dccf52 100644 --- a/cpp/src/neighbors/detail/vamana/greedy_search.cuh +++ b/cpp/src/neighbors/detail/vamana/greedy_search.cuh @@ -111,23 +111,22 @@ __global__ __launch_bounds__(128, 12) void GreedySearchKernel( using QueryCoordT = typename greedy_search_query_coord::type; - union ShmemLayout { - QueryCoordT coords; - IdxT neighborhood_arr; - DistPair candidate_queue; - }; int align_padding = raft::alignTo(dim, 16) - dim; - extern __shared__ __align__(alignof(ShmemLayout)) char smem[]; + const bool fp16_query_smem = greedy_search_use_fp16_query_smem(dim); + + extern __shared__ __align__(16) char smem[]; // Per-warp shared memory layout: coords, neighbor_array, candidate_queue - const int coords_size = (dim + align_padding) * sizeof(QueryCoordT); + const int coords_size = + (dim + align_padding) * greedy_search_query_smem_elem_size(dim); const int neighbor_size = degree * sizeof(IdxT); const int queue_size_bytes = max_queue_size * sizeof(DistPair); const int per_warp_size = (coords_size + neighbor_size + queue_size_bytes + 15) & ~15; char* warp_smem = &smem[warpIdx * per_warp_size]; - QueryCoordT* s_coords = reinterpret_cast(warp_smem); + __half* s_coords_half = reinterpret_cast<__half*>(warp_smem); + QueryCoordT* s_coords = reinterpret_cast(warp_smem); IdxT* neighbor_array = reinterpret_cast(warp_smem + coords_size); DistPair* candidate_queue_smem = reinterpret_cast*>(warp_smem + coords_size + neighbor_size); @@ -138,9 +137,15 @@ __global__ __launch_bounds__(128, 12) void GreedySearchKernel( static __shared__ int k_max_idx[4]; static __shared__ int num_neighbors[4]; + Point<__half, accT> s_query_half; Point s_query; - s_query.Dim = dim; - s_query.coords = s_coords; + if (fp16_query_smem) { + s_query_half.Dim = dim; + s_query_half.coords = s_coords_half; + } else { + s_query.Dim = dim; + s_query.coords = s_coords; + } PriorityQueue heap_queue; if (laneId == 0) { @@ -153,35 +158,54 @@ __global__ __launch_bounds__(128, 12) void GreedySearchKernel( for (int i = blockIdx.x * 4 + warpIdx; i < num_queries; i += gridDim.x * 4) { query_list[i].reset_warp(laneId); - if constexpr (is_cuda_fp16_v) { + int cur_query_id = query_list[i].queryId; + if (fp16_query_smem) { + if constexpr (is_cuda_fp16_v) { + update_shared_point_warp_fp16_query_smem( + &s_query_half, vec_ptr, cur_query_id, dim, laneId); + } else if constexpr (std::is_same_v) { + update_shared_point_warp_fp16_query_smem( + &s_query_half, vec_ptr, cur_query_id, dim, laneId); + } else { + update_shared_point_warp_fp16_query_smem( + &s_query_half, vec_ptr, cur_query_id, dim, laneId); + } + } else if constexpr (is_cuda_fp16_v) { update_shared_point_warp_half_to_float( - &s_query, vec_ptr, query_list[i].queryId, dim, laneId); + &s_query, vec_ptr, cur_query_id, dim, laneId); } else { - update_shared_point_warp(&s_query, vec_ptr, query_list[i].queryId, dim, laneId); + update_shared_point_warp(&s_query, vec_ptr, cur_query_id, dim, laneId); } if (laneId == 0) { topk_q_size[warpIdx] = 0; cand_q_size[warpIdx] = 0; - s_query.id = query_list[i].queryId; + if (fp16_query_smem) { + s_query_half.id = cur_query_id; + } else { + s_query.id = cur_query_id; + } cur_k_max[warpIdx] = 0; k_max_idx[warpIdx] = 0; heap_queue.reset(); } - Point* query_vec = &s_query; - query_vec->Dim = dim; - query_vec->coords = s_coords; accT medoid_dist; - if constexpr (is_cuda_fp16_v) { - medoid_dist = dist_warp(query_vec->coords, + if (fp16_query_smem) { + medoid_dist = dist_warp_half_query(s_coords_half, + &vec_ptr[(size_t)medoid_id * (size_t)dim], + dim, + metric, + laneId); + } else if constexpr (is_cuda_fp16_v) { + medoid_dist = dist_warp(s_coords, &vec_ptr[(size_t)medoid_id * (size_t)dim], dim, metric, laneId); } else { medoid_dist = dist_warp( - query_vec->coords, &vec_ptr[(size_t)medoid_id * (size_t)dim], dim, metric, laneId); + s_coords, &vec_ptr[(size_t)medoid_id * (size_t)dim], dim, metric, laneId); } if (laneId == 0) { heap_queue.insert_back(medoid_dist, medoid_id); } @@ -242,19 +266,21 @@ __global__ __launch_bounds__(128, 12) void GreedySearchKernel( atomicMin(&num_neighbors[warpIdx], (int)j); } - enqueue_all_neighbors_warp(num_neighbors[warpIdx], - query_vec, - vec_ptr, - neighbor_array, - heap_queue, - dim, - metric, - laneId); + enqueue_all_neighbors_warp(num_neighbors[warpIdx], + fp16_query_smem, + s_coords_half, + s_coords, + vec_ptr, + neighbor_array, + heap_queue, + dim, + metric, + laneId); } bool self_found = false; for (int j = laneId; j < query_list[i].size; j += 32) { - if (query_list[i].ids[j] == query_vec->id) { + if (query_list[i].ids[j] == cur_query_id) { query_list[i].dists[j] = raft::upper_bound(); query_list[i].ids[j] = raft::upper_bound(); self_found = true; diff --git a/cpp/src/neighbors/detail/vamana/priority_queue.cuh b/cpp/src/neighbors/detail/vamana/priority_queue.cuh index 891065bf97..fec95cd15e 100644 --- a/cpp/src/neighbors/detail/vamana/priority_queue.cuh +++ b/cpp/src/neighbors/detail/vamana/priority_queue.cuh @@ -397,7 +397,10 @@ __forceinline__ __device__ void enqueue_all_neighbors_warp(int num_neighbors, for (int i = 0; i < num_neighbors; i++) { const DataT* neighbor_vec = &vec_ptr[(size_t)(neighbor_array[i]) * (size_t)(dim)]; accT dist_out; - if constexpr (std::is_same_v && is_cuda_fp16_v) { + if constexpr (std::is_same_v) { + dist_out = dist_warp_half_query( + query_vec->coords, neighbor_vec, dim, metric, laneId); + } else if constexpr (std::is_same_v && is_cuda_fp16_v) { dist_out = dist_warp(query_vec->coords, neighbor_vec, dim, metric, laneId); } else { static_assert(std::is_same_v); @@ -407,4 +410,38 @@ __forceinline__ __device__ void enqueue_all_neighbors_warp(int num_neighbors, } } +template +__forceinline__ __device__ void enqueue_all_neighbors_warp(int num_neighbors, + bool fp16_query_smem, + __half* s_coords_half, + typename greedy_search_query_coord::type* + s_coords, + const T* vec_ptr, + IdxT* neighbor_array, + PriorityQueue& heap_queue, + int dim, + cuvs::distance::DistanceType metric, + int laneId) +{ + if (fp16_query_smem) { + Point<__half, accT> query_vec; + query_vec.coords = s_coords_half; + query_vec.Dim = dim; + enqueue_all_neighbors_warp<__half, T, accT, IdxT>( + num_neighbors, &query_vec, vec_ptr, neighbor_array, heap_queue, dim, metric, laneId); + } else if constexpr (is_cuda_fp16_v) { + Point query_vec; + query_vec.coords = reinterpret_cast(s_coords); + query_vec.Dim = dim; + enqueue_all_neighbors_warp( + num_neighbors, &query_vec, vec_ptr, neighbor_array, heap_queue, dim, metric, laneId); + } else { + Point query_vec; + query_vec.coords = reinterpret_cast(s_coords); + query_vec.Dim = dim; + enqueue_all_neighbors_warp( + num_neighbors, &query_vec, vec_ptr, neighbor_array, heap_queue, dim, metric, laneId); + } +} + } // namespace cuvs::neighbors::vamana::detail diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index c7330a9cde..fd7a60b9d5 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -217,12 +217,13 @@ void batched_insert_vamana( SELECT_SORT_SMEM_SIZE(degree, visited_size); // Sets sort_smem_size based on dataset // GreedySearch: per-warp shared memory (4 warps): coords, neighbor_array, candidate_queue - // Half datasets promote query coords to float in smem. + const int search_coords_size = + (dim + align_padding) * greedy_search_query_smem_elem_size(dim); const int coords_size = (dim + align_padding) * static_cast(sizeof(QueryCoordT)); const int neighbor_size = degree * sizeof(IdxT); const int queue_size_bytes = queue_size * sizeof(DistPair); int search_smem_total_size = - static_cast(4 * ((coords_size + neighbor_size + queue_size_bytes + 15) & ~15)); + static_cast(4 * ((search_coords_size + neighbor_size + queue_size_bytes + 15) & ~15)); // Total dynamic shared memory size needed by both RobustPrune calls const int cand_coords_smem_size = diff --git a/cpp/src/neighbors/detail/vamana/vamana_structs.cuh b/cpp/src/neighbors/detail/vamana/vamana_structs.cuh index ac8155c6ef..30b53577d9 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_structs.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_structs.cuh @@ -47,6 +47,24 @@ struct greedy_search_query_coord { using type = std::conditional_t, float, T>; }; +// Wide vectors: store cached GreedySearch query coords as __half in smem (dim >= 512 only). +// Half dataset is excluded: it already caches query as float (promote once); fp16 smem would +// re-widen on every distance and lose the vectorized float-query vs half-neighbor path. +static constexpr int kGreedySearchFp16QuerySmemMinDim = 512; + +template +__host__ __device__ inline bool greedy_search_use_fp16_query_smem(int dim) +{ + return dim >= kGreedySearchFp16QuerySmemMinDim && !is_cuda_fp16_v; +} + +template +__host__ __device__ inline int greedy_search_query_smem_elem_size(int dim) +{ + if (greedy_search_use_fp16_query_smem(dim)) { return static_cast(sizeof(__half)); } + return static_cast(sizeof(typename greedy_search_query_coord::type)); +} + // Currently supported values for graph_degree. static const int DEGREE_SIZES[4] = {32, 64, 128, 256}; @@ -607,6 +625,287 @@ l2_warp_float_half(const float* src, const __half* dest, int dim, int lane) } } +/* half query vs float dataset: widen query coords to float at use (mirror of float_half) */ +__device__ __forceinline__ float2 l2_load_src2_half(const __half* src, int i) +{ + return __half22float2(*reinterpret_cast(&src[i])); +} + +template +__device__ SUMTYPE l2_SEQ_warp_half_float(const __half* src, const float* dst, int dim, int lane) +{ + SUMTYPE partial_sum = 0; + for (int i = lane * 2; i < dim; i += VAMANA_WARP_SIZE * 2) { + float2 src2 = l2_load_src2_half(src, i); + float2 dst2 = l2_load_src2(dst, i); + l2_fma_sq2(partial_sum, src2.x, src2.y, dst2); + } + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum += __shfl_down_sync(FULL_BITMASK, partial_sum, offset); + } + return partial_sum; +} + +template +__device__ SUMTYPE l2_ILP2_warp_half_float(const __half* src, const float* dst, int dim, int lane) +{ + SUMTYPE partial_sum[2] = {0, 0}; + for (int i = lane * 2; i < dim; i += 2 * VAMANA_WARP_SIZE * 2) { + float2 temp_dst[2] = {{0, 0}, {0, 0}}; + temp_dst[0] = l2_load_src2(dst, i); + if (i + 64 < dim) temp_dst[1] = l2_load_src2(dst, i + 64); + + float2 src0 = l2_load_src2_half(src, i); + l2_fma_sq2(partial_sum[0], src0.x, src0.y, temp_dst[0]); + if (i + 64 < dim) { + float2 src1 = l2_load_src2_half(src, i + 64); + l2_fma_sq2(partial_sum[1], src1.x, src1.y, temp_dst[1]); + } + } + partial_sum[0] += partial_sum[1]; + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + } + return partial_sum[0]; +} + +template +__device__ SUMTYPE l2_ILP4_warp_half_float(const __half* src, const float* dst, int dim, int lane) +{ + SUMTYPE partial_sum[4] = {0, 0, 0, 0}; + for (int i = lane * 2; i < dim; i += 4 * VAMANA_WARP_SIZE * 2) { + float2 temp_dst[4] = {{0, 0}, {0, 0}, {0, 0}, {0, 0}}; + temp_dst[0] = l2_load_src2(dst, i); + if (i + 64 < dim) temp_dst[1] = l2_load_src2(dst, i + 64); + if (i + 128 < dim) temp_dst[2] = l2_load_src2(dst, i + 128); + if (i + 192 < dim) temp_dst[3] = l2_load_src2(dst, i + 192); + + float2 src0 = l2_load_src2_half(src, i); + l2_fma_sq2(partial_sum[0], src0.x, src0.y, temp_dst[0]); + if (i + 64 < dim) { + float2 src1 = l2_load_src2_half(src, i + 64); + l2_fma_sq2(partial_sum[1], src1.x, src1.y, temp_dst[1]); + } + if (i + 128 < dim) { + float2 src2 = l2_load_src2_half(src, i + 128); + l2_fma_sq2(partial_sum[2], src2.x, src2.y, temp_dst[2]); + } + if (i + 192 < dim) { + float2 src3 = l2_load_src2_half(src, i + 192); + l2_fma_sq2(partial_sum[3], src3.x, src3.y, temp_dst[3]); + } + } + partial_sum[0] += partial_sum[1] + partial_sum[2] + partial_sum[3]; + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + } + return partial_sum[0]; +} + +template +__forceinline__ __device__ SUMTYPE +l2_warp_half_float(const __half* src, const float* dest, int dim, int lane) +{ + if (dim >= 128) { + return l2_ILP4_warp_half_float(src, dest, dim, lane); + } else if (dim >= 64) { + return l2_ILP2_warp_half_float(src, dest, dim, lane); + } else { + return l2_SEQ_warp_half_float(src, dest, dim, lane); + } +} + +/* fp16 query smem vs half dataset: vectorized half2 loads, widen to float, float accumulate + * (mirror of l2_warp_float_half; avoids scalar smem loads and native half FMA) */ +template +__device__ SUMTYPE l2_SEQ_warp_half_smem_half(const __half* src, const __half* dst, int dim, int lane) +{ + SUMTYPE partial_sum = 0; + for (int i = lane * 2; i < dim; i += VAMANA_WARP_SIZE * 2) { + float2 dst2 = l2_load_dst2_half(dst, i); + float2 src2 = l2_load_src2_half(src, i); + l2_fma_sq2(partial_sum, src2.x, src2.y, dst2); + } + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum += __shfl_down_sync(FULL_BITMASK, partial_sum, offset); + } + return partial_sum; +} + +template +__device__ SUMTYPE l2_ILP2_warp_half_smem_half(const __half* src, const __half* dst, int dim, int lane) +{ + SUMTYPE partial_sum[2] = {0, 0}; + for (int i = lane * 2; i < dim; i += 2 * VAMANA_WARP_SIZE * 2) { + float2 temp_dst[2] = {{0, 0}, {0, 0}}; + temp_dst[0] = l2_load_dst2_half(dst, i); + if (i + 64 < dim) temp_dst[1] = l2_load_dst2_half(dst, i + 64); + + float2 src0 = l2_load_src2_half(src, i); + l2_fma_sq2(partial_sum[0], src0.x, src0.y, temp_dst[0]); + if (i + 64 < dim) { + float2 src1 = l2_load_src2_half(src, i + 64); + l2_fma_sq2(partial_sum[1], src1.x, src1.y, temp_dst[1]); + } + } + partial_sum[0] += partial_sum[1]; + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + } + return partial_sum[0]; +} + +template +__device__ SUMTYPE l2_ILP4_warp_half_smem_half(const __half* src, const __half* dst, int dim, int lane) +{ + SUMTYPE partial_sum[4] = {0, 0, 0, 0}; + for (int i = lane * 2; i < dim; i += 4 * VAMANA_WARP_SIZE * 2) { + float2 temp_dst[4] = {{0, 0}, {0, 0}, {0, 0}, {0, 0}}; + temp_dst[0] = l2_load_dst2_half(dst, i); + if (i + 64 < dim) temp_dst[1] = l2_load_dst2_half(dst, i + 64); + if (i + 128 < dim) temp_dst[2] = l2_load_dst2_half(dst, i + 128); + if (i + 192 < dim) temp_dst[3] = l2_load_dst2_half(dst, i + 192); + + float2 src0 = l2_load_src2_half(src, i); + l2_fma_sq2(partial_sum[0], src0.x, src0.y, temp_dst[0]); + if (i + 64 < dim) { + float2 src1 = l2_load_src2_half(src, i + 64); + l2_fma_sq2(partial_sum[1], src1.x, src1.y, temp_dst[1]); + } + if (i + 128 < dim) { + float2 src2 = l2_load_src2_half(src, i + 128); + l2_fma_sq2(partial_sum[2], src2.x, src2.y, temp_dst[2]); + } + if (i + 192 < dim) { + float2 src3 = l2_load_src2_half(src, i + 192); + l2_fma_sq2(partial_sum[3], src3.x, src3.y, temp_dst[3]); + } + } + partial_sum[0] += partial_sum[1] + partial_sum[2] + partial_sum[3]; + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + } + return partial_sum[0]; +} + +template +__forceinline__ __device__ SUMTYPE +l2_warp_half_smem_half(const __half* src, const __half* dest, int dim, int lane) +{ + if (dim >= 128) { + return l2_ILP4_warp_half_smem_half(src, dest, dim, lane); + } else if (dim >= 64) { + return l2_ILP2_warp_half_smem_half(src, dest, dim, lane); + } else { + return l2_SEQ_warp_half_smem_half(src, dest, dim, lane); + } +} + +/* fp16 query smem vs int8 (or other native) dataset: same vectorized query widen, float accumulate */ +template +__device__ __forceinline__ void l2_fma_sq2_half_native(SUMTYPE& acc, + float2 src2, + const DataT* dst, + int i) +{ + float dx = src2.x - static_cast(dst[i]); + float dy = src2.y - static_cast(dst[i + 1]); + acc = fmaf(dx, dx, acc); + acc = fmaf(dy, dy, acc); +} + +template +__device__ SUMTYPE l2_SEQ_warp_half_smem_native(const __half* src, const DataT* dst, int dim, int lane) +{ + SUMTYPE partial_sum = 0; + for (int i = lane * 2; i < dim; i += VAMANA_WARP_SIZE * 2) { + float2 src2 = l2_load_src2_half(src, i); + l2_fma_sq2_half_native(partial_sum, src2, dst, i); + } + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum += __shfl_down_sync(FULL_BITMASK, partial_sum, offset); + } + return partial_sum; +} + +template +__device__ SUMTYPE l2_ILP2_warp_half_smem_native(const __half* src, const DataT* dst, int dim, int lane) +{ + SUMTYPE partial_sum[2] = {0, 0}; + for (int i = lane * 2; i < dim; i += 2 * VAMANA_WARP_SIZE * 2) { + float2 src0 = l2_load_src2_half(src, i); + l2_fma_sq2_half_native(partial_sum[0], src0, dst, i); + if (i + 64 < dim) { + float2 src1 = l2_load_src2_half(src, i + 64); + l2_fma_sq2_half_native(partial_sum[1], src1, dst, i + 64); + } + } + partial_sum[0] += partial_sum[1]; + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + } + return partial_sum[0]; +} + +template +__device__ SUMTYPE l2_ILP4_warp_half_smem_native(const __half* src, const DataT* dst, int dim, int lane) +{ + SUMTYPE partial_sum[4] = {0, 0, 0, 0}; + for (int i = lane * 2; i < dim; i += 4 * VAMANA_WARP_SIZE * 2) { + float2 src0 = l2_load_src2_half(src, i); + l2_fma_sq2_half_native(partial_sum[0], src0, dst, i); + if (i + 64 < dim) { + float2 src1 = l2_load_src2_half(src, i + 64); + l2_fma_sq2_half_native(partial_sum[1], src1, dst, i + 64); + } + if (i + 128 < dim) { + float2 src2 = l2_load_src2_half(src, i + 128); + l2_fma_sq2_half_native(partial_sum[2], src2, dst, i + 128); + } + if (i + 192 < dim) { + float2 src3 = l2_load_src2_half(src, i + 192); + l2_fma_sq2_half_native(partial_sum[3], src3, dst, i + 192); + } + } + partial_sum[0] += partial_sum[1] + partial_sum[2] + partial_sum[3]; + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + } + return partial_sum[0]; +} + +template +__forceinline__ __device__ SUMTYPE +l2_warp_half_smem_native(const __half* src, const DataT* dest, int dim, int lane) +{ + if (dim >= 128) { + return l2_ILP4_warp_half_smem_native(src, dest, dim, lane); + } else if (dim >= 64) { + return l2_ILP2_warp_half_smem_native(src, dest, dim, lane); + } else { + return l2_SEQ_warp_half_smem_native(src, dest, dim, lane); + } +} + +template +__forceinline__ __device__ SUMTYPE +dist_warp_half_query( + const __half* src, const DataT* dest, int dim, cuvs::distance::DistanceType metric, int lane) +{ + SUMTYPE d; + if constexpr (is_cuda_fp16_v) { + d = l2_warp_half_smem_half(src, reinterpret_cast(dest), dim, lane); + } else if constexpr (std::is_same_v) { + d = l2_warp_half_float(src, dest, dim, lane); + } else { + d = l2_warp_half_smem_native(src, dest, dim, lane); + } + if (metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + return static_cast(sqrtf(static_cast(d))); + } + return d; +} + template __forceinline__ __device__ SUMTYPE dist_warp( const float* src, const half* dest, int dim, cuvs::distance::DistanceType metric, int lane) @@ -911,6 +1210,61 @@ __device__ void update_shared_point_warp_half_to_float(Point* share } } +template +__device__ void update_shared_point_warp_fp16_query_smem(Point<__half, accT>* shared_point, + const half* data_ptr, + int id, + int dim, + int laneId) +{ + const __half* half_ptr = reinterpret_cast(data_ptr); + shared_point->id = id; + shared_point->Dim = dim; + const size_t base = (size_t)id * (size_t)dim; + for (size_t i = laneId * 2; i + 1 < (size_t)dim; i += 64) { + *reinterpret_cast(&shared_point->coords[i]) = + *reinterpret_cast(&half_ptr[base + i]); + } + if (((size_t)dim & 1u) != 0u && laneId == 0) { + shared_point->coords[dim - 1] = half_ptr[base + dim - 1]; + } +} + +template +__device__ void update_shared_point_warp_fp16_query_smem(Point<__half, accT>* shared_point, + const float* data_ptr, + int id, + int dim, + int laneId) +{ + shared_point->id = id; + shared_point->Dim = dim; + const size_t base = (size_t)id * (size_t)dim; + for (size_t i = laneId * 2; i + 1 < (size_t)dim; i += 64) { + float2 v = *reinterpret_cast(&data_ptr[base + i]); + *reinterpret_cast(&shared_point->coords[i]) = __float22half2_rn(v); + } + if (((size_t)dim & 1u) != 0u && laneId == 0) { + shared_point->coords[dim - 1] = __float2half(data_ptr[base + dim - 1]); + } +} + +template +__device__ std::enable_if_t && !std::is_same_v, void> +update_shared_point_warp_fp16_query_smem(Point<__half, accT>* shared_point, + const T* data_ptr, + int id, + int dim, + int laneId) +{ + shared_point->id = id; + shared_point->Dim = dim; + const size_t base = (size_t)id * (size_t)dim; + for (size_t i = laneId; i < (size_t)dim; i += 32) { + shared_point->coords[i] = __float2half(static_cast(data_ptr[base + i])); + } +} + // Update the graph from the results of the query list (or reverse edge list) template __global__ void write_graph_edges_kernel(raft::device_matrix_view graph, From eeca2865d3d46b70af8f45663198767ecd0e9848 Mon Sep 17 00:00:00 2001 From: bkarsin Date: Mon, 15 Jun 2026 20:43:59 -0700 Subject: [PATCH 10/14] =?UTF-8?q?perf(vamana):=20M5=5Fprune=5Fsingle=5Fwar?= =?UTF-8?q?p=5Fmerge=20=E2=80=94=20Do=20the=20RobustPrune=20merge=20once?= =?UTF-8?q?=20(one=20warp)=20instead=20of=20redundantly=20on=20all=20128?= =?UTF-8?q?=20threads,=20then=20broadcast?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../neighbors/detail/vamana/robust_prune.cuh | 137 ++++++++++-------- .../neighbors/detail/vamana/vamana_build.cuh | 4 +- 2 files changed, 81 insertions(+), 60 deletions(-) diff --git a/cpp/src/neighbors/detail/vamana/robust_prune.cuh b/cpp/src/neighbors/detail/vamana/robust_prune.cuh index 71707d03df..d4bd9beb02 100644 --- a/cpp/src/neighbors/detail/vamana/robust_prune.cuh +++ b/cpp/src/neighbors/detail/vamana/robust_prune.cuh @@ -88,16 +88,22 @@ __global__ void RobustPruneKernel( const int nbh_list_offset = (degree + visited_size) * sizeof(float); DistPair* new_nbh_list = reinterpret_cast*>(&smem[nbh_list_offset]); - const int cand_coords_offset = + const int query_cache_offset = nbh_list_offset + (degree + visited_size) * sizeof(DistPair); + DistPair* query_cache = + reinterpret_cast*>(&smem[query_cache_offset]); + const int cand_coords_offset = + query_cache_offset + visited_size * sizeof(DistPair); const int coord_bytes = (dim + align_padding) * static_cast(sizeof(QueryCoordT)); int graph_dists_offset = cand_coords_offset; QueryCoordT* s_cand_coords = nullptr; if (dim >= kRobustPruneCandCacheMinDim) { - s_cand_coords = reinterpret_cast(&smem[cand_coords_offset]); - graph_dists_offset = cand_coords_offset + coord_bytes; + s_cand_coords = reinterpret_cast(&smem[cand_coords_offset]); + graph_dists_offset = cand_coords_offset + coord_bytes; } accT* graph_dists = reinterpret_cast(&smem[graph_dists_offset]); + const int graph_ids_offset = graph_dists_offset + degree * sizeof(accT); + IdxT* graph_ids = reinterpret_cast(&smem[graph_ids_offset]); static __shared__ Point s_query; s_query.coords = &s_coords_mem[blockIdx.x * (dim + align_padding)]; @@ -105,6 +111,7 @@ __global__ void RobustPruneKernel( static __shared__ int prev_edges; static __shared__ int s_accept_count; static __shared__ int s_do_accept; + static __shared__ int s_res_size; const int laneId = threadIdx.x & 31; const int warpId = threadIdx.x >> 5; @@ -119,10 +126,6 @@ __global__ void RobustPruneKernel( update_shared_point(&s_query, &dataset(0, 0), queryId, dim, i); } - int graphIdx = 0; - int listIdx = 0; - int res_size = degree + visited_size; - // Count total valid edge candidates __syncthreads(); if (threadIdx.x == 0) { @@ -136,6 +139,10 @@ __global__ void RobustPruneKernel( } for (int j = threadIdx.x; j < degree + visited_size; j += blockDim.x) { occlusion_list[j] = 0.0; + if (j < visited_size) { + query_cache[j].idx = query_list[i].ids[j]; + query_cache[j].dist = query_list[i].dists[j]; + } } __syncthreads(); @@ -161,65 +168,77 @@ __global__ void RobustPruneKernel( s_query.coords, &dataset((size_t)gid, 0), dim, metric, laneId); } } - if (laneId == 0) { graph_dists[j] = d; } + if (laneId == 0) { + graph_dists[j] = d; + graph_ids[j] = gid; + } + } + for (int j = threadIdx.x; j < degree; j += blockDim.x) { + if (j >= prev_edges) { graph_ids[j] = raft::upper_bound(); } } __syncthreads(); - DistPair next_cand; - // Merge graph and candidate list - for (int outIdx = 0; outIdx < degree + visited_size; outIdx++) { - // Check if no more valid elements from graph or list - if (graphIdx < degree && graph(queryId, graphIdx) == raft::upper_bound()) { - graphIdx = degree; - } - if (listIdx < visited_size && query_list[i].ids[listIdx] == raft::upper_bound()) { - listIdx = visited_size; - } - - // Get next candidate vector for list - if (graphIdx >= degree) { - if (listIdx >= visited_size) { // Fill remaining list if no candidates - if (res_size > outIdx) res_size = outIdx; // Set result size - new_nbh_list[outIdx].idx = raft::upper_bound(); - new_nbh_list[outIdx].dist = raft::upper_bound(); - __syncthreads(); - continue; - } else { - next_cand.idx = query_list[i].ids[listIdx]; - next_cand.dist = query_list[i].dists[listIdx]; - listIdx++; + if (threadIdx.x == 0) { + int graphIdx = 0; + int listIdx = 0; + int merged_size = degree + visited_size; + + DistPair next_cand; + // Merge graph and candidate list from smem (no global reads during merge). + for (int outIdx = 0; outIdx < degree + visited_size; outIdx++) { + // Check if no more valid elements from graph or list + if (graphIdx < degree && graph_ids[graphIdx] == raft::upper_bound()) { + graphIdx = degree; } - } else if (listIdx >= visited_size) { - next_cand.idx = graph(queryId, graphIdx); - next_cand.dist = graph_dists[graphIdx]; - graphIdx++; - } else { - accT listDist = query_list[i].dists[listIdx]; - IdxT listId = query_list[i].ids[listIdx]; - IdxT graphId = graph(queryId, graphIdx); - - if (graphId == listId) { - next_cand.idx = listId; - next_cand.dist = listDist; - graphIdx++; - listIdx++; - } else if (listDist <= graph_dists[graphIdx]) { - next_cand.idx = listId; - next_cand.dist = listDist; - listIdx++; - } else { - next_cand.idx = graphId; + if (listIdx < visited_size && query_cache[listIdx].idx == raft::upper_bound()) { + listIdx = visited_size; + } + + // Get next candidate vector for list + if (graphIdx >= degree) { + if (listIdx >= visited_size) { // Fill remaining list if no candidates + if (merged_size > outIdx) merged_size = outIdx; // Set result size + new_nbh_list[outIdx].idx = raft::upper_bound(); + new_nbh_list[outIdx].dist = raft::upper_bound(); + continue; + } else { + next_cand = query_cache[listIdx]; + listIdx++; + } + } else if (listIdx >= visited_size) { + next_cand.idx = graph_ids[graphIdx]; next_cand.dist = graph_dists[graphIdx]; graphIdx++; + } else { + accT listDist = query_cache[listIdx].dist; + IdxT listId = query_cache[listIdx].idx; + IdxT graphId = graph_ids[graphIdx]; + + if (graphId == listId) { + next_cand.idx = listId; + next_cand.dist = listDist; + graphIdx++; + listIdx++; + } else if (listDist <= graph_dists[graphIdx]) { + next_cand.idx = listId; + next_cand.dist = listDist; + listIdx++; + } else { + next_cand.idx = graphId; + next_cand.dist = graph_dists[graphIdx]; + graphIdx++; + } } - } - new_nbh_list[outIdx].idx = next_cand.idx; - new_nbh_list[outIdx].dist = next_cand.dist; + new_nbh_list[outIdx].idx = next_cand.idx; + new_nbh_list[outIdx].dist = next_cand.dist; + } + s_res_size = merged_size; } + __syncthreads(); // If we need to prune at all... - if (res_size > degree) { + if (s_res_size > degree) { if (threadIdx.x == 0) s_accept_count = 0; __syncthreads(); const bool cache_cand_in_smem = dim >= kRobustPruneCandCacheMinDim; @@ -232,7 +251,7 @@ __global__ void RobustPruneKernel( // Go through different alpha values. These constants are hard-coded in the MSFT DiskANN code for (float cur_alpha = 1.0; cur_alpha <= alpha && s_accept_count < degree; cur_alpha *= 1.2) { - for (int pass_start = 0; pass_start < res_size && s_accept_count < degree; pass_start++) { + for (int pass_start = 0; pass_start < s_res_size && s_accept_count < degree; pass_start++) { if (threadIdx.x == 0) { s_do_accept = (occlusion_list[pass_start] != raft::lower_bound() && occlusion_list[pass_start] <= cur_alpha && @@ -260,7 +279,7 @@ __global__ void RobustPruneKernel( } T* cand_ptr = const_cast(&dataset((size_t)(new_nbh_list[pass_start].idx), 0)); - for (int occId = pass_start + 1 + warpId; occId < res_size; occId += num_warps) { + for (int occId = pass_start + 1 + warpId; occId < s_res_size; occId += num_warps) { if (occlusion_list[occId] <= alpha && occlusion_list[occId] != raft::lower_bound()) { T* k_ptr = const_cast(&dataset((size_t)(new_nbh_list[occId].idx), 0)); @@ -302,7 +321,7 @@ __global__ void RobustPruneKernel( new_nbh_list[out_idx].dist = raft::upper_bound(); } - if (threadIdx.x == 0) { res_size = s_accept_count; } + if (threadIdx.x == 0) { s_res_size = s_accept_count; } __syncthreads(); } @@ -311,7 +330,7 @@ __global__ void RobustPruneKernel( query_list[i].ids[j] = new_nbh_list[j].idx; query_list[i].dists[j] = new_nbh_list[j].dist; } - if (threadIdx.x == 0) { query_list[i].size = res_size; } + if (threadIdx.x == 0) { query_list[i].size = s_res_size; } } } diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index fd7a60b9d5..a0534629fe 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -230,8 +230,10 @@ void batched_insert_vamana( (dim >= kRobustPruneCandCacheMinDim) ? coords_size : 0; int prune_smem_total_size = (degree + visited_size) * sizeof(float) + // Occlusion list (degree + visited_size) * sizeof(DistPair) + + visited_size * sizeof(DistPair) + // merge query cache cand_coords_smem_size + - degree * static_cast(sizeof(accT)); // graph edge dist cache + degree * static_cast(sizeof(accT)) + // graph edge dist cache + degree * static_cast(sizeof(IdxT)); // graph edge id cache RAFT_LOG_DEBUG( "Dynamic shared memory usage (bytes): GreedySearch: %d, Segment Sort: %d, Robust Prune: %d", From 4348e3ed61f5097b2bb3a269b43bd6235bc8374f Mon Sep 17 00:00:00 2001 From: bkarsin Date: Mon, 15 Jun 2026 21:20:15 -0700 Subject: [PATCH 11/14] =?UTF-8?q?perf(vamana):=20M6=5Fprune=5Fwarps=5Ftuni?= =?UTF-8?q?ng=20=E2=80=94=20Sweep=20RobustPrune=20warps-per-block=20(4=20v?= =?UTF-8?q?s=208)=20to=20raise=20occupancy/MLP=20on=20the=20degree-64=20oc?= =?UTF-8?q?clusion=20sweep?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cpp/src/neighbors/detail/vamana/macros.cuh | 21 +++++++++++++++++++ .../neighbors/detail/vamana/robust_prune.cuh | 2 ++ .../neighbors/detail/vamana/vamana_build.cuh | 9 +++++--- 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/cpp/src/neighbors/detail/vamana/macros.cuh b/cpp/src/neighbors/detail/vamana/macros.cuh index 714509b99e..9cb4faebef 100644 --- a/cpp/src/neighbors/detail/vamana/macros.cuh +++ b/cpp/src/neighbors/detail/vamana/macros.cuh @@ -5,11 +5,32 @@ #pragma once +#include +#include + namespace cuvs::neighbors::vamana::detail { // RobustPrune wide-dim optimizations: smem candidate cache and GreedySearch distance reuse. static constexpr int kRobustPruneCandCacheMinDim = 128; +// Minimum dim for 8 warps on degree-64 occlusion sweep (half). Below this, barrier tax wins. +static constexpr int kRobustPruneMultiWarpMinDimHalf = 960; + +// Occlusion sweep is multi-warp (occId += num_warps). Extra warps hide wide-dim djk latency, +// but narrow dim, low degree, and byte-wide (int8) / half distances have cheap djk -- barrier +// overhead wins unless dim is very wide. +template +__host__ __device__ inline int robust_prune_block_dim(int dim, int degree) +{ + if (degree < 64 || dim < kRobustPruneCandCacheMinDim) { return 128; } + if constexpr (std::is_same_v || std::is_same_v) { return 128; } + if constexpr (std::is_same_v) { + if (dim < kRobustPruneMultiWarpMinDimHalf) { return 128; } + return 256; + } + return 256; +} + /* Macros to compute the shared memory requirements for CUB primitives used by search and prune */ #define COMPUTE_SMEM_SIZE(degree, visited_size, DEG, CANDS) \ if (degree == DEG && visited_size <= CANDS && visited_size > CANDS / 2) { \ diff --git a/cpp/src/neighbors/detail/vamana/robust_prune.cuh b/cpp/src/neighbors/detail/vamana/robust_prune.cuh index d4bd9beb02..9bf5a691e2 100644 --- a/cpp/src/neighbors/detail/vamana/robust_prune.cuh +++ b/cpp/src/neighbors/detail/vamana/robust_prune.cuh @@ -261,6 +261,7 @@ __global__ void RobustPruneKernel( } __syncthreads(); + // update_shared_point uses all block threads; barrier must be uniform across warps. if (s_do_accept && cache_cand_in_smem) { if constexpr (is_cuda_fp16_v) { update_shared_point_half_to_float( @@ -300,6 +301,7 @@ __global__ void RobustPruneKernel( } } } + // Publish occId occlusion updates before the next pass_start reads occlusion_list. __syncthreads(); } } diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index a0534629fe..eca9b5d66b 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -49,7 +49,6 @@ namespace cuvs::neighbors::vamana::detail { static const int blockD = 32; static const int blockD_greedy = 128; // 4 warps per block, each warp processes one query -static const int blockD_prune = 128; // 4 warps per block, parallel occlusion per query static const int maxBlocks = 10000; // generate random permutation of inserts - TODO do this on GPU / faster @@ -235,11 +234,15 @@ void batched_insert_vamana( degree * static_cast(sizeof(accT)) + // graph edge dist cache degree * static_cast(sizeof(IdxT)); // graph edge id cache + const int blockD_prune = robust_prune_block_dim(dim, degree); + RAFT_LOG_DEBUG( - "Dynamic shared memory usage (bytes): GreedySearch: %d, Segment Sort: %d, Robust Prune: %d", + "Dynamic shared memory usage (bytes): GreedySearch: %d, Segment Sort: %d, Robust Prune: %d, " + "RobustPrune blockDim: %d", search_smem_total_size, sort_smem_size, - prune_smem_total_size); + prune_smem_total_size, + blockD_prune); #if KERNEL_TIMING auto end_t = std::chrono::system_clock::now(); From 4a0c179fdd8652636cc11b95c01f22b121e4401f Mon Sep 17 00:00:00 2001 From: bkarsin Date: Wed, 24 Jun 2026 15:39:19 -0700 Subject: [PATCH 12/14] Clean up and remove dead code --- cpp/include/cuvs/neighbors/vamana.hpp | 3 +- .../neighbors/detail/vamana/greedy_search.cuh | 9 +- cpp/src/neighbors/detail/vamana/macros.cuh | 2 +- .../detail/vamana/priority_queue.cuh | 104 +----------------- .../detail/vamana/vamana_structs.cuh | 32 ------ 5 files changed, 11 insertions(+), 139 deletions(-) diff --git a/cpp/include/cuvs/neighbors/vamana.hpp b/cpp/include/cuvs/neighbors/vamana.hpp index e00bda1fa4..6c54b09d43 100644 --- a/cpp/include/cuvs/neighbors/vamana.hpp +++ b/cpp/include/cuvs/neighbors/vamana.hpp @@ -287,8 +287,7 @@ struct index : cuvs::neighbors::index { * to improve graph quality. The index_params struct controls the degree of the final graph. * * The following distance metrics are supported: - * - L2Expanded (squared L2) - * - L2SqrtExpanded (Euclidean distance) + * - L2 * * Usage example: * @code{.cpp} diff --git a/cpp/src/neighbors/detail/vamana/greedy_search.cuh b/cpp/src/neighbors/detail/vamana/greedy_search.cuh index 2511dccf52..7d894bb61d 100644 --- a/cpp/src/neighbors/detail/vamana/greedy_search.cuh +++ b/cpp/src/neighbors/detail/vamana/greedy_search.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -113,6 +113,7 @@ __global__ __launch_bounds__(128, 12) void GreedySearchKernel( int align_padding = raft::alignTo(dim, 16) - dim; + // Only use fp16 coords in shared memory if type is fp16 and dim >= 512 const bool fp16_query_smem = greedy_search_use_fp16_query_smem(dim); extern __shared__ __align__(16) char smem[]; @@ -131,12 +132,14 @@ __global__ __launch_bounds__(128, 12) void GreedySearchKernel( DistPair* candidate_queue_smem = reinterpret_cast*>(warp_smem + coords_size + neighbor_size); + // 4 warps per block static __shared__ int topk_q_size[4]; static __shared__ int cand_q_size[4]; static __shared__ accT cur_k_max[4]; static __shared__ int k_max_idx[4]; static __shared__ int num_neighbors[4]; + // Different code path for fp16 since it is gated by dim and datatype Point<__half, accT> s_query_half; Point s_query; if (fp16_query_smem) { @@ -263,7 +266,7 @@ __global__ __launch_bounds__(128, 12) void GreedySearchKernel( for (size_t j = laneId; j < degree; j += 32) { neighbor_array[j] = graph(cand_num, j); if (neighbor_array[j] == raft::upper_bound()) - atomicMin(&num_neighbors[warpIdx], (int)j); + atomicMin(&num_neighbors[warpIdx], (int)j); // warp-wide min to find the number of neighbors } enqueue_all_neighbors_warp(num_neighbors[warpIdx], @@ -283,7 +286,7 @@ __global__ __launch_bounds__(128, 12) void GreedySearchKernel( if (query_list[i].ids[j] == cur_query_id) { query_list[i].dists[j] = raft::upper_bound(); query_list[i].ids[j] = raft::upper_bound(); - self_found = true; + self_found = true; // Flat to reduce size by 1 } } self_found = (raft::ballot(self_found) != 0); diff --git a/cpp/src/neighbors/detail/vamana/macros.cuh b/cpp/src/neighbors/detail/vamana/macros.cuh index 9cb4faebef..d154c34ae7 100644 --- a/cpp/src/neighbors/detail/vamana/macros.cuh +++ b/cpp/src/neighbors/detail/vamana/macros.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ diff --git a/cpp/src/neighbors/detail/vamana/priority_queue.cuh b/cpp/src/neighbors/detail/vamana/priority_queue.cuh index fec95cd15e..be213e8df1 100644 --- a/cpp/src/neighbors/detail/vamana/priority_queue.cuh +++ b/cpp/src/neighbors/detail/vamana/priority_queue.cuh @@ -211,27 +211,8 @@ __host__ __device__ bool operator>(const Node& first, const Node other.distance; } -template -__device__ bool check_duplicate(const Node* pq, const int size, Node new_node) -{ - bool found = false; - for (int i = threadIdx.x; i < size; i += blockDim.x) { - if (pq[i].nodeid == new_node.nodeid) { - found = true; - break; - } - } - unsigned mask = raft::ballot(found); - - if (mask == 0) - return false; - - else - return true; -} - -// Warp-level version: each warp scans its own pq with laneId and stride 32 +// each warp scans its own pq with laneId and stride 32 to find duplicates template __device__ bool check_duplicate_warp( const Node* pq, const int size, Node new_node, int laneId) @@ -247,65 +228,6 @@ __device__ bool check_duplicate_warp( return (mask != 0); } -/* - Enqueuing a input value into parallel queue with tracker -*/ -template -__inline__ __device__ void parallel_pq_max_enqueue(Node* pq, - int* size, - const int pq_size, - Node input_data, - SUMTYPE* cur_max_val, - int* max_idx) -{ - if (*size < pq_size) { - __syncthreads(); - if (threadIdx.x == 0) { - pq[*size].distance = input_data.distance; - pq[*size].nodeid = input_data.nodeid; - *size = *size + 1; - if (input_data.distance > (*cur_max_val)) { - *cur_max_val = input_data.distance; - *max_idx = *size - 1; - } - } - __syncthreads(); - return; - } else { - if (input_data.distance >= (*cur_max_val)) { - __syncthreads(); - return; - } - if (threadIdx.x == 0) { - pq[*max_idx].distance = input_data.distance; - pq[*max_idx].nodeid = input_data.nodeid; - } - int idx = 0; - SUMTYPE max_val = pq[0].distance; - - for (int i = threadIdx.x; i < pq_size; i += 32) { - if (pq[i].distance > max_val) { - max_val = pq[i].distance; - idx = i; - } - } - - for (int offset = 16; offset > 0; offset /= 2) { - SUMTYPE new_max_val = raft::shfl_up(max_val, offset); - int new_idx = raft::shfl_up(idx, offset); - if (new_max_val > max_val) { - max_val = new_max_val; - idx = new_idx; - } - } - - if (threadIdx.x == 31) { - *max_idx = idx; - *cur_max_val = max_val; - } - } - __syncthreads(); -} // Warp-level version: no __syncthreads, uses laneId for single-thread ops and warp shuffle template @@ -360,28 +282,6 @@ __inline__ __device__ void parallel_pq_max_enqueue_warp(Node* pq, } } -/* - Compute the distances between the source vector and all nodes in the neighbor_array and enqueue - them in the PQ -*/ -template -__forceinline__ __device__ void enqueue_all_neighbors(int num_neighbors, - Point* query_vec, - const T* vec_ptr, - IdxT* neighbor_array, - PriorityQueue& heap_queue, - int dim, - cuvs::distance::DistanceType metric) -{ - for (int i = 0; i < num_neighbors; i++) { - accT dist_out = dist( - query_vec->coords, &vec_ptr[(size_t)(neighbor_array[i]) * (size_t)(dim)], dim, metric); - - __syncthreads(); - if (threadIdx.x == 0) { heap_queue.insert_back(dist_out, neighbor_array[i]); } - __syncthreads(); - } -} // Warp-level version: lane 0 does insert_back, no __syncthreads template @@ -410,6 +310,8 @@ __forceinline__ __device__ void enqueue_all_neighbors_warp(int num_neighbors, } } + +// Half-precision version with two code paths based on query being fp16 or fp32 template __forceinline__ __device__ void enqueue_all_neighbors_warp(int num_neighbors, bool fp16_query_smem, diff --git a/cpp/src/neighbors/detail/vamana/vamana_structs.cuh b/cpp/src/neighbors/detail/vamana/vamana_structs.cuh index 30b53577d9..58b3266216 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_structs.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_structs.cuh @@ -928,21 +928,6 @@ dist_warp(const T* src, const T* dest, int dim, cuvs::distance::DistanceType met return d; } -/* Block/warp L2: float query vs half dataset (RobustPrune uses blockDim=32) */ -template -__forceinline__ __device__ SUMTYPE dist(const float* src, - const half* dest, - int dim, - cuvs::distance::DistanceType metric) -{ - SUMTYPE d = - l2_warp_float_half(src, reinterpret_cast(dest), dim, threadIdx.x); - if (metric == cuvs::distance::DistanceType::L2SqrtExpanded) { - return static_cast(sqrtf(static_cast(d))); - } - return d; -} - // Warp-cooperative id lookup in a GreedySearch visited list (sorted by distance, not id). // All lanes execute the same number of ballot rounds to avoid warp deadlock. template @@ -1110,23 +1095,6 @@ __global__ void set_query_ids(void* query_list_ptr, IdxT* d_query_ids, int step_ } } -// Compute prefix sums on sizes. Currently only works with 1 thread -// TODO replace with parallel version -template -__global__ void prefix_sums_sizes(QueryCandidates* query_list, - int num_queries, - int* total_edges) -{ - if (threadIdx.x == 0 && blockIdx.x == 0) { - int sum = 0; - for (int i = 0; i < num_queries + 1; i++) { - sum += query_list[i].size; - query_list[i].size = sum - query_list[i].size; // exclusive prefix sum - } - *total_edges = query_list[num_queries].size; - } -} - // Device fcn to have a threadblock copy coordinates into shared memory template __device__ void update_shared_point( From cb91c997f1bd1f9b5e9cd5dd1636400dd1e88cc2 Mon Sep 17 00:00:00 2001 From: bkarsin Date: Wed, 24 Jun 2026 15:50:18 -0700 Subject: [PATCH 13/14] pre-commit fixes --- c/src/neighbors/vamana.cpp | 2 +- .../neighbors/detail/vamana/greedy_search.cuh | 33 ++--- .../detail/vamana/priority_queue.cuh | 52 ++++---- .../neighbors/detail/vamana/robust_prune.cuh | 49 ++++--- .../neighbors/detail/vamana/vamana_build.cuh | 104 +++++++-------- .../detail/vamana/vamana_structs.cuh | 123 ++++++++---------- cpp/src/neighbors/vamana.cuh | 2 +- cpp/src/neighbors/vamana_serialize_half.cu | 2 +- examples/cpp/src/vamana_example.cu | 14 +- python/cuvs/cuvs/neighbors/vamana/vamana.pyx | 2 +- python/cuvs/cuvs/tests/test_vamana.py | 2 +- 11 files changed, 177 insertions(+), 208 deletions(-) diff --git a/c/src/neighbors/vamana.cpp b/c/src/neighbors/vamana.cpp index beabe6d616..03ffcf6811 100644 --- a/c/src/neighbors/vamana.cpp +++ b/c/src/neighbors/vamana.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ diff --git a/cpp/src/neighbors/detail/vamana/greedy_search.cuh b/cpp/src/neighbors/detail/vamana/greedy_search.cuh index 7d894bb61d..ba2c49834f 100644 --- a/cpp/src/neighbors/detail/vamana/greedy_search.cuh +++ b/cpp/src/neighbors/detail/vamana/greedy_search.cuh @@ -119,16 +119,15 @@ __global__ __launch_bounds__(128, 12) void GreedySearchKernel( extern __shared__ __align__(16) char smem[]; // Per-warp shared memory layout: coords, neighbor_array, candidate_queue - const int coords_size = - (dim + align_padding) * greedy_search_query_smem_elem_size(dim); + const int coords_size = (dim + align_padding) * greedy_search_query_smem_elem_size(dim); const int neighbor_size = degree * sizeof(IdxT); const int queue_size_bytes = max_queue_size * sizeof(DistPair); const int per_warp_size = (coords_size + neighbor_size + queue_size_bytes + 15) & ~15; - char* warp_smem = &smem[warpIdx * per_warp_size]; + char* warp_smem = &smem[warpIdx * per_warp_size]; __half* s_coords_half = reinterpret_cast<__half*>(warp_smem); - QueryCoordT* s_coords = reinterpret_cast(warp_smem); - IdxT* neighbor_array = reinterpret_cast(warp_smem + coords_size); + QueryCoordT* s_coords = reinterpret_cast(warp_smem); + IdxT* neighbor_array = reinterpret_cast(warp_smem + coords_size); DistPair* candidate_queue_smem = reinterpret_cast*>(warp_smem + coords_size + neighbor_size); @@ -156,7 +155,7 @@ __global__ __launch_bounds__(128, 12) void GreedySearchKernel( } Node* topk_pq = &topk_pq_mem[(blockIdx.x * 4 + warpIdx) * topk]; - const T* vec_ptr = &dataset(0, 0); + const T* vec_ptr = &dataset(0, 0); for (int i = blockIdx.x * 4 + warpIdx; i < num_queries; i += gridDim.x * 4) { query_list[i].reset_warp(laneId); @@ -174,8 +173,7 @@ __global__ __launch_bounds__(128, 12) void GreedySearchKernel( &s_query_half, vec_ptr, cur_query_id, dim, laneId); } } else if constexpr (is_cuda_fp16_v) { - update_shared_point_warp_half_to_float( - &s_query, vec_ptr, cur_query_id, dim, laneId); + update_shared_point_warp_half_to_float(&s_query, vec_ptr, cur_query_id, dim, laneId); } else { update_shared_point_warp(&s_query, vec_ptr, cur_query_id, dim, laneId); } @@ -195,17 +193,11 @@ __global__ __launch_bounds__(128, 12) void GreedySearchKernel( accT medoid_dist; if (fp16_query_smem) { - medoid_dist = dist_warp_half_query(s_coords_half, - &vec_ptr[(size_t)medoid_id * (size_t)dim], - dim, - metric, - laneId); + medoid_dist = dist_warp_half_query( + s_coords_half, &vec_ptr[(size_t)medoid_id * (size_t)dim], dim, metric, laneId); } else if constexpr (is_cuda_fp16_v) { - medoid_dist = dist_warp(s_coords, - &vec_ptr[(size_t)medoid_id * (size_t)dim], - dim, - metric, - laneId); + medoid_dist = + dist_warp(s_coords, &vec_ptr[(size_t)medoid_id * (size_t)dim], dim, metric, laneId); } else { medoid_dist = dist_warp( s_coords, &vec_ptr[(size_t)medoid_id * (size_t)dim], dim, metric, laneId); @@ -266,7 +258,8 @@ __global__ __launch_bounds__(128, 12) void GreedySearchKernel( for (size_t j = laneId; j < degree; j += 32) { neighbor_array[j] = graph(cand_num, j); if (neighbor_array[j] == raft::upper_bound()) - atomicMin(&num_neighbors[warpIdx], (int)j); // warp-wide min to find the number of neighbors + atomicMin(&num_neighbors[warpIdx], + (int)j); // warp-wide min to find the number of neighbors } enqueue_all_neighbors_warp(num_neighbors[warpIdx], @@ -286,7 +279,7 @@ __global__ __launch_bounds__(128, 12) void GreedySearchKernel( if (query_list[i].ids[j] == cur_query_id) { query_list[i].dists[j] = raft::upper_bound(); query_list[i].ids[j] = raft::upper_bound(); - self_found = true; // Flat to reduce size by 1 + self_found = true; // Flat to reduce size by 1 } } self_found = (raft::ballot(self_found) != 0); diff --git a/cpp/src/neighbors/detail/vamana/priority_queue.cuh b/cpp/src/neighbors/detail/vamana/priority_queue.cuh index be213e8df1..5e44563648 100644 --- a/cpp/src/neighbors/detail/vamana/priority_queue.cuh +++ b/cpp/src/neighbors/detail/vamana/priority_queue.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -211,11 +211,12 @@ __host__ __device__ bool operator>(const Node& first, const Node other.distance; } - // each warp scans its own pq with laneId and stride 32 to find duplicates template -__device__ bool check_duplicate_warp( - const Node* pq, const int size, Node new_node, int laneId) +__device__ bool check_duplicate_warp(const Node* pq, + const int size, + Node new_node, + int laneId) { bool found = false; for (int i = laneId; i < size; i += 32) { @@ -228,7 +229,6 @@ __device__ bool check_duplicate_warp( return (mask != 0); } - // Warp-level version: no __syncthreads, uses laneId for single-thread ops and warp shuffle template __inline__ __device__ void parallel_pq_max_enqueue_warp(Node* pq, @@ -282,24 +282,23 @@ __inline__ __device__ void parallel_pq_max_enqueue_warp(Node* pq, } } - // Warp-level version: lane 0 does insert_back, no __syncthreads template __forceinline__ __device__ void enqueue_all_neighbors_warp(int num_neighbors, - Point* query_vec, - const DataT* vec_ptr, - IdxT* neighbor_array, - PriorityQueue& heap_queue, - int dim, - cuvs::distance::DistanceType metric, - int laneId) + Point* query_vec, + const DataT* vec_ptr, + IdxT* neighbor_array, + PriorityQueue& heap_queue, + int dim, + cuvs::distance::DistanceType metric, + int laneId) { for (int i = 0; i < num_neighbors; i++) { const DataT* neighbor_vec = &vec_ptr[(size_t)(neighbor_array[i]) * (size_t)(dim)]; accT dist_out; if constexpr (std::is_same_v) { - dist_out = dist_warp_half_query( - query_vec->coords, neighbor_vec, dim, metric, laneId); + dist_out = + dist_warp_half_query(query_vec->coords, neighbor_vec, dim, metric, laneId); } else if constexpr (std::is_same_v && is_cuda_fp16_v) { dist_out = dist_warp(query_vec->coords, neighbor_vec, dim, metric, laneId); } else { @@ -310,20 +309,19 @@ __forceinline__ __device__ void enqueue_all_neighbors_warp(int num_neighbors, } } - // Half-precision version with two code paths based on query being fp16 or fp32 template -__forceinline__ __device__ void enqueue_all_neighbors_warp(int num_neighbors, - bool fp16_query_smem, - __half* s_coords_half, - typename greedy_search_query_coord::type* - s_coords, - const T* vec_ptr, - IdxT* neighbor_array, - PriorityQueue& heap_queue, - int dim, - cuvs::distance::DistanceType metric, - int laneId) +__forceinline__ __device__ void enqueue_all_neighbors_warp( + int num_neighbors, + bool fp16_query_smem, + __half* s_coords_half, + typename greedy_search_query_coord::type* s_coords, + const T* vec_ptr, + IdxT* neighbor_array, + PriorityQueue& heap_queue, + int dim, + cuvs::distance::DistanceType metric, + int laneId) { if (fp16_query_smem) { Point<__half, accT> query_vec; diff --git a/cpp/src/neighbors/detail/vamana/robust_prune.cuh b/cpp/src/neighbors/detail/vamana/robust_prune.cuh index 9bf5a691e2..32aeffa0a5 100644 --- a/cpp/src/neighbors/detail/vamana/robust_prune.cuh +++ b/cpp/src/neighbors/detail/vamana/robust_prune.cuh @@ -84,7 +84,7 @@ __global__ void RobustPruneKernel( int align_padding = raft::alignTo(dim, alignof(ShmemLayout)) - dim; - float* occlusion_list = reinterpret_cast(smem); + float* occlusion_list = reinterpret_cast(smem); const int nbh_list_offset = (degree + visited_size) * sizeof(float); DistPair* new_nbh_list = reinterpret_cast*>(&smem[nbh_list_offset]); @@ -92,18 +92,17 @@ __global__ void RobustPruneKernel( nbh_list_offset + (degree + visited_size) * sizeof(DistPair); DistPair* query_cache = reinterpret_cast*>(&smem[query_cache_offset]); - const int cand_coords_offset = - query_cache_offset + visited_size * sizeof(DistPair); - const int coord_bytes = (dim + align_padding) * static_cast(sizeof(QueryCoordT)); - int graph_dists_offset = cand_coords_offset; - QueryCoordT* s_cand_coords = nullptr; + const int cand_coords_offset = query_cache_offset + visited_size * sizeof(DistPair); + const int coord_bytes = (dim + align_padding) * static_cast(sizeof(QueryCoordT)); + int graph_dists_offset = cand_coords_offset; + QueryCoordT* s_cand_coords = nullptr; if (dim >= kRobustPruneCandCacheMinDim) { s_cand_coords = reinterpret_cast(&smem[cand_coords_offset]); graph_dists_offset = cand_coords_offset + coord_bytes; } - accT* graph_dists = reinterpret_cast(&smem[graph_dists_offset]); + accT* graph_dists = reinterpret_cast(&smem[graph_dists_offset]); const int graph_ids_offset = graph_dists_offset + degree * sizeof(accT); - IdxT* graph_ids = reinterpret_cast(&smem[graph_ids_offset]); + IdxT* graph_ids = reinterpret_cast(&smem[graph_ids_offset]); static __shared__ Point s_query; s_query.coords = &s_coords_mem[blockIdx.x * (dim + align_padding)]; @@ -113,8 +112,8 @@ __global__ void RobustPruneKernel( static __shared__ int s_do_accept; static __shared__ int s_res_size; - const int laneId = threadIdx.x & 31; - const int warpId = threadIdx.x >> 5; + const int laneId = threadIdx.x & 31; + const int warpId = threadIdx.x >> 5; const int num_warps = blockDim.x >> 5; for (int i = blockIdx.x; i < num_queries; i += gridDim.x) { @@ -147,25 +146,22 @@ __global__ void RobustPruneKernel( __syncthreads(); // Precompute graph-edge distances in parallel; reuse GreedySearch dists when bit-exact. - const int visited_count = query_list[i].size; - const IdxT* visited_ids = query_list[i].ids; - const accT* visited_dists = query_list[i].dists; + const int visited_count = query_list[i].size; + const IdxT* visited_ids = query_list[i].ids; + const accT* visited_dists = query_list[i].dists; const bool reuse_search_dists = (dim >= kRobustPruneCandCacheMinDim); for (int j = warpId; j < prev_edges; j += num_warps) { IdxT gid = graph(queryId, j); accT d; bool found = false; if (reuse_search_dists) { - found = lookup_visited_dist_warp( - visited_ids, visited_dists, visited_count, gid, d, laneId); + found = lookup_visited_dist_warp(visited_ids, visited_dists, visited_count, gid, d, laneId); } if (!found) { if constexpr (is_cuda_fp16_v) { - d = dist_warp( - s_query.coords, &dataset((size_t)gid, 0), dim, metric, laneId); + d = dist_warp(s_query.coords, &dataset((size_t)gid, 0), dim, metric, laneId); } else { - d = dist_warp( - s_query.coords, &dataset((size_t)gid, 0), dim, metric, laneId); + d = dist_warp(s_query.coords, &dataset((size_t)gid, 0), dim, metric, laneId); } } if (laneId == 0) { @@ -249,15 +245,14 @@ __global__ void RobustPruneKernel( } // Go through different alpha values. These constants are hard-coded in the MSFT DiskANN code - for (float cur_alpha = 1.0; cur_alpha <= alpha && s_accept_count < degree; - cur_alpha *= 1.2) { + for (float cur_alpha = 1.0; cur_alpha <= alpha && s_accept_count < degree; cur_alpha *= 1.2) { for (int pass_start = 0; pass_start < s_res_size && s_accept_count < degree; pass_start++) { if (threadIdx.x == 0) { - s_do_accept = (occlusion_list[pass_start] != raft::lower_bound() && - occlusion_list[pass_start] <= cur_alpha && - new_nbh_list[pass_start].idx != queryId) - ? 1 - : 0; + s_do_accept = + (occlusion_list[pass_start] != raft::lower_bound() && + occlusion_list[pass_start] <= cur_alpha && new_nbh_list[pass_start].idx != queryId) + ? 1 + : 0; } __syncthreads(); @@ -295,7 +290,7 @@ __global__ void RobustPruneKernel( djk = dist_warp(cand_ptr, k_ptr, dim, metric, laneId); } if (laneId == 0) { - accT new_occ = (float)(new_nbh_list[occId].dist / djk); + accT new_occ = (float)(new_nbh_list[occId].dist / djk); occlusion_list[occId] = std::max(occlusion_list[occId], new_occ); } } diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index eca9b5d66b..795928c0c8 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -129,11 +129,11 @@ void batched_insert_vamana( IdxT* medoid_id, cuvs::distance::DistanceType metric) { - auto stream = raft::resource::get_cuda_stream(res); - cudaStream_t cs = stream; - int N = dataset.extent(0); - int dim = dataset.extent(1); - int degree = graph.extent(1); + auto stream = raft::resource::get_cuda_stream(res); + cudaStream_t cs = stream; + int N = dataset.extent(0); + int dim = dataset.extent(1); + int degree = graph.extent(1); // Algorithm params int max_batchsize = (int)(params.max_fraction * (float)N); @@ -216,18 +216,16 @@ void batched_insert_vamana( SELECT_SORT_SMEM_SIZE(degree, visited_size); // Sets sort_smem_size based on dataset // GreedySearch: per-warp shared memory (4 warps): coords, neighbor_array, candidate_queue - const int search_coords_size = - (dim + align_padding) * greedy_search_query_smem_elem_size(dim); - const int coords_size = (dim + align_padding) * static_cast(sizeof(QueryCoordT)); - const int neighbor_size = degree * sizeof(IdxT); - const int queue_size_bytes = queue_size * sizeof(DistPair); + const int search_coords_size = (dim + align_padding) * greedy_search_query_smem_elem_size(dim); + const int coords_size = (dim + align_padding) * static_cast(sizeof(QueryCoordT)); + const int neighbor_size = degree * sizeof(IdxT); + const int queue_size_bytes = queue_size * sizeof(DistPair); int search_smem_total_size = static_cast(4 * ((search_coords_size + neighbor_size + queue_size_bytes + 15) & ~15)); // Total dynamic shared memory size needed by both RobustPrune calls - const int cand_coords_smem_size = - (dim >= kRobustPruneCandCacheMinDim) ? coords_size : 0; - int prune_smem_total_size = (degree + visited_size) * sizeof(float) + // Occlusion list + const int cand_coords_smem_size = (dim >= kRobustPruneCandCacheMinDim) ? coords_size : 0; + int prune_smem_total_size = (degree + visited_size) * sizeof(float) + // Occlusion list (degree + visited_size) * sizeof(DistPair) + visited_size * sizeof(DistPair) + // merge query cache cand_coords_smem_size + @@ -268,10 +266,10 @@ void batched_insert_vamana( auto edge_src = raft::make_device_mdarray(res, large_ws, raft::make_extents(max_total_edges)); - auto edge_counts = raft::make_device_mdarray( - res, large_ws, raft::make_extents(max_batchsize + 1)); - auto edge_offsets = raft::make_device_mdarray( - res, large_ws, raft::make_extents(max_batchsize + 1)); + auto edge_counts = + raft::make_device_mdarray(res, large_ws, raft::make_extents(max_batchsize + 1)); + auto edge_offsets = + raft::make_device_mdarray(res, large_ws, raft::make_extents(max_batchsize + 1)); size_t temp_storage_bytes_dist = 0; size_t temp_storage_bytes_edge = 0; @@ -289,9 +287,8 @@ void batched_insert_vamana( max_total_edges, CmpEdge(), cs); - size_t temp_storage_bytes = - std::max(temp_storage_bytes_dist, temp_storage_bytes_edge); - auto temp_sort_storage = raft::make_device_mdarray( + size_t temp_storage_bytes = std::max(temp_storage_bytes_dist, temp_storage_bytes_edge); + auto temp_sort_storage = raft::make_device_mdarray( res, large_ws, raft::make_extents(std::max(temp_storage_bytes, size_t{1}))); size_t scan_temp_bytes = 0; @@ -335,7 +332,7 @@ void batched_insert_vamana( if (start + step_size > N) { step_size = N - start; } RAFT_LOG_DEBUG("Starting batch of inserts indices_start:%d, batch_size:%d", start, step_size); - int num_blocks = min(maxBlocks, step_size); + int num_blocks = min(maxBlocks, step_size); int num_blocks_greedy = min(maxBlocks, (step_size + 3) / 4); // Copy ids to be inserted for this batch @@ -345,15 +342,16 @@ void batched_insert_vamana( // Call greedy search to get candidates for every vector being inserted GreedySearchKernel - <<>>(d_graph.view(), - dataset, - query_list_ptr.data_handle(), - step_size, - *medoid_id, - visited_size, - metric, - queue_size, - topk_pq_mem.data_handle()); + <<>>( + d_graph.view(), + dataset, + query_list_ptr.data_handle(), + step_size, + *medoid_id, + visited_size, + metric, + queue_size, + topk_pq_mem.data_handle()); RAFT_CUDA_TRY(cudaPeekAtLastError()); #if KERNEL_TIMING @@ -380,13 +378,13 @@ void batched_insert_vamana( // Run on candidates of vectors being inserted RobustPruneKernel <<>>(d_graph.view(), - dataset, - query_list_ptr.data_handle(), - step_size, - visited_size, - metric, - alpha, - s_coords_mem.data_handle()); + dataset, + query_list_ptr.data_handle(), + step_size, + visited_size, + metric, + alpha, + s_coords_mem.data_handle()); RAFT_CUDA_TRY(cudaPeekAtLastError()); // Segmented sort on query list @@ -417,8 +415,8 @@ void batched_insert_vamana( // compute prefix sums of query_list sizes const int prefix_count = step_size + 1; - gather_query_sizes<<>>( - query_list, edge_counts.data_handle(), prefix_count); + gather_query_sizes + <<>>(query_list, edge_counts.data_handle(), prefix_count); RAFT_CUDA_TRY(cudaPeekAtLastError()); cub::DeviceScan::ExclusiveSum(scan_temp_storage.data_handle(), @@ -429,8 +427,8 @@ void batched_insert_vamana( cs); RAFT_CUDA_TRY(cudaPeekAtLastError()); - scatter_prefix_offsets<<>>( - query_list, edge_offsets.data_handle(), prefix_count); + scatter_prefix_offsets + <<>>(query_list, edge_offsets.data_handle(), prefix_count); RAFT_CUDA_TRY(cudaPeekAtLastError()); int total_edges; @@ -485,9 +483,8 @@ void batched_insert_vamana( auto unique_indices = raft::make_device_vector(res, total_edges); raft::linalg::map_offset(res, unique_indices.view(), raft::identity_op{}); - thrust::unique_by_key(edge_dest_vec.begin(), - edge_dest_vec.begin() + total_edges, - unique_indices.data_handle()); + thrust::unique_by_key( + edge_dest_vec.begin(), edge_dest_vec.begin() + total_edges, unique_indices.data_handle()); #if KERNEL_TIMING RAFT_CUDA_TRY(cudaDeviceSynchronize()); @@ -532,16 +529,15 @@ void batched_insert_vamana( RAFT_CUDA_TRY(cudaPeekAtLastError()); // Call 2nd RobustPrune on reverse query_list - RobustPruneKernel - <<>>(d_graph.view(), - raft::make_const_mdspan( - dataset), - reverse_list_ptr.data_handle(), - reverse_batch, - visited_size, - metric, - alpha, - s_coords_mem.data_handle()); + RobustPruneKernel<<>>( + d_graph.view(), + raft::make_const_mdspan(dataset), + reverse_list_ptr.data_handle(), + reverse_batch, + visited_size, + metric, + alpha, + s_coords_mem.data_handle()); RAFT_CUDA_TRY(cudaPeekAtLastError()); // Segmented sort on reverse_list diff --git a/cpp/src/neighbors/detail/vamana/vamana_structs.cuh b/cpp/src/neighbors/detail/vamana/vamana_structs.cuh index 58b3266216..8af77bfd1f 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_structs.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_structs.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -92,7 +92,7 @@ __device__ __host__ void swap(DistPair* a, DistPair* b) template struct CmpDist { __host__ __device__ bool operator()(const DistPair& lhs, - const DistPair& rhs) + const DistPair& rhs) { return lhs.dist < rhs.dist; } @@ -255,8 +255,8 @@ __device__ SUMTYPE l2_SEQ_half(Point<__half, SUMTYPE>* src_vec, Point<__half, SU template __device__ SUMTYPE l2_ILP2_half(Point<__half, SUMTYPE>* src_vec, Point<__half, SUMTYPE>* dst_vec) { - __half temp_dst[2] = {__float2half(0.0f), __float2half(0.0f)}; - __half partial_sum[2] = {__float2half(0.0f), __float2half(0.0f)}; + __half temp_dst[2] = {__float2half(0.0f), __float2half(0.0f)}; + __half partial_sum[2] = {__float2half(0.0f), __float2half(0.0f)}; for (int i = threadIdx.x; i < src_vec->Dim; i += 2 * blockDim.x) { temp_dst[0] = dst_vec->coords[i]; if (i + 32 < src_vec->Dim) temp_dst[1] = dst_vec->coords[i + 32]; @@ -273,14 +273,10 @@ __device__ SUMTYPE l2_ILP2_half(Point<__half, SUMTYPE>* src_vec, Point<__half, S template __device__ SUMTYPE l2_ILP4_half(Point<__half, SUMTYPE>* src_vec, Point<__half, SUMTYPE>* dst_vec) { - __half temp_dst[4] = {__float2half(0.0f), - __float2half(0.0f), - __float2half(0.0f), - __float2half(0.0f)}; - __half partial_sum[4] = {__float2half(0.0f), - __float2half(0.0f), - __float2half(0.0f), - __float2half(0.0f)}; + __half temp_dst[4] = { + __float2half(0.0f), __float2half(0.0f), __float2half(0.0f), __float2half(0.0f)}; + __half partial_sum[4] = { + __float2half(0.0f), __float2half(0.0f), __float2half(0.0f), __float2half(0.0f)}; for (int i = threadIdx.x; i < src_vec->Dim; i += 4 * blockDim.x) { temp_dst[0] = dst_vec->coords[i]; if (i + 32 < src_vec->Dim) temp_dst[1] = dst_vec->coords[i + 32]; @@ -295,7 +291,8 @@ __device__ SUMTYPE l2_ILP4_half(Point<__half, SUMTYPE>* src_vec, Point<__half, S if (i + 96 < src_vec->Dim) l2_half_fma_sq(partial_sum[3], src_vec[0].coords[i + 96], temp_dst[3]); } - partial_sum[0] = __hadd(partial_sum[0], __hadd(partial_sum[1], __hadd(partial_sum[2], partial_sum[3]))); + partial_sum[0] = + __hadd(partial_sum[0], __hadd(partial_sum[1], __hadd(partial_sum[2], partial_sum[3]))); return l2_half_warp_reduce(partial_sum[0]); } @@ -460,14 +457,10 @@ __device__ SUMTYPE l2_ILP4_half_warp(Point<__half, SUMTYPE>* src_vec, Point<__half, SUMTYPE>* dst_vec, int lane) { - __half temp_dst[4] = {__float2half(0.0f), - __float2half(0.0f), - __float2half(0.0f), - __float2half(0.0f)}; - __half partial_sum[4] = {__float2half(0.0f), - __float2half(0.0f), - __float2half(0.0f), - __float2half(0.0f)}; + __half temp_dst[4] = { + __float2half(0.0f), __float2half(0.0f), __float2half(0.0f), __float2half(0.0f)}; + __half partial_sum[4] = { + __float2half(0.0f), __float2half(0.0f), __float2half(0.0f), __float2half(0.0f)}; for (int i = lane; i < src_vec->Dim; i += 4 * VAMANA_WARP_SIZE) { temp_dst[0] = dst_vec->coords[i]; if (i + 32 < src_vec->Dim) temp_dst[1] = dst_vec->coords[i + 32]; @@ -482,12 +475,15 @@ __device__ SUMTYPE l2_ILP4_half_warp(Point<__half, SUMTYPE>* src_vec, if (i + 96 < src_vec->Dim) l2_half_fma_sq(partial_sum[3], src_vec[0].coords[i + 96], temp_dst[3]); } - partial_sum[0] = __hadd(partial_sum[0], __hadd(partial_sum[1], __hadd(partial_sum[2], partial_sum[3]))); + partial_sum[0] = + __hadd(partial_sum[0], __hadd(partial_sum[1], __hadd(partial_sum[2], partial_sum[3]))); return l2_half_warp_reduce(partial_sum[0]); } template -__forceinline__ __device__ SUMTYPE l2_warp(Point* src_vec, Point* dst_vec, int lane) +__forceinline__ __device__ SUMTYPE l2_warp(Point* src_vec, + Point* dst_vec, + int lane) { if constexpr (std::is_same_v, __half>) { if (src_vec->Dim >= 128) { @@ -718,7 +714,8 @@ l2_warp_half_float(const __half* src, const float* dest, int dim, int lane) /* fp16 query smem vs half dataset: vectorized half2 loads, widen to float, float accumulate * (mirror of l2_warp_float_half; avoids scalar smem loads and native half FMA) */ template -__device__ SUMTYPE l2_SEQ_warp_half_smem_half(const __half* src, const __half* dst, int dim, int lane) +__device__ SUMTYPE +l2_SEQ_warp_half_smem_half(const __half* src, const __half* dst, int dim, int lane) { SUMTYPE partial_sum = 0; for (int i = lane * 2; i < dim; i += VAMANA_WARP_SIZE * 2) { @@ -733,7 +730,8 @@ __device__ SUMTYPE l2_SEQ_warp_half_smem_half(const __half* src, const __half* d } template -__device__ SUMTYPE l2_ILP2_warp_half_smem_half(const __half* src, const __half* dst, int dim, int lane) +__device__ SUMTYPE +l2_ILP2_warp_half_smem_half(const __half* src, const __half* dst, int dim, int lane) { SUMTYPE partial_sum[2] = {0, 0}; for (int i = lane * 2; i < dim; i += 2 * VAMANA_WARP_SIZE * 2) { @@ -756,7 +754,8 @@ __device__ SUMTYPE l2_ILP2_warp_half_smem_half(const __half* src, const __half* } template -__device__ SUMTYPE l2_ILP4_warp_half_smem_half(const __half* src, const __half* dst, int dim, int lane) +__device__ SUMTYPE +l2_ILP4_warp_half_smem_half(const __half* src, const __half* dst, int dim, int lane) { SUMTYPE partial_sum[4] = {0, 0, 0, 0}; for (int i = lane * 2; i < dim; i += 4 * VAMANA_WARP_SIZE * 2) { @@ -801,12 +800,13 @@ l2_warp_half_smem_half(const __half* src, const __half* dest, int dim, int lane) } } -/* fp16 query smem vs int8 (or other native) dataset: same vectorized query widen, float accumulate */ +/* fp16 query smem vs int8 (or other native) dataset: same vectorized query widen, float accumulate + */ template __device__ __forceinline__ void l2_fma_sq2_half_native(SUMTYPE& acc, - float2 src2, - const DataT* dst, - int i) + float2 src2, + const DataT* dst, + int i) { float dx = src2.x - static_cast(dst[i]); float dy = src2.y - static_cast(dst[i + 1]); @@ -815,7 +815,8 @@ __device__ __forceinline__ void l2_fma_sq2_half_native(SUMTYPE& acc, } template -__device__ SUMTYPE l2_SEQ_warp_half_smem_native(const __half* src, const DataT* dst, int dim, int lane) +__device__ SUMTYPE +l2_SEQ_warp_half_smem_native(const __half* src, const DataT* dst, int dim, int lane) { SUMTYPE partial_sum = 0; for (int i = lane * 2; i < dim; i += VAMANA_WARP_SIZE * 2) { @@ -829,7 +830,8 @@ __device__ SUMTYPE l2_SEQ_warp_half_smem_native(const __half* src, const DataT* } template -__device__ SUMTYPE l2_ILP2_warp_half_smem_native(const __half* src, const DataT* dst, int dim, int lane) +__device__ SUMTYPE +l2_ILP2_warp_half_smem_native(const __half* src, const DataT* dst, int dim, int lane) { SUMTYPE partial_sum[2] = {0, 0}; for (int i = lane * 2; i < dim; i += 2 * VAMANA_WARP_SIZE * 2) { @@ -848,7 +850,8 @@ __device__ SUMTYPE l2_ILP2_warp_half_smem_native(const __half* src, const DataT* } template -__device__ SUMTYPE l2_ILP4_warp_half_smem_native(const __half* src, const DataT* dst, int dim, int lane) +__device__ SUMTYPE +l2_ILP4_warp_half_smem_native(const __half* src, const DataT* dst, int dim, int lane) { SUMTYPE partial_sum[4] = {0, 0, 0, 0}; for (int i = lane * 2; i < dim; i += 4 * VAMANA_WARP_SIZE * 2) { @@ -888,8 +891,7 @@ l2_warp_half_smem_native(const __half* src, const DataT* dest, int dim, int lane } template -__forceinline__ __device__ SUMTYPE -dist_warp_half_query( +__forceinline__ __device__ SUMTYPE dist_warp_half_query( const __half* src, const DataT* dest, int dim, cuvs::distance::DistanceType metric, int lane) { SUMTYPE d; @@ -934,13 +936,13 @@ template __forceinline__ __device__ bool lookup_visited_dist_warp( const IdxT* ids, const accT* dists, int size, IdxT target, accT& out_dist, int laneId) { - bool found = false; - accT found_dist = static_cast(0); + bool found = false; + accT found_dist = static_cast(0); const int num_iters = (size + 31) >> 5; for (int k = 0; k < num_iters; ++k) { - const int j = (k << 5) + laneId; - const bool hit = (j < size) && (ids[j] == target); - accT my_dist = hit ? dists[j] : static_cast(0); + const int j = (k << 5) + laneId; + const bool hit = (j < size) && (ids[j] == target); + accT my_dist = hit ? dists[j] : static_cast(0); const unsigned hits = raft::ballot(hit); if (hits != 0 && !found) { const int src_lane = __ffs(hits) - 1; @@ -1123,11 +1125,8 @@ __device__ void update_shared_point(Point* shared_point, // Warp-level: uses laneId and stride 32 for coordinate copy template -__device__ void update_shared_point_warp(Point* shared_point, - const T* data_ptr, - int id, - int dim, - int laneId) +__device__ void update_shared_point_warp( + Point* shared_point, const T* data_ptr, int id, int dim, int laneId) { shared_point->id = id; shared_point->Dim = dim; @@ -1148,9 +1147,9 @@ __device__ void update_shared_point_half_to_float(Point* shared_poi shared_point->Dim = dim; const size_t base = (size_t)id * (size_t)dim; for (size_t i = threadIdx.x * 2; i + 1 < (size_t)dim; i += (size_t)blockDim.x * 2) { - float2 promoted = __half22float2(*reinterpret_cast(&half_ptr[base + i])); - float2* coord_pair = reinterpret_cast(&shared_point->coords[i]); - *coord_pair = promoted; + float2 promoted = __half22float2(*reinterpret_cast(&half_ptr[base + i])); + float2* coord_pair = reinterpret_cast(&shared_point->coords[i]); + *coord_pair = promoted; } if (((size_t)dim & 1u) != 0u && threadIdx.x == 0) { shared_point->coords[dim - 1] = __half2float(half_ptr[base + dim - 1]); @@ -1158,11 +1157,8 @@ __device__ void update_shared_point_half_to_float(Point* shared_poi } template -__device__ void update_shared_point_warp_half_to_float(Point* shared_point, - const half* data_ptr, - int id, - int dim, - int laneId) +__device__ void update_shared_point_warp_half_to_float( + Point* shared_point, const half* data_ptr, int id, int dim, int laneId) { const __half* half_ptr = reinterpret_cast(data_ptr); shared_point->id = id; @@ -1179,11 +1175,8 @@ __device__ void update_shared_point_warp_half_to_float(Point* share } template -__device__ void update_shared_point_warp_fp16_query_smem(Point<__half, accT>* shared_point, - const half* data_ptr, - int id, - int dim, - int laneId) +__device__ void update_shared_point_warp_fp16_query_smem( + Point<__half, accT>* shared_point, const half* data_ptr, int id, int dim, int laneId) { const __half* half_ptr = reinterpret_cast(data_ptr); shared_point->id = id; @@ -1199,11 +1192,8 @@ __device__ void update_shared_point_warp_fp16_query_smem(Point<__half, accT>* sh } template -__device__ void update_shared_point_warp_fp16_query_smem(Point<__half, accT>* shared_point, - const float* data_ptr, - int id, - int dim, - int laneId) +__device__ void update_shared_point_warp_fp16_query_smem( + Point<__half, accT>* shared_point, const float* data_ptr, int id, int dim, int laneId) { shared_point->id = id; shared_point->Dim = dim; @@ -1219,11 +1209,8 @@ __device__ void update_shared_point_warp_fp16_query_smem(Point<__half, accT>* sh template __device__ std::enable_if_t && !std::is_same_v, void> -update_shared_point_warp_fp16_query_smem(Point<__half, accT>* shared_point, - const T* data_ptr, - int id, - int dim, - int laneId) +update_shared_point_warp_fp16_query_smem( + Point<__half, accT>* shared_point, const T* data_ptr, int id, int dim, int laneId) { shared_point->id = id; shared_point->Dim = dim; diff --git a/cpp/src/neighbors/vamana.cuh b/cpp/src/neighbors/vamana.cuh index a6c941d3c4..f202ef20af 100644 --- a/cpp/src/neighbors/vamana.cuh +++ b/cpp/src/neighbors/vamana.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ diff --git a/cpp/src/neighbors/vamana_serialize_half.cu b/cpp/src/neighbors/vamana_serialize_half.cu index 76e0d81fa7..099db7b2bb 100644 --- a/cpp/src/neighbors/vamana_serialize_half.cu +++ b/cpp/src/neighbors/vamana_serialize_half.cu @@ -3,8 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -#include #include "vamana_serialize.cuh" +#include namespace cuvs::neighbors::vamana { diff --git a/examples/cpp/src/vamana_example.cu b/examples/cpp/src/vamana_example.cu index a73dd2a0b5..fc578abe1f 100644 --- a/examples/cpp/src/vamana_example.cu +++ b/examples/cpp/src/vamana_example.cu @@ -142,13 +142,13 @@ int main(int argc, char* argv[]) // Simple build example to create graph and write to a file vamana_build_and_write<__half>(dev_resources, - raft::make_const_mdspan(dataset.view()), - out_fname, - degree, - max_visited, - max_fraction, - iters, - codebook_prefix); + raft::make_const_mdspan(dataset.view()), + out_fname, + degree, + max_visited, + max_fraction, + iters, + codebook_prefix); } else { usage(); } diff --git a/python/cuvs/cuvs/neighbors/vamana/vamana.pyx b/python/cuvs/cuvs/neighbors/vamana/vamana.pyx index 3bd2b6706b..277b6d5597 100644 --- a/python/cuvs/cuvs/neighbors/vamana/vamana.pyx +++ b/python/cuvs/cuvs/neighbors/vamana/vamana.pyx @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 # # cython: language_level=3 diff --git a/python/cuvs/cuvs/tests/test_vamana.py b/python/cuvs/cuvs/tests/test_vamana.py index bd19615788..6158d5da77 100644 --- a/python/cuvs/cuvs/tests/test_vamana.py +++ b/python/cuvs/cuvs/tests/test_vamana.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 import numpy as np From 2dc5fa91dcbaefb228e1d9277cd776b0728c2b3f Mon Sep 17 00:00:00 2001 From: bkarsin Date: Wed, 24 Jun 2026 16:25:58 -0700 Subject: [PATCH 14/14] Fix bug with odd dimension data for some cases --- .../neighbors/detail/vamana/vamana_build.cuh | 2 +- .../detail/vamana/vamana_structs.cuh | 74 ++++++++++++++++--- 2 files changed, 63 insertions(+), 13 deletions(-) diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index 795928c0c8..6555e3765d 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -186,7 +186,7 @@ void batched_insert_vamana( raft::resource::get_large_workspace_resource_ref(res), raft::make_extents(max_batchsize, visited_size)); - // Assign memory to query_list structures and initiailize + // Assign memory to query_list structures and initialize init_query_candidate_list<<<256, blockD, 0, stream>>>(query_list, visited_ids.data_handle(), visited_dists.data_handle(), diff --git a/cpp/src/neighbors/detail/vamana/vamana_structs.cuh b/cpp/src/neighbors/detail/vamana/vamana_structs.cuh index 8af77bfd1f..5f41dc9afe 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_structs.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_structs.cuh @@ -537,6 +537,35 @@ __device__ __forceinline__ void l2_fma_sq2(SUMTYPE& acc, float sx, float sy, flo acc = fmaf(dy, dy, acc); } +// Widen any supported operand type to float for the scalar fallback path. +template +__device__ __forceinline__ float l2_widen_to_float(X x) +{ + return static_cast(x); +} +__device__ __forceinline__ float l2_widen_to_float(__half x) { return __half2float(x); } + +// Scalar, alignment-agnostic warp L2 used as a fallback when `dim` is ODD. The +// vectorized float2/half2 comparators below assume each dataset row begins on a +// vector-aligned boundary, which only holds when dim is even (row stride = +// dim * sizeof(elem) is then a multiple of the vector width). For odd dim, odd +// rows are under-aligned and a float2/half2 load raises cudaErrorMisalignedAddress, +// so we widen both operands to float and accumulate one element per lane instead. +template +__device__ __forceinline__ SUMTYPE +l2_warp_scalar_widen(const SrcT* src, const DstT* dst, int dim, int lane) +{ + SUMTYPE partial_sum = 0; + for (int i = lane; i < dim; i += VAMANA_WARP_SIZE) { + float diff = l2_widen_to_float(src[i]) - l2_widen_to_float(dst[i]); + partial_sum = fmaf(diff, diff, partial_sum); + } + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum += __shfl_down_sync(FULL_BITMASK, partial_sum, offset); + } + return partial_sum; +} + template __device__ SUMTYPE l2_SEQ_warp_float_half(const float* src, const __half* dst, int dim, int lane) { @@ -612,6 +641,7 @@ template __forceinline__ __device__ SUMTYPE l2_warp_float_half(const float* src, const __half* dest, int dim, int lane) { + if (dim & 1) { return l2_warp_scalar_widen(src, dest, dim, lane); } if (dim >= 128) { return l2_ILP4_warp_float_half(src, dest, dim, lane); } else if (dim >= 64) { @@ -702,6 +732,7 @@ template __forceinline__ __device__ SUMTYPE l2_warp_half_float(const __half* src, const float* dest, int dim, int lane) { + if (dim & 1) { return l2_warp_scalar_widen(src, dest, dim, lane); } if (dim >= 128) { return l2_ILP4_warp_half_float(src, dest, dim, lane); } else if (dim >= 64) { @@ -791,6 +822,7 @@ template __forceinline__ __device__ SUMTYPE l2_warp_half_smem_half(const __half* src, const __half* dest, int dim, int lane) { + if (dim & 1) { return l2_warp_scalar_widen(src, dest, dim, lane); } if (dim >= 128) { return l2_ILP4_warp_half_smem_half(src, dest, dim, lane); } else if (dim >= 64) { @@ -881,6 +913,7 @@ template __forceinline__ __device__ SUMTYPE l2_warp_half_smem_native(const __half* src, const DataT* dest, int dim, int lane) { + if (dim & 1) { return l2_warp_scalar_widen(src, dest, dim, lane); } if (dim >= 128) { return l2_ILP4_warp_half_smem_native(src, dest, dim, lane); } else if (dim >= 64) { @@ -1146,14 +1179,19 @@ __device__ void update_shared_point_half_to_float(Point* shared_poi shared_point->id = id; shared_point->Dim = dim; const size_t base = (size_t)id * (size_t)dim; + // Odd dim => odd rows start at an under-aligned address; half2 loads would + // fault (cudaErrorMisalignedAddress). Use scalar promotion in that case. + if ((dim & 1) != 0) { + for (size_t i = threadIdx.x; i < (size_t)dim; i += blockDim.x) { + shared_point->coords[i] = __half2float(half_ptr[base + i]); + } + return; + } for (size_t i = threadIdx.x * 2; i + 1 < (size_t)dim; i += (size_t)blockDim.x * 2) { float2 promoted = __half22float2(*reinterpret_cast(&half_ptr[base + i])); float2* coord_pair = reinterpret_cast(&shared_point->coords[i]); *coord_pair = promoted; } - if (((size_t)dim & 1u) != 0u && threadIdx.x == 0) { - shared_point->coords[dim - 1] = __half2float(half_ptr[base + dim - 1]); - } } template @@ -1164,14 +1202,18 @@ __device__ void update_shared_point_warp_half_to_float( shared_point->id = id; shared_point->Dim = dim; const size_t base = (size_t)id * (size_t)dim; + // Odd dim => odd rows are under-aligned for half2 loads; promote scalar. + if ((dim & 1) != 0) { + for (size_t i = laneId; i < (size_t)dim; i += 32) { + shared_point->coords[i] = __half2float(half_ptr[base + i]); + } + return; + } for (size_t i = laneId * 2; i + 1 < (size_t)dim; i += 64) { float2 promoted = __half22float2(*reinterpret_cast(&half_ptr[base + i])); float2* coord_pair = reinterpret_cast(&shared_point->coords[i]); *coord_pair = promoted; } - if (((size_t)dim & 1u) != 0u && laneId == 0) { - shared_point->coords[dim - 1] = __half2float(half_ptr[base + dim - 1]); - } } template @@ -1182,13 +1224,17 @@ __device__ void update_shared_point_warp_fp16_query_smem( shared_point->id = id; shared_point->Dim = dim; const size_t base = (size_t)id * (size_t)dim; + // Odd dim => odd rows are under-aligned for half2 loads; copy scalar. + if ((dim & 1) != 0) { + for (size_t i = laneId; i < (size_t)dim; i += 32) { + shared_point->coords[i] = half_ptr[base + i]; + } + return; + } for (size_t i = laneId * 2; i + 1 < (size_t)dim; i += 64) { *reinterpret_cast(&shared_point->coords[i]) = *reinterpret_cast(&half_ptr[base + i]); } - if (((size_t)dim & 1u) != 0u && laneId == 0) { - shared_point->coords[dim - 1] = half_ptr[base + dim - 1]; - } } template @@ -1198,13 +1244,17 @@ __device__ void update_shared_point_warp_fp16_query_smem( shared_point->id = id; shared_point->Dim = dim; const size_t base = (size_t)id * (size_t)dim; + // Odd dim => odd rows are under-aligned for float2 loads; convert scalar. + if ((dim & 1) != 0) { + for (size_t i = laneId; i < (size_t)dim; i += 32) { + shared_point->coords[i] = __float2half(data_ptr[base + i]); + } + return; + } for (size_t i = laneId * 2; i + 1 < (size_t)dim; i += 64) { float2 v = *reinterpret_cast(&data_ptr[base + i]); *reinterpret_cast(&shared_point->coords[i]) = __float22half2_rn(v); } - if (((size_t)dim & 1u) != 0u && laneId == 0) { - shared_point->coords[dim - 1] = __float2half(data_ptr[base + dim - 1]); - } } template