Skip to content

Commit c73df13

Browse files
committed
Tests added, to check correctness, made with pytest
1 parent c3cd380 commit c73df13

14 files changed

Lines changed: 143 additions & 221 deletions

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,6 @@ benchmark/dataset
77
*.gl
88
*.zip
99
*.npz
10-
benchmark/*.csv
10+
benchmark/*.csv
11+
12+
*_pycache__/

dummy.csv

Lines changed: 0 additions & 5 deletions
This file was deleted.

generateGraph.cpp

Lines changed: 0 additions & 18 deletions
This file was deleted.

main.cpp

Lines changed: 0 additions & 48 deletions
This file was deleted.

src/AliasTable.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,10 @@ class LRUTable
108108
public:
109109
LRUTable(int64_t maxCapacity, std::function<std::span<float>(int64_t)> weightFunc) : MAXCAPACITY(maxCapacity), getWeights(weightFunc) {}
110110

111-
const AliasTable& get_alias_table(int64_t nodeId);
111+
const AliasTable& get_alias_table(int64_t nodeId,int64_t nodeDegree);
112112
};
113113

114-
inline const AliasTable& LRUTable::get_alias_table(int64_t nodeID){
114+
inline const AliasTable& LRUTable::get_alias_table(int64_t nodeID, int64_t nodeDegree){
115115
auto it = isNodePresent.find(nodeID);
116116
if(it != isNodePresent.end()){ // cache HIT
117117

@@ -125,6 +125,10 @@ inline const AliasTable& LRUTable::get_alias_table(int64_t nodeID){
125125
// Weights
126126
std::span<float> span_w = getWeights(nodeID);
127127
std::vector<float> weights(span_w.begin(),span_w.end());
128+
129+
if(weights.empty()){
130+
weights = std::vector<float>(nodeDegree,1.0);
131+
}
128132

129133
AliasTable newTable(weights);
130134

src/CSR.hpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef CSR_H
22
#define CSR_H
33
#include "MemoryMap.hpp"
4+
#include "csrFilegen.hpp"
45
#include <span>
56
#include <cstddef>
67

@@ -41,7 +42,7 @@ class CSR{
4142
int64_t* get_colPtr(){
4243
return colPtr;
4344
}
44-
float* get_weights(){
45+
float* get_weightsPtr(){
4546
return weightPtr;
4647
}
4748
int64_t get_num_nodes(){
@@ -75,8 +76,8 @@ inline CSR::CSR(const char* graphPath){
7576

7677
if(has_weights){
7778
char* base = static_cast<char*>(this->graphMap->get_data()); // 1byte shifts
78-
79-
this->weightPtr = reinterpret_cast<float*>(base + header.offset_col + this->sizeofcolPtr);
79+
uint64_t offset_weights = align64(header.offset_col + this->sizeofcolPtr);
80+
this->weightPtr = reinterpret_cast<float*>(base + offset_weights);
8081
}else this->weightPtr = nullptr;
8182
}
8283

@@ -89,23 +90,30 @@ inline CSR::~CSR(){
8990
}
9091

9192
inline int64_t CSR::get_degree(int64_t nodeId){
92-
// return degree of nodeId, how many conections it have
93+
// return degree of nodeId, how many conections it have
94+
if(nodeId >= num_nodes) return 0;
9395
return this->nnzRow[nodeId+1] - this->nnzRow[nodeId];
9496
}
9597
inline std::span<int64_t> CSR::get_edges(int64_t nodeId){
9698
// return the edges of nodeId
99+
if(nodeId>= num_nodes){
100+
throw std::runtime_error("NodeId is greater than number of nodes.");
101+
}
97102
int64_t* p =&this->colPtr[this->nnzRow[nodeId]];
98103
int64_t d = this->get_degree(nodeId);
99104
return std::span<int64_t>(p,d);
100105
}
101106

102107
inline std::span<float> CSR::get_weights(int64_t nodeId){
103108
// return the weights of nodeId, only valid if has_weights is true
109+
if(nodeId>= num_nodes){
110+
throw std::runtime_error("NodeId is greater than number of nodes.");
111+
}
104112
if(!has_weights){
105113
throw std::runtime_error("Graph does not have weights");
106114
}
107115
float* p =&this->weightPtr[this->nnzRow[nodeId]];
108-
int64_t d = this->get_degree(nodeId);
116+
int64_t d = this->get_degree(nodeId);
109117
return std::span<float>(p,d);
110118
}
111119

src/Graphzero.hpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ inline void Graphzero::ReservoirSampling(int64_t nodeId, int k, int64_t* result)
8686
for(int i = 0; i < deg; i++){
8787
result[i] = neighbours[i];
8888
}
89+
return;
8990
}
9091

9192
// selection k neighbours, first k elements
@@ -143,10 +144,7 @@ inline void Graphzero::randomWalk(int64_t start_node, int64_t length, float p, f
143144
auto weightFunc = [this](int64_t nodeID){
144145
if(this->has_weights) return this->storage->get_weights(nodeID);
145146
else{
146-
int64_t nodeDegree = this->storage->get_degree(nodeID);
147-
std::vector<float> weights(nodeDegree);
148-
for (float& w : weights) w = RNG.rand();
149-
return std::span<float>(weights);
147+
return std::span<float>(); // lrutable detect if empty
150148
}
151149
};
152150

@@ -156,12 +154,15 @@ inline void Graphzero::randomWalk(int64_t start_node, int64_t length, float p, f
156154

157155
walk[0] = start_node;
158156

159-
for (int64_t i = 1; i < length; i++)
157+
for (int64_t i = 1; i <= length; i++)
160158
{
161159
int64_t degree = storage->get_degree(curr);
162-
if (degree == 0) break; // Dead end
160+
if (degree == 0){ // Dead end
161+
walk[i] = curr;
162+
continue;
163+
};
163164

164-
auto table = lruCache.get_alias_table(curr);
165+
auto table = lruCache.get_alias_table(curr,this->storage->get_degree(curr));
165166
if(i == 1){
166167
next = storage->get_edges(curr)[table.sample()];
167168
}else {
@@ -175,7 +176,7 @@ inline void Graphzero::randomWalk(int64_t start_node, int64_t length, float p, f
175176

176177
//keep p = 1.0f and q = 1.0f for default values.
177178
inline std::vector<int64_t>* Graphzero::batchRandomWalk(const std::vector<int64_t>& startNodes, int64_t walkLength, float p, float q){
178-
std::vector<int64_t>* results = new std::vector<int64_t>(walkLength*startNodes.size());
179+
std::vector<int64_t>* results = new std::vector<int64_t>((walkLength+1)*startNodes.size());
179180

180181
// set only for random walks
181182
storage->set_access_pattern(true);
@@ -184,7 +185,7 @@ inline std::vector<int64_t>* Graphzero::batchRandomWalk(const std::vector<int64_
184185
for(signed long long i = 0; i < startNodes.size(); i++){
185186

186187
// thread safe
187-
int64_t offset = i*walkLength;
188+
int64_t offset = i*(walkLength+1);
188189

189190
randomWalk(startNodes[i],walkLength,p,q, results->data() + offset);
190191
}
@@ -195,15 +196,15 @@ inline std::vector<int64_t>* Graphzero::batchRandomWalk(const std::vector<int64_
195196
}
196197

197198
inline std::vector<int64_t>* Graphzero::batchRandomUniformWalk(const std::vector<int64_t>& startNodes, int64_t walkLength){
198-
std::vector<int64_t>* results = new std::vector<int64_t>(walkLength*startNodes.size());
199+
std::vector<int64_t>* results = new std::vector<int64_t>((walkLength + 1)*startNodes.size());
199200

200201
// set only for random walks
201202
storage->set_access_pattern(true);
202203

203204
#pragma omp parallel for
204205
for(signed long long i = 0; i < startNodes.size(); i++){
205206
// walking here
206-
int64_t offset = (int64_t)(i*walkLength);
207+
int64_t offset = (int64_t)(i*(walkLength+1));
207208
int64_t curr = startNodes[i], next;
208209
(*results)[offset] = curr;
209210
for(int64_t j = 1; j < walkLength; ++j){
@@ -237,7 +238,7 @@ inline std::vector<int64_t>* Graphzero::batchRandomFanout(const std::vector<int6
237238

238239
// thread safe write into results
239240
int64_t offset = i * K;
240-
ReservoirSampling(startNodes[i], (int)K, results->data() + offset);
241+
this->ReservoirSampling(startNodes[i], (int)K, results->data() + offset);
241242
}
242243

243244
// reset

src/bindings.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ R"doc(Get the degree of a node.
4747
// the Graph object alive as the owner.
4848
return nb::ndarray<nb::numpy, int64_t, nb::shape<1>>(
4949
const_cast<int64_t*>(edges.data()), // pointer to data
50-
{ edges.size() }, // shape
51-
nb::cast(self) // owner: keep Graph instance alive
50+
{ edges.size() } // shape
5251
);
5352
},
5453
nb::keep_alive<0,1>(),
@@ -68,8 +67,7 @@ R"doc(Returns the neighbours of a node.
6867
// the Graph object alive as the owner.
6968
return nb::ndarray<nb::numpy, float, nb::shape<1>>(
7069
const_cast<float*>(weights.data()), // pointer to data
71-
{ weights.size() }, // shape
72-
nb::cast(self) // owner: keep Graph instance alive
70+
{ weights.size() } // shape
7371
);
7472
},
7573
nb::keep_alive<0,1>(),
@@ -93,7 +91,7 @@ R"doc(Returns the edge weight of neighbours of a node.
9391

9492
return nb::ndarray<nb::numpy, int64_t, nb::shape<2>>(
9593
walkData->data(),
96-
{startNodes.size(),static_cast<size_t>(walkLength) },
94+
{startNodes.size(),static_cast<size_t>(walkLength + 1) },
9795
owner
9896
);
9997
},
@@ -124,7 +122,7 @@ R"doc(Performs 2nd-order random walks (Node2Vec style).
124122

125123
return nb::ndarray<nb::numpy, int64_t, nb::shape<2>>(
126124
walkData->data(),
127-
{startNodes.size(),static_cast<size_t>(walkLength)},
125+
{startNodes.size(),static_cast<size_t>(walkLength + 1)},
128126
owner
129127
);
130128
},
@@ -170,7 +168,7 @@ R"doc(Performs uniform random fanout sampling.
170168
)
171169

172170
.def("sample_neighbours", [](Graphzero &self, int64_t startNode, int64_t K) {
173-
std::vector<int64_t>* walkData = new std::vector<int64_t>(K);
171+
std::vector<int64_t>* walkData = new std::vector<int64_t>((std::min)(K,self.get_storage()->get_degree(startNode)));
174172
self.ReservoirSampling(startNode, K, walkData->data());
175173

176174
nb::capsule owner(walkData, [](void* p) noexcept {

tests/dataloader_test.py

Lines changed: 0 additions & 65 deletions
This file was deleted.

tests/dummy.csv

Lines changed: 0 additions & 5 deletions
This file was deleted.

0 commit comments

Comments
 (0)