diff --git a/c/src/neighbors/vamana.cpp b/c/src/neighbors/vamana.cpp index d1686ad96f..03ffcf6811 100644 --- a/c/src/neighbors/vamana.cpp +++ b/c/src/neighbors/vamana.cpp @@ -1,9 +1,10 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #include +#include #include #include @@ -82,6 +83,8 @@ extern "C" cuvsError_t cuvsVamanaIndexDestroy(cuvsVamanaIndex_t index_c_ptr) if (index.addr != 0) { if (index.dtype.code == kDLFloat && index.dtype.bits == 32) { delete reinterpret_cast*>(index.addr); + } else if (index.dtype.code == kDLFloat && index.dtype.bits == 16) { + delete reinterpret_cast*>(index.addr); } else if (index.dtype.code == kDLInt && index.dtype.bits == 8) { delete reinterpret_cast*>(index.addr); } else if (index.dtype.code == kDLUInt && index.dtype.bits == 8) { @@ -100,6 +103,10 @@ extern "C" cuvsError_t cuvsVamanaIndexGetDims(cuvsVamanaIndex_t index, int* dim) auto index_ptr = reinterpret_cast*>(index->addr); *dim = index_ptr->dim(); + } else if (index->dtype.code == kDLFloat && index->dtype.bits == 16) { + auto index_ptr = + reinterpret_cast*>(index->addr); + *dim = index_ptr->dim(); } else if (index->dtype.code == kDLInt && index->dtype.bits == 8) { auto index_ptr = reinterpret_cast*>(index->addr); @@ -123,6 +130,8 @@ extern "C" cuvsError_t cuvsVamanaBuild(cuvsResources_t res, if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) { index->addr = reinterpret_cast(_build(res, params, dataset_tensor)); + } else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 16) { + index->addr = reinterpret_cast(_build(res, params, dataset_tensor)); } else if (dataset.dtype.code == kDLInt && dataset.dtype.bits == 8) { index->addr = reinterpret_cast(_build(res, params, dataset_tensor)); } else if (dataset.dtype.code == kDLUInt && dataset.dtype.bits == 8) { @@ -143,6 +152,8 @@ extern "C" cuvsError_t cuvsVamanaSerialize(cuvsResources_t res, return cuvs::core::translate_exceptions([=] { if (index->dtype.code == kDLFloat && index->dtype.bits == 32) { _serialize(res, filename, index, include_dataset); + } else if (index->dtype.code == kDLFloat && index->dtype.bits == 16) { + _serialize(res, filename, index, include_dataset); } else if (index->dtype.code == kDLInt && index->dtype.bits == 8) { _serialize(res, filename, index, include_dataset); } else if (index->dtype.code == kDLUInt && index->dtype.bits == 8) { diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 89cbadfcfc..ef8fe9a943 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -1414,10 +1414,12 @@ if(NOT BUILD_CPU_ONLY) src/neighbors/tiered_index.cu src/neighbors/sparse_brute_force.cu src/neighbors/vamana_build_float.cu + src/neighbors/vamana_build_half.cu src/neighbors/vamana_build_uint8.cu src/neighbors/vamana_build_int8.cu src/neighbors/vamana_codebooks_float.cu src/neighbors/vamana_serialize_float.cu + src/neighbors/vamana_serialize_half.cu src/neighbors/vamana_serialize_uint8.cu src/neighbors/vamana_serialize_int8.cu src/preprocessing/quantize/scalar.cu diff --git a/cpp/include/cuvs/neighbors/vamana.hpp b/cpp/include/cuvs/neighbors/vamana.hpp index 645adc5c5d..6c54b09d43 100644 --- a/cpp/include/cuvs/neighbors/vamana.hpp +++ b/cpp/include/cuvs/neighbors/vamana.hpp @@ -6,6 +6,7 @@ #pragma once #include "common.hpp" +#include #include #include #include @@ -344,6 +345,16 @@ auto build(raft::resources const& res, raft::host_matrix_view dataset) -> cuvs::neighbors::vamana::index; +auto build(raft::resources const& res, + const cuvs::neighbors::vamana::index_params& params, + raft::device_matrix_view dataset) + -> cuvs::neighbors::vamana::index; + +auto build(raft::resources const& res, + const cuvs::neighbors::vamana::index_params& params, + raft::host_matrix_view dataset) + -> cuvs::neighbors::vamana::index; + /** * @brief Build the index from the dataset for efficient DiskANN search. * @@ -520,6 +531,12 @@ void serialize(raft::resources const& handle, bool include_dataset = true, bool sector_aligned = false); +void serialize(raft::resources const& handle, + const std::string& file_prefix, + const cuvs::neighbors::vamana::index& index, + bool include_dataset = true, + bool sector_aligned = false); + /** * Save the index to file. * diff --git a/cpp/src/neighbors/detail/vamana/greedy_search.cuh b/cpp/src/neighbors/detail/vamana/greedy_search.cuh index 4e71c1189c..ba2c49834f 100644 --- a/cpp/src/neighbors/detail/vamana/greedy_search.cuh +++ b/cpp/src/neighbors/detail/vamana/greedy_search.cuh @@ -5,7 +5,7 @@ #pragma once -#include +#include #include "macros.cuh" #include "priority_queue.cuh" @@ -18,6 +18,7 @@ #include #include +#include #include namespace cuvs::neighbors::vamana::detail { @@ -74,6 +75,9 @@ __global__ void SortPairsKernel(void* query_list_ptr, int num_queries, int topk) /******************************************************************************************** GPU kernel to perform a batched GreedySearch on a graph. Since this is used for Vamana construction, the entire visited list is kept and stored within the query_list. + Uses 128 threads per block (4 warps), each warp processes one query independently + with per-warp scratch space to avoid block synchronization overhead. + Input - graph with edge lists, dataset vectors, query_list_ptr with the ids of dataset vectors to be searched. All inputs, including dataset, must be device accessible. @@ -85,7 +89,7 @@ template , raft::memory_type::host>> -__global__ void GreedySearchKernel( +__global__ __launch_bounds__(128, 12) void GreedySearchKernel( raft::device_matrix_view graph, raft::mdspan, raft::row_major, Accessor> dataset, void* query_list_ptr, @@ -96,186 +100,196 @@ __global__ void GreedySearchKernel( int max_queue_size, Node* topk_pq_mem) { - int n = dataset.extent(0); + const int warpIdx = threadIdx.x / 32; + const int laneId = threadIdx.x % 32; + int dim = dataset.extent(1); int degree = graph.extent(1); QueryCandidates* query_list = static_cast*>(query_list_ptr); - static __shared__ int topk_q_size; - static __shared__ int cand_q_size; - static __shared__ accT cur_k_max; - static __shared__ int k_max_idx; - - static __shared__ Point s_query; - - union ShmemLayout { - // All blocksort sizes have same alignment (16) - T coords; - int neighborhood_arr; - DistPair candidate_queue; - }; - - int align_padding = (((dim - 1) / alignof(ShmemLayout)) + 1) * alignof(ShmemLayout) - dim; + using QueryCoordT = typename greedy_search_query_coord::type; - // Dynamic shared memory used for blocksort, temp vector storage, and neighborhood list - extern __shared__ __align__(alignof(ShmemLayout)) char smem[]; + int align_padding = raft::alignTo(dim, 16) - dim; - size_t smem_offset = 0; + // Only use fp16 coords in shared memory if type is fp16 and dim >= 512 + const bool fp16_query_smem = greedy_search_use_fp16_query_smem(dim); - T* s_coords = reinterpret_cast(&smem[smem_offset]); - smem_offset += (dim + align_padding) * sizeof(T); + extern __shared__ __align__(16) char smem[]; - Node* topk_pq = &topk_pq_mem[blockIdx.x * topk]; - - int* neighbor_array = reinterpret_cast(&smem[smem_offset]); - smem_offset += degree * sizeof(int); + // Per-warp shared memory layout: coords, neighbor_array, candidate_queue + const int coords_size = (dim + align_padding) * greedy_search_query_smem_elem_size(dim); + const int neighbor_size = degree * sizeof(IdxT); + const int queue_size_bytes = max_queue_size * sizeof(DistPair); + const int per_warp_size = (coords_size + neighbor_size + queue_size_bytes + 15) & ~15; + char* warp_smem = &smem[warpIdx * per_warp_size]; + __half* s_coords_half = reinterpret_cast<__half*>(warp_smem); + QueryCoordT* s_coords = reinterpret_cast(warp_smem); + IdxT* neighbor_array = reinterpret_cast(warp_smem + coords_size); DistPair* candidate_queue_smem = - reinterpret_cast*>(&smem[smem_offset]); - - s_query.coords = s_coords; - s_query.Dim = dim; + reinterpret_cast*>(warp_smem + coords_size + neighbor_size); + + // 4 warps per block + static __shared__ int topk_q_size[4]; + static __shared__ int cand_q_size[4]; + static __shared__ accT cur_k_max[4]; + static __shared__ int k_max_idx[4]; + static __shared__ int num_neighbors[4]; + + // Different code path for fp16 since it is gated by dim and datatype + Point<__half, accT> s_query_half; + Point s_query; + if (fp16_query_smem) { + s_query_half.Dim = dim; + s_query_half.coords = s_coords_half; + } else { + s_query.Dim = dim; + s_query.coords = s_coords; + } PriorityQueue heap_queue; - - if (threadIdx.x == 0) { - heap_queue.initialize(candidate_queue_smem, max_queue_size, &cand_q_size); + if (laneId == 0) { + heap_queue.initialize(candidate_queue_smem, max_queue_size, &cand_q_size[warpIdx]); } - static __shared__ int num_neighbors; - - for (int i = blockIdx.x; i < num_queries; i += gridDim.x) { - __syncthreads(); + Node* topk_pq = &topk_pq_mem[(blockIdx.x * 4 + warpIdx) * topk]; + const T* vec_ptr = &dataset(0, 0); - // resetting visited list - query_list[i].reset(); + for (int i = blockIdx.x * 4 + warpIdx; i < num_queries; i += gridDim.x * 4) { + query_list[i].reset_warp(laneId); - // storing the current query vector into shared memory - update_shared_point(&s_query, &dataset(0, 0), query_list[i].queryId, dim); + int cur_query_id = query_list[i].queryId; + if (fp16_query_smem) { + if constexpr (is_cuda_fp16_v) { + update_shared_point_warp_fp16_query_smem( + &s_query_half, vec_ptr, cur_query_id, dim, laneId); + } else if constexpr (std::is_same_v) { + update_shared_point_warp_fp16_query_smem( + &s_query_half, vec_ptr, cur_query_id, dim, laneId); + } else { + update_shared_point_warp_fp16_query_smem( + &s_query_half, vec_ptr, cur_query_id, dim, laneId); + } + } else if constexpr (is_cuda_fp16_v) { + update_shared_point_warp_half_to_float(&s_query, vec_ptr, cur_query_id, dim, laneId); + } else { + update_shared_point_warp(&s_query, vec_ptr, cur_query_id, dim, laneId); + } - if (threadIdx.x == 0) { - topk_q_size = 0; - cand_q_size = 0; - s_query.id = query_list[i].queryId; - cur_k_max = 0; - k_max_idx = 0; + if (laneId == 0) { + topk_q_size[warpIdx] = 0; + cand_q_size[warpIdx] = 0; + if (fp16_query_smem) { + s_query_half.id = cur_query_id; + } else { + s_query.id = cur_query_id; + } + cur_k_max[warpIdx] = 0; + k_max_idx[warpIdx] = 0; heap_queue.reset(); } - __syncthreads(); - - Point* query_vec; - - // Just start from medoid every time, rather than multiple set_ups - query_vec = &s_query; - query_vec->Dim = dim; - const T* medoid = &dataset((size_t)medoid_id, 0); - accT medoid_dist = dist(query_vec->coords, medoid, dim, metric); - - if (threadIdx.x == 0) { heap_queue.insert_back(medoid_dist, medoid_id); } - __syncthreads(); + accT medoid_dist; + if (fp16_query_smem) { + medoid_dist = dist_warp_half_query( + s_coords_half, &vec_ptr[(size_t)medoid_id * (size_t)dim], dim, metric, laneId); + } else if constexpr (is_cuda_fp16_v) { + medoid_dist = + dist_warp(s_coords, &vec_ptr[(size_t)medoid_id * (size_t)dim], dim, metric, laneId); + } else { + medoid_dist = dist_warp( + s_coords, &vec_ptr[(size_t)medoid_id * (size_t)dim], dim, metric, laneId); + } - while (cand_q_size != 0) { - __syncthreads(); + if (laneId == 0) { heap_queue.insert_back(medoid_dist, medoid_id); } + while (cand_q_size[warpIdx] != 0) { int cand_num; accT cur_distance; - if (threadIdx.x == 0) { - Node test_cand; + if (laneId == 0) { DistPair test_cand_out = heap_queue.pop(); - test_cand.distance = test_cand_out.dist; - test_cand.nodeid = test_cand_out.idx; - cand_num = test_cand.nodeid; + cand_num = test_cand_out.idx; cur_distance = test_cand_out.dist; } - __syncthreads(); - - cand_num = raft::shfl(cand_num, 0); - - __syncthreads(); - - if (query_list[i].check_visited(cand_num, cur_distance)) { continue; } - + cand_num = raft::shfl(cand_num, 0); cur_distance = raft::shfl(cur_distance, 0); - // stop condition for the graph traversal process + if (query_list[i].check_visited_warp(cand_num, cur_distance, laneId)) { continue; } + bool done = false; bool pass_flag = false; - if (topk_q_size == topk) { - // Check the current node with the worst candidate in top-k queue - if (threadIdx.x == 0) { - if (cur_k_max <= cur_distance) { done = true; } + if (topk_q_size[warpIdx] == topk) { + if (laneId == 0) { + if (cur_k_max[warpIdx] <= cur_distance) { done = true; } } - done = raft::shfl(done, 0); if (done) { if (query_list[i].size < topk) { pass_flag = true; - } - - else if (query_list[i].size >= topk) { + } else if (query_list[i].size >= topk) { break; } } } - // The current node is closer to the query vector than the worst candidate in top-K queue, so - // enquee the current node in top-k queue Node new_cand; new_cand.distance = cur_distance; new_cand.nodeid = cand_num; - if (check_duplicate(topk_pq, topk_q_size, new_cand) == false) { + if (check_duplicate_warp(topk_pq, topk_q_size[warpIdx], new_cand, laneId) == false) { if (!pass_flag) { - parallel_pq_max_enqueue( - topk_pq, &topk_q_size, topk, new_cand, &cur_k_max, &k_max_idx); - - __syncthreads(); + parallel_pq_max_enqueue_warp(topk_pq, + &topk_q_size[warpIdx], + topk, + new_cand, + &cur_k_max[warpIdx], + &k_max_idx[warpIdx], + laneId); } } else { - // already visited continue; } - num_neighbors = degree; - __syncthreads(); + num_neighbors[warpIdx] = degree; - for (size_t j = threadIdx.x; j < degree; j += blockDim.x) { - // Load neighbors from the graph array and store them in neighbor array (shared memory) + for (size_t j = laneId; j < degree; j += 32) { neighbor_array[j] = graph(cand_num, j); if (neighbor_array[j] == raft::upper_bound()) - atomicMin(&num_neighbors, (int)j); // warp-wide min to find the number of neighbors + atomicMin(&num_neighbors[warpIdx], + (int)j); // warp-wide min to find the number of neighbors } - // computing distances between the query vector and neighbor vectors then enqueue in priority - // queue. - enqueue_all_neighbors( - num_neighbors, query_vec, &dataset(0, 0), neighbor_array, heap_queue, dim, metric); - - __syncthreads(); - - } // End cand_q_size != 0 loop + enqueue_all_neighbors_warp(num_neighbors[warpIdx], + fp16_query_smem, + s_coords_half, + s_coords, + vec_ptr, + neighbor_array, + heap_queue, + dim, + metric, + laneId); + } bool self_found = false; - // Remove self edges - for (int j = threadIdx.x; j < query_list[i].size; j += blockDim.x) { - if (query_list[i].ids[j] == query_vec->id) { + for (int j = laneId; j < query_list[i].size; j += 32) { + if (query_list[i].ids[j] == cur_query_id) { query_list[i].dists[j] = raft::upper_bound(); query_list[i].ids[j] = raft::upper_bound(); - self_found = true; // Flag to reduce size by 1 + self_found = true; // Flat to reduce size by 1 } } + self_found = (raft::ballot(self_found) != 0); - for (int j = query_list[i].size + threadIdx.x; j < query_list[i].maxSize; j += blockDim.x) { + for (int j = query_list[i].size + laneId; j < query_list[i].maxSize; j += 32) { query_list[i].ids[j] = raft::upper_bound(); query_list[i].dists[j] = raft::upper_bound(); } - __syncthreads(); - if (self_found) query_list[i].size--; + if (self_found && laneId == 0) { query_list[i].size--; } } return; diff --git a/cpp/src/neighbors/detail/vamana/macros.cuh b/cpp/src/neighbors/detail/vamana/macros.cuh index 8ec1509677..d154c34ae7 100644 --- a/cpp/src/neighbors/detail/vamana/macros.cuh +++ b/cpp/src/neighbors/detail/vamana/macros.cuh @@ -1,12 +1,36 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once +#include +#include + namespace cuvs::neighbors::vamana::detail { +// RobustPrune wide-dim optimizations: smem candidate cache and GreedySearch distance reuse. +static constexpr int kRobustPruneCandCacheMinDim = 128; + +// Minimum dim for 8 warps on degree-64 occlusion sweep (half). Below this, barrier tax wins. +static constexpr int kRobustPruneMultiWarpMinDimHalf = 960; + +// Occlusion sweep is multi-warp (occId += num_warps). Extra warps hide wide-dim djk latency, +// but narrow dim, low degree, and byte-wide (int8) / half distances have cheap djk -- barrier +// overhead wins unless dim is very wide. +template +__host__ __device__ inline int robust_prune_block_dim(int dim, int degree) +{ + if (degree < 64 || dim < kRobustPruneCandCacheMinDim) { return 128; } + if constexpr (std::is_same_v || std::is_same_v) { return 128; } + if constexpr (std::is_same_v) { + if (dim < kRobustPruneMultiWarpMinDimHalf) { return 128; } + return 256; + } + return 256; +} + /* Macros to compute the shared memory requirements for CUB primitives used by search and prune */ #define COMPUTE_SMEM_SIZE(degree, visited_size, DEG, CANDS) \ if (degree == DEG && visited_size <= CANDS && visited_size > CANDS / 2) { \ diff --git a/cpp/src/neighbors/detail/vamana/priority_queue.cuh b/cpp/src/neighbors/detail/vamana/priority_queue.cuh index a1ce4e7159..5e44563648 100644 --- a/cpp/src/neighbors/detail/vamana/priority_queue.cuh +++ b/cpp/src/neighbors/detail/vamana/priority_queue.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -211,40 +211,36 @@ __host__ __device__ bool operator>(const Node& first, const Node other.distance; } +// each warp scans its own pq with laneId and stride 32 to find duplicates template -__device__ bool check_duplicate(const Node* pq, const int size, Node new_node) +__device__ bool check_duplicate_warp(const Node* pq, + const int size, + Node new_node, + int laneId) { bool found = false; - for (int i = threadIdx.x; i < size; i += blockDim.x) { + for (int i = laneId; i < size; i += 32) { if (pq[i].nodeid == new_node.nodeid) { found = true; break; } } - unsigned mask = raft::ballot(found); - - if (mask == 0) - return false; - - else - return true; + return (mask != 0); } -/* - Enqueuing a input value into parallel queue with tracker -*/ +// Warp-level version: no __syncthreads, uses laneId for single-thread ops and warp shuffle template -__inline__ __device__ void parallel_pq_max_enqueue(Node* pq, - int* size, - const int pq_size, - Node input_data, - SUMTYPE* cur_max_val, - int* max_idx) +__inline__ __device__ void parallel_pq_max_enqueue_warp(Node* pq, + int* size, + const int pq_size, + Node input_data, + SUMTYPE* cur_max_val, + int* max_idx, + int laneId) { if (*size < pq_size) { - __syncthreads(); - if (threadIdx.x == 0) { + if (laneId == 0) { pq[*size].distance = input_data.distance; pq[*size].nodeid = input_data.nodeid; *size = *size + 1; @@ -253,21 +249,17 @@ __inline__ __device__ void parallel_pq_max_enqueue(Node* pq, *max_idx = *size - 1; } } - __syncthreads(); return; } else { - if (input_data.distance >= (*cur_max_val)) { - __syncthreads(); - return; - } - if (threadIdx.x == 0) { + if (input_data.distance >= (*cur_max_val)) { return; } + if (laneId == 0) { pq[*max_idx].distance = input_data.distance; pq[*max_idx].nodeid = input_data.nodeid; } int idx = 0; SUMTYPE max_val = pq[0].distance; - for (int i = threadIdx.x; i < pq_size; i += 32) { + for (int i = laneId; i < pq_size; i += 32) { if (pq[i].distance > max_val) { max_val = pq[i].distance; idx = i; @@ -283,34 +275,72 @@ __inline__ __device__ void parallel_pq_max_enqueue(Node* pq, } } - if (threadIdx.x == 31) { + if (laneId == 31) { *max_idx = idx; *cur_max_val = max_val; } } - __syncthreads(); } -/* - Compute the distances between the source vector and all nodes in the neighbor_array and enqueue - them in the PQ -*/ -template -__forceinline__ __device__ void enqueue_all_neighbors(int num_neighbors, - Point* query_vec, - const T* vec_ptr, - int* neighbor_array, - PriorityQueue& heap_queue, - int dim, - cuvs::distance::DistanceType metric) +// Warp-level version: lane 0 does insert_back, no __syncthreads +template +__forceinline__ __device__ void enqueue_all_neighbors_warp(int num_neighbors, + Point* query_vec, + const DataT* vec_ptr, + IdxT* neighbor_array, + PriorityQueue& heap_queue, + int dim, + cuvs::distance::DistanceType metric, + int laneId) { for (int i = 0; i < num_neighbors; i++) { - accT dist_out = dist( - query_vec->coords, &vec_ptr[(size_t)(neighbor_array[i]) * (size_t)(dim)], dim, metric); + const DataT* neighbor_vec = &vec_ptr[(size_t)(neighbor_array[i]) * (size_t)(dim)]; + accT dist_out; + if constexpr (std::is_same_v) { + dist_out = + dist_warp_half_query(query_vec->coords, neighbor_vec, dim, metric, laneId); + } else if constexpr (std::is_same_v && is_cuda_fp16_v) { + dist_out = dist_warp(query_vec->coords, neighbor_vec, dim, metric, laneId); + } else { + static_assert(std::is_same_v); + dist_out = dist_warp(query_vec->coords, neighbor_vec, dim, metric, laneId); + } + if (laneId == 0) { heap_queue.insert_back(dist_out, neighbor_array[i]); } + } +} - __syncthreads(); - if (threadIdx.x == 0) { heap_queue.insert_back(dist_out, neighbor_array[i]); } - __syncthreads(); +// Half-precision version with two code paths based on query being fp16 or fp32 +template +__forceinline__ __device__ void enqueue_all_neighbors_warp( + int num_neighbors, + bool fp16_query_smem, + __half* s_coords_half, + typename greedy_search_query_coord::type* s_coords, + const T* vec_ptr, + IdxT* neighbor_array, + PriorityQueue& heap_queue, + int dim, + cuvs::distance::DistanceType metric, + int laneId) +{ + if (fp16_query_smem) { + Point<__half, accT> query_vec; + query_vec.coords = s_coords_half; + query_vec.Dim = dim; + enqueue_all_neighbors_warp<__half, T, accT, IdxT>( + num_neighbors, &query_vec, vec_ptr, neighbor_array, heap_queue, dim, metric, laneId); + } else if constexpr (is_cuda_fp16_v) { + Point query_vec; + query_vec.coords = reinterpret_cast(s_coords); + query_vec.Dim = dim; + enqueue_all_neighbors_warp( + num_neighbors, &query_vec, vec_ptr, neighbor_array, heap_queue, dim, metric, laneId); + } else { + Point query_vec; + query_vec.coords = reinterpret_cast(s_coords); + query_vec.Dim = dim; + enqueue_all_neighbors_warp( + num_neighbors, &query_vec, vec_ptr, neighbor_array, heap_queue, dim, metric, laneId); } } diff --git a/cpp/src/neighbors/detail/vamana/robust_prune.cuh b/cpp/src/neighbors/detail/vamana/robust_prune.cuh index 31fb6d589f..32aeffa0a5 100644 --- a/cpp/src/neighbors/detail/vamana/robust_prune.cuh +++ b/cpp/src/neighbors/detail/vamana/robust_prune.cuh @@ -9,6 +9,8 @@ #include +#include + #include "macros.cuh" #include "vamana_structs.cuh" @@ -61,7 +63,7 @@ __global__ void RobustPruneKernel( int visited_size, cuvs::distance::DistanceType metric, float alpha, - T* s_coords_mem) + typename greedy_search_query_coord::type* s_coords_mem) { int n = dataset.extent(0); int dim = dataset.extent(1); @@ -69,6 +71,8 @@ __global__ void RobustPruneKernel( QueryCandidates* query_list = static_cast*>(query_list_ptr); + using QueryCoordT = typename greedy_search_query_coord::type; + union ShmemLayout { // All blocksort sizes have same alignment (16) float occlusion; @@ -80,24 +84,46 @@ __global__ void RobustPruneKernel( int align_padding = raft::alignTo(dim, alignof(ShmemLayout)) - dim; - float* occlusion_list = reinterpret_cast(smem); + float* occlusion_list = reinterpret_cast(smem); + const int nbh_list_offset = (degree + visited_size) * sizeof(float); DistPair* new_nbh_list = - reinterpret_cast*>(&smem[(degree + visited_size) * sizeof(float)]); + reinterpret_cast*>(&smem[nbh_list_offset]); + const int query_cache_offset = + nbh_list_offset + (degree + visited_size) * sizeof(DistPair); + DistPair* query_cache = + reinterpret_cast*>(&smem[query_cache_offset]); + const int cand_coords_offset = query_cache_offset + visited_size * sizeof(DistPair); + const int coord_bytes = (dim + align_padding) * static_cast(sizeof(QueryCoordT)); + int graph_dists_offset = cand_coords_offset; + QueryCoordT* s_cand_coords = nullptr; + if (dim >= kRobustPruneCandCacheMinDim) { + s_cand_coords = reinterpret_cast(&smem[cand_coords_offset]); + graph_dists_offset = cand_coords_offset + coord_bytes; + } + accT* graph_dists = reinterpret_cast(&smem[graph_dists_offset]); + const int graph_ids_offset = graph_dists_offset + degree * sizeof(accT); + IdxT* graph_ids = reinterpret_cast(&smem[graph_ids_offset]); - static __shared__ Point s_query; + static __shared__ Point s_query; s_query.coords = &s_coords_mem[blockIdx.x * (dim + align_padding)]; s_query.Dim = dim; static __shared__ int prev_edges; - static __shared__ accT graphDist; + static __shared__ int s_accept_count; + static __shared__ int s_do_accept; + static __shared__ int s_res_size; + + const int laneId = threadIdx.x & 31; + const int warpId = threadIdx.x >> 5; + const int num_warps = blockDim.x >> 5; for (int i = blockIdx.x; i < num_queries; i += gridDim.x) { int queryId = query_list[i].queryId; - update_shared_point(&s_query, &dataset(0, 0), queryId, dim, i); - - int graphIdx = 0; - int listIdx = 0; - int res_size = degree + visited_size; + if constexpr (is_cuda_fp16_v) { + update_shared_point_half_to_float(&s_query, &dataset(0, 0), queryId, dim); + } else { + update_shared_point(&s_query, &dataset(0, 0), queryId, dim, i); + } // Count total valid edge candidates __syncthreads(); @@ -112,106 +138,173 @@ __global__ void RobustPruneKernel( } for (int j = threadIdx.x; j < degree + visited_size; j += blockDim.x) { occlusion_list[j] = 0.0; + if (j < visited_size) { + query_cache[j].idx = query_list[i].ids[j]; + query_cache[j].dist = query_list[i].dists[j]; + } } __syncthreads(); - DistPair next_cand; - // Merge graph and candidate list - for (int outIdx = 0; outIdx < degree + visited_size; outIdx++) { - // Check if no more valid elements from graph or list - if (graphIdx < degree && graph(queryId, graphIdx) == raft::upper_bound()) { - graphIdx = degree; + // Precompute graph-edge distances in parallel; reuse GreedySearch dists when bit-exact. + const int visited_count = query_list[i].size; + const IdxT* visited_ids = query_list[i].ids; + const accT* visited_dists = query_list[i].dists; + const bool reuse_search_dists = (dim >= kRobustPruneCandCacheMinDim); + for (int j = warpId; j < prev_edges; j += num_warps) { + IdxT gid = graph(queryId, j); + accT d; + bool found = false; + if (reuse_search_dists) { + found = lookup_visited_dist_warp(visited_ids, visited_dists, visited_count, gid, d, laneId); + } + if (!found) { + if constexpr (is_cuda_fp16_v) { + d = dist_warp(s_query.coords, &dataset((size_t)gid, 0), dim, metric, laneId); + } else { + d = dist_warp(s_query.coords, &dataset((size_t)gid, 0), dim, metric, laneId); + } } - if (listIdx < visited_size && query_list[i].ids[listIdx] == raft::upper_bound()) { - listIdx = visited_size; + if (laneId == 0) { + graph_dists[j] = d; + graph_ids[j] = gid; } + } + for (int j = threadIdx.x; j < degree; j += blockDim.x) { + if (j >= prev_edges) { graph_ids[j] = raft::upper_bound(); } + } + __syncthreads(); - // Get next candidate vector for list - if (graphIdx >= degree) { - if (listIdx >= visited_size) { // Fill remaining list if no candidates - if (res_size > outIdx) res_size = outIdx; // Set result size - new_nbh_list[outIdx].idx = raft::upper_bound(); - new_nbh_list[outIdx].dist = raft::upper_bound(); - __syncthreads(); - continue; - } else { - next_cand.idx = query_list[i].ids[listIdx]; - next_cand.dist = query_list[i].dists[listIdx]; - listIdx++; + if (threadIdx.x == 0) { + int graphIdx = 0; + int listIdx = 0; + int merged_size = degree + visited_size; + + DistPair next_cand; + // Merge graph and candidate list from smem (no global reads during merge). + for (int outIdx = 0; outIdx < degree + visited_size; outIdx++) { + // Check if no more valid elements from graph or list + if (graphIdx < degree && graph_ids[graphIdx] == raft::upper_bound()) { + graphIdx = degree; + } + if (listIdx < visited_size && query_cache[listIdx].idx == raft::upper_bound()) { + listIdx = visited_size; } - } else if (listIdx >= visited_size) { - next_cand.idx = graph(queryId, graphIdx); - accT tempDist = - dist(s_query.coords, &dataset((size_t)graph(queryId, graphIdx), 0), dim, metric); - if (threadIdx.x == 0) graphDist = tempDist; - __syncthreads(); - next_cand.dist = graphDist; - graphIdx++; - } else { - accT listDist = query_list[i].dists[listIdx]; - - accT tempDist = - dist(s_query.coords, &dataset((size_t)graph(queryId, graphIdx), 0), dim, metric); - if (threadIdx.x == 0) graphDist = tempDist; - __syncthreads(); - - if (listDist <= graphDist) { - next_cand.idx = query_list[i].ids[listIdx]; - next_cand.dist = listDist; - - if (graph(queryId, graphIdx) == query_list[i].ids[listIdx]) { // Duplicate found! - graphIdx++; // Skip the duplicate + + // Get next candidate vector for list + if (graphIdx >= degree) { + if (listIdx >= visited_size) { // Fill remaining list if no candidates + if (merged_size > outIdx) merged_size = outIdx; // Set result size + new_nbh_list[outIdx].idx = raft::upper_bound(); + new_nbh_list[outIdx].dist = raft::upper_bound(); + continue; + } else { + next_cand = query_cache[listIdx]; + listIdx++; } - listIdx++; - } else { - next_cand.idx = graph(queryId, graphIdx); - next_cand.dist = graphDist; + } else if (listIdx >= visited_size) { + next_cand.idx = graph_ids[graphIdx]; + next_cand.dist = graph_dists[graphIdx]; graphIdx++; + } else { + accT listDist = query_cache[listIdx].dist; + IdxT listId = query_cache[listIdx].idx; + IdxT graphId = graph_ids[graphIdx]; + + if (graphId == listId) { + next_cand.idx = listId; + next_cand.dist = listDist; + graphIdx++; + listIdx++; + } else if (listDist <= graph_dists[graphIdx]) { + next_cand.idx = listId; + next_cand.dist = listDist; + listIdx++; + } else { + next_cand.idx = graphId; + next_cand.dist = graph_dists[graphIdx]; + graphIdx++; + } } - } - new_nbh_list[outIdx].idx = next_cand.idx; - new_nbh_list[outIdx].dist = next_cand.dist; + new_nbh_list[outIdx].idx = next_cand.idx; + new_nbh_list[outIdx].dist = next_cand.dist; + } + s_res_size = merged_size; } + __syncthreads(); // If we need to prune at all... - if (res_size > degree) { - int accept_count = 0; + if (s_res_size > degree) { + if (threadIdx.x == 0) s_accept_count = 0; + __syncthreads(); + const bool cache_cand_in_smem = dim >= kRobustPruneCandCacheMinDim; + Point s_cand; + if (cache_cand_in_smem) { + s_cand.coords = s_cand_coords; + s_cand.Dim = dim; + } // Go through different alpha values. These constants are hard-coded in the MSFT DiskANN code - for (float cur_alpha = 1.0; cur_alpha <= alpha && accept_count < degree; cur_alpha *= 1.2) { - for (int pass_start = 0; pass_start < res_size && accept_count < degree; pass_start++) { - // pick next non-occluded element - if (occlusion_list[pass_start] == raft::lower_bound() || - occlusion_list[pass_start] > cur_alpha) { - continue; // Skip over elements already pruned or already accepted + for (float cur_alpha = 1.0; cur_alpha <= alpha && s_accept_count < degree; cur_alpha *= 1.2) { + for (int pass_start = 0; pass_start < s_res_size && s_accept_count < degree; pass_start++) { + if (threadIdx.x == 0) { + s_do_accept = + (occlusion_list[pass_start] != raft::lower_bound() && + occlusion_list[pass_start] <= cur_alpha && new_nbh_list[pass_start].idx != queryId) + ? 1 + : 0; } + __syncthreads(); - if (new_nbh_list[pass_start].idx == queryId) { continue; } - - T* cand_ptr = const_cast(&dataset((size_t)(new_nbh_list[pass_start].idx), 0)); - - occlusion_list[pass_start] = raft::lower_bound(); // Mark as "accepted" - accept_count++; + // update_shared_point uses all block threads; barrier must be uniform across warps. + if (s_do_accept && cache_cand_in_smem) { + if constexpr (is_cuda_fp16_v) { + update_shared_point_half_to_float( + &s_cand, &dataset(0, 0), new_nbh_list[pass_start].idx, dim); + } else { + update_shared_point( + &s_cand, &dataset(0, 0), new_nbh_list[pass_start].idx, dim); + } + } + __syncthreads(); - // Update rest of the occlusion list - for (int occId = pass_start + 1; occId < res_size; occId++) { - if (occlusion_list[occId] <= alpha && - occlusion_list[occId] != raft::lower_bound()) { - T* k_ptr = const_cast(&dataset((size_t)(new_nbh_list[occId].idx), 0)); - accT djk = dist(cand_ptr, k_ptr, dim, metric); - accT new_occ = (float)(new_nbh_list[occId].dist / djk); + if (s_do_accept) { + if (threadIdx.x == 0) { + occlusion_list[pass_start] = raft::lower_bound(); + s_accept_count++; + } - occlusion_list[occId] = std::max(occlusion_list[occId], new_occ); + T* cand_ptr = const_cast(&dataset((size_t)(new_nbh_list[pass_start].idx), 0)); + for (int occId = pass_start + 1 + warpId; occId < s_res_size; occId += num_warps) { + if (occlusion_list[occId] <= alpha && + occlusion_list[occId] != raft::lower_bound()) { + T* k_ptr = const_cast(&dataset((size_t)(new_nbh_list[occId].idx), 0)); + accT djk; + if (cache_cand_in_smem) { + if constexpr (is_cuda_fp16_v) { + djk = dist_warp(s_cand.coords, k_ptr, dim, metric, laneId); + } else { + djk = dist_warp(s_cand.coords, k_ptr, dim, metric, laneId); + } + } else { + djk = dist_warp(cand_ptr, k_ptr, dim, metric, laneId); + } + if (laneId == 0) { + accT new_occ = (float)(new_nbh_list[occId].dist / djk); + occlusion_list[occId] = std::max(occlusion_list[occId], new_occ); + } + } } } + // Publish occId occlusion updates before the next pass_start reads occlusion_list. + __syncthreads(); } } // Move all "accepted" candidates to front of list and zero out the rest if (threadIdx.x == 0) { int out_idx = 1; - for (int read_idx = 1; out_idx < accept_count; read_idx++) { + for (int read_idx = 1; out_idx < s_accept_count; read_idx++) { if (occlusion_list[read_idx] == raft::lower_bound()) { // If it is "accepted" new_nbh_list[out_idx].idx = new_nbh_list[read_idx].idx; new_nbh_list[out_idx].dist = new_nbh_list[read_idx].dist; @@ -220,12 +313,12 @@ __global__ void RobustPruneKernel( } } __syncthreads(); - for (int out_idx = accept_count + threadIdx.x; out_idx < degree; out_idx++) { + for (int out_idx = s_accept_count + threadIdx.x; out_idx < degree; out_idx += blockDim.x) { new_nbh_list[out_idx].idx = raft::upper_bound(); new_nbh_list[out_idx].dist = raft::upper_bound(); } - if (threadIdx.x == 0) { res_size = accept_count; } + if (threadIdx.x == 0) { s_res_size = s_accept_count; } __syncthreads(); } @@ -234,7 +327,7 @@ __global__ void RobustPruneKernel( query_list[i].ids[j] = new_nbh_list[j].idx; query_list[i].dists[j] = new_nbh_list[j].dist; } - if (threadIdx.x == 0) { query_list[i].size = res_size; } + if (threadIdx.x == 0) { query_list[i].size = s_res_size; } } } diff --git a/cpp/src/neighbors/detail/vamana/vamana_build.cuh b/cpp/src/neighbors/detail/vamana/vamana_build.cuh index 336d81215b..6555e3765d 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_build.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_build.cuh @@ -28,6 +28,12 @@ #include #include +#include +#include + +#include +#include + #include #include @@ -41,8 +47,9 @@ namespace cuvs::neighbors::vamana::detail { * @{ */ -static const int blockD = 32; -static const int maxBlocks = 10000; +static const int blockD = 32; +static const int blockD_greedy = 128; // 4 warps per block, each warp processes one query +static const int maxBlocks = 10000; // generate random permutation of inserts - TODO do this on GPU / faster template @@ -84,6 +91,26 @@ __global__ void print_queryIds(void* query_list_ptr) #define KERNEL_TIMING (RAFT_LOG_ACTIVE_LEVEL <= RAPIDS_LOGGER_LOG_LEVEL_DEBUG) +template +__global__ void gather_query_sizes(QueryCandidates* query_list, + int* edge_counts, + int count) +{ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) { + edge_counts[i] = query_list[i].size; + } +} + +template +__global__ void scatter_prefix_offsets(QueryCandidates* query_list, + const int* edge_offsets, + int count) +{ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < count; i += blockDim.x * gridDim.x) { + query_list[i].size = edge_offsets[i]; + } +} + /******************************************************************************************** * Main Vamana building function - insert vectors into empty graph in batches * Pre - dataset contains the vector data, host matrix allocated to store the graph @@ -102,10 +129,11 @@ void batched_insert_vamana( IdxT* medoid_id, cuvs::distance::DistanceType metric) { - auto stream = raft::resource::get_cuda_stream(res); - int N = dataset.extent(0); - int dim = dataset.extent(1); - int degree = graph.extent(1); + auto stream = raft::resource::get_cuda_stream(res); + cudaStream_t cs = stream; + int N = dataset.extent(0); + int dim = dataset.extent(1); + int degree = graph.extent(1); // Algorithm params int max_batchsize = (int)(params.max_fraction * (float)N); @@ -172,7 +200,8 @@ void batched_insert_vamana( int align_padding = raft::alignTo(dim, 16) - dim; - auto s_coords_mem = raft::make_device_mdarray( + using QueryCoordT = typename greedy_search_query_coord::type; + auto s_coords_mem = raft::make_device_mdarray( res, raft::resource::get_large_workspace_resource_ref(res), raft::make_extents(min(maxBlocks, max(max_batchsize, reverse_batch)), @@ -186,20 +215,32 @@ void batched_insert_vamana( int sort_smem_size = 0; SELECT_SORT_SMEM_SIZE(degree, visited_size); // Sets sort_smem_size based on dataset - // Total dynamic shared memory used by GreedySearch + // GreedySearch: per-warp shared memory (4 warps): coords, neighbor_array, candidate_queue + const int search_coords_size = (dim + align_padding) * greedy_search_query_smem_elem_size(dim); + const int coords_size = (dim + align_padding) * static_cast(sizeof(QueryCoordT)); + const int neighbor_size = degree * sizeof(IdxT); + const int queue_size_bytes = queue_size * sizeof(DistPair); int search_smem_total_size = - static_cast((dim + align_padding) * sizeof(T) + // visited_size * sizeof(Node) + - degree * sizeof(int) + queue_size * sizeof(DistPair)); + static_cast(4 * ((search_coords_size + neighbor_size + queue_size_bytes + 15) & ~15)); // Total dynamic shared memory size needed by both RobustPrune calls - int prune_smem_total_size = (degree + visited_size) * sizeof(float) + // Occlusion list - (degree + visited_size) * sizeof(DistPair); + const int cand_coords_smem_size = (dim >= kRobustPruneCandCacheMinDim) ? coords_size : 0; + int prune_smem_total_size = (degree + visited_size) * sizeof(float) + // Occlusion list + (degree + visited_size) * sizeof(DistPair) + + visited_size * sizeof(DistPair) + // merge query cache + cand_coords_smem_size + + degree * static_cast(sizeof(accT)) + // graph edge dist cache + degree * static_cast(sizeof(IdxT)); // graph edge id cache + + const int blockD_prune = robust_prune_block_dim(dim, degree); RAFT_LOG_DEBUG( - "Dynamic shared memory usage (bytes): GreedySearch: %d, Segment Sort: %d, Robust Prune: %d", + "Dynamic shared memory usage (bytes): GreedySearch: %d, Segment Sort: %d, Robust Prune: %d, " + "RobustPrune blockDim: %d", search_smem_total_size, sort_smem_size, - prune_smem_total_size); + prune_smem_total_size, + blockD_prune); #if KERNEL_TIMING auto end_t = std::chrono::system_clock::now(); @@ -214,6 +255,63 @@ void batched_insert_vamana( double batch_prune = 0.0; #endif + const int64_t max_total_edges = static_cast(max_batchsize) * degree; + const int max_reverse_batch = params.reverse_batchsize; + auto large_ws = raft::resource::get_large_workspace_resource_ref(res); + + auto edge_dist_pair = raft::make_device_mdarray>( + res, large_ws, raft::make_extents(max_total_edges)); + auto edge_dest = + raft::make_device_mdarray(res, large_ws, raft::make_extents(max_total_edges)); + auto edge_src = + raft::make_device_mdarray(res, large_ws, raft::make_extents(max_total_edges)); + + auto edge_counts = + raft::make_device_mdarray(res, large_ws, raft::make_extents(max_batchsize + 1)); + auto edge_offsets = + raft::make_device_mdarray(res, large_ws, raft::make_extents(max_batchsize + 1)); + + size_t temp_storage_bytes_dist = 0; + size_t temp_storage_bytes_edge = 0; + cub::DeviceMergeSort::SortPairs(nullptr, + temp_storage_bytes_dist, + edge_dist_pair.data_handle(), + edge_src.data_handle(), + max_total_edges, + CmpDist(), + cs); + cub::DeviceMergeSort::SortPairs(nullptr, + temp_storage_bytes_edge, + edge_dest.data_handle(), + edge_src.data_handle(), + max_total_edges, + CmpEdge(), + cs); + size_t temp_storage_bytes = std::max(temp_storage_bytes_dist, temp_storage_bytes_edge); + auto temp_sort_storage = raft::make_device_mdarray( + res, large_ws, raft::make_extents(std::max(temp_storage_bytes, size_t{1}))); + + size_t scan_temp_bytes = 0; + cub::DeviceScan::ExclusiveSum(nullptr, + scan_temp_bytes, + edge_counts.data_handle(), + edge_offsets.data_handle(), + max_batchsize + 1, + cs); + auto scan_temp_storage = raft::make_device_mdarray( + res, large_ws, raft::make_extents(std::max(scan_temp_bytes, size_t{1}))); + + thrust::device_vector edge_dest_vec(max_total_edges); + + auto reverse_list_ptr = raft::make_device_mdarray>( + res, large_ws, raft::make_extents(max_reverse_batch)); + auto rev_ids = raft::make_device_mdarray( + res, large_ws, raft::make_extents(max_reverse_batch, visited_size)); + auto rev_dists = raft::make_device_mdarray( + res, large_ws, raft::make_extents(max_reverse_batch, visited_size)); + QueryCandidates* reverse_list = + static_cast*>(reverse_list_ptr.data_handle()); + // Random medoid has minor impact on recall // TODO: use heuristic for better medoid selection, issue: // https://github.com/rapidsai/cuvs/issues/355 @@ -234,27 +332,26 @@ void batched_insert_vamana( if (start + step_size > N) { step_size = N - start; } RAFT_LOG_DEBUG("Starting batch of inserts indices_start:%d, batch_size:%d", start, step_size); - int num_blocks = min(maxBlocks, step_size); + int num_blocks = min(maxBlocks, step_size); + int num_blocks_greedy = min(maxBlocks, (step_size + 3) / 4); // Copy ids to be inserted for this batch - raft::copy( - res, - raft::make_device_vector_view(query_ids.data_handle(), int64_t(step_size)), - raft::make_host_vector_view(insert_order.data() + start, int64_t(step_size))); + raft::copy(query_ids.data_handle(), &insert_order.data()[start], step_size, stream); set_query_ids<<>>( query_list_ptr.data_handle(), query_ids.data_handle(), step_size); // Call greedy search to get candidates for every vector being inserted GreedySearchKernel - <<>>(d_graph.view(), - dataset, - query_list_ptr.data_handle(), - step_size, - *medoid_id, - visited_size, - metric, - queue_size, - topk_pq_mem.data_handle()); + <<>>( + d_graph.view(), + dataset, + query_list_ptr.data_handle(), + step_size, + *medoid_id, + visited_size, + metric, + queue_size, + topk_pq_mem.data_handle()); RAFT_CUDA_TRY(cudaPeekAtLastError()); #if KERNEL_TIMING @@ -280,14 +377,14 @@ void batched_insert_vamana( // Run on candidates of vectors being inserted RobustPruneKernel - <<>>(d_graph.view(), - dataset, - query_list_ptr.data_handle(), - step_size, - visited_size, - metric, - alpha, - s_coords_mem.data_handle()); + <<>>(d_graph.view(), + dataset, + query_list_ptr.data_handle(), + step_size, + visited_size, + metric, + alpha, + s_coords_mem.data_handle()); RAFT_CUDA_TRY(cudaPeekAtLastError()); // Segmented sort on query list @@ -316,30 +413,27 @@ void batched_insert_vamana( start_t = std::chrono::system_clock::now(); #endif - // compute prefix sums of query_list sizes - TODO parallelize prefix sums - // auto d_total_edges = raft::make_device_mdarray( - // res, raft::resource::get_workspace_resource_ref(res), raft::make_extents(1)); - rmm::device_scalar d_total_edges(stream); - prefix_sums_sizes<<<1, 1, 0, stream>>>(query_list, step_size, d_total_edges.data()); + // compute prefix sums of query_list sizes + const int prefix_count = step_size + 1; + gather_query_sizes + <<>>(query_list, edge_counts.data_handle(), prefix_count); RAFT_CUDA_TRY(cudaPeekAtLastError()); - int total_edges = d_total_edges.value(stream); - // raft::copy(&total_edges, d_total_edges.data_handle(), 1, stream); - // RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + cub::DeviceScan::ExclusiveSum(scan_temp_storage.data_handle(), + scan_temp_bytes, + edge_counts.data_handle(), + edge_offsets.data_handle(), + prefix_count, + cs); + RAFT_CUDA_TRY(cudaPeekAtLastError()); - auto edge_dist_pair = raft::make_device_mdarray>( - res, - raft::resource::get_large_workspace_resource_ref(res), - raft::make_extents(total_edges)); - - auto edge_dest = - raft::make_device_mdarray(res, - raft::resource::get_large_workspace_resource_ref(res), - raft::make_extents(total_edges)); - auto edge_src = - raft::make_device_mdarray(res, - raft::resource::get_large_workspace_resource_ref(res), - raft::make_extents(total_edges)); + scatter_prefix_offsets + <<>>(query_list, edge_offsets.data_handle(), prefix_count); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + + int total_edges; + raft::copy(&total_edges, edge_offsets.data_handle() + step_size, 1, stream); + raft::resource::sync_stream(res); // Create reverse edge list create_reverse_edge_list @@ -352,62 +446,20 @@ void batched_insert_vamana( { // Sort by dists first so final edge lists are each sorted by dist - void* d_temp_storage = nullptr; - size_t temp_storage_bytes = 0; - - cub::DeviceMergeSort::SortPairs(d_temp_storage, - temp_storage_bytes, - edge_dist_pair.data_handle(), - edge_src.data_handle(), - total_edges, - CmpDist(), - stream); - - RAFT_LOG_DEBUG("Temp storage needed for sorting dist (bytes): %lu", temp_storage_bytes); - - auto temp_sort_storage = raft::make_device_mdarray( - res, - raft::resource::get_large_workspace_resource_ref(res), - raft::make_extents(temp_storage_bytes / sizeof(IdxT))); - - // Sort to group reverse edges by destination cub::DeviceMergeSort::SortPairs(temp_sort_storage.data_handle(), temp_storage_bytes, edge_dist_pair.data_handle(), edge_src.data_handle(), total_edges, CmpDist(), - stream); + cs); } - /* - DistPair* temp_ptr = edge_dist_pair.data_handle(); + DistPair* edge_dist_pair_ptr = edge_dist_pair.data_handle(); raft::linalg::map_offset( - res, edge_dest.view(), [temp_ptr] __device__(size_t i) { return temp_ptr[i].idx; }); - */ - raft::linalg::map( - res, - edge_dest.view(), - [] __device__(auto x) { return x.idx; }, - raft::make_const_mdspan(edge_dist_pair.view())); - - void* d_temp_storage = nullptr; - size_t temp_storage_bytes = 0; - - cub::DeviceMergeSort::SortPairs(d_temp_storage, - temp_storage_bytes, - edge_dest.data_handle(), - edge_src.data_handle(), - total_edges, - CmpEdge(), - stream); - - RAFT_LOG_DEBUG("Temp storage needed for sorting (bytes): %lu", temp_storage_bytes); - - auto temp_sort_storage = raft::make_device_mdarray( res, - raft::resource::get_large_workspace_resource_ref(res), - raft::make_extents(temp_storage_bytes / sizeof(IdxT))); + raft::make_device_vector_view(edge_dest.data_handle(), total_edges), + [edge_dist_pair_ptr] __device__(size_t i) { return edge_dist_pair_ptr[i].idx; }); // Sort to group reverse edges by destination cub::DeviceMergeSort::SortPairs(temp_sort_storage.data_handle(), @@ -416,22 +468,23 @@ void batched_insert_vamana( edge_src.data_handle(), total_edges, CmpEdge(), - stream); + cs); // Get number of unique node destinations IdxT unique_dests = cuvs::sparse::neighbors::get_n_components(edge_dest.data_handle(), total_edges, stream); // Find which node IDs have reverse edges and their indices in the reverse edge list - thrust::device_vector edge_dest_vec(edge_dest.data_handle(), - edge_dest.data_handle() + total_edges); + RAFT_CUDA_TRY(cudaMemcpyAsync(thrust::raw_pointer_cast(edge_dest_vec.data()), + edge_dest.data_handle(), + total_edges * sizeof(IdxT), + cudaMemcpyDeviceToDevice, + stream)); auto unique_indices = raft::make_device_vector(res, total_edges); raft::linalg::map_offset(res, unique_indices.view(), raft::identity_op{}); - thrust::unique_by_key(edge_dest_vec.begin(), edge_dest_vec.end(), unique_indices.data_handle()); - - edge_dest_vec.clear(); - edge_dest_vec.shrink_to_fit(); + thrust::unique_by_key( + edge_dest_vec.begin(), edge_dest_vec.begin() + total_edges, unique_indices.data_handle()); #if KERNEL_TIMING RAFT_CUDA_TRY(cudaDeviceSynchronize()); @@ -448,24 +501,6 @@ void batched_insert_vamana( reverse_batch = (int)unique_dests - rev_start; } - // Allocate reverse QueryCandidate list based on number of unique destinations - auto reverse_list_ptr = raft::make_device_mdarray>( - res, - raft::resource::get_large_workspace_resource_ref(res), - raft::make_extents(reverse_batch)); - auto rev_ids = - raft::make_device_mdarray(res, - raft::resource::get_large_workspace_resource_ref(res), - raft::make_extents(reverse_batch, visited_size)); - - auto rev_dists = - raft::make_device_mdarray(res, - raft::resource::get_large_workspace_resource_ref(res), - raft::make_extents(reverse_batch, visited_size)); - - QueryCandidates* reverse_list = - static_cast*>(reverse_list_ptr.data_handle()); - init_query_candidate_list<<<256, blockD, 0, stream>>>(reverse_list, rev_ids.data_handle(), rev_dists.data_handle(), @@ -494,15 +529,15 @@ void batched_insert_vamana( RAFT_CUDA_TRY(cudaPeekAtLastError()); // Call 2nd RobustPrune on reverse query_list - RobustPruneKernel - <<>>(d_graph.view(), - raft::make_const_mdspan(dataset), - reverse_list_ptr.data_handle(), - reverse_batch, - visited_size, - metric, - alpha, - s_coords_mem.data_handle()); + RobustPruneKernel<<>>( + d_graph.view(), + raft::make_const_mdspan(dataset), + reverse_list_ptr.data_handle(), + reverse_batch, + visited_size, + metric, + alpha, + s_coords_mem.data_handle()); RAFT_CUDA_TRY(cudaPeekAtLastError()); // Segmented sort on reverse_list @@ -545,7 +580,7 @@ void batched_insert_vamana( batch_prune); #endif - raft::copy(res, graph, d_graph.view()); + raft::copy(graph.data_handle(), d_graph.data_handle(), d_graph.size(), stream); RAFT_CHECK_CUDA(stream); } @@ -561,8 +596,9 @@ index build( { uint32_t graph_degree = params.graph_degree; - RAFT_EXPECTS(params.metric == cuvs::distance::DistanceType::L2Expanded, - "Currently only L2Expanded metric is supported"); + RAFT_EXPECTS(params.metric == cuvs::distance::DistanceType::L2Expanded || + params.metric == cuvs::distance::DistanceType::L2SqrtExpanded, + "Only L2Expanded and L2SqrtExpanded metrics are supported"); const int* deg_size = std::find(std::begin(DEGREE_SIZES), std::end(DEGREE_SIZES), graph_degree); RAFT_EXPECTS(deg_size != std::end(DEGREE_SIZES), "Provided graph_degree not currently supported"); @@ -579,10 +615,9 @@ index build( RAFT_LOG_DEBUG("Running Vamana batched insert algorithm"); - cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded; IdxT medoid_id; batched_insert_vamana( - res, params, dataset, vamana_graph.view(), &medoid_id, metric); + res, params, dataset, vamana_graph.view(), &medoid_id, params.metric); std::optional> quantized_vectors; if (params.codebooks) { @@ -607,10 +642,10 @@ index build( res, codebook_params.pq_encoding_table.size()); // logically a 2D matrix with dimensions // pq_codebook_size x dim_per_subspace * pq_dim - raft::copy(res, - pq_encoding_table_device_vec.view(), - raft::make_host_vector_view(codebook_params.pq_encoding_table.data(), - pq_encoding_table_device_vec.extent(0))); + raft::copy(pq_encoding_table_device_vec.data_handle(), + codebook_params.pq_encoding_table.data(), + codebook_params.pq_encoding_table.size(), + raft::resource::get_cuda_stream(res)); int dim_per_subspace = dim / pq_dim; auto pq_codebook = raft::make_device_matrix(res, pq_codebook_size * pq_dim, dim_per_subspace); @@ -634,12 +669,10 @@ index build( // prepare rotation matrix auto rotation_matrix_device = raft::make_device_matrix(res, dim, dim); - raft::copy( - res, - raft::make_device_vector_view(rotation_matrix_device.data_handle(), - int64_t(codebook_params.rotation_matrix.size())), - raft::make_host_vector_view(codebook_params.rotation_matrix.data(), - int64_t(codebook_params.rotation_matrix.size()))); + raft::copy(rotation_matrix_device.data_handle(), + codebook_params.rotation_matrix.data(), + codebook_params.rotation_matrix.size(), + raft::resource::get_cuda_stream(res)); // process in batches const uint32_t n_rows = dataset.extent(0); diff --git a/cpp/src/neighbors/detail/vamana/vamana_structs.cuh b/cpp/src/neighbors/detail/vamana/vamana_structs.cuh index 9a076a6727..5f41dc9afe 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_structs.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_structs.cuh @@ -8,7 +8,9 @@ #include #include #include +#include #include +#include #include #include @@ -20,6 +22,7 @@ #include #include #include +#include #include @@ -31,6 +34,37 @@ namespace cuvs::neighbors::vamana::detail { #define FULL_BITMASK 0xFFFFFFFF +// Warp stride for per-warp distance reduction (GreedySearch uses multiple warps per block). +static constexpr int VAMANA_WARP_SIZE = 32; + +// vamana fp16 instantiations use CUDA's half type (alias of __half on device). +template +inline constexpr bool is_cuda_fp16_v = std::is_same_v, half>; + +// GreedySearch promotes fp16 queries to float in shared memory for distance reuse. +template +struct greedy_search_query_coord { + using type = std::conditional_t, float, T>; +}; + +// Wide vectors: store cached GreedySearch query coords as __half in smem (dim >= 512 only). +// Half dataset is excluded: it already caches query as float (promote once); fp16 smem would +// re-widen on every distance and lose the vectorized float-query vs half-neighbor path. +static constexpr int kGreedySearchFp16QuerySmemMinDim = 512; + +template +__host__ __device__ inline bool greedy_search_use_fp16_query_smem(int dim) +{ + return dim >= kGreedySearchFp16QuerySmemMinDim && !is_cuda_fp16_v; +} + +template +__host__ __device__ inline int greedy_search_query_smem_elem_size(int dim) +{ + if (greedy_search_use_fp16_query_smem(dim)) { return static_cast(sizeof(__half)); } + return static_cast(sizeof(typename greedy_search_query_coord::type)); +} + // Currently supported values for graph_degree. static const int DEGREE_SIZES[4] = {32, 64, 128, 256}; @@ -57,7 +91,8 @@ __device__ __host__ void swap(DistPair* a, DistPair* b) // Structure to sort by distance template struct CmpDist { - __device__ bool operator()(const DistPair& lhs, const DistPair& rhs) + __host__ __device__ bool operator()(const DistPair& lhs, + const DistPair& rhs) { return lhs.dist < rhs.dist; } @@ -66,7 +101,7 @@ struct CmpDist { // Used to sort reverse edges by destination template struct CmpEdge { - __device__ bool operator()(const IdxT& lhs, const IdxT& rhs) { return lhs < rhs; } + __host__ __device__ bool operator()(const IdxT& lhs, const IdxT& rhs) { return lhs < rhs; } }; /********************************************************************* @@ -170,16 +205,118 @@ __device__ SUMTYPE l2_ILP4(Point* src_vec, Point* dst_ve return partial_sum[0]; } +/* fp16: native __hsub/__hfma throughout; single float widen at return */ +__device__ __forceinline__ void l2_half_accum(__half& lane_sum, __half s, __half t) +{ + __half d = __hsub(s, t); + lane_sum = __hfma(d, d, lane_sum); +} + +/* ILP helpers: accumulate (s-t)^2 into acc; operands must already be loaded */ +__device__ __forceinline__ void l2_half_fma_sq(__half& acc, __half s, __half t) +{ + __half d = __hsub(s, t); + acc = __hfma(d, d, acc); +} + +__device__ __forceinline__ __half l2_half_shfl_down(__half val, int offset) +{ + unsigned int v = static_cast(__half_as_ushort(val)); + v = __shfl_down_sync(FULL_BITMASK, v, offset); + return __ushort_as_half(static_cast(v & 0xFFFFu)); +} + +__device__ __forceinline__ __half l2_half_warp_reduce_to_half(__half lane_sum) +{ + for (int offset = 16; offset > 0; offset /= 2) { + lane_sum = __hadd(lane_sum, l2_half_shfl_down(lane_sum, offset)); + } + return lane_sum; +} + +template +__device__ __forceinline__ SUMTYPE l2_half_warp_reduce(__half lane_sum) +{ + return static_cast(__half2float(l2_half_warp_reduce_to_half(lane_sum))); +} + +template +__device__ SUMTYPE l2_SEQ_half(Point<__half, SUMTYPE>* src_vec, Point<__half, SUMTYPE>* dst_vec) +{ + __half lane_sum = __float2half(0.0f); + + for (int i = threadIdx.x; i < src_vec->Dim; i += blockDim.x) { + l2_half_accum(lane_sum, src_vec[0].coords[i], dst_vec[0].coords[i]); + } + + return l2_half_warp_reduce(lane_sum); +} + +template +__device__ SUMTYPE l2_ILP2_half(Point<__half, SUMTYPE>* src_vec, Point<__half, SUMTYPE>* dst_vec) +{ + __half temp_dst[2] = {__float2half(0.0f), __float2half(0.0f)}; + __half partial_sum[2] = {__float2half(0.0f), __float2half(0.0f)}; + for (int i = threadIdx.x; i < src_vec->Dim; i += 2 * blockDim.x) { + temp_dst[0] = dst_vec->coords[i]; + if (i + 32 < src_vec->Dim) temp_dst[1] = dst_vec->coords[i + 32]; + + l2_half_fma_sq(partial_sum[0], src_vec[0].coords[i], temp_dst[0]); + if (i + 32 < src_vec->Dim) + l2_half_fma_sq(partial_sum[1], src_vec[0].coords[i + 32], temp_dst[1]); + } + partial_sum[0] = __hadd(partial_sum[0], partial_sum[1]); + + return l2_half_warp_reduce(partial_sum[0]); +} + +template +__device__ SUMTYPE l2_ILP4_half(Point<__half, SUMTYPE>* src_vec, Point<__half, SUMTYPE>* dst_vec) +{ + __half temp_dst[4] = { + __float2half(0.0f), __float2half(0.0f), __float2half(0.0f), __float2half(0.0f)}; + __half partial_sum[4] = { + __float2half(0.0f), __float2half(0.0f), __float2half(0.0f), __float2half(0.0f)}; + for (int i = threadIdx.x; i < src_vec->Dim; i += 4 * blockDim.x) { + temp_dst[0] = dst_vec->coords[i]; + if (i + 32 < src_vec->Dim) temp_dst[1] = dst_vec->coords[i + 32]; + if (i + 64 < src_vec->Dim) temp_dst[2] = dst_vec->coords[i + 64]; + if (i + 96 < src_vec->Dim) temp_dst[3] = dst_vec->coords[i + 96]; + + l2_half_fma_sq(partial_sum[0], src_vec[0].coords[i], temp_dst[0]); + if (i + 32 < src_vec->Dim) + l2_half_fma_sq(partial_sum[1], src_vec[0].coords[i + 32], temp_dst[1]); + if (i + 64 < src_vec->Dim) + l2_half_fma_sq(partial_sum[2], src_vec[0].coords[i + 64], temp_dst[2]); + if (i + 96 < src_vec->Dim) + l2_half_fma_sq(partial_sum[3], src_vec[0].coords[i + 96], temp_dst[3]); + } + partial_sum[0] = + __hadd(partial_sum[0], __hadd(partial_sum[1], __hadd(partial_sum[2], partial_sum[3]))); + + return l2_half_warp_reduce(partial_sum[0]); +} + /* Selects ILP optimization level based on dimension */ template __forceinline__ __device__ SUMTYPE l2(Point* src_vec, Point* dst_vec) { - if (src_vec->Dim >= 128) { - return l2_ILP4(src_vec, dst_vec); - } else if (src_vec->Dim >= 64) { - return l2_ILP2(src_vec, dst_vec); + if constexpr (std::is_same_v, __half>) { + if (src_vec->Dim >= 128) { + return l2_ILP4_half(src_vec, dst_vec); + } else if (src_vec->Dim >= 64) { + return l2_ILP2_half(src_vec, dst_vec); + } else { + return l2_SEQ_half(src_vec, dst_vec); + } } else { - return l2_SEQ(src_vec, dst_vec); + if (src_vec->Dim >= 128) { + return l2_ILP4(src_vec, dst_vec); + } else if (src_vec->Dim >= 64) { + return l2_ILP2(src_vec, dst_vec); + } else { + return l2_SEQ(src_vec, dst_vec); + } } } @@ -197,12 +334,657 @@ __host__ __device__ SUMTYPE l2(const T* src, const T* dest, int dim) return l2(&src_p, &dest_p); } -// Currently only L2Expanded is supported template __host__ __device__ SUMTYPE dist(const T* src, const T* dest, int dim, cuvs::distance::DistanceType metric) { - return l2(src, dest, dim); + SUMTYPE d = l2(src, dest, dim); + if (metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + return static_cast(sqrtf(static_cast(d))); + } + return d; +} + +/* + * Warp-strided L2 / dist: each warp computes one distance using lanes 0..31 only. + * Required when blockDim.x > 32 but only one distance per warp is desired — plain l2() shards + * work across the whole block then reduces inside each warp only, which under-counts L2. + */ +template +__device__ SUMTYPE l2_SEQ_warp(Point* src_vec, Point* dst_vec, int lane) +{ + SUMTYPE partial_sum = 0; + for (int i = lane; i < src_vec->Dim; i += VAMANA_WARP_SIZE) { + partial_sum = fmaf((src_vec[0].coords[i] - dst_vec[0].coords[i]), + (src_vec[0].coords[i] - dst_vec[0].coords[i]), + partial_sum); + } + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum += __shfl_down_sync(FULL_BITMASK, partial_sum, offset); + } + return partial_sum; +} + +template +__device__ SUMTYPE l2_ILP2_warp(Point* src_vec, Point* dst_vec, int lane) +{ + T temp_dst[2] = {0, 0}; + SUMTYPE partial_sum[2] = {0, 0}; + for (int i = lane; i < src_vec->Dim; i += 2 * VAMANA_WARP_SIZE) { + temp_dst[0] = dst_vec->coords[i]; + if (i + 32 < src_vec->Dim) temp_dst[1] = dst_vec->coords[i + 32]; + + partial_sum[0] = fmaf( + (src_vec[0].coords[i] - temp_dst[0]), (src_vec[0].coords[i] - temp_dst[0]), partial_sum[0]); + if (i + 32 < src_vec->Dim) + partial_sum[1] = fmaf((src_vec[0].coords[i + 32] - temp_dst[1]), + (src_vec[0].coords[i + 32] - temp_dst[1]), + partial_sum[1]); + } + partial_sum[0] += partial_sum[1]; + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + } + return partial_sum[0]; +} + +template +__device__ SUMTYPE l2_ILP4_warp(Point* src_vec, Point* dst_vec, int lane) +{ + T temp_dst[4] = {0, 0, 0, 0}; + SUMTYPE partial_sum[4] = {0, 0, 0, 0}; + for (int i = lane; i < src_vec->Dim; i += 4 * VAMANA_WARP_SIZE) { + temp_dst[0] = dst_vec->coords[i]; + if (i + 32 < src_vec->Dim) temp_dst[1] = dst_vec->coords[i + 32]; + if (i + 64 < src_vec->Dim) temp_dst[2] = dst_vec->coords[i + 64]; + if (i + 96 < src_vec->Dim) temp_dst[3] = dst_vec->coords[i + 96]; + + partial_sum[0] = fmaf( + (src_vec[0].coords[i] - temp_dst[0]), (src_vec[0].coords[i] - temp_dst[0]), partial_sum[0]); + if (i + 32 < src_vec->Dim) + partial_sum[1] = fmaf((src_vec[0].coords[i + 32] - temp_dst[1]), + (src_vec[0].coords[i + 32] - temp_dst[1]), + partial_sum[1]); + if (i + 64 < src_vec->Dim) + partial_sum[2] = fmaf((src_vec[0].coords[i + 64] - temp_dst[2]), + (src_vec[0].coords[i + 64] - temp_dst[2]), + partial_sum[2]); + if (i + 96 < src_vec->Dim) + partial_sum[3] = fmaf((src_vec[0].coords[i + 96] - temp_dst[3]), + (src_vec[0].coords[i + 96] - temp_dst[3]), + partial_sum[3]); + } + partial_sum[0] += partial_sum[1] + partial_sum[2] + partial_sum[3]; + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + } + return partial_sum[0]; +} + +template +__device__ SUMTYPE l2_SEQ_half_warp(Point<__half, SUMTYPE>* src_vec, + Point<__half, SUMTYPE>* dst_vec, + int lane) +{ + __half lane_sum = __float2half(0.0f); + for (int i = lane; i < src_vec->Dim; i += VAMANA_WARP_SIZE) { + l2_half_accum(lane_sum, src_vec[0].coords[i], dst_vec[0].coords[i]); + } + return l2_half_warp_reduce(lane_sum); +} + +template +__device__ SUMTYPE l2_ILP2_half_warp(Point<__half, SUMTYPE>* src_vec, + Point<__half, SUMTYPE>* dst_vec, + int lane) +{ + __half temp_dst[2] = {__float2half(0.0f), __float2half(0.0f)}; + __half partial_sum[2] = {__float2half(0.0f), __float2half(0.0f)}; + for (int i = lane; i < src_vec->Dim; i += 2 * VAMANA_WARP_SIZE) { + temp_dst[0] = dst_vec->coords[i]; + if (i + 32 < src_vec->Dim) temp_dst[1] = dst_vec->coords[i + 32]; + + l2_half_fma_sq(partial_sum[0], src_vec[0].coords[i], temp_dst[0]); + if (i + 32 < src_vec->Dim) + l2_half_fma_sq(partial_sum[1], src_vec[0].coords[i + 32], temp_dst[1]); + } + partial_sum[0] = __hadd(partial_sum[0], partial_sum[1]); + return l2_half_warp_reduce(partial_sum[0]); +} + +template +__device__ SUMTYPE l2_ILP4_half_warp(Point<__half, SUMTYPE>* src_vec, + Point<__half, SUMTYPE>* dst_vec, + int lane) +{ + __half temp_dst[4] = { + __float2half(0.0f), __float2half(0.0f), __float2half(0.0f), __float2half(0.0f)}; + __half partial_sum[4] = { + __float2half(0.0f), __float2half(0.0f), __float2half(0.0f), __float2half(0.0f)}; + for (int i = lane; i < src_vec->Dim; i += 4 * VAMANA_WARP_SIZE) { + temp_dst[0] = dst_vec->coords[i]; + if (i + 32 < src_vec->Dim) temp_dst[1] = dst_vec->coords[i + 32]; + if (i + 64 < src_vec->Dim) temp_dst[2] = dst_vec->coords[i + 64]; + if (i + 96 < src_vec->Dim) temp_dst[3] = dst_vec->coords[i + 96]; + + l2_half_fma_sq(partial_sum[0], src_vec[0].coords[i], temp_dst[0]); + if (i + 32 < src_vec->Dim) + l2_half_fma_sq(partial_sum[1], src_vec[0].coords[i + 32], temp_dst[1]); + if (i + 64 < src_vec->Dim) + l2_half_fma_sq(partial_sum[2], src_vec[0].coords[i + 64], temp_dst[2]); + if (i + 96 < src_vec->Dim) + l2_half_fma_sq(partial_sum[3], src_vec[0].coords[i + 96], temp_dst[3]); + } + partial_sum[0] = + __hadd(partial_sum[0], __hadd(partial_sum[1], __hadd(partial_sum[2], partial_sum[3]))); + return l2_half_warp_reduce(partial_sum[0]); +} + +template +__forceinline__ __device__ SUMTYPE l2_warp(Point* src_vec, + Point* dst_vec, + int lane) +{ + if constexpr (std::is_same_v, __half>) { + if (src_vec->Dim >= 128) { + return l2_ILP4_half_warp(src_vec, dst_vec, lane); + } else if (src_vec->Dim >= 64) { + return l2_ILP2_half_warp(src_vec, dst_vec, lane); + } else { + return l2_SEQ_half_warp(src_vec, dst_vec, lane); + } + } else { + if (src_vec->Dim >= 128) { + return l2_ILP4_warp(src_vec, dst_vec, lane); + } else if (src_vec->Dim >= 64) { + return l2_ILP2_warp(src_vec, dst_vec, lane); + } else { + return l2_SEQ_warp(src_vec, dst_vec, lane); + } + } +} + +template +__forceinline__ __device__ SUMTYPE l2_warp(const T* src, const T* dest, int dim, int lane) +{ + Point src_p; + src_p.coords = const_cast(src); + src_p.Dim = dim; + Point dest_p; + dest_p.coords = const_cast(dest); + dest_p.Dim = dim; + + return l2_warp(&src_p, &dest_p, lane); +} + +/* float query vs half dataset: vectorized half2/float2 loads (even lane*2 indices) */ +__device__ __forceinline__ float2 l2_load_src2(const float* src, int i) +{ + return *reinterpret_cast(&src[i]); +} + +__device__ __forceinline__ float2 l2_load_dst2_half(const __half* dst, int i) +{ + return __half22float2(*reinterpret_cast(&dst[i])); +} + +template +__device__ __forceinline__ void l2_fma_sq2(SUMTYPE& acc, float sx, float sy, float2 dst2) +{ + float dx = sx - dst2.x; + float dy = sy - dst2.y; + acc = fmaf(dx, dx, acc); + acc = fmaf(dy, dy, acc); +} + +// Widen any supported operand type to float for the scalar fallback path. +template +__device__ __forceinline__ float l2_widen_to_float(X x) +{ + return static_cast(x); +} +__device__ __forceinline__ float l2_widen_to_float(__half x) { return __half2float(x); } + +// Scalar, alignment-agnostic warp L2 used as a fallback when `dim` is ODD. The +// vectorized float2/half2 comparators below assume each dataset row begins on a +// vector-aligned boundary, which only holds when dim is even (row stride = +// dim * sizeof(elem) is then a multiple of the vector width). For odd dim, odd +// rows are under-aligned and a float2/half2 load raises cudaErrorMisalignedAddress, +// so we widen both operands to float and accumulate one element per lane instead. +template +__device__ __forceinline__ SUMTYPE +l2_warp_scalar_widen(const SrcT* src, const DstT* dst, int dim, int lane) +{ + SUMTYPE partial_sum = 0; + for (int i = lane; i < dim; i += VAMANA_WARP_SIZE) { + float diff = l2_widen_to_float(src[i]) - l2_widen_to_float(dst[i]); + partial_sum = fmaf(diff, diff, partial_sum); + } + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum += __shfl_down_sync(FULL_BITMASK, partial_sum, offset); + } + return partial_sum; +} + +template +__device__ SUMTYPE l2_SEQ_warp_float_half(const float* src, const __half* dst, int dim, int lane) +{ + SUMTYPE partial_sum = 0; + for (int i = lane * 2; i < dim; i += VAMANA_WARP_SIZE * 2) { + float2 dst2 = l2_load_dst2_half(dst, i); + float2 src2 = l2_load_src2(src, i); + l2_fma_sq2(partial_sum, src2.x, src2.y, dst2); + } + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum += __shfl_down_sync(FULL_BITMASK, partial_sum, offset); + } + return partial_sum; +} + +template +__device__ SUMTYPE l2_ILP2_warp_float_half(const float* src, const __half* dst, int dim, int lane) +{ + SUMTYPE partial_sum[2] = {0, 0}; + for (int i = lane * 2; i < dim; i += 2 * VAMANA_WARP_SIZE * 2) { + float2 temp_dst[2] = {{0, 0}, {0, 0}}; + temp_dst[0] = l2_load_dst2_half(dst, i); + if (i + 64 < dim) temp_dst[1] = l2_load_dst2_half(dst, i + 64); + + float2 src0 = l2_load_src2(src, i); + l2_fma_sq2(partial_sum[0], src0.x, src0.y, temp_dst[0]); + if (i + 64 < dim) { + float2 src1 = l2_load_src2(src, i + 64); + l2_fma_sq2(partial_sum[1], src1.x, src1.y, temp_dst[1]); + } + } + partial_sum[0] += partial_sum[1]; + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + } + return partial_sum[0]; +} + +template +__device__ SUMTYPE l2_ILP4_warp_float_half(const float* src, const __half* dst, int dim, int lane) +{ + SUMTYPE partial_sum[4] = {0, 0, 0, 0}; + for (int i = lane * 2; i < dim; i += 4 * VAMANA_WARP_SIZE * 2) { + float2 temp_dst[4] = {{0, 0}, {0, 0}, {0, 0}, {0, 0}}; + temp_dst[0] = l2_load_dst2_half(dst, i); + if (i + 64 < dim) temp_dst[1] = l2_load_dst2_half(dst, i + 64); + if (i + 128 < dim) temp_dst[2] = l2_load_dst2_half(dst, i + 128); + if (i + 192 < dim) temp_dst[3] = l2_load_dst2_half(dst, i + 192); + + float2 src0 = l2_load_src2(src, i); + l2_fma_sq2(partial_sum[0], src0.x, src0.y, temp_dst[0]); + if (i + 64 < dim) { + float2 src1 = l2_load_src2(src, i + 64); + l2_fma_sq2(partial_sum[1], src1.x, src1.y, temp_dst[1]); + } + if (i + 128 < dim) { + float2 src2 = l2_load_src2(src, i + 128); + l2_fma_sq2(partial_sum[2], src2.x, src2.y, temp_dst[2]); + } + if (i + 192 < dim) { + float2 src3 = l2_load_src2(src, i + 192); + l2_fma_sq2(partial_sum[3], src3.x, src3.y, temp_dst[3]); + } + } + partial_sum[0] += partial_sum[1] + partial_sum[2] + partial_sum[3]; + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + } + return partial_sum[0]; +} + +template +__forceinline__ __device__ SUMTYPE +l2_warp_float_half(const float* src, const __half* dest, int dim, int lane) +{ + if (dim & 1) { return l2_warp_scalar_widen(src, dest, dim, lane); } + if (dim >= 128) { + return l2_ILP4_warp_float_half(src, dest, dim, lane); + } else if (dim >= 64) { + return l2_ILP2_warp_float_half(src, dest, dim, lane); + } else { + return l2_SEQ_warp_float_half(src, dest, dim, lane); + } +} + +/* half query vs float dataset: widen query coords to float at use (mirror of float_half) */ +__device__ __forceinline__ float2 l2_load_src2_half(const __half* src, int i) +{ + return __half22float2(*reinterpret_cast(&src[i])); +} + +template +__device__ SUMTYPE l2_SEQ_warp_half_float(const __half* src, const float* dst, int dim, int lane) +{ + SUMTYPE partial_sum = 0; + for (int i = lane * 2; i < dim; i += VAMANA_WARP_SIZE * 2) { + float2 src2 = l2_load_src2_half(src, i); + float2 dst2 = l2_load_src2(dst, i); + l2_fma_sq2(partial_sum, src2.x, src2.y, dst2); + } + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum += __shfl_down_sync(FULL_BITMASK, partial_sum, offset); + } + return partial_sum; +} + +template +__device__ SUMTYPE l2_ILP2_warp_half_float(const __half* src, const float* dst, int dim, int lane) +{ + SUMTYPE partial_sum[2] = {0, 0}; + for (int i = lane * 2; i < dim; i += 2 * VAMANA_WARP_SIZE * 2) { + float2 temp_dst[2] = {{0, 0}, {0, 0}}; + temp_dst[0] = l2_load_src2(dst, i); + if (i + 64 < dim) temp_dst[1] = l2_load_src2(dst, i + 64); + + float2 src0 = l2_load_src2_half(src, i); + l2_fma_sq2(partial_sum[0], src0.x, src0.y, temp_dst[0]); + if (i + 64 < dim) { + float2 src1 = l2_load_src2_half(src, i + 64); + l2_fma_sq2(partial_sum[1], src1.x, src1.y, temp_dst[1]); + } + } + partial_sum[0] += partial_sum[1]; + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + } + return partial_sum[0]; +} + +template +__device__ SUMTYPE l2_ILP4_warp_half_float(const __half* src, const float* dst, int dim, int lane) +{ + SUMTYPE partial_sum[4] = {0, 0, 0, 0}; + for (int i = lane * 2; i < dim; i += 4 * VAMANA_WARP_SIZE * 2) { + float2 temp_dst[4] = {{0, 0}, {0, 0}, {0, 0}, {0, 0}}; + temp_dst[0] = l2_load_src2(dst, i); + if (i + 64 < dim) temp_dst[1] = l2_load_src2(dst, i + 64); + if (i + 128 < dim) temp_dst[2] = l2_load_src2(dst, i + 128); + if (i + 192 < dim) temp_dst[3] = l2_load_src2(dst, i + 192); + + float2 src0 = l2_load_src2_half(src, i); + l2_fma_sq2(partial_sum[0], src0.x, src0.y, temp_dst[0]); + if (i + 64 < dim) { + float2 src1 = l2_load_src2_half(src, i + 64); + l2_fma_sq2(partial_sum[1], src1.x, src1.y, temp_dst[1]); + } + if (i + 128 < dim) { + float2 src2 = l2_load_src2_half(src, i + 128); + l2_fma_sq2(partial_sum[2], src2.x, src2.y, temp_dst[2]); + } + if (i + 192 < dim) { + float2 src3 = l2_load_src2_half(src, i + 192); + l2_fma_sq2(partial_sum[3], src3.x, src3.y, temp_dst[3]); + } + } + partial_sum[0] += partial_sum[1] + partial_sum[2] + partial_sum[3]; + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + } + return partial_sum[0]; +} + +template +__forceinline__ __device__ SUMTYPE +l2_warp_half_float(const __half* src, const float* dest, int dim, int lane) +{ + if (dim & 1) { return l2_warp_scalar_widen(src, dest, dim, lane); } + if (dim >= 128) { + return l2_ILP4_warp_half_float(src, dest, dim, lane); + } else if (dim >= 64) { + return l2_ILP2_warp_half_float(src, dest, dim, lane); + } else { + return l2_SEQ_warp_half_float(src, dest, dim, lane); + } +} + +/* fp16 query smem vs half dataset: vectorized half2 loads, widen to float, float accumulate + * (mirror of l2_warp_float_half; avoids scalar smem loads and native half FMA) */ +template +__device__ SUMTYPE +l2_SEQ_warp_half_smem_half(const __half* src, const __half* dst, int dim, int lane) +{ + SUMTYPE partial_sum = 0; + for (int i = lane * 2; i < dim; i += VAMANA_WARP_SIZE * 2) { + float2 dst2 = l2_load_dst2_half(dst, i); + float2 src2 = l2_load_src2_half(src, i); + l2_fma_sq2(partial_sum, src2.x, src2.y, dst2); + } + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum += __shfl_down_sync(FULL_BITMASK, partial_sum, offset); + } + return partial_sum; +} + +template +__device__ SUMTYPE +l2_ILP2_warp_half_smem_half(const __half* src, const __half* dst, int dim, int lane) +{ + SUMTYPE partial_sum[2] = {0, 0}; + for (int i = lane * 2; i < dim; i += 2 * VAMANA_WARP_SIZE * 2) { + float2 temp_dst[2] = {{0, 0}, {0, 0}}; + temp_dst[0] = l2_load_dst2_half(dst, i); + if (i + 64 < dim) temp_dst[1] = l2_load_dst2_half(dst, i + 64); + + float2 src0 = l2_load_src2_half(src, i); + l2_fma_sq2(partial_sum[0], src0.x, src0.y, temp_dst[0]); + if (i + 64 < dim) { + float2 src1 = l2_load_src2_half(src, i + 64); + l2_fma_sq2(partial_sum[1], src1.x, src1.y, temp_dst[1]); + } + } + partial_sum[0] += partial_sum[1]; + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + } + return partial_sum[0]; +} + +template +__device__ SUMTYPE +l2_ILP4_warp_half_smem_half(const __half* src, const __half* dst, int dim, int lane) +{ + SUMTYPE partial_sum[4] = {0, 0, 0, 0}; + for (int i = lane * 2; i < dim; i += 4 * VAMANA_WARP_SIZE * 2) { + float2 temp_dst[4] = {{0, 0}, {0, 0}, {0, 0}, {0, 0}}; + temp_dst[0] = l2_load_dst2_half(dst, i); + if (i + 64 < dim) temp_dst[1] = l2_load_dst2_half(dst, i + 64); + if (i + 128 < dim) temp_dst[2] = l2_load_dst2_half(dst, i + 128); + if (i + 192 < dim) temp_dst[3] = l2_load_dst2_half(dst, i + 192); + + float2 src0 = l2_load_src2_half(src, i); + l2_fma_sq2(partial_sum[0], src0.x, src0.y, temp_dst[0]); + if (i + 64 < dim) { + float2 src1 = l2_load_src2_half(src, i + 64); + l2_fma_sq2(partial_sum[1], src1.x, src1.y, temp_dst[1]); + } + if (i + 128 < dim) { + float2 src2 = l2_load_src2_half(src, i + 128); + l2_fma_sq2(partial_sum[2], src2.x, src2.y, temp_dst[2]); + } + if (i + 192 < dim) { + float2 src3 = l2_load_src2_half(src, i + 192); + l2_fma_sq2(partial_sum[3], src3.x, src3.y, temp_dst[3]); + } + } + partial_sum[0] += partial_sum[1] + partial_sum[2] + partial_sum[3]; + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + } + return partial_sum[0]; +} + +template +__forceinline__ __device__ SUMTYPE +l2_warp_half_smem_half(const __half* src, const __half* dest, int dim, int lane) +{ + if (dim & 1) { return l2_warp_scalar_widen(src, dest, dim, lane); } + if (dim >= 128) { + return l2_ILP4_warp_half_smem_half(src, dest, dim, lane); + } else if (dim >= 64) { + return l2_ILP2_warp_half_smem_half(src, dest, dim, lane); + } else { + return l2_SEQ_warp_half_smem_half(src, dest, dim, lane); + } +} + +/* fp16 query smem vs int8 (or other native) dataset: same vectorized query widen, float accumulate + */ +template +__device__ __forceinline__ void l2_fma_sq2_half_native(SUMTYPE& acc, + float2 src2, + const DataT* dst, + int i) +{ + float dx = src2.x - static_cast(dst[i]); + float dy = src2.y - static_cast(dst[i + 1]); + acc = fmaf(dx, dx, acc); + acc = fmaf(dy, dy, acc); +} + +template +__device__ SUMTYPE +l2_SEQ_warp_half_smem_native(const __half* src, const DataT* dst, int dim, int lane) +{ + SUMTYPE partial_sum = 0; + for (int i = lane * 2; i < dim; i += VAMANA_WARP_SIZE * 2) { + float2 src2 = l2_load_src2_half(src, i); + l2_fma_sq2_half_native(partial_sum, src2, dst, i); + } + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum += __shfl_down_sync(FULL_BITMASK, partial_sum, offset); + } + return partial_sum; +} + +template +__device__ SUMTYPE +l2_ILP2_warp_half_smem_native(const __half* src, const DataT* dst, int dim, int lane) +{ + SUMTYPE partial_sum[2] = {0, 0}; + for (int i = lane * 2; i < dim; i += 2 * VAMANA_WARP_SIZE * 2) { + float2 src0 = l2_load_src2_half(src, i); + l2_fma_sq2_half_native(partial_sum[0], src0, dst, i); + if (i + 64 < dim) { + float2 src1 = l2_load_src2_half(src, i + 64); + l2_fma_sq2_half_native(partial_sum[1], src1, dst, i + 64); + } + } + partial_sum[0] += partial_sum[1]; + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + } + return partial_sum[0]; +} + +template +__device__ SUMTYPE +l2_ILP4_warp_half_smem_native(const __half* src, const DataT* dst, int dim, int lane) +{ + SUMTYPE partial_sum[4] = {0, 0, 0, 0}; + for (int i = lane * 2; i < dim; i += 4 * VAMANA_WARP_SIZE * 2) { + float2 src0 = l2_load_src2_half(src, i); + l2_fma_sq2_half_native(partial_sum[0], src0, dst, i); + if (i + 64 < dim) { + float2 src1 = l2_load_src2_half(src, i + 64); + l2_fma_sq2_half_native(partial_sum[1], src1, dst, i + 64); + } + if (i + 128 < dim) { + float2 src2 = l2_load_src2_half(src, i + 128); + l2_fma_sq2_half_native(partial_sum[2], src2, dst, i + 128); + } + if (i + 192 < dim) { + float2 src3 = l2_load_src2_half(src, i + 192); + l2_fma_sq2_half_native(partial_sum[3], src3, dst, i + 192); + } + } + partial_sum[0] += partial_sum[1] + partial_sum[2] + partial_sum[3]; + for (int offset = 16; offset > 0; offset /= 2) { + partial_sum[0] += __shfl_down_sync(FULL_BITMASK, partial_sum[0], offset); + } + return partial_sum[0]; +} + +template +__forceinline__ __device__ SUMTYPE +l2_warp_half_smem_native(const __half* src, const DataT* dest, int dim, int lane) +{ + if (dim & 1) { return l2_warp_scalar_widen(src, dest, dim, lane); } + if (dim >= 128) { + return l2_ILP4_warp_half_smem_native(src, dest, dim, lane); + } else if (dim >= 64) { + return l2_ILP2_warp_half_smem_native(src, dest, dim, lane); + } else { + return l2_SEQ_warp_half_smem_native(src, dest, dim, lane); + } +} + +template +__forceinline__ __device__ SUMTYPE dist_warp_half_query( + const __half* src, const DataT* dest, int dim, cuvs::distance::DistanceType metric, int lane) +{ + SUMTYPE d; + if constexpr (is_cuda_fp16_v) { + d = l2_warp_half_smem_half(src, reinterpret_cast(dest), dim, lane); + } else if constexpr (std::is_same_v) { + d = l2_warp_half_float(src, dest, dim, lane); + } else { + d = l2_warp_half_smem_native(src, dest, dim, lane); + } + if (metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + return static_cast(sqrtf(static_cast(d))); + } + return d; +} + +template +__forceinline__ __device__ SUMTYPE dist_warp( + const float* src, const half* dest, int dim, cuvs::distance::DistanceType metric, int lane) +{ + SUMTYPE d = l2_warp_float_half(src, reinterpret_cast(dest), dim, lane); + if (metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + return static_cast(sqrtf(static_cast(d))); + } + return d; +} + +template +__forceinline__ __device__ SUMTYPE +dist_warp(const T* src, const T* dest, int dim, cuvs::distance::DistanceType metric, int lane) +{ + SUMTYPE d = l2_warp(src, dest, dim, lane); + if (metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + return static_cast(sqrtf(static_cast(d))); + } + return d; +} + +// Warp-cooperative id lookup in a GreedySearch visited list (sorted by distance, not id). +// All lanes execute the same number of ballot rounds to avoid warp deadlock. +template +__forceinline__ __device__ bool lookup_visited_dist_warp( + const IdxT* ids, const accT* dists, int size, IdxT target, accT& out_dist, int laneId) +{ + bool found = false; + accT found_dist = static_cast(0); + const int num_iters = (size + 31) >> 5; + for (int k = 0; k < num_iters; ++k) { + const int j = (k << 5) + laneId; + const bool hit = (j < size) && (ids[j] == target); + accT my_dist = hit ? dists[j] : static_cast(0); + const unsigned hits = raft::ballot(hit); + if (hits != 0 && !found) { + const int src_lane = __ffs(hits) - 1; + found_dist = raft::shfl(my_dist, src_lane); + found = true; + } + } + if (found) { out_dist = found_dist; } + return found; } /*************************************************************************************** @@ -226,6 +1008,16 @@ struct QueryCandidates { size = 0; } + // Warp-level reset: uses laneId and stride 32, no block sync + __device__ void reset_warp(int laneId) + { + for (int i = laneId; i < maxSize; i += 32) { + ids[i] = raft::upper_bound(); + dists[i] = raft::upper_bound(); + } + if (laneId == 0) { size = 0; } + } + // Checks current list to see if a node as previously been visited __inline__ __device__ bool check_visited(IdxT target, accT dist) { @@ -249,6 +1041,26 @@ struct QueryCandidates { } return found; } + + // Warp-level check_visited: no __syncthreads, uses laneId and warp ballot + __inline__ __device__ bool check_visited_warp(IdxT target, accT dist_val, int laneId) + { + bool my_found = false; + for (int i = laneId; i < size; i += 32) { + if (ids[i] == target) { + my_found = true; + break; + } + } + unsigned mask = raft::ballot(my_found); + bool found = (mask != 0); + if (!found && size < maxSize && laneId == 0) { + ids[size] = target; + dists[size] = dist_val; + size++; + } + return found; + } // For debugging /* __inline__ __device__ void print_visited() { @@ -318,23 +1130,6 @@ __global__ void set_query_ids(void* query_list_ptr, IdxT* d_query_ids, int step_ } } -// Compute prefix sums on sizes. Currently only works with 1 thread -// TODO replace with parallel version -template -__global__ void prefix_sums_sizes(QueryCandidates* query_list, - int num_queries, - int* total_edges) -{ - if (threadIdx.x == 0 && blockIdx.x == 0) { - int sum = 0; - for (int i = 0; i < num_queries + 1; i++) { - sum += query_list[i].size; - query_list[i].size = sum - query_list[i].size; // exclusive prefix sum - } - *total_edges = query_list[num_queries].size; - } -} - // Device fcn to have a threadblock copy coordinates into shared memory template __device__ void update_shared_point( @@ -361,6 +1156,120 @@ __device__ void update_shared_point(Point* shared_point, } } +// Warp-level: uses laneId and stride 32 for coordinate copy +template +__device__ void update_shared_point_warp( + Point* shared_point, const T* data_ptr, int id, int dim, int laneId) +{ + shared_point->id = id; + shared_point->Dim = dim; + for (size_t i = laneId; i < dim; i += 32) { + shared_point->coords[i] = data_ptr[(size_t)(id) * (size_t)(dim) + i]; + } +} + +// Promote half dataset vector to float in shared memory (once per query) +template +__device__ void update_shared_point_half_to_float(Point* shared_point, + const half* data_ptr, + int id, + int dim) +{ + const __half* half_ptr = reinterpret_cast(data_ptr); + shared_point->id = id; + shared_point->Dim = dim; + const size_t base = (size_t)id * (size_t)dim; + // Odd dim => odd rows start at an under-aligned address; half2 loads would + // fault (cudaErrorMisalignedAddress). Use scalar promotion in that case. + if ((dim & 1) != 0) { + for (size_t i = threadIdx.x; i < (size_t)dim; i += blockDim.x) { + shared_point->coords[i] = __half2float(half_ptr[base + i]); + } + return; + } + for (size_t i = threadIdx.x * 2; i + 1 < (size_t)dim; i += (size_t)blockDim.x * 2) { + float2 promoted = __half22float2(*reinterpret_cast(&half_ptr[base + i])); + float2* coord_pair = reinterpret_cast(&shared_point->coords[i]); + *coord_pair = promoted; + } +} + +template +__device__ void update_shared_point_warp_half_to_float( + Point* shared_point, const half* data_ptr, int id, int dim, int laneId) +{ + const __half* half_ptr = reinterpret_cast(data_ptr); + shared_point->id = id; + shared_point->Dim = dim; + const size_t base = (size_t)id * (size_t)dim; + // Odd dim => odd rows are under-aligned for half2 loads; promote scalar. + if ((dim & 1) != 0) { + for (size_t i = laneId; i < (size_t)dim; i += 32) { + shared_point->coords[i] = __half2float(half_ptr[base + i]); + } + return; + } + for (size_t i = laneId * 2; i + 1 < (size_t)dim; i += 64) { + float2 promoted = __half22float2(*reinterpret_cast(&half_ptr[base + i])); + float2* coord_pair = reinterpret_cast(&shared_point->coords[i]); + *coord_pair = promoted; + } +} + +template +__device__ void update_shared_point_warp_fp16_query_smem( + Point<__half, accT>* shared_point, const half* data_ptr, int id, int dim, int laneId) +{ + const __half* half_ptr = reinterpret_cast(data_ptr); + shared_point->id = id; + shared_point->Dim = dim; + const size_t base = (size_t)id * (size_t)dim; + // Odd dim => odd rows are under-aligned for half2 loads; copy scalar. + if ((dim & 1) != 0) { + for (size_t i = laneId; i < (size_t)dim; i += 32) { + shared_point->coords[i] = half_ptr[base + i]; + } + return; + } + for (size_t i = laneId * 2; i + 1 < (size_t)dim; i += 64) { + *reinterpret_cast(&shared_point->coords[i]) = + *reinterpret_cast(&half_ptr[base + i]); + } +} + +template +__device__ void update_shared_point_warp_fp16_query_smem( + Point<__half, accT>* shared_point, const float* data_ptr, int id, int dim, int laneId) +{ + shared_point->id = id; + shared_point->Dim = dim; + const size_t base = (size_t)id * (size_t)dim; + // Odd dim => odd rows are under-aligned for float2 loads; convert scalar. + if ((dim & 1) != 0) { + for (size_t i = laneId; i < (size_t)dim; i += 32) { + shared_point->coords[i] = __float2half(data_ptr[base + i]); + } + return; + } + for (size_t i = laneId * 2; i + 1 < (size_t)dim; i += 64) { + float2 v = *reinterpret_cast(&data_ptr[base + i]); + *reinterpret_cast(&shared_point->coords[i]) = __float22half2_rn(v); + } +} + +template +__device__ std::enable_if_t && !std::is_same_v, void> +update_shared_point_warp_fp16_query_smem( + Point<__half, accT>* shared_point, const T* data_ptr, int id, int dim, int laneId) +{ + shared_point->id = id; + shared_point->Dim = dim; + const size_t base = (size_t)id * (size_t)dim; + for (size_t i = laneId; i < (size_t)dim; i += 32) { + shared_point->coords[i] = __float2half(static_cast(data_ptr[base + i])); + } +} + // Update the graph from the results of the query list (or reverse edge list) template __global__ void write_graph_edges_kernel(raft::device_matrix_view graph, @@ -391,6 +1300,7 @@ __global__ void create_reverse_edge_list(void* query_list_ptr, for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num_queries; i += blockDim.x * gridDim.x) { + int read_idx = i * query_list[i].maxSize; int cand_count = query_list[i + 1].size - query_list[i].size; for (int j = 0; j < cand_count; j++) { diff --git a/cpp/src/neighbors/vamana.cuh b/cpp/src/neighbors/vamana.cuh index d2a73809ed..f202ef20af 100644 --- a/cpp/src/neighbors/vamana.cuh +++ b/cpp/src/neighbors/vamana.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -45,6 +45,7 @@ namespace cuvs::neighbors::vamana { * * The following distance metrics are supported: * - L2Expanded + * - L2SqrtExpanded * * Usage example: * @code{.cpp} diff --git a/cpp/src/neighbors/vamana_build_half.cu b/cpp/src/neighbors/vamana_build_half.cu new file mode 100644 index 0000000000..817f8c0488 --- /dev/null +++ b/cpp/src/neighbors/vamana_build_half.cu @@ -0,0 +1,33 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "vamana.cuh" +#include +#include + +namespace cuvs::neighbors::vamana { + +#define RAFT_INST_VAMANA_BUILD(T, IdxT) \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::vamana::index_params& params, \ + raft::device_matrix_view dataset) \ + -> cuvs::neighbors::vamana::index \ + { \ + return cuvs::neighbors::vamana::build(handle, params, dataset); \ + } \ + \ + auto build(raft::resources const& handle, \ + const cuvs::neighbors::vamana::index_params& params, \ + raft::host_matrix_view dataset) \ + -> cuvs::neighbors::vamana::index \ + { \ + return cuvs::neighbors::vamana::build(handle, params, dataset); \ + } + +RAFT_INST_VAMANA_BUILD(half, uint32_t); + +#undef RAFT_INST_VAMANA_BUILD + +} // namespace cuvs::neighbors::vamana diff --git a/cpp/src/neighbors/vamana_serialize_half.cu b/cpp/src/neighbors/vamana_serialize_half.cu new file mode 100644 index 0000000000..099db7b2bb --- /dev/null +++ b/cpp/src/neighbors/vamana_serialize_half.cu @@ -0,0 +1,13 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "vamana_serialize.cuh" +#include + +namespace cuvs::neighbors::vamana { + +CUVS_INST_VAMANA_SERIALIZE(half); + +} // namespace cuvs::neighbors::vamana diff --git a/examples/cpp/src/vamana_example.cu b/examples/cpp/src/vamana_example.cu index b43eb20b1d..fc578abe1f 100644 --- a/examples/cpp/src/vamana_example.cu +++ b/examples/cpp/src/vamana_example.cu @@ -136,6 +136,19 @@ int main(int argc, char* argv[]) max_fraction, iters, codebook_prefix); + } else if (dtype == "half" || dtype == "fp16") { + // Read in binary dataset file + auto dataset = read_bin_dataset<__half, int64_t>(dev_resources, data_fname, INT_MAX); + + // Simple build example to create graph and write to a file + vamana_build_and_write<__half>(dev_resources, + raft::make_const_mdspan(dataset.view()), + out_fname, + degree, + max_visited, + max_fraction, + iters, + codebook_prefix); } else { usage(); } diff --git a/python/cuvs/cuvs/neighbors/vamana/vamana.pyx b/python/cuvs/cuvs/neighbors/vamana/vamana.pyx index a2b4be0a61..277b6d5597 100644 --- a/python/cuvs/cuvs/neighbors/vamana/vamana.pyx +++ b/python/cuvs/cuvs/neighbors/vamana/vamana.pyx @@ -1,5 +1,5 @@ # -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 # # cython: language_level=3 @@ -169,13 +169,14 @@ def build(IndexParams index_params, dataset, resources=None): struct controls the degree of the final graph. The following distance metrics are supported: - - L2Expanded + - L2Expanded (sqeuclidean) + - L2SqrtExpanded (l2 / euclidean distance) Parameters ---------- index_params : IndexParams object dataset : CUDA array interface compliant matrix shape (n_samples, dim) - Supported dtype [float, int8, uint8] + Supported dtype [float, float16, int8, uint8] {resources_docstring} Returns @@ -201,6 +202,7 @@ def build(IndexParams index_params, dataset, resources=None): # in RAFT to make this a single call dataset_ai = wrap_array(dataset) _check_input_array(dataset_ai, [np.dtype('float32'), + np.dtype('float16'), np.dtype('int8'), np.dtype('uint8')]) diff --git a/python/cuvs/cuvs/tests/test_vamana.py b/python/cuvs/cuvs/tests/test_vamana.py index f692e4ae69..6158d5da77 100644 --- a/python/cuvs/cuvs/tests/test_vamana.py +++ b/python/cuvs/cuvs/tests/test_vamana.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 import numpy as np @@ -12,6 +12,8 @@ def _gen_data(shape, dtype): rng = np.random.default_rng(12345) if dtype == np.float32: return rng.random(shape, dtype=np.float32) + if dtype == np.float16: + return rng.random(shape, dtype=np.float32).astype(np.float16) if dtype == np.int8: # keep small magnitude to avoid overflow if used elsewhere return rng.integers(low=-10, high=10, size=shape, dtype=np.int8) @@ -20,7 +22,7 @@ def _gen_data(shape, dtype): raise AssertionError("unexpected dtype in test helper") -@pytest.mark.parametrize("dtype", [np.float32, np.int8, np.uint8]) +@pytest.mark.parametrize("dtype", [np.float32, np.float16, np.int8, np.uint8]) def test_vamana_build_basic(dtype): n_rows, n_cols = 1000, 16 data = _gen_data((n_rows, n_cols), dtype) @@ -36,7 +38,7 @@ def test_vamana_build_basic(dtype): assert "Index(type=Vamana" in repr(idx) -@pytest.mark.parametrize("dtype", [np.float32, np.int8, np.uint8]) +@pytest.mark.parametrize("dtype", [np.float32, np.float16, np.int8, np.uint8]) @pytest.mark.skip( reason="Skipping host build test because of CUDA error " "in C++ API. Reference issue: " @@ -69,9 +71,9 @@ def test_vamana_serialize(tmp_path, include_dataset): def test_vamana_build_rejects_unsupported_dtype(): - # float16 is not in the accepted list in the build wrapper; expect failure + # e.g. float64 is not supported for Vamana build n_rows, n_cols = 64, 8 - data = _gen_data((n_rows, n_cols), np.float32).astype(np.float16) + data = np.random.default_rng(1).random((n_rows, n_cols), dtype=np.float64) from pylibraft.common import device_ndarray data_dev = device_ndarray(data)