diff --git a/c/include/cuvs/core/c_api.h b/c/include/cuvs/core/c_api.h index 00d4729481..5663cdf975 100644 --- a/c/include/cuvs/core/c_api.h +++ b/c/include/cuvs/core/c_api.h @@ -266,7 +266,7 @@ CUVS_EXPORT cuvsError_t cuvsMatrixCopy(cuvsResources_t res, DLManagedTensor* src * @param[in] res cuvsResources_t opaque C handle * @param[in] src Pointer to DLManagedTensor to copy * @param[in] start First row index to include in the output - * @param[in] end Last row index to include in the output + * @param[in] end One past the last row index to include in the output * @param[out] dst Pointer to DLManagedTensor to receive slice from matrix */ CUVS_EXPORT cuvsError_t cuvsMatrixSliceRows( diff --git a/c/include/cuvs/distance/pairwise_distance.h b/c/include/cuvs/distance/pairwise_distance.h index 0dbe41d5d7..c97c0044b4 100644 --- a/c/include/cuvs/distance/pairwise_distance.h +++ b/c/include/cuvs/distance/pairwise_distance.h @@ -41,9 +41,10 @@ extern "C" { * @endcode * * @param[in] res cuvs resources object for managing expensive resources - * @param[in] x first set of points (size n*k) - * @param[in] y second set of points (size m*k) - * @param[out] dist output distance matrix (size n*m) + * @param[in] x first set of points (size n*k). Must have the same floating point dtype as `y` + * @param[in] y second set of points (size m*k). Must have the same floating point dtype as `x` + * @param[out] dist output distance matrix (size n*m). Must be float32 for float16 inputs, and + * match the input dtype otherwise * @param[in] metric distance to evaluate * @param[in] metric_arg metric argument (used for Minkowski distance) */ diff --git a/c/src/core/c_api.cpp b/c/src/core/c_api.cpp index f4e3664482..e8a4f56bcd 100644 --- a/c/src/core/c_api.cpp +++ b/c/src/core/c_api.cpp @@ -339,33 +339,51 @@ extern "C" cuvsError_t cuvsMatrixSliceRows(cuvsResources_t res, DLManagedTensor* dst_managed) { return cuvs::core::translate_exceptions([=] { - RAFT_EXPECTS(end >= start, "end index must be greater than start index"); + RAFT_EXPECTS(dst_managed != nullptr, "dst tensor should be initialized"); + + dst_managed->dl_tensor = DLTensor{}; + dst_managed->manager_ctx = nullptr; + dst_managed->deleter = nullptr; + + RAFT_EXPECTS(src_managed != nullptr, "src tensor should be initialized"); DLTensor& src = src_managed->dl_tensor; DLTensor& dst = dst_managed->dl_tensor; - RAFT_EXPECTS(src.ndim <= 2, "src should be a 1 or 2 dimensional tensor"); + RAFT_EXPECTS(src.ndim == 1 || src.ndim == 2, "src should be a 1 or 2 dimensional tensor"); RAFT_EXPECTS(src.shape != nullptr, "shape should be initialized in the src tensor"); + RAFT_EXPECTS(src.data != nullptr, "data should be initialized in the src tensor"); + RAFT_EXPECTS(start >= 0 && end >= start && end <= src.shape[0], + "row slice range must satisfy 0 <= start <= end <= src.shape[0]"); - dst.dtype = src.dtype; - dst.device = src.device; - dst.ndim = src.ndim; - dst.shape = new int64_t[dst.ndim]; - dst.shape[0] = end - start; + auto shape = std::make_unique(src.ndim); + std::unique_ptr strides; + shape[0] = end - start; int64_t row_strides = 1; - if (dst.ndim == 2) { - dst.shape[1] = src.shape[1]; - row_strides = dst.shape[1]; + if (src.ndim == 1 && src.strides) { + strides = std::make_unique(1); + row_strides = strides[0] = src.strides[0]; + } + + if (src.ndim == 2) { + shape[1] = src.shape[1]; + row_strides = shape[1]; if (src.strides) { - dst.strides = new int64_t[2]; - row_strides = dst.strides[0] = src.strides[0]; - dst.strides[1] = src.strides[1]; + strides = std::make_unique(2); + row_strides = strides[0] = src.strides[0]; + strides[1] = src.strides[1]; } } - dst.data = static_cast(src.data) + start * row_strides * (dst.dtype.bits / 8); + dst.dtype = src.dtype; + dst.device = src.device; + dst.ndim = src.ndim; + dst.shape = shape.release(); + dst.strides = strides.release(); + dst.byte_offset = src.byte_offset; + dst.data = static_cast(src.data) + start * row_strides * (dst.dtype.bits / 8); dst_managed->deleter = cuvsMatrixDestroy; }); } diff --git a/c/src/distance/pairwise_distance.cpp b/c/src/distance/pairwise_distance.cpp index f3981ee059..758ebfb54d 100644 --- a/c/src/distance/pairwise_distance.cpp +++ b/c/src/distance/pairwise_distance.cpp @@ -52,15 +52,26 @@ extern "C" cuvsError_t cuvsPairwiseDistance(cuvsResources_t res, { return cuvs::core::translate_exceptions([=] { auto x_dt = x_tensor->dl_tensor.dtype; - auto y_dt = x_tensor->dl_tensor.dtype; - auto dist_dt = x_tensor->dl_tensor.dtype; + auto y_dt = y_tensor->dl_tensor.dtype; + auto dist_dt = distances_tensor->dl_tensor.dtype; if ((x_dt.code != kDLFloat) || (y_dt.code != kDLFloat) || (dist_dt.code != kDLFloat)) { RAFT_FAIL("Inputs to cuvsPairwiseDistance must all be floating point tensors"); } - if ((x_dt.bits != y_dt.bits) || (x_dt.bits != dist_dt.bits)) { - RAFT_FAIL("Inputs to cuvsPairwiseDistance must all have the same dtype"); + if (x_dt.lanes != 1 || y_dt.lanes != 1 || dist_dt.lanes != 1) { + RAFT_FAIL("Inputs to cuvsPairwiseDistance must all have a single dtype lane"); + } + + if (x_dt.bits != y_dt.bits) { + RAFT_FAIL("X and Y inputs to cuvsPairwiseDistance must have the same dtype"); + } + + auto expected_dist_bits = x_dt.bits == 16 ? 32 : x_dt.bits; + if (dist_dt.bits != expected_dist_bits) { + RAFT_FAIL( + "distances output to cuvsPairwiseDistance must have dtype float32 for float16 inputs " + "and match the input dtype otherwise"); } bool x_row_major; diff --git a/c/tests/core/c_api.c b/c/tests/core/c_api.c index b84738c797..df5407c787 100644 --- a/c/tests/core/c_api.c +++ b/c/tests/core/c_api.c @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -9,6 +9,73 @@ #include #include +static void expect_matrix_slice_error(cuvsResources_t res, + DLManagedTensor* src, + int64_t start, + int64_t end) +{ + int64_t sentinel_stride = 0; + DLManagedTensor dst = {0}; + dst.dl_tensor.strides = &sentinel_stride; + + if (cuvsMatrixSliceRows(res, src, start, end, &dst) != CUVS_ERROR) { exit(EXIT_FAILURE); } + if (dst.dl_tensor.shape != NULL || dst.dl_tensor.strides != NULL || dst.deleter != NULL) { + exit(EXIT_FAILURE); + } +} + +static void test_matrix_slice_rows(cuvsResources_t res) +{ + int32_t data[] = {0, 1, 2, 3, 4, 5}; + int64_t shape_2d[] = {3, 2}; + int64_t sentinel_stride = 0; + DLManagedTensor src_2d = {0}; + src_2d.dl_tensor.data = data; + src_2d.dl_tensor.device = (DLDevice){kDLCPU, 0}; + src_2d.dl_tensor.ndim = 2; + src_2d.dl_tensor.dtype = (DLDataType){kDLInt, 32, 1}; + src_2d.dl_tensor.shape = shape_2d; + src_2d.dl_tensor.byte_offset = 0; + + DLManagedTensor dst_2d = {0}; + dst_2d.dl_tensor.strides = &sentinel_stride; + if (cuvsMatrixSliceRows(res, &src_2d, 1, 3, &dst_2d) != CUVS_SUCCESS) { + exit(EXIT_FAILURE); + } + if (dst_2d.dl_tensor.ndim != 2 || dst_2d.dl_tensor.shape[0] != 2 || + dst_2d.dl_tensor.shape[1] != 2 || dst_2d.dl_tensor.data != (void*)(data + 2) || + dst_2d.dl_tensor.strides != NULL || dst_2d.deleter == NULL) { + exit(EXIT_FAILURE); + } + dst_2d.deleter(&dst_2d); + + int64_t shape_1d[] = {6}; + DLManagedTensor src_1d = {0}; + src_1d.dl_tensor.data = data; + src_1d.dl_tensor.device = (DLDevice){kDLCPU, 0}; + src_1d.dl_tensor.ndim = 1; + src_1d.dl_tensor.dtype = (DLDataType){kDLInt, 32, 1}; + src_1d.dl_tensor.shape = shape_1d; + + DLManagedTensor dst_1d = {0}; + if (cuvsMatrixSliceRows(res, &src_1d, 1, 4, &dst_1d) != CUVS_SUCCESS) { + exit(EXIT_FAILURE); + } + if (dst_1d.dl_tensor.ndim != 1 || dst_1d.dl_tensor.shape[0] != 3 || + dst_1d.dl_tensor.data != (void*)(data + 1) || dst_1d.dl_tensor.strides != NULL || + dst_1d.deleter == NULL) { + exit(EXIT_FAILURE); + } + dst_1d.deleter(&dst_1d); + + expect_matrix_slice_error(res, &src_2d, -1, 1); + expect_matrix_slice_error(res, &src_2d, 0, 4); + + DLManagedTensor src_0d = src_2d; + src_0d.dl_tensor.ndim = 0; + expect_matrix_slice_error(res, &src_0d, 0, 0); +} + int main() { // Create resources @@ -73,6 +140,8 @@ int main() cuvsError_t free_error_pinned = cuvsRMMHostFree(ptr3, 1024); if (free_error_pinned == CUVS_ERROR) { exit(EXIT_FAILURE); } + test_matrix_slice_rows(res); + // Destroy resources error = cuvsResourcesDestroy(res); if (error == CUVS_ERROR) { exit(EXIT_FAILURE); } diff --git a/c/tests/distance/pairwise_distance_c.cu b/c/tests/distance/pairwise_distance_c.cu index f53efa28c9..952b1e38ca 100644 --- a/c/tests/distance/pairwise_distance_c.cu +++ b/c/tests/distance/pairwise_distance_c.cu @@ -4,13 +4,20 @@ */ #include +#include #include +#include #include #include #include +#include #include +#include + +#include +#include extern "C" void run_pairwise_distance(int64_t n_rows, int64_t n_queries, @@ -28,6 +35,69 @@ void generate_random_data(T* devPtr, size_t size) raft::random::uniform(handle, r, devPtr, size, T(0.1), T(2.0)); }; +namespace { + +struct DeviceMatrixTensor { + DLManagedTensor tensor{}; + int64_t shape[2]{}; + + DeviceMatrixTensor(void* data, int64_t rows, int64_t cols, DLDataType dtype) + { + shape[0] = rows; + shape[1] = cols; + tensor.dl_tensor.data = data; + tensor.dl_tensor.device = DLDevice{kDLCUDA, 0}; + tensor.dl_tensor.ndim = 2; + tensor.dl_tensor.dtype = dtype; + tensor.dl_tensor.shape = shape; + tensor.dl_tensor.strides = nullptr; + tensor.dl_tensor.byte_offset = 0; + } +}; + +DLDataType float_dtype(uint8_t bits) { return DLDataType{kDLFloat, bits, 1}; } + +void expect_pairwise_distance_error_contains(DLDataType x_dtype, + DLDataType y_dtype, + DLDataType distances_dtype, + std::string_view expected_error) +{ + cuvsResources_t res; + ASSERT_EQ(cuvsResourcesCreate(&res), CUVS_SUCCESS); + + void *x_data, *y_data, *distances_data; + RAFT_CUDA_TRY(cudaMalloc(&x_data, 2 * 3 * sizeof(double))); + RAFT_CUDA_TRY(cudaMalloc(&y_data, 4 * 3 * sizeof(double))); + RAFT_CUDA_TRY(cudaMalloc(&distances_data, 2 * 4 * sizeof(double))); + + DeviceMatrixTensor x_tensor{x_data, 2, 3, x_dtype}; + DeviceMatrixTensor y_tensor{y_data, 4, 3, y_dtype}; + DeviceMatrixTensor distances_tensor{distances_data, 2, 4, distances_dtype}; + + auto status = cuvsPairwiseDistance(res, + &x_tensor.tensor, + &y_tensor.tensor, + &distances_tensor.tensor, + L2Expanded, + 2.0f); + EXPECT_EQ(status, CUVS_ERROR); + if (status == CUVS_ERROR) { + const char* error_text = cuvsGetLastErrorText(); + if (error_text == nullptr) { + ADD_FAILURE() << "Expected cuvsPairwiseDistance to set an error message"; + } else { + EXPECT_NE(std::string{error_text}.find(expected_error), std::string::npos) << error_text; + } + } + + RAFT_CUDA_TRY(cudaFree(x_data)); + RAFT_CUDA_TRY(cudaFree(y_data)); + RAFT_CUDA_TRY(cudaFree(distances_data)); + ASSERT_EQ(cuvsResourcesDestroy(res), CUVS_SUCCESS); +} + +} // namespace + TEST(PairwiseDistanceC, Distance) { int64_t n_rows = 8096; @@ -51,3 +121,62 @@ TEST(PairwiseDistanceC, Distance) cudaFree(query_data); cudaFree(distances_data); } + +TEST(PairwiseDistanceC, FailsWithMismatchedInputDtypes) +{ + expect_pairwise_distance_error_contains(float_dtype(32), + float_dtype(64), + float_dtype(32), + "X and Y inputs to cuvsPairwiseDistance must have the " + "same dtype"); +} + +TEST(PairwiseDistanceC, FailsWithMismatchedFloatOutputDtype) +{ + expect_pairwise_distance_error_contains( + float_dtype(32), + float_dtype(32), + float_dtype(64), + "distances output to cuvsPairwiseDistance must have dtype float32 for float16 inputs"); +} + +TEST(PairwiseDistanceC, FailsWithFloat16OutputForFloat16Inputs) +{ + expect_pairwise_distance_error_contains( + float_dtype(16), + float_dtype(16), + float_dtype(16), + "distances output to cuvsPairwiseDistance must have dtype float32 for float16 inputs"); +} + +TEST(PairwiseDistanceC, AllowsFloat32OutputForFloat16Inputs) +{ + cuvsResources_t res; + ASSERT_EQ(cuvsResourcesCreate(&res), CUVS_SUCCESS); + + constexpr int64_t n_rows = 2; + constexpr int64_t n_queries = 3; + constexpr int64_t n_dim = 4; + + half *x_data, *y_data; + float* distances_data; + RAFT_CUDA_TRY(cudaMalloc(&x_data, sizeof(half) * n_rows * n_dim)); + RAFT_CUDA_TRY(cudaMalloc(&y_data, sizeof(half) * n_queries * n_dim)); + RAFT_CUDA_TRY(cudaMalloc(&distances_data, sizeof(float) * n_rows * n_queries)); + RAFT_CUDA_TRY(cudaMemset(x_data, 0, sizeof(half) * n_rows * n_dim)); + RAFT_CUDA_TRY(cudaMemset(y_data, 0, sizeof(half) * n_queries * n_dim)); + + DeviceMatrixTensor x_tensor{x_data, n_rows, n_dim, float_dtype(16)}; + DeviceMatrixTensor y_tensor{y_data, n_queries, n_dim, float_dtype(16)}; + DeviceMatrixTensor distances_tensor{distances_data, n_rows, n_queries, float_dtype(32)}; + + auto status = cuvsPairwiseDistance( + res, &x_tensor.tensor, &y_tensor.tensor, &distances_tensor.tensor, L2Expanded, 2.0f); + EXPECT_EQ(status, CUVS_SUCCESS) << (cuvsGetLastErrorText() ? cuvsGetLastErrorText() : ""); + if (status == CUVS_SUCCESS) { EXPECT_EQ(cuvsStreamSync(res), CUVS_SUCCESS); } + + RAFT_CUDA_TRY(cudaFree(x_data)); + RAFT_CUDA_TRY(cudaFree(y_data)); + RAFT_CUDA_TRY(cudaFree(distances_data)); + ASSERT_EQ(cuvsResourcesDestroy(res), CUVS_SUCCESS); +} diff --git a/rust/cuvs/examples/cagra.rs b/rust/cuvs/examples/cagra.rs index 2f0ee4e071..4203c7c639 100644 --- a/rust/cuvs/examples/cagra.rs +++ b/rust/cuvs/examples/cagra.rs @@ -22,7 +22,8 @@ fn cagra_example() -> Result<()> { // build the cagra index let build_params = IndexParams::new()?; - let index = Index::build(&res, &build_params, &dataset)?; + let dataset_device = ManagedTensor::from_ndarray(&dataset)?.to_device(&res)?; + let index = Index::build(&res, &build_params, dataset_device)?; println!("Indexed {}x{} datapoints into cagra index", n_datapoints, n_features); // use the first 4 points from the dataset as queries : will test that we get them back @@ -35,12 +36,12 @@ fn cagra_example() -> Result<()> { // CAGRA search API requires queries and outputs to be on device memory // copy query data over, and allocate new device memory for the distances/ neighbors // outputs - let queries = ManagedTensor::from(&queries).to_device(&res)?; + let queries = ManagedTensor::from_ndarray(&queries)?.to_device(&res)?; let mut neighbors_host = ndarray::Array::::zeros((n_queries, k)); - let neighbors = ManagedTensor::from(&neighbors_host).to_device(&res)?; + let neighbors = ManagedTensor::from_ndarray(&neighbors_host)?.to_device(&res)?; let mut distances_host = ndarray::Array::::zeros((n_queries, k)); - let distances = ManagedTensor::from(&distances_host).to_device(&res)?; + let distances = ManagedTensor::from_ndarray(&distances_host)?.to_device(&res)?; let search_params = SearchParams::new()?; diff --git a/rust/cuvs/src/brute_force.rs b/rust/cuvs/src/brute_force.rs index 413e8b0fb1..20a77c3840 100644 --- a/rust/cuvs/src/brute_force.rs +++ b/rust/cuvs/src/brute_force.rs @@ -13,14 +13,14 @@ use crate::resources::Resources; /// Brute Force KNN Index #[derive(Debug)] -pub struct Index { +pub struct Index<'a> { inner: ffi::cuvsBruteForceIndex_t, // cuVS brute_force::index stores a non-owning view into the dataset. // Keep the Rust tensor alive for as long as the C++ index may read it. - _dataset: Option, + _dataset: Option>, } -impl Index { +impl<'a> Index<'a> { /// Builds a new Brute Force KNN Index from the dataset for efficient search. /// /// # Arguments @@ -29,13 +29,13 @@ impl Index { /// * `metric` - DistanceType to use for building the index /// * `metric_arg` - Optional value of `p` for Minkowski distances /// * `dataset` - A row-major matrix on either the host or device to index - pub fn build>( + pub fn build>>( res: &Resources, metric: DistanceType, metric_arg: Option, dataset: T, - ) -> Result { - let dataset: ManagedTensor = dataset.into(); + ) -> Result> { + let dataset: ManagedTensor<'a> = dataset.into(); let mut index = Index::new()?; unsafe { check_cuvs(ffi::cuvsBruteForceBuild( @@ -51,7 +51,7 @@ impl Index { } /// Creates a new empty index - pub fn new() -> Result { + pub fn new() -> Result> { unsafe { let mut index = std::mem::MaybeUninit::::uninit(); check_cuvs(ffi::cuvsBruteForceIndexCreate(index.as_mut_ptr()))?; @@ -70,9 +70,9 @@ impl Index { pub fn search( &self, res: &Resources, - queries: &ManagedTensor, - neighbors: &ManagedTensor, - distances: &ManagedTensor, + queries: &ManagedTensor<'_>, + neighbors: &ManagedTensor<'_>, + distances: &ManagedTensor<'_>, ) -> Result<()> { unsafe { let prefilter = ffi::cuvsFilter { addr: 0, type_: ffi::cuvsFilterType::NO_FILTER }; @@ -89,7 +89,7 @@ impl Index { } } -impl Drop for Index { +impl Drop for Index<'_> { fn drop(&mut self) { if let Err(e) = check_cuvs(unsafe { ffi::cuvsBruteForceIndexDestroy(self.inner) }) { write!(stderr(), "failed to call bruteForceIndexDestroy {:?}", e) @@ -114,7 +114,7 @@ mod tests { let dataset_host = ndarray::Array::::random((n_datapoints, n_features), Uniform::new(0., 1.0)); - let dataset = ManagedTensor::from(&dataset_host).to_device(&res).unwrap(); + let dataset = ManagedTensor::from_ndarray(&dataset_host).unwrap().to_device(&res).unwrap(); println!("dataset {:#?}", dataset_host); @@ -132,12 +132,14 @@ mod tests { let k = 4; println!("queries! {:#?}", queries); - let queries = ManagedTensor::from(&queries).to_device(&res).unwrap(); + let queries = ManagedTensor::from_ndarray(&queries).unwrap().to_device(&res).unwrap(); let mut neighbors_host = ndarray::Array::::zeros((n_queries, k)); - let neighbors = ManagedTensor::from(&neighbors_host).to_device(&res).unwrap(); + let neighbors = + ManagedTensor::from_ndarray(&neighbors_host).unwrap().to_device(&res).unwrap(); let mut distances_host = ndarray::Array::::zeros((n_queries, k)); - let distances = ManagedTensor::from(&distances_host).to_device(&res).unwrap(); + let distances = + ManagedTensor::from_ndarray(&distances_host).unwrap().to_device(&res).unwrap(); index.search(&res, &queries, &neighbors, &distances).unwrap(); diff --git a/rust/cuvs/src/cagra/index.rs b/rust/cuvs/src/cagra/index.rs index d69a4d5033..cd3db706db 100644 --- a/rust/cuvs/src/cagra/index.rs +++ b/rust/cuvs/src/cagra/index.rs @@ -35,12 +35,12 @@ impl Index { /// * `res` - Resources to use /// * `params` - Parameters for building the index /// * `dataset` - A row-major matrix on either the host or device to index - pub fn build>( + pub fn build<'a, T: Into>>( res: &Resources, params: &IndexParams, dataset: T, ) -> Result { - let dataset: ManagedTensor = dataset.into(); + let dataset: ManagedTensor<'a> = dataset.into(); let index = Index::new()?; unsafe { check_cuvs(ffi::cuvsCagraBuild(res.0, params.0, dataset.as_ptr(), index.0))?; @@ -70,9 +70,9 @@ impl Index { &self, res: &Resources, params: &SearchParams, - queries: &ManagedTensor, - neighbors: &ManagedTensor, - distances: &ManagedTensor, + queries: &ManagedTensor<'_>, + neighbors: &ManagedTensor<'_>, + distances: &ManagedTensor<'_>, ) -> Result<()> { unsafe { let prefilter = ffi::cuvsFilter { addr: 0, type_: ffi::cuvsFilterType::NO_FILTER }; @@ -108,10 +108,10 @@ impl Index { &self, res: &Resources, params: &SearchParams, - queries: &ManagedTensor, - neighbors: &ManagedTensor, - distances: &ManagedTensor, - bitset: &ManagedTensor, + queries: &ManagedTensor<'_>, + neighbors: &ManagedTensor<'_>, + distances: &ManagedTensor<'_>, + bitset: &ManagedTensor<'_>, ) -> Result<()> { unsafe { let prefilter = ffi::cuvsFilter { @@ -212,7 +212,9 @@ mod tests { ) -> (ndarray::Array2, Index) { let dataset = ndarray::Array::::random((N_DATAPOINTS, N_FEATURES), Uniform::new(0., 1.0)); - let index = Index::build(res, build_params, &dataset).expect("failed to build cagra index"); + let dataset_device = ManagedTensor::from_ndarray(&dataset).unwrap().to_device(res).unwrap(); + let index = + Index::build(res, build_params, dataset_device).expect("failed to build cagra index"); (dataset, index) } @@ -227,13 +229,15 @@ mod tests { k: usize, ) { let queries = dataset.slice(s![0..n_queries, ..]); - let queries = ManagedTensor::from(&queries).to_device(res).unwrap(); + let queries = ManagedTensor::from_ndarray(&queries).unwrap().to_device(res).unwrap(); let mut neighbors_host = ndarray::Array::::zeros((n_queries, k)); - let neighbors = ManagedTensor::from(&neighbors_host).to_device(res).unwrap(); + let neighbors = + ManagedTensor::from_ndarray(&neighbors_host).unwrap().to_device(res).unwrap(); let mut distances_host = ndarray::Array::::zeros((n_queries, k)); - let distances = ManagedTensor::from(&distances_host).to_device(res).unwrap(); + let distances = + ManagedTensor::from_ndarray(&distances_host).unwrap().to_device(res).unwrap(); let search_params = SearchParams::new().unwrap(); index.search(res, &search_params, &queries, &neighbors, &distances).expect("search failed"); @@ -281,8 +285,10 @@ mod tests { let dataset = ndarray::Array::::random((n_datapoints, n_features), Uniform::new(0., 1.0)); - let index = - Index::build(&res, &build_params, &dataset).expect("failed to create cagra index"); + let dataset_device = + ManagedTensor::from_ndarray(&dataset).unwrap().to_device(&res).unwrap(); + let index = Index::build(&res, &build_params, dataset_device) + .expect("failed to create cagra index"); // Build a bitset that includes only even-indexed rows let n_words = (n_datapoints + 31) / 32; @@ -292,18 +298,20 @@ mod tests { bitset_host[i / 32] |= 1u32 << (i % 32); } } - let bitset = ManagedTensor::from(&bitset_host).to_device(&res).unwrap(); + let bitset = ManagedTensor::from_ndarray(&bitset_host).unwrap().to_device(&res).unwrap(); // Query with the first 4 even-indexed rows let n_queries = 4; - let queries = dataset.slice(s![0..n_queries * 2;2, ..]); // rows 0, 2, 4, 6 - let queries = ManagedTensor::from(&queries).to_device(&res).unwrap(); + let queries = dataset.slice(s![0..n_queries * 2;2, ..]).to_owned(); // rows 0, 2, 4, 6 + let queries = ManagedTensor::from_ndarray(&queries).unwrap().to_device(&res).unwrap(); let k = 10; let mut neighbors_host = ndarray::Array::::zeros((n_queries, k)); - let neighbors = ManagedTensor::from(&neighbors_host).to_device(&res).unwrap(); - let mut distances_host = ndarray::Array::::zeros((n_queries, k)); - let distances = ManagedTensor::from(&distances_host).to_device(&res).unwrap(); + let neighbors = + ManagedTensor::from_ndarray(&neighbors_host).unwrap().to_device(&res).unwrap(); + let distances_host = ndarray::Array::::zeros((n_queries, k)); + let distances = + ManagedTensor::from_ndarray(&distances_host).unwrap().to_device(&res).unwrap(); let search_params = SearchParams::new().unwrap(); diff --git a/rust/cuvs/src/cagra/mod.rs b/rust/cuvs/src/cagra/mod.rs index 9043b17386..a1518e1077 100644 --- a/rust/cuvs/src/cagra/mod.rs +++ b/rust/cuvs/src/cagra/mod.rs @@ -27,7 +27,8 @@ //! //! // build the cagra index //! let build_params = IndexParams::new()?; -//! let index = Index::build(&res, &build_params, &dataset)?; +//! let dataset_device = ManagedTensor::from_ndarray(&dataset)?.to_device(&res)?; +//! let index = Index::build(&res, &build_params, dataset_device)?; //! println!( //! "Indexed {}x{} datapoints into cagra index", //! n_datapoints, n_features @@ -43,12 +44,12 @@ //! // CAGRA search API requires queries and outputs to be on device memory //! // copy query data over, and allocate new device memory for the distances/ neighbors //! // outputs -//! let queries = ManagedTensor::from(&queries).to_device(&res)?; +//! let queries = ManagedTensor::from_ndarray(&queries)?.to_device(&res)?; //! let mut neighbors_host = ndarray::Array::::zeros((n_queries, k)); -//! let neighbors = ManagedTensor::from(&neighbors_host).to_device(&res)?; +//! let neighbors = ManagedTensor::from_ndarray(&neighbors_host)?.to_device(&res)?; //! //! let mut distances_host = ndarray::Array::::zeros((n_queries, k)); -//! let distances = ManagedTensor::from(&distances_host).to_device(&res)?; +//! let distances = ManagedTensor::from_ndarray(&distances_host)?.to_device(&res)?; //! //! let search_params = SearchParams::new()?; //! @@ -76,7 +77,7 @@ //! //! // Build an index (using some dataset) //! let build_params = IndexParams::new()?; -//! // let index = Index::build(&res, &build_params, &dataset)?; +//! // let index = Index::build(&res, &build_params, dataset_device)?; //! //! // Save the index to disk (including the dataset) //! // index.serialize(&res, "/path/to/index.bin", true)?; diff --git a/rust/cuvs/src/cluster/kmeans/mod.rs b/rust/cuvs/src/cluster/kmeans/mod.rs index 5015f49f45..9d67371597 100644 --- a/rust/cuvs/src/cluster/kmeans/mod.rs +++ b/rust/cuvs/src/cluster/kmeans/mod.rs @@ -23,10 +23,10 @@ //! let n_clusters = 8; //! let dataset = //! ndarray::Array::::random((n_datapoints, n_features), Uniform::new(0., 1.0)); -//! let dataset = ManagedTensor::from(&dataset).to_device(&res)?; +//! let dataset = ManagedTensor::from_ndarray(&dataset)?.to_device(&res)?; //! //! let centroids_host = ndarray::Array::::zeros((n_clusters, n_features)); -//! let mut centroids = ManagedTensor::from(¢roids_host).to_device(&res)?; +//! let mut centroids = ManagedTensor::from_ndarray(¢roids_host)?.to_device(&res)?; //! //! // find the centroids with the kmeans index //! let kmeans_params = kmeans::Params::new()?.set_n_clusters(n_clusters as i32); @@ -56,9 +56,9 @@ use crate::resources::Resources; pub fn fit( res: &Resources, params: &Params, - x: &ManagedTensor, - sample_weight: &Option, - centroids: &mut ManagedTensor, + x: &ManagedTensor<'_>, + sample_weight: &Option>, + centroids: &mut ManagedTensor<'_>, ) -> Result<(f64, i32)> { let mut inertia: f64 = 0.0; let mut niter: i32 = 0; @@ -95,10 +95,10 @@ pub fn fit( pub fn predict( res: &Resources, params: &Params, - x: &ManagedTensor, - sample_weight: &Option, - centroids: &ManagedTensor, - labels: &mut ManagedTensor, + x: &ManagedTensor<'_>, + sample_weight: &Option>, + centroids: &ManagedTensor<'_>, + labels: &mut ManagedTensor<'_>, normalize_weight: bool, ) -> Result { let mut inertia: f64 = 0.0; @@ -128,7 +128,11 @@ pub fn predict( /// * `res` - Resources to use /// * `x` - Input matrix in device memory - shape (m, k) /// * `centroids` - Centroids calculated by fit in device memory, shape (n_clusters, k) -pub fn cluster_cost(res: &Resources, x: &ManagedTensor, centroids: &ManagedTensor) -> Result { +pub fn cluster_cost( + res: &Resources, + x: &ManagedTensor<'_>, + centroids: &ManagedTensor<'_>, +) -> Result { let mut inertia: f64 = 0.0; unsafe { @@ -159,10 +163,11 @@ mod tests { let n_features = 16; let dataset = ndarray::Array::::random((n_datapoints, n_features), Uniform::new(0., 1.0)); - let dataset = ManagedTensor::from(&dataset).to_device(&res).unwrap(); + let dataset = ManagedTensor::from_ndarray(&dataset).unwrap().to_device(&res).unwrap(); let centroids_host = ndarray::Array::::zeros((n_clusters, n_features)); - let mut centroids = ManagedTensor::from(¢roids_host).to_device(&res).unwrap(); + let mut centroids = + ManagedTensor::from_ndarray(¢roids_host).unwrap().to_device(&res).unwrap(); let params = Params::new().unwrap().set_n_clusters(n_clusters as i32); @@ -176,7 +181,8 @@ mod tests { assert!(n_iter >= 1); let mut labels_host = ndarray::Array::::zeros((n_clusters,)); - let mut labels = ManagedTensor::from(&labels_host).to_device(&res).unwrap(); + let mut labels = + ManagedTensor::from_ndarray(&labels_host).unwrap().to_device(&res).unwrap(); // make sure the prediction for each centroid is the centroid itself predict(&res, ¶ms, ¢roids, &None, ¢roids, &mut labels, false).unwrap(); diff --git a/rust/cuvs/src/distance/mod.rs b/rust/cuvs/src/distance/mod.rs index 36a5850905..285e560afd 100644 --- a/rust/cuvs/src/distance/mod.rs +++ b/rust/cuvs/src/distance/mod.rs @@ -20,9 +20,9 @@ use crate::resources::Resources; /// * `metric_arg` - Optional value of `p` for Minkowski distances pub fn pairwise_distance( res: &Resources, - x: &ManagedTensor, - y: &ManagedTensor, - distances: &ManagedTensor, + x: &ManagedTensor<'_>, + y: &ManagedTensor<'_>, + distances: &ManagedTensor<'_>, metric: DistanceType, metric_arg: Option, ) -> Result<()> { @@ -53,10 +53,12 @@ mod tests { let n_features = 16; let dataset = ndarray::Array::::random((n_datapoints, n_features), Uniform::new(0., 1.0)); - let dataset_device = ManagedTensor::from(&dataset).to_device(&res).unwrap(); + let dataset_device = + ManagedTensor::from_ndarray(&dataset).unwrap().to_device(&res).unwrap(); let mut distances_host = ndarray::Array::::zeros((n_datapoints, n_datapoints)); - let distances = ManagedTensor::from(&distances_host).to_device(&res).unwrap(); + let distances = + ManagedTensor::from_ndarray(&distances_host).unwrap().to_device(&res).unwrap(); pairwise_distance( &res, diff --git a/rust/cuvs/src/dlpack.rs b/rust/cuvs/src/dlpack.rs index 1687f88d17..e05240b245 100644 --- a/rust/cuvs/src/dlpack.rs +++ b/rust/cuvs/src/dlpack.rs @@ -3,39 +3,90 @@ * SPDX-License-Identifier: Apache-2.0 */ -use std::convert::From; +use std::marker::PhantomData; -use crate::error::{Result, check_cuvs}; +use crate::error::{Error, Result, check_cuvs}; use crate::resources::Resources; /// ManagedTensor is a wrapper around a dlpack DLManagedTensor object. /// This lets you pass matrices in device or host memory into cuvs. #[derive(Debug)] -pub struct ManagedTensor(ffi::DLManagedTensor); +pub struct ManagedTensor<'a> { + tensor: ffi::DLManagedTensor, + shape: Box<[i64]>, + _borrow: PhantomData<&'a ()>, +} pub trait IntoDtype { fn ffi_dtype() -> ffi::DLDataType; } -impl ManagedTensor { - pub fn as_ptr(&self) -> *mut ffi::DLManagedTensor { - &self.0 as *const _ as *mut _ +impl<'a> ManagedTensor<'a> { + pub(crate) fn as_ptr(&self) -> *mut ffi::DLManagedTensor { + &self.tensor as *const _ as *mut _ + } + + fn new( + data: *mut std::ffi::c_void, + device_type: ffi::DLDeviceType, + dtype: ffi::DLDataType, + shape: Box<[i64]>, + deleter: Option, + ) -> Self { + let dl_tensor = ffi::DLTensor { + data, + device: ffi::DLDevice { device_type, device_id: 0 }, + ndim: shape.len() as i32, + dtype, + shape: std::ptr::null_mut(), + strides: std::ptr::null_mut(), + byte_offset: 0, + }; + + let mut ret = Self { + tensor: ffi::DLManagedTensor { dl_tensor, manager_ctx: std::ptr::null_mut(), deleter }, + shape, + _borrow: PhantomData, + }; + ret.tensor.dl_tensor.shape = ret.shape.as_mut_ptr(); + ret + } + + /// Create a non-owning view of a row-major ndarray. + pub fn from_ndarray(arr: &'a ndarray::ArrayBase) -> Result + where + T: IntoDtype, + S: ndarray::RawData, + D: ndarray::Dimension, + { + let shape = ndarray_shape(arr)?; + validate_standard_layout(arr)?; + Ok(Self::new( + arr.as_ptr() as *mut std::ffi::c_void, + ffi::DLDeviceType::kDLCPU, + T::ffi_dtype(), + shape, + None, + )) } /// Creates a new ManagedTensor on the current GPU device, and copies /// the data into it. - pub fn to_device(&self, res: &Resources) -> Result { + pub fn to_device(&self, res: &Resources) -> Result> { unsafe { - let bytes = dl_tensor_bytes(&self.0.dl_tensor); + let bytes = dl_tensor_bytes(&self.tensor.dl_tensor); let mut device_data: *mut std::ffi::c_void = std::ptr::null_mut(); // allocate storage, copy over check_cuvs(ffi::cuvsRMMAlloc(res.0, &mut device_data as *mut _, bytes))?; - let mut ret = ManagedTensor(self.0); - ret.0.dl_tensor.data = device_data; - ret.0.deleter = Some(rmm_free_tensor); - ret.0.dl_tensor.device.device_type = ffi::DLDeviceType::kDLCUDA; + let ret = ManagedTensor::new( + device_data, + ffi::DLDeviceType::kDLCUDA, + self.tensor.dl_tensor.dtype, + self.shape.clone(), + Some(rmm_free_tensor), + ); check_cuvs(ffi::cuvsMatrixCopy(res.0, self.as_ptr(), ret.as_ptr()))?; @@ -53,11 +104,23 @@ impl ManagedTensor { res: &Resources, arr: &mut ndarray::ArrayBase, ) -> Result<()> { + validate_host_output(&self.tensor.dl_tensor, self.shape.as_ref(), arr)?; + unsafe { - let mut dst = self.0; - dst.dl_tensor.data = arr.as_mut_ptr() as *mut std::ffi::c_void; - dst.dl_tensor.device.device_type = ffi::DLDeviceType::kDLCPU; - dst.deleter = None; + let mut dst_shape = ndarray_shape(arr)?; + let mut dst = ffi::DLManagedTensor { + dl_tensor: ffi::DLTensor { + data: arr.as_mut_ptr() as *mut std::ffi::c_void, + device: ffi::DLDevice { device_type: ffi::DLDeviceType::kDLCPU, device_id: 0 }, + ndim: dst_shape.len() as i32, + dtype: T::ffi_dtype(), + shape: dst_shape.as_mut_ptr(), + strides: std::ptr::null_mut(), + byte_offset: 0, + }, + manager_ctx: std::ptr::null_mut(), + deleter: None, + }; check_cuvs(ffi::cuvsMatrixCopy(res.0, self.as_ptr(), &mut dst))?; Ok(()) @@ -71,10 +134,81 @@ fn dl_tensor_bytes(tensor: &ffi::DLTensor) -> usize { for dim in 0..tensor.ndim { bytes *= unsafe { (*tensor.shape.add(dim as usize)) as usize }; } - bytes *= (tensor.dtype.bits / 8) as usize; + bytes *= ((tensor.dtype.bits as usize * tensor.dtype.lanes as usize).div_ceil(8)) as usize; bytes } +fn ndarray_shape(arr: &ndarray::ArrayBase) -> Result> +where + S: ndarray::RawData, + D: ndarray::Dimension, +{ + if arr.ndim() > i32::MAX as usize { + return Err(Error::InvalidArgument(format!( + "ndarray rank {} does not fit in i32", + arr.ndim() + ))); + } + + arr.shape() + .iter() + .map(|&dim| { + i64::try_from(dim).map_err(|_| { + Error::InvalidArgument(format!("ndarray dimension {dim} does not fit in i64")) + }) + }) + .collect::>>() + .map(Vec::into_boxed_slice) +} + +fn validate_standard_layout(arr: &ndarray::ArrayBase) -> Result<()> +where + S: ndarray::RawData, + D: ndarray::Dimension, +{ + if arr.is_standard_layout() { + Ok(()) + } else { + Err(Error::InvalidArgument("ndarray must be in standard row-major layout".to_string())) + } +} + +fn validate_dtype(actual: ffi::DLDataType, expected: ffi::DLDataType) -> Result<()> { + if actual.code == expected.code + && actual.bits == expected.bits + && actual.lanes == expected.lanes + { + Ok(()) + } else { + Err(Error::InvalidArgument(format!( + "dtype mismatch: tensor has code={}, bits={}, lanes={} but ndarray has code={}, bits={}, lanes={}", + actual.code, actual.bits, actual.lanes, expected.code, expected.bits, expected.lanes + ))) + } +} + +fn validate_host_output( + src: &ffi::DLTensor, + src_shape: &[i64], + arr: &ndarray::ArrayBase, +) -> Result<()> +where + T: IntoDtype, + S: ndarray::RawData, + D: ndarray::Dimension, +{ + validate_standard_layout(arr)?; + let dst_shape = ndarray_shape(arr)?; + if src_shape != dst_shape.as_ref() { + return Err(Error::InvalidArgument(format!( + "shape mismatch: tensor has shape {:?} but ndarray has shape {:?}", + src_shape, + dst_shape.as_ref() + ))); + } + validate_dtype(src.dtype, T::ffi_dtype()) +} + unsafe extern "C" fn rmm_free_tensor(self_: *mut ffi::DLManagedTensor) { unsafe { let bytes = dl_tensor_bytes(&(*self_).dl_tensor); @@ -83,39 +217,21 @@ unsafe extern "C" fn rmm_free_tensor(self_: *mut ffi::DLManagedTensor) { } } -/// Create a non-owning view of a Tensor from a ndarray -impl, D: ndarray::Dimension> - From<&ndarray::ArrayBase> for ManagedTensor +impl<'a, T: IntoDtype, S: ndarray::RawData, D: ndarray::Dimension> + TryFrom<&'a ndarray::ArrayBase> for ManagedTensor<'a> { - fn from(arr: &ndarray::ArrayBase) -> Self { - // There is a draft PR out right now for creating dlpack directly from ndarray - // right now, but until its merged we have to implement ourselves - //https://github.com/rust-ndarray/ndarray/pull/1306/files - unsafe { - let mut ret = std::mem::MaybeUninit::::uninit(); - let tensor = ret.as_mut_ptr(); - (*tensor).data = arr.as_ptr() as *mut std::os::raw::c_void; - (*tensor).device = - ffi::DLDevice { device_type: ffi::DLDeviceType::kDLCPU, device_id: 0 }; - (*tensor).byte_offset = 0; - (*tensor).strides = std::ptr::null_mut(); // TODO: error if not rowmajor - (*tensor).ndim = arr.ndim() as i32; - (*tensor).shape = arr.shape().as_ptr() as *mut _; - (*tensor).dtype = T::ffi_dtype(); - ManagedTensor(ffi::DLManagedTensor { - dl_tensor: ret.assume_init(), - manager_ctx: std::ptr::null_mut(), - deleter: None, - }) - } + type Error = Error; + + fn try_from(arr: &'a ndarray::ArrayBase) -> Result { + ManagedTensor::from_ndarray(arr) } } -impl Drop for ManagedTensor { +impl Drop for ManagedTensor<'_> { fn drop(&mut self) { unsafe { - if let Some(deleter) = self.0.deleter { - deleter(&mut self.0 as *mut _); + if let Some(deleter) = self.tensor.deleter { + deleter(&mut self.tensor as *mut _); } } } @@ -159,7 +275,8 @@ mod tests { fn test_from_ndarray() { let arr = ndarray::Array::::zeros((8, 4)); - let tensor = unsafe { (*(ManagedTensor::from(&arr).as_ptr())).dl_tensor }; + let tensor = ManagedTensor::from_ndarray(&arr).unwrap(); + let tensor = unsafe { (*tensor.as_ptr()).dl_tensor }; assert_eq!(tensor.ndim, 2); @@ -168,7 +285,47 @@ mod tests { assert_eq!(unsafe { *tensor.shape.add(1) }, 4); let arr = ndarray::Array::::zeros((8,)); - let tensor = unsafe { (*(ManagedTensor::from(&arr).as_ptr())).dl_tensor }; + let tensor = ManagedTensor::from_ndarray(&arr).unwrap(); + let tensor = unsafe { (*tensor.as_ptr()).dl_tensor }; assert_eq!(tensor.ndim, 1); } + + #[test] + fn test_from_ndarray_rejects_non_standard_layout() { + let arr = ndarray::Array::::zeros((8, 4)); + let view = arr.slice(ndarray::s![.., ..;2]); + + let err = ManagedTensor::from_ndarray(&view).unwrap_err(); + assert!(matches!(err, Error::InvalidArgument(_))); + } + + #[test] + fn test_to_host_validation_rejects_shape_mismatch() { + let src = ndarray::Array::::zeros((8, 4)); + let tensor = ManagedTensor::from_ndarray(&src).unwrap(); + let mut dst = ndarray::Array::::zeros((8, 3)); + + let err = validate_host_output::( + &tensor.tensor.dl_tensor, + tensor.shape.as_ref(), + &mut dst, + ) + .unwrap_err(); + assert!(matches!(err, Error::InvalidArgument(_))); + } + + #[test] + fn test_to_host_validation_rejects_dtype_mismatch() { + let src = ndarray::Array::::zeros((8, 4)); + let tensor = ManagedTensor::from_ndarray(&src).unwrap(); + let mut dst = ndarray::Array::::zeros((8, 4)); + + let err = validate_host_output::( + &tensor.tensor.dl_tensor, + tensor.shape.as_ref(), + &mut dst, + ) + .unwrap_err(); + assert!(matches!(err, Error::InvalidArgument(_))); + } } diff --git a/rust/cuvs/src/ivf_flat/index.rs b/rust/cuvs/src/ivf_flat/index.rs index a602d64c05..ba30dbfa54 100644 --- a/rust/cuvs/src/ivf_flat/index.rs +++ b/rust/cuvs/src/ivf_flat/index.rs @@ -22,12 +22,12 @@ impl Index { /// * `res` - Resources to use /// * `params` - Parameters for building the index /// * `dataset` - A row-major matrix on either the host or device to index - pub fn build>( + pub fn build<'a, T: Into>>( res: &Resources, params: &IndexParams, dataset: T, ) -> Result { - let dataset: ManagedTensor = dataset.into(); + let dataset: ManagedTensor<'a> = dataset.into(); let index = Index::new()?; unsafe { check_cuvs(ffi::cuvsIvfFlatBuild(res.0, params.0, dataset.as_ptr(), index.0))?; @@ -57,9 +57,9 @@ impl Index { &self, res: &Resources, params: &SearchParams, - queries: &ManagedTensor, - neighbors: &ManagedTensor, - distances: &ManagedTensor, + queries: &ManagedTensor<'_>, + neighbors: &ManagedTensor<'_>, + distances: &ManagedTensor<'_>, ) -> Result<()> { unsafe { let prefilter = ffi::cuvsFilter { addr: 0, type_: ffi::cuvsFilterType::NO_FILTER }; @@ -105,7 +105,8 @@ mod tests { let dataset = ndarray::Array::::random((n_datapoints, n_features), Uniform::new(0., 1.0)); - let dataset_device = ManagedTensor::from(&dataset).to_device(&res).unwrap(); + let dataset_device = + ManagedTensor::from_ndarray(&dataset).unwrap().to_device(&res).unwrap(); // build the ivf-flat index let index = Index::build(&res, &build_params, dataset_device) @@ -121,12 +122,14 @@ mod tests { // IvfFlat search API requires queries and outputs to be on device memory // copy query data over, and allocate new device memory for the distances/ neighbors // outputs - let queries = ManagedTensor::from(&queries).to_device(&res).unwrap(); + let queries = ManagedTensor::from_ndarray(&queries).unwrap().to_device(&res).unwrap(); let mut neighbors_host = ndarray::Array::::zeros((n_queries, k)); - let neighbors = ManagedTensor::from(&neighbors_host).to_device(&res).unwrap(); + let neighbors = + ManagedTensor::from_ndarray(&neighbors_host).unwrap().to_device(&res).unwrap(); let mut distances_host = ndarray::Array::::zeros((n_queries, k)); - let distances = ManagedTensor::from(&distances_host).to_device(&res).unwrap(); + let distances = + ManagedTensor::from_ndarray(&distances_host).unwrap().to_device(&res).unwrap(); let search_params = SearchParams::new().unwrap(); @@ -157,7 +160,8 @@ mod tests { let dataset = ndarray::Array::::random((n_datapoints, n_features), Uniform::new(0., 1.0)); - let dataset_device = ManagedTensor::from(&dataset).to_device(&res).unwrap(); + let dataset_device = + ManagedTensor::from_ndarray(&dataset).unwrap().to_device(&res).unwrap(); // Build the index once let index = Index::build(&res, &build_params, dataset_device) @@ -170,13 +174,15 @@ mod tests { for search_iter in 0..3 { let n_queries = 4; let queries = dataset.slice(s![0..n_queries, ..]); - let queries = ManagedTensor::from(&queries).to_device(&res).unwrap(); + let queries = ManagedTensor::from_ndarray(&queries).unwrap().to_device(&res).unwrap(); let mut neighbors_host = ndarray::Array::::zeros((n_queries, k)); - let neighbors = ManagedTensor::from(&neighbors_host).to_device(&res).unwrap(); + let neighbors = + ManagedTensor::from_ndarray(&neighbors_host).unwrap().to_device(&res).unwrap(); let mut distances_host = ndarray::Array::::zeros((n_queries, k)); - let distances = ManagedTensor::from(&distances_host).to_device(&res).unwrap(); + let distances = + ManagedTensor::from_ndarray(&distances_host).unwrap().to_device(&res).unwrap(); // This should work on every iteration because search() takes &self index diff --git a/rust/cuvs/src/ivf_flat/mod.rs b/rust/cuvs/src/ivf_flat/mod.rs index 7417116965..7f71bff20b 100644 --- a/rust/cuvs/src/ivf_flat/mod.rs +++ b/rust/cuvs/src/ivf_flat/mod.rs @@ -28,7 +28,8 @@ //! //! // build the ivf-flat index //! let build_params = IndexParams::new()?; -//! let index = Index::build(&res, &build_params, &dataset)?; +//! let dataset_device = ManagedTensor::from_ndarray(&dataset)?.to_device(&res)?; +//! let index = Index::build(&res, &build_params, dataset_device)?; //! println!( //! "Indexed {}x{} datapoints into ivf-flat index", //! n_datapoints, n_features @@ -44,12 +45,12 @@ //! // Ivf-Flat search API requires queries and outputs to be on device memory //! // copy query data over, and allocate new device memory for the distances/ neighbors //! // outputs -//! let queries = ManagedTensor::from(&queries).to_device(&res)?; +//! let queries = ManagedTensor::from_ndarray(&queries)?.to_device(&res)?; //! let mut neighbors_host = ndarray::Array::::zeros((n_queries, k)); -//! let neighbors = ManagedTensor::from(&neighbors_host).to_device(&res)?; +//! let neighbors = ManagedTensor::from_ndarray(&neighbors_host)?.to_device(&res)?; //! //! let mut distances_host = ndarray::Array::::zeros((n_queries, k)); -//! let distances = ManagedTensor::from(&distances_host).to_device(&res)?; +//! let distances = ManagedTensor::from_ndarray(&distances_host)?.to_device(&res)?; //! //! let search_params = SearchParams::new()?; //! diff --git a/rust/cuvs/src/ivf_pq/index.rs b/rust/cuvs/src/ivf_pq/index.rs index 492fefa0f1..c4557dc528 100644 --- a/rust/cuvs/src/ivf_pq/index.rs +++ b/rust/cuvs/src/ivf_pq/index.rs @@ -22,12 +22,12 @@ impl Index { /// * `res` - Resources to use /// * `params` - Parameters for building the index /// * `dataset` - A row-major matrix on either the host or device to index - pub fn build>( + pub fn build<'a, T: Into>>( res: &Resources, params: &IndexParams, dataset: T, ) -> Result { - let dataset: ManagedTensor = dataset.into(); + let dataset: ManagedTensor<'a> = dataset.into(); let index = Index::new()?; unsafe { check_cuvs(ffi::cuvsIvfPqBuild(res.0, params.0, dataset.as_ptr(), index.0))?; @@ -57,9 +57,9 @@ impl Index { &self, res: &Resources, params: &SearchParams, - queries: &ManagedTensor, - neighbors: &ManagedTensor, - distances: &ManagedTensor, + queries: &ManagedTensor<'_>, + neighbors: &ManagedTensor<'_>, + distances: &ManagedTensor<'_>, ) -> Result<()> { unsafe { check_cuvs(ffi::cuvsIvfPqSearch( @@ -102,7 +102,8 @@ mod tests { let dataset = ndarray::Array::::random((n_datapoints, n_features), Uniform::new(0., 1.0)); - let dataset_device = ManagedTensor::from(&dataset).to_device(&res).unwrap(); + let dataset_device = + ManagedTensor::from_ndarray(&dataset).unwrap().to_device(&res).unwrap(); // build the ivf-pq index let index = Index::build(&res, &build_params, dataset_device) @@ -118,12 +119,14 @@ mod tests { // Ivf-Pq search API requires queries and outputs to be on device memory // copy query data over, and allocate new device memory for the distances/ neighbors // outputs - let queries = ManagedTensor::from(&queries).to_device(&res).unwrap(); + let queries = ManagedTensor::from_ndarray(&queries).unwrap().to_device(&res).unwrap(); let mut neighbors_host = ndarray::Array::::zeros((n_queries, k)); - let neighbors = ManagedTensor::from(&neighbors_host).to_device(&res).unwrap(); + let neighbors = + ManagedTensor::from_ndarray(&neighbors_host).unwrap().to_device(&res).unwrap(); let mut distances_host = ndarray::Array::::zeros((n_queries, k)); - let distances = ManagedTensor::from(&distances_host).to_device(&res).unwrap(); + let distances = + ManagedTensor::from_ndarray(&distances_host).unwrap().to_device(&res).unwrap(); let search_params = SearchParams::new().unwrap(); @@ -154,7 +157,8 @@ mod tests { let dataset = ndarray::Array::::random((n_datapoints, n_features), Uniform::new(0., 1.0)); - let dataset_device = ManagedTensor::from(&dataset).to_device(&res).unwrap(); + let dataset_device = + ManagedTensor::from_ndarray(&dataset).unwrap().to_device(&res).unwrap(); // Build the index once let index = Index::build(&res, &build_params, dataset_device) @@ -167,13 +171,15 @@ mod tests { for search_iter in 0..3 { let n_queries = 4; let queries = dataset.slice(s![0..n_queries, ..]); - let queries = ManagedTensor::from(&queries).to_device(&res).unwrap(); + let queries = ManagedTensor::from_ndarray(&queries).unwrap().to_device(&res).unwrap(); let mut neighbors_host = ndarray::Array::::zeros((n_queries, k)); - let neighbors = ManagedTensor::from(&neighbors_host).to_device(&res).unwrap(); + let neighbors = + ManagedTensor::from_ndarray(&neighbors_host).unwrap().to_device(&res).unwrap(); let mut distances_host = ndarray::Array::::zeros((n_queries, k)); - let distances = ManagedTensor::from(&distances_host).to_device(&res).unwrap(); + let distances = + ManagedTensor::from_ndarray(&distances_host).unwrap().to_device(&res).unwrap(); // This should work on every iteration because search() takes &self index diff --git a/rust/cuvs/src/ivf_pq/mod.rs b/rust/cuvs/src/ivf_pq/mod.rs index c4676cd1aa..7bafeb0afb 100644 --- a/rust/cuvs/src/ivf_pq/mod.rs +++ b/rust/cuvs/src/ivf_pq/mod.rs @@ -25,7 +25,8 @@ //! //! // build the ivf-pq index //! let build_params = IndexParams::new()?; -//! let index = Index::build(&res, &build_params, &dataset)?; +//! let dataset_device = ManagedTensor::from_ndarray(&dataset)?.to_device(&res)?; +//! let index = Index::build(&res, &build_params, dataset_device)?; //! println!( //! "Indexed {}x{} datapoints into ivf-pq index", //! n_datapoints, n_features @@ -41,12 +42,12 @@ //! // Ivf-Pq search API requires queries and outputs to be on device memory //! // copy query data over, and allocate new device memory for the distances/ neighbors //! // outputs -//! let queries = ManagedTensor::from(&queries).to_device(&res)?; +//! let queries = ManagedTensor::from_ndarray(&queries)?.to_device(&res)?; //! let mut neighbors_host = ndarray::Array::::zeros((n_queries, k)); -//! let neighbors = ManagedTensor::from(&neighbors_host).to_device(&res)?; +//! let neighbors = ManagedTensor::from_ndarray(&neighbors_host)?.to_device(&res)?; //! //! let mut distances_host = ndarray::Array::::zeros((n_queries, k)); -//! let distances = ManagedTensor::from(&distances_host).to_device(&res)?; +//! let distances = ManagedTensor::from_ndarray(&distances_host)?.to_device(&res)?; //! //! let search_params = SearchParams::new()?; //! diff --git a/rust/cuvs/src/vamana/index.rs b/rust/cuvs/src/vamana/index.rs index 485f8ac008..d4e28aa4b7 100644 --- a/rust/cuvs/src/vamana/index.rs +++ b/rust/cuvs/src/vamana/index.rs @@ -30,12 +30,12 @@ impl Index { /// * `res` - Resources to use /// * `params` - Parameters for building the index /// * `dataset` - A row-major matrix on either the host or device to index - pub fn build>( + pub fn build<'a, T: Into>>( res: &Resources, params: &IndexParams, dataset: T, ) -> Result { - let dataset: ManagedTensor = dataset.into(); + let dataset: ManagedTensor<'a> = dataset.into(); let index = Index::new()?; unsafe { check_cuvs(ffi::cuvsVamanaBuild(res.0, params.0, dataset.as_ptr(), index.0))?; @@ -104,7 +104,8 @@ mod tests { let dataset = ndarray::Array::::random((n_datapoints, n_features), Uniform::new(0., 1.0)); - let dataset_device = ManagedTensor::from(&dataset).to_device(&res).unwrap(); + let dataset_device = + ManagedTensor::from_ndarray(&dataset).unwrap().to_device(&res).unwrap(); // build the vamana index let _index = Index::build(&res, &build_params, dataset_device)