Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion c/src/neighbors/vamana.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#include <cstdint>
#include <cuda_fp16.h>
#include <dlpack/dlpack.h>

#include <raft/core/error.hpp>
Expand Down Expand Up @@ -82,6 +83,8 @@ extern "C" cuvsError_t cuvsVamanaIndexDestroy(cuvsVamanaIndex_t index_c_ptr)
if (index.addr != 0) {
if (index.dtype.code == kDLFloat && index.dtype.bits == 32) {
delete reinterpret_cast<cuvs::neighbors::vamana::index<float, uint32_t>*>(index.addr);
} else if (index.dtype.code == kDLFloat && index.dtype.bits == 16) {
delete reinterpret_cast<cuvs::neighbors::vamana::index<half, uint32_t>*>(index.addr);
} else if (index.dtype.code == kDLInt && index.dtype.bits == 8) {
delete reinterpret_cast<cuvs::neighbors::vamana::index<int8_t, uint32_t>*>(index.addr);
} else if (index.dtype.code == kDLUInt && index.dtype.bits == 8) {
Expand All @@ -100,6 +103,10 @@ extern "C" cuvsError_t cuvsVamanaIndexGetDims(cuvsVamanaIndex_t index, int* dim)
auto index_ptr =
reinterpret_cast<cuvs::neighbors::vamana::index<float, uint32_t>*>(index->addr);
*dim = index_ptr->dim();
} else if (index->dtype.code == kDLFloat && index->dtype.bits == 16) {
auto index_ptr =
reinterpret_cast<cuvs::neighbors::vamana::index<half, uint32_t>*>(index->addr);
*dim = index_ptr->dim();
} else if (index->dtype.code == kDLInt && index->dtype.bits == 8) {
auto index_ptr =
reinterpret_cast<cuvs::neighbors::vamana::index<int8_t, uint32_t>*>(index->addr);
Expand All @@ -123,6 +130,8 @@ extern "C" cuvsError_t cuvsVamanaBuild(cuvsResources_t res,

if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) {
index->addr = reinterpret_cast<uintptr_t>(_build<float>(res, params, dataset_tensor));
} else if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 16) {
index->addr = reinterpret_cast<uintptr_t>(_build<half>(res, params, dataset_tensor));
} else if (dataset.dtype.code == kDLInt && dataset.dtype.bits == 8) {
index->addr = reinterpret_cast<uintptr_t>(_build<int8_t>(res, params, dataset_tensor));
} else if (dataset.dtype.code == kDLUInt && dataset.dtype.bits == 8) {
Expand All @@ -143,6 +152,8 @@ extern "C" cuvsError_t cuvsVamanaSerialize(cuvsResources_t res,
return cuvs::core::translate_exceptions([=] {
if (index->dtype.code == kDLFloat && index->dtype.bits == 32) {
_serialize<float>(res, filename, index, include_dataset);
} else if (index->dtype.code == kDLFloat && index->dtype.bits == 16) {
_serialize<half>(res, filename, index, include_dataset);
} else if (index->dtype.code == kDLInt && index->dtype.bits == 8) {
_serialize<int8_t>(res, filename, index, include_dataset);
} else if (index->dtype.code == kDLUInt && index->dtype.bits == 8) {
Expand Down
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1414,10 +1414,12 @@ if(NOT BUILD_CPU_ONLY)
src/neighbors/tiered_index.cu
src/neighbors/sparse_brute_force.cu
src/neighbors/vamana_build_float.cu
src/neighbors/vamana_build_half.cu
src/neighbors/vamana_build_uint8.cu
src/neighbors/vamana_build_int8.cu
src/neighbors/vamana_codebooks_float.cu
src/neighbors/vamana_serialize_float.cu
src/neighbors/vamana_serialize_half.cu
src/neighbors/vamana_serialize_uint8.cu
src/neighbors/vamana_serialize_int8.cu
src/preprocessing/quantize/scalar.cu
Expand Down
17 changes: 17 additions & 0 deletions cpp/include/cuvs/neighbors/vamana.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#pragma once

#include "common.hpp"
#include <cuda_fp16.h>
#include <cuvs/distance/distance.hpp>
#include <cuvs/neighbors/common.hpp>
#include <raft/core/device_mdspan.hpp>
Expand Down Expand Up @@ -344,6 +345,16 @@ auto build(raft::resources const& res,
raft::host_matrix_view<const float, int64_t, raft::row_major> dataset)
-> cuvs::neighbors::vamana::index<float, uint32_t>;

auto build(raft::resources const& res,
const cuvs::neighbors::vamana::index_params& params,
raft::device_matrix_view<const half, int64_t, raft::row_major> dataset)
-> cuvs::neighbors::vamana::index<half, uint32_t>;

auto build(raft::resources const& res,
const cuvs::neighbors::vamana::index_params& params,
raft::host_matrix_view<const half, int64_t, raft::row_major> dataset)
-> cuvs::neighbors::vamana::index<half, uint32_t>;

/**
* @brief Build the index from the dataset for efficient DiskANN search.
*
Expand Down Expand Up @@ -520,6 +531,12 @@ void serialize(raft::resources const& handle,
bool include_dataset = true,
bool sector_aligned = false);

void serialize(raft::resources const& handle,
const std::string& file_prefix,
const cuvs::neighbors::vamana::index<half, uint32_t>& index,
bool include_dataset = true,
bool sector_aligned = false);

/**
* Save the index to file.
*
Expand Down
Loading
Loading