diff --git a/benchmarking/utils.py b/benchmarking/utils.py index 76688a4..8037263 100644 --- a/benchmarking/utils.py +++ b/benchmarking/utils.py @@ -217,6 +217,7 @@ def benchmark_jnp_dot(matrix_dims: Tuple[int, int, int], """Benchmarks `jnp.dot`.""" baseline_op = _get_baseline_op(matrix_dims, dtype, average, seed) timings = _benchmark_op(baseline_op, num_trials) + print(timings) return np.array(timings) / average @@ -230,4 +231,5 @@ def benchmark_factorized_algorithm(factors: np.ndarray, factorization_algorithm_op = _get_factorization_op( factors, matrix_dims, dtype, average, seed) timings = _benchmark_op(factorization_algorithm_op, num_trials) + print(timings) return np.array(timings) / average