From aa2491b0b6ceff1fdb9c602da63829de8d73260a Mon Sep 17 00:00:00 2001 From: Imke van Ooijen Date: Thu, 20 Nov 2025 15:34:03 +0100 Subject: [PATCH 01/14] Added Triton support from https://github.com/super-nexus/kernel_tuner/tree/triton_support --- examples/triton/conv2d_tuning.py | 306 ++++++++++++++++++ examples/triton/vec_add.py | 46 +++ kernel_tuner/backends/triton.py | 174 ++++++++++ kernel_tuner/core.py | 247 ++++---------- kernel_tuner/interface.py | 7 +- kernel_tuner/kernel_sources/__init__.py | 0 kernel_tuner/kernel_sources/kernel_source.py | 77 +++++ .../kernel_sources/kernel_source_fn.py | 160 +++++++++ .../kernel_sources/kernel_source_str.py | 174 ++++++++++ kernel_tuner/kernel_sources/model/__init__.py | 0 .../model/prepared_kernel_source_data.py | 10 + kernel_tuner/kernelbuilder.py | 5 +- kernel_tuner/language.py | 10 + kernel_tuner/observers/triton.py | 32 ++ kernel_tuner/util.py | 7 +- 15 files changed, 1072 insertions(+), 183 deletions(-) create mode 100644 examples/triton/conv2d_tuning.py create mode 100644 examples/triton/vec_add.py create mode 100644 kernel_tuner/backends/triton.py create mode 100644 kernel_tuner/kernel_sources/__init__.py create mode 100644 kernel_tuner/kernel_sources/kernel_source.py create mode 100644 kernel_tuner/kernel_sources/kernel_source_fn.py create mode 100644 kernel_tuner/kernel_sources/kernel_source_str.py create mode 100644 kernel_tuner/kernel_sources/model/__init__.py create mode 100644 kernel_tuner/kernel_sources/model/prepared_kernel_source_data.py create mode 100644 kernel_tuner/language.py create mode 100644 kernel_tuner/observers/triton.py diff --git a/examples/triton/conv2d_tuning.py b/examples/triton/conv2d_tuning.py new file mode 100644 index 000000000..5389b4a28 --- /dev/null +++ b/examples/triton/conv2d_tuning.py @@ -0,0 +1,306 @@ +import torch +import triton.language as tl +import numpy as np +from kernel_tuner.interface import tune_kernel +import os +import json +from datetime import datetime + +# Check for required environment variable +cache_dir = os.getenv('KERNEL_TUNER_CACHE_DIR') +cache_file_name = os.getenv('KERNEL_TUNER_CACHE_FILE', 'conv2d_tuning_results.json') + +if cache_dir is None: + raise ValueError("Environment variable KERNEL_TUNER_CACHE_DIR must be set") + +cache_file = os.path.join(cache_dir, cache_file_name) + + +def conv2d_output_size( + in_size: int, + kernel_size: int, + stride: int, + padding: int, + dilation: int, +) -> int: + """ + Determines the output size of a 2D convolution operation. + + Args: + in_size: Input size. + kernel_size: Kernel size. + stride: Stride. + padding: Padding. + + Returns: + Output size of 2D convolution. + """ + return (in_size + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1 + + +def conv2d_forward_kernel( + input_pointer, + weight_pointer, + output_pointer, + in_n, + input_height, + input_width, + out_c, + out_height, + out_width, + input_n_stride, + input_c_stride, + input_height_stride, + input_width_stride, + weight_n_stride, + weight_c_stride, + weight_height_stride, + weight_width_stride, + output_n_stride, + output_c_stride, + output_height_stride, + output_width_stride, + weight_c: tl.constexpr, + weight_height: tl.constexpr, + weight_width: tl.constexpr, + stride_height: tl.constexpr, + stride_width: tl.constexpr, + padding_height: tl.constexpr, + padding_width: tl.constexpr, + dilation_height: tl.constexpr, + dilation_width: tl.constexpr, + groups: tl.constexpr, + BLOCK_NI_HO_WO: tl.constexpr, + BLOCK_CI: tl.constexpr, + BLOCK_CO: tl.constexpr, +): + pid_ni_ho_wo = tl.program_id(0) + pid_co = tl.program_id(1) + pid_group = tl.program_id(2) + + # caculate in_n out_height out_weight value in kernel + ni_ho_wo_offset = pid_ni_ho_wo * BLOCK_NI_HO_WO + tl.arange(0, BLOCK_NI_HO_WO) + ni_ho_offset = ni_ho_wo_offset // out_width + in_n_point_value = ni_ho_offset // out_height + output_height_point_value = ni_ho_offset % out_height + output_width_point_value = ni_ho_wo_offset % out_width + + # Load the input and weight pointers. input and weight are of shape + # [in_n, groups, in_c, input_height, input_width] and [groups, out_c, in_c, weight_height, weight_width] + out_per_group_c = out_c // groups + output_c_offset = pid_co * BLOCK_CO + tl.arange(0, BLOCK_CO) + input_pointer += ( + input_n_stride * in_n_point_value + input_c_stride * pid_group * weight_c + )[:, None] + weight_pointer += ( + weight_n_stride * output_c_offset + + weight_n_stride * pid_group * out_per_group_c + )[None, :] + + accum = tl.zeros((BLOCK_NI_HO_WO, BLOCK_CO), dtype=tl.float32) + BLOCK_CI_COUNT = (weight_c + BLOCK_CI - 1) // BLOCK_CI + for hwc in range(weight_height * weight_width * BLOCK_CI_COUNT): + c = (hwc % BLOCK_CI_COUNT) * BLOCK_CI + hw = hwc // BLOCK_CI_COUNT + h = hw // weight_width + w = hw % weight_width + + input_c_offset = c + tl.arange(0, BLOCK_CI) + input_height_offset = ( + h * dilation_height + - padding_height + + stride_height * output_height_point_value + ) + input_width_offset = ( + w * dilation_width - padding_width + stride_width * output_width_point_value + ) + + curr_input_pointer = ( + input_pointer + + (input_c_stride * input_c_offset)[None, :] + + (input_height_stride * input_height_offset)[:, None] + + (input_width_stride * input_width_offset)[:, None] + ) + curr_weight_pointer = ( + weight_pointer + + (weight_c_stride * input_c_offset)[:, None] + + (weight_height_stride * h) + + (weight_width_stride * w) + ) + + input_mask = ( + (in_n_point_value < in_n)[:, None] + & (input_c_offset < weight_c)[None, :] + & (0 <= input_height_offset)[:, None] + & (input_height_offset < input_height)[:, None] + & (0 <= input_width_offset)[:, None] + & (input_width_offset < input_width)[:, None] + ) + weight_mask = (input_c_offset < weight_c)[:, None] & ( + output_c_offset < out_per_group_c + )[None, :] + + input_block = tl.load(curr_input_pointer, mask=input_mask) + weight_block = tl.load(curr_weight_pointer, mask=weight_mask) + + accum += tl.dot(input_block, weight_block, allow_tf32=False) + + output_pointer += ( + (output_n_stride * in_n_point_value)[:, None] + + (output_c_stride * (pid_group * out_per_group_c + output_c_offset))[None, :] + + (output_height_stride * output_height_point_value)[:, None] + + (output_width_stride * output_width_point_value)[:, None] + ) + output_mask = ( + (in_n_point_value < in_n)[:, None] + & (output_c_offset < out_per_group_c)[None, :] + & (output_height_point_value < out_height)[:, None] + & (output_width_point_value < out_width)[:, None] + ) + + tl.store(output_pointer, accum, mask=output_mask) + + +def tune_conv2d(batch_size=1, in_channels=64, height=32, width=32, + out_channels=128, kernel_size=3, stride=1, padding=1, + groups=1): + """ + Tune the conv2d kernel with different configurations. + """ + # Create sample inputs + input = torch.randn(batch_size, in_channels, height, width, + device='cuda', dtype=torch.float32) + weight = torch.randn(out_channels, in_channels//groups, kernel_size, kernel_size, + device='cuda', dtype=torch.float32) + + # Calculate output dimensions + out_height = conv2d_output_size(height, kernel_size, stride, padding, 1) + out_width = conv2d_output_size(width, kernel_size, stride, padding, 1) + output = torch.empty((batch_size, out_channels, out_height, out_width), + device='cuda', dtype=torch.float32) + + # Prepare all arguments for the kernel + arguments = [ + input, weight, output, + np.int32(batch_size), + np.int32(height), + np.int32(width), + np.int32(out_channels), + np.int32(out_height), + np.int32(out_width), + np.int32(input.stride(0)), + np.int32(input.stride(1)), + np.int32(input.stride(2)), + np.int32(input.stride(3)), + np.int32(weight.stride(0)), + np.int32(weight.stride(1)), + np.int32(weight.stride(2)), + np.int32(weight.stride(3)), + np.int32(output.stride(0)), + np.int32(output.stride(1)), + np.int32(output.stride(2)), + np.int32(output.stride(3)), + np.int32(in_channels//groups), # weight_c + np.int32(kernel_size), # weight_height + np.int32(kernel_size), # weight_width + np.int32(stride), # stride_height + np.int32(stride), # stride_width + np.int32(padding), # padding_height + np.int32(padding), # padding_width + np.int32(1), # dilation_height + np.int32(1), # dilation_width + np.int32(groups), # groups + ] + + # Define tuning parameters - only powers of 2 + tune_params = { + 'BLOCK_NI_HO_WO': [2 ** i for i in range(4, 10)], + 'BLOCK_CI': [2 ** i for i in range(4, 10)], + 'BLOCK_CO': [2 ** i for i in range(4, 10)], + 'num_stages': [1, 2, 3, 4], + 'num_warps': [1, 2, 4, 8], + } + + print(tune_params) + + # Define constraints + constraints = [ + "BLOCK_CI <= %d" % (in_channels//groups), + "BLOCK_CO <= %d" % out_channels, + ] + + # Problem size for the grid + problem_size = ( + batch_size * out_height * out_width, # Grid dimension 0 + out_channels, # Grid dimension 1 + groups, # Grid dimension 2 + ) + + # Grid divisor expressions + grid_div_x = ["BLOCK_NI_HO_WO"] + grid_div_y = ["BLOCK_CO"] + grid_div_z = ["1"] + + results, env = tune_kernel( + kernel_name='conv2d_forward_kernel', + kernel_source=conv2d_forward_kernel, + problem_size=problem_size, + arguments=arguments, + tune_params=tune_params, + restrictions=constraints, + lang='TRITON', + grid_div_x=grid_div_x, + grid_div_y=grid_div_y, + grid_div_z=grid_div_z, + block_size_names=['BLOCK_NI_HO_WO', 'BLOCK_CI', 'BLOCK_CO'], + strategy='genetic_algorithm', + strategy_options={ + 'maxiter': 1000, + 'popsize': 100, + }, + cache=cache_file, + ) + + return results + + +if __name__ == '__main__': + # Run tuning with moderately large input dimensions + results = tune_conv2d( + batch_size=16, + in_channels=128, + height=112, + width=112, + out_channels=256, + kernel_size=3, + stride=1, + padding=1, + groups=1 + ) + + + # Filter out failed compilations and find best config + valid_results = [result for result in results if isinstance(result['time'], (int, float))] + if valid_results: + best_config = min(valid_results, key=lambda x: x['time']) + print("\nBest configuration:") + print(json.dumps(best_config, indent=2)) + else: + print("\nNo valid configurations found - all compilations failed") + + # Create results dictionary with GPU info + all_results = { + "gpu_info": { + "gpu_name": torch.cuda.get_device_name() + }, + "results": valid_results + } + + # Add timestamp to filename + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_file = f'conv2d_results_{timestamp}.json' + + # Save results + import json + with open(output_file, 'w') as f: + json.dump(all_results, f, indent=2) \ No newline at end of file diff --git a/examples/triton/vec_add.py b/examples/triton/vec_add.py new file mode 100644 index 000000000..555cfba46 --- /dev/null +++ b/examples/triton/vec_add.py @@ -0,0 +1,46 @@ +import numpy +import triton.language as tl +import torch +from kernel_tuner import tune_kernel, run_kernel +from kernel_tuner.file_utils import store_output_file, store_metadata_file + + +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # note: `constexpr` so it can be used as a shape value. + ): + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + +size = 10000000 + +a = torch.randn(size, dtype=torch.float32) +b = torch.randn(size, dtype=torch.float32) +c = torch.zeros_like(b) +n = torch.tensor(size, dtype=torch.int32) + +args = [a, b, c, n] + +tune_params = dict() +tune_params["block_size_x"] = [2**i for i in range(10)] + +results, env = tune_kernel( + kernel_name="add_kernel", + kernel_source=add_kernel, + problem_size=size, + arguments=args, + tune_params=tune_params, + lang="triton" +) + +print("Hello") \ No newline at end of file diff --git a/kernel_tuner/backends/triton.py b/kernel_tuner/backends/triton.py new file mode 100644 index 000000000..cf6fda504 --- /dev/null +++ b/kernel_tuner/backends/triton.py @@ -0,0 +1,174 @@ +import logging +import numpy as np + +from kernel_tuner.backends.backend import GPUBackend +from kernel_tuner.observers.triton import TritonRuntimeObserver + +try: + import torch +except ImportError: + logging.error("Torch not available") + +try: + import triton + import triton.language as tl +except ImportError: + triton = None + tl = None + logging.error("Unable to load triton") + + +class TritonFunctions(GPUBackend): + + def __init__(self, device=0, iterations=7, compiler_options=None, observers=None): + if not triton or not torch: + logging.error("Triton or torch not available") + raise ImportError("Triton or torch not available") + + self.device_id = torch.cuda.current_device() + + self.device_properties = torch.cuda.get_device_properties(self.device_id) + self.name = torch.cuda.get_device_name(self.device_id) + self.max_threads = self.device_properties.max_threads_per_multi_processor + + env = dict() + env["device_name"] = self.name + env["max_threads"] = self.max_threads + env["iterations"] = iterations + env["compiler_options"] = compiler_options + self.env = env + + self.stream = torch.cuda.default_stream() + self.start = torch.cuda.Event(enable_timing=True) + self.end = torch.cuda.Event(enable_timing=True) + + # setup observers + self.observers = observers or [] + self.observers.append(TritonRuntimeObserver(self)) + for obs in self.observers: + obs.register_device(self) + + self.units = {"time": "ms", "power": "s,mW", "energy": "J"} + + super().__init__(device=device, iterations=iterations, compiler_options=compiler_options, observers=observers) + + def ready_argument_list(self, arguments): + # Allocate memory here + torch_args = [] + + for arg in arguments: + if isinstance(arg, torch.Tensor) and arg.dim() > 0: + torch_args.append(arg.cuda()) + elif isinstance(arg, torch.Tensor) and arg.dim() == 0: + scalar_value = arg.item() + torch_args.append(scalar_value) + elif isinstance(arg, np.ndarray): + torch_arg = torch.from_numpy(arg) + torch_arg_gpu = torch_arg.cuda() + torch_args.append(torch_arg_gpu) + elif isinstance(arg, np.generic): + scalar_value = arg.item() + torch_args.append(scalar_value) + else: + logging.warning("Unknown instance in triton functions") + + return torch_args + + def compile(self, kernel_instance, gpu_args=None): + logging.debug("Compiling triton kernel") + + if kernel_instance.kernel_fn is None: + raise ValueError("kernel_fn is None, currently Triton only supports callable kernel_source") + + if gpu_args is None: + raise ValueError("gpu_args is None, Triton needs gpu args to compile the kernel") + + grid = kernel_instance.grid + threads = kernel_instance.threads + jit_function = triton.jit(kernel_instance.kernel_fn) + params = kernel_instance.params + gpu_kwargs = self.build_gpu_kwargs(jit_function, threads, params) + + # Call the jit function in order to compile it + jit_function[grid](*gpu_args, **gpu_kwargs) + + return jit_function + + def start_event(self): + logging.debug("Start triton event") + self.start.record() + + def stop_event(self): + logging.debug("Stop triton event") + self.end.record() + + def kernel_finished(self): + logging.debug("Checking if kernel has finished") + return self.end.query() + + def run_kernel(self, func, gpu_args, threads, grid, stream=None, params=None): + if params is None: + raise ValueError("params is None, Triton needs params in order to set num_warps, num_ctas, etc.") + + # Run the kernel + if stream is None: + stream = self.stream + + gpu_kwargs = self.build_gpu_kwargs(func, threads, params) + + with torch.cuda.stream(stream): + logging.debug("Running triton kernel") + func[grid](*gpu_args, **gpu_kwargs) + + def build_gpu_kwargs(self, jit_fn, threads, params=None): + gpu_kwargs = {} + + if 'BLOCK_SIZE' in jit_fn.arg_names: + gpu_kwargs['BLOCK_SIZE'] = threads[0] + + if 'BLOCK_SIZE_X' in jit_fn.arg_names: + gpu_kwargs['BLOCK_SIZE_X'] = threads[0] + + if 'BLOCK_SIZE_Y' in jit_fn.arg_names: + gpu_kwargs['BLOCK_SIZE_Y'] = threads[1] + + if 'BLOCK_SIZE_Z' in jit_fn.arg_names: + gpu_kwargs['BLOCK_SIZE_Z'] = threads[2] + + if params is None: + return gpu_kwargs + + for param in params: + if param in jit_fn.arg_names: + gpu_kwargs[param] = params[param] + + # Check for Triton specific parameters + if 'num_warps' in params: + gpu_kwargs['num_warps'] = params['num_warps'] + if 'num_ctas' in params: + gpu_kwargs['num_ctas'] = params['num_ctas'] + if 'num_stages' in params: + gpu_kwargs['num_stages'] = params['num_stages'] + + return gpu_kwargs + + def synchronize(self): + torch.cuda.synchronize() + + def memset(self, allocation, value, size): + pass + + def memcpy_dtoh(self, dest, src): + pass + + def memcpy_htod(self, dest, src): + pass + + def copy_constant_memory_args(self, cmem_args): + raise NotImplementedError("Triton does not support constant memory") + + def copy_shared_memory_args(self, smem_args): + raise NotImplementedError("Triton does not support shared memory") + + def copy_texture_memory_args(self, texmem_args): + raise NotImplementedError("Triton does not support texture memory") \ No newline at end of file diff --git a/kernel_tuner/core.py b/kernel_tuner/core.py index 5352ced74..1a7e67b34 100644 --- a/kernel_tuner/core.py +++ b/kernel_tuner/core.py @@ -21,6 +21,8 @@ from kernel_tuner.backends.nvcuda import CudaFunctions from kernel_tuner.backends.opencl import OpenCLFunctions from kernel_tuner.backends.pycuda import PyCudaFunctions +from kernel_tuner.backends.triton import TritonFunctions +from kernel_tuner.kernel_sources.kernel_source import KernelSource from kernel_tuner.observers.nvml import NVMLObserver from kernel_tuner.observers.observer import ContinuousObserver, OutputObserver, PrologueObserver from kernel_tuner.observers.tegra import TegraObserver @@ -41,6 +43,7 @@ "name", "kernel_source", "kernel_string", + "kernel_fn", "temp_files", "threads", "grid", @@ -53,13 +56,27 @@ class KernelInstance(_KernelInstance): """Class that represents the specific parameterized instance of a kernel.""" + def __new__(cls, *args, **kwargs): + # Detect old-style calls (without kernel_fn) + if len(args) == 8: # old version + name, kernel_source, kernel_string, temp_files, threads, grid, params, arguments = args + kernel_fn = None + args = (name, kernel_source, kernel_string, kernel_fn, temp_files, threads, grid, params, arguments) + elif "kernel_fn" not in kwargs and len(args) == 8: + kwargs["kernel_fn"] = None + return super().__new__(cls, *args, **kwargs) + def delete_temp_files(self): """Delete any generated temp files.""" - for v in self.temp_files.values(): - util.delete_temp_file(v) + tmp_files_list = self.temp_files.values() if isinstance(self.temp_files, dict) else self.temp_files + for tmp_file in tmp_files_list: + util.delete_temp_file(tmp_file) def prepare_temp_files_for_error_msg(self): """Prepare temp file with source code, and return list of temp file names.""" + if type(self.kernel_source).__name__ == "KernelSourceFn": + return [] # TODO what do we want to return here? + temp_filename = util.get_temp_filename(suffix=self.kernel_source.get_suffix()) util.write_file(temp_filename, self.kernel_string) ret = [temp_filename] @@ -67,160 +84,7 @@ def prepare_temp_files_for_error_msg(self): return ret -class KernelSource(object): - """Class that holds the kernel sources. - - There is a primary kernel source, which can be either a source string, - a filename (indicating a file containing the kernel source code), - or a callable (generating the kernel source code). - There can additionally be (one or multiple) secondary kernel sources, which - must be filenames. - """ - - def __init__(self, kernel_name, kernel_sources, lang, defines=None): - if not isinstance(kernel_sources, list): - kernel_sources = [kernel_sources] - self.kernel_sources = kernel_sources - self.kernel_name = kernel_name - self.defines = defines - if lang is None: - if callable(self.kernel_sources[0]): - raise TypeError("Please specify language when using a code generator function") - kernel_string = self.get_kernel_string(0) - lang = util.detect_language(kernel_string) - - # The validity of lang is checked later, when creating the DeviceInterface - self.lang = lang.upper() - - def get_kernel_string(self, index=0, params=None): - """Retrieve the kernel source with the given index and return as a string. - - See util.get_kernel_string() for details. - - :param index: Index of the kernel source in the list of sources. - :type index: int - - :param params: Dictionary containing the tunable parameters for this specific - kernel instance, only needed when kernel_source is a generator. - :type param: dict - - :returns: A string containing the kernel code. - :rtype: string - """ - logging.debug("get_kernel_string called") - - if hasattr(self, 'lang') and self.lang.upper() == "HYPERTUNER": - return "" - - kernel_source = self.kernel_sources[index] - return util.get_kernel_string(kernel_source, params) - - def prepare_list_of_files( - self, kernel_name, params, grid, threads, block_size_names - ): - """Prepare the kernel string along with any additional files. - - The first file in the list is allowed to include or read in the others - The files beyond the first are considered additional files that may also contain tunable parameters - - For each file beyond the first this function creates a temporary file with - preprocessors statements inserted. Occurrences of the original filenames in the - first file are replaced with their temporary counterparts. - - :param kernel_name: A string specifying the kernel name. - :type kernel_name: string - - :param params: A dictionary with the tunable parameters for this particular - instance. - :type params: dict() - - :param grid: The grid dimensions for this instance. The grid dimensions are - also inserted into the code as if they are tunable parameters for - convenience. - :type grid: tuple() - - :param threads: The thread block dimensions for this instance. The thread block are - also inserted into the code as if they are tunable parameters for - convenience. - :type threads: tuple() - - :param block_size_names: A list of strings that denote the names - for the thread block dimensions. - :type block_size_names: list(string) - - """ - temp_files = dict() - - if self.lang.upper() == "HYPERTUNER": - return tuple(["", "", temp_files]) - - for i, f in enumerate(self.kernel_sources): - if i > 0 and not util.looks_like_a_filename(f): - raise ValueError("When passing multiple kernel sources, the secondary entries must be filenames") - - ks = self.get_kernel_string(i, params) - # add preprocessor statements - n, ks = util.prepare_kernel_string( - kernel_name, - ks, - params, - grid, - threads, - block_size_names, - self.lang, - self.defines, - ) - - if i == 0: - # primary kernel source - name = n - kernel_string = ks - continue - - # save secondary kernel sources to temporary files - - # generate temp filename with the same extension - temp_file = util.get_temp_filename(suffix="." + f.split(".")[-1]) - temp_files[f] = temp_file - util.write_file(temp_file, ks) - # replace occurrences of the additional file's name in the first kernel_string with the name of the temp file - kernel_string = kernel_string.replace(f, temp_file) - - return name, kernel_string, temp_files - - def get_user_suffix(self, index=0): - """Get the suffix of the kernel filename, if the user specified one. Return None otherwise.""" - if util.looks_like_a_filename(self.kernel_sources[index]) and ("." in self.kernel_sources[index]): - return "." + self.kernel_sources[index].split(".")[-1] - return None - - def get_suffix(self, index=0): - """Return a suitable suffix for a kernel filename. - - This uses the user-specified suffix if available, or one based on the - lang/backend otherwise. - """ - # TODO: Consider delegating this to the backend - suffix = self.get_user_suffix(index) - if suffix is not None: - return suffix - - _suffixes = {"CUDA": ".cu", "OpenCL": ".cl", "C": ".c"} - try: - return _suffixes[self.lang] - except KeyError: - return ".c" - - def check_argument_lists(self, kernel_name, arguments): - """Check if the kernel arguments have the correct types. - - This is done by calling util.check_argument_list on each kernel string. - """ - for i, f in enumerate(self.kernel_sources): - if not callable(f): - util.check_argument_list(kernel_name, self.get_kernel_string(i), arguments) - else: - logging.debug("Checking of arguments list not supported yet for code generators.") +# KernelSource has been moved to kernel_sources/kernel_source_str.py class DeviceInterface(object): @@ -320,9 +184,16 @@ def __init__( compiler_options=compiler_options ) self.requires_warmup = False + elif lang.upper() == "TRITON": + dev = TritonFunctions( + device, + compiler_options=compiler_options, + iterations=iterations, + observers=observers + ) else: raise NotImplementedError( - "Sorry, support for languages other than CUDA, OpenCL, HIP, C, and Fortran is not implemented yet" + "Sorry, support for languages other than CUDA, OpenCL, HIP, C, Triton and Fortran is not implemented yet" ) self.dev = dev @@ -361,17 +232,23 @@ def __init__( if not quiet: print("Using: " + self.dev.name) - def benchmark_prologue(self, func, gpu_args, threads, grid, result): + def run_kernel_bench(self, func, gpu_args, threads, grid, stream=None, params=None): + if isinstance(self.dev, TritonFunctions): + self.dev.run_kernel(func, gpu_args, threads, grid, params=params) + else: + self.dev.run_kernel(func, gpu_args, threads, grid) + + def benchmark_prologue(self, func, gpu_args, threads, grid, result, params=None): """Benchmark prologue one kernel execution per PrologueObserver.""" for obs in self.prologue_observers: self.dev.synchronize() obs.before_start() - self.dev.run_kernel(func, gpu_args, threads, grid) + self.run_kernel_bench(func, gpu_args, threads, grid, stream=None, params=params) self.dev.synchronize() obs.after_finish() result.update(obs.get_results()) - def benchmark_default(self, func, gpu_args, threads, grid, result): + def benchmark_default(self, func, gpu_args, threads, grid, result, params=None): """Benchmark one kernel execution for 'iterations' at a time.""" self.dev.synchronize() for _ in range(self.iterations): @@ -379,7 +256,7 @@ def benchmark_default(self, func, gpu_args, threads, grid, result): obs.before_start() self.dev.synchronize() self.dev.start_event() - self.dev.run_kernel(func, gpu_args, threads, grid) + self.run_kernel_bench(func, gpu_args, threads, grid, stream=None, params=params) self.dev.stop_event() for obs in self.benchmark_observers: obs.after_start() @@ -395,7 +272,7 @@ def benchmark_default(self, func, gpu_args, threads, grid, result): result.update(obs.get_results()) - def benchmark_continuous(self, func, gpu_args, threads, grid, result, duration): + def benchmark_continuous(self, func, gpu_args, threads, grid, result, duration, params=None): """Benchmark continuously for at least 'duration' seconds.""" iterations = int(np.ceil(duration / (result["time"] / 1000))) self.dev.synchronize() @@ -403,7 +280,7 @@ def benchmark_continuous(self, func, gpu_args, threads, grid, result, duration): obs.before_start() self.dev.start_event() for _ in range(iterations): - self.dev.run_kernel(func, gpu_args, threads, grid) + self.run_kernel_bench(func, gpu_args, threads, grid, stream=None, params=params) self.dev.stop_event() for obs in self.continuous_observers: obs.after_start() @@ -453,8 +330,8 @@ def benchmark(self, func, gpu_args, instance, verbose, objective, skip_nvml_sett result = {} try: - self.benchmark_prologue(func, gpu_args, instance.threads, instance.grid, result) - self.benchmark_default(func, gpu_args, instance.threads, instance.grid, result) + self.benchmark_prologue(func, gpu_args, instance.threads, instance.grid, result, instance.params) + self.benchmark_default(func, gpu_args, instance.threads, instance.grid, result, instance.params) if self.continuous_observers: duration = 1 @@ -462,7 +339,7 @@ def benchmark(self, func, gpu_args, instance, verbose, objective, skip_nvml_sett obs.results = result duration = max(duration, obs.continuous_duration) - self.benchmark_continuous(func, gpu_args, instance.threads, instance.grid, result, duration) + self.benchmark_continuous(func, gpu_args, instance.threads, instance.grid, result, duration, instance.params) except Exception as e: # some launches may fail because too many registers are required @@ -507,7 +384,7 @@ def check_kernel_output( self.dev.refresh_memory(gpu_args, instance.arguments, should_sync) # run the kernel - check = self.run_kernel(func, gpu_args, instance) + check = self.run_kernel_check(func, gpu_args, instance) if not check: # runtime failure occurred that should be ignored, skip correctness check return @@ -575,7 +452,7 @@ def compile_and_benchmark(self, kernel_source, gpu_args, params, kernel_options, try: # compile the kernel start_compilation = time.perf_counter() - func = self.compile_kernel(instance, verbose) + func = self.compile_kernel(instance, verbose, gpu_args) if not func: result[to.objective] = util.CompilationFailedConfig() else: @@ -624,14 +501,17 @@ def compile_and_benchmark(self, kernel_source, gpu_args, params, kernel_options, return result - def compile_kernel(self, instance, verbose): + def compile_kernel(self, instance, verbose, gpu_args=None): """Compile the kernel for this specific instance.""" logging.debug("compile_kernel " + instance.name) # compile kernel_string into device func func = None try: - func = self.dev.compile(instance) + if isinstance(self.dev, TritonFunctions): + func = self.dev.compile(instance, gpu_args) + else: + func = self.dev.compile(instance) except Exception as e: # compiles may fail because certain kernel configurations use too # much shared memory for example, the desired behavior is to simply @@ -692,20 +572,21 @@ def create_kernel_instance(self, kernel_source, kernel_options, params, verbose) params, kernel_options.block_size_names, ) - if np.prod(threads) > self.dev.max_threads: + if kernel_source.lang != 'TRITON' and np.prod(threads) > self.dev.max_threads: if verbose: print(f"skipping config {util.get_instance_string(params)} reason: too many threads per block") return util.InvalidConfig() # obtain the kernel_string and prepare additional files, if any - name, kernel_string, temp_files = kernel_source.prepare_list_of_files( - kernel_options.kernel_name, + instance_data = kernel_source.prepare_kernel_instance( + kernel_options, params, grid, threads, - kernel_options.block_size_names, ) + name = instance_data.kernel_name + kernel_string=instance_data.kernel_str # check for templated kernel if kernel_source.lang in ["CUDA", "NVCUDA", "HIP"] and "<" in name and ">" in name: kernel_string, name = wrap_templated_kernel(kernel_string, name) @@ -714,7 +595,17 @@ def create_kernel_instance(self, kernel_source, kernel_options, params, verbose) arguments = _preprocess_gpu_arguments(kernel_options.arguments, params) # collect everything we know about this instance and return it - return KernelInstance(name, kernel_source, kernel_string, temp_files, threads, grid, params, arguments) + return KernelInstance( + name=name, + kernel_source=kernel_source, + kernel_string=kernel_string, + kernel_fn=instance_data.kernel_fn, + temp_files=instance_data.temp_files, + threads=threads, + grid=grid, + params=params, + arguments=arguments, + ) def get_environment(self): """Return dictionary with information about the environment.""" @@ -751,14 +642,15 @@ def ready_argument_list(self, arguments): return gpu_args - def run_kernel(self, func, gpu_args, instance): + + def run_kernel_check(self, func, gpu_args, instance): """Run a compiled kernel instance on a device.""" logging.debug("run_kernel %s", instance.name) logging.debug("thread block dims (%d, %d, %d)", *instance.threads) logging.debug("grid dims (%d, %d, %d)", *instance.grid) try: - self.dev.run_kernel(func, gpu_args, instance.threads, instance.grid) + self.run_kernel_bench(func, gpu_args, instance.threads, instance.grid, stream=None, params=instance.params) except Exception as e: if "too many resources requested for launch" in str(e) or "OUT_OF_RESOURCES" in str(e): logging.debug("ignoring runtime failure due to too many resources required") @@ -767,6 +659,7 @@ def run_kernel(self, func, gpu_args, instance): logging.debug("encountered unexpected runtime failure: " + str(e)) raise e return True + def _preprocess_gpu_arguments(old_arguments, params): diff --git a/kernel_tuner/interface.py b/kernel_tuner/interface.py index 32e91c86f..bad921d24 100644 --- a/kernel_tuner/interface.py +++ b/kernel_tuner/interface.py @@ -39,6 +39,7 @@ import kernel_tuner.util as util from kernel_tuner.file_utils import get_input_file, get_t4_metadata, get_t4_results, import_class_from_file from kernel_tuner.integration import get_objective_defaults +from kernel_tuner.kernel_sources.kernel_source import KernelSource from kernel_tuner.runners.sequential import SequentialRunner from kernel_tuner.runners.simulation import SimulationRunner from kernel_tuner.searchspace import Searchspace @@ -592,7 +593,7 @@ def tune_kernel( if log: logging.basicConfig(filename=kernel_name + datetime.now().strftime("%Y%m%d-%H:%M:%S") + ".log", level=log) - kernelsource = core.KernelSource(kernel_name, kernel_source, lang, defines) + kernelsource = KernelSource(kernel_name, kernel_source, lang, defines) _check_user_input(kernel_name, kernelsource, arguments, block_size_names) @@ -777,7 +778,7 @@ def run_kernel( if log: logging.basicConfig(filename=kernel_name + datetime.now().strftime("%Y%m%d-%H:%M:%S") + ".log", level=log) - kernelsource = core.KernelSource(kernel_name, kernel_source, lang, defines) + kernelsource = KernelSource(kernel_name, kernel_source, lang, defines) _check_user_input(kernel_name, kernelsource, arguments, block_size_names) @@ -825,7 +826,7 @@ def run_kernel( instance.delete_temp_files() # run the kernel - if not dev.run_kernel(func, gpu_args, instance): + if not dev.run_kernel_check(func, gpu_args, instance): raise RuntimeError("runtime error occured, too many resources requested") # copy data in GPU memory back to the host diff --git a/kernel_tuner/kernel_sources/__init__.py b/kernel_tuner/kernel_sources/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kernel_tuner/kernel_sources/kernel_source.py b/kernel_tuner/kernel_sources/kernel_source.py new file mode 100644 index 000000000..be8bbb265 --- /dev/null +++ b/kernel_tuner/kernel_sources/kernel_source.py @@ -0,0 +1,77 @@ +import inspect +import kernel_tuner.util as util + +from abc import abstractmethod + +from kernel_tuner.kernel_sources.model.prepared_kernel_source_data import PreparedKernelSourceData + + + + +class KernelSource: + def __new__(cls, kernel_name, kernel_sources, lang, defines=None): + """Factory behavior""" + if cls is KernelSource: + if inspect.isfunction(kernel_sources) and (lang and lang.upper() == "TRITON"): # TODO should this be isfunction? + from kernel_tuner.kernel_sources.kernel_source_fn import KernelSourceFn + print("CREATING KSFN") + return KernelSourceFn(kernel_name, kernel_sources, lang, defines) + else: + from kernel_tuner.kernel_sources.kernel_source_str import KernelSourceStr + print("CREATING KSSTR") + return KernelSourceStr(kernel_name, kernel_sources, lang, defines) + # otherwise, normal subclass init + return super().__new__(cls) + + def __init__(self, kernel_name, kernel_sources, lang, defines=None): + if not isinstance(kernel_sources, list): + kernel_sources = [kernel_sources] + self.kernel_sources = kernel_sources + self.kernel_name = kernel_name + self.defines = defines + + if lang is None: + if callable(self.kernel_sources[0]): + raise TypeError("Please specify language when using a code generator function") + kernel_string = self.get_kernel_string(0) + self.lang = util.detect_language(kernel_string) + else: + self.lang = lang + + @abstractmethod + def prepare_kernel_instance(self, kernel_options, params, grid, threads) -> PreparedKernelSourceData: + raise NotImplementedError("create_kernel_instance not implemented") + + @abstractmethod + def check_argument_lists(self, kernel_name, arguments): + raise NotImplementedError("check_argument_lists not implemented") + + + +''' +class KernelSource: + def __init__(self, kernel_name, kernel_sources, lang, defines=None): + if not isinstance(kernel_sources, list): + kernel_sources = [kernel_sources] + + self.kernel_sources = kernel_sources + self.kernel_name = kernel_name + self.defines = defines + + if lang is None: + if callable(self.kernel_sources[0]): + raise TypeError("Please specify language when using a code generator function") + kernel_string = self.get_kernel_string(0) + self.lang = util.detect_language(kernel_string) + else: + self.lang = lang + + @abstractmethod + def prepare_kernel_instance(self, kernel_options, params, grid, threads) -> PreparedKernelSourceData: + raise NotImplementedError("create_kernel_instance not implemented") + + @abstractmethod + def check_argument_lists(self, kernel_name, arguments): + raise NotImplementedError("check_argument_lists not implemented") + +''' \ No newline at end of file diff --git a/kernel_tuner/kernel_sources/kernel_source_fn.py b/kernel_tuner/kernel_sources/kernel_source_fn.py new file mode 100644 index 000000000..5df6fcb94 --- /dev/null +++ b/kernel_tuner/kernel_sources/kernel_source_fn.py @@ -0,0 +1,160 @@ +import inspect +import ast +import copy +import uuid +import sys + +import astor +import tempfile +import importlib.util + +from typing import Any + +from kernel_tuner.kernel_sources.kernel_source import KernelSource +from kernel_tuner.kernel_sources.model.prepared_kernel_source_data import PreparedKernelSourceData + + +class KernelSourceFn(KernelSource): + + def __init__(self, kernel_name, kernel_source, lang, defines=None): + super().__init__(kernel_name, kernel_source, lang, defines) + if isinstance(kernel_source, list): + raise ValueError("KernelSourceFn only supports a single kernel source function") + + self.source_kernel_fn = kernel_source + self.kernel_fn = self.source_kernel_fn + self.source = inspect.getsource(kernel_source) + self.source_tree = ast.parse(self.source) + self.triton_import_nodes = [ + ast.Import(names=[ast.alias(name='triton', asname=None)]), + ast.ImportFrom( + module='triton', + names=[ast.alias(name='language', asname='tl')], + level=0 + ) + ] + + # Find the module where the kernel function is defined + self.module = inspect.getmodule(kernel_source) + # Get dependencies by analyzing the AST + self.dependencies = self._find_triton_dependencies() + + def _find_function_dependencies(self): + """Find all function calls in the kernel.""" + class FunctionCallVisitor(ast.NodeVisitor): + def __init__(self): + self.called_functions = set() + + def visit_Call(self, node): + if isinstance(node.func, ast.Name): + self.called_functions.add(node.func.id) + self.generic_visit(node) + + visitor = FunctionCallVisitor() + visitor.visit(self.source_tree) + return visitor.called_functions + + def _is_triton_jit_function(self, node): + """Check if a function has the @triton.jit decorator.""" + if not isinstance(node, ast.FunctionDef): + return False + + for decorator in node.decorator_list: + if (isinstance(decorator, ast.Name) and decorator.id == 'jit' or + isinstance(decorator, ast.Attribute) and decorator.attr == 'jit' or + isinstance(decorator, ast.Call) and + isinstance(decorator.func, ast.Attribute) and + decorator.func.attr == 'jit'): + return True + return False + + def _find_triton_dependencies(self): + """Find all Triton JIT functions that are called by the kernel.""" + dependencies = {} + called_functions = self._find_function_dependencies() + + # Parse the entire module to find Triton JIT functions + module_source = inspect.getsource(self.module) + module_tree = ast.parse(module_source) + + for node in module_tree.body: + if (self._is_triton_jit_function(node) and + node.name in called_functions and + node.name != self.kernel_name): # Skip the main kernel itself + dependencies[node.name] = node + + return dependencies + + def prepare_kernel_instance(self, kernel_options, params, grid, threads): + new_kernel_fn, temp_file_path = self.apply_params_to_source_fn(params) + self.kernel_fn = new_kernel_fn + + return PreparedKernelSourceData( + temp_files=[temp_file_path], + kernel_name=self.kernel_name, + kernel_fn=new_kernel_fn, + kernel_str=None + ) + + def check_argument_lists(self, kernel_name, arguments): + return True + + def apply_params_to_source_fn(self, params): + transformer = ReplaceVars(params) + + # Create a new module with all necessary functions + new_module = ast.Module(body=[], type_ignores=[]) + + # Add both triton imports + new_module.body.extend(self.triton_import_nodes) + + # Add all Triton JIT dependencies first + for dep_node in self.dependencies.values(): + new_module.body.append(copy.deepcopy(dep_node)) + + # Add transformed main kernel + source_tree_copy = copy.deepcopy(self.source_tree) + transformed_tree = transformer.visit(source_tree_copy) + new_module.body.extend(transformed_tree.body) + + # Fix locations and generate source + ast.fix_missing_locations(new_module) + new_source = astor.to_source(new_module) + + # Create a unique module name + module_name = f'temp_kernel_module_{uuid.uuid4().hex}' + + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as temp_file: + temp_file.write(new_source) + temp_file_path = temp_file.name + + spec = importlib.util.spec_from_file_location(module_name, temp_file_path) + temp_module = importlib.util.module_from_spec(spec) + + # Register the module in sys.modules before executing it + sys.modules[module_name] = temp_module + spec.loader.exec_module(temp_module) + new_fn = getattr(temp_module, self.kernel_name) + + return new_fn, temp_file_path + + def __del__(self): + # Clean up temporary modules when the instance is destroyed + for key in list(sys.modules.keys()): + if key.startswith('temp_kernel_module_'): + del sys.modules[key] + + +class ReplaceVars(ast.NodeTransformer): + + def __init__(self, params: dict): + self.params = params + + def visit_Name(self, node: ast.Name) -> Any: + if isinstance(node.ctx, ast.Load) and node.id in self.params.keys(): + return ast.copy_location( + ast.Constant(value=self.params[node.id]), + node + ) + + return node \ No newline at end of file diff --git a/kernel_tuner/kernel_sources/kernel_source_str.py b/kernel_tuner/kernel_sources/kernel_source_str.py new file mode 100644 index 000000000..f470b939f --- /dev/null +++ b/kernel_tuner/kernel_sources/kernel_source_str.py @@ -0,0 +1,174 @@ +import kernel_tuner.util as util +import logging + +from kernel_tuner.kernel_sources.kernel_source import KernelSource +from kernel_tuner.core import wrap_templated_kernel +from kernel_tuner.kernel_sources.model.prepared_kernel_source_data import PreparedKernelSourceData +from kernel_tuner.language import Language + + +class KernelSourceStr(KernelSource): + """Class that holds the kernel sources. + + There is a primary kernel source for string-based kernels., which can be either a source string, + a filename (indicating a file containing the kernel source code), + or a callable (generating the kernel source code). + There can additionally be (one or multiple) secondary kernel sources, which + must be filenames. + """ + + def __init__(self, kernel_name, kernel_sources, lang, defines=None): + super().__init__(kernel_name, kernel_sources, lang, defines) + + def prepare_kernel_instance(self, kernel_options, params, grid, threads): + name, kernel_string, temp_files = self.prepare_list_of_files( + kernel_name=kernel_options.kernel_name, + params=params, + grid=grid, + threads=threads, + block_size_names=kernel_options.block_size_names, + ) + + lang_is_cuda = self.lang in [Language.CUDA, Language.NVCUDA] + + if lang_is_cuda and "<" in name and ">" in name: + kernel_string, name = wrap_templated_kernel(kernel_string, name) + + return PreparedKernelSourceData( + temp_files=temp_files, + kernel_name=name, + kernel_str=kernel_string, + kernel_fn=None, + ) + + + + def get_kernel_string(self, index=0, params=None): + """Retrieve the kernel source with the given index and return as a string. + + See util.get_kernel_string() for details. + + :param index: Index of the kernel source in the list of sources. + :type index: int + + :param params: Dictionary containing the tunable parameters for this specific + kernel instance, only needed when kernel_source is a generator. + :type param: dict + + :returns: A string containing the kernel code. + :rtype: string + """ + logging.debug("get_kernel_string called") + + if hasattr(self, 'lang') and self.lang.upper() == "HYPERTUNER": + return "" + + kernel_source = self.kernel_sources[index] + return util.get_kernel_string(kernel_source, params) + + def prepare_list_of_files( + self, kernel_name, params, grid, threads, block_size_names + ): + """Prepare the kernel string along with any additional files. + + The first file in the list is allowed to include or read in the others + The files beyond the first are considered additional files that may also contain tunable parameters + + For each file beyond the first this function creates a temporary file with + preprocessors statements inserted. Occurrences of the original filenames in the + first file are replaced with their temporary counterparts. + + :param kernel_name: A string specifying the kernel name. + :type kernel_name: string + + :param params: A dictionary with the tunable parameters for this particular + instance. + :type params: dict() + + :param grid: The grid dimensions for this instance. The grid dimensions are + also inserted into the code as if they are tunable parameters for + convenience. + :type grid: tuple() + + :param threads: The thread block dimensions for this instance. The thread block are + also inserted into the code as if they are tunable parameters for + convenience. + :type threads: tuple() + + :param block_size_names: A list of strings that denote the names + for the thread block dimensions. + :type block_size_names: list(string) + + """ + temp_files = dict() + + if self.lang.upper() == "HYPERTUNER": + return tuple(["", "", temp_files]) + + for i, f in enumerate(self.kernel_sources): + if i > 0 and not util.looks_like_a_filename(f): + raise ValueError("When passing multiple kernel sources, the secondary entries must be filenames") + + ks = self.get_kernel_string(i, params) + # add preprocessor statements + n, ks = util.prepare_kernel_string( + kernel_name, + ks, + params, + grid, + threads, + block_size_names, + self.lang, + self.defines, + ) + + if i == 0: + # primary kernel source + name = n + kernel_string = ks + continue + + # save secondary kernel sources to temporary files + + # generate temp filename with the same extension + temp_file = util.get_temp_filename(suffix="." + f.split(".")[-1]) + temp_files[f] = temp_file + util.write_file(temp_file, ks) + # replace occurrences of the additional file's name in the first kernel_string with the name of the temp file + kernel_string = kernel_string.replace(f, temp_file) + + return name, kernel_string, temp_files + + def get_user_suffix(self, index=0): + """Get the suffix of the kernel filename, if the user specified one. Return None otherwise.""" + if util.looks_like_a_filename(self.kernel_sources[index]) and ("." in self.kernel_sources[index]): + return "." + self.kernel_sources[index].split(".")[-1] + return None + + def get_suffix(self, index=0): + """Return a suitable suffix for a kernel filename. + + This uses the user-specified suffix if available, or one based on the + lang/backend otherwise. + """ + # TODO: Consider delegating this to the backend + suffix = self.get_user_suffix(index) + if suffix is not None: + return suffix + + _suffixes = {"CUDA": ".cu", "OpenCL": ".cl", "C": ".c"} + try: + return _suffixes[self.lang] + except KeyError: + return ".c" + + def check_argument_lists(self, kernel_name, arguments): + """Check if the kernel arguments have the correct types. + + This is done by calling util.check_argument_list on each kernel string. + """ + for i, f in enumerate(self.kernel_sources): + if not callable(f): + util.check_argument_list(kernel_name, self.get_kernel_string(i), arguments) + else: + logging.debug("Checking of arguments list not supported yet for code generators.") \ No newline at end of file diff --git a/kernel_tuner/kernel_sources/model/__init__.py b/kernel_tuner/kernel_sources/model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kernel_tuner/kernel_sources/model/prepared_kernel_source_data.py b/kernel_tuner/kernel_sources/model/prepared_kernel_source_data.py new file mode 100644 index 000000000..b9c904b38 --- /dev/null +++ b/kernel_tuner/kernel_sources/model/prepared_kernel_source_data.py @@ -0,0 +1,10 @@ +from dataclasses import dataclass +from typing import Any + + +@dataclass +class PreparedKernelSourceData: + temp_files: Any + kernel_name: str + kernel_str: Any + kernel_fn: Any \ No newline at end of file diff --git a/kernel_tuner/kernelbuilder.py b/kernel_tuner/kernelbuilder.py index 0f3f6154f..e2039e66d 100644 --- a/kernel_tuner/kernelbuilder.py +++ b/kernel_tuner/kernelbuilder.py @@ -4,6 +4,7 @@ from kernel_tuner.interface import Options, _kernel_options from kernel_tuner.integration import TuneResults +from kernel_tuner.kernel_sources.kernel_source import KernelSource class PythonKernel(object): @@ -30,7 +31,7 @@ def __init__(self, kernel_name, kernel_string, problem_size, arguments, params=N """ #construct device interface - kernel_source = core.KernelSource(kernel_name, kernel_string, lang) + kernel_source = KernelSource(kernel_name, kernel_string, lang) self.dev = core.DeviceInterface(kernel_source, device=device, quiet=True) if not params: params = {} @@ -92,7 +93,7 @@ def run_kernel(self, args): :type args: list(np.ndarray or np.generic) """ self.update_gpu_args(args) - self.dev.run_kernel(self.func, self.gpu_args, self.kernel_instance) + self.dev.run_kernel_check(self.func, self.gpu_args, self.kernel_instance) return self.get_gpu_result(args) def __call__(self, *args): diff --git a/kernel_tuner/language.py b/kernel_tuner/language.py new file mode 100644 index 000000000..ad510b55b --- /dev/null +++ b/kernel_tuner/language.py @@ -0,0 +1,10 @@ +from enum import Enum + +class Language(Enum): + CUDA = "CUDA" + OpenCL = "OPENCL" + C = "C" + HIP = "HIP" + TRITON = "TRITON" + FORTRAN = "FORTRAN" + NVCUDA = "NVCUDA" \ No newline at end of file diff --git a/kernel_tuner/observers/triton.py b/kernel_tuner/observers/triton.py new file mode 100644 index 000000000..878e8f27c --- /dev/null +++ b/kernel_tuner/observers/triton.py @@ -0,0 +1,32 @@ +import numpy as np + +from kernel_tuner.observers.observer import BenchmarkObserver + +try: + import torch +except (ImportError, RuntimeError): + torch = None + + +class TritonRuntimeObserver(BenchmarkObserver): + """Observer that measures time using CUDA events during benchmarking.""" + + def __init__(self, dev): + if torch is None: + raise ImportError("Unable to load torch") + + self.dev = dev + self.stream = dev.stream + self.start = dev.start + self.end = dev.end + self.times = [] + + def after_finish(self): + # Time is measured in milliseconds + event_elapsed_time = self.start.elapsed_time(self.end) + self.times.append(event_elapsed_time) + + def get_results(self): + results = {"time": np.average(self.times), "times": self.times.copy()} + self.times = [] + return results \ No newline at end of file diff --git a/kernel_tuner/util.py b/kernel_tuner/util.py index 2d9e3f1b3..fd4cb5697 100644 --- a/kernel_tuner/util.py +++ b/kernel_tuner/util.py @@ -136,6 +136,9 @@ def check_argument_type(dtype, kernel_argument): def check_argument_list(kernel_name, kernel_string, args): """Raise an exception if kernel arguments do not match host arguments.""" + if kernel_string is None: + return + kernel_arguments = list() collected_errors = list() @@ -162,7 +165,7 @@ def check_argument_list(kernel_name, kernel_string, args): if not isinstance(arg, (np.ndarray, np.generic, cp.ndarray, torch.Tensor, DeviceArray)): raise TypeError( f"Argument at position {i} of type: {type(arg)} should be of type " - "np.ndarray, numpy scalar, or HIP Python DeviceArray type" +- "np.ndarray, numpy scalar, or HIP Python DeviceArray type" ) correct = True @@ -415,6 +418,8 @@ def detect_language(kernel_string): lang = "CUDA" elif "__kernel" in kernel_string: lang = "OpenCL" + elif "@triton.jit" in kernel_string: + lang = "Triton" else: lang = "C" return lang From 68d43ec4540b24fa13358fe717a739fe5fa69182 Mon Sep 17 00:00:00 2001 From: Imke van Ooijen Date: Sat, 20 Dec 2025 12:51:19 +0100 Subject: [PATCH 02/14] Started tilus support --- examples/tilus/vec_add.py | 67 ++++++ kernel_tuner/backends/tilus.py | 194 ++++++++++++++++++ kernel_tuner/core.py | 8 + kernel_tuner/kernel_sources/kernel_source.py | 13 +- .../kernel_sources/kernel_source_fn.py | 49 +++-- kernel_tuner/language.py | 10 +- kernel_tuner/observers/generic_python.py | 33 +++ 7 files changed, 356 insertions(+), 18 deletions(-) create mode 100644 examples/tilus/vec_add.py create mode 100644 kernel_tuner/backends/tilus.py create mode 100644 kernel_tuner/observers/generic_python.py diff --git a/examples/tilus/vec_add.py b/examples/tilus/vec_add.py new file mode 100644 index 000000000..60dcc6074 --- /dev/null +++ b/examples/tilus/vec_add.py @@ -0,0 +1,67 @@ +import tilus +from tilus import float32, int32 +from tilus.utils import cdiv, benchmark_func +import torch + + +class VecAddV(tilus.Script): + def __init__(self): + super().__init__() + self.block_size = 256 # number of threads per block + + def __call__( + self, + n_size: int32, # size of the vectors + a_ptr: ~float32, # input vector A + b_ptr: ~float32, # input vector B + c_ptr: ~float32 # output vector C + ): + # compute the number of blocks needed + self.attrs.blocks = [cdiv(n_size, self.block_size)] + self.attrs.warps = 1 # number of warps per block + + # calculate the offset for this block + offset: int32 = self.block_size * self.blockIdx.x + + # create global views for input/output vectors + ga = self.global_view(a_ptr, dtype=float32, shape=[n_size]) + gb = self.global_view(b_ptr, dtype=float32, shape=[n_size]) + gc = self.global_view(c_ptr, dtype=float32, shape=[n_size]) + + # load a block of A and B into registers + a = self.load_global(ga, offsets=[offset], shape=[self.block_size]) + b = self.load_global(gb, offsets=[offset], shape=[self.block_size]) + + # perform element-wise addition + c = a + b + + # store the result back to global memory + self.store_global(gc, c, offsets=[offset]) + + +def main(): + N = 1 << 20 # vector size + vecadd = VecAddV() + + a = torch.rand(N, dtype=torch.float32).cuda() + b = torch.rand(N, dtype=torch.float32).cuda() + c_actual = torch.empty_like(a) + c_expect = a + b + + torch.cuda.synchronize() + vecadd(N, a, b, c_actual) + torch.cuda.synchronize() + + # correctness check + torch.testing.assert_close(c_expect, c_actual, atol=1e-6, rtol=1e-6) + + # benchmark + for name, func in [ + ("torch", lambda: a + b), + ("tilus", lambda: vecadd(N, a, b, c_actual)), + ]: + latency = benchmark_func(func, warmup=5, repeat=20) + print(f"{name} latency: {latency:.3f} ms") + +if __name__ == "__main__": + main() diff --git a/kernel_tuner/backends/tilus.py b/kernel_tuner/backends/tilus.py new file mode 100644 index 000000000..d3782fb0d --- /dev/null +++ b/kernel_tuner/backends/tilus.py @@ -0,0 +1,194 @@ +import logging +import numpy as np +import inspect + +from kernel_tuner.backends.backend import GPUBackend +from kernel_tuner.observers.generic_python import GenericPythonRuntimeObserver + +try: + import torch +except ImportError: + logging.error("Torch not available") + +try: + import tilus +except ImportError: + tilus = None + logging.error("Unable to load Tilus") + + +class TilusFunctions(GPUBackend): + + def __init__(self, device=0, iterations=7, compiler_options=None, observers=None): + ''' + In here, everyting is generic if the language uses Torch as backend + ''' + if not tilus or not torch: + logging.error("Tilus or Torch not available") + raise ImportError("Tilus or Torch not available") + + self.device_id = torch.cuda.current_device() + + self.device_properties = torch.cuda.get_device_properties(self.device_id) + self.name = torch.cuda.get_device_name(self.device_id) + self.max_threads = self.device_properties.max_threads_per_multi_processor + + env = dict() + env["device_name"] = self.name + env["max_threads"] = self.max_threads + env["iterations"] = iterations + env["compiler_options"] = compiler_options + self.env = env + + self.stream = torch.cuda.default_stream() + self.start = torch.cuda.Event(enable_timing=True) + self.end = torch.cuda.Event(enable_timing=True) + + # setup observers + self.observers = observers or [] + self.observers.append(GenericPythonRuntimeObserver(self)) + for obs in self.observers: + obs.register_device(self) + + self.units = {"time": "ms", "power": "s,mW", "energy": "J"} + + super().__init__(device=device, iterations=iterations, compiler_options=compiler_options, observers=observers) + + def ready_argument_list(self, arguments): + ''' + This seems to work for Torch as backend. However, another idea for generic would be + to let the user provide two argument lists: one with numpy and one with the actual arguments + that will be passed to the kernel. TODO: why do we need the numpy list anyway? + ''' + torch_args = [] + + for arg in arguments: + if isinstance(arg, torch.Tensor) and arg.dim() > 0: + torch_args.append(arg.cuda()) + elif isinstance(arg, torch.Tensor) and arg.dim() == 0: + scalar_value = arg.item() + torch_args.append(scalar_value) + elif isinstance(arg, np.ndarray): + torch_arg = torch.from_numpy(arg) + torch_arg_gpu = torch_arg.cuda() + torch_args.append(torch_arg_gpu) + elif isinstance(arg, np.generic): + scalar_value = arg.item() + torch_args.append(scalar_value) + else: + logging.warning("Unknown instance in Tilus functions") + + return torch_args + + def compile(self, kernel_instance, gpu_args=None): + logging.debug("Compiling Tilus kernel") + + if kernel_instance.kernel_fn is None: + raise ValueError("kernel_fn is None, currently Tilus only supports callable kernel_source") + + ''' + if gpu_args is None: + raise ValueError("gpu_args is None, Triton needs gpu args to compile the kernel") + TODO: find out about these GPU args + ''' + + grid = kernel_instance.grid # TODO might need to use this? + threads = kernel_instance.threads + params = kernel_instance.params + kernel_function = kernel_instance.kernel_fn + gpu_kwargs = self.build_gpu_kwargs(kernel_function, threads, params) + + # Call the jit function in order to compile it + kernel_function(*gpu_args, **gpu_kwargs) + + return kernel_function + + def start_event(self): + logging.debug("Start triton event") + self.start.record() + + def stop_event(self): + logging.debug("Stop triton event") + self.end.record() + + def kernel_finished(self): + logging.debug("Checking if kernel has finished") + return self.end.query() + + def run_kernel(self, func, gpu_args, threads, grid, stream=None, params=None): + ''' + if params is None: + raise ValueError("params is None, Triton needs params in order to set num_warps, num_ctas, etc.") + ''' + + # Run the kernel + if stream is None: + stream = self.stream + + gpu_kwargs = self.build_gpu_kwargs(func, threads, params) + + with torch.cuda.stream(stream): + logging.debug("Running Tilus kernel") + func(*gpu_args, **gpu_kwargs) + + + + def build_gpu_kwargs(self, kernel_function, threads, params=None): + gpu_kwargs = {} + + ''' + if 'BLOCK_SIZE' in jit_fn.arg_names: + gpu_kwargs['BLOCK_SIZE'] = threads[0] + + if 'BLOCK_SIZE_X' in jit_fn.arg_names: + gpu_kwargs['BLOCK_SIZE_X'] = threads[0] + + if 'BLOCK_SIZE_Y' in jit_fn.arg_names: + gpu_kwargs['BLOCK_SIZE_Y'] = threads[1] + + if 'BLOCK_SIZE_Z' in jit_fn.arg_names: + gpu_kwargs['BLOCK_SIZE_Z'] = threads[2] + ''' + + if params is None: + return gpu_kwargs + + + sig = inspect.signature(kernel_function) + for name, p in sig.parameters.items(): + if name in params: + gpu_kwargs[name] = params[name] + + ''' + # Check for Triton specific parameters + if 'num_warps' in params: + gpu_kwargs['num_warps'] = params['num_warps'] + if 'num_ctas' in params: + gpu_kwargs['num_ctas'] = params['num_ctas'] + if 'num_stages' in params: + gpu_kwargs['num_stages'] = params['num_stages'] + ''' + + return gpu_kwargs + + + def synchronize(self): + torch.cuda.synchronize() + + def memset(self, allocation, value, size): + pass + + def memcpy_dtoh(self, dest, src): + pass + + def memcpy_htod(self, dest, src): + pass + + def copy_constant_memory_args(self, cmem_args): + raise NotImplementedError("Tilus does not support constant memory") + + def copy_shared_memory_args(self, smem_args): + raise NotImplementedError("Tilus does not support shared memory") + + def copy_texture_memory_args(self, texmem_args): + raise NotImplementedError("Tilus does not support texture memory") \ No newline at end of file diff --git a/kernel_tuner/core.py b/kernel_tuner/core.py index 1a7e67b34..3b545d459 100644 --- a/kernel_tuner/core.py +++ b/kernel_tuner/core.py @@ -22,6 +22,7 @@ from kernel_tuner.backends.opencl import OpenCLFunctions from kernel_tuner.backends.pycuda import PyCudaFunctions from kernel_tuner.backends.triton import TritonFunctions +from kernel_tuner.backends.tilus import TilusFunctions from kernel_tuner.kernel_sources.kernel_source import KernelSource from kernel_tuner.observers.nvml import NVMLObserver from kernel_tuner.observers.observer import ContinuousObserver, OutputObserver, PrologueObserver @@ -191,6 +192,13 @@ def __init__( iterations=iterations, observers=observers ) + elif lang.upper() == "GENERIC_PYTHON": + dev = TilusFunctions( + device, + compiler_options=compiler_options, + iterations=iterations, + observers=observers + ) else: raise NotImplementedError( "Sorry, support for languages other than CUDA, OpenCL, HIP, C, Triton and Fortran is not implemented yet" diff --git a/kernel_tuner/kernel_sources/kernel_source.py b/kernel_tuner/kernel_sources/kernel_source.py index be8bbb265..70c94903d 100644 --- a/kernel_tuner/kernel_sources/kernel_source.py +++ b/kernel_tuner/kernel_sources/kernel_source.py @@ -3,6 +3,7 @@ from abc import abstractmethod +from kernel_tuner.language import Language from kernel_tuner.kernel_sources.model.prepared_kernel_source_data import PreparedKernelSourceData @@ -11,8 +12,16 @@ class KernelSource: def __new__(cls, kernel_name, kernel_sources, lang, defines=None): """Factory behavior""" + if lang == None: + language = None + else: + try: + language = Language(lang) + except ValueError: + raise TypeError(f"Supported languages are {[l.value for l in Language]}") + if cls is KernelSource: - if inspect.isfunction(kernel_sources) and (lang and lang.upper() == "TRITON"): # TODO should this be isfunction? + if inspect.isfunction(kernel_sources) and (language and (language == Language.TRITON or language == Language.GENERIC_PYTHON)): # TODO should this be isfunction? from kernel_tuner.kernel_sources.kernel_source_fn import KernelSourceFn print("CREATING KSFN") return KernelSourceFn(kernel_name, kernel_sources, lang, defines) @@ -20,9 +29,11 @@ def __new__(cls, kernel_name, kernel_sources, lang, defines=None): from kernel_tuner.kernel_sources.kernel_source_str import KernelSourceStr print("CREATING KSSTR") return KernelSourceStr(kernel_name, kernel_sources, lang, defines) + # otherwise, normal subclass init return super().__new__(cls) + def __init__(self, kernel_name, kernel_sources, lang, defines=None): if not isinstance(kernel_sources, list): kernel_sources = [kernel_sources] diff --git a/kernel_tuner/kernel_sources/kernel_source_fn.py b/kernel_tuner/kernel_sources/kernel_source_fn.py index 5df6fcb94..223e0e84f 100644 --- a/kernel_tuner/kernel_sources/kernel_source_fn.py +++ b/kernel_tuner/kernel_sources/kernel_source_fn.py @@ -10,10 +10,12 @@ from typing import Any +from kernel_tuner.language import Language from kernel_tuner.kernel_sources.kernel_source import KernelSource from kernel_tuner.kernel_sources.model.prepared_kernel_source_data import PreparedKernelSourceData + class KernelSourceFn(KernelSource): def __init__(self, kernel_name, kernel_source, lang, defines=None): @@ -21,23 +23,35 @@ def __init__(self, kernel_name, kernel_source, lang, defines=None): if isinstance(kernel_source, list): raise ValueError("KernelSourceFn only supports a single kernel source function") + + try: + self.lang = Language(lang) + except ValueError: + raise TypeError(f"Supported languages are {[l.value for l in Language]}") self.source_kernel_fn = kernel_source self.kernel_fn = self.source_kernel_fn self.source = inspect.getsource(kernel_source) self.source_tree = ast.parse(self.source) - self.triton_import_nodes = [ - ast.Import(names=[ast.alias(name='triton', asname=None)]), - ast.ImportFrom( - module='triton', - names=[ast.alias(name='language', asname='tl')], - level=0 - ) - ] + if self.lang == Language.TRITON: + self.import_nodes = [ + ast.Import(names=[ast.alias(name='triton', asname=None)]), + ast.ImportFrom( + module='triton', + names=[ast.alias(name='language', asname='tl')], + level=0 + ) + ] + else: + self.import_nodes = [n for n in self.source_tree.body if isinstance(n, (ast.Import, ast.ImportFrom))] + # Find the module where the kernel function is defined self.module = inspect.getmodule(kernel_source) # Get dependencies by analyzing the AST - self.dependencies = self._find_triton_dependencies() + if self.lang == Language.lang: + self.dependencies = self._find_triton_dependencies() + else: + self.dependencies = None # TODO def _find_function_dependencies(self): """Find all function calls in the kernel.""" @@ -55,7 +69,7 @@ def visit_Call(self, node): return visitor.called_functions def _is_triton_jit_function(self, node): - """Check if a function has the @triton.jit decorator.""" + """Check if a function has the @triton.jit decorator. Is triton specific""" if not isinstance(node, ast.FunctionDef): return False @@ -69,7 +83,7 @@ def _is_triton_jit_function(self, node): return False def _find_triton_dependencies(self): - """Find all Triton JIT functions that are called by the kernel.""" + """Find all Triton JIT functions that are called by the kernel. Is Triton specific""" dependencies = {} called_functions = self._find_function_dependencies() @@ -105,12 +119,15 @@ def apply_params_to_source_fn(self, params): # Create a new module with all necessary functions new_module = ast.Module(body=[], type_ignores=[]) - # Add both triton imports - new_module.body.extend(self.triton_import_nodes) + # Add imports + new_module.body.extend(self.import_nodes) - # Add all Triton JIT dependencies first - for dep_node in self.dependencies.values(): - new_module.body.append(copy.deepcopy(dep_node)) + if self.lang == Language.TRITON: + # Add all Triton JIT dependencies first + for dep_node in self.dependencies.values(): + new_module.body.append(copy.deepcopy(dep_node)) + + # TODO the kernel can not have dependencies yet. Obviously this needs fixing. # Add transformed main kernel source_tree_copy = copy.deepcopy(self.source_tree) diff --git a/kernel_tuner/language.py b/kernel_tuner/language.py index ad510b55b..d33933cca 100644 --- a/kernel_tuner/language.py +++ b/kernel_tuner/language.py @@ -7,4 +7,12 @@ class Language(Enum): HIP = "HIP" TRITON = "TRITON" FORTRAN = "FORTRAN" - NVCUDA = "NVCUDA" \ No newline at end of file + NVCUDA = "NVCUDA" + GENERIC_PYTHON = "GENERIC_PYTHON" + + @classmethod + def is_valid(cls, value: str) -> bool: + """ + Test if a language is valid in Kernel Tuner framework. + """ + return value in cls._value2member_map_ \ No newline at end of file diff --git a/kernel_tuner/observers/generic_python.py b/kernel_tuner/observers/generic_python.py new file mode 100644 index 000000000..896b4f9cc --- /dev/null +++ b/kernel_tuner/observers/generic_python.py @@ -0,0 +1,33 @@ +import numpy as np + +from kernel_tuner.observers.observer import BenchmarkObserver + +try: + import torch +except (ImportError, RuntimeError): + torch = None + + +class GenericPythonRuntimeObserver(BenchmarkObserver): + """Observer that measures time using CUDA events during benchmarking. + TODO think about: do we need torch? Does CUDA work for all languages?""" + + def __init__(self, dev): + if torch is None: + raise ImportError("Unable to load torch") + + self.dev = dev + self.stream = dev.stream + self.start = dev.start + self.end = dev.end + self.times = [] + + def after_finish(self): + # Time is measured in milliseconds + event_elapsed_time = self.start.elapsed_time(self.end) + self.times.append(event_elapsed_time) + + def get_results(self): + results = {"time": np.average(self.times), "times": self.times.copy()} + self.times = [] + return results \ No newline at end of file From 542f5bae3b2bf29df9926dc2a22d770d7b5e71a8 Mon Sep 17 00:00:00 2001 From: "I.C. van Ooijen" Date: Fri, 6 Feb 2026 15:20:16 +0100 Subject: [PATCH 03/14] Working implementation tested with basic kernels. Might need some refactoring. Definately needs comments --- .../generic_python/matmul/triton_matmul.py | 154 ++++++++++ examples/generic_python/numba_vec_add.py | 68 ++++ examples/generic_python/tilus_naive_matmul.py | 124 ++++++++ .../generic_python/tilus_splitk_matmul.py | 267 ++++++++++++++++ .../generic_python/tilus_tunable_precision.py | 120 ++++++++ .../tilus_vec_add.py} | 64 ++-- examples/generic_python/triton_vec_add.py | 76 +++++ examples/generic_python/warp_vec_add.py | 98 ++++++ examples/triton/vec_add.py | 46 +-- kernel_tuner/backends/generic_python.py | 290 ++++++++++++++++++ kernel_tuner/backends/pycuda.py | 2 + kernel_tuner/backends/tilus.py | 194 ------------ kernel_tuner/core.py | 149 +++++---- kernel_tuner/interface.py | 43 ++- kernel_tuner/kernel_sources/kernel_source.py | 86 +++--- .../kernel_sources/kernel_source_fn.py | 167 ++++++++-- .../kernel_sources/kernel_source_str.py | 6 +- kernel_tuner/language.py | 4 +- kernel_tuner/observers/generic_python.py | 3 +- kernel_tuner/util.py | 2 +- test/context.py | 9 + test/test_backend.py | 8 +- test/test_generic_python_functions.py | 116 +++++++ test/test_kernel_source_fn.py | 160 ++++++++++ 24 files changed, 1866 insertions(+), 390 deletions(-) create mode 100644 examples/generic_python/matmul/triton_matmul.py create mode 100644 examples/generic_python/numba_vec_add.py create mode 100644 examples/generic_python/tilus_naive_matmul.py create mode 100644 examples/generic_python/tilus_splitk_matmul.py create mode 100644 examples/generic_python/tilus_tunable_precision.py rename examples/{tilus/vec_add.py => generic_python/tilus_vec_add.py} (53%) create mode 100644 examples/generic_python/triton_vec_add.py create mode 100644 examples/generic_python/warp_vec_add.py create mode 100644 kernel_tuner/backends/generic_python.py delete mode 100644 kernel_tuner/backends/tilus.py create mode 100644 test/test_generic_python_functions.py create mode 100644 test/test_kernel_source_fn.py diff --git a/examples/generic_python/matmul/triton_matmul.py b/examples/generic_python/matmul/triton_matmul.py new file mode 100644 index 000000000..7115ea5cc --- /dev/null +++ b/examples/generic_python/matmul/triton_matmul.py @@ -0,0 +1,154 @@ +import torch + +import triton +import triton.language as tl + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + +def get_cuda_autotune_config(): + return [ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, + num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, + num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, + num_warps=2), + # Good config for fp8 inputs. + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, + num_warps=8), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, + num_warps=8), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4) + ] + +''' +@triton.autotune( + configs=get_cuda_autotune_config(), + key=['M', 'N', 'K'], +) +''' +@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + M, N, K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ----------------------------------------------------------- + # Add some integer bound assumptions. + # This helps to guide integer analysis in the backend to optimize + # load/store offset address calculation + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + accumulator = tl.dot(a, b, accumulator) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + c = accumulator.to(tl.float16) + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul(a, b): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + M, K = a.shape + K, N = b.shape + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + matmul_kernel[grid]( + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1) + ) + return c + + diff --git a/examples/generic_python/numba_vec_add.py b/examples/generic_python/numba_vec_add.py new file mode 100644 index 000000000..3732aa018 --- /dev/null +++ b/examples/generic_python/numba_vec_add.py @@ -0,0 +1,68 @@ +import numpy as np +from numba import cuda +from math import ceil +from kernel_tuner import tune_kernel, run_kernel + +#@cuda.jit +def f(a, b, c): + tid = cuda.grid(1) + size = len(c) + + if tid < size: + c[tid] = a[tid] + b[tid] + + +def call_numba(kernel_function, args, kwargs, grid, threads, params): + kernel_function[grid[0], threads[0]](*args, **kwargs) + + +def verify(answer, result_host, atol): + correct = True + for i, ans in enumerate(answer): + if ans is None: + continue + res = result_host[i].copy_to_host() + if not np.allclose(ans, res, atol=atol): + correct = False + + return correct + + +N = 100000000 +a = cuda.to_device(np.random.random(N)) +b = cuda.to_device(np.random.random(N)) +c = cuda.device_array_like(a) +c_expect = a.copy_to_host() + b.copy_to_host() + +args = [a, b, c] +tune_params = {"block_size_x": [2**i for i in range(10)]} + +results = run_kernel( + kernel_name="f", + kernel_source=f, + problem_size=N, + arguments=args, + params={"block_size_x": 32}, + lang="generic_python", + call_function=call_numba, + decorator="@cuda.jit" +) + + +print(np.allclose(results[2], c_expect)) + +''' +results, env = tune_kernel( + kernel_name="f", + kernel_source=f, + problem_size=N, + arguments=args, + tune_params=tune_params, + lang="generic_python", + answer=[None, None, c_expect], + verify=verify, + call_function=call_numba, + decorator="@cuda.jit" +) +''' + diff --git a/examples/generic_python/tilus_naive_matmul.py b/examples/generic_python/tilus_naive_matmul.py new file mode 100644 index 000000000..2f358ef48 --- /dev/null +++ b/examples/generic_python/tilus_naive_matmul.py @@ -0,0 +1,124 @@ +from tilus import float16, float32, int32 +from tilus.utils import cdiv +import tilus +from kernel_tuner import tune_kernel, run_kernel +import math +import torch + + +class MatmulV0(tilus.Script): + def __init__(self): + super().__init__() + # we define three hyperparameters: ``block_m``, ``block_n``, and ``block_k`` to determine the tile size on + # m, n, and k dimensions for each `thread block` of the kernel. + self.block_m = 64 + self.block_n = 64 + self.block_k = 16 + + def __call__( + self, + m_size: int32, # the size of the m dimension of the input matrix A and output matrix C + n_size: int, # the size of the n dimension of the input matrix B and output matrix C + k_size: int, # the size of the k dimension of the input matrix A and B + a_ptr: ~float16, # the pointer to the input matrix A, which is a 2D tensor of shape [m_size, k_size] + b_ptr: ~float16, # the pointer to the input matrix B, which is a 2D tensor of shape [k_size, n_size] + c_ptr: ~float16, # the pointer to the output matrix C, which is a 2D tensor of shape [m_size, n_size] + ): + self.attrs.blocks = [ + cdiv(m_size, self.block_m), # the x dimension size of the grid + cdiv(n_size, self.block_n), # the y dimension size of the grid + ] + self.attrs.warps = 1 # the number of warps per thread block, must be a compile-time known integer + + # define two int32 variables to store the offsets of the m and n dimensions for the current thread block. + offset_m: int32 = self.block_m * self.blockIdx.x + offset_n: int32 = self.block_n * self.blockIdx.y + + # create two global tensors `ga` and `gb` to represent the input matrices A and B, respectively. + ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) + gb = self.global_view(b_ptr, dtype=float16, shape=[k_size, n_size]) + + # create a register tensor `acc` to accumulate the results of the matrix multiplication. + acc = self.register_tensor( + dtype=float32, shape=[self.block_m, self.block_n], init=0.0 + ) + + # iterate over the k dimension in blocks of size `block_k`. + for k in range(cdiv(k_size, self.block_k)): + # calculate the offset for the current block in the k dimension + offset_k = k * self.block_k + + # load a block of matrix A and B into register tensors `a` and `b`. + a = self.load_global( + ga, offsets=[offset_m, offset_k], shape=[self.block_m, self.block_k] + ) + b = self.load_global( + gb, offsets=[offset_k, offset_n], shape=[self.block_k, self.block_n] + ) + + # perform the dot product: acc = a @ b + acc + self.dot(a, b, acc, out=acc) + + # after the loop, we cast the accumulated result `acc` to float16 type and store it back to the output matrix C. + acc_f16 = self.cast(acc, dtype=float16) + gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) + self.store_global(gc, acc_f16, offsets=[offset_m, offset_n]) + + +def call_tilus(kernel_function, args, kwargs, grid, threads, params): + kernel_function(*args, **kwargs) + +def main(): + m, n, k = 4096, 4096, 4096 + + # create an instance of the kernel we have just defined + matmul = MatmulV0() + + a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + b = (torch.rand(k, n, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + c_expect = a @ b + + ''' + torch.cuda.synchronize() + # launch the kernel by passing required arguments + matmul(m, n, k, a, b, c_actual) + torch.cuda.synchronize() + + # check correctness + torch.testing.assert_close(c_expect, c_actual, atol=1e-2, rtol=1e-2) + ''' + + args = [m, n, k, a, b, c_actual] + tune_params = dict() + tune_params["block_m"] = [16, 32, 64, 128, 256] + tune_params["block_n"] = [16, 32, 64, 128, 256] + tune_params["block_k"] = [16, 32, 64, 128, 256] + + + restrictions = [ + "block_m * block_n <= 4096", + "block_m * block_k <= 2048", + "block_k * block_n <= 4096", + "block_k >= 16", + ] + + results, env = tune_kernel( + kernel_name="MatmulV0", + kernel_source=MatmulV0, + problem_size=[m, n], + arguments=args, + tune_params=tune_params, + lang="generic_python", + answer=[None, None, None, None, None, c_expect.cpu()], + restrictions=restrictions, + block_size_names=["block_m", "block_n", "block_k"], + call_function=call_tilus, + strategy="simulated_annealing" + ) + + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/generic_python/tilus_splitk_matmul.py b/examples/generic_python/tilus_splitk_matmul.py new file mode 100644 index 000000000..f0dff1ac3 --- /dev/null +++ b/examples/generic_python/tilus_splitk_matmul.py @@ -0,0 +1,267 @@ +import tilus +from tilus import float16, float32, int32 +from tilus.utils import cdiv, benchmark_func +import torch +import math +from kernel_tuner import tune_kernel, run_kernel + +''' +@tilus.autotune("num_warps", [4, 8]) +@tilus.autotune("block_m, block_n", [(128, 128), (128, 64), (64, 128), (32, 256)]) +@tilus.autotune("block_k", [16, 32]) +@tilus.autotune("num_stages", [3, 4, 5]) +@tilus.autotune("split_k_factor", [1, 4, 12, 16]) +''' +class MatmulV5(tilus.Script): + ''' + def __init__(self, block_m, block_n, block_k, num_warps, num_stages, split_k_factor): + super().__init__() + self.block_m = block_m + self.block_n = block_n + self.block_k = block_k + self.num_warps = num_warps + self.num_stages = num_stages + self.split_k_factor = split_k_factor + ''' + def __init__(self): + super().__init__() + self.block_m = 128 + self.block_n = 128 + self.block_k = 16 + self.num_warps = 4 + self.num_stages = 4 + self.split_k_factor = 4 + + def __call__( + self, + m_size: int32, + n_size: int, + k_size: int, + a_ptr: ~float16, + b_ptr: ~float16, + c_ptr: ~float16, + ): + self.attrs.blocks = [ + cdiv(m_size, self.block_m), + cdiv(n_size, self.block_n), + self.split_k_factor, + ] + self.attrs.warps = self.num_warps + + # the k_size for each thread block + block_k_size = ( + cdiv(cdiv(k_size, self.split_k_factor), self.block_k) * self.block_k + ) + start_offset_k = self.blockIdx.z * block_k_size + end_offset_k = min(start_offset_k + block_k_size, k_size) + + block_m, block_n, block_k = self.block_m, self.block_n, self.block_k + offset_m: int32 = block_m * self.blockIdx.x + offset_n: int32 = block_n * self.blockIdx.y + + ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) + gb = self.global_view(b_ptr, dtype=float16, shape=[k_size, n_size]) + sa = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_m, block_k]) + sb = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_k, block_n]) + acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0) + + for stage in range(self.num_stages - 1): + offset_k = start_offset_k + stage * self.block_k + self.copy_async(src=ga, dst=sa[stage], offsets=[offset_m, offset_k]) + self.copy_async(src=gb, dst=sb[stage], offsets=[offset_k, offset_n]) + self.copy_async_commit_group() + + self.copy_async_wait_group(n=self.num_stages - 2) + self.sync() + + current_stage: int32 = 0 + preload_stage: int32 = self.num_stages - 1 + for offset_k in self.range( + start_offset_k, end_offset_k, block_k, unroll=self.num_stages + ): + # computation for current tile + a = self.load_shared(sa[current_stage]) + b = self.load_shared(sb[current_stage]) + self.dot(a, b, acc, out=acc) + + # preload the next tile of A and B into shared memory + preload_offset_k = offset_k + (self.num_stages - 1) * block_k + if preload_offset_k < end_offset_k: + self.copy_async( + src=ga, + dst=sa[preload_stage], + offsets=[offset_m, preload_offset_k], + ) + self.copy_async( + src=gb, + dst=sb[preload_stage], + offsets=[preload_offset_k, offset_n], + ) + self.copy_async_commit_group() + + # update the stage + current_stage = (current_stage + 1) % self.num_stages + preload_stage = (preload_stage + 1) % self.num_stages + self.copy_async_wait_group(n=self.num_stages - 2) + self.sync() + + # free the shared memory tensors for A and B + self.free_shared(sa) + self.free_shared(sb) + + # cast the accumulator to float16 and change the register tensor's layout + sc = self.shared_tensor(dtype=float16, shape=[block_m, block_n]) + casted_acc = self.cast(acc, dtype=float16) + self.store_shared(sc, casted_acc) + self.sync() + rc = self.load_shared(sc) + self.free_shared(sc) + + m_blocks, n_blocks = cdiv(m_size, block_m), cdiv(n_size, block_n) + gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) + if self.split_k_factor == 0: + self.store_global(gc, rc, offsets=[offset_m, offset_n]) + else: + semaphores = self.global_tensor( + dtype=int32, shape=[m_blocks, n_blocks], requires_clean=True + ) + semaphore: ~int32 = ~semaphores[self.blockIdx.x, self.blockIdx.y] + + # load and accumulate the partial result in global memory + if self.blockIdx.z > 0: + self.lock_semaphore(semaphore, value=self.blockIdx.z) + partial_rc = self.load_global( + gc, offsets=[offset_m, offset_n], shape=[block_m, block_n] + ) + self.add(rc, partial_rc, out=rc) + + # store the result to global memory and release the semaphore + self.store_global(gc, rc, offsets=[offset_m, offset_n]) + + # release the semaphore + self.sync() # we need to make sure the previous store_global is finished + self.release_semaphore( + semaphore, value=(self.blockIdx.z + 1) % self.split_k_factor + ) + + +def call_tilus(kernel_function, args, kwargs, grid, threads, params): + kernel_function(*args, **kwargs) + + +def without_kernel_tuner(): + tilus.option.clear_cache = True + + tilus.option.verbose_autotune = True + m, n, k = 4096, 4096, 4096 + + # create an instance of the kernel we have just defined + matmul = MatmulV5() + + a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + b = (torch.rand(k, n, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + c_expect = a @ b + + + torch.cuda.synchronize() + # launch the kernel by passing required arguments + matmul(m, n, k, a, b, c_actual) + torch.cuda.synchronize() + + # check correctness + torch.testing.assert_close(c_expect, c_actual, atol=1e-2, rtol=1e-2) + + + import pandas + rows = [] + headers = ["m", "n", "k", "name", "latency (ms)", "tflops"] + # benchmark + for name, func in [ + ("torch", lambda: torch.matmul(a, b, out=c_expect)), + ("tilus", lambda: matmul(m, n, k, a, b, c_actual)), + ]: + latency = benchmark_func(func, warmup=5, repeat=20) + tflops = 2 * m * n * k / latency * 1e-9 + rows.append([m, n, k, name, latency, tflops]) + + df = pandas.DataFrame(rows, columns=headers) + print(df) + + + +#best performing configuration: +#block_m=128, block_n=64, block_k=16, num_warps=4, num_stages=4, split_k_factor=1, time=2.027ms + +def main(): + m, n, k = 4096, 4096, 4096 + + # create an instance of the kernel we have just defined + #matmul = MatmulV5() + + a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + b = (torch.rand(k, n, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + c_expect = a @ b + + ''' + torch.cuda.synchronize() + # launch the kernel by passing required arguments + matmul(m, n, k, a, b, c_actual) + torch.cuda.synchronize() + + # check correctness + torch.testing.assert_close(c_expect, c_actual, atol=1e-2, rtol=1e-2) + ''' + + args = [m, n, k, a, b, c_actual] + tune_params = dict() + tune_params["block_m"] = [32, 64, 128] #[16, 32, 64, 128, 256] + tune_params["block_n"] = [32, 64, 128] #[16, 32, 64, 128, 256] + tune_params["block_k"] = [16, 32] #[16, 32, 64, 128, 256] + tune_params["num_warps"] = [4, 8] #[2, 4, 8, 16] + tune_params["num_stages"] = [3, 4, 5] #[2, 3, 4, 5, 6] + tune_params["split_k_factor"] = [1, 4, 12, 16] #[1, 4, 12, 16, 20] + +#@tilus.autotune("num_warps", [4, 8]) +#@tilus.autotune("block_m, block_n", [(128, 128), (128, 64), (64, 128), (32, 256)]) +#@tilus.autotune("block_k", [16, 32]) +#@tilus.autotune("num_stages", [3, 4, 5]) +#@tilus.autotune("split_k_factor", [1, 4, 12, 16]) + + ''' + restrictions = [ + "block_m * block_n <= 4096", + "block_m * block_k <= 2048", + "block_k * block_n <= 4096", + "block_k >= 16", + "2 * (num_stages * block_k * (block_m + block_n) + block_m * block_n) <= 65536", # shared mem + "num_warps * 32 <= 1024", + "block_k * split_k_factor <= 4096", + ] + ''' + restrictions = ["block_m * block_n >= 8192", "block_m * block_n <= 16384"] + + + + results, env = tune_kernel( + kernel_name="MatmulV5", # This has to be a string of the actual name. TODO is this always the case? + kernel_source=MatmulV5, + problem_size=[m, n], + arguments=args, + tune_params=tune_params, + lang="generic_python", + answer=[None, None, None, None, None, c_expect.cpu()], + atol=1e-2, + restrictions=restrictions, + block_size_names=["block_m", "block_n", "block_k"], + call_function=call_tilus, + #strategy="random_sample", + ) + + + + +if __name__ == "__main__": + main() + #without_kernel_tuner() \ No newline at end of file diff --git a/examples/generic_python/tilus_tunable_precision.py b/examples/generic_python/tilus_tunable_precision.py new file mode 100644 index 000000000..1cc4a7499 --- /dev/null +++ b/examples/generic_python/tilus_tunable_precision.py @@ -0,0 +1,120 @@ +import tilus +import hidet +from hidet import float32, float16, int32 +#from tilus import float32, int32, float16 +from tilus.utils import cdiv, benchmark_func +import torch +from kernel_tuner import tune_kernel, run_kernel +from kernel_tuner.accuracy import Tunable, AccuracyObserver +import numpy as np + +INPUT_TYPE = float32 +OUTPUT_TYPE = float32 + +class VecAddV(tilus.Script): + def __init__(self): + super().__init__() + self.block_size_x = 32 # number of threads per block + + def __call__( + self, + n_size: int32, # size of the vectors + a_ptr: ~INPUT_TYPE, # input vector A + b_ptr: ~INPUT_TYPE, # input vector B + c_ptr: ~OUTPUT_TYPE # output vector C + ): + + # compute the number of blocks needed + self.attrs.blocks = [cdiv(n_size, self.block_size_x)] + self.attrs.warps = 1 # number of warps per block + + # calculate the offset for this block + offset: int32 = self.block_size_x * self.blockIdx.x + + # create global views for input/output vectors + ga = self.global_view(a_ptr, dtype=INPUT_TYPE, shape=[n_size]) + gb = self.global_view(b_ptr, dtype=INPUT_TYPE, shape=[n_size]) + gc = self.global_view(c_ptr, dtype=OUTPUT_TYPE, shape=[n_size]) + + a = self.load_global(ga, offsets=[offset], shape=[self.block_size_x]) + b = self.load_global(gb, offsets=[offset], shape=[self.block_size_x]) + c = a + b + self.store_global(gc, c, offsets=[offset]) + + +def call_tilus(kernel_function, args, kwargs, grid, threads, params): + kernel_function(*args, **kwargs) + + +def verify(answer, result_host, atol): + correct = True + for i, ans in enumerate(answer): + if ans is None: + continue + res = result_host[i].cpu() + if not torch.allcose(ans, res, atol=atol): + correct = False + + return correct + + + +def main(): + + size = 1024000 + + a_32 = torch.randn(size, dtype=torch.float32) + b_32 = torch.randn(size, dtype=torch.float32) + c_32 = torch.zeros_like(b_32) + c_expect = a_32 + b_32 + + a_16 = a_32.to(torch.float16) + b_16 = b_32.to(torch.float16) + c_16 = c_32.to(torch.float16) + + + tune_params = dict() + tune_params["block_size_x"] = [32, 64, 128, 256, 512, 1024] + tune_params["INPUT_TYPE"] = [tilus.float16, tilus.float32] + tune_params["OUTPUT_TYPE"] = [tilus.float16, tilus.float32] + + args = [ + size, + Tunable("INPUT_TYPE", { + tilus.float32: a_32, + tilus.float16: a_16, + }), + Tunable("INPUT_TYPE", { + tilus.float32: b_32, + tilus.float16: b_16, + }), + Tunable("OUTPUT_TYPE", { + tilus.float32: c_32, + tilus.float16: c_16, + }), + ] + + print(tune_params) + + observers = [AccuracyObserver("RMSE")] + + + + results, env = tune_kernel( + kernel_name="VecAddV", + kernel_source=VecAddV, + problem_size=size, + arguments=args, + tune_params=tune_params, + lang="generic_python", + answer=[None, None, None, c_expect.cpu()], + observers=observers, + call_function=call_tilus, + verify=verify, + verbose=True, + ) + + + +if __name__ == "__main__": + main() diff --git a/examples/tilus/vec_add.py b/examples/generic_python/tilus_vec_add.py similarity index 53% rename from examples/tilus/vec_add.py rename to examples/generic_python/tilus_vec_add.py index 60dcc6074..dce785909 100644 --- a/examples/tilus/vec_add.py +++ b/examples/generic_python/tilus_vec_add.py @@ -2,12 +2,14 @@ from tilus import float32, int32 from tilus.utils import cdiv, benchmark_func import torch +from kernel_tuner import tune_kernel, run_kernel + class VecAddV(tilus.Script): def __init__(self): super().__init__() - self.block_size = 256 # number of threads per block + self.block_size_x = 32 # number of threads per block def __call__( self, @@ -15,53 +17,59 @@ def __call__( a_ptr: ~float32, # input vector A b_ptr: ~float32, # input vector B c_ptr: ~float32 # output vector C - ): + ): + # compute the number of blocks needed - self.attrs.blocks = [cdiv(n_size, self.block_size)] + self.attrs.blocks = [cdiv(n_size, self.block_size_x)] self.attrs.warps = 1 # number of warps per block # calculate the offset for this block - offset: int32 = self.block_size * self.blockIdx.x + offset: int32 = self.block_size_x * self.blockIdx.x # create global views for input/output vectors ga = self.global_view(a_ptr, dtype=float32, shape=[n_size]) gb = self.global_view(b_ptr, dtype=float32, shape=[n_size]) gc = self.global_view(c_ptr, dtype=float32, shape=[n_size]) - # load a block of A and B into registers - a = self.load_global(ga, offsets=[offset], shape=[self.block_size]) - b = self.load_global(gb, offsets=[offset], shape=[self.block_size]) - - # perform element-wise addition + a = self.load_global(ga, offsets=[offset], shape=[self.block_size_x]) + b = self.load_global(gb, offsets=[offset], shape=[self.block_size_x]) c = a + b - - # store the result back to global memory self.store_global(gc, c, offsets=[offset]) +def call_tilus(kernel_function, args, kwargs, grid, threads, params): + kernel_function(*args, **kwargs) + + def main(): - N = 1 << 20 # vector size - vecadd = VecAddV() - a = torch.rand(N, dtype=torch.float32).cuda() - b = torch.rand(N, dtype=torch.float32).cuda() - c_actual = torch.empty_like(a) + size = 1024000 + + a = torch.randn(size, dtype=torch.float32) + b = torch.randn(size, dtype=torch.float32) + c = torch.zeros_like(b) c_expect = a + b + + + args = [size, a, b, c] + tune_params = dict() + tune_params["block_size_x"] = [32, 64, 128, 256, 512, 1024] - torch.cuda.synchronize() - vecadd(N, a, b, c_actual) - torch.cuda.synchronize() - # correctness check - torch.testing.assert_close(c_expect, c_actual, atol=1e-6, rtol=1e-6) + + results, env = tune_kernel( + kernel_name="VecAddV", + kernel_source=VecAddV, + problem_size=size, + arguments=args, + tune_params=tune_params, + lang="generic_python", + answer=[None, None, None, c_expect.cpu()], + call_function=call_tilus, + verbose=True, + ) - # benchmark - for name, func in [ - ("torch", lambda: a + b), - ("tilus", lambda: vecadd(N, a, b, c_actual)), - ]: - latency = benchmark_func(func, warmup=5, repeat=20) - print(f"{name} latency: {latency:.3f} ms") + if __name__ == "__main__": main() diff --git a/examples/generic_python/triton_vec_add.py b/examples/generic_python/triton_vec_add.py new file mode 100644 index 000000000..be28f2003 --- /dev/null +++ b/examples/generic_python/triton_vec_add.py @@ -0,0 +1,76 @@ +import numpy as np +import triton.language as tl +import torch +from kernel_tuner import tune_kernel, run_kernel +from kernel_tuner.file_utils import store_output_file, store_metadata_file +import triton +from math import ceil + + +@triton.jit +def add_op(x, y): + return x + y + + +#triton.jit +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + block_size_x: tl.constexpr, # Number of elements each program should process. + # note: `constexpr` so it can be used as a shape value. + ): + pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. + block_start = pid * block_size_x + offsets = block_start + tl.arange(0, block_size_x) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = add_op(x, y) + tl.store(output_ptr + offsets, output, mask=mask) + + +def call_triton(kernel_function, args, kwargs, grid, threads, params): + kernel_function[grid](*args, **kwargs) + +# NOTE: can the python file be changed in between? what happens? +# NOTE: tune params in the funcion signature are supported as key word arguments. Do not pass them as args, these +# will be inserted automatically. You can use them in the call function (kwargs). + +def tune_with_generic(): + size = 10000000 + + a = torch.randn(size, device='cuda', dtype=torch.float32) + b = torch.randn(size, device='cuda', dtype=torch.float32) + c = torch.empty_like(b) + n = torch.tensor(size, dtype=torch.int32) + c_expect = a + b + + args = [a, b, c, size] + tune_params = dict() + tune_params["block_size_x"] = [2**i for i in range(10)] + + ''' + result = run_kernel("add_kernel", add_kernel, size, args, {"block_size_x": 256}, + lang="generic_python", call_function=call_triton, decorator="@triton.jit") + print(np.allclose(c_expect.cpu(), result[2])) + ''' + + results, env = tune_kernel( + kernel_name="add_kernel", + kernel_source=add_kernel, + problem_size=size, + arguments=args, + tune_params=tune_params, + lang="generic_python", + answer=[None, None, c_expect.cpu(), None], + call_function=call_triton, + decorator="@triton.jit" + ) + + + + + + +tune_with_generic() \ No newline at end of file diff --git a/examples/generic_python/warp_vec_add.py b/examples/generic_python/warp_vec_add.py new file mode 100644 index 000000000..dc08dd6d0 --- /dev/null +++ b/examples/generic_python/warp_vec_add.py @@ -0,0 +1,98 @@ +import warp as wp +import numpy as np +from kernel_tuner import tune_kernel, run_kernel + + + +@wp.func +def add_op(x: float, y: float): + return x + y + + +#@wp.kernel +def vec_add(a: wp.array(dtype=float), + b: wp.array(dtype=float), + c: wp.array(dtype=float), + n: int, + work_per_thread: int): + + tid = wp.tid() + base = tid * work_per_thread + + for i in range(work_per_thread): + idx = base + i + if idx < n: + c[idx] = add_op(a[idx], b[idx]) + + +# TODO do we allways want the call function to have the same parameters +# or do we only require some of them? +def call_warp(kernel_function, args, kwargs, grid, threads, params): + final_args = list(args) + final_args.extend(kwargs.values()) + dim = args[3] + wp.launch(kernel=kernel_function, dim=dim, inputs=final_args) + + +# NOTE default verify function only works for numpy/cupy ndarray, torch Tensor or numpy scalar +# That is why we need a costum verify function for warp. +def verify(answer, result_host, atol): + correct = True + for i, ans in enumerate(answer): + if ans is None: + continue + res = result_host[i].numpy() + if not np.allclose(ans, res, atol=atol): + correct = False + + return correct + + + +def tune(): + n = 1024 + + # Create host arrays + a_np = np.arange(n, dtype=np.float32) + b_np = np.arange(n, 0, -1, dtype=np.float32) + c_np = np.zeros(n, dtype=np.float32) + c_expect = a_np + b_np + + # Create Warp arrays on GPU + a = wp.array(a_np, dtype=float) + b = wp.array(b_np, dtype=float) + c = wp.array(c_np, dtype=float) + + tune_params = dict() + tune_params["work_per_thread"] = [2**i for i in range(10)] + args = [a, b, c, n] + + + ''' + results = run_kernel( + kernel_name="vec_add", + kernel_source=vec_add, + problem_size=n, + arguments=args, + params={"work_per_thread": 16}, + lang="generic_python", + call_function=call_warp, + decorator="@wp.kernel" + ) + ''' + + results, env = tune_kernel( + kernel_name="vec_add", + kernel_source=vec_add, + problem_size=n, + arguments=args, + tune_params=tune_params, + lang="generic_python", + answer=[None, None, c_expect, None], + verify=verify, + call_function=call_warp, + decorator="@wp.kernel" + ) + + +tune() \ No newline at end of file diff --git a/examples/triton/vec_add.py b/examples/triton/vec_add.py index 555cfba46..42c3e52c7 100644 --- a/examples/triton/vec_add.py +++ b/examples/triton/vec_add.py @@ -3,18 +3,19 @@ import torch from kernel_tuner import tune_kernel, run_kernel from kernel_tuner.file_utils import store_output_file, store_metadata_file +import triton - +#@triton.jit def add_kernel(x_ptr, # *Pointer* to first input vector. y_ptr, # *Pointer* to second input vector. output_ptr, # *Pointer* to output vector. n_elements, # Size of the vector. - BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + block_size_x: tl.constexpr, # Number of elements each program should process. # note: `constexpr` so it can be used as a shape value. ): pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) + block_start = pid * block_size_x + offsets = block_start + tl.arange(0, block_size_x) mask = offsets < n_elements x = tl.load(x_ptr + offsets, mask=mask) y = tl.load(y_ptr + offsets, mask=mask) @@ -22,25 +23,28 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. tl.store(output_ptr + offsets, output, mask=mask) -size = 10000000 -a = torch.randn(size, dtype=torch.float32) -b = torch.randn(size, dtype=torch.float32) -c = torch.zeros_like(b) -n = torch.tensor(size, dtype=torch.int32) +def tune_with_triton(): + size = 10000000 + + a = torch.randn(size, device='cuda', dtype=torch.float32) + b = torch.randn(size, device='cuda', dtype=torch.float32) + c = torch.empty_like(b) + n = torch.tensor(size, dtype=torch.int32) + + args = [a, b, c, n] -args = [a, b, c, n] + tune_params = dict() + tune_params["block_size_x"] = [2**i for i in range(10)] -tune_params = dict() -tune_params["block_size_x"] = [2**i for i in range(10)] + results, env = tune_kernel( + kernel_name="add_kernel", + kernel_source=add_kernel, + problem_size=size, + arguments=args, + tune_params=tune_params, + lang="triton" + ) -results, env = tune_kernel( - kernel_name="add_kernel", - kernel_source=add_kernel, - problem_size=size, - arguments=args, - tune_params=tune_params, - lang="triton" -) -print("Hello") \ No newline at end of file +tune_with_triton() \ No newline at end of file diff --git a/kernel_tuner/backends/generic_python.py b/kernel_tuner/backends/generic_python.py new file mode 100644 index 000000000..fbe0f0750 --- /dev/null +++ b/kernel_tuner/backends/generic_python.py @@ -0,0 +1,290 @@ +import logging +import inspect +import copy +import traceback # for compile error handling +import re + +from kernel_tuner.backends.backend import GPUBackend +from kernel_tuner.observers.generic_python import GenericPythonRuntimeObserver + +try: + import torch +except ImportError: + torch = None + + +# TODO delete temp file + + + +class GenericPythonFunctions(GPUBackend): + + def __init__(self, device=0, iterations=7, compiler_options=None, observers=None): + ''' + In here, everyting is generic if the language uses CUDA as backend + ''' + if not torch: + logging.error("Unable to import Torch") + raise ImportError("Unable to import Torch") + + self.device_id = torch.cuda.current_device() + + self.device_properties = torch.cuda.get_device_properties(self.device_id) + self.name = torch.cuda.get_device_name(self.device_id) + self.max_threads = self.device_properties.max_threads_per_multi_processor + + env = dict() + env["device_name"] = self.name + env["max_threads"] = self.max_threads + env["iterations"] = iterations + env["compiler_options"] = compiler_options + self.env = env + + self.stream = torch.cuda.default_stream() + self.start = torch.cuda.Event(enable_timing=True) + self.end = torch.cuda.Event(enable_timing=True) + + # setup observers + self.observers = observers or [] + self.observers.append(GenericPythonRuntimeObserver(self)) + for obs in self.observers: + obs.register_device(self) + + self.units = {"time": "ms", "power": "s,mW", "energy": "J"} + + # Variables to be filled in at compile time: + self.call_function = None + self.signature = None + self.gpu_kwargs = None + + super().__init__(device=device, iterations=iterations, compiler_options=compiler_options, observers=observers) + + def ready_argument_list(self, arguments): + ''' + The user already supplies the arguments in the correct format, because we are working with + a Python based language anyway. TODO probably only works with torch and numpy? + ''' + return copy.deepcopy(arguments) + + def compile(self, kernel_instance, gpu_args=None): + logging.debug("Compiling Generic Python kernel") + + if kernel_instance.kernel_fn is None: + raise ValueError("kernel_fn is None, currently Generic Python only supports callable kernel_source") + + if gpu_args is None: + raise ValueError("gpu_args is None, Generic Python needs gpu args to compile the kernel") + + # The first time we compile, we also set the call function and the signature + if self.call_function is None or self.signature is None: + self.call_function = kernel_instance.kernel_source.call_function + self.signature = kernel_instance.kernel_source.signature + + grid = kernel_instance.grid + threads = kernel_instance.threads + params = kernel_instance.params + if inspect.isclass(kernel_instance.kernel_fn): + kernel_function = kernel_instance.kernel_fn() + elif callable(kernel_instance.kernel_fn): + # Handles functions and decroators that return callable objects + kernel_function = kernel_instance.kernel_fn + else: + raise TypeError("kernel function is not a class or function") + + self.gpu_kwargs = self.build_gpu_kwargs(params) + + # Call the jit function in order to compile it + self.synchronize() + self.call_function(kernel_function, gpu_args, self.gpu_kwargs, grid, threads, params) + self.synchronize() + + + return kernel_function + + def start_event(self): + logging.debug("Start Generic Python event") + self.start.record() + + def stop_event(self): + logging.debug("Stop Generic Python event") + self.end.record() + + def kernel_finished(self): + logging.debug("Checking if kernel has finished") + return self.end.query() + + def run_kernel(self, func, gpu_args, threads, grid, stream=None, params=None): + + # Run the kernel + if stream is None: + stream = self.stream + + with torch.cuda.stream(stream): + logging.debug("Running Generic Python kernel") + self.call_function(func, gpu_args, self.gpu_kwargs, grid, threads, params) + + + + def build_gpu_kwargs(self, params=None): + gpu_kwargs = {} + + if params is None: + return gpu_kwargs + + for name, p in self.signature.parameters.items(): + if name in params: + gpu_kwargs[name] = params[name] + + return gpu_kwargs + + + def synchronize(self): + torch.cuda.synchronize() + + def memset(self, allocation, value, size): + pass + + def memcpy_dtoh(self, dest, src): + pass + + def memcpy_htod(self, dest, src): + pass + + def copy_constant_memory_args(self, cmem_args): + raise NotImplementedError("Generic Python does not support constant memory") + + def copy_shared_memory_args(self, smem_args): + raise NotImplementedError("Generic Python does not support shared memory") + + def copy_texture_memory_args(self, texmem_args): + raise NotImplementedError("Generic Python does not support texture memory") + + def refresh_memory(self, gpu_memory, host_arguments, should_sync): + """Refresh the GPU memory with the untouched host arguments. We overwrite the standard function + because Python DSLs do usually do not manage memory explicitely""" + for i, arg in enumerate(host_arguments): + if should_sync[i]: + gpu_memory[i] = copy.deepcopy(arg) + + + def classify_compile_exception(self, e): + """Best effort to differentiate between a user error and a resource error. Input is Exception""" + + RESOURCE_KEYWORDS = ( + # Shared memory + "shared memory", + "smem", + "uses too much shared", + "exceeds shared memory", + "shared memory limit", + + # Registers / occupancy + "too many registers", + "register spill", + "uses too many registers", + "out of registers", + + # Launch configuration + "invalid launch configuration", + "invalid configuration argument", + "threads per block", + "block size", + "grid size", + "num_warps", + "num_ctas", + + # Generic resource exhaustion + "too many resources", + "out of resources", + "exceeds maximum", + "exceeds limit", + + # Compiler-level indicators + "ptxas error", + "ptxas fatal", + "nvcc error", + "cuda error", + "llvm error", + "mlir error", + "lowering failed", + ) + + USER_ERROR_KEYWORDS = ( + # Undefined / missing symbols + "not defined", + "undefined variable", + "without definition", + "unknown variable", + "unbound", + + # Type / shape errors + "type mismatch", + "invalid type", + "cannot convert", + "expected .* but got", + "incompatible types", + + # Indexing / bounds + "index out of bounds", + "out of bounds access", + "invalid index", + + # IR / AST construction + "failed to build", + "invalid expression", + "malformed", + "illegal operation", + ) + + RESOURCE_ORIGINS = ( + "ptxas", + "nvcc", + "cuda", + "llvm", + "mlir", + "cubin", + "fatbin", + ) + + USER_ORIGINS = ( + "transpiler", + "scheduler", + "hidet", + "frontend", + "ast", + ) + + + USER_ERROR_TYPES = ( + NameError, + UnboundLocalError, + AttributeError, + TypeError, + SyntaxError, + IndentationError, + ) + + if isinstance(e, USER_ERROR_TYPES): + return "user_error" + + + def match_any(patterns, text): + return any(re.search(p, text) for p in patterns) + + msg = str(e).lower() + tb = "".join(traceback.format_tb(e.__traceback__)).lower() + + + if match_any(RESOURCE_KEYWORDS, msg): + return "resource_error" + + if match_any(RESOURCE_ORIGINS, msg + tb): + return "resource_error" + + if match_any(USER_ERROR_KEYWORDS, msg): + return "user_error" + + if match_any(USER_ORIGINS, msg + tb): + return "user_error" + + return "unknown" diff --git a/kernel_tuner/backends/pycuda.py b/kernel_tuner/backends/pycuda.py index c8f3e689a..d5d9c49ad 100644 --- a/kernel_tuner/backends/pycuda.py +++ b/kernel_tuner/backends/pycuda.py @@ -375,6 +375,7 @@ def memset(self, allocation, value, size): drv.memset_d8(allocation, value, size) def memcpy_dtoh(self, dest, src): + print("dtoh called") """Perform a device to host memory copy. :param dest: A numpy array in host memory to store the data @@ -389,6 +390,7 @@ def memcpy_dtoh(self, dest, src): dest[:] = src def memcpy_htod(self, dest, src): + print("htod called") """Perform a host to device memory copy. :param dest: A GPU memory allocation unit diff --git a/kernel_tuner/backends/tilus.py b/kernel_tuner/backends/tilus.py deleted file mode 100644 index d3782fb0d..000000000 --- a/kernel_tuner/backends/tilus.py +++ /dev/null @@ -1,194 +0,0 @@ -import logging -import numpy as np -import inspect - -from kernel_tuner.backends.backend import GPUBackend -from kernel_tuner.observers.generic_python import GenericPythonRuntimeObserver - -try: - import torch -except ImportError: - logging.error("Torch not available") - -try: - import tilus -except ImportError: - tilus = None - logging.error("Unable to load Tilus") - - -class TilusFunctions(GPUBackend): - - def __init__(self, device=0, iterations=7, compiler_options=None, observers=None): - ''' - In here, everyting is generic if the language uses Torch as backend - ''' - if not tilus or not torch: - logging.error("Tilus or Torch not available") - raise ImportError("Tilus or Torch not available") - - self.device_id = torch.cuda.current_device() - - self.device_properties = torch.cuda.get_device_properties(self.device_id) - self.name = torch.cuda.get_device_name(self.device_id) - self.max_threads = self.device_properties.max_threads_per_multi_processor - - env = dict() - env["device_name"] = self.name - env["max_threads"] = self.max_threads - env["iterations"] = iterations - env["compiler_options"] = compiler_options - self.env = env - - self.stream = torch.cuda.default_stream() - self.start = torch.cuda.Event(enable_timing=True) - self.end = torch.cuda.Event(enable_timing=True) - - # setup observers - self.observers = observers or [] - self.observers.append(GenericPythonRuntimeObserver(self)) - for obs in self.observers: - obs.register_device(self) - - self.units = {"time": "ms", "power": "s,mW", "energy": "J"} - - super().__init__(device=device, iterations=iterations, compiler_options=compiler_options, observers=observers) - - def ready_argument_list(self, arguments): - ''' - This seems to work for Torch as backend. However, another idea for generic would be - to let the user provide two argument lists: one with numpy and one with the actual arguments - that will be passed to the kernel. TODO: why do we need the numpy list anyway? - ''' - torch_args = [] - - for arg in arguments: - if isinstance(arg, torch.Tensor) and arg.dim() > 0: - torch_args.append(arg.cuda()) - elif isinstance(arg, torch.Tensor) and arg.dim() == 0: - scalar_value = arg.item() - torch_args.append(scalar_value) - elif isinstance(arg, np.ndarray): - torch_arg = torch.from_numpy(arg) - torch_arg_gpu = torch_arg.cuda() - torch_args.append(torch_arg_gpu) - elif isinstance(arg, np.generic): - scalar_value = arg.item() - torch_args.append(scalar_value) - else: - logging.warning("Unknown instance in Tilus functions") - - return torch_args - - def compile(self, kernel_instance, gpu_args=None): - logging.debug("Compiling Tilus kernel") - - if kernel_instance.kernel_fn is None: - raise ValueError("kernel_fn is None, currently Tilus only supports callable kernel_source") - - ''' - if gpu_args is None: - raise ValueError("gpu_args is None, Triton needs gpu args to compile the kernel") - TODO: find out about these GPU args - ''' - - grid = kernel_instance.grid # TODO might need to use this? - threads = kernel_instance.threads - params = kernel_instance.params - kernel_function = kernel_instance.kernel_fn - gpu_kwargs = self.build_gpu_kwargs(kernel_function, threads, params) - - # Call the jit function in order to compile it - kernel_function(*gpu_args, **gpu_kwargs) - - return kernel_function - - def start_event(self): - logging.debug("Start triton event") - self.start.record() - - def stop_event(self): - logging.debug("Stop triton event") - self.end.record() - - def kernel_finished(self): - logging.debug("Checking if kernel has finished") - return self.end.query() - - def run_kernel(self, func, gpu_args, threads, grid, stream=None, params=None): - ''' - if params is None: - raise ValueError("params is None, Triton needs params in order to set num_warps, num_ctas, etc.") - ''' - - # Run the kernel - if stream is None: - stream = self.stream - - gpu_kwargs = self.build_gpu_kwargs(func, threads, params) - - with torch.cuda.stream(stream): - logging.debug("Running Tilus kernel") - func(*gpu_args, **gpu_kwargs) - - - - def build_gpu_kwargs(self, kernel_function, threads, params=None): - gpu_kwargs = {} - - ''' - if 'BLOCK_SIZE' in jit_fn.arg_names: - gpu_kwargs['BLOCK_SIZE'] = threads[0] - - if 'BLOCK_SIZE_X' in jit_fn.arg_names: - gpu_kwargs['BLOCK_SIZE_X'] = threads[0] - - if 'BLOCK_SIZE_Y' in jit_fn.arg_names: - gpu_kwargs['BLOCK_SIZE_Y'] = threads[1] - - if 'BLOCK_SIZE_Z' in jit_fn.arg_names: - gpu_kwargs['BLOCK_SIZE_Z'] = threads[2] - ''' - - if params is None: - return gpu_kwargs - - - sig = inspect.signature(kernel_function) - for name, p in sig.parameters.items(): - if name in params: - gpu_kwargs[name] = params[name] - - ''' - # Check for Triton specific parameters - if 'num_warps' in params: - gpu_kwargs['num_warps'] = params['num_warps'] - if 'num_ctas' in params: - gpu_kwargs['num_ctas'] = params['num_ctas'] - if 'num_stages' in params: - gpu_kwargs['num_stages'] = params['num_stages'] - ''' - - return gpu_kwargs - - - def synchronize(self): - torch.cuda.synchronize() - - def memset(self, allocation, value, size): - pass - - def memcpy_dtoh(self, dest, src): - pass - - def memcpy_htod(self, dest, src): - pass - - def copy_constant_memory_args(self, cmem_args): - raise NotImplementedError("Tilus does not support constant memory") - - def copy_shared_memory_args(self, smem_args): - raise NotImplementedError("Tilus does not support shared memory") - - def copy_texture_memory_args(self, texmem_args): - raise NotImplementedError("Tilus does not support texture memory") \ No newline at end of file diff --git a/kernel_tuner/core.py b/kernel_tuner/core.py index 3b545d459..7705d828c 100644 --- a/kernel_tuner/core.py +++ b/kernel_tuner/core.py @@ -22,11 +22,12 @@ from kernel_tuner.backends.opencl import OpenCLFunctions from kernel_tuner.backends.pycuda import PyCudaFunctions from kernel_tuner.backends.triton import TritonFunctions -from kernel_tuner.backends.tilus import TilusFunctions +from kernel_tuner.backends.generic_python import GenericPythonFunctions from kernel_tuner.kernel_sources.kernel_source import KernelSource from kernel_tuner.observers.nvml import NVMLObserver from kernel_tuner.observers.observer import ContinuousObserver, OutputObserver, PrologueObserver from kernel_tuner.observers.tegra import TegraObserver +from kernel_tuner.language import Language try: import torch @@ -58,12 +59,12 @@ class KernelInstance(_KernelInstance): """Class that represents the specific parameterized instance of a kernel.""" def __new__(cls, *args, **kwargs): - # Detect old-style calls (without kernel_fn) + # Detect old-style calls (without kernel_fn and gpu_kwargs) for old tests if len(args) == 8: # old version name, kernel_source, kernel_string, temp_files, threads, grid, params, arguments = args kernel_fn = None args = (name, kernel_source, kernel_string, kernel_fn, temp_files, threads, grid, params, arguments) - elif "kernel_fn" not in kwargs and len(args) == 8: + elif "kernel_fn" not in kwargs and len(args) < 9: kwargs["kernel_fn"] = None return super().__new__(cls, *args, **kwargs) @@ -76,7 +77,7 @@ def delete_temp_files(self): def prepare_temp_files_for_error_msg(self): """Prepare temp file with source code, and return list of temp file names.""" if type(self.kernel_source).__name__ == "KernelSourceFn": - return [] # TODO what do we want to return here? + return [] # already done during compilation temp_filename = util.get_temp_filename(suffix=self.kernel_source.get_suffix()) util.write_file(temp_filename, self.kernel_string) @@ -136,28 +137,28 @@ def __init__( logging.debug("DeviceInterface instantiated, lang=%s", lang) - if lang.upper() == "CUDA": + if lang == Language.CUDA: dev = PyCudaFunctions( device, compiler_options=compiler_options, iterations=iterations, observers=observers, ) - elif lang.upper() == "CUPY": + elif lang == Language.CUPY: dev = CupyFunctions( device, compiler_options=compiler_options, iterations=iterations, observers=observers, ) - elif lang.upper() == "NVCUDA": + elif lang == Language.NVCUDA: dev = CudaFunctions( device, compiler_options=compiler_options, iterations=iterations, observers=observers, ) - elif lang.upper() == "OPENCL": + elif lang == Language.OPENCL: dev = OpenCLFunctions( device, platform, @@ -165,35 +166,35 @@ def __init__( iterations=iterations, observers=observers, ) - elif lang.upper() in ["C", "FORTRAN"]: + elif lang == Language.C or lang == Language.FORTRAN: dev = CompilerFunctions( compiler=compiler, compiler_options=compiler_options, iterations=iterations, observers=observers, ) - elif lang.upper() == "HIP": + elif lang == Language.HIP: dev = HipFunctions( device, compiler_options=compiler_options, iterations=iterations, observers=observers, ) - elif lang.upper() == "HYPERTUNER": + elif lang == Language.HYPERTUNER: dev = HypertunerFunctions( iterations=iterations, compiler_options=compiler_options ) self.requires_warmup = False - elif lang.upper() == "TRITON": + elif lang == Language.TRITON: dev = TritonFunctions( device, compiler_options=compiler_options, iterations=iterations, observers=observers ) - elif lang.upper() == "GENERIC_PYTHON": - dev = TilusFunctions( + elif lang == Language.GENERIC_PYTHON: + dev = GenericPythonFunctions( device, compiler_options=compiler_options, iterations=iterations, @@ -241,7 +242,7 @@ def __init__( print("Using: " + self.dev.name) def run_kernel_bench(self, func, gpu_args, threads, grid, stream=None, params=None): - if isinstance(self.dev, TritonFunctions): + if isinstance(self.dev, TritonFunctions) or isinstance(self.dev, GenericPythonFunctions): self.dev.run_kernel(func, gpu_args, threads, grid, params=params) else: self.dev.run_kernel(func, gpu_args, threads, grid) @@ -392,31 +393,44 @@ def check_kernel_output( self.dev.refresh_memory(gpu_args, instance.arguments, should_sync) # run the kernel + self.dev.synchronize() check = self.run_kernel_check(func, gpu_args, instance) + self.dev.synchronize() if not check: # runtime failure occurred that should be ignored, skip correctness check return # retrieve gpu results to host memory result_host = [] - for i, arg in enumerate(instance.arguments): - if should_sync[i]: - if isinstance(arg, (np.ndarray, cp.ndarray)): - result_host.append(np.zeros_like(arg)) - self.dev.memcpy_dtoh(result_host[-1], gpu_args[i]) - elif isinstance(arg, torch.Tensor) and isinstance(answer[i], torch.Tensor): - if not answer[i].is_cuda: - # if the answer is on the host, copy gpu output to host as well - result_host.append(torch.zeros_like(answer[i])) - self.dev.memcpy_dtoh(result_host[-1], gpu_args[i].tensor) + if instance.kernel_source is not None and instance.kernel_source.lang == Language.GENERIC_PYTHON: + # Python DSLs do not explicitly manage memory. Therefore, we can use gpu args direclty + for i, arg in enumerate(gpu_args): + if should_sync[i]: + if isinstance(arg, torch.Tensor): + result_host.append(arg.cpu()) else: - result_host.append(gpu_args[i].tensor) + result_host.append(arg) + else: + result_host.append(None) + else: + for i, arg in enumerate(instance.arguments): + if should_sync[i]: + if isinstance(arg, (np.ndarray, cp.ndarray)): + result_host.append(np.zeros_like(arg)) + self.dev.memcpy_dtoh(result_host[-1], gpu_args[i]) + elif isinstance(arg, torch.Tensor) and isinstance(answer[i], torch.Tensor): + if not answer[i].is_cuda: + # if the answer is on the host, copy gpu output to host as well + result_host.append(torch.zeros_like(answer[i])) + self.dev.memcpy_dtoh(result_host[-1], gpu_args[i].tensor) + else: + result_host.append(gpu_args[i].tensor) + else: + # We should sync this argument, but we do not know how to transfer this type of argument + # What do we do? Should we throw an error? + result_host.append(None) else: - # We should sync this argument, but we do not know how to transfer this type of argument - # What do we do? Should we throw an error? result_host.append(None) - else: - result_host.append(None) # Call the output observers for obs in self.output_observers: @@ -434,6 +448,7 @@ def check_kernel_output( correct = True if not correct: + print("expected: ", answer, "\ngot: ", result_host) raise RuntimeError("Kernel result verification failed for: " + util.get_config_string(instance.params)) def compile_and_benchmark(self, kernel_source, gpu_args, params, kernel_options, to): @@ -449,7 +464,6 @@ def compile_and_benchmark(self, kernel_source, gpu_args, params, kernel_options, instance_string = util.get_instance_string(params) logging.debug("compile_and_benchmark " + instance_string) - instance = self.create_kernel_instance(kernel_source, kernel_options, params, verbose) if isinstance(instance, util.ErrorConfig): result[to.objective] = util.InvalidConfig() @@ -515,33 +529,55 @@ def compile_kernel(self, instance, verbose, gpu_args=None): # compile kernel_string into device func func = None - try: - if isinstance(self.dev, TritonFunctions): + + if isinstance(self.dev, TritonFunctions) or isinstance(self.dev, GenericPythonFunctions): + try: func = self.dev.compile(instance, gpu_args) - else: + except Exception as e: + if isinstance(self.dev, GenericPythonFunctions): + exception_type = self.dev.classify_compile_exception(e) + else: + exception_type = "unkown" + + if exception_type == "user_error": + error_message = str(e.stderr) if hasattr(e, "stderr") else str(e) + print("compile_kernel failed due to error: " + error_message) + print("Error while compiling:", instance.name) + raise e + else: + logging.debug( + "compile_kernel failed due to kernel using too many resources" + ) + if verbose: + print( + f"skipping config {util.get_instance_string(instance.params)} reason: too many resources" + ) + + else: + try: func = self.dev.compile(instance) - except Exception as e: - # compiles may fail because certain kernel configurations use too - # much shared memory for example, the desired behavior is to simply - # skip over this configuration and try the next one - shared_mem_error_messages = [ - "uses too much shared data", - "local memory limit exceeded", - r"local memory \(\d+\) exceeds limit \(\d+\)", - ] - error_message = str(e.stderr) if hasattr(e, "stderr") else str(e) - if any(re.search(msg, error_message) for msg in shared_mem_error_messages): - logging.debug( - "compile_kernel failed due to kernel using too much shared memory" - ) - if verbose: - print( - f"skipping config {util.get_instance_string(instance.params)} reason: too much shared memory used" + except Exception as e: + # compiles may fail because certain kernel configurations use too + # much shared memory for example, the desired behavior is to simply + # skip over this configuration and try the next one + shared_mem_error_messages = [ + "uses too much shared data", + "local memory limit exceeded", + r"local memory \(\d+\) exceeds limit \(\d+\)", + ] + error_message = str(e.stderr) if hasattr(e, "stderr") else str(e) + if any(re.search(msg, error_message) for msg in shared_mem_error_messages): + logging.debug( + "compile_kernel failed due to kernel using too much shared memory" ) - else: - print("compile_kernel failed due to error: " + error_message) - print("Error while compiling:", instance.name) - raise e + if verbose: + print( + f"skipping config {util.get_instance_string(instance.params)} reason: too much shared memory used" + ) + else: + print("compile_kernel failed due to error: " + error_message) + print("Error while compiling:", instance.name) + raise e return func @staticmethod @@ -580,7 +616,8 @@ def create_kernel_instance(self, kernel_source, kernel_options, params, verbose) params, kernel_options.block_size_names, ) - if kernel_source.lang != 'TRITON' and np.prod(threads) > self.dev.max_threads: + + if kernel_source.lang not in [Language.TRITON, Language.GENERIC_PYTHON] and np.prod(threads) > self.dev.max_threads: if verbose: print(f"skipping config {util.get_instance_string(params)} reason: too many threads per block") return util.InvalidConfig() diff --git a/kernel_tuner/interface.py b/kernel_tuner/interface.py index bad921d24..ce9d74080 100644 --- a/kernel_tuner/interface.py +++ b/kernel_tuner/interface.py @@ -43,6 +43,7 @@ from kernel_tuner.runners.sequential import SequentialRunner from kernel_tuner.runners.simulation import SimulationRunner from kernel_tuner.searchspace import Searchspace +from kernel_tuner.language import Language try: import torch @@ -588,12 +589,15 @@ def tune_kernel( observers=None, objective=None, objective_higher_is_better=None, + call_function=None, + decorator=None ): + start_overhead_time = perf_counter() if log: logging.basicConfig(filename=kernel_name + datetime.now().strftime("%Y%m%d-%H:%M:%S") + ".log", level=log) - kernelsource = KernelSource(kernel_name, kernel_source, lang, defines) + kernelsource = KernelSource(kernel_name, kernel_source, lang, defines, call_function, decorator) _check_user_input(kernel_name, kernelsource, arguments, block_size_names) @@ -774,11 +778,13 @@ def run_kernel( block_size_names=None, quiet=False, log=None, + call_function=None, + decorator=None ): if log: logging.basicConfig(filename=kernel_name + datetime.now().strftime("%Y%m%d-%H:%M:%S") + ".log", level=log) - kernelsource = KernelSource(kernel_name, kernel_source, lang, defines) + kernelsource = KernelSource(kernel_name, kernel_source, lang, defines, call_function, decorator) _check_user_input(kernel_name, kernelsource, arguments, block_size_names) @@ -806,8 +812,11 @@ def run_kernel( # see if the kernel arguments have correct type util.check_argument_list(instance.name, instance.kernel_string, arguments) - # compile the kernel - func = dev.compile_kernel(instance, False) + # compile the kernel, Extra logic needed to check if we have a python function + if lang.upper() == "TRITON" or lang.upper() == "GENERIC_PYTHON": + func = dev.compile_kernel(instance, False, gpu_args) + else: + func = dev.compile_kernel(instance, False) if func is None: raise RuntimeError("cannot compile kernel, too much shared memory used") @@ -829,16 +838,26 @@ def run_kernel( if not dev.run_kernel_check(func, gpu_args, instance): raise RuntimeError("runtime error occured, too many resources requested") + + # copy data in GPU memory back to the host results = [] - for i, arg in enumerate(arguments): - if numpy.isscalar(arg): - results.append(arg) - elif isinstance(arg, torch.Tensor): - results.append(arg.cpu()) - else: - results.append(numpy.zeros_like(arg)) - dev.memcpy_dtoh(results[-1], gpu_args[i]) + if instance.kernel_source.lang == Language.GENERIC_PYTHON: + for arg in gpu_args: + if isinstance(arg, torch.Tensor): + results.append(arg.cpu()) + else: + results.append(arg) + else: + for i, arg in enumerate(arguments): + if numpy.isscalar(arg): + results.append(arg) + elif isinstance(arg, torch.Tensor): + results.append(arg.cpu()) + else: + results.append(numpy.zeros_like(arg)) + dev.memcpy_dtoh(results[-1], gpu_args[i]) + return results diff --git a/kernel_tuner/kernel_sources/kernel_source.py b/kernel_tuner/kernel_sources/kernel_source.py index 70c94903d..9a16d9b6c 100644 --- a/kernel_tuner/kernel_sources/kernel_source.py +++ b/kernel_tuner/kernel_sources/kernel_source.py @@ -6,65 +6,50 @@ from kernel_tuner.language import Language from kernel_tuner.kernel_sources.model.prepared_kernel_source_data import PreparedKernelSourceData - - - -class KernelSource: - def __new__(cls, kernel_name, kernel_sources, lang, defines=None): +# We use this pattern because we would otherwise get an init twice +# We create a kernelsouce by calling KernelSource(...). The call method for KernelSource is +# replaced by the call method in KernelSourceFactory, because KernelSource uses that class as +# metaclass. Inside the call method, a call to create either KernelSourceStr or KernelSourceFN +# is done. This triggers the __init__ call of those subclasses. In both subclasses, we call super().__init__ +# which initializes some sublclass specific variables and triggers the __init__ call of the KernelSource class. +class KernelSourceFactory(type): + def __call__(cls, kernel_name, kernel_sources, lang, defines=None, call_function=None, decorator=None): """Factory behavior""" if lang == None: language = None else: try: - language = Language(lang) + language = Language(lang.upper()) except ValueError: raise TypeError(f"Supported languages are {[l.value for l in Language]}") + + + # Determine if we need to create a KernelSourceStr or a KernelSourceFn + if (language and (language == Language.TRITON or language == Language.GENERIC_PYTHON)): + ks_str = False + else: + ks_str = True + from kernel_tuner.kernel_sources.kernel_source_str import KernelSourceStr + from kernel_tuner.kernel_sources.kernel_source_fn import KernelSourceFn if cls is KernelSource: - if inspect.isfunction(kernel_sources) and (language and (language == Language.TRITON or language == Language.GENERIC_PYTHON)): # TODO should this be isfunction? - from kernel_tuner.kernel_sources.kernel_source_fn import KernelSourceFn - print("CREATING KSFN") - return KernelSourceFn(kernel_name, kernel_sources, lang, defines) - else: - from kernel_tuner.kernel_sources.kernel_source_str import KernelSourceStr - print("CREATING KSSTR") + if ks_str: return KernelSourceStr(kernel_name, kernel_sources, lang, defines) - - # otherwise, normal subclass init - return super().__new__(cls) - - - def __init__(self, kernel_name, kernel_sources, lang, defines=None): - if not isinstance(kernel_sources, list): - kernel_sources = [kernel_sources] - self.kernel_sources = kernel_sources - self.kernel_name = kernel_name - self.defines = defines + else: + return KernelSourceFn(kernel_name, kernel_sources, lang, defines, call_function, decorator) - if lang is None: - if callable(self.kernel_sources[0]): - raise TypeError("Please specify language when using a code generator function") - kernel_string = self.get_kernel_string(0) - self.lang = util.detect_language(kernel_string) + # Else, normal behaviour for subclasses + if ks_str: + return super().__call__(kernel_name, kernel_sources, lang, defines) else: - self.lang = lang - - @abstractmethod - def prepare_kernel_instance(self, kernel_options, params, grid, threads) -> PreparedKernelSourceData: - raise NotImplementedError("create_kernel_instance not implemented") - - @abstractmethod - def check_argument_lists(self, kernel_name, arguments): - raise NotImplementedError("check_argument_lists not implemented") - - + return super().__call__(kernel_name, kernel_sources, lang, defines, call_function, decorator) + +# TODO do we really want the Language enum? it's a lot of changes +class KernelSource(metaclass=KernelSourceFactory): -''' -class KernelSource: def __init__(self, kernel_name, kernel_sources, lang, defines=None): if not isinstance(kernel_sources, list): kernel_sources = [kernel_sources] - self.kernel_sources = kernel_sources self.kernel_name = kernel_name self.defines = defines @@ -73,9 +58,18 @@ def __init__(self, kernel_name, kernel_sources, lang, defines=None): if callable(self.kernel_sources[0]): raise TypeError("Please specify language when using a code generator function") kernel_string = self.get_kernel_string(0) - self.lang = util.detect_language(kernel_string) + language = util.detect_language(kernel_string) + try: + self.lang = Language(language.upper()) + except ValueError: + # TODO this should never happen + raise TypeError(f"Supported languages are {[l.value for l in Language]}, found {language}") else: - self.lang = lang + try: + self.lang = Language(lang.upper()) + except ValueError: + raise TypeError(f"Supported languages are {[l.value for l in Language]}") + @abstractmethod def prepare_kernel_instance(self, kernel_options, params, grid, threads) -> PreparedKernelSourceData: @@ -85,4 +79,4 @@ def prepare_kernel_instance(self, kernel_options, params, grid, threads) -> Prep def check_argument_lists(self, kernel_name, arguments): raise NotImplementedError("check_argument_lists not implemented") -''' \ No newline at end of file + diff --git a/kernel_tuner/kernel_sources/kernel_source_fn.py b/kernel_tuner/kernel_sources/kernel_source_fn.py index 223e0e84f..edb75f95d 100644 --- a/kernel_tuner/kernel_sources/kernel_source_fn.py +++ b/kernel_tuner/kernel_sources/kernel_source_fn.py @@ -18,40 +18,62 @@ class KernelSourceFn(KernelSource): - def __init__(self, kernel_name, kernel_source, lang, defines=None): + def __init__(self, kernel_name, kernel_source, lang, defines=None, call_function=None, decorator=None): super().__init__(kernel_name, kernel_source, lang, defines) if isinstance(kernel_source, list): raise ValueError("KernelSourceFn only supports a single kernel source function") - try: - self.lang = Language(lang) + self.lang = Language(lang.upper()) except ValueError: raise TypeError(f"Supported languages are {[l.value for l in Language]}") + + if call_function is None: + raise ValueError("call_function must be supplied for language Generic Python") + if not callable(call_function): + raise TypeError(f"call_function {call_function} is not a callable object.") + self.call_function = call_function # TODO ceck signature + + if decorator: + if not isinstance(decorator, str): + raise TypeError(f"{decorator} is not a decorator") + if decorator[0] != '@': + raise ValueError(f"{decorator} is not a valid decorator") + self.decorator = decorator + self.source_kernel_fn = kernel_source self.kernel_fn = self.source_kernel_fn - self.source = inspect.getsource(kernel_source) + self.signature = inspect.signature(kernel_source) + try: + self.source = inspect.getsource(kernel_source) + except TypeError as e: + raise TypeError( + f"{e}. Did you forget to remove a decorator before tuning?" + ) from e self.source_tree = ast.parse(self.source) - if self.lang == Language.TRITON: - self.import_nodes = [ - ast.Import(names=[ast.alias(name='triton', asname=None)]), - ast.ImportFrom( - module='triton', - names=[ast.alias(name='language', asname='tl')], - level=0 - ) - ] - else: - self.import_nodes = [n for n in self.source_tree.body if isinstance(n, (ast.Import, ast.ImportFrom))] - + self.import_nodes = self._find_import_nodes(inspect.getfile(kernel_source)) # Find the module where the kernel function is defined self.module = inspect.getmodule(kernel_source) # Get dependencies by analyzing the AST - if self.lang == Language.lang: + if self.lang == Language.TRITON: self.dependencies = self._find_triton_dependencies() else: - self.dependencies = None # TODO + self.dependencies = self._find_dependencies() + + + + + def _find_import_nodes(self, source_file): + with open(source_file, "r") as f: + tree = ast.parse(f.read(), filename=source_file) + + import_nodes = [] + for node in tree.body: + if isinstance(node, (ast.Import, ast.ImportFrom)): + import_nodes.append(node) + + return import_nodes def _find_function_dependencies(self): """Find all function calls in the kernel.""" @@ -99,6 +121,49 @@ def _find_triton_dependencies(self): return dependencies + def _find_function_dependecies2(self, tree, local_funcs): + # Non triton specific + class FunctionCallVisitor(ast.NodeVisitor): + def __init__(self): + self.called = set() + + def visit_Call(self, node): + if isinstance(node.func, ast.Name): + name = node.func.id + if name in local_funcs: + self.called.add(name) + self.generic_visit(node) + + visitor = FunctionCallVisitor() + visitor.visit(tree) + return visitor.called + + def _find_local_functions(self, tree): + local_funcs = set() + for node in tree.body: + if isinstance(node, ast.FunctionDef): + local_funcs.add(node.name) + return local_funcs + + def _find_dependencies(self): + source_file = inspect.getfile(self.source_kernel_fn) + with open(source_file, "r") as f: + source_code = f.read() + tree = ast.parse(source_code, filename=source_file) + local_funcs = self._find_local_functions(tree) + called_functions = self._find_function_dependecies2(self.source_tree, local_funcs) + + dependencies = {} + for node in tree.body: + if (isinstance(node, ast.FunctionDef) and node.name in called_functions and + node.name != self.kernel_name): # Skip the main kernel itself + dependencies[node.name] = node + + return dependencies + + def _add_decorator(self, function): + pass + def prepare_kernel_instance(self, kernel_options, params, grid, threads): new_kernel_fn, temp_file_path = self.apply_params_to_source_fn(params) self.kernel_fn = new_kernel_fn @@ -122,22 +187,34 @@ def apply_params_to_source_fn(self, params): # Add imports new_module.body.extend(self.import_nodes) - if self.lang == Language.TRITON: - # Add all Triton JIT dependencies first - for dep_node in self.dependencies.values(): - new_module.body.append(copy.deepcopy(dep_node)) - - # TODO the kernel can not have dependencies yet. Obviously this needs fixing. + # Add dependencies (functions) + for dep_node in self.dependencies.values(): + dep_node_copy = copy.deepcopy(dep_node) + transformed_dep_node = transformer.visit(dep_node_copy) + #new_module.body.append(copy.deepcopy(dep_node)) + new_module.body.append(transformed_dep_node) # Add transformed main kernel source_tree_copy = copy.deepcopy(self.source_tree) transformed_tree = transformer.visit(source_tree_copy) + + # Add decorator if needed + if self.decorator: + dummy = f"{self.decorator}\ndef _dummy():\n pass\n" + decorator_node = ast.parse(dummy).body[0].decorator_list[0] + for node in transformed_tree.body: + if isinstance(node, ast.FunctionDef): + node.decorator_list.insert(0, decorator_node) + break # only apply to the top level function + new_module.body.extend(transformed_tree.body) # Fix locations and generate source ast.fix_missing_locations(new_module) new_source = astor.to_source(new_module) + #print(new_source) + # Create a unique module name module_name = f'temp_kernel_module_{uuid.uuid4().hex}' @@ -173,5 +250,45 @@ def visit_Name(self, node: ast.Name) -> Any: ast.Constant(value=self.params[node.id]), node ) + return node + + def visit_Attribute(self, node: ast.Attribute): + self.generic_visit(node) + + # Replace self. with constant, but only if it's being read + if ( + isinstance(node.value, ast.Name) + and node.value.id == "self" + and node.attr in self.params + and isinstance(node.ctx, ast.Load) # <- check context + ): + return ast.copy_location(ast.Constant(self.params[node.attr]), node) + + return node + + + def visit_Assign(self, node: ast.Assign): + # Replace assignments like: mock_param = ... + if ( + len(node.targets) == 1 + and isinstance(node.targets[0], ast.Name) + and node.targets[0].id in self.params + ): + node.value = ast.Constant(value=self.params[node.targets[0].id]) + return node + + # Replace assignment like self. = ... + if ( + len(node.targets) == 1 + and isinstance(node.targets[0], ast.Attribute) + and isinstance(node.targets[0].value, ast.Name) + and node.targets[0].value.id == "self" + and node.targets[0].attr in self.params + ): + node.value = ast.Constant(value=self.params[node.targets[0].attr]) + return node + + return self.generic_visit(node) + + - return node \ No newline at end of file diff --git a/kernel_tuner/kernel_sources/kernel_source_str.py b/kernel_tuner/kernel_sources/kernel_source_str.py index f470b939f..4db0831a7 100644 --- a/kernel_tuner/kernel_sources/kernel_source_str.py +++ b/kernel_tuner/kernel_sources/kernel_source_str.py @@ -60,7 +60,7 @@ def get_kernel_string(self, index=0, params=None): """ logging.debug("get_kernel_string called") - if hasattr(self, 'lang') and self.lang.upper() == "HYPERTUNER": + if hasattr(self, 'lang') and self.lang == Language.HYPERTUNER: return "" kernel_source = self.kernel_sources[index] @@ -102,7 +102,7 @@ def prepare_list_of_files( """ temp_files = dict() - if self.lang.upper() == "HYPERTUNER": + if self.lang == Language.HYPERTUNER: return tuple(["", "", temp_files]) for i, f in enumerate(self.kernel_sources): @@ -156,7 +156,7 @@ def get_suffix(self, index=0): if suffix is not None: return suffix - _suffixes = {"CUDA": ".cu", "OpenCL": ".cl", "C": ".c"} + _suffixes = {Language.CUDA: ".cu", Language.OPENCL: ".cl", Language.C: ".c"} try: return _suffixes[self.lang] except KeyError: diff --git a/kernel_tuner/language.py b/kernel_tuner/language.py index d33933cca..b129e04d5 100644 --- a/kernel_tuner/language.py +++ b/kernel_tuner/language.py @@ -2,13 +2,15 @@ class Language(Enum): CUDA = "CUDA" - OpenCL = "OPENCL" + OPENCL = "OPENCL" C = "C" HIP = "HIP" TRITON = "TRITON" FORTRAN = "FORTRAN" NVCUDA = "NVCUDA" GENERIC_PYTHON = "GENERIC_PYTHON" + CUPY = "CUPY" + HYPERTUNER = "HYPERTUNER" @classmethod def is_valid(cls, value: str) -> bool: diff --git a/kernel_tuner/observers/generic_python.py b/kernel_tuner/observers/generic_python.py index 896b4f9cc..62dfa20c7 100644 --- a/kernel_tuner/observers/generic_python.py +++ b/kernel_tuner/observers/generic_python.py @@ -9,8 +9,7 @@ class GenericPythonRuntimeObserver(BenchmarkObserver): - """Observer that measures time using CUDA events during benchmarking. - TODO think about: do we need torch? Does CUDA work for all languages?""" + """Observer that measures time using CUDA events during benchmarking.""" def __init__(self, dev): if torch is None: diff --git a/kernel_tuner/util.py b/kernel_tuner/util.py index fd4cb5697..c28bafa0b 100644 --- a/kernel_tuner/util.py +++ b/kernel_tuner/util.py @@ -165,7 +165,7 @@ def check_argument_list(kernel_name, kernel_string, args): if not isinstance(arg, (np.ndarray, np.generic, cp.ndarray, torch.Tensor, DeviceArray)): raise TypeError( f"Argument at position {i} of type: {type(arg)} should be of type " -- "np.ndarray, numpy scalar, or HIP Python DeviceArray type" + "np.ndarray, numpy scalar, or HIP Python DeviceArray type" ) correct = True diff --git a/test/context.py b/test/context.py index bad152986..81637f24b 100644 --- a/test/context.py +++ b/test/context.py @@ -73,6 +73,12 @@ except ImportError: bayes_opt_gpytorch_present = False +try: + import torch + gen_python_torch_present = True +except ImportError: + gen_python_torch_present = False + try: import pyatf pyatf_present = True @@ -109,6 +115,7 @@ skip_if_no_hip = pytest.mark.skipif(not hip_present, reason="No HIP Python found") skip_if_no_pyatf = pytest.mark.skipif(not pyatf_present, reason="PyATF not installed") skip_if_no_methodology = pytest.mark.skipif(not methodology_present, reason="Autotuning Methodology not found") +skip_if_no_torch = pytest.mark.skipif(not gen_python_torch_present, reason="Torch not installed") def skip_backend(backend: str): @@ -128,3 +135,5 @@ def skip_backend(backend: str): pytest.skip("No nvc++ on PATH") elif backend.upper() == "HIP" and not hip_present: pytest.skip("HIP Python not installed") + elif backend.upper() == "GENERIC_PYTHON" and not gen_python_torch_present: + pytest.skip("Torch not installed") diff --git a/test/test_backend.py b/test/test_backend.py index e694649c1..af9971a48 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -4,8 +4,9 @@ skip_if_no_cuda, skip_if_no_opencl, skip_if_no_pycuda, + skip_if_no_torch, ) -from kernel_tuner.backends import backend, compiler, cupy, nvcuda, opencl, pycuda +from kernel_tuner.backends import backend, compiler, cupy, nvcuda, opencl, pycuda, generic_python class WrongBackend(backend.Backend): @@ -45,3 +46,8 @@ def test_opencl_backend(): @skip_if_no_pycuda def test_pycuda_backend(): dev = pycuda.PyCudaFunctions() + + +@skip_if_no_torch +def test_generic_python_backend(): + dev = generic_python.GenericPythonFunctions() \ No newline at end of file diff --git a/test/test_generic_python_functions.py b/test/test_generic_python_functions.py new file mode 100644 index 000000000..a3d250b60 --- /dev/null +++ b/test/test_generic_python_functions.py @@ -0,0 +1,116 @@ +from .context import skip_if_no_torch +from .test_kernel_source_fn import mock_kernel, kernel_with_kwarg, call_mock +from kernel_tuner.core import DeviceInterface, KernelInstance +from kernel_tuner.kernel_sources.kernel_source import KernelSource +import numpy as np + +try: + import torch + torch_present = True +except ImportError: + pass + + +# Helper functions ------------------------------ + +def value_equal(a, b): + # Torch tensors + if isinstance(a, torch.Tensor): + return torch.equal(a, b) + + # NumPy arrays + if isinstance(a, np.ndarray): + return np.array_equal(a, b) + + # Fallback (ints, floats, strings, lists, tuples, dicts) + return a == b + +def get_context(): + params = {'mock_param': 64} + a = 42 + b = torch.randn(12, device='cuda', dtype=torch.float32) + args = [a, b] + ks = KernelSource("mock_kernel", mock_kernel, "generic_python", call_function=call_mock) + return ks, args, params + + +# Tests ---------------------------------------------- + +@skip_if_no_torch +def test_ready_argument_list(): + ks, args, params = get_context() + dev = DeviceInterface(ks) + gpu_args = dev.ready_argument_list(args) + + assert len(args) == len(gpu_args) + + for i, _ in enumerate(gpu_args): + assert value_equal(args[i], gpu_args[i]) + if type(gpu_args[i]) in (list, dict, torch.Tensor, np.ndarray): + assert gpu_args[i] is not args[i] # Assure deep copy + + +@skip_if_no_torch +def test_compile(): + ks, args, params = get_context() + + instance_data = ks.prepare_kernel_instance( + kernel_options=None, + params=params, + grid=None, + threads=None, + ) + dev = DeviceInterface(ks) + instance = KernelInstance( + name="mock_kernel", + kernel_source=ks, + kernel_string=None, + kernel_fn=instance_data.kernel_fn, + temp_files=instance_data.temp_files, + threads=None, + grid=None, + params=params, + arguments=None, + ) + gpu_args = dev.ready_argument_list(args) + + callable_fn = dev.compile_kernel(instance, verbose=False, gpu_args=gpu_args) + + # The mock function return the mock_param, which should be set to 64. + assert callable_fn(*args) == 64 + + +@skip_if_no_torch +def test_gpu_kwargs(): + params = {'mock_param': 64} + a = torch.randn(12, device='cuda', dtype=torch.float32) + args = [a] # we do not have to specify the kwarg here + ks = KernelSource("kernel_with_kwarg", kernel_with_kwarg, "generic_python", call_function=call_mock) + dev = DeviceInterface(ks) + + instance_data = ks.prepare_kernel_instance( + kernel_options=None, + params=params, + grid=None, + threads=None, + ) + dev = DeviceInterface(ks) + instance = KernelInstance( + name="kernel_with_kwarg", + kernel_source=ks, + kernel_string=None, + kernel_fn=instance_data.kernel_fn, + temp_files=instance_data.temp_files, + threads=None, + grid=None, + params=params, + arguments=None, + ) + gpu_args = dev.ready_argument_list(args) + callable_fn = dev.compile_kernel(instance, verbose=False, gpu_args=gpu_args) + + kwargs = dev.dev.build_gpu_kwargs(params) + + assert kwargs["mock_param"] == 64 + assert callable_fn(*args, **kwargs) == 64 + diff --git a/test/test_kernel_source_fn.py b/test/test_kernel_source_fn.py new file mode 100644 index 000000000..e23a41a15 --- /dev/null +++ b/test/test_kernel_source_fn.py @@ -0,0 +1,160 @@ +import pytest +import inspect +import ast +import functools # for an example decorator +from kernel_tuner.kernel_sources.kernel_source import KernelSource +# Note: need subclass imports to test for types and instantiate subclasses direclty. +# Normally, just importing KernelSource is enough. +from kernel_tuner.kernel_sources.kernel_source_fn import KernelSourceFn +from kernel_tuner.kernel_sources.kernel_source_str import KernelSourceStr + + +# Helper functions -------------------------------------------- + +def normalize_ast(src: str): + return ast.dump(ast.parse(src), include_attributes=False) + +def mock_kernel(a, b): + mock_param = 256 + if a < mock_param: + b = 42 + return mock_param + +def kernel_with_kwarg(a, mock_param): + return mock_param + +# NOTE should kernel_with_kwarg now return the tuning param value for mock param or 5? +def kernel_with_dependency(): + mock_param = 42 + foo = kernel_with_kwarg(mock_param, 5) + return mock_kernel(42, 42) + +def call_mock(kernel_function, args, kwargs, grid, threads, params): + kernel_function(*args, **kwargs) + +class TilusLike(): + def __init__(self): + self.mock_param = 32 + + def __call__(self): + return self.mock_param + + def another_function(self): + return self.mock_param + + +# Tests ------------------------------------------------------ + +def test_factory_behaviour(): + # KernelSourceFn should only be created when language generic_python is supplied + ks_fn = KernelSource("mock_kernel", mock_kernel, "generic_python", call_function=call_mock) + ks_str = KernelSource("vector_add", 'extern "C" __global__ void vector_add(float *c, float *a, float *b, int n) {', lang=None) + + assert isinstance(ks_fn, KernelSourceFn) + assert isinstance(ks_str, KernelSourceStr) + + +def test_initiation(): + ''' + Test invalid KernelSourceFn initations + ''' + with pytest.raises(ValueError, match=r"call_function must be supplied for language .*"): + KernelSource("mock_kernel", mock_kernel, "generic_python") + + with pytest.raises(TypeError, match=r".* is not a callable object"): + KernelSource("mock_kernel", "This is a string Kernel", "generic_python", call_function=call_mock) + + with pytest.raises(ValueError, match=r"KernelSourceFn only supports a single kernel source function"): + KernelSource("mock_kernel", [mock_kernel, kernel_with_dependency], "generic_python", call_function=call_mock) + + with pytest.raises(TypeError, match=r".* is not a callable object"): + KernelSource("mock_kernel", mock_kernel, "generic_python", call_function="not a function") + + with pytest.raises(ValueError, match=r".* is not a valid decorator"): + KernelSource("mock_kernel", mock_kernel, "generic_python", call_function=call_mock, decorator="not a decorator") + + with pytest.raises(TypeError, match=r".* is not a decorator"): + KernelSource("mock_kernel", mock_kernel, "generic_python", call_function=call_mock, decorator=mock_kernel) + + +def test_param_subsitution(): + params = {"mock_param": 128} + ks = KernelSourceFn("mock_kernel", mock_kernel, "generic_python", call_function=call_mock) + new_kernel_fn, _ = ks.apply_params_to_source_fn(params) + + actual_src = inspect.getsource(new_kernel_fn) + expected_src = """ +def mock_kernel(a, b): + mock_param = 128 + if a < 128: + b = 42 + return 128 +""" + + assert normalize_ast(actual_src) == normalize_ast(expected_src) + + +def test_imports(): + ''' + Import statements that are present in the file where the function lives + should also be present in the file where the new function lives. + ''' + + params = {"mock_param": 128} + ks = KernelSourceFn("mock_kernel", mock_kernel, "generic_python", call_function=call_mock) + _, temp_path = ks.apply_params_to_source_fn(params) + + # Check if imports are present + with open(temp_path) as f: + full_src = f.read() + + assert "import pytest" in full_src + assert "import inspect" in full_src + + +def test_param_substitution_class(): + ''' + Tilus uses the __call__ function of a class to define its kernel. Therefore, + param substitution should also work classes. Param substiution in other class + functions is also supported. + ''' + + params = {"mock_param": 128} + ks = KernelSourceFn("TilusLike", TilusLike, "generic_python", call_function=call_mock) + + new_kernel_fn, _ = ks.apply_params_to_source_fn(params) + + actual_src = inspect.getsource(new_kernel_fn) + expected_src = """ +class TilusLike(): + def __init__(self): + self.mock_param = 128 + + def __call__(self): + return 128 + + def another_function(self): + return 128 +""" + assert normalize_ast(actual_src) == normalize_ast(expected_src) + + + +def test_decorator(): + params = {"mock_param": 128} + ks = KernelSourceFn("mock_kernel", mock_kernel, "generic_python", call_function=call_mock, decorator="@functools.lru_cache()") + new_kernel_fn, _ = ks.apply_params_to_source_fn(params) + + assert hasattr(new_kernel_fn, "__wrapped__") + + +def test_dependencies(): + params = {"mock_param": 128} + ks = KernelSourceFn("kernel_with_dependency", kernel_with_dependency, "generic_python", call_function=call_mock) + new_kernel_fn, _ = ks.apply_params_to_source_fn(params) + res = new_kernel_fn() # This should not throw an error if the dependency exists in the module. + + assert res == 128 + + + From d21377e16758bd3544503b58076ca1130ccb19b9 Mon Sep 17 00:00:00 2001 From: Imke van Ooijen Date: Sun, 1 Mar 2026 10:40:14 +0100 Subject: [PATCH 04/14] improved features, removed explicit Triton support (now available through generic Python), cleaned up code --- .../generic_python/matmul/tilelang_matmul.py | 212 ++++++++++++ .../generic_python/matmul/tilus_matmul.py | 267 +++++++++++++++ .../generic_python/matmul/triton_matmul.py | 263 ++++++++++++++- examples/generic_python/numba_vec_add.py | 38 ++- examples/generic_python/tilelang_vec_add.py | 57 ++++ examples/generic_python/tilus_vec_add.py | 58 +++- examples/generic_python/triton_vec_add.py | 8 +- examples/generic_python/warp_vec_add.py | 36 ++- examples/triton/conv2d_tuning.py | 306 ------------------ examples/triton/vec_add.py | 50 --- kernel_tuner/backends/generic_python.py | 189 ++++++++--- kernel_tuner/backends/triton.py | 174 ---------- kernel_tuner/core.py | 32 +- kernel_tuner/interface.py | 35 +- kernel_tuner/kernel_sources/kernel_source.py | 32 +- .../kernel_sources/kernel_source_fn.py | 300 +++++++++-------- .../kernel_sources/kernel_source_str.py | 2 +- kernel_tuner/language.py | 1 - kernel_tuner/observers/triton.py | 32 -- test/test_generic_python_functions.py | 2 +- test/test_kernel_source_fn.py | 6 +- 21 files changed, 1254 insertions(+), 846 deletions(-) create mode 100644 examples/generic_python/matmul/tilelang_matmul.py create mode 100644 examples/generic_python/matmul/tilus_matmul.py create mode 100644 examples/generic_python/tilelang_vec_add.py delete mode 100644 examples/triton/conv2d_tuning.py delete mode 100644 examples/triton/vec_add.py delete mode 100644 kernel_tuner/backends/triton.py delete mode 100644 kernel_tuner/observers/triton.py diff --git a/examples/generic_python/matmul/tilelang_matmul.py b/examples/generic_python/matmul/tilelang_matmul.py new file mode 100644 index 000000000..2a7b5154d --- /dev/null +++ b/examples/generic_python/matmul/tilelang_matmul.py @@ -0,0 +1,212 @@ +import tilelang +import tilelang.language as T +import torch +from kernel_tuner import tune_kernel + + +#@tilelang.jit +def matmul(M, N, K, block_M, block_N, block_K, dtype: str = 'float16', accum_dtype: str = 'float32'): + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Define a grid with enough blocks to cover M×N + num_threads=128 + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): + + # Allocate shared memory for the current tile of A and B + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + + # Allocate a local (register) fragment for partial accumulations + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Enable swizzle-based rasterization for better L2 locality + panel_size = 4 + T.use_swizzle(panel_size=panel_size, enable=True) + + # Initialize the local accumulation buffer to zero + T.clear(C_local) + + num_stages=3 + + # Loop over the K dimension in block_K chunks, using a pipeline + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + # Copy from global memory to shared memory + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + + # Perform a matrix multiply-accumulate on the tile + T.gemm(A_shared, B_shared, C_local) + + # Copy the accumulated result from local memory (C_local) to global memory (C) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + +@tilelang.jit +def matmul_with_decorator(M, N, K, block_M, block_N, block_K, dtype: str = 'float16', accum_dtype: str = 'float32'): + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Define a grid with enough blocks to cover M×N + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + + # Allocate shared memory for the current tile of A and B + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + + # Allocate a local (register) fragment for partial accumulations + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Initialize the local accumulation buffer to zero + T.clear(C_local) + + # Loop over the K dimension in block_K chunks, using a 3-stage pipeline + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + # Copy from global memory to shared memory + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + + # Perform a matrix multiply-accumulate on the tile + T.gemm(A_shared, B_shared, C_local) + + # Copy the accumulated result from local memory (C_local) to global memory (C) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + + +def run(m, n, k): + a = torch.randn(m, k, device="cuda", dtype=torch.float16) + b = torch.randn(k, n, device="cuda", dtype=torch.float16) + c = torch.empty(m, n, device="cuda", dtype=torch.float16) + kernel = matmul(m, n, k, 128, 128, 32) + kernel(a, b, c) + ref_c = a @ b + tol = m * 2**(-11) + # Validate results + torch.testing.assert_close(c, ref_c, rtol=tol, atol=tol) + + +def call_tilelang(kernel_function, args, kwargs, grid, threads, params): + compiled_kernel = kernel_function(**kwargs) + compiled_kernel(*args) + + +def time(m, n, k): + a = torch.randn(m, k, device="cuda", dtype=torch.float16) + b = torch.randn(k, n, device="cuda", dtype=torch.float16) + c = torch.empty(m, n, device="cuda", dtype=torch.float16) + c_ans = a @ b + + args = [a, b, c] + tune_params = dict() + tune_params["M"] = [m] + tune_params["K"] = [k] + tune_params["N"] = [n] + tune_params["block_M"] = [64, 128] + tune_params["block_N"] = [64, 128] + tune_params["block_K"] = [32, 64] + + results_kt, env = tune_kernel("matmul", matmul, m * n, args, tune_params, lang="generic_python", + call_function=call_tilelang, decorator="@tilelang.jit", verbose=False, iterations=100) + + import time + num_repeats = 100 + times_direct = [] + for config in results_kt: + bs_m = config["block_M"] + bs_n = config["block_N"] + bs_k = config["block_K"] + + c = torch.empty(m, n, device="cuda", dtype=torch.float16) + + kernel = matmul_with_decorator(m, n, k, bs_m, bs_n, bs_k) + kernel(a, b, c) + + torch.allclose(c.cpu(), c_ans.cpu(), atol=m * 2**(-11)) + + for i in range(num_repeats): + times = [] + + torch.cuda.synchronize() + start = time.time() + kernel(a, b, c) + torch.cuda.synchronize() + times.append(time.time() - start) + + avg_time_ms = round((1000 * sum(times) / len(times)), 3) + times_direct.append(avg_time_ms) + print(f"BLOCK_SIZE_M={bs_m}, BLOCK_SIZE_N={bs_n}, BLOCK_SIZE_K={bs_k}, time={avg_time_ms}ms") + + + import matplotlib.pyplot as plt + + # Extract times + times_kt = [cfg['time'] for cfg in results_kt] + + # x-axis labels + configs = [f"config{i}" for i in range(len(times_kt))] + x = range(len(configs)) + + plt.figure(figsize=(10,6)) + plt.plot(configs, times_kt, marker='s', label='KernelTuner') + plt.plot(configs, times_direct, marker='x', label='Direct') + plt.ylabel('Time (ms)') + plt.xlabel('Configuration') + plt.title('Kernel execution time per configuration') + plt.xticks(rotation=45) + plt.legend() + plt.grid(True) + plt.tight_layout() + plt.savefig("tilelang.png") + print("saved fig") + + +def tune(m, n, k): + a = torch.randn(m, k, device="cuda", dtype=torch.float16) + b = torch.randn(k, n, device="cuda", dtype=torch.float16) + c = torch.empty(m, n, device="cuda", dtype=torch.float16) + c_actual = a @ b + + args = [a, b, c] + tune_params = dict() + tune_params["M"] = [m] + tune_params["K"] = [k] + tune_params["N"] = [n] + tune_params["block_M"] = [64, 128, 256] + tune_params["block_N"] = [64, 128, 256] + tune_params["block_K"] = [32, 64, 128] + tune_params["num_stages"] = [2, 3, 4] + tune_params["panel_size"] = [4, 8] # equivalent to group size m in Triton + tune_params["num_threads"] = [64, 128, 256] + + restrictions = [ + # tile size budget + "block_M * block_N <= 16384", + + # aspect ratio <= 4 (no max/min allowed, so expand manually) + "block_M <= 4 * block_N", + "block_N <= 4 * block_M", + + # large K only with reasonably large M/N + "not (block_K == 128 and block_M < 64 and block_N < 64)", + ] + + tol = m * 2**(-11) + answer = [None, None, c_actual.cpu()] + + results, env = tune_kernel("matmul", matmul, m * n, args, tune_params, atol=tol, lang="generic_python", + call_function=call_tilelang, restrictions=restrictions, answer=answer, decorator="@tilelang.jit", verbose=False) + +if __name__ == "__main__": + #m, n, k = 1024, 1024, 1024 + m, n, k = 8192, 8192, 8192 + time(m, n, k) \ No newline at end of file diff --git a/examples/generic_python/matmul/tilus_matmul.py b/examples/generic_python/matmul/tilus_matmul.py new file mode 100644 index 000000000..07532013f --- /dev/null +++ b/examples/generic_python/matmul/tilus_matmul.py @@ -0,0 +1,267 @@ +import math + +import pandas +import tilus +import torch +from tilus import float16, float32, int32 +from tilus.utils import benchmark_func +from kernel_tuner import tune_kernel, run_kernel + + + +class MatmulV4(tilus.Script): + def __init__(self): + super().__init__() + self.block_m = 128 + self.block_n = 128 + self.block_k = 16 + self.num_warps = 4 + self.num_stages = 4 + + def __call__( + self, + m_size: int32, + n_size: int, + k_size: int, + a_ptr: ~float16, + b_ptr: ~float16, + c_ptr: ~float16, + ): + self.attrs.blocks = [ + self.utils.ceil_div(m_size, self.block_m), + self.utils.ceil_div(n_size, self.block_n), + ] + self.attrs.warps = self.num_warps + + block_m, block_n, block_k = self.block_m, self.block_n, self.block_k + offset_m: int32 = block_m * self.blockIdx.x + offset_n: int32 = block_n * self.blockIdx.y + + ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) + gb = self.global_view(b_ptr, dtype=float16, shape=[k_size, n_size]) + sa = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_m, block_k]) + sb = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_k, block_n]) + acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0) + + for stage in range(self.num_stages - 1): + offset_k = stage * self.block_k + self.copy_async(src=ga, dst=sa[stage], offsets=[offset_m, offset_k]) + self.copy_async(src=gb, dst=sb[stage], offsets=[offset_k, offset_n]) + self.copy_async_commit_group() + + self.copy_async_wait_group(n=self.num_stages - 2) + self.sync() + + current_stage: int32 = 0 + preload_stage: int32 = self.num_stages - 1 + for offset_k in self.range(0, k_size, block_k, unroll=self.num_stages): + # computation for current tile + a = self.load_shared(sa[current_stage]) + b = self.load_shared(sb[current_stage]) + self.dot(a, b, acc, out=acc) + + # preload the next tile of A and B into shared memory + preload_offset_k = offset_k + (self.num_stages - 1) * block_k + self.copy_async( + src=ga, + dst=sa[preload_stage], + offsets=[offset_m, preload_offset_k], + ) + self.copy_async( + src=gb, + dst=sb[preload_stage], + offsets=[preload_offset_k, offset_n], + ) + self.copy_async_commit_group() + + # update the stage + current_stage = (current_stage + 1) % self.num_stages + preload_stage = (preload_stage + 1) % self.num_stages + self.copy_async_wait_group(n=self.num_stages - 2) + self.sync() + + self.free_shared(sa) + self.free_shared(sb) + + casted_acc = self.cast(acc, dtype=float16) + gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) + self.store_global(gc, casted_acc, offsets=[offset_m, offset_n]) + + + +class MatmulGroupedOrdering(tilus.Script): + def __init__(self): + super().__init__() + self.block_m = 128 + self.block_n = 128 + self.block_k = 16 + self.num_warps = 4 + self.num_stages = 4 + self.group_size_m = 8 + + def __call__( + self, + m_size: int32, + n_size: int, + k_size: int, + a_ptr: ~float16, + b_ptr: ~float16, + c_ptr: ~float16, + ): + block_m, block_n, block_k = self.block_m, self.block_n, self.block_k + + num_pid_m = self.utils.ceil_div(m_size, block_m) + num_pid_n = self.utils.ceil_div(n_size, block_n) + self.attrs.blocks = [num_pid_m * num_pid_n] + + pid = self.blockIdx.x + num_pid_in_group = self.group_size_m * num_pid_n + group_id = pid // num_pid_in_group + + first_pid_m = group_id * self.group_size_m + group_size_m = min(num_pid_m - first_pid_m, self.group_size_m) + + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + self.attrs.warps = self.num_warps + offset_m: int32 = pid_m * block_m + offset_n: int32 = pid_n * block_n + + ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) + gb = self.global_view(b_ptr, dtype=float16, shape=[k_size, n_size]) + sa = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_m, block_k]) + sb = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_k, block_n]) + acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0) + + for stage in range(self.num_stages - 1): + offset_k = stage * self.block_k + self.copy_async(src=ga, dst=sa[stage], offsets=[offset_m, offset_k]) + self.copy_async(src=gb, dst=sb[stage], offsets=[offset_k, offset_n]) + self.copy_async_commit_group() + + self.copy_async_wait_group(n=self.num_stages - 2) + self.sync() + + current_stage: int32 = 0 + preload_stage: int32 = self.num_stages - 1 + for offset_k in self.range(0, k_size, block_k, unroll=self.num_stages): + # computation for current tile + a = self.load_shared(sa[current_stage]) + b = self.load_shared(sb[current_stage]) + self.dot(a, b, acc, out=acc) + + # preload the next tile of A and B into shared memory + preload_offset_k = offset_k + (self.num_stages - 1) * block_k + self.copy_async( + src=ga, + dst=sa[preload_stage], + offsets=[offset_m, preload_offset_k], + ) + self.copy_async( + src=gb, + dst=sb[preload_stage], + offsets=[preload_offset_k, offset_n], + ) + self.copy_async_commit_group() + + # update the stage + current_stage = (current_stage + 1) % self.num_stages + preload_stage = (preload_stage + 1) % self.num_stages + self.copy_async_wait_group(n=self.num_stages - 2) + self.sync() + + self.free_shared(sa) + self.free_shared(sb) + + casted_acc = self.cast(acc, dtype=float16) + gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) + self.store_global(gc, casted_acc, offsets=[offset_m, offset_n]) + + + +def main(): + headers = ["m", "n", "k", "name", "latency (ms)", "tflops"] + workloads = [ + [4096, 4096, 4096], + [1024, 1024, 14336], + ] + + rows = [] + for m, n, k in workloads: + matmul = MatmulGroupedOrdering() #MatmulV4() + + a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + b = (torch.rand(k, n, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + c_expect = a @ b + matmul(m, n, k, a, b, c_actual) + + # check correctness + torch.testing.assert_close(c_expect, c_actual) + + # benchmark + for name, func in [ + ("torch", lambda: torch.matmul(a, b, out=c_expect)), + ("tilus", lambda: matmul(m, n, k, a, b, c_actual)), + ]: + latency = benchmark_func(func, warmup=5, repeat=20) + tflops = 2 * m * n * k / latency * 1e-9 + rows.append([m, n, k, name, latency, tflops]) + + df = pandas.DataFrame(rows, columns=headers) + print(df) + +def call_tilus(kernel_function, args, kwargs, grid, threads, params): + kernel_function(*args, **kwargs) + + +def tune_matmul(m, n, k): + a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + b = (torch.rand(k, n, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + c_expect = a @ b + + size = m * n #(m, n) + args = [m, n, k, a, b, c_actual] + tune_params = dict() + tune_params["block_m"] = [32, 64, 128, 256] + tune_params["block_n"] = [32, 64, 128, 256] + tune_params["block_k"] = [32, 64, 128] + tune_params["group_size_m"] = [4, 8] + tune_params["num_stages"] = [2, 3, 4] + tune_params["num_warps"] = [4, 8] + + + restrictions = [ + # tile size budget + "block_m * block_n <= 16384", + + # aspect ratio <= 4 (no max/min allowed, so expand manually) + "block_m <= 4 * block_n", + "block_n <= 4 * block_m", + + # large K only with reasonably large M/N + "not (block_k == 128 and block_m < 64 and block_n < 64)", + + # 32x32 requires 8 warps + "not (block_m == 32 and block_n == 32 and num_warps < 8)", + ] + + + answer = [None] * 6 + answer[-1] = c_expect.cpu() + atol = 1e-2 #m * 2**(-11) + + results, env = tune_kernel("MatmulGroupedOrdering", MatmulGroupedOrdering, size, args, tune_params, grid_div_x = ["block_m", "block_n"], + answer = answer, atol=atol, restrictions=restrictions, + lang="generic_python", call_function=call_tilus, + block_size_names=["block_m", "block_n", "block_k"], strategy="simulated_annealing") + + +if __name__ == "__main__": + #m, n, k = 4096, 4096, 4096 + m, n, k = 8192, 8192, 8192 + tune_matmul(m, n, k) + + #main() \ No newline at end of file diff --git a/examples/generic_python/matmul/triton_matmul.py b/examples/generic_python/matmul/triton_matmul.py index 7115ea5cc..238a61c9f 100644 --- a/examples/generic_python/matmul/triton_matmul.py +++ b/examples/generic_python/matmul/triton_matmul.py @@ -3,7 +3,8 @@ import triton import triton.language as tl -DEVICE = triton.runtime.driver.active.get_active_torch_device() +from kernel_tuner import tune_kernel +from kernel_tuner import run_kernel def get_cuda_autotune_config(): return [ @@ -48,7 +49,7 @@ def get_cuda_autotune_config(): key=['M', 'N', 'K'], ) ''' -@triton.jit +#@triton.jit def matmul_kernel( # Pointers to matrices a_ptr, b_ptr, c_ptr, @@ -62,7 +63,7 @@ def matmul_kernel( stride_cm, stride_cn, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # - GROUP_SIZE_M: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, num_stages, num_warps# ): """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) @@ -132,23 +133,251 @@ def matmul_kernel( tl.store(c_ptrs, c, mask=c_mask) -def matmul(a, b): - # Check constraints. - assert a.shape[1] == b.shape[0], "Incompatible dimensions" - assert a.is_contiguous(), "Matrix A must be contiguous" - M, K = a.shape - K, N = b.shape - # Allocates output. - c = torch.empty((M, N), device=a.device, dtype=torch.float16) - # 1D launch kernel where each block gets its own program. - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) +def run_matmul(m, n, k): + a = torch.rand(m, k, dtype=torch.float16).cuda() + b = torch.rand(k, n, dtype=torch.float16).cuda() + c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + c_expect = a @ b + + grid = lambda META: (triton.cdiv(m, META['BLOCK_SIZE_M']) * triton.cdiv(n, META['BLOCK_SIZE_N']), ) + matmul_kernel[grid]( - a, b, c, # - M, N, K, # + a, b, c_actual, # + m, n, k, # a.stride(0), a.stride(1), # b.stride(0), b.stride(1), # - c.stride(0), c.stride(1) + c_actual.stride(0), c_actual.stride(1), + 128, 256, 64, 8 ) - return c + + torch.testing.assert_close(c_expect, c_actual, atol=1e-2, rtol=1e-2) + + +def run_matmul_kt(m, n, k): + a = torch.rand(m, k, dtype=torch.float16).cuda() + b = torch.rand(k, n, dtype=torch.float16).cuda() + c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + c_expect = a @ b + + size = m * n + + args = [a, b, c_actual, m, n, k, a.stride(0), a.stride(1), b.stride(0), b.stride(1), + c_actual.stride(0), c_actual.stride(1)] + + params = {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M":4, "num_stages":3, "num_warps":4} + + result = run_kernel("matmul_kernel", matmul_kernel, size, args, params=params, grid_div_x=["BLOCK_SIZE_N", "BLOCK_SIZE_M"], + lang="generic_python", decorator="@triton.jit", call_function=call_triton, + block_size_names=["BLOCK_SIZE_M", "BLOCK_SIZE_N", "BLOCK_SIZE_K"]) + c_res = result[2] + + assert torch.allclose(c_res, c_expect.cpu(), atol=1e-2, rtol=1e-1) + + + + + + +def call_triton(kernel_function, args, kwargs, grid, threads, params): + #print("using grid: ", grid) + #print("args: ", args) + #print("kwargs: ", kwargs) + kernel_function[grid](*args, **kwargs) + + + + +def check_time(m, n, k): + a = torch.rand(m, k, dtype=torch.float16).cuda() + b = torch.rand(k, n, dtype=torch.float16).cuda() + c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + c_expect = a @ b + + size = m * n + args = [a, b, c_actual, m, n, k, a.stride(0), a.stride(1), b.stride(0), b.stride(1), + c_actual.stride(0), c_actual.stride(1)] + tune_params = dict() + tune_params["BLOCK_SIZE_M"] = [64, 128] + tune_params["BLOCK_SIZE_N"] = [64, 128] + tune_params["BLOCK_SIZE_K"] = [32, 64] + tune_params["GROUP_SIZE_M"] = [4, 8] + tune_params["num_stages"] = [3] + tune_params["num_warps"] = [4] + + restrictions = [ + # tile size budget + "BLOCK_SIZE_M * BLOCK_SIZE_N <= 16384", + + # aspect ratio <= 4 (no max/min allowed, so expand manually) + "BLOCK_SIZE_M <= 4 * BLOCK_SIZE_N", + "BLOCK_SIZE_N <= 4 * BLOCK_SIZE_M", + + # large K only with reasonably large M/N + "not (BLOCK_SIZE_K == 128 and BLOCK_SIZE_M < 64 and BLOCK_SIZE_N < 64)", + + # 32x32 requires 8 warps + "not (BLOCK_SIZE_M == 32 and BLOCK_SIZE_N == 32 and num_warps < 8)", + ] + + grid_div = ["BLOCK_SIZE_N", "BLOCK_SIZE_M"] + + answer = [None] * 12 + answer[2] = c_expect.cpu() + atol = 1e-2 #m * 2**(-11) + + + + results_ours, _ = tune_kernel("matmul_kernel", matmul_kernel, size, args, tune_params, grid_div_x = grid_div, + restrictions=restrictions, iterations=100, answer=answer, atol=atol, + lang="generic_python", decorator="@triton.jit", call_function=call_triton, + block_size_names=["BLOCK_SIZE_M", "BLOCK_SIZE_N", "BLOCK_SIZE_K"]) + + + results_prev, _ = tune_kernel("matmul_kernel", matmul_kernel, size, args, tune_params, grid_div_x = grid_div, + restrictions=restrictions, iterations=100, answer=answer, atol=atol, + lang="triton", block_size_names=["BLOCK_SIZE_M", "BLOCK_SIZE_N", "BLOCK_SIZE_K"]) + + + + import time + num_repeats = 100 + times_direct = [] + for config in results_prev: + + bs_m = config["BLOCK_SIZE_M"] + bs_n = config["BLOCK_SIZE_N"] + bs_k = config["BLOCK_SIZE_K"] + gs_m = config["GROUP_SIZE_M"] + num_stages = config["num_stages"] + num_warps = config["num_warps"] + + c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + + grid = (triton.cdiv(m, bs_m) * triton.cdiv(n, bs_n, ), ) + jit_function = triton.jit(matmul_kernel) + + + jit_function[grid]( + a, b, c_actual, + m, n, k, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c_actual.stride(0), c_actual.stride(1), + bs_m, bs_n, bs_k, gs_m, num_stages, num_warps + ) + + + torch.allclose(c_expect.cpu(), c_actual.cpu(), atol=1e-2) + + + + for i in range(num_repeats): + times = [] + + #c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + + torch.cuda.synchronize() + start = time.time() + jit_function[grid]( + a, b, c_actual, + m, n, k, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c_actual.stride(0), c_actual.stride(1), + bs_m, bs_n, bs_k, gs_m, num_stages, num_warps + ) + + torch.cuda.synchronize() + times.append(time.time() - start) + + avg_time_ms = round((1000 * sum(times) / len(times)), 3) + times_direct.append(avg_time_ms) + print(f"BLOCK_SIZE_M={bs_m}, BLOCK_SIZE_N={bs_n}, BLOCK_SIZE_K={bs_k}, GROUP_SIZE_M={gs_m}, num_stages={num_stages}, num_warps={num_warps}, time={avg_time_ms}ms") + + + import matplotlib.pyplot as plt + + # Extract times + times_prev = [cfg['time'] for cfg in results_prev] + times_ours = [cfg['time'] for cfg in results_ours] + + # x-axis labels + configs = [f"config{i}" for i in range(len(times_prev))] + x = range(len(configs)) + + plt.figure(figsize=(10,6)) + plt.plot(configs, times_prev, marker='o', label='Triton tuned') + plt.plot(configs, times_ours, marker='s', label='Generic tuned') + plt.plot(configs, times_direct, marker='x', label='Direct') + plt.ylabel('Time (ms)') + plt.xlabel('Configuration') + plt.title('Kernel execution time per configuration') + plt.xticks(rotation=45) + plt.legend() + plt.grid(True) + plt.tight_layout() + plt.savefig("ouptut.png") + + + + + + +def tune_matmul(m, n, k): + a = torch.rand(m, k, dtype=torch.float16).cuda() + b = torch.rand(k, n, dtype=torch.float16).cuda() + c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + c_expect = a @ b + + size = m * n + args = [a, b, c_actual, m, n, k, a.stride(0), a.stride(1), b.stride(0), b.stride(1), + c_actual.stride(0), c_actual.stride(1)] + tune_params = dict() + tune_params["BLOCK_SIZE_M"] = [64, 128, 256] + tune_params["BLOCK_SIZE_N"] = [64, 128, 256] + tune_params["BLOCK_SIZE_K"] = [32, 64, 128] + tune_params["GROUP_SIZE_M"] = [4, 8] + tune_params["num_stages"] = [2, 3, 4] + tune_params["num_warps"] = [4, 8] + + restrictions = [ + # tile size budget + "BLOCK_SIZE_M * BLOCK_SIZE_N <= 16384", + + # aspect ratio <= 4 (no max/min allowed, so expand manually) + "BLOCK_SIZE_M <= 4 * BLOCK_SIZE_N", + "BLOCK_SIZE_N <= 4 * BLOCK_SIZE_M", + + # large K only with reasonably large M/N + "not (BLOCK_SIZE_K == 128 and BLOCK_SIZE_M < 64 and BLOCK_SIZE_N < 64)", + + # 32x32 requires 8 warps + "not (BLOCK_SIZE_M == 32 and BLOCK_SIZE_N == 32 and num_warps < 8)", + ] + + grid_div = ["BLOCK_SIZE_N", "BLOCK_SIZE_M"] + + answer = [None] * 12 + answer[2] = c_expect.cpu() + + results, env = tune_kernel("matmul_kernel", matmul_kernel, size, args, tune_params, grid_div_x = grid_div, + answer = answer, atol=4.0, restrictions=restrictions, + lang="generic_python", decorator="@triton.jit", call_function=call_triton, + block_size_names=["BLOCK_SIZE_M", "BLOCK_SIZE_N", "BLOCK_SIZE_K"], strategy="simulated_annealing") + + + + + +if __name__ == "__main__": + m, n, k = 8192, 8192, 8192 + #m, n, k = 4096, 4096, 4096 + #tune_matmul(m, n, k) + #check_time(m, n, k) + #run_matmul_kt(m, n, k) + + + + diff --git a/examples/generic_python/numba_vec_add.py b/examples/generic_python/numba_vec_add.py index 3732aa018..e49f9e372 100644 --- a/examples/generic_python/numba_vec_add.py +++ b/examples/generic_python/numba_vec_add.py @@ -1,7 +1,8 @@ -import numpy as np from numba import cuda from math import ceil from kernel_tuner import tune_kernel, run_kernel +import torch + #@cuda.jit def f(a, b, c): @@ -13,7 +14,13 @@ def f(a, b, c): def call_numba(kernel_function, args, kwargs, grid, threads, params): - kernel_function[grid[0], threads[0]](*args, **kwargs) + numba_args = [] + for arg in args: + if isinstance(arg, torch.Tensor): + numba_args.append(cuda.as_cuda_array(arg)) + else: + numba_args.append(arg) + kernel_function[grid, threads](*args, **kwargs) def verify(answer, result_host, atol): @@ -28,15 +35,25 @@ def verify(answer, result_host, atol): return correct +import numpy as np N = 100000000 -a = cuda.to_device(np.random.random(N)) -b = cuda.to_device(np.random.random(N)) -c = cuda.device_array_like(a) -c_expect = a.copy_to_host() + b.copy_to_host() + +a = np.random.random(N) +b = np.random.random(N) +c = np.zeros(N) +c_expect = a + b +#c_expect = a.copy_to_host() + b.copy_to_host() + + +#a_torch = torch.rand(N, dtype=torch.float32) +#b_torch = torch.rand(N, dtype=torch.float32) +#c_torch = torch.zeros(N, dtype=torch.float32) +#c_expect = (a_torch + b_torch).cpu() args = [a, b, c] tune_params = {"block_size_x": [2**i for i in range(10)]} +''' results = run_kernel( kernel_name="f", kernel_source=f, @@ -47,11 +64,12 @@ def verify(answer, result_host, atol): call_function=call_numba, decorator="@cuda.jit" ) +''' -print(np.allclose(results[2], c_expect)) +#print(np.allclose(results[2], c_expect)) + -''' results, env = tune_kernel( kernel_name="f", kernel_source=f, @@ -60,9 +78,9 @@ def verify(answer, result_host, atol): tune_params=tune_params, lang="generic_python", answer=[None, None, c_expect], - verify=verify, + #verify=verify, call_function=call_numba, decorator="@cuda.jit" ) -''' + diff --git a/examples/generic_python/tilelang_vec_add.py b/examples/generic_python/tilelang_vec_add.py new file mode 100644 index 000000000..3ff3c9376 --- /dev/null +++ b/examples/generic_python/tilelang_vec_add.py @@ -0,0 +1,57 @@ +import tilelang +import tilelang.language as T +import torch +from kernel_tuner import tune_kernel + +#@tilelang.jit # infers target from tensors at first call +def add(N: int, dtype: str = 'float32', block: int = 256,): + + @T.prim_func + def add_kernel( + A: T.Tensor((N,), dtype), + B: T.Tensor((N,), dtype), + C: T.Tensor((N,), dtype), + ): + with T.Kernel(T.ceildiv(N, block), threads=block) as bx: + for i in T.Parallel(block): + gi = bx * block + i + # Optional — LegalizeSafeMemoryAccess inserts a guard when an access may be OOB + C[gi] = A[gi] + B[gi] + + return add_kernel + + +def run_normal(): + # Host side (PyTorch shown; NumPy/DLPack also supported) + N = 1 << 20 + A = torch.randn(N, device='cuda', dtype=torch.float32) + B = torch.randn(N, device='cuda', dtype=torch.float32) + C = torch.empty(N, device='cuda', dtype=torch.float32) + + kernel = add(N) + kernel(A, B, C) # runs on GPU + torch.testing.assert_close(C, A + B) + print("done") + + +def call_tilelang(kernel_function, args, kwargs, grid, threads, params): + compiled_kernel = kernel_function(**kwargs) # cached, so second time only cache lookup is performed + compiled_kernel(*args) + +def tune(): + N = 1 << 20 + A = torch.randn(N, device='cuda', dtype=torch.float32) + B = torch.randn(N, device='cuda', dtype=torch.float32) + C = torch.empty(N, device='cuda', dtype=torch.float32) + + args = [A, B, C] + tune_params = dict() + tune_params["block"] = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] + tune_params["N"] = [N] # Not a tune param, but enables an easier call function + + answer = [None, None, (A + B).cpu()] + + res, env = tune_kernel("add", add, N, args, tune_params, lang="generic_python", + call_function=call_tilelang, decorator="@tilelang.jit", answer=answer) + +tune() \ No newline at end of file diff --git a/examples/generic_python/tilus_vec_add.py b/examples/generic_python/tilus_vec_add.py index dce785909..161575181 100644 --- a/examples/generic_python/tilus_vec_add.py +++ b/examples/generic_python/tilus_vec_add.py @@ -41,13 +41,10 @@ def call_tilus(kernel_function, args, kwargs, grid, threads, params): kernel_function(*args, **kwargs) -def main(): - - size = 1024000 - - a = torch.randn(size, dtype=torch.float32) - b = torch.randn(size, dtype=torch.float32) - c = torch.zeros_like(b) +def tune_vecadd(size): + a = torch.randn(size, dtype=torch.float32).cuda() + b = torch.randn(size, dtype=torch.float32).cuda() + c = torch.empty(size, dtype=torch.float32).cuda() c_expect = a + b @@ -55,8 +52,6 @@ def main(): tune_params = dict() tune_params["block_size_x"] = [32, 64, 128, 256, 512, 1024] - - results, env = tune_kernel( kernel_name="VecAddV", kernel_source=VecAddV, @@ -69,7 +64,50 @@ def main(): verbose=True, ) + +# TODO run kernel error handling same as tune_kernel +def run_vecadd(size): + a = torch.randn(size, dtype=torch.float32).cuda() + b = torch.randn(size, dtype=torch.float32).cuda() + c = torch.empty(size, dtype=torch.float32).cuda() + c_expect = a + b + args = [size, a, b, c] + + results = run_kernel( + kernel_name="VecAddV", + kernel_source=VecAddV, + problem_size=size, + arguments=args, + params={"block_size_x": 32}, + lang="generic_python", + call_function=call_tilus, + #verbose=True, + ) + + c_expect = c_expect.cpu() + + assert torch.allclose(results[-1], c_expect) + + + +def run_normal(size): + a = torch.randn(size, dtype=torch.float32).cuda() + b = torch.randn(size, dtype=torch.float32).cuda() + c = torch.empty(size, dtype=torch.float32).cuda() + c_expect = a + b + + vecadd = VecAddV() + vecadd(size, a, b, c) + print(c) + print(c_expect) + assert torch.allclose(c, c_expect) + + + if __name__ == "__main__": - main() + size = 1024 + tune_vecadd(size) + #run_vecadd(size) + #run_normal(size) diff --git a/examples/generic_python/triton_vec_add.py b/examples/generic_python/triton_vec_add.py index be28f2003..fcc019615 100644 --- a/examples/generic_python/triton_vec_add.py +++ b/examples/generic_python/triton_vec_add.py @@ -12,7 +12,7 @@ def add_op(x, y): return x + y -#triton.jit +#@triton.jit def add_kernel(x_ptr, # *Pointer* to first input vector. y_ptr, # *Pointer* to second input vector. output_ptr, # *Pointer* to output vector. @@ -50,11 +50,12 @@ def tune_with_generic(): tune_params = dict() tune_params["block_size_x"] = [2**i for i in range(10)] - ''' + result = run_kernel("add_kernel", add_kernel, size, args, {"block_size_x": 256}, lang="generic_python", call_function=call_triton, decorator="@triton.jit") print(np.allclose(c_expect.cpu(), result[2])) - ''' + + results, env = tune_kernel( kernel_name="add_kernel", @@ -69,6 +70,7 @@ def tune_with_generic(): ) + diff --git a/examples/generic_python/warp_vec_add.py b/examples/generic_python/warp_vec_add.py index dc08dd6d0..4a2c8d350 100644 --- a/examples/generic_python/warp_vec_add.py +++ b/examples/generic_python/warp_vec_add.py @@ -1,6 +1,9 @@ import warp as wp import numpy as np from kernel_tuner import tune_kernel, run_kernel +import torch + +wp.init() @@ -13,9 +16,9 @@ def add_op(x: float, y: float): def vec_add(a: wp.array(dtype=float), b: wp.array(dtype=float), c: wp.array(dtype=float), - n: int, - work_per_thread: int): + n: int): + work_per_thread = 8 tid = wp.tid() base = tid * work_per_thread @@ -28,10 +31,14 @@ def vec_add(a: wp.array(dtype=float), # TODO do we allways want the call function to have the same parameters # or do we only require some of them? def call_warp(kernel_function, args, kwargs, grid, threads, params): - final_args = list(args) - final_args.extend(kwargs.values()) + warp_args = [] + for arg in args: + if isinstance(arg, torch.Tensor): + warp_args.append(wp.from_torch(arg)) + else: + warp_args.append(arg) dim = args[3] - wp.launch(kernel=kernel_function, dim=dim, inputs=final_args) + wp.launch(kernel=kernel_function, dim=dim, inputs=warp_args) # NOTE default verify function only works for numpy/cupy ndarray, torch Tensor or numpy scalar @@ -41,6 +48,8 @@ def verify(answer, result_host, atol): for i, ans in enumerate(answer): if ans is None: continue + print("res: ", type(result_host[i])) + print("expect: ", type(ans)) res = result_host[i].numpy() if not np.allclose(ans, res, atol=atol): correct = False @@ -53,19 +62,15 @@ def tune(): n = 1024 # Create host arrays - a_np = np.arange(n, dtype=np.float32) - b_np = np.arange(n, 0, -1, dtype=np.float32) - c_np = np.zeros(n, dtype=np.float32) - c_expect = a_np + b_np + a_torch = torch.arange(n, dtype=torch.float32, device="cuda") + b_torch = torch.arange(n, 0, -1, dtype=torch.float32, device="cuda") + c_torch = torch.zeros(n, dtype=torch.float32, device="cuda") + c_expect = a_torch + b_torch - # Create Warp arrays on GPU - a = wp.array(a_np, dtype=float) - b = wp.array(b_np, dtype=float) - c = wp.array(c_np, dtype=float) tune_params = dict() tune_params["work_per_thread"] = [2**i for i in range(10)] - args = [a, b, c, n] + args = [a_torch, b_torch, c_torch, n] ''' @@ -88,8 +93,7 @@ def tune(): arguments=args, tune_params=tune_params, lang="generic_python", - answer=[None, None, c_expect, None], - verify=verify, + answer=[None, None, c_expect.cpu(), None], call_function=call_warp, decorator="@wp.kernel" ) diff --git a/examples/triton/conv2d_tuning.py b/examples/triton/conv2d_tuning.py deleted file mode 100644 index 5389b4a28..000000000 --- a/examples/triton/conv2d_tuning.py +++ /dev/null @@ -1,306 +0,0 @@ -import torch -import triton.language as tl -import numpy as np -from kernel_tuner.interface import tune_kernel -import os -import json -from datetime import datetime - -# Check for required environment variable -cache_dir = os.getenv('KERNEL_TUNER_CACHE_DIR') -cache_file_name = os.getenv('KERNEL_TUNER_CACHE_FILE', 'conv2d_tuning_results.json') - -if cache_dir is None: - raise ValueError("Environment variable KERNEL_TUNER_CACHE_DIR must be set") - -cache_file = os.path.join(cache_dir, cache_file_name) - - -def conv2d_output_size( - in_size: int, - kernel_size: int, - stride: int, - padding: int, - dilation: int, -) -> int: - """ - Determines the output size of a 2D convolution operation. - - Args: - in_size: Input size. - kernel_size: Kernel size. - stride: Stride. - padding: Padding. - - Returns: - Output size of 2D convolution. - """ - return (in_size + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1 - - -def conv2d_forward_kernel( - input_pointer, - weight_pointer, - output_pointer, - in_n, - input_height, - input_width, - out_c, - out_height, - out_width, - input_n_stride, - input_c_stride, - input_height_stride, - input_width_stride, - weight_n_stride, - weight_c_stride, - weight_height_stride, - weight_width_stride, - output_n_stride, - output_c_stride, - output_height_stride, - output_width_stride, - weight_c: tl.constexpr, - weight_height: tl.constexpr, - weight_width: tl.constexpr, - stride_height: tl.constexpr, - stride_width: tl.constexpr, - padding_height: tl.constexpr, - padding_width: tl.constexpr, - dilation_height: tl.constexpr, - dilation_width: tl.constexpr, - groups: tl.constexpr, - BLOCK_NI_HO_WO: tl.constexpr, - BLOCK_CI: tl.constexpr, - BLOCK_CO: tl.constexpr, -): - pid_ni_ho_wo = tl.program_id(0) - pid_co = tl.program_id(1) - pid_group = tl.program_id(2) - - # caculate in_n out_height out_weight value in kernel - ni_ho_wo_offset = pid_ni_ho_wo * BLOCK_NI_HO_WO + tl.arange(0, BLOCK_NI_HO_WO) - ni_ho_offset = ni_ho_wo_offset // out_width - in_n_point_value = ni_ho_offset // out_height - output_height_point_value = ni_ho_offset % out_height - output_width_point_value = ni_ho_wo_offset % out_width - - # Load the input and weight pointers. input and weight are of shape - # [in_n, groups, in_c, input_height, input_width] and [groups, out_c, in_c, weight_height, weight_width] - out_per_group_c = out_c // groups - output_c_offset = pid_co * BLOCK_CO + tl.arange(0, BLOCK_CO) - input_pointer += ( - input_n_stride * in_n_point_value + input_c_stride * pid_group * weight_c - )[:, None] - weight_pointer += ( - weight_n_stride * output_c_offset - + weight_n_stride * pid_group * out_per_group_c - )[None, :] - - accum = tl.zeros((BLOCK_NI_HO_WO, BLOCK_CO), dtype=tl.float32) - BLOCK_CI_COUNT = (weight_c + BLOCK_CI - 1) // BLOCK_CI - for hwc in range(weight_height * weight_width * BLOCK_CI_COUNT): - c = (hwc % BLOCK_CI_COUNT) * BLOCK_CI - hw = hwc // BLOCK_CI_COUNT - h = hw // weight_width - w = hw % weight_width - - input_c_offset = c + tl.arange(0, BLOCK_CI) - input_height_offset = ( - h * dilation_height - - padding_height - + stride_height * output_height_point_value - ) - input_width_offset = ( - w * dilation_width - padding_width + stride_width * output_width_point_value - ) - - curr_input_pointer = ( - input_pointer - + (input_c_stride * input_c_offset)[None, :] - + (input_height_stride * input_height_offset)[:, None] - + (input_width_stride * input_width_offset)[:, None] - ) - curr_weight_pointer = ( - weight_pointer - + (weight_c_stride * input_c_offset)[:, None] - + (weight_height_stride * h) - + (weight_width_stride * w) - ) - - input_mask = ( - (in_n_point_value < in_n)[:, None] - & (input_c_offset < weight_c)[None, :] - & (0 <= input_height_offset)[:, None] - & (input_height_offset < input_height)[:, None] - & (0 <= input_width_offset)[:, None] - & (input_width_offset < input_width)[:, None] - ) - weight_mask = (input_c_offset < weight_c)[:, None] & ( - output_c_offset < out_per_group_c - )[None, :] - - input_block = tl.load(curr_input_pointer, mask=input_mask) - weight_block = tl.load(curr_weight_pointer, mask=weight_mask) - - accum += tl.dot(input_block, weight_block, allow_tf32=False) - - output_pointer += ( - (output_n_stride * in_n_point_value)[:, None] - + (output_c_stride * (pid_group * out_per_group_c + output_c_offset))[None, :] - + (output_height_stride * output_height_point_value)[:, None] - + (output_width_stride * output_width_point_value)[:, None] - ) - output_mask = ( - (in_n_point_value < in_n)[:, None] - & (output_c_offset < out_per_group_c)[None, :] - & (output_height_point_value < out_height)[:, None] - & (output_width_point_value < out_width)[:, None] - ) - - tl.store(output_pointer, accum, mask=output_mask) - - -def tune_conv2d(batch_size=1, in_channels=64, height=32, width=32, - out_channels=128, kernel_size=3, stride=1, padding=1, - groups=1): - """ - Tune the conv2d kernel with different configurations. - """ - # Create sample inputs - input = torch.randn(batch_size, in_channels, height, width, - device='cuda', dtype=torch.float32) - weight = torch.randn(out_channels, in_channels//groups, kernel_size, kernel_size, - device='cuda', dtype=torch.float32) - - # Calculate output dimensions - out_height = conv2d_output_size(height, kernel_size, stride, padding, 1) - out_width = conv2d_output_size(width, kernel_size, stride, padding, 1) - output = torch.empty((batch_size, out_channels, out_height, out_width), - device='cuda', dtype=torch.float32) - - # Prepare all arguments for the kernel - arguments = [ - input, weight, output, - np.int32(batch_size), - np.int32(height), - np.int32(width), - np.int32(out_channels), - np.int32(out_height), - np.int32(out_width), - np.int32(input.stride(0)), - np.int32(input.stride(1)), - np.int32(input.stride(2)), - np.int32(input.stride(3)), - np.int32(weight.stride(0)), - np.int32(weight.stride(1)), - np.int32(weight.stride(2)), - np.int32(weight.stride(3)), - np.int32(output.stride(0)), - np.int32(output.stride(1)), - np.int32(output.stride(2)), - np.int32(output.stride(3)), - np.int32(in_channels//groups), # weight_c - np.int32(kernel_size), # weight_height - np.int32(kernel_size), # weight_width - np.int32(stride), # stride_height - np.int32(stride), # stride_width - np.int32(padding), # padding_height - np.int32(padding), # padding_width - np.int32(1), # dilation_height - np.int32(1), # dilation_width - np.int32(groups), # groups - ] - - # Define tuning parameters - only powers of 2 - tune_params = { - 'BLOCK_NI_HO_WO': [2 ** i for i in range(4, 10)], - 'BLOCK_CI': [2 ** i for i in range(4, 10)], - 'BLOCK_CO': [2 ** i for i in range(4, 10)], - 'num_stages': [1, 2, 3, 4], - 'num_warps': [1, 2, 4, 8], - } - - print(tune_params) - - # Define constraints - constraints = [ - "BLOCK_CI <= %d" % (in_channels//groups), - "BLOCK_CO <= %d" % out_channels, - ] - - # Problem size for the grid - problem_size = ( - batch_size * out_height * out_width, # Grid dimension 0 - out_channels, # Grid dimension 1 - groups, # Grid dimension 2 - ) - - # Grid divisor expressions - grid_div_x = ["BLOCK_NI_HO_WO"] - grid_div_y = ["BLOCK_CO"] - grid_div_z = ["1"] - - results, env = tune_kernel( - kernel_name='conv2d_forward_kernel', - kernel_source=conv2d_forward_kernel, - problem_size=problem_size, - arguments=arguments, - tune_params=tune_params, - restrictions=constraints, - lang='TRITON', - grid_div_x=grid_div_x, - grid_div_y=grid_div_y, - grid_div_z=grid_div_z, - block_size_names=['BLOCK_NI_HO_WO', 'BLOCK_CI', 'BLOCK_CO'], - strategy='genetic_algorithm', - strategy_options={ - 'maxiter': 1000, - 'popsize': 100, - }, - cache=cache_file, - ) - - return results - - -if __name__ == '__main__': - # Run tuning with moderately large input dimensions - results = tune_conv2d( - batch_size=16, - in_channels=128, - height=112, - width=112, - out_channels=256, - kernel_size=3, - stride=1, - padding=1, - groups=1 - ) - - - # Filter out failed compilations and find best config - valid_results = [result for result in results if isinstance(result['time'], (int, float))] - if valid_results: - best_config = min(valid_results, key=lambda x: x['time']) - print("\nBest configuration:") - print(json.dumps(best_config, indent=2)) - else: - print("\nNo valid configurations found - all compilations failed") - - # Create results dictionary with GPU info - all_results = { - "gpu_info": { - "gpu_name": torch.cuda.get_device_name() - }, - "results": valid_results - } - - # Add timestamp to filename - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - output_file = f'conv2d_results_{timestamp}.json' - - # Save results - import json - with open(output_file, 'w') as f: - json.dump(all_results, f, indent=2) \ No newline at end of file diff --git a/examples/triton/vec_add.py b/examples/triton/vec_add.py deleted file mode 100644 index 42c3e52c7..000000000 --- a/examples/triton/vec_add.py +++ /dev/null @@ -1,50 +0,0 @@ -import numpy -import triton.language as tl -import torch -from kernel_tuner import tune_kernel, run_kernel -from kernel_tuner.file_utils import store_output_file, store_metadata_file -import triton - -#@triton.jit -def add_kernel(x_ptr, # *Pointer* to first input vector. - y_ptr, # *Pointer* to second input vector. - output_ptr, # *Pointer* to output vector. - n_elements, # Size of the vector. - block_size_x: tl.constexpr, # Number of elements each program should process. - # note: `constexpr` so it can be used as a shape value. - ): - pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. - block_start = pid * block_size_x - offsets = block_start + tl.arange(0, block_size_x) - mask = offsets < n_elements - x = tl.load(x_ptr + offsets, mask=mask) - y = tl.load(y_ptr + offsets, mask=mask) - output = x + y - tl.store(output_ptr + offsets, output, mask=mask) - - - -def tune_with_triton(): - size = 10000000 - - a = torch.randn(size, device='cuda', dtype=torch.float32) - b = torch.randn(size, device='cuda', dtype=torch.float32) - c = torch.empty_like(b) - n = torch.tensor(size, dtype=torch.int32) - - args = [a, b, c, n] - - tune_params = dict() - tune_params["block_size_x"] = [2**i for i in range(10)] - - results, env = tune_kernel( - kernel_name="add_kernel", - kernel_source=add_kernel, - problem_size=size, - arguments=args, - tune_params=tune_params, - lang="triton" - ) - - -tune_with_triton() \ No newline at end of file diff --git a/kernel_tuner/backends/generic_python.py b/kernel_tuner/backends/generic_python.py index fbe0f0750..5d867fdac 100644 --- a/kernel_tuner/backends/generic_python.py +++ b/kernel_tuner/backends/generic_python.py @@ -3,6 +3,9 @@ import copy import traceback # for compile error handling import re +import warnings +import builtins +import numpy as np from kernel_tuner.backends.backend import GPUBackend from kernel_tuner.observers.generic_python import GenericPythonRuntimeObserver @@ -13,22 +16,29 @@ torch = None -# TODO delete temp file +class GenericPythonFunctions(GPUBackend): + """Class that groups the Python DSL functions on maintains state about the device.""" + def __init__(self, device=0, iterations=7, compiler_options=None, observers=None): + """Instantiate GenericPythonFunctions object used for interacting with the device. + Currently, only CUDA devices are supported. + Instantiating this object will inspect and store certain device properties at + runtime, which are used during compilation and/or execution of kernels by the + kernel tuner. Compiler options are not supported for GenericPython. -class GenericPythonFunctions(GPUBackend): + :param device: Number of CUDA device to use for this context + :type device: int + + :param iterations: Number of iterations used while benchmarking a kernel, 7 by default. + :type iterations: int + """ - def __init__(self, device=0, iterations=7, compiler_options=None, observers=None): - ''' - In here, everyting is generic if the language uses CUDA as backend - ''' if not torch: logging.error("Unable to import Torch") raise ImportError("Unable to import Torch") self.device_id = torch.cuda.current_device() - self.device_properties = torch.cuda.get_device_properties(self.device_id) self.name = torch.cuda.get_device_name(self.device_id) self.max_threads = self.device_properties.max_threads_per_multi_processor @@ -52,7 +62,7 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None self.units = {"time": "ms", "power": "s,mW", "energy": "J"} - # Variables to be filled in at compile time: + # Variables to be filled in at compile time, needed for running the kernel after compilation. self.call_function = None self.signature = None self.gpu_kwargs = None @@ -60,15 +70,58 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None super().__init__(device=device, iterations=iterations, compiler_options=compiler_options, observers=observers) def ready_argument_list(self, arguments): - ''' - The user already supplies the arguments in the correct format, because we are working with - a Python based language anyway. TODO probably only works with torch and numpy? - ''' - return copy.deepcopy(arguments) + """Ready argument list to be passed to the kernel. Converts arguments to Torch GPU Tensors or Python + Scalars. Arguments of built-in Python types are left untouched. + + :param arguments: List of arguments to be passed to the kernel. + The order should match the argument list on the kernel. + Allowed values are: + - numpy.ndarray, and/or numpy.int32, numpy.float32, and so on. + - Torch Tensors + - Built-in Python types. + :type arguments: list + + :returns: A list of arguments that can be passed to the kernel. + :rtype: list( Torch.Tensor on GPU, int, float, bool, etc. ) + """ + torch_args = [] + + for arg in arguments: + if isinstance(arg, torch.Tensor): + if arg.dim() == 0: # Scalar tensor, convert to Python scalar + torch_args.append(arg.item()) + else: + if arg.is_cuda: # already on GPU, need deep copy to not overwrite + torch_args.append(arg.clone()) + else: # Copy from CPU to GPU + torch_args.append(arg.contiguous().to("cuda")) + elif isinstance(arg, np.ndarray): # Convert Numpy CPU array to Torch GPU Tensor + torch_args.append(torch.from_numpy(arg).to("cuda")) + elif isinstance(arg, np.generic): # Numpy scalar, convert to Python scalar + torch_args.append(arg.item()) + elif isinstance(arg, (int, float, bool, str)): # Is already Python + torch_args.append(arg) + elif type(arg) in vars(builtins).values(): + torch_args.append(arg) + else: + raise TypeError("Unknown argument type: ", type(arg)) + + return torch_args + def compile(self, kernel_instance, gpu_args=None): - logging.debug("Compiling Generic Python kernel") + """Compile the kernel by executing it once. This enforces that the kernel is cached. + + :param kernel_instance: The kernel instance containing information such as the kernel_source, + grid, threads, params, etc. + :type kernel_instance: KernelInstance + + :param gpu_args: arguments to be passed to the kernel. + :type gpu_args: list(any) + :returns: A kernel that can be called with the user-defined call function. + :rtype: callable + """ if kernel_instance.kernel_fn is None: raise ValueError("kernel_fn is None, currently Generic Python only supports callable kernel_source") @@ -76,6 +129,7 @@ def compile(self, kernel_instance, gpu_args=None): raise ValueError("gpu_args is None, Generic Python needs gpu args to compile the kernel") # The first time we compile, we also set the call function and the signature + # We need this later to run the kernel. if self.call_function is None or self.signature is None: self.call_function = kernel_instance.kernel_source.call_function self.signature = kernel_instance.kernel_source.signature @@ -83,92 +137,145 @@ def compile(self, kernel_instance, gpu_args=None): grid = kernel_instance.grid threads = kernel_instance.threads params = kernel_instance.params + + # If the kernel source is a class, we use the __call__ function as the kernel_function. if inspect.isclass(kernel_instance.kernel_fn): kernel_function = kernel_instance.kernel_fn() elif callable(kernel_instance.kernel_fn): - # Handles functions and decroators that return callable objects kernel_function = kernel_instance.kernel_fn else: raise TypeError("kernel function is not a class or function") - self.gpu_kwargs = self.build_gpu_kwargs(params) - - # Call the jit function in order to compile it + # Tuning params can contain kernel arguments. In such cases, create keyword arguments with + # the values of the tuning params. + self.gpu_kwargs = {} + if params is not None: + for name, p in self.signature.parameters.items(): + if name in params: + self.gpu_kwargs[name] = params[name] + + # Call the user-defined call function in order to compile the kernel. self.synchronize() self.call_function(kernel_function, gpu_args, self.gpu_kwargs, grid, threads, params) self.synchronize() - return kernel_function def start_event(self): - logging.debug("Start Generic Python event") + """Records the event that marks the start of a measurement.""" self.start.record() def stop_event(self): - logging.debug("Stop Generic Python event") + """Records the event that marks the end of a measurement.""" self.end.record() def kernel_finished(self): - logging.debug("Checking if kernel has finished") + """Returns True if the kernel has finished, False otherwise.""" return self.end.query() def run_kernel(self, func, gpu_args, threads, grid, stream=None, params=None): + """Runs the Python kernel passed as 'func'. + + :param func: A cached Python kernel for this specific kernel configuration + :type func: callable + + :param gpu_args: A list of arguments to the kernel, order should match the + order in the code. + :type gpu_args: list(any) - # Run the kernel + :param threads: A tuple listing the number of threads in each dimension of + the thread block + :type threads: tuple(int, int, int) + + :param grid: A tuple listing the number of thread blocks in each dimension + of the grid + :type grid: tuple(int, int, int) + + :param params: A dictionary with the tuning params for this specific kernel + configuration + :type params: dict + """ if stream is None: stream = self.stream with torch.cuda.stream(stream): logging.debug("Running Generic Python kernel") self.call_function(func, gpu_args, self.gpu_kwargs, grid, threads, params) - - - - def build_gpu_kwargs(self, params=None): - gpu_kwargs = {} - - if params is None: - return gpu_kwargs - - for name, p in self.signature.parameters.items(): - if name in params: - gpu_kwargs[name] = params[name] - - return gpu_kwargs def synchronize(self): + """Halts execution until device has finished its tasks.""" torch.cuda.synchronize() + def memset(self, allocation, value, size): + """This method must implement setting the memory to a value on the device. + Not implemented: Python DSLs usually do not perform explicit memory handling.""" pass + def memcpy_dtoh(self, dest, src): + """This method must implement a device to host copy. + Not implemented: Python DSLs usually do not perform explicit memory handling.""" pass + def memcpy_htod(self, dest, src): + """This method must implement a host to device copy. + Not implemented: Python DSLs usually do not perform explicit memory handling.""" pass + def copy_constant_memory_args(self, cmem_args): raise NotImplementedError("Generic Python does not support constant memory") + def copy_shared_memory_args(self, smem_args): raise NotImplementedError("Generic Python does not support shared memory") + def copy_texture_memory_args(self, texmem_args): raise NotImplementedError("Generic Python does not support texture memory") + def refresh_memory(self, gpu_memory, host_arguments, should_sync): """Refresh the GPU memory with the untouched host arguments. We overwrite the standard function because Python DSLs do usually do not manage memory explicitely""" - for i, arg in enumerate(host_arguments): - if should_sync[i]: - gpu_memory[i] = copy.deepcopy(arg) + for i, host_arg in enumerate(host_arguments): + if should_sync[i]: + gpu_arg = gpu_memory[i] + + # Scalar Python type + if isinstance(gpu_arg, (int, float, bool)): + gpu_memory[i] = host_arg + elif type(gpu_arg) in vars(builtins).values(): + gpu_memory[i] = host_arg + + # GPU tensor + elif isinstance(gpu_arg, torch.Tensor): + if isinstance(host_arg, np.ndarray): + gpu_arg.copy_(torch.as_tensor(host_arg)) + elif isinstance(host_arg, torch.Tensor): + if host_arg.is_cuda and host_arg.device != gpu_arg.device: + gpu_arg.copy_(host_arg.to(gpu_arg.device)) # different gpu's, no direct copy allowed + else: + gpu_arg.copy_(host_arg) + else: + # host_arg is scalar, fill into tensor + gpu_arg.fill_(host_arg) + + def classify_compile_exception(self, e): - """Best effort to differentiate between a user error and a resource error. Input is Exception""" + """Best effort to differentiate between a user error and a resource error. + + :param e: the caught exception + :type e: exception + + :returns: "resource_error" , "user_error" or "unknown" + :rtype: string + """ RESOURCE_KEYWORDS = ( # Shared memory diff --git a/kernel_tuner/backends/triton.py b/kernel_tuner/backends/triton.py deleted file mode 100644 index cf6fda504..000000000 --- a/kernel_tuner/backends/triton.py +++ /dev/null @@ -1,174 +0,0 @@ -import logging -import numpy as np - -from kernel_tuner.backends.backend import GPUBackend -from kernel_tuner.observers.triton import TritonRuntimeObserver - -try: - import torch -except ImportError: - logging.error("Torch not available") - -try: - import triton - import triton.language as tl -except ImportError: - triton = None - tl = None - logging.error("Unable to load triton") - - -class TritonFunctions(GPUBackend): - - def __init__(self, device=0, iterations=7, compiler_options=None, observers=None): - if not triton or not torch: - logging.error("Triton or torch not available") - raise ImportError("Triton or torch not available") - - self.device_id = torch.cuda.current_device() - - self.device_properties = torch.cuda.get_device_properties(self.device_id) - self.name = torch.cuda.get_device_name(self.device_id) - self.max_threads = self.device_properties.max_threads_per_multi_processor - - env = dict() - env["device_name"] = self.name - env["max_threads"] = self.max_threads - env["iterations"] = iterations - env["compiler_options"] = compiler_options - self.env = env - - self.stream = torch.cuda.default_stream() - self.start = torch.cuda.Event(enable_timing=True) - self.end = torch.cuda.Event(enable_timing=True) - - # setup observers - self.observers = observers or [] - self.observers.append(TritonRuntimeObserver(self)) - for obs in self.observers: - obs.register_device(self) - - self.units = {"time": "ms", "power": "s,mW", "energy": "J"} - - super().__init__(device=device, iterations=iterations, compiler_options=compiler_options, observers=observers) - - def ready_argument_list(self, arguments): - # Allocate memory here - torch_args = [] - - for arg in arguments: - if isinstance(arg, torch.Tensor) and arg.dim() > 0: - torch_args.append(arg.cuda()) - elif isinstance(arg, torch.Tensor) and arg.dim() == 0: - scalar_value = arg.item() - torch_args.append(scalar_value) - elif isinstance(arg, np.ndarray): - torch_arg = torch.from_numpy(arg) - torch_arg_gpu = torch_arg.cuda() - torch_args.append(torch_arg_gpu) - elif isinstance(arg, np.generic): - scalar_value = arg.item() - torch_args.append(scalar_value) - else: - logging.warning("Unknown instance in triton functions") - - return torch_args - - def compile(self, kernel_instance, gpu_args=None): - logging.debug("Compiling triton kernel") - - if kernel_instance.kernel_fn is None: - raise ValueError("kernel_fn is None, currently Triton only supports callable kernel_source") - - if gpu_args is None: - raise ValueError("gpu_args is None, Triton needs gpu args to compile the kernel") - - grid = kernel_instance.grid - threads = kernel_instance.threads - jit_function = triton.jit(kernel_instance.kernel_fn) - params = kernel_instance.params - gpu_kwargs = self.build_gpu_kwargs(jit_function, threads, params) - - # Call the jit function in order to compile it - jit_function[grid](*gpu_args, **gpu_kwargs) - - return jit_function - - def start_event(self): - logging.debug("Start triton event") - self.start.record() - - def stop_event(self): - logging.debug("Stop triton event") - self.end.record() - - def kernel_finished(self): - logging.debug("Checking if kernel has finished") - return self.end.query() - - def run_kernel(self, func, gpu_args, threads, grid, stream=None, params=None): - if params is None: - raise ValueError("params is None, Triton needs params in order to set num_warps, num_ctas, etc.") - - # Run the kernel - if stream is None: - stream = self.stream - - gpu_kwargs = self.build_gpu_kwargs(func, threads, params) - - with torch.cuda.stream(stream): - logging.debug("Running triton kernel") - func[grid](*gpu_args, **gpu_kwargs) - - def build_gpu_kwargs(self, jit_fn, threads, params=None): - gpu_kwargs = {} - - if 'BLOCK_SIZE' in jit_fn.arg_names: - gpu_kwargs['BLOCK_SIZE'] = threads[0] - - if 'BLOCK_SIZE_X' in jit_fn.arg_names: - gpu_kwargs['BLOCK_SIZE_X'] = threads[0] - - if 'BLOCK_SIZE_Y' in jit_fn.arg_names: - gpu_kwargs['BLOCK_SIZE_Y'] = threads[1] - - if 'BLOCK_SIZE_Z' in jit_fn.arg_names: - gpu_kwargs['BLOCK_SIZE_Z'] = threads[2] - - if params is None: - return gpu_kwargs - - for param in params: - if param in jit_fn.arg_names: - gpu_kwargs[param] = params[param] - - # Check for Triton specific parameters - if 'num_warps' in params: - gpu_kwargs['num_warps'] = params['num_warps'] - if 'num_ctas' in params: - gpu_kwargs['num_ctas'] = params['num_ctas'] - if 'num_stages' in params: - gpu_kwargs['num_stages'] = params['num_stages'] - - return gpu_kwargs - - def synchronize(self): - torch.cuda.synchronize() - - def memset(self, allocation, value, size): - pass - - def memcpy_dtoh(self, dest, src): - pass - - def memcpy_htod(self, dest, src): - pass - - def copy_constant_memory_args(self, cmem_args): - raise NotImplementedError("Triton does not support constant memory") - - def copy_shared_memory_args(self, smem_args): - raise NotImplementedError("Triton does not support shared memory") - - def copy_texture_memory_args(self, texmem_args): - raise NotImplementedError("Triton does not support texture memory") \ No newline at end of file diff --git a/kernel_tuner/core.py b/kernel_tuner/core.py index 7705d828c..1f130afa3 100644 --- a/kernel_tuner/core.py +++ b/kernel_tuner/core.py @@ -21,7 +21,6 @@ from kernel_tuner.backends.nvcuda import CudaFunctions from kernel_tuner.backends.opencl import OpenCLFunctions from kernel_tuner.backends.pycuda import PyCudaFunctions -from kernel_tuner.backends.triton import TritonFunctions from kernel_tuner.backends.generic_python import GenericPythonFunctions from kernel_tuner.kernel_sources.kernel_source import KernelSource from kernel_tuner.observers.nvml import NVMLObserver @@ -86,7 +85,7 @@ def prepare_temp_files_for_error_msg(self): return ret -# KernelSource has been moved to kernel_sources/kernel_source_str.py +# KernelSource has been moved to kernel_sources/kernel_source_str.py to support Generic Python class DeviceInterface(object): @@ -186,13 +185,6 @@ def __init__( compiler_options=compiler_options ) self.requires_warmup = False - elif lang == Language.TRITON: - dev = TritonFunctions( - device, - compiler_options=compiler_options, - iterations=iterations, - observers=observers - ) elif lang == Language.GENERIC_PYTHON: dev = GenericPythonFunctions( device, @@ -202,7 +194,7 @@ def __init__( ) else: raise NotImplementedError( - "Sorry, support for languages other than CUDA, OpenCL, HIP, C, Triton and Fortran is not implemented yet" + "Sorry, support for languages other than CUDA, OpenCL, HIP, C, Fortran and Generic Python is not implemented yet" ) self.dev = dev @@ -242,7 +234,7 @@ def __init__( print("Using: " + self.dev.name) def run_kernel_bench(self, func, gpu_args, threads, grid, stream=None, params=None): - if isinstance(self.dev, TritonFunctions) or isinstance(self.dev, GenericPythonFunctions): + if isinstance(self.dev, GenericPythonFunctions): # Generic Python needs params to compile. self.dev.run_kernel(func, gpu_args, threads, grid, params=params) else: self.dev.run_kernel(func, gpu_args, threads, grid) @@ -403,7 +395,7 @@ def check_kernel_output( # retrieve gpu results to host memory result_host = [] if instance.kernel_source is not None and instance.kernel_source.lang == Language.GENERIC_PYTHON: - # Python DSLs do not explicitly manage memory. Therefore, we can use gpu args direclty + # arguments are either Torch tensors or built-in Python types for i, arg in enumerate(gpu_args): if should_sync[i]: if isinstance(arg, torch.Tensor): @@ -530,15 +522,12 @@ def compile_kernel(self, instance, verbose, gpu_args=None): # compile kernel_string into device func func = None - if isinstance(self.dev, TritonFunctions) or isinstance(self.dev, GenericPythonFunctions): + # Python DSLs have different error handling, so we make a case distinction. + if isinstance(self.dev, GenericPythonFunctions): try: func = self.dev.compile(instance, gpu_args) except Exception as e: - if isinstance(self.dev, GenericPythonFunctions): - exception_type = self.dev.classify_compile_exception(e) - else: - exception_type = "unkown" - + exception_type = self.dev.classify_compile_exception(e) if exception_type == "user_error": error_message = str(e.stderr) if hasattr(e, "stderr") else str(e) print("compile_kernel failed due to error: " + error_message) @@ -550,10 +539,9 @@ def compile_kernel(self, instance, verbose, gpu_args=None): ) if verbose: print( - f"skipping config {util.get_instance_string(instance.params)} reason: too many resources" + f"skipping config {util.get_instance_string(instance.params)} reason: \n{e}" ) - - else: + else: # All other languages try: func = self.dev.compile(instance) except Exception as e: @@ -617,7 +605,7 @@ def create_kernel_instance(self, kernel_source, kernel_options, params, verbose) kernel_options.block_size_names, ) - if kernel_source.lang not in [Language.TRITON, Language.GENERIC_PYTHON] and np.prod(threads) > self.dev.max_threads: + if kernel_source.lang is not Language.GENERIC_PYTHON and np.prod(threads) > self.dev.max_threads: if verbose: print(f"skipping config {util.get_instance_string(params)} reason: too many threads per block") return util.InvalidConfig() diff --git a/kernel_tuner/interface.py b/kernel_tuner/interface.py index ce9d74080..e0421f847 100644 --- a/kernel_tuner/interface.py +++ b/kernel_tuner/interface.py @@ -108,9 +108,11 @@ def __deepcopy__(self, _): ( "kernel_source", ( - """The CUDA, OpenCL, HIP, or C kernel code. + """The CUDA, OpenCL, HIP, C or Python DSL kernel code. It is allowed for the code to be passed as a string, a filename, a function that returns a string of code, or a list when the code needs auxilliary files. + In the case of a kernel in a Python DSL such as Triton, the reference to the + Python callable should be passed. To support combined host and device code tuning, a list of filenames can be passed. The first file in the list should be the @@ -134,7 +136,7 @@ def __deepcopy__(self, _): """Specifies the language used for GPU kernels. The kernel_tuner automatically detects the language, but if it fails, you may specify the language using this argument, currently supported: "CUDA", "CuPy", - "nvcuda", "OpenCL", "HIP", or "C".""", + "nvcuda", "OpenCL", "HIP", "C", or "Generic_Python.""", "string", ), ), @@ -181,7 +183,9 @@ def __deepcopy__(self, _): "arguments", ( """A list of kernel arguments, use numpy arrays for - arrays, use numpy.int32 or numpy.float32 for scalars.""", + arrays, use numpy.int32 or numpy.float32 for scalars. When the language + Generic Python is used, Torch Tensors and built-in Python data types are also + excepted""", "list", ), ), @@ -285,6 +289,31 @@ def __deepcopy__(self, _): "dict", ), ), + ( + "call_function", + ( + """When the language Generic Python is used, a call function that calls the kernel in the Python + DSL must be specified. The function must take the following arguments: + :kernel_function: the callable function with the tuning parameters inserted. If provided, the + kernel_function is decorated with the de decorator. + :args: list of kernel arguments, as provided by the user in the argument. + :kwargs: dictionary of kernel keyword arguments. If a tuning parameter is in the kernel signature, + the tuning parameter will be added as a keyword argument. + :grid: the launch grid (tuple with 3 values), as computed by KernelTuner + :threads: the thread block size (tuple with 3 values), as computed by KernelTuner + :params: dictionary with the values of the tuning params for the specific configuration.""", + "function", + ), + ), + ( + "decorator", + ( + """When the language Generic Python is used, a decorator can be provided in which the kernel source + will be wrapped internally by KernelTuner. Note that when passing the kernel to KernelTuner with the + ``kernel_source`` argument, the decorator should be removed from the kernel.""", + "string", + ), + ), ] ) diff --git a/kernel_tuner/kernel_sources/kernel_source.py b/kernel_tuner/kernel_sources/kernel_source.py index 9a16d9b6c..d18ecb284 100644 --- a/kernel_tuner/kernel_sources/kernel_source.py +++ b/kernel_tuner/kernel_sources/kernel_source.py @@ -1,4 +1,3 @@ -import inspect import kernel_tuner.util as util from abc import abstractmethod @@ -6,15 +5,20 @@ from kernel_tuner.language import Language from kernel_tuner.kernel_sources.model.prepared_kernel_source_data import PreparedKernelSourceData -# We use this pattern because we would otherwise get an init twice -# We create a kernelsouce by calling KernelSource(...). The call method for KernelSource is -# replaced by the call method in KernelSourceFactory, because KernelSource uses that class as -# metaclass. Inside the call method, a call to create either KernelSourceStr or KernelSourceFN -# is done. This triggers the __init__ call of those subclasses. In both subclasses, we call super().__init__ -# which initializes some sublclass specific variables and triggers the __init__ call of the KernelSource class. + class KernelSourceFactory(type): + ''' + Factory to dynamically determine if a KernelSource should be a KernelSourceStr or a KernelSourceFn. + if lang=generic_python, KernelSourceFn will be used. In all other cases, an instance of KernelSourceStr + is created. + + We create a kernelsouce by calling KernelSource(...). The call method for KernelSource is replaced by + the call method in KernelSourceFactory, because KernelSource uses that class as metaclass. Inside the + call method, a call to create either KernelSourceStr or KernelSourceFn is performed. This triggers the + __init__ call of the corresponding subclass. In both subclasses, we call super().__init__, which triggers + the __init__ call of the KernelSource class to initalize some common variables. + ''' def __call__(cls, kernel_name, kernel_sources, lang, defines=None, call_function=None, decorator=None): - """Factory behavior""" if lang == None: language = None else: @@ -23,9 +27,8 @@ def __call__(cls, kernel_name, kernel_sources, lang, defines=None, call_function except ValueError: raise TypeError(f"Supported languages are {[l.value for l in Language]}") - # Determine if we need to create a KernelSourceStr or a KernelSourceFn - if (language and (language == Language.TRITON or language == Language.GENERIC_PYTHON)): + if (language and language == Language.GENERIC_PYTHON): ks_str = False else: ks_str = True @@ -44,7 +47,8 @@ def __call__(cls, kernel_name, kernel_sources, lang, defines=None, call_function else: return super().__call__(kernel_name, kernel_sources, lang, defines, call_function, decorator) -# TODO do we really want the Language enum? it's a lot of changes + + class KernelSource(metaclass=KernelSourceFactory): def __init__(self, kernel_name, kernel_sources, lang, defines=None): @@ -59,11 +63,7 @@ def __init__(self, kernel_name, kernel_sources, lang, defines=None): raise TypeError("Please specify language when using a code generator function") kernel_string = self.get_kernel_string(0) language = util.detect_language(kernel_string) - try: - self.lang = Language(language.upper()) - except ValueError: - # TODO this should never happen - raise TypeError(f"Supported languages are {[l.value for l in Language]}, found {language}") + self.lang = Language(language.upper()) else: try: self.lang = Language(lang.upper()) diff --git a/kernel_tuner/kernel_sources/kernel_source_fn.py b/kernel_tuner/kernel_sources/kernel_source_fn.py index edb75f95d..386685377 100644 --- a/kernel_tuner/kernel_sources/kernel_source_fn.py +++ b/kernel_tuner/kernel_sources/kernel_source_fn.py @@ -3,6 +3,7 @@ import copy import uuid import sys +import logging import astor import tempfile @@ -17,154 +18,66 @@ class KernelSourceFn(KernelSource): + """ + Class that holds the Python-function-based kernel sources. + + There is a primary kernel source for function-based kernels in Python. The source must be + a callable. This can be a function of a class. The function should not be decorated, but + a decorator can be supplied to wrap the function. + + A call function to specify how the kernel should be launched must be supplied. The call + function must take the following arguments: + - kernel_function: the callable function with the tuning parameters inserted. If provided, the + kernel_function is decorated with the de decorator. + - args: list of kernel arguments, as provided by the user in the argument. + - kwargs: dictionary of kernel keyword arguments. If a tuning parameter is in the kernel signature, + the tuning parameter will be added as a keyword argument. + - grid: the launch grid (tuple with 3 values), as computed by KernelTuner + - threads: the thread block size (tuple with 3 values), as computed by KernelTuner + - params: dictionary with the values of the tuning params for this configuration. + """ def __init__(self, kernel_name, kernel_source, lang, defines=None, call_function=None, decorator=None): super().__init__(kernel_name, kernel_source, lang, defines) if isinstance(kernel_source, list): raise ValueError("KernelSourceFn only supports a single kernel source function") - - try: - self.lang = Language(lang.upper()) - except ValueError: - raise TypeError(f"Supported languages are {[l.value for l in Language]}") - - if call_function is None: - raise ValueError("call_function must be supplied for language Generic Python") - if not callable(call_function): - raise TypeError(f"call_function {call_function} is not a callable object.") + + if self.lang == Language.GENERIC_PYTHON: + if call_function is None: + raise ValueError("call_function must be supplied for language Generic Python") + if not callable(call_function): + raise TypeError(f"call_function of type {type(call_function)} is not a callable object.") self.call_function = call_function # TODO ceck signature if decorator: if not isinstance(decorator, str): - raise TypeError(f"{decorator} is not a decorator") + raise TypeError(f"The decorator should be a string, got {type(decorator)} instead.") if decorator[0] != '@': - raise ValueError(f"{decorator} is not a valid decorator") + raise ValueError(f"The decorator should start with a '@', got {decorator} instead.") self.decorator = decorator - self.source_kernel_fn = kernel_source - self.kernel_fn = self.source_kernel_fn - self.signature = inspect.signature(kernel_source) + self.source_kernel_fn = kernel_source # This kernel source remains the original object + self.kernel_fn = self.source_kernel_fn # This is the kernel source that we will modify + try: - self.source = inspect.getsource(kernel_source) + self.source = inspect.getsource(self.source_kernel_fn) except TypeError as e: raise TypeError( f"{e}. Did you forget to remove a decorator before tuning?" ) from e - self.source_tree = ast.parse(self.source) - self.import_nodes = self._find_import_nodes(inspect.getfile(kernel_source)) - - # Find the module where the kernel function is defined - self.module = inspect.getmodule(kernel_source) - # Get dependencies by analyzing the AST - if self.lang == Language.TRITON: - self.dependencies = self._find_triton_dependencies() - else: - self.dependencies = self._find_dependencies() - - - - - def _find_import_nodes(self, source_file): - with open(source_file, "r") as f: - tree = ast.parse(f.read(), filename=source_file) - - import_nodes = [] - for node in tree.body: - if isinstance(node, (ast.Import, ast.ImportFrom)): - import_nodes.append(node) - - return import_nodes - - def _find_function_dependencies(self): - """Find all function calls in the kernel.""" - class FunctionCallVisitor(ast.NodeVisitor): - def __init__(self): - self.called_functions = set() - - def visit_Call(self, node): - if isinstance(node.func, ast.Name): - self.called_functions.add(node.func.id) - self.generic_visit(node) - - visitor = FunctionCallVisitor() - visitor.visit(self.source_tree) - return visitor.called_functions - - def _is_triton_jit_function(self, node): - """Check if a function has the @triton.jit decorator. Is triton specific""" - if not isinstance(node, ast.FunctionDef): - return False - - for decorator in node.decorator_list: - if (isinstance(decorator, ast.Name) and decorator.id == 'jit' or - isinstance(decorator, ast.Attribute) and decorator.attr == 'jit' or - isinstance(decorator, ast.Call) and - isinstance(decorator.func, ast.Attribute) and - decorator.func.attr == 'jit'): - return True - return False - - def _find_triton_dependencies(self): - """Find all Triton JIT functions that are called by the kernel. Is Triton specific""" - dependencies = {} - called_functions = self._find_function_dependencies() - - # Parse the entire module to find Triton JIT functions - module_source = inspect.getsource(self.module) - module_tree = ast.parse(module_source) - for node in module_tree.body: - if (self._is_triton_jit_function(node) and - node.name in called_functions and - node.name != self.kernel_name): # Skip the main kernel itself - dependencies[node.name] = node - - return dependencies - - def _find_function_dependecies2(self, tree, local_funcs): - # Non triton specific - class FunctionCallVisitor(ast.NodeVisitor): - def __init__(self): - self.called = set() - - def visit_Call(self, node): - if isinstance(node.func, ast.Name): - name = node.func.id - if name in local_funcs: - self.called.add(name) - self.generic_visit(node) - - visitor = FunctionCallVisitor() - visitor.visit(tree) - return visitor.called - - def _find_local_functions(self, tree): - local_funcs = set() - for node in tree.body: - if isinstance(node, ast.FunctionDef): - local_funcs.add(node.name) - return local_funcs - - def _find_dependencies(self): - source_file = inspect.getfile(self.source_kernel_fn) - with open(source_file, "r") as f: - source_code = f.read() - tree = ast.parse(source_code, filename=source_file) - local_funcs = self._find_local_functions(tree) - called_functions = self._find_function_dependecies2(self.source_tree, local_funcs) - - dependencies = {} - for node in tree.body: - if (isinstance(node, ast.FunctionDef) and node.name in called_functions and - node.name != self.kernel_name): # Skip the main kernel itself - dependencies[node.name] = node + self.signature = inspect.signature(self.source_kernel_fn) + self.source_tree = ast.parse(self.source) + self.import_nodes = self._find_import_nodes(inspect.getfile(self.source_kernel_fn)) + self.dependencies = self._find_dependencies() - return dependencies - def _add_decorator(self, function): - pass - def prepare_kernel_instance(self, kernel_options, params, grid, threads): + ''' + Given the dictionary of tuning parameter values for this configuration, + generate a kernel instance with these tuning parameters inserted. kernel_options, + grid and threads are not needed for Python kernels. + ''' new_kernel_fn, temp_file_path = self.apply_params_to_source_fn(params) self.kernel_fn = new_kernel_fn @@ -175,13 +88,26 @@ def prepare_kernel_instance(self, kernel_options, params, grid, threads): kernel_str=None ) + def check_argument_lists(self, kernel_name, arguments): + ''' + Check if the kernel arguments have the correct types. + Not implemented for Python, because type hinting is not always supplied. + ''' + logging.debug("Checking of arguments list not supported yet for Python kernels") return True + def apply_params_to_source_fn(self, params): + ''' + Create a module with the kernel imports and local dependencies from the kernel source file. + Find instances of the tuning parameters in the kernel and local dependencies and replace these + instances by the values of this configuration. + Return the new kernel and the path to the new module, where the new kernel lives. + ''' transformer = ReplaceVars(params) - # Create a new module with all necessary functions + # Create a new module to store the transformed AST in. new_module = ast.Module(body=[], type_ignores=[]) # Add imports @@ -190,11 +116,10 @@ def apply_params_to_source_fn(self, params): # Add dependencies (functions) for dep_node in self.dependencies.values(): dep_node_copy = copy.deepcopy(dep_node) - transformed_dep_node = transformer.visit(dep_node_copy) - #new_module.body.append(copy.deepcopy(dep_node)) + transformed_dep_node = transformer.visit(dep_node_copy) # apply tuning params to dependencies new_module.body.append(transformed_dep_node) - # Add transformed main kernel + # Transform main kernel source_tree_copy = copy.deepcopy(self.source_tree) transformed_tree = transformer.visit(source_tree_copy) @@ -207,39 +132,133 @@ def apply_params_to_source_fn(self, params): node.decorator_list.insert(0, decorator_node) break # only apply to the top level function + # Add transformed main kernel to new module new_module.body.extend(transformed_tree.body) # Fix locations and generate source ast.fix_missing_locations(new_module) new_source = astor.to_source(new_module) - #print(new_source) - - # Create a unique module name + # Create a unique module name and write new source to it. module_name = f'temp_kernel_module_{uuid.uuid4().hex}' - with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as temp_file: temp_file.write(new_source) temp_file_path = temp_file.name - spec = importlib.util.spec_from_file_location(module_name, temp_file_path) - temp_module = importlib.util.module_from_spec(spec) - # Register the module in sys.modules before executing it + spec = importlib.util.spec_from_file_location(module_name, temp_file_path) + temp_module = importlib.util.module_from_spec(spec) sys.modules[module_name] = temp_module spec.loader.exec_module(temp_module) new_fn = getattr(temp_module, self.kernel_name) return new_fn, temp_file_path + + def _find_import_nodes(self, source_file): + ''' + Parse kernel source file to find import statements. Return those + statements as AST import nodes. + ''' + with open(source_file, "r") as f: + tree = ast.parse(f.read(), filename=source_file) + + import_nodes = [] + for node in tree.body: + if isinstance(node, (ast.Import, ast.ImportFrom)): + import_nodes.append(node) + + return import_nodes + + + def _find_function_dependecies(self, tree, local_funcs): + ''' + Given an AST and a list of local function names, return the funtions that are invoked + somewhere in the file, and are in the local_funcs list. + ''' + class FunctionCallVisitor(ast.NodeVisitor): + def __init__(self): + self.called = set() + + def visit_Call(self, node): + if isinstance(node.func, ast.Name): + name = node.func.id + if name in local_funcs: + self.called.add(name) + self.generic_visit(node) + + visitor = FunctionCallVisitor() + visitor.visit(tree) + return visitor.called + + + def _find_dependencies(self): + ''' + Find all local function dependencies in the file where the kernel source is + defined. Return a dicionary indexed by the function node name with the function + body as value. + ''' + source_file = inspect.getfile(self.source_kernel_fn) + with open(source_file, "r") as f: + source_code = f.read() + tree = ast.parse(source_code, filename=source_file) + + # Find the locally defined functions in the source file + local_funcs = set() + for node in tree.body: + if isinstance(node, ast.FunctionDef): + local_funcs.add(node.name) + + # Given source file AST and the set of locally defined functions, find the names of + # the local functions that are invoked somewhere in the AST. + called_functions = self._find_function_dependecies(self.source_tree, local_funcs) + + # Traverse the tree one more time to save the dependency functions in a dictionary. + # Skip the main kernel. + dependencies = {} + for node in tree.body: + if (isinstance(node, ast.FunctionDef) and node.name in called_functions and + node.name != self.kernel_name): # Skip the main kernel itself + dependencies[node.name] = node + + return dependencies + + def __del__(self): - # Clean up temporary modules when the instance is destroyed + ''' + Clean up temporary modules when the instance is destroyed + ''' for key in list(sys.modules.keys()): if key.startswith('temp_kernel_module_'): del sys.modules[key] class ReplaceVars(ast.NodeTransformer): + ''' + AST transformer that replaces occurrences of tuning parameters with constant values. + The tuning parameters (params) should be provided as a dictionary. + The transformation supports the following cases: + + - Variable reads: + occurrence of a variable name that matches a key in `params` are replaced if the + value is being read. + + - Object attribute reads: + expressions of the form `self.` are replaced with constants if + `` exists in `params` and is being read (not assigned to). + + - Assignments to parameters: + assignments like `param = ...` or `self.param = ...` are overridden, + replacing the right-hand side with the constant value from `params`. + + Examples: + Given `params = {"block_size": 32}`: + x = 2 * block_size -> x = 2 * 32 + x = 2 * self.block_size -> x = 2 * 32 + self.block_size = 64 -> self.block_size = 32 + block_size = 64 -> block_size = 32 + + ''' def __init__(self, params: dict): self.params = params @@ -252,6 +271,7 @@ def visit_Name(self, node: ast.Name) -> Any: ) return node + def visit_Attribute(self, node: ast.Attribute): self.generic_visit(node) @@ -260,7 +280,7 @@ def visit_Attribute(self, node: ast.Attribute): isinstance(node.value, ast.Name) and node.value.id == "self" and node.attr in self.params - and isinstance(node.ctx, ast.Load) # <- check context + and isinstance(node.ctx, ast.Load) ): return ast.copy_location(ast.Constant(self.params[node.attr]), node) diff --git a/kernel_tuner/kernel_sources/kernel_source_str.py b/kernel_tuner/kernel_sources/kernel_source_str.py index 4db0831a7..095417392 100644 --- a/kernel_tuner/kernel_sources/kernel_source_str.py +++ b/kernel_tuner/kernel_sources/kernel_source_str.py @@ -8,7 +8,7 @@ class KernelSourceStr(KernelSource): - """Class that holds the kernel sources. + """Class that holds the string-based kernel sources. There is a primary kernel source for string-based kernels., which can be either a source string, a filename (indicating a file containing the kernel source code), diff --git a/kernel_tuner/language.py b/kernel_tuner/language.py index b129e04d5..02948c12b 100644 --- a/kernel_tuner/language.py +++ b/kernel_tuner/language.py @@ -5,7 +5,6 @@ class Language(Enum): OPENCL = "OPENCL" C = "C" HIP = "HIP" - TRITON = "TRITON" FORTRAN = "FORTRAN" NVCUDA = "NVCUDA" GENERIC_PYTHON = "GENERIC_PYTHON" diff --git a/kernel_tuner/observers/triton.py b/kernel_tuner/observers/triton.py deleted file mode 100644 index 878e8f27c..000000000 --- a/kernel_tuner/observers/triton.py +++ /dev/null @@ -1,32 +0,0 @@ -import numpy as np - -from kernel_tuner.observers.observer import BenchmarkObserver - -try: - import torch -except (ImportError, RuntimeError): - torch = None - - -class TritonRuntimeObserver(BenchmarkObserver): - """Observer that measures time using CUDA events during benchmarking.""" - - def __init__(self, dev): - if torch is None: - raise ImportError("Unable to load torch") - - self.dev = dev - self.stream = dev.stream - self.start = dev.start - self.end = dev.end - self.times = [] - - def after_finish(self): - # Time is measured in milliseconds - event_elapsed_time = self.start.elapsed_time(self.end) - self.times.append(event_elapsed_time) - - def get_results(self): - results = {"time": np.average(self.times), "times": self.times.copy()} - self.times = [] - return results \ No newline at end of file diff --git a/test/test_generic_python_functions.py b/test/test_generic_python_functions.py index a3d250b60..eee0d3dd1 100644 --- a/test/test_generic_python_functions.py +++ b/test/test_generic_python_functions.py @@ -109,7 +109,7 @@ def test_gpu_kwargs(): gpu_args = dev.ready_argument_list(args) callable_fn = dev.compile_kernel(instance, verbose=False, gpu_args=gpu_args) - kwargs = dev.dev.build_gpu_kwargs(params) + kwargs = dev.dev.gpu_kwargs assert kwargs["mock_param"] == 64 assert callable_fn(*args, **kwargs) == 64 diff --git a/test/test_kernel_source_fn.py b/test/test_kernel_source_fn.py index e23a41a15..85aa4e3eb 100644 --- a/test/test_kernel_source_fn.py +++ b/test/test_kernel_source_fn.py @@ -61,7 +61,7 @@ def test_initiation(): with pytest.raises(ValueError, match=r"call_function must be supplied for language .*"): KernelSource("mock_kernel", mock_kernel, "generic_python") - with pytest.raises(TypeError, match=r".* is not a callable object"): + with pytest.raises(TypeError, match=r".* Did you forget to remove a decorator before tuning\?"): KernelSource("mock_kernel", "This is a string Kernel", "generic_python", call_function=call_mock) with pytest.raises(ValueError, match=r"KernelSourceFn only supports a single kernel source function"): @@ -70,10 +70,10 @@ def test_initiation(): with pytest.raises(TypeError, match=r".* is not a callable object"): KernelSource("mock_kernel", mock_kernel, "generic_python", call_function="not a function") - with pytest.raises(ValueError, match=r".* is not a valid decorator"): + with pytest.raises(ValueError, match=r"The decorator should start with a '@', got .* instead."): KernelSource("mock_kernel", mock_kernel, "generic_python", call_function=call_mock, decorator="not a decorator") - with pytest.raises(TypeError, match=r".* is not a decorator"): + with pytest.raises(TypeError, match=r"The decorator should be a string, got .* instead."): KernelSource("mock_kernel", mock_kernel, "generic_python", call_function=call_mock, decorator=mock_kernel) From 81340389cb4c41321f24894a00d81db2585c99aa Mon Sep 17 00:00:00 2001 From: Imke van Ooijen Date: Thu, 5 Mar 2026 12:54:35 +0100 Subject: [PATCH 05/14] support for decorated kernels without having to remove the decorator --- .../generic_python/matmul/tilus_matmul.py | 1 - examples/generic_python/numba_vec_add.py | 88 +++++-------- examples/generic_python/tilelang_vec_add.py | 14 +- examples/generic_python/tilus_naive_matmul.py | 5 +- .../generic_python/tilus_splitk_matmul.py | 26 +--- .../generic_python/tilus_tunable_precision.py | 120 ------------------ examples/generic_python/tilus_vec_add.py | 36 ++---- examples/generic_python/triton_vec_add.py | 37 +++--- examples/generic_python/warp_vec_add.py | 63 +++------ kernel_tuner/backends/generic_python.py | 15 +-- kernel_tuner/interface.py | 21 +-- kernel_tuner/kernel_sources/kernel_source.py | 6 +- .../kernel_sources/kernel_source_fn.py | 74 +++++------ kernel_tuner/util.py | 78 +++++++++++- test/test_generic_python_functions.py | 27 +++- test/test_kernel_source_fn.py | 50 +++++--- 16 files changed, 258 insertions(+), 403 deletions(-) delete mode 100644 examples/generic_python/tilus_tunable_precision.py diff --git a/examples/generic_python/matmul/tilus_matmul.py b/examples/generic_python/matmul/tilus_matmul.py index 07532013f..e8996c201 100644 --- a/examples/generic_python/matmul/tilus_matmul.py +++ b/examples/generic_python/matmul/tilus_matmul.py @@ -88,7 +88,6 @@ def __call__( self.store_global(gc, casted_acc, offsets=[offset_m, offset_n]) - class MatmulGroupedOrdering(tilus.Script): def __init__(self): super().__init__() diff --git a/examples/generic_python/numba_vec_add.py b/examples/generic_python/numba_vec_add.py index e49f9e372..9f8dfe095 100644 --- a/examples/generic_python/numba_vec_add.py +++ b/examples/generic_python/numba_vec_add.py @@ -1,10 +1,13 @@ from numba import cuda -from math import ceil -from kernel_tuner import tune_kernel, run_kernel import torch +from kernel_tuner import tune_kernel, run_kernel +import numpy as np +from pathlib import Path + +FULL_PATH = Path(__file__).resolve() -#@cuda.jit +@cuda.jit def f(a, b, c): tid = cuda.grid(1) size = len(c) @@ -23,64 +26,31 @@ def call_numba(kernel_function, args, kwargs, grid, threads, params): kernel_function[grid, threads](*args, **kwargs) -def verify(answer, result_host, atol): - correct = True - for i, ans in enumerate(answer): - if ans is None: - continue - res = result_host[i].copy_to_host() - if not np.allclose(ans, res, atol=atol): - correct = False - return correct +def tune(): + + N = 100000 + + a = np.random.random(N) + b = np.random.random(N) + c = np.zeros(N) + c_expect = a + b + + args = [a, b, c] + tune_params = {"block_size_x": [2**i for i in range(10)]} -import numpy as np -N = 100000000 - -a = np.random.random(N) -b = np.random.random(N) -c = np.zeros(N) -c_expect = a + b -#c_expect = a.copy_to_host() + b.copy_to_host() - - -#a_torch = torch.rand(N, dtype=torch.float32) -#b_torch = torch.rand(N, dtype=torch.float32) -#c_torch = torch.zeros(N, dtype=torch.float32) -#c_expect = (a_torch + b_torch).cpu() - -args = [a, b, c] -tune_params = {"block_size_x": [2**i for i in range(10)]} - -''' -results = run_kernel( - kernel_name="f", - kernel_source=f, - problem_size=N, - arguments=args, - params={"block_size_x": 32}, - lang="generic_python", - call_function=call_numba, - decorator="@cuda.jit" -) -''' - - -#print(np.allclose(results[2], c_expect)) - - -results, env = tune_kernel( - kernel_name="f", - kernel_source=f, - problem_size=N, - arguments=args, - tune_params=tune_params, - lang="generic_python", - answer=[None, None, c_expect], - #verify=verify, - call_function=call_numba, - decorator="@cuda.jit" -) + results, env = tune_kernel( + kernel_name="f", + kernel_source=FULL_PATH, + problem_size=N, + arguments=args, + tune_params=tune_params, + lang="generic_python", + answer=[None, None, c_expect], + call_function=call_numba, + ) +if __name__ == "__main__": + tune() \ No newline at end of file diff --git a/examples/generic_python/tilelang_vec_add.py b/examples/generic_python/tilelang_vec_add.py index 3ff3c9376..601b6a7b7 100644 --- a/examples/generic_python/tilelang_vec_add.py +++ b/examples/generic_python/tilelang_vec_add.py @@ -2,8 +2,11 @@ import tilelang.language as T import torch from kernel_tuner import tune_kernel +from pathlib import Path -#@tilelang.jit # infers target from tensors at first call +FULL_PATH = Path(__file__).resolve() + +@tilelang.jit # infers target from tensors at first call def add(N: int, dtype: str = 'float32', block: int = 256,): @T.prim_func @@ -51,7 +54,8 @@ def tune(): answer = [None, None, (A + B).cpu()] - res, env = tune_kernel("add", add, N, args, tune_params, lang="generic_python", - call_function=call_tilelang, decorator="@tilelang.jit", answer=answer) - -tune() \ No newline at end of file + res, env = tune_kernel("add", FULL_PATH, N, args, tune_params, lang="generic_python", + call_function=call_tilelang, answer=answer) + +if __name__ == "__main__": + tune() \ No newline at end of file diff --git a/examples/generic_python/tilus_naive_matmul.py b/examples/generic_python/tilus_naive_matmul.py index 2f358ef48..42cdff86d 100644 --- a/examples/generic_python/tilus_naive_matmul.py +++ b/examples/generic_python/tilus_naive_matmul.py @@ -4,6 +4,9 @@ from kernel_tuner import tune_kernel, run_kernel import math import torch +from pathlib import Path + +FULL_PATH = Path(__file__).resolve() class MatmulV0(tilus.Script): @@ -105,7 +108,7 @@ def main(): results, env = tune_kernel( kernel_name="MatmulV0", - kernel_source=MatmulV0, + kernel_source=FULL_PATH, problem_size=[m, n], arguments=args, tune_params=tune_params, diff --git a/examples/generic_python/tilus_splitk_matmul.py b/examples/generic_python/tilus_splitk_matmul.py index f0dff1ac3..be797299e 100644 --- a/examples/generic_python/tilus_splitk_matmul.py +++ b/examples/generic_python/tilus_splitk_matmul.py @@ -4,17 +4,13 @@ import torch import math from kernel_tuner import tune_kernel, run_kernel +from pathlib import Path + +FULL_PATH = Path(__file__).resolve() -''' -@tilus.autotune("num_warps", [4, 8]) -@tilus.autotune("block_m, block_n", [(128, 128), (128, 64), (64, 128), (32, 256)]) -@tilus.autotune("block_k", [16, 32]) -@tilus.autotune("num_stages", [3, 4, 5]) -@tilus.autotune("split_k_factor", [1, 4, 12, 16]) -''' class MatmulV5(tilus.Script): - ''' - def __init__(self, block_m, block_n, block_k, num_warps, num_stages, split_k_factor): + def __init__(self, block_m=None, block_n=None, block_k=None, + num_warps=None, num_stages=None, split_k_factor=None): super().__init__() self.block_m = block_m self.block_n = block_n @@ -22,15 +18,7 @@ def __init__(self, block_m, block_n, block_k, num_warps, num_stages, split_k_fac self.num_warps = num_warps self.num_stages = num_stages self.split_k_factor = split_k_factor - ''' - def __init__(self): - super().__init__() - self.block_m = 128 - self.block_n = 128 - self.block_k = 16 - self.num_warps = 4 - self.num_stages = 4 - self.split_k_factor = 4 + def __call__( self, @@ -246,7 +234,7 @@ def main(): results, env = tune_kernel( kernel_name="MatmulV5", # This has to be a string of the actual name. TODO is this always the case? - kernel_source=MatmulV5, + kernel_source=FULL_PATH, problem_size=[m, n], arguments=args, tune_params=tune_params, diff --git a/examples/generic_python/tilus_tunable_precision.py b/examples/generic_python/tilus_tunable_precision.py deleted file mode 100644 index 1cc4a7499..000000000 --- a/examples/generic_python/tilus_tunable_precision.py +++ /dev/null @@ -1,120 +0,0 @@ -import tilus -import hidet -from hidet import float32, float16, int32 -#from tilus import float32, int32, float16 -from tilus.utils import cdiv, benchmark_func -import torch -from kernel_tuner import tune_kernel, run_kernel -from kernel_tuner.accuracy import Tunable, AccuracyObserver -import numpy as np - -INPUT_TYPE = float32 -OUTPUT_TYPE = float32 - -class VecAddV(tilus.Script): - def __init__(self): - super().__init__() - self.block_size_x = 32 # number of threads per block - - def __call__( - self, - n_size: int32, # size of the vectors - a_ptr: ~INPUT_TYPE, # input vector A - b_ptr: ~INPUT_TYPE, # input vector B - c_ptr: ~OUTPUT_TYPE # output vector C - ): - - # compute the number of blocks needed - self.attrs.blocks = [cdiv(n_size, self.block_size_x)] - self.attrs.warps = 1 # number of warps per block - - # calculate the offset for this block - offset: int32 = self.block_size_x * self.blockIdx.x - - # create global views for input/output vectors - ga = self.global_view(a_ptr, dtype=INPUT_TYPE, shape=[n_size]) - gb = self.global_view(b_ptr, dtype=INPUT_TYPE, shape=[n_size]) - gc = self.global_view(c_ptr, dtype=OUTPUT_TYPE, shape=[n_size]) - - a = self.load_global(ga, offsets=[offset], shape=[self.block_size_x]) - b = self.load_global(gb, offsets=[offset], shape=[self.block_size_x]) - c = a + b - self.store_global(gc, c, offsets=[offset]) - - -def call_tilus(kernel_function, args, kwargs, grid, threads, params): - kernel_function(*args, **kwargs) - - -def verify(answer, result_host, atol): - correct = True - for i, ans in enumerate(answer): - if ans is None: - continue - res = result_host[i].cpu() - if not torch.allcose(ans, res, atol=atol): - correct = False - - return correct - - - -def main(): - - size = 1024000 - - a_32 = torch.randn(size, dtype=torch.float32) - b_32 = torch.randn(size, dtype=torch.float32) - c_32 = torch.zeros_like(b_32) - c_expect = a_32 + b_32 - - a_16 = a_32.to(torch.float16) - b_16 = b_32.to(torch.float16) - c_16 = c_32.to(torch.float16) - - - tune_params = dict() - tune_params["block_size_x"] = [32, 64, 128, 256, 512, 1024] - tune_params["INPUT_TYPE"] = [tilus.float16, tilus.float32] - tune_params["OUTPUT_TYPE"] = [tilus.float16, tilus.float32] - - args = [ - size, - Tunable("INPUT_TYPE", { - tilus.float32: a_32, - tilus.float16: a_16, - }), - Tunable("INPUT_TYPE", { - tilus.float32: b_32, - tilus.float16: b_16, - }), - Tunable("OUTPUT_TYPE", { - tilus.float32: c_32, - tilus.float16: c_16, - }), - ] - - print(tune_params) - - observers = [AccuracyObserver("RMSE")] - - - - results, env = tune_kernel( - kernel_name="VecAddV", - kernel_source=VecAddV, - problem_size=size, - arguments=args, - tune_params=tune_params, - lang="generic_python", - answer=[None, None, None, c_expect.cpu()], - observers=observers, - call_function=call_tilus, - verify=verify, - verbose=True, - ) - - - -if __name__ == "__main__": - main() diff --git a/examples/generic_python/tilus_vec_add.py b/examples/generic_python/tilus_vec_add.py index 161575181..938d9026c 100644 --- a/examples/generic_python/tilus_vec_add.py +++ b/examples/generic_python/tilus_vec_add.py @@ -3,8 +3,9 @@ from tilus.utils import cdiv, benchmark_func import torch from kernel_tuner import tune_kernel, run_kernel +from pathlib import Path - +FULL_PATH = Path(__file__).resolve() class VecAddV(tilus.Script): def __init__(self): @@ -41,7 +42,7 @@ def call_tilus(kernel_function, args, kwargs, grid, threads, params): kernel_function(*args, **kwargs) -def tune_vecadd(size): +def tune(size): a = torch.randn(size, dtype=torch.float32).cuda() b = torch.randn(size, dtype=torch.float32).cuda() c = torch.empty(size, dtype=torch.float32).cuda() @@ -52,9 +53,10 @@ def tune_vecadd(size): tune_params = dict() tune_params["block_size_x"] = [32, 64, 128, 256, 512, 1024] + results, env = tune_kernel( kernel_name="VecAddV", - kernel_source=VecAddV, + kernel_source=FULL_PATH, problem_size=size, arguments=args, tune_params=tune_params, @@ -65,8 +67,7 @@ def tune_vecadd(size): ) -# TODO run kernel error handling same as tune_kernel -def run_vecadd(size): +def run(size): a = torch.randn(size, dtype=torch.float32).cuda() b = torch.randn(size, dtype=torch.float32).cuda() c = torch.empty(size, dtype=torch.float32).cuda() @@ -77,37 +78,20 @@ def run_vecadd(size): results = run_kernel( kernel_name="VecAddV", - kernel_source=VecAddV, + kernel_source=FULL_PATH, problem_size=size, arguments=args, params={"block_size_x": 32}, lang="generic_python", call_function=call_tilus, - #verbose=True, ) c_expect = c_expect.cpu() assert torch.allclose(results[-1], c_expect) - - - -def run_normal(size): - a = torch.randn(size, dtype=torch.float32).cuda() - b = torch.randn(size, dtype=torch.float32).cuda() - c = torch.empty(size, dtype=torch.float32).cuda() - c_expect = a + b - - vecadd = VecAddV() - vecadd(size, a, b, c) - print(c) - print(c_expect) - assert torch.allclose(c, c_expect) - if __name__ == "__main__": - size = 1024 - tune_vecadd(size) - #run_vecadd(size) - #run_normal(size) + size = 100000000 + tune(size) + run(size) diff --git a/examples/generic_python/triton_vec_add.py b/examples/generic_python/triton_vec_add.py index fcc019615..2e982abfa 100644 --- a/examples/generic_python/triton_vec_add.py +++ b/examples/generic_python/triton_vec_add.py @@ -1,18 +1,18 @@ import numpy as np -import triton.language as tl import torch -from kernel_tuner import tune_kernel, run_kernel -from kernel_tuner.file_utils import store_output_file, store_metadata_file import triton -from math import ceil +import triton.language as tl +from pathlib import Path +from kernel_tuner import tune_kernel, run_kernel + +FULL_PATH = Path(__file__).resolve() @triton.jit def add_op(x, y): return x + y - -#@triton.jit +@triton.jit def add_kernel(x_ptr, # *Pointer* to first input vector. y_ptr, # *Pointer* to second input vector. output_ptr, # *Pointer* to output vector. @@ -33,11 +33,8 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. def call_triton(kernel_function, args, kwargs, grid, threads, params): kernel_function[grid](*args, **kwargs) -# NOTE: can the python file be changed in between? what happens? -# NOTE: tune params in the funcion signature are supported as key word arguments. Do not pass them as args, these -# will be inserted automatically. You can use them in the call function (kwargs). -def tune_with_generic(): +def tune(): size = 10000000 a = torch.randn(size, device='cuda', dtype=torch.float32) @@ -48,31 +45,27 @@ def tune_with_generic(): args = [a, b, c, size] tune_params = dict() - tune_params["block_size_x"] = [2**i for i in range(10)] + tune_params["block_size_x"] = [2**i for i in range(11)] - result = run_kernel("add_kernel", add_kernel, size, args, {"block_size_x": 256}, - lang="generic_python", call_function=call_triton, decorator="@triton.jit") - print(np.allclose(c_expect.cpu(), result[2])) - + result = run_kernel("add_kernel", FULL_PATH, size, args, {"block_size_x": 256}, + lang="generic_python", call_function=call_triton) + assert np.allclose(c_expect.cpu(), result[2]) + results, env = tune_kernel( kernel_name="add_kernel", - kernel_source=add_kernel, + kernel_source=FULL_PATH, problem_size=size, arguments=args, tune_params=tune_params, lang="generic_python", answer=[None, None, c_expect.cpu(), None], call_function=call_triton, - decorator="@triton.jit" ) - - - - -tune_with_generic() \ No newline at end of file +if __name__ == "__main__": + tune() \ No newline at end of file diff --git a/examples/generic_python/warp_vec_add.py b/examples/generic_python/warp_vec_add.py index 4a2c8d350..4f94c91d2 100644 --- a/examples/generic_python/warp_vec_add.py +++ b/examples/generic_python/warp_vec_add.py @@ -2,17 +2,16 @@ import numpy as np from kernel_tuner import tune_kernel, run_kernel import torch +from pathlib import Path wp.init() - - +FULL_PATH = Path(__file__).resolve() @wp.func def add_op(x: float, y: float): return x + y - -#@wp.kernel +@wp.kernel def vec_add(a: wp.array(dtype=float), b: wp.array(dtype=float), c: wp.array(dtype=float), @@ -28,8 +27,6 @@ def vec_add(a: wp.array(dtype=float), c[idx] = add_op(a[idx], b[idx]) -# TODO do we allways want the call function to have the same parameters -# or do we only require some of them? def call_warp(kernel_function, args, kwargs, grid, threads, params): warp_args = [] for arg in args: @@ -37,66 +34,36 @@ def call_warp(kernel_function, args, kwargs, grid, threads, params): warp_args.append(wp.from_torch(arg)) else: warp_args.append(arg) - dim = args[3] + dim = params['size'] wp.launch(kernel=kernel_function, dim=dim, inputs=warp_args) -# NOTE default verify function only works for numpy/cupy ndarray, torch Tensor or numpy scalar -# That is why we need a costum verify function for warp. -def verify(answer, result_host, atol): - correct = True - for i, ans in enumerate(answer): - if ans is None: - continue - print("res: ", type(result_host[i])) - print("expect: ", type(ans)) - res = result_host[i].numpy() - if not np.allclose(ans, res, atol=atol): - correct = False - - return correct - - - def tune(): n = 1024 # Create host arrays - a_torch = torch.arange(n, dtype=torch.float32, device="cuda") - b_torch = torch.arange(n, 0, -1, dtype=torch.float32, device="cuda") - c_torch = torch.zeros(n, dtype=torch.float32, device="cuda") - c_expect = a_torch + b_torch - + a = np.arange(n, dtype=np.float32) + b = np.arange(n, 0, -1, dtype=np.float32) + c = np.zeros(n, dtype=np.float32) + c_expect = a + b + tune_params = dict() tune_params["work_per_thread"] = [2**i for i in range(10)] - args = [a_torch, b_torch, c_torch, n] - - - ''' - results = run_kernel( - kernel_name="vec_add", - kernel_source=vec_add, - problem_size=n, - arguments=args, - params={"work_per_thread": 16}, - lang="generic_python", - call_function=call_warp, - decorator="@wp.kernel" - ) - ''' + tune_params["size"] = [n] + args = [a, b, c, n] results, env = tune_kernel( kernel_name="vec_add", - kernel_source=vec_add, + kernel_source=FULL_PATH, problem_size=n, arguments=args, tune_params=tune_params, lang="generic_python", - answer=[None, None, c_expect.cpu(), None], + answer=[None, None, c_expect, None], call_function=call_warp, - decorator="@wp.kernel" ) -tune() \ No newline at end of file +if __name__ == "__main__": + tune() \ No newline at end of file diff --git a/kernel_tuner/backends/generic_python.py b/kernel_tuner/backends/generic_python.py index 5d867fdac..39be59a23 100644 --- a/kernel_tuner/backends/generic_python.py +++ b/kernel_tuner/backends/generic_python.py @@ -104,7 +104,7 @@ def ready_argument_list(self, arguments): elif type(arg) in vars(builtins).values(): torch_args.append(arg) else: - raise TypeError("Unknown argument type: ", type(arg)) + raise TypeError("Unknown argument type: ", type(arg), ". Accepted types are Torch tenors, NumPy arrays and scalars and built-in Python types.") return torch_args @@ -150,9 +150,9 @@ def compile(self, kernel_instance, gpu_args=None): # the values of the tuning params. self.gpu_kwargs = {} if params is not None: - for name, p in self.signature.parameters.items(): - if name in params: - self.gpu_kwargs[name] = params[name] + for arg_name in self.signature: + if arg_name in params: + self.gpu_kwargs[arg_name] = params[arg_name] # Call the user-defined call function in order to compile the kernel. self.synchronize() @@ -173,7 +173,7 @@ def kernel_finished(self): """Returns True if the kernel has finished, False otherwise.""" return self.end.query() - def run_kernel(self, func, gpu_args, threads, grid, stream=None, params=None): + def run_kernel(self, func, gpu_args, threads, grid, params=None): """Runs the Python kernel passed as 'func'. :param func: A cached Python kernel for this specific kernel configuration @@ -195,10 +195,7 @@ def run_kernel(self, func, gpu_args, threads, grid, stream=None, params=None): configuration :type params: dict """ - if stream is None: - stream = self.stream - - with torch.cuda.stream(stream): + with torch.cuda.stream(self.stream): logging.debug("Running Generic Python kernel") self.call_function(func, gpu_args, self.gpu_kwargs, grid, threads, params) diff --git a/kernel_tuner/interface.py b/kernel_tuner/interface.py index e0421f847..d69ce3de0 100644 --- a/kernel_tuner/interface.py +++ b/kernel_tuner/interface.py @@ -111,8 +111,7 @@ def __deepcopy__(self, _): """The CUDA, OpenCL, HIP, C or Python DSL kernel code. It is allowed for the code to be passed as a string, a filename, a function that returns a string of code, or a list when the code needs auxilliary files. - In the case of a kernel in a Python DSL such as Triton, the reference to the - Python callable should be passed. + In the case of a kernel in a Python DSL such as Triton, only a filename is accepted. To support combined host and device code tuning, a list of filenames can be passed. The first file in the list should be the @@ -294,8 +293,7 @@ def __deepcopy__(self, _): ( """When the language Generic Python is used, a call function that calls the kernel in the Python DSL must be specified. The function must take the following arguments: - :kernel_function: the callable function with the tuning parameters inserted. If provided, the - kernel_function is decorated with the de decorator. + :kernel_function: the callable function with the tuning parameters inserted. :args: list of kernel arguments, as provided by the user in the argument. :kwargs: dictionary of kernel keyword arguments. If a tuning parameter is in the kernel signature, the tuning parameter will be added as a keyword argument. @@ -305,15 +303,6 @@ def __deepcopy__(self, _): "function", ), ), - ( - "decorator", - ( - """When the language Generic Python is used, a decorator can be provided in which the kernel source - will be wrapped internally by KernelTuner. Note that when passing the kernel to KernelTuner with the - ``kernel_source`` argument, the decorator should be removed from the kernel.""", - "string", - ), - ), ] ) @@ -619,14 +608,13 @@ def tune_kernel( objective=None, objective_higher_is_better=None, call_function=None, - decorator=None ): start_overhead_time = perf_counter() if log: logging.basicConfig(filename=kernel_name + datetime.now().strftime("%Y%m%d-%H:%M:%S") + ".log", level=log) - kernelsource = KernelSource(kernel_name, kernel_source, lang, defines, call_function, decorator) + kernelsource = KernelSource(kernel_name, kernel_source, lang, defines, call_function) _check_user_input(kernel_name, kernelsource, arguments, block_size_names) @@ -808,12 +796,11 @@ def run_kernel( quiet=False, log=None, call_function=None, - decorator=None ): if log: logging.basicConfig(filename=kernel_name + datetime.now().strftime("%Y%m%d-%H:%M:%S") + ".log", level=log) - kernelsource = KernelSource(kernel_name, kernel_source, lang, defines, call_function, decorator) + kernelsource = KernelSource(kernel_name, kernel_source, lang, defines, call_function) _check_user_input(kernel_name, kernelsource, arguments, block_size_names) diff --git a/kernel_tuner/kernel_sources/kernel_source.py b/kernel_tuner/kernel_sources/kernel_source.py index d18ecb284..51fb75d9b 100644 --- a/kernel_tuner/kernel_sources/kernel_source.py +++ b/kernel_tuner/kernel_sources/kernel_source.py @@ -18,7 +18,7 @@ class KernelSourceFactory(type): __init__ call of the corresponding subclass. In both subclasses, we call super().__init__, which triggers the __init__ call of the KernelSource class to initalize some common variables. ''' - def __call__(cls, kernel_name, kernel_sources, lang, defines=None, call_function=None, decorator=None): + def __call__(cls, kernel_name, kernel_sources, lang, defines=None, call_function=None): if lang == None: language = None else: @@ -39,13 +39,13 @@ def __call__(cls, kernel_name, kernel_sources, lang, defines=None, call_function if ks_str: return KernelSourceStr(kernel_name, kernel_sources, lang, defines) else: - return KernelSourceFn(kernel_name, kernel_sources, lang, defines, call_function, decorator) + return KernelSourceFn(kernel_name, kernel_sources, lang, defines, call_function) # Else, normal behaviour for subclasses if ks_str: return super().__call__(kernel_name, kernel_sources, lang, defines) else: - return super().__call__(kernel_name, kernel_sources, lang, defines, call_function, decorator) + return super().__call__(kernel_name, kernel_sources, lang, defines, call_function) diff --git a/kernel_tuner/kernel_sources/kernel_source_fn.py b/kernel_tuner/kernel_sources/kernel_source_fn.py index 386685377..22509ae37 100644 --- a/kernel_tuner/kernel_sources/kernel_source_fn.py +++ b/kernel_tuner/kernel_sources/kernel_source_fn.py @@ -14,21 +14,20 @@ from kernel_tuner.language import Language from kernel_tuner.kernel_sources.kernel_source import KernelSource from kernel_tuner.kernel_sources.model.prepared_kernel_source_data import PreparedKernelSourceData - +from kernel_tuner.util import get_kernel_ast, get_arg_names class KernelSourceFn(KernelSource): """ Class that holds the Python-function-based kernel sources. - There is a primary kernel source for function-based kernels in Python. The source must be - a callable. This can be a function of a class. The function should not be decorated, but - a decorator can be supplied to wrap the function. + There is a primary kernel source for function-based kernels in Python. The kernel_source + must be a path to the file where the kernel with kernel_name lives. The kernel can be + decorated by a JIT decorator. A call function to specify how the kernel should be launched must be supplied. The call function must take the following arguments: - - kernel_function: the callable function with the tuning parameters inserted. If provided, the - kernel_function is decorated with the de decorator. + - kernel_function: the callable function with the tuning parameters inserted. - args: list of kernel arguments, as provided by the user in the argument. - kwargs: dictionary of kernel keyword arguments. If a tuning parameter is in the kernel signature, the tuning parameter will be added as a keyword argument. @@ -37,10 +36,10 @@ class KernelSourceFn(KernelSource): - params: dictionary with the values of the tuning params for this configuration. """ - def __init__(self, kernel_name, kernel_source, lang, defines=None, call_function=None, decorator=None): + def __init__(self, kernel_name, kernel_source, lang, defines=None, call_function=None): super().__init__(kernel_name, kernel_source, lang, defines) if isinstance(kernel_source, list): - raise ValueError("KernelSourceFn only supports a single kernel source function") + raise ValueError("KernelSourceFn only supports a single kernel source") if self.lang == Language.GENERIC_PYTHON: if call_function is None: @@ -49,27 +48,20 @@ def __init__(self, kernel_name, kernel_source, lang, defines=None, call_function raise TypeError(f"call_function of type {type(call_function)} is not a callable object.") self.call_function = call_function # TODO ceck signature - if decorator: - if not isinstance(decorator, str): - raise TypeError(f"The decorator should be a string, got {type(decorator)} instead.") - if decorator[0] != '@': - raise ValueError(f"The decorator should start with a '@', got {decorator} instead.") - self.decorator = decorator - - self.source_kernel_fn = kernel_source # This kernel source remains the original object - self.kernel_fn = self.source_kernel_fn # This is the kernel source that we will modify - - try: - self.source = inspect.getsource(self.source_kernel_fn) - except TypeError as e: - raise TypeError( - f"{e}. Did you forget to remove a decorator before tuning?" - ) from e + if not isinstance(kernel_name, str): + raise TypeError("kernel_name should be a string, got ", type(kernel_name)) - self.signature = inspect.signature(self.source_kernel_fn) - self.source_tree = ast.parse(self.source) - self.import_nodes = self._find_import_nodes(inspect.getfile(self.source_kernel_fn)) - self.dependencies = self._find_dependencies() + source_ast = get_kernel_ast(kernel_name, kernel_source) + if isinstance(source_ast, tuple): # Class based kernel + self.source_tree = source_ast[0] + self.signature = get_arg_names(source_ast[1]) + else: + self.source_tree = source_ast + self.signature = get_arg_names(source_ast) + + self.kernel_fn = self.source_tree # This is where we will store the transformed source. + self.import_nodes = self._find_import_nodes(kernel_source) + self.dependencies = self._find_dependencies(kernel_source) def prepare_kernel_instance(self, kernel_options, params, grid, threads): @@ -123,22 +115,15 @@ def apply_params_to_source_fn(self, params): source_tree_copy = copy.deepcopy(self.source_tree) transformed_tree = transformer.visit(source_tree_copy) - # Add decorator if needed - if self.decorator: - dummy = f"{self.decorator}\ndef _dummy():\n pass\n" - decorator_node = ast.parse(dummy).body[0].decorator_list[0] - for node in transformed_tree.body: - if isinstance(node, ast.FunctionDef): - node.decorator_list.insert(0, decorator_node) - break # only apply to the top level function - # Add transformed main kernel to new module - new_module.body.extend(transformed_tree.body) + new_module.body.append(transformed_tree) # Fix locations and generate source ast.fix_missing_locations(new_module) new_source = astor.to_source(new_module) - + + #print(new_source) + # Create a unique module name and write new source to it. module_name = f'temp_kernel_module_{uuid.uuid4().hex}' with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as temp_file: @@ -151,10 +136,10 @@ def apply_params_to_source_fn(self, params): sys.modules[module_name] = temp_module spec.loader.exec_module(temp_module) new_fn = getattr(temp_module, self.kernel_name) - + return new_fn, temp_file_path - + def _find_import_nodes(self, source_file): ''' Parse kernel source file to find import statements. Return those @@ -192,16 +177,15 @@ def visit_Call(self, node): return visitor.called - def _find_dependencies(self): + def _find_dependencies(self, filepath): ''' Find all local function dependencies in the file where the kernel source is defined. Return a dicionary indexed by the function node name with the function body as value. ''' - source_file = inspect.getfile(self.source_kernel_fn) - with open(source_file, "r") as f: + with open(filepath, "r") as f: source_code = f.read() - tree = ast.parse(source_code, filename=source_file) + tree = ast.parse(source_code, filename=filepath) # Find the locally defined functions in the source file local_funcs = set() diff --git a/kernel_tuner/util.py b/kernel_tuner/util.py index c28bafa0b..24c622546 100644 --- a/kernel_tuner/util.py +++ b/kernel_tuner/util.py @@ -418,8 +418,6 @@ def detect_language(kernel_string): lang = "CUDA" elif "__kernel" in kernel_string: lang = "OpenCL" - elif "@triton.jit" in kernel_string: - lang = "Triton" else: lang = "C" return lang @@ -559,6 +557,82 @@ def get_kernel_string(kernel_source, params=None): return kernel_string +def get_kernel_ast(kernel_name, filepath): + ''' + Util function for Generic Python Backend that returns the kernel function as AST. + + :param kernel_name: name of the kernel (as passed by the user) + :type kernel_name: string + + :param filepath: the path to the file where the kernel lives. (passed by the user as kenrel_source) + :type filepath: string or Path containing a filename that points to the kernel source + + :returns: ast.FunctionDef node in case the kernel is a function or a tuple + (ast.ClassDef node, ast.FunctionDef node) in case the kernel is represented as the __call__ function + of a class (for Tilus support). + ''' + if isinstance(filepath, Path): + source = read_file(filepath) + elif isinstance(filepath, str): + with open(filepath, "r") as f: + source = f.read() + else: + raise TypeError("Error kernel_source does not specify a path to a file") + + tree = ast.parse(source) + + # Function based kernels + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef) and node.name == kernel_name: + return node + + # Class based kernels + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef) and node.name == kernel_name: + class_node = node + break + + if not class_node: + raise ValueError(f"Kernel {kernel_name} not found in {filepath}") + + # Search for __call__ function within class + for node in class_node.body: + if isinstance(node, ast.FunctionDef) and node.name == "__call__": + call_node = node + return (class_node, call_node) + + if not call_node: + raise ValueError(f"No __call__ function found inside Class {kernel_name}") + + + +def get_arg_names(func_node: ast.FunctionDef): + """ + Util to get the argument names from a Python AST function definition. This is needed to check + if there are any tunable parameters in the function signature. + """ + args = func_node.args + names = [] + + names += [arg.arg for arg in args.args] + names += [arg.arg for arg in args.posonlyargs] + names += [arg.arg for arg in args.kwonlyargs] + + # *args + if args.vararg: + names.append(args.vararg.arg) + + # **kwargs + if args.kwarg: + names.append(args.kwarg.arg) + + # Remove self for classes. + if 'self' in names: + names.remove('self') + + return names + + def get_problem_size(problem_size, params): """Compute current problem size.""" if callable(problem_size): diff --git a/test/test_generic_python_functions.py b/test/test_generic_python_functions.py index eee0d3dd1..b8cddea88 100644 --- a/test/test_generic_python_functions.py +++ b/test/test_generic_python_functions.py @@ -1,8 +1,10 @@ from .context import skip_if_no_torch -from .test_kernel_source_fn import mock_kernel, kernel_with_kwarg, call_mock +from .test_kernel_source_fn import call_mock from kernel_tuner.core import DeviceInterface, KernelInstance from kernel_tuner.kernel_sources.kernel_source import KernelSource import numpy as np +from pathlib import Path +import os try: import torch @@ -10,6 +12,7 @@ except ImportError: pass +KS_FILE = os.path.join(Path(__file__).resolve().parent, "test_kernel_source_fn.py") # Helper functions ------------------------------ @@ -30,7 +33,7 @@ def get_context(): a = 42 b = torch.randn(12, device='cuda', dtype=torch.float32) args = [a, b] - ks = KernelSource("mock_kernel", mock_kernel, "generic_python", call_function=call_mock) + ks = KernelSource("mock_kernel", KS_FILE, "generic_python", call_function=call_mock) return ks, args, params @@ -40,14 +43,24 @@ def get_context(): def test_ready_argument_list(): ks, args, params = get_context() dev = DeviceInterface(ks) + gpu_args = dev.ready_argument_list(args) assert len(args) == len(gpu_args) - for i, _ in enumerate(gpu_args): - assert value_equal(args[i], gpu_args[i]) - if type(gpu_args[i]) in (list, dict, torch.Tensor, np.ndarray): - assert gpu_args[i] is not args[i] # Assure deep copy + # arg 0: python scalar + assert isinstance(gpu_args[0], int) + assert gpu_args[0] == args[0] + + # arg 1: torch cuda tensor + assert isinstance(gpu_args[1], torch.Tensor) + assert gpu_args[1].is_cuda + + # values equal + assert torch.allclose(gpu_args[1], args[1]) + + # ensure deep copy + assert gpu_args[1] is not args[1] @skip_if_no_torch @@ -85,7 +98,7 @@ def test_gpu_kwargs(): params = {'mock_param': 64} a = torch.randn(12, device='cuda', dtype=torch.float32) args = [a] # we do not have to specify the kwarg here - ks = KernelSource("kernel_with_kwarg", kernel_with_kwarg, "generic_python", call_function=call_mock) + ks = KernelSource("kernel_with_kwarg", KS_FILE, "generic_python", call_function=call_mock) dev = DeviceInterface(ks) instance_data = ks.prepare_kernel_instance( diff --git a/test/test_kernel_source_fn.py b/test/test_kernel_source_fn.py index 85aa4e3eb..d02be9e4d 100644 --- a/test/test_kernel_source_fn.py +++ b/test/test_kernel_source_fn.py @@ -1,3 +1,5 @@ +import os +from pathlib import Path import pytest import inspect import ast @@ -8,7 +10,7 @@ from kernel_tuner.kernel_sources.kernel_source_fn import KernelSourceFn from kernel_tuner.kernel_sources.kernel_source_str import KernelSourceStr - +KS_FILE = Path(__file__).resolve() # Helper functions -------------------------------------------- def normalize_ast(src: str): @@ -41,13 +43,18 @@ def __call__(self): def another_function(self): return self.mock_param + +@functools.lru_cache() +def kernel_with_decorator(): + mock_param = 42 + return mock_param # Tests ------------------------------------------------------ def test_factory_behaviour(): # KernelSourceFn should only be created when language generic_python is supplied - ks_fn = KernelSource("mock_kernel", mock_kernel, "generic_python", call_function=call_mock) + ks_fn = KernelSource("mock_kernel", KS_FILE, "generic_python", call_function=call_mock) ks_str = KernelSource("vector_add", 'extern "C" __global__ void vector_add(float *c, float *a, float *b, int n) {', lang=None) assert isinstance(ks_fn, KernelSourceFn) @@ -59,27 +66,25 @@ def test_initiation(): Test invalid KernelSourceFn initations ''' with pytest.raises(ValueError, match=r"call_function must be supplied for language .*"): - KernelSource("mock_kernel", mock_kernel, "generic_python") + KernelSource("mock_kernel", KS_FILE, "generic_python") - with pytest.raises(TypeError, match=r".* Did you forget to remove a decorator before tuning\?"): + with pytest.raises(FileNotFoundError, match=r".* No such file or directory: .*"): KernelSource("mock_kernel", "This is a string Kernel", "generic_python", call_function=call_mock) - with pytest.raises(ValueError, match=r"KernelSourceFn only supports a single kernel source function"): - KernelSource("mock_kernel", [mock_kernel, kernel_with_dependency], "generic_python", call_function=call_mock) + with pytest.raises(TypeError, match="Error kernel_source does not specify a path to a file"): + KernelSource("mock_kernel", mock_kernel, "generic_python", call_function=call_mock) + + with pytest.raises(ValueError, match=r"KernelSourceFn only supports a single kernel source"): + KernelSource("mock_kernel", [KS_FILE, "another file"], "generic_python", call_function=call_mock) with pytest.raises(TypeError, match=r".* is not a callable object"): - KernelSource("mock_kernel", mock_kernel, "generic_python", call_function="not a function") - - with pytest.raises(ValueError, match=r"The decorator should start with a '@', got .* instead."): - KernelSource("mock_kernel", mock_kernel, "generic_python", call_function=call_mock, decorator="not a decorator") - - with pytest.raises(TypeError, match=r"The decorator should be a string, got .* instead."): - KernelSource("mock_kernel", mock_kernel, "generic_python", call_function=call_mock, decorator=mock_kernel) + KernelSource("mock_kernel", KS_FILE, "generic_python", call_function="not a function") + def test_param_subsitution(): params = {"mock_param": 128} - ks = KernelSourceFn("mock_kernel", mock_kernel, "generic_python", call_function=call_mock) + ks = KernelSourceFn("mock_kernel", KS_FILE, "generic_python", call_function=call_mock) new_kernel_fn, _ = ks.apply_params_to_source_fn(params) actual_src = inspect.getsource(new_kernel_fn) @@ -101,7 +106,7 @@ def test_imports(): ''' params = {"mock_param": 128} - ks = KernelSourceFn("mock_kernel", mock_kernel, "generic_python", call_function=call_mock) + ks = KernelSourceFn("mock_kernel", KS_FILE, "generic_python", call_function=call_mock) _, temp_path = ks.apply_params_to_source_fn(params) # Check if imports are present @@ -120,7 +125,7 @@ def test_param_substitution_class(): ''' params = {"mock_param": 128} - ks = KernelSourceFn("TilusLike", TilusLike, "generic_python", call_function=call_mock) + ks = KernelSourceFn("TilusLike", KS_FILE, "generic_python", call_function=call_mock) new_kernel_fn, _ = ks.apply_params_to_source_fn(params) @@ -142,15 +147,22 @@ def another_function(self): def test_decorator(): params = {"mock_param": 128} - ks = KernelSourceFn("mock_kernel", mock_kernel, "generic_python", call_function=call_mock, decorator="@functools.lru_cache()") + ks = KernelSourceFn("kernel_with_decorator", KS_FILE, "generic_python", call_function=call_mock) new_kernel_fn, _ = ks.apply_params_to_source_fn(params) - + actual_src = inspect.getsource(new_kernel_fn) + expected_src = """ +@functools.lru_cache() +def kernel_with_decorator(): + mock_param = 128 + return 128 +""" assert hasattr(new_kernel_fn, "__wrapped__") + assert normalize_ast(actual_src) == normalize_ast(expected_src) def test_dependencies(): params = {"mock_param": 128} - ks = KernelSourceFn("kernel_with_dependency", kernel_with_dependency, "generic_python", call_function=call_mock) + ks = KernelSourceFn("kernel_with_dependency", KS_FILE, "generic_python", call_function=call_mock) new_kernel_fn, _ = ks.apply_params_to_source_fn(params) res = new_kernel_fn() # This should not throw an error if the dependency exists in the module. From f9977f1344e8a9a73707acd2548b9649ab1654cc Mon Sep 17 00:00:00 2001 From: Imke van Ooijen Date: Sat, 7 Mar 2026 12:29:20 +0100 Subject: [PATCH 06/14] added normalize utility for the call function --- examples/generic_python/matmul/test_tilus.py | 26 ++ .../generic_python/matmul/tilelang_matmul.py | 221 +++------- .../generic_python/matmul/tilus_matmul.py | 319 ++++++--------- .../generic_python/matmul/triton_matmul.py | 371 ++++------------- .../matmul_old/tilelang_matmul.py | 212 ++++++++++ .../generic_python/matmul_old/tilus_matmul.py | 266 ++++++++++++ .../matmul_old/triton_matmul.py | 383 ++++++++++++++++++ examples/generic_python/numba_vec_add.py | 4 +- examples/generic_python/tilelang_vec_add.py | 2 +- examples/generic_python/tilus_naive_matmul.py | 2 +- .../generic_python/tilus_splitk_matmul.py | 2 +- examples/generic_python/tilus_vec_add.py | 28 +- examples/generic_python/triton_vec_add.py | 2 +- kernel_tuner/interface.py | 15 +- .../kernel_sources/kernel_source_fn.py | 2 +- kernel_tuner/util.py | 27 ++ test/test_util_functions.py | 22 + 17 files changed, 1229 insertions(+), 675 deletions(-) create mode 100644 examples/generic_python/matmul/test_tilus.py create mode 100644 examples/generic_python/matmul_old/tilelang_matmul.py create mode 100644 examples/generic_python/matmul_old/tilus_matmul.py create mode 100644 examples/generic_python/matmul_old/triton_matmul.py diff --git a/examples/generic_python/matmul/test_tilus.py b/examples/generic_python/matmul/test_tilus.py new file mode 100644 index 000000000..6d4b2430d --- /dev/null +++ b/examples/generic_python/matmul/test_tilus.py @@ -0,0 +1,26 @@ +import torch +from tilus_matmul import MatmulBasic + +sizes = [ + (65, 65, 17), + (67, 71, 19), + (1, 1, 1), + (63, 63, 15), + (129, 130, 33), +] + +matmul = MatmulBasic() +for m, n, k in sizes: + print(m, n, k) + + a = torch.randn(m, k, dtype=torch.float16, device="cuda") + b = torch.randn(k, n, dtype=torch.float16, device="cuda") + + c_actual = torch.empty(m, n, dtype=torch.float16, device="cuda") + c_expect = a @ b + + matmul(m, n, k, a, b, c_actual) + + torch.cuda.synchronize() + + torch.testing.assert_close(c_expect, c_actual, atol=1e-2, rtol=1e-2) \ No newline at end of file diff --git a/examples/generic_python/matmul/tilelang_matmul.py b/examples/generic_python/matmul/tilelang_matmul.py index 2a7b5154d..895732629 100644 --- a/examples/generic_python/matmul/tilelang_matmul.py +++ b/examples/generic_python/matmul/tilelang_matmul.py @@ -1,212 +1,87 @@ +import torch import tilelang import tilelang.language as T -import torch -from kernel_tuner import tune_kernel -#@tilelang.jit -def matmul(M, N, K, block_M, block_N, block_K, dtype: str = 'float16', accum_dtype: str = 'float32'): +@tilelang.jit +def matmul_basic(M:int, N:int, K:int, block_M:int, block_N:int, block_K:int, dtype:str="float16", accum_dtype:str="float32"): @T.prim_func - def main( + def gemm( A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype), C: T.Tensor((M, N), dtype), ): - # Define a grid with enough blocks to cover M×N - num_threads=128 - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): - - # Allocate shared memory for the current tile of A and B + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + # We do use shared memory, even though this is a basic kernel. However, you don't + # really get around this because T.gemm can not handle global memory directly. A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype) - - # Allocate a local (register) fragment for partial accumulations C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - - # Enable swizzle-based rasterization for better L2 locality - panel_size = 4 - T.use_swizzle(panel_size=panel_size, enable=True) - - # Initialize the local accumulation buffer to zero T.clear(C_local) - - num_stages=3 - - # Loop over the K dimension in block_K chunks, using a pipeline - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - # Copy from global memory to shared memory + + # We do use a pipelining optimization here, because this is 'the basic way' + # of writing for loops in TileLang. + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): T.copy(A[by * block_M, k * block_K], A_shared) T.copy(B[k * block_K, bx * block_N], B_shared) - - # Perform a matrix multiply-accumulate on the tile T.gemm(A_shared, B_shared, C_local) - # Copy the accumulated result from local memory (C_local) to global memory (C) T.copy(C_local, C[by * block_M, bx * block_N]) + + return gemm - return main +# https://github.com/tile-ai/tilelang/blob/main/examples/gemm/example_gemm_autotune.py @tilelang.jit -def matmul_with_decorator(M, N, K, block_M, block_N, block_K, dtype: str = 'float16', accum_dtype: str = 'float32'): +def matmul_opt(M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32): @T.prim_func - def main( + def gemm_autotune( A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), + B: T.Tensor((N, K), dtype), C: T.Tensor((M, N), dtype), ): - # Define a grid with enough blocks to cover M×N - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - - # Allocate shared memory for the current tile of A and B + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_K, block_N), dtype) - - # Allocate a local (register) fragment for partial accumulations + B_shared = T.alloc_shared((block_N, block_K), dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - - # Initialize the local accumulation buffer to zero + C_shared = T.alloc_shared((block_M, block_N), dtype) + T.use_swizzle(panel_size=10, enable=enable_rasteration) T.clear(C_local) - - # Loop over the K dimension in block_K chunks, using a 3-stage pipeline - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): - # Copy from global memory to shared memory + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm( + A_shared, + B_shared, + C_local, + transpose_B=True, + ) + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return gemm_autotune - # Perform a matrix multiply-accumulate on the tile - T.gemm(A_shared, B_shared, C_local) - # Copy the accumulated result from local memory (C_local) to global memory (C) - T.copy(C_local, C[by * block_M, bx * block_N]) - return main +def main(): + kernel = matmul_basic(1024, 1024, 1024, 128, 128, 32) + import torch + + a = torch.randn(1024, 1024).cuda().half() + b = torch.randn(1024, 1024).cuda().half() + c = torch.empty((1024, 1024), device='cuda', dtype=torch.float16) -def run(m, n, k): - a = torch.randn(m, k, device="cuda", dtype=torch.float16) - b = torch.randn(k, n, device="cuda", dtype=torch.float16) - c = torch.empty(m, n, device="cuda", dtype=torch.float16) - kernel = matmul(m, n, k, 128, 128, 32) kernel(a, b, c) + ref_c = a @ b - tol = m * 2**(-11) - # Validate results - torch.testing.assert_close(c, ref_c, rtol=tol, atol=tol) - - -def call_tilelang(kernel_function, args, kwargs, grid, threads, params): - compiled_kernel = kernel_function(**kwargs) - compiled_kernel(*args) - - -def time(m, n, k): - a = torch.randn(m, k, device="cuda", dtype=torch.float16) - b = torch.randn(k, n, device="cuda", dtype=torch.float16) - c = torch.empty(m, n, device="cuda", dtype=torch.float16) - c_ans = a @ b - - args = [a, b, c] - tune_params = dict() - tune_params["M"] = [m] - tune_params["K"] = [k] - tune_params["N"] = [n] - tune_params["block_M"] = [64, 128] - tune_params["block_N"] = [64, 128] - tune_params["block_K"] = [32, 64] - - results_kt, env = tune_kernel("matmul", matmul, m * n, args, tune_params, lang="generic_python", - call_function=call_tilelang, decorator="@tilelang.jit", verbose=False, iterations=100) - - import time - num_repeats = 100 - times_direct = [] - for config in results_kt: - bs_m = config["block_M"] - bs_n = config["block_N"] - bs_k = config["block_K"] - - c = torch.empty(m, n, device="cuda", dtype=torch.float16) - - kernel = matmul_with_decorator(m, n, k, bs_m, bs_n, bs_k) - kernel(a, b, c) - - torch.allclose(c.cpu(), c_ans.cpu(), atol=m * 2**(-11)) - - for i in range(num_repeats): - times = [] - - torch.cuda.synchronize() - start = time.time() - kernel(a, b, c) - torch.cuda.synchronize() - times.append(time.time() - start) - - avg_time_ms = round((1000 * sum(times) / len(times)), 3) - times_direct.append(avg_time_ms) - print(f"BLOCK_SIZE_M={bs_m}, BLOCK_SIZE_N={bs_n}, BLOCK_SIZE_K={bs_k}, time={avg_time_ms}ms") - - - import matplotlib.pyplot as plt - - # Extract times - times_kt = [cfg['time'] for cfg in results_kt] - - # x-axis labels - configs = [f"config{i}" for i in range(len(times_kt))] - x = range(len(configs)) - - plt.figure(figsize=(10,6)) - plt.plot(configs, times_kt, marker='s', label='KernelTuner') - plt.plot(configs, times_direct, marker='x', label='Direct') - plt.ylabel('Time (ms)') - plt.xlabel('Configuration') - plt.title('Kernel execution time per configuration') - plt.xticks(rotation=45) - plt.legend() - plt.grid(True) - plt.tight_layout() - plt.savefig("tilelang.png") - print("saved fig") - - -def tune(m, n, k): - a = torch.randn(m, k, device="cuda", dtype=torch.float16) - b = torch.randn(k, n, device="cuda", dtype=torch.float16) - c = torch.empty(m, n, device="cuda", dtype=torch.float16) - c_actual = a @ b - - args = [a, b, c] - tune_params = dict() - tune_params["M"] = [m] - tune_params["K"] = [k] - tune_params["N"] = [n] - tune_params["block_M"] = [64, 128, 256] - tune_params["block_N"] = [64, 128, 256] - tune_params["block_K"] = [32, 64, 128] - tune_params["num_stages"] = [2, 3, 4] - tune_params["panel_size"] = [4, 8] # equivalent to group size m in Triton - tune_params["num_threads"] = [64, 128, 256] - - restrictions = [ - # tile size budget - "block_M * block_N <= 16384", - - # aspect ratio <= 4 (no max/min allowed, so expand manually) - "block_M <= 4 * block_N", - "block_N <= 4 * block_M", - - # large K only with reasonably large M/N - "not (block_K == 128 and block_M < 64 and block_N < 64)", - ] - - tol = m * 2**(-11) - answer = [None, None, c_actual.cpu()] - - results, env = tune_kernel("matmul", matmul, m * n, args, tune_params, atol=tol, lang="generic_python", - call_function=call_tilelang, restrictions=restrictions, answer=answer, decorator="@tilelang.jit", verbose=False) + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("All check passed.") + + + if __name__ == "__main__": - #m, n, k = 1024, 1024, 1024 - m, n, k = 8192, 8192, 8192 - time(m, n, k) \ No newline at end of file + main() diff --git a/examples/generic_python/matmul/tilus_matmul.py b/examples/generic_python/matmul/tilus_matmul.py index e8996c201..dfe294761 100644 --- a/examples/generic_python/matmul/tilus_matmul.py +++ b/examples/generic_python/matmul/tilus_matmul.py @@ -1,102 +1,98 @@ -import math - -import pandas import tilus -import torch from tilus import float16, float32, int32 -from tilus.utils import benchmark_func -from kernel_tuner import tune_kernel, run_kernel - +from tilus.utils import cdiv -class MatmulV4(tilus.Script): +# This kernel is copied from the Tilus project: +# https://github.com/NVIDIA/tilus/blob/main/examples/matmul/matmul_v0.py +# +# Original example: matmul_v0.py +# Copyright (c) the Tilus authors +class MatmulBasic(tilus.Script): def __init__(self): super().__init__() - self.block_m = 128 - self.block_n = 128 + # we define three hyperparameters: ``block_m``, ``block_n``, and ``block_k`` to determine the tile size on + # m, n, and k dimensions for each `thread block` of the kernel. + self.block_m = 64 + self.block_n = 64 self.block_k = 16 - self.num_warps = 4 - self.num_stages = 4 def __call__( self, - m_size: int32, - n_size: int, - k_size: int, - a_ptr: ~float16, - b_ptr: ~float16, - c_ptr: ~float16, + m_size: int32, # the size of the m dimension of the input matrix A and output matrix C + n_size: int, # the size of the n dimension of the input matrix B and output matrix C + k_size: int, # the size of the k dimension of the input matrix A and B + a_ptr: ~float16, # the pointer to the input matrix A, which is a 2D tensor of shape [m_size, k_size] + b_ptr: ~float16, # the pointer to the input matrix B, which is a 2D tensor of shape [k_size, n_size] + c_ptr: ~float16, # the pointer to the output matrix C, which is a 2D tensor of shape [m_size, n_size] ): self.attrs.blocks = [ - self.utils.ceil_div(m_size, self.block_m), - self.utils.ceil_div(n_size, self.block_n), + cdiv(m_size, self.block_m), # the x dimension size of the grid + cdiv(n_size, self.block_n), # the y dimension size of the grid ] - self.attrs.warps = self.num_warps + self.attrs.warps = 1 # the number of warps per thread block, must be a compile-time known integer - block_m, block_n, block_k = self.block_m, self.block_n, self.block_k - offset_m: int32 = block_m * self.blockIdx.x - offset_n: int32 = block_n * self.blockIdx.y + # define two int32 variables to store the offsets of the m and n dimensions for the current thread block. + offset_m: int32 = self.block_m * self.blockIdx.x + offset_n: int32 = self.block_n * self.blockIdx.y + # create two global tensors `ga` and `gb` to represent the input matrices A and B, respectively. ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) gb = self.global_view(b_ptr, dtype=float16, shape=[k_size, n_size]) - sa = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_m, block_k]) - sb = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_k, block_n]) - acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0) - - for stage in range(self.num_stages - 1): - offset_k = stage * self.block_k - self.copy_async(src=ga, dst=sa[stage], offsets=[offset_m, offset_k]) - self.copy_async(src=gb, dst=sb[stage], offsets=[offset_k, offset_n]) - self.copy_async_commit_group() - self.copy_async_wait_group(n=self.num_stages - 2) - self.sync() + # create a register tensor `acc` to accumulate the results of the matrix multiplication. + acc = self.register_tensor( + dtype=float32, shape=[self.block_m, self.block_n], init=0.0 + ) - current_stage: int32 = 0 - preload_stage: int32 = self.num_stages - 1 - for offset_k in self.range(0, k_size, block_k, unroll=self.num_stages): - # computation for current tile - a = self.load_shared(sa[current_stage]) - b = self.load_shared(sb[current_stage]) - self.dot(a, b, acc, out=acc) + # iterate over the k dimension in blocks of size `block_k`. + for k in range(cdiv(k_size, self.block_k)): + # calculate the offset for the current block in the k dimension + offset_k = k * self.block_k - # preload the next tile of A and B into shared memory - preload_offset_k = offset_k + (self.num_stages - 1) * block_k - self.copy_async( - src=ga, - dst=sa[preload_stage], - offsets=[offset_m, preload_offset_k], + # load a block of matrix A and B into register tensors `a` and `b`. + a = self.load_global( + ga, offsets=[offset_m, offset_k], shape=[self.block_m, self.block_k] ) - self.copy_async( - src=gb, - dst=sb[preload_stage], - offsets=[preload_offset_k, offset_n], + b = self.load_global( + gb, offsets=[offset_k, offset_n], shape=[self.block_k, self.block_n] ) - self.copy_async_commit_group() - # update the stage - current_stage = (current_stage + 1) % self.num_stages - preload_stage = (preload_stage + 1) % self.num_stages - self.copy_async_wait_group(n=self.num_stages - 2) - self.sync() - - self.free_shared(sa) - self.free_shared(sb) + # perform the dot product: acc = a @ b + acc + self.dot(a, b, acc, out=acc) - casted_acc = self.cast(acc, dtype=float16) + # after the loop, we cast the accumulated result `acc` to float16 type and store it back to the output matrix C. + acc_f16 = self.cast(acc, dtype=float16) gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) - self.store_global(gc, casted_acc, offsets=[offset_m, offset_n]) + self.store_global(gc, acc_f16, offsets=[offset_m, offset_n]) -class MatmulGroupedOrdering(tilus.Script): - def __init__(self): + +# This kernel is copied from the Tilus project: +# https://github.com/NVIDIA/tilus/blob/main/examples/matmul/matmul_v5.py +# +# Original example: matmul_v5.py +# Copyright (c) the Tilus authors +# +# Modifications in file: +# - Removed auto-tuning decorators +class MatmulOpt(tilus.Script): + def __init__( + self, + num_warps=None, + block_m=None, + block_n=None, + block_k=None, + num_stages=None, + split_k_factor=None, + ): super().__init__() - self.block_m = 128 - self.block_n = 128 - self.block_k = 16 - self.num_warps = 4 - self.num_stages = 4 - self.group_size_m = 8 + self.block_m = block_m + self.block_n = block_n + self.block_k = block_k + self.num_warps = num_warps + self.num_stages = num_stages + self.split_k_factor = split_k_factor def __call__( self, @@ -107,25 +103,23 @@ def __call__( b_ptr: ~float16, c_ptr: ~float16, ): - block_m, block_n, block_k = self.block_m, self.block_n, self.block_k - - num_pid_m = self.utils.ceil_div(m_size, block_m) - num_pid_n = self.utils.ceil_div(n_size, block_n) - self.attrs.blocks = [num_pid_m * num_pid_n] - - pid = self.blockIdx.x - num_pid_in_group = self.group_size_m * num_pid_n - group_id = pid // num_pid_in_group - - first_pid_m = group_id * self.group_size_m - group_size_m = min(num_pid_m - first_pid_m, self.group_size_m) + self.attrs.blocks = [ + cdiv(m_size, self.block_m), + cdiv(n_size, self.block_n), + self.split_k_factor, + ] + self.attrs.warps = self.num_warps - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m + # the k_size for each thread block + block_k_size = ( + cdiv(cdiv(k_size, self.split_k_factor), self.block_k) * self.block_k + ) + start_offset_k = self.blockIdx.z * block_k_size + end_offset_k = min(start_offset_k + block_k_size, k_size) - self.attrs.warps = self.num_warps - offset_m: int32 = pid_m * block_m - offset_n: int32 = pid_n * block_n + block_m, block_n, block_k = self.block_m, self.block_n, self.block_k + offset_m: int32 = block_m * self.blockIdx.x + offset_n: int32 = block_n * self.blockIdx.y ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) gb = self.global_view(b_ptr, dtype=float16, shape=[k_size, n_size]) @@ -134,7 +128,7 @@ def __call__( acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0) for stage in range(self.num_stages - 1): - offset_k = stage * self.block_k + offset_k = start_offset_k + stage * self.block_k self.copy_async(src=ga, dst=sa[stage], offsets=[offset_m, offset_k]) self.copy_async(src=gb, dst=sb[stage], offsets=[offset_k, offset_n]) self.copy_async_commit_group() @@ -144,7 +138,9 @@ def __call__( current_stage: int32 = 0 preload_stage: int32 = self.num_stages - 1 - for offset_k in self.range(0, k_size, block_k, unroll=self.num_stages): + for offset_k in self.range( + start_offset_k, end_offset_k, block_k, unroll=self.num_stages + ): # computation for current tile a = self.load_shared(sa[current_stage]) b = self.load_shared(sb[current_stage]) @@ -152,16 +148,17 @@ def __call__( # preload the next tile of A and B into shared memory preload_offset_k = offset_k + (self.num_stages - 1) * block_k - self.copy_async( - src=ga, - dst=sa[preload_stage], - offsets=[offset_m, preload_offset_k], - ) - self.copy_async( - src=gb, - dst=sb[preload_stage], - offsets=[preload_offset_k, offset_n], - ) + if preload_offset_k < end_offset_k: + self.copy_async( + src=ga, + dst=sa[preload_stage], + offsets=[offset_m, preload_offset_k], + ) + self.copy_async( + src=gb, + dst=sb[preload_stage], + offsets=[preload_offset_k, offset_n], + ) self.copy_async_commit_group() # update the stage @@ -170,97 +167,41 @@ def __call__( self.copy_async_wait_group(n=self.num_stages - 2) self.sync() + # free the shared memory tensors for A and B self.free_shared(sa) self.free_shared(sb) + # cast the accumulator to float16 and change the register tensor's layout + sc = self.shared_tensor(dtype=float16, shape=[block_m, block_n]) casted_acc = self.cast(acc, dtype=float16) - gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) - self.store_global(gc, casted_acc, offsets=[offset_m, offset_n]) - - - -def main(): - headers = ["m", "n", "k", "name", "latency (ms)", "tflops"] - workloads = [ - [4096, 4096, 4096], - [1024, 1024, 14336], - ] - - rows = [] - for m, n, k in workloads: - matmul = MatmulGroupedOrdering() #MatmulV4() - - a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) - b = (torch.rand(k, n, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) - c_actual = torch.empty(m, n, dtype=torch.float16).cuda() - c_expect = a @ b - matmul(m, n, k, a, b, c_actual) - - # check correctness - torch.testing.assert_close(c_expect, c_actual) - - # benchmark - for name, func in [ - ("torch", lambda: torch.matmul(a, b, out=c_expect)), - ("tilus", lambda: matmul(m, n, k, a, b, c_actual)), - ]: - latency = benchmark_func(func, warmup=5, repeat=20) - tflops = 2 * m * n * k / latency * 1e-9 - rows.append([m, n, k, name, latency, tflops]) - - df = pandas.DataFrame(rows, columns=headers) - print(df) - -def call_tilus(kernel_function, args, kwargs, grid, threads, params): - kernel_function(*args, **kwargs) - - -def tune_matmul(m, n, k): - a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) - b = (torch.rand(k, n, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) - c_actual = torch.empty(m, n, dtype=torch.float16).cuda() - c_expect = a @ b - - size = m * n #(m, n) - args = [m, n, k, a, b, c_actual] - tune_params = dict() - tune_params["block_m"] = [32, 64, 128, 256] - tune_params["block_n"] = [32, 64, 128, 256] - tune_params["block_k"] = [32, 64, 128] - tune_params["group_size_m"] = [4, 8] - tune_params["num_stages"] = [2, 3, 4] - tune_params["num_warps"] = [4, 8] - - - restrictions = [ - # tile size budget - "block_m * block_n <= 16384", - - # aspect ratio <= 4 (no max/min allowed, so expand manually) - "block_m <= 4 * block_n", - "block_n <= 4 * block_m", - - # large K only with reasonably large M/N - "not (block_k == 128 and block_m < 64 and block_n < 64)", - - # 32x32 requires 8 warps - "not (block_m == 32 and block_n == 32 and num_warps < 8)", - ] - - - answer = [None] * 6 - answer[-1] = c_expect.cpu() - atol = 1e-2 #m * 2**(-11) - - results, env = tune_kernel("MatmulGroupedOrdering", MatmulGroupedOrdering, size, args, tune_params, grid_div_x = ["block_m", "block_n"], - answer = answer, atol=atol, restrictions=restrictions, - lang="generic_python", call_function=call_tilus, - block_size_names=["block_m", "block_n", "block_k"], strategy="simulated_annealing") - - -if __name__ == "__main__": - #m, n, k = 4096, 4096, 4096 - m, n, k = 8192, 8192, 8192 - tune_matmul(m, n, k) + self.store_shared(sc, casted_acc) + self.sync() + rc = self.load_shared(sc) + self.free_shared(sc) - #main() \ No newline at end of file + m_blocks, n_blocks = cdiv(m_size, block_m), cdiv(n_size, block_n) + gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) + if self.split_k_factor == 0: + self.store_global(gc, rc, offsets=[offset_m, offset_n]) + else: + semaphores = self.global_tensor( + dtype=int32, shape=[m_blocks, n_blocks], requires_clean=True + ) + semaphore: ~int32 = ~semaphores[self.blockIdx.x, self.blockIdx.y] + + # load and accumulate the partial result in global memory + if self.blockIdx.z > 0: + self.lock_semaphore(semaphore, value=self.blockIdx.z) + partial_rc = self.load_global( + gc, offsets=[offset_m, offset_n], shape=[block_m, block_n] + ) + self.add(rc, partial_rc, out=rc) + + # store the result to global memory and release the semaphore + self.store_global(gc, rc, offsets=[offset_m, offset_n]) + + # release the semaphore + self.sync() # we need to make sure the previous store_global is finished + self.release_semaphore( + semaphore, value=(self.blockIdx.z + 1) % self.split_k_factor + ) \ No newline at end of file diff --git a/examples/generic_python/matmul/triton_matmul.py b/examples/generic_python/matmul/triton_matmul.py index 238a61c9f..b127b0331 100644 --- a/examples/generic_python/matmul/triton_matmul.py +++ b/examples/generic_python/matmul/triton_matmul.py @@ -1,56 +1,81 @@ import torch - import triton import triton.language as tl -from kernel_tuner import tune_kernel -from kernel_tuner import run_kernel -def get_cuda_autotune_config(): - return [ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, - num_warps=8), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, - num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, - num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, - num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, - num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, - num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, - num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, - num_warps=2), - # Good config for fp8 inputs. - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, - num_warps=8), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, - num_warps=8), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, - num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, - num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, - num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, - num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, - num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, - num_warps=4) - ] +@triton.jit +def matmul_basic( + # Pointers + a_ptr, b_ptr, c_ptr, + # Matrix sizes + M, N, K, + # Strides + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + # Tile sizes + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + # Each program computes one BLOCK_SIZE_M x BLOCK_SIZE_N tile of C + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + + # Compute row/column indices for the tile + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # Create pointers to A and B tiles + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # Accumulator + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Loop over K dimension + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, + mask=(offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + + b = tl.load( + b_ptrs, + mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N), + other=0.0, + ) + + accumulator = tl.dot(a, b, accumulator) -''' -@triton.autotune( - configs=get_cuda_autotune_config(), - key=['M', 'N', 'K'], -) -''' -#@triton.jit -def matmul_kernel( + # advance K + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + # Store result + c = accumulator.to(tl.float16) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=mask) + + + + +# This kernel is copied from the Triton project: +# https://github.com/triton-lang/triton/blob/main/python/tutorials/03-matrix-multiplication.py +# +# Original example: 03-matrix-multiplication.py +# Copyright (c) the Triton authors +# +# Modifications in file: +# - Removed auto-tuning decorators +# - Removed activation function +@triton.jit +def matmul_opt( # Pointers to matrices a_ptr, b_ptr, c_ptr, # Matrix dimensions @@ -63,7 +88,7 @@ def matmul_kernel( stride_cm, stride_cn, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # - GROUP_SIZE_M: tl.constexpr, num_stages, num_warps# + GROUP_SIZE_M: tl.constexpr, # ): """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) @@ -71,6 +96,7 @@ def matmul_kernel( # ----------------------------------------------------------- # Map program ids `pid` to the block of C it should compute. # This is done in a grouped ordering to promote L2 data reuse. + # See above `L2 Cache Optimizations` section for details. pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) @@ -100,6 +126,7 @@ def matmul_kernel( # and accumulate # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetic` section for details offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) @@ -133,251 +160,3 @@ def matmul_kernel( tl.store(c_ptrs, c, mask=c_mask) -def run_matmul(m, n, k): - a = torch.rand(m, k, dtype=torch.float16).cuda() - b = torch.rand(k, n, dtype=torch.float16).cuda() - c_actual = torch.empty(m, n, dtype=torch.float16).cuda() - c_expect = a @ b - - grid = lambda META: (triton.cdiv(m, META['BLOCK_SIZE_M']) * triton.cdiv(n, META['BLOCK_SIZE_N']), ) - - matmul_kernel[grid]( - a, b, c_actual, # - m, n, k, # - a.stride(0), a.stride(1), # - b.stride(0), b.stride(1), # - c_actual.stride(0), c_actual.stride(1), - 128, 256, 64, 8 - ) - - torch.testing.assert_close(c_expect, c_actual, atol=1e-2, rtol=1e-2) - - -def run_matmul_kt(m, n, k): - a = torch.rand(m, k, dtype=torch.float16).cuda() - b = torch.rand(k, n, dtype=torch.float16).cuda() - c_actual = torch.empty(m, n, dtype=torch.float16).cuda() - c_expect = a @ b - - size = m * n - - args = [a, b, c_actual, m, n, k, a.stride(0), a.stride(1), b.stride(0), b.stride(1), - c_actual.stride(0), c_actual.stride(1)] - - params = {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M":4, "num_stages":3, "num_warps":4} - - result = run_kernel("matmul_kernel", matmul_kernel, size, args, params=params, grid_div_x=["BLOCK_SIZE_N", "BLOCK_SIZE_M"], - lang="generic_python", decorator="@triton.jit", call_function=call_triton, - block_size_names=["BLOCK_SIZE_M", "BLOCK_SIZE_N", "BLOCK_SIZE_K"]) - c_res = result[2] - - assert torch.allclose(c_res, c_expect.cpu(), atol=1e-2, rtol=1e-1) - - - - - - -def call_triton(kernel_function, args, kwargs, grid, threads, params): - #print("using grid: ", grid) - #print("args: ", args) - #print("kwargs: ", kwargs) - kernel_function[grid](*args, **kwargs) - - - - -def check_time(m, n, k): - a = torch.rand(m, k, dtype=torch.float16).cuda() - b = torch.rand(k, n, dtype=torch.float16).cuda() - c_actual = torch.empty(m, n, dtype=torch.float16).cuda() - c_expect = a @ b - - size = m * n - args = [a, b, c_actual, m, n, k, a.stride(0), a.stride(1), b.stride(0), b.stride(1), - c_actual.stride(0), c_actual.stride(1)] - tune_params = dict() - tune_params["BLOCK_SIZE_M"] = [64, 128] - tune_params["BLOCK_SIZE_N"] = [64, 128] - tune_params["BLOCK_SIZE_K"] = [32, 64] - tune_params["GROUP_SIZE_M"] = [4, 8] - tune_params["num_stages"] = [3] - tune_params["num_warps"] = [4] - - restrictions = [ - # tile size budget - "BLOCK_SIZE_M * BLOCK_SIZE_N <= 16384", - - # aspect ratio <= 4 (no max/min allowed, so expand manually) - "BLOCK_SIZE_M <= 4 * BLOCK_SIZE_N", - "BLOCK_SIZE_N <= 4 * BLOCK_SIZE_M", - - # large K only with reasonably large M/N - "not (BLOCK_SIZE_K == 128 and BLOCK_SIZE_M < 64 and BLOCK_SIZE_N < 64)", - - # 32x32 requires 8 warps - "not (BLOCK_SIZE_M == 32 and BLOCK_SIZE_N == 32 and num_warps < 8)", - ] - - grid_div = ["BLOCK_SIZE_N", "BLOCK_SIZE_M"] - - answer = [None] * 12 - answer[2] = c_expect.cpu() - atol = 1e-2 #m * 2**(-11) - - - - results_ours, _ = tune_kernel("matmul_kernel", matmul_kernel, size, args, tune_params, grid_div_x = grid_div, - restrictions=restrictions, iterations=100, answer=answer, atol=atol, - lang="generic_python", decorator="@triton.jit", call_function=call_triton, - block_size_names=["BLOCK_SIZE_M", "BLOCK_SIZE_N", "BLOCK_SIZE_K"]) - - - results_prev, _ = tune_kernel("matmul_kernel", matmul_kernel, size, args, tune_params, grid_div_x = grid_div, - restrictions=restrictions, iterations=100, answer=answer, atol=atol, - lang="triton", block_size_names=["BLOCK_SIZE_M", "BLOCK_SIZE_N", "BLOCK_SIZE_K"]) - - - - import time - num_repeats = 100 - times_direct = [] - for config in results_prev: - - bs_m = config["BLOCK_SIZE_M"] - bs_n = config["BLOCK_SIZE_N"] - bs_k = config["BLOCK_SIZE_K"] - gs_m = config["GROUP_SIZE_M"] - num_stages = config["num_stages"] - num_warps = config["num_warps"] - - c_actual = torch.empty(m, n, dtype=torch.float16).cuda() - - grid = (triton.cdiv(m, bs_m) * triton.cdiv(n, bs_n, ), ) - jit_function = triton.jit(matmul_kernel) - - - jit_function[grid]( - a, b, c_actual, - m, n, k, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c_actual.stride(0), c_actual.stride(1), - bs_m, bs_n, bs_k, gs_m, num_stages, num_warps - ) - - - torch.allclose(c_expect.cpu(), c_actual.cpu(), atol=1e-2) - - - - for i in range(num_repeats): - times = [] - - #c_actual = torch.empty(m, n, dtype=torch.float16).cuda() - - torch.cuda.synchronize() - start = time.time() - jit_function[grid]( - a, b, c_actual, - m, n, k, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c_actual.stride(0), c_actual.stride(1), - bs_m, bs_n, bs_k, gs_m, num_stages, num_warps - ) - - torch.cuda.synchronize() - times.append(time.time() - start) - - avg_time_ms = round((1000 * sum(times) / len(times)), 3) - times_direct.append(avg_time_ms) - print(f"BLOCK_SIZE_M={bs_m}, BLOCK_SIZE_N={bs_n}, BLOCK_SIZE_K={bs_k}, GROUP_SIZE_M={gs_m}, num_stages={num_stages}, num_warps={num_warps}, time={avg_time_ms}ms") - - - import matplotlib.pyplot as plt - - # Extract times - times_prev = [cfg['time'] for cfg in results_prev] - times_ours = [cfg['time'] for cfg in results_ours] - - # x-axis labels - configs = [f"config{i}" for i in range(len(times_prev))] - x = range(len(configs)) - - plt.figure(figsize=(10,6)) - plt.plot(configs, times_prev, marker='o', label='Triton tuned') - plt.plot(configs, times_ours, marker='s', label='Generic tuned') - plt.plot(configs, times_direct, marker='x', label='Direct') - plt.ylabel('Time (ms)') - plt.xlabel('Configuration') - plt.title('Kernel execution time per configuration') - plt.xticks(rotation=45) - plt.legend() - plt.grid(True) - plt.tight_layout() - plt.savefig("ouptut.png") - - - - - - -def tune_matmul(m, n, k): - a = torch.rand(m, k, dtype=torch.float16).cuda() - b = torch.rand(k, n, dtype=torch.float16).cuda() - c_actual = torch.empty(m, n, dtype=torch.float16).cuda() - c_expect = a @ b - - size = m * n - args = [a, b, c_actual, m, n, k, a.stride(0), a.stride(1), b.stride(0), b.stride(1), - c_actual.stride(0), c_actual.stride(1)] - tune_params = dict() - tune_params["BLOCK_SIZE_M"] = [64, 128, 256] - tune_params["BLOCK_SIZE_N"] = [64, 128, 256] - tune_params["BLOCK_SIZE_K"] = [32, 64, 128] - tune_params["GROUP_SIZE_M"] = [4, 8] - tune_params["num_stages"] = [2, 3, 4] - tune_params["num_warps"] = [4, 8] - - restrictions = [ - # tile size budget - "BLOCK_SIZE_M * BLOCK_SIZE_N <= 16384", - - # aspect ratio <= 4 (no max/min allowed, so expand manually) - "BLOCK_SIZE_M <= 4 * BLOCK_SIZE_N", - "BLOCK_SIZE_N <= 4 * BLOCK_SIZE_M", - - # large K only with reasonably large M/N - "not (BLOCK_SIZE_K == 128 and BLOCK_SIZE_M < 64 and BLOCK_SIZE_N < 64)", - - # 32x32 requires 8 warps - "not (BLOCK_SIZE_M == 32 and BLOCK_SIZE_N == 32 and num_warps < 8)", - ] - - grid_div = ["BLOCK_SIZE_N", "BLOCK_SIZE_M"] - - answer = [None] * 12 - answer[2] = c_expect.cpu() - - results, env = tune_kernel("matmul_kernel", matmul_kernel, size, args, tune_params, grid_div_x = grid_div, - answer = answer, atol=4.0, restrictions=restrictions, - lang="generic_python", decorator="@triton.jit", call_function=call_triton, - block_size_names=["BLOCK_SIZE_M", "BLOCK_SIZE_N", "BLOCK_SIZE_K"], strategy="simulated_annealing") - - - - - -if __name__ == "__main__": - m, n, k = 8192, 8192, 8192 - #m, n, k = 4096, 4096, 4096 - #tune_matmul(m, n, k) - #check_time(m, n, k) - #run_matmul_kt(m, n, k) - - - - - - diff --git a/examples/generic_python/matmul_old/tilelang_matmul.py b/examples/generic_python/matmul_old/tilelang_matmul.py new file mode 100644 index 000000000..2a7b5154d --- /dev/null +++ b/examples/generic_python/matmul_old/tilelang_matmul.py @@ -0,0 +1,212 @@ +import tilelang +import tilelang.language as T +import torch +from kernel_tuner import tune_kernel + + +#@tilelang.jit +def matmul(M, N, K, block_M, block_N, block_K, dtype: str = 'float16', accum_dtype: str = 'float32'): + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Define a grid with enough blocks to cover M×N + num_threads=128 + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): + + # Allocate shared memory for the current tile of A and B + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + + # Allocate a local (register) fragment for partial accumulations + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Enable swizzle-based rasterization for better L2 locality + panel_size = 4 + T.use_swizzle(panel_size=panel_size, enable=True) + + # Initialize the local accumulation buffer to zero + T.clear(C_local) + + num_stages=3 + + # Loop over the K dimension in block_K chunks, using a pipeline + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + # Copy from global memory to shared memory + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + + # Perform a matrix multiply-accumulate on the tile + T.gemm(A_shared, B_shared, C_local) + + # Copy the accumulated result from local memory (C_local) to global memory (C) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + +@tilelang.jit +def matmul_with_decorator(M, N, K, block_M, block_N, block_K, dtype: str = 'float16', accum_dtype: str = 'float32'): + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Define a grid with enough blocks to cover M×N + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + + # Allocate shared memory for the current tile of A and B + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + + # Allocate a local (register) fragment for partial accumulations + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Initialize the local accumulation buffer to zero + T.clear(C_local) + + # Loop over the K dimension in block_K chunks, using a 3-stage pipeline + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + # Copy from global memory to shared memory + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + + # Perform a matrix multiply-accumulate on the tile + T.gemm(A_shared, B_shared, C_local) + + # Copy the accumulated result from local memory (C_local) to global memory (C) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + + +def run(m, n, k): + a = torch.randn(m, k, device="cuda", dtype=torch.float16) + b = torch.randn(k, n, device="cuda", dtype=torch.float16) + c = torch.empty(m, n, device="cuda", dtype=torch.float16) + kernel = matmul(m, n, k, 128, 128, 32) + kernel(a, b, c) + ref_c = a @ b + tol = m * 2**(-11) + # Validate results + torch.testing.assert_close(c, ref_c, rtol=tol, atol=tol) + + +def call_tilelang(kernel_function, args, kwargs, grid, threads, params): + compiled_kernel = kernel_function(**kwargs) + compiled_kernel(*args) + + +def time(m, n, k): + a = torch.randn(m, k, device="cuda", dtype=torch.float16) + b = torch.randn(k, n, device="cuda", dtype=torch.float16) + c = torch.empty(m, n, device="cuda", dtype=torch.float16) + c_ans = a @ b + + args = [a, b, c] + tune_params = dict() + tune_params["M"] = [m] + tune_params["K"] = [k] + tune_params["N"] = [n] + tune_params["block_M"] = [64, 128] + tune_params["block_N"] = [64, 128] + tune_params["block_K"] = [32, 64] + + results_kt, env = tune_kernel("matmul", matmul, m * n, args, tune_params, lang="generic_python", + call_function=call_tilelang, decorator="@tilelang.jit", verbose=False, iterations=100) + + import time + num_repeats = 100 + times_direct = [] + for config in results_kt: + bs_m = config["block_M"] + bs_n = config["block_N"] + bs_k = config["block_K"] + + c = torch.empty(m, n, device="cuda", dtype=torch.float16) + + kernel = matmul_with_decorator(m, n, k, bs_m, bs_n, bs_k) + kernel(a, b, c) + + torch.allclose(c.cpu(), c_ans.cpu(), atol=m * 2**(-11)) + + for i in range(num_repeats): + times = [] + + torch.cuda.synchronize() + start = time.time() + kernel(a, b, c) + torch.cuda.synchronize() + times.append(time.time() - start) + + avg_time_ms = round((1000 * sum(times) / len(times)), 3) + times_direct.append(avg_time_ms) + print(f"BLOCK_SIZE_M={bs_m}, BLOCK_SIZE_N={bs_n}, BLOCK_SIZE_K={bs_k}, time={avg_time_ms}ms") + + + import matplotlib.pyplot as plt + + # Extract times + times_kt = [cfg['time'] for cfg in results_kt] + + # x-axis labels + configs = [f"config{i}" for i in range(len(times_kt))] + x = range(len(configs)) + + plt.figure(figsize=(10,6)) + plt.plot(configs, times_kt, marker='s', label='KernelTuner') + plt.plot(configs, times_direct, marker='x', label='Direct') + plt.ylabel('Time (ms)') + plt.xlabel('Configuration') + plt.title('Kernel execution time per configuration') + plt.xticks(rotation=45) + plt.legend() + plt.grid(True) + plt.tight_layout() + plt.savefig("tilelang.png") + print("saved fig") + + +def tune(m, n, k): + a = torch.randn(m, k, device="cuda", dtype=torch.float16) + b = torch.randn(k, n, device="cuda", dtype=torch.float16) + c = torch.empty(m, n, device="cuda", dtype=torch.float16) + c_actual = a @ b + + args = [a, b, c] + tune_params = dict() + tune_params["M"] = [m] + tune_params["K"] = [k] + tune_params["N"] = [n] + tune_params["block_M"] = [64, 128, 256] + tune_params["block_N"] = [64, 128, 256] + tune_params["block_K"] = [32, 64, 128] + tune_params["num_stages"] = [2, 3, 4] + tune_params["panel_size"] = [4, 8] # equivalent to group size m in Triton + tune_params["num_threads"] = [64, 128, 256] + + restrictions = [ + # tile size budget + "block_M * block_N <= 16384", + + # aspect ratio <= 4 (no max/min allowed, so expand manually) + "block_M <= 4 * block_N", + "block_N <= 4 * block_M", + + # large K only with reasonably large M/N + "not (block_K == 128 and block_M < 64 and block_N < 64)", + ] + + tol = m * 2**(-11) + answer = [None, None, c_actual.cpu()] + + results, env = tune_kernel("matmul", matmul, m * n, args, tune_params, atol=tol, lang="generic_python", + call_function=call_tilelang, restrictions=restrictions, answer=answer, decorator="@tilelang.jit", verbose=False) + +if __name__ == "__main__": + #m, n, k = 1024, 1024, 1024 + m, n, k = 8192, 8192, 8192 + time(m, n, k) \ No newline at end of file diff --git a/examples/generic_python/matmul_old/tilus_matmul.py b/examples/generic_python/matmul_old/tilus_matmul.py new file mode 100644 index 000000000..e8996c201 --- /dev/null +++ b/examples/generic_python/matmul_old/tilus_matmul.py @@ -0,0 +1,266 @@ +import math + +import pandas +import tilus +import torch +from tilus import float16, float32, int32 +from tilus.utils import benchmark_func +from kernel_tuner import tune_kernel, run_kernel + + + +class MatmulV4(tilus.Script): + def __init__(self): + super().__init__() + self.block_m = 128 + self.block_n = 128 + self.block_k = 16 + self.num_warps = 4 + self.num_stages = 4 + + def __call__( + self, + m_size: int32, + n_size: int, + k_size: int, + a_ptr: ~float16, + b_ptr: ~float16, + c_ptr: ~float16, + ): + self.attrs.blocks = [ + self.utils.ceil_div(m_size, self.block_m), + self.utils.ceil_div(n_size, self.block_n), + ] + self.attrs.warps = self.num_warps + + block_m, block_n, block_k = self.block_m, self.block_n, self.block_k + offset_m: int32 = block_m * self.blockIdx.x + offset_n: int32 = block_n * self.blockIdx.y + + ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) + gb = self.global_view(b_ptr, dtype=float16, shape=[k_size, n_size]) + sa = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_m, block_k]) + sb = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_k, block_n]) + acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0) + + for stage in range(self.num_stages - 1): + offset_k = stage * self.block_k + self.copy_async(src=ga, dst=sa[stage], offsets=[offset_m, offset_k]) + self.copy_async(src=gb, dst=sb[stage], offsets=[offset_k, offset_n]) + self.copy_async_commit_group() + + self.copy_async_wait_group(n=self.num_stages - 2) + self.sync() + + current_stage: int32 = 0 + preload_stage: int32 = self.num_stages - 1 + for offset_k in self.range(0, k_size, block_k, unroll=self.num_stages): + # computation for current tile + a = self.load_shared(sa[current_stage]) + b = self.load_shared(sb[current_stage]) + self.dot(a, b, acc, out=acc) + + # preload the next tile of A and B into shared memory + preload_offset_k = offset_k + (self.num_stages - 1) * block_k + self.copy_async( + src=ga, + dst=sa[preload_stage], + offsets=[offset_m, preload_offset_k], + ) + self.copy_async( + src=gb, + dst=sb[preload_stage], + offsets=[preload_offset_k, offset_n], + ) + self.copy_async_commit_group() + + # update the stage + current_stage = (current_stage + 1) % self.num_stages + preload_stage = (preload_stage + 1) % self.num_stages + self.copy_async_wait_group(n=self.num_stages - 2) + self.sync() + + self.free_shared(sa) + self.free_shared(sb) + + casted_acc = self.cast(acc, dtype=float16) + gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) + self.store_global(gc, casted_acc, offsets=[offset_m, offset_n]) + + +class MatmulGroupedOrdering(tilus.Script): + def __init__(self): + super().__init__() + self.block_m = 128 + self.block_n = 128 + self.block_k = 16 + self.num_warps = 4 + self.num_stages = 4 + self.group_size_m = 8 + + def __call__( + self, + m_size: int32, + n_size: int, + k_size: int, + a_ptr: ~float16, + b_ptr: ~float16, + c_ptr: ~float16, + ): + block_m, block_n, block_k = self.block_m, self.block_n, self.block_k + + num_pid_m = self.utils.ceil_div(m_size, block_m) + num_pid_n = self.utils.ceil_div(n_size, block_n) + self.attrs.blocks = [num_pid_m * num_pid_n] + + pid = self.blockIdx.x + num_pid_in_group = self.group_size_m * num_pid_n + group_id = pid // num_pid_in_group + + first_pid_m = group_id * self.group_size_m + group_size_m = min(num_pid_m - first_pid_m, self.group_size_m) + + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + self.attrs.warps = self.num_warps + offset_m: int32 = pid_m * block_m + offset_n: int32 = pid_n * block_n + + ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) + gb = self.global_view(b_ptr, dtype=float16, shape=[k_size, n_size]) + sa = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_m, block_k]) + sb = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_k, block_n]) + acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0) + + for stage in range(self.num_stages - 1): + offset_k = stage * self.block_k + self.copy_async(src=ga, dst=sa[stage], offsets=[offset_m, offset_k]) + self.copy_async(src=gb, dst=sb[stage], offsets=[offset_k, offset_n]) + self.copy_async_commit_group() + + self.copy_async_wait_group(n=self.num_stages - 2) + self.sync() + + current_stage: int32 = 0 + preload_stage: int32 = self.num_stages - 1 + for offset_k in self.range(0, k_size, block_k, unroll=self.num_stages): + # computation for current tile + a = self.load_shared(sa[current_stage]) + b = self.load_shared(sb[current_stage]) + self.dot(a, b, acc, out=acc) + + # preload the next tile of A and B into shared memory + preload_offset_k = offset_k + (self.num_stages - 1) * block_k + self.copy_async( + src=ga, + dst=sa[preload_stage], + offsets=[offset_m, preload_offset_k], + ) + self.copy_async( + src=gb, + dst=sb[preload_stage], + offsets=[preload_offset_k, offset_n], + ) + self.copy_async_commit_group() + + # update the stage + current_stage = (current_stage + 1) % self.num_stages + preload_stage = (preload_stage + 1) % self.num_stages + self.copy_async_wait_group(n=self.num_stages - 2) + self.sync() + + self.free_shared(sa) + self.free_shared(sb) + + casted_acc = self.cast(acc, dtype=float16) + gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) + self.store_global(gc, casted_acc, offsets=[offset_m, offset_n]) + + + +def main(): + headers = ["m", "n", "k", "name", "latency (ms)", "tflops"] + workloads = [ + [4096, 4096, 4096], + [1024, 1024, 14336], + ] + + rows = [] + for m, n, k in workloads: + matmul = MatmulGroupedOrdering() #MatmulV4() + + a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + b = (torch.rand(k, n, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + c_expect = a @ b + matmul(m, n, k, a, b, c_actual) + + # check correctness + torch.testing.assert_close(c_expect, c_actual) + + # benchmark + for name, func in [ + ("torch", lambda: torch.matmul(a, b, out=c_expect)), + ("tilus", lambda: matmul(m, n, k, a, b, c_actual)), + ]: + latency = benchmark_func(func, warmup=5, repeat=20) + tflops = 2 * m * n * k / latency * 1e-9 + rows.append([m, n, k, name, latency, tflops]) + + df = pandas.DataFrame(rows, columns=headers) + print(df) + +def call_tilus(kernel_function, args, kwargs, grid, threads, params): + kernel_function(*args, **kwargs) + + +def tune_matmul(m, n, k): + a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + b = (torch.rand(k, n, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) + c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + c_expect = a @ b + + size = m * n #(m, n) + args = [m, n, k, a, b, c_actual] + tune_params = dict() + tune_params["block_m"] = [32, 64, 128, 256] + tune_params["block_n"] = [32, 64, 128, 256] + tune_params["block_k"] = [32, 64, 128] + tune_params["group_size_m"] = [4, 8] + tune_params["num_stages"] = [2, 3, 4] + tune_params["num_warps"] = [4, 8] + + + restrictions = [ + # tile size budget + "block_m * block_n <= 16384", + + # aspect ratio <= 4 (no max/min allowed, so expand manually) + "block_m <= 4 * block_n", + "block_n <= 4 * block_m", + + # large K only with reasonably large M/N + "not (block_k == 128 and block_m < 64 and block_n < 64)", + + # 32x32 requires 8 warps + "not (block_m == 32 and block_n == 32 and num_warps < 8)", + ] + + + answer = [None] * 6 + answer[-1] = c_expect.cpu() + atol = 1e-2 #m * 2**(-11) + + results, env = tune_kernel("MatmulGroupedOrdering", MatmulGroupedOrdering, size, args, tune_params, grid_div_x = ["block_m", "block_n"], + answer = answer, atol=atol, restrictions=restrictions, + lang="generic_python", call_function=call_tilus, + block_size_names=["block_m", "block_n", "block_k"], strategy="simulated_annealing") + + +if __name__ == "__main__": + #m, n, k = 4096, 4096, 4096 + m, n, k = 8192, 8192, 8192 + tune_matmul(m, n, k) + + #main() \ No newline at end of file diff --git a/examples/generic_python/matmul_old/triton_matmul.py b/examples/generic_python/matmul_old/triton_matmul.py new file mode 100644 index 000000000..238a61c9f --- /dev/null +++ b/examples/generic_python/matmul_old/triton_matmul.py @@ -0,0 +1,383 @@ +import torch + +import triton +import triton.language as tl + +from kernel_tuner import tune_kernel +from kernel_tuner import run_kernel + +def get_cuda_autotune_config(): + return [ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, + num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, + num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, + num_warps=2), + # Good config for fp8 inputs. + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, + num_warps=8), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, + num_warps=8), + triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4) + ] + +''' +@triton.autotune( + configs=get_cuda_autotune_config(), + key=['M', 'N', 'K'], +) +''' +#@triton.jit +def matmul_kernel( + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + M, N, K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, num_stages, num_warps# +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ----------------------------------------------------------- + # Add some integer bound assumptions. + # This helps to guide integer analysis in the backend to optimize + # load/store offset address calculation + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + tl.assume(stride_am > 0) + tl.assume(stride_ak > 0) + tl.assume(stride_bn > 0) + tl.assume(stride_bk > 0) + tl.assume(stride_cm > 0) + tl.assume(stride_cn > 0) + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + # If it is out of bounds, set it to 0. + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. + accumulator = tl.dot(a, b, accumulator) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + c = accumulator.to(tl.float16) + + # ----------------------------------------------------------- + # Write back the block of the output matrix C with masks. + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def run_matmul(m, n, k): + a = torch.rand(m, k, dtype=torch.float16).cuda() + b = torch.rand(k, n, dtype=torch.float16).cuda() + c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + c_expect = a @ b + + grid = lambda META: (triton.cdiv(m, META['BLOCK_SIZE_M']) * triton.cdiv(n, META['BLOCK_SIZE_N']), ) + + matmul_kernel[grid]( + a, b, c_actual, # + m, n, k, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c_actual.stride(0), c_actual.stride(1), + 128, 256, 64, 8 + ) + + torch.testing.assert_close(c_expect, c_actual, atol=1e-2, rtol=1e-2) + + +def run_matmul_kt(m, n, k): + a = torch.rand(m, k, dtype=torch.float16).cuda() + b = torch.rand(k, n, dtype=torch.float16).cuda() + c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + c_expect = a @ b + + size = m * n + + args = [a, b, c_actual, m, n, k, a.stride(0), a.stride(1), b.stride(0), b.stride(1), + c_actual.stride(0), c_actual.stride(1)] + + params = {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M":4, "num_stages":3, "num_warps":4} + + result = run_kernel("matmul_kernel", matmul_kernel, size, args, params=params, grid_div_x=["BLOCK_SIZE_N", "BLOCK_SIZE_M"], + lang="generic_python", decorator="@triton.jit", call_function=call_triton, + block_size_names=["BLOCK_SIZE_M", "BLOCK_SIZE_N", "BLOCK_SIZE_K"]) + c_res = result[2] + + assert torch.allclose(c_res, c_expect.cpu(), atol=1e-2, rtol=1e-1) + + + + + + +def call_triton(kernel_function, args, kwargs, grid, threads, params): + #print("using grid: ", grid) + #print("args: ", args) + #print("kwargs: ", kwargs) + kernel_function[grid](*args, **kwargs) + + + + +def check_time(m, n, k): + a = torch.rand(m, k, dtype=torch.float16).cuda() + b = torch.rand(k, n, dtype=torch.float16).cuda() + c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + c_expect = a @ b + + size = m * n + args = [a, b, c_actual, m, n, k, a.stride(0), a.stride(1), b.stride(0), b.stride(1), + c_actual.stride(0), c_actual.stride(1)] + tune_params = dict() + tune_params["BLOCK_SIZE_M"] = [64, 128] + tune_params["BLOCK_SIZE_N"] = [64, 128] + tune_params["BLOCK_SIZE_K"] = [32, 64] + tune_params["GROUP_SIZE_M"] = [4, 8] + tune_params["num_stages"] = [3] + tune_params["num_warps"] = [4] + + restrictions = [ + # tile size budget + "BLOCK_SIZE_M * BLOCK_SIZE_N <= 16384", + + # aspect ratio <= 4 (no max/min allowed, so expand manually) + "BLOCK_SIZE_M <= 4 * BLOCK_SIZE_N", + "BLOCK_SIZE_N <= 4 * BLOCK_SIZE_M", + + # large K only with reasonably large M/N + "not (BLOCK_SIZE_K == 128 and BLOCK_SIZE_M < 64 and BLOCK_SIZE_N < 64)", + + # 32x32 requires 8 warps + "not (BLOCK_SIZE_M == 32 and BLOCK_SIZE_N == 32 and num_warps < 8)", + ] + + grid_div = ["BLOCK_SIZE_N", "BLOCK_SIZE_M"] + + answer = [None] * 12 + answer[2] = c_expect.cpu() + atol = 1e-2 #m * 2**(-11) + + + + results_ours, _ = tune_kernel("matmul_kernel", matmul_kernel, size, args, tune_params, grid_div_x = grid_div, + restrictions=restrictions, iterations=100, answer=answer, atol=atol, + lang="generic_python", decorator="@triton.jit", call_function=call_triton, + block_size_names=["BLOCK_SIZE_M", "BLOCK_SIZE_N", "BLOCK_SIZE_K"]) + + + results_prev, _ = tune_kernel("matmul_kernel", matmul_kernel, size, args, tune_params, grid_div_x = grid_div, + restrictions=restrictions, iterations=100, answer=answer, atol=atol, + lang="triton", block_size_names=["BLOCK_SIZE_M", "BLOCK_SIZE_N", "BLOCK_SIZE_K"]) + + + + import time + num_repeats = 100 + times_direct = [] + for config in results_prev: + + bs_m = config["BLOCK_SIZE_M"] + bs_n = config["BLOCK_SIZE_N"] + bs_k = config["BLOCK_SIZE_K"] + gs_m = config["GROUP_SIZE_M"] + num_stages = config["num_stages"] + num_warps = config["num_warps"] + + c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + + grid = (triton.cdiv(m, bs_m) * triton.cdiv(n, bs_n, ), ) + jit_function = triton.jit(matmul_kernel) + + + jit_function[grid]( + a, b, c_actual, + m, n, k, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c_actual.stride(0), c_actual.stride(1), + bs_m, bs_n, bs_k, gs_m, num_stages, num_warps + ) + + + torch.allclose(c_expect.cpu(), c_actual.cpu(), atol=1e-2) + + + + for i in range(num_repeats): + times = [] + + #c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + + torch.cuda.synchronize() + start = time.time() + jit_function[grid]( + a, b, c_actual, + m, n, k, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c_actual.stride(0), c_actual.stride(1), + bs_m, bs_n, bs_k, gs_m, num_stages, num_warps + ) + + torch.cuda.synchronize() + times.append(time.time() - start) + + avg_time_ms = round((1000 * sum(times) / len(times)), 3) + times_direct.append(avg_time_ms) + print(f"BLOCK_SIZE_M={bs_m}, BLOCK_SIZE_N={bs_n}, BLOCK_SIZE_K={bs_k}, GROUP_SIZE_M={gs_m}, num_stages={num_stages}, num_warps={num_warps}, time={avg_time_ms}ms") + + + import matplotlib.pyplot as plt + + # Extract times + times_prev = [cfg['time'] for cfg in results_prev] + times_ours = [cfg['time'] for cfg in results_ours] + + # x-axis labels + configs = [f"config{i}" for i in range(len(times_prev))] + x = range(len(configs)) + + plt.figure(figsize=(10,6)) + plt.plot(configs, times_prev, marker='o', label='Triton tuned') + plt.plot(configs, times_ours, marker='s', label='Generic tuned') + plt.plot(configs, times_direct, marker='x', label='Direct') + plt.ylabel('Time (ms)') + plt.xlabel('Configuration') + plt.title('Kernel execution time per configuration') + plt.xticks(rotation=45) + plt.legend() + plt.grid(True) + plt.tight_layout() + plt.savefig("ouptut.png") + + + + + + +def tune_matmul(m, n, k): + a = torch.rand(m, k, dtype=torch.float16).cuda() + b = torch.rand(k, n, dtype=torch.float16).cuda() + c_actual = torch.empty(m, n, dtype=torch.float16).cuda() + c_expect = a @ b + + size = m * n + args = [a, b, c_actual, m, n, k, a.stride(0), a.stride(1), b.stride(0), b.stride(1), + c_actual.stride(0), c_actual.stride(1)] + tune_params = dict() + tune_params["BLOCK_SIZE_M"] = [64, 128, 256] + tune_params["BLOCK_SIZE_N"] = [64, 128, 256] + tune_params["BLOCK_SIZE_K"] = [32, 64, 128] + tune_params["GROUP_SIZE_M"] = [4, 8] + tune_params["num_stages"] = [2, 3, 4] + tune_params["num_warps"] = [4, 8] + + restrictions = [ + # tile size budget + "BLOCK_SIZE_M * BLOCK_SIZE_N <= 16384", + + # aspect ratio <= 4 (no max/min allowed, so expand manually) + "BLOCK_SIZE_M <= 4 * BLOCK_SIZE_N", + "BLOCK_SIZE_N <= 4 * BLOCK_SIZE_M", + + # large K only with reasonably large M/N + "not (BLOCK_SIZE_K == 128 and BLOCK_SIZE_M < 64 and BLOCK_SIZE_N < 64)", + + # 32x32 requires 8 warps + "not (BLOCK_SIZE_M == 32 and BLOCK_SIZE_N == 32 and num_warps < 8)", + ] + + grid_div = ["BLOCK_SIZE_N", "BLOCK_SIZE_M"] + + answer = [None] * 12 + answer[2] = c_expect.cpu() + + results, env = tune_kernel("matmul_kernel", matmul_kernel, size, args, tune_params, grid_div_x = grid_div, + answer = answer, atol=4.0, restrictions=restrictions, + lang="generic_python", decorator="@triton.jit", call_function=call_triton, + block_size_names=["BLOCK_SIZE_M", "BLOCK_SIZE_N", "BLOCK_SIZE_K"], strategy="simulated_annealing") + + + + + +if __name__ == "__main__": + m, n, k = 8192, 8192, 8192 + #m, n, k = 4096, 4096, 4096 + #tune_matmul(m, n, k) + #check_time(m, n, k) + #run_matmul_kt(m, n, k) + + + + + + diff --git a/examples/generic_python/numba_vec_add.py b/examples/generic_python/numba_vec_add.py index 9f8dfe095..b5a2c56f3 100644 --- a/examples/generic_python/numba_vec_add.py +++ b/examples/generic_python/numba_vec_add.py @@ -16,7 +16,7 @@ def f(a, b, c): c[tid] = a[tid] + b[tid] -def call_numba(kernel_function, args, kwargs, grid, threads, params): +def call_numba(kernel_function, args, kwargs, grid, threads): numba_args = [] for arg in args: if isinstance(arg, torch.Tensor): @@ -26,8 +26,6 @@ def call_numba(kernel_function, args, kwargs, grid, threads, params): kernel_function[grid, threads](*args, **kwargs) - - def tune(): N = 100000 diff --git a/examples/generic_python/tilelang_vec_add.py b/examples/generic_python/tilelang_vec_add.py index 601b6a7b7..99463d112 100644 --- a/examples/generic_python/tilelang_vec_add.py +++ b/examples/generic_python/tilelang_vec_add.py @@ -37,7 +37,7 @@ def run_normal(): print("done") -def call_tilelang(kernel_function, args, kwargs, grid, threads, params): +def call_tilelang(kernel_function, args, kwargs): compiled_kernel = kernel_function(**kwargs) # cached, so second time only cache lookup is performed compiled_kernel(*args) diff --git a/examples/generic_python/tilus_naive_matmul.py b/examples/generic_python/tilus_naive_matmul.py index 42cdff86d..c53c22f73 100644 --- a/examples/generic_python/tilus_naive_matmul.py +++ b/examples/generic_python/tilus_naive_matmul.py @@ -68,7 +68,7 @@ def __call__( self.store_global(gc, acc_f16, offsets=[offset_m, offset_n]) -def call_tilus(kernel_function, args, kwargs, grid, threads, params): +def call_tilus(kernel_function, args, kwargs): kernel_function(*args, **kwargs) def main(): diff --git a/examples/generic_python/tilus_splitk_matmul.py b/examples/generic_python/tilus_splitk_matmul.py index be797299e..0695dcceb 100644 --- a/examples/generic_python/tilus_splitk_matmul.py +++ b/examples/generic_python/tilus_splitk_matmul.py @@ -133,7 +133,7 @@ def __call__( ) -def call_tilus(kernel_function, args, kwargs, grid, threads, params): +def call_tilus(kernel_function, args, kwargs): kernel_function(*args, **kwargs) diff --git a/examples/generic_python/tilus_vec_add.py b/examples/generic_python/tilus_vec_add.py index 938d9026c..9f842637c 100644 --- a/examples/generic_python/tilus_vec_add.py +++ b/examples/generic_python/tilus_vec_add.py @@ -8,9 +8,9 @@ FULL_PATH = Path(__file__).resolve() class VecAddV(tilus.Script): - def __init__(self): + def __init__(self, block_size_x=None): super().__init__() - self.block_size_x = 32 # number of threads per block + self.block_size_x = block_size_x # number of threads per block def __call__( self, @@ -38,7 +38,7 @@ def __call__( self.store_global(gc, c, offsets=[offset]) -def call_tilus(kernel_function, args, kwargs, grid, threads, params): +def call_tilus(kernel_function, args, kwargs): kernel_function(*args, **kwargs) @@ -91,7 +91,23 @@ def run(size): assert torch.allclose(results[-1], c_expect) +def tune_with_builtin(size): + TunedVecAdd = tilus.autotune("block_size_x", [32, 64, 128, 256, 512, 1024])(VecAddV) + vecadd = TunedVecAdd() + + a = torch.randn(size, dtype=torch.float32).cuda() + b = torch.randn(size, dtype=torch.float32).cuda() + c = torch.empty(size, dtype=torch.float32).cuda() + c_expect = a + b + + vecadd(size, a, b, c) + torch.cuda.synchronize() + + torch.testing.assert_close(c_expect, c) + + if __name__ == "__main__": - size = 100000000 - tune(size) - run(size) + size = 10000000 + #tune(size) + #run(size) + tune_with_builtin(size) diff --git a/examples/generic_python/triton_vec_add.py b/examples/generic_python/triton_vec_add.py index 2e982abfa..f69ae96d7 100644 --- a/examples/generic_python/triton_vec_add.py +++ b/examples/generic_python/triton_vec_add.py @@ -30,7 +30,7 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. tl.store(output_ptr + offsets, output, mask=mask) -def call_triton(kernel_function, args, kwargs, grid, threads, params): +def call_triton(kernel_function, args, kwargs, grid): kernel_function[grid](*args, **kwargs) diff --git a/kernel_tuner/interface.py b/kernel_tuner/interface.py index d69ce3de0..3417fb4d7 100644 --- a/kernel_tuner/interface.py +++ b/kernel_tuner/interface.py @@ -292,11 +292,12 @@ def __deepcopy__(self, _): "call_function", ( """When the language Generic Python is used, a call function that calls the kernel in the Python - DSL must be specified. The function must take the following arguments: + DSL must be specified. The function must take the following positional arguments: :kernel_function: the callable function with the tuning parameters inserted. :args: list of kernel arguments, as provided by the user in the argument. :kwargs: dictionary of kernel keyword arguments. If a tuning parameter is in the kernel signature, the tuning parameter will be added as a keyword argument. + Optionally, the following arguments can be used. The order and name of the arguments must match. :grid: the launch grid (tuple with 3 values), as computed by KernelTuner :threads: the thread block size (tuple with 3 values), as computed by KernelTuner :params: dictionary with the values of the tuning params for the specific configuration.""", @@ -614,7 +615,7 @@ def tune_kernel( if log: logging.basicConfig(filename=kernel_name + datetime.now().strftime("%Y%m%d-%H:%M:%S") + ".log", level=log) - kernelsource = KernelSource(kernel_name, kernel_source, lang, defines, call_function) + kernelsource = KernelSource(kernel_name, kernel_source, lang, defines, util.normalize_call_function(call_function)) _check_user_input(kernel_name, kernelsource, arguments, block_size_names) @@ -653,6 +654,10 @@ def tune_kernel( logging.debug("tuning_options: %s", util.get_config_string(tuning_options)) logging.debug("device_options: %s", util.get_config_string(device_options)) + # the user-specific call function may of may not have optional grid, threads and params + # arguments. We normalize it so that it always accepts all arguments. + kernel_options.call_function = util.normalize_call_function(kernel_options.call_function) + # check whether the selected strategy and options are valid strategy_string = strategy if strategy: @@ -800,7 +805,7 @@ def run_kernel( if log: logging.basicConfig(filename=kernel_name + datetime.now().strftime("%Y%m%d-%H:%M:%S") + ".log", level=log) - kernelsource = KernelSource(kernel_name, kernel_source, lang, defines, call_function) + kernelsource = KernelSource(kernel_name, kernel_source, lang, defines, util.normalize_call_function(call_function)) _check_user_input(kernel_name, kernelsource, arguments, block_size_names) @@ -809,6 +814,10 @@ def run_kernel( kernel_options = Options([(k, opts[k]) for k in _kernel_options.keys()]) device_options = Options([(k, opts[k]) for k in _device_options.keys()]) + # the user-specific call function may of may not have optional grid, threads and params + # arguments. We normalize it so that it always accepts all arguments. + kernel_options.call_function = util.normalize_call_function(kernel_options.call_function) + # detect language and create the right device function interface dev = core.DeviceInterface(kernelsource, iterations=1, **device_options) diff --git a/kernel_tuner/kernel_sources/kernel_source_fn.py b/kernel_tuner/kernel_sources/kernel_source_fn.py index 22509ae37..3eb27d86b 100644 --- a/kernel_tuner/kernel_sources/kernel_source_fn.py +++ b/kernel_tuner/kernel_sources/kernel_source_fn.py @@ -46,7 +46,7 @@ def __init__(self, kernel_name, kernel_source, lang, defines=None, call_function raise ValueError("call_function must be supplied for language Generic Python") if not callable(call_function): raise TypeError(f"call_function of type {type(call_function)} is not a callable object.") - self.call_function = call_function # TODO ceck signature + self.call_function = call_function if not isinstance(kernel_name, str): raise TypeError("kernel_name should be a string, got ", type(kernel_name)) diff --git a/kernel_tuner/util.py b/kernel_tuner/util.py index 24c622546..beae6b5d5 100644 --- a/kernel_tuner/util.py +++ b/kernel_tuner/util.py @@ -978,6 +978,33 @@ def has_kw_argument(func, name): return lambda answer, result_host, atol: v(answer, result_host) +def normalize_call_function(v): + """Normalize a user-specified call function for language Generic Python. + + The user-specified function has three required positional arguments (kernel_function, + args, kwargs), and three optional keyword arguments: grid, threads and params. The + optional keyword arguments should appear in that order. We normalize the function + so that it always accepts grid, threads and params. + + Undefined behaviour if the passed function does not match the required signatures. + """ + def has_kw_argument(func, name): + sig = signature(func) + return name in sig.parameters + + if v is None: + return None + + if has_kw_argument(v, "grid"): + if has_kw_argument(v, "threads"): + if has_kw_argument(v, "params"): + return v + return lambda kernel_function, args, kwargs, grid, threads, params: v(kernel_function, args, kwargs, grid, threads) + return lambda kernel_function, args, kwargs, grid, threads, params: v(kernel_function, args, kwargs, grid) + return lambda kernel_function, args, kwargs, grid, threads, params: v(kernel_function, args, kwargs) + + + def parse_restrictions( restrictions: list[str], tune_params: dict, monolithic=False, format=None ) -> list[tuple[Union[Constraint, str], list[str]]]: diff --git a/test/test_util_functions.py b/test/test_util_functions.py index 4a1858f37..389da8b81 100644 --- a/test/test_util_functions.py +++ b/test/test_util_functions.py @@ -595,6 +595,28 @@ def verify2(answer, result_host, atol): assert v(1, 2, atol=3) + +def test_normalize_call_function_none(): + assert normalize_call_function(None) is None + +@pytest.mark.parametrize( + "func", + [ + lambda f, a, k: f(*a, **k), + lambda f, a, k, grid: f(*a, **k), + lambda f, a, k, grid, threads: f(*a, **k), + lambda f, a, k, grid, threads, params: f(*a, **k), + ], +) +def test_normalize_call_function(func): + v = normalize_call_function(func) + + def kernel(x): + return x + 1 + + assert v(kernel, (1,), {}, grid=1, threads=2, params=3) == 2 + + def test_process_cache(): def assert_open_cachefile_is_correctly_parsed(cache): with open(cache, "r") as cachefile: From 976da390af4fd1b5b821c4bc11510d5478ec8f7b Mon Sep 17 00:00:00 2001 From: "I.C. van Ooijen" Date: Sun, 8 Mar 2026 11:34:15 +0100 Subject: [PATCH 07/14] taichi matmul exmple --- .../generic_python/matmul/taichi_matmul.py | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 examples/generic_python/matmul/taichi_matmul.py diff --git a/examples/generic_python/matmul/taichi_matmul.py b/examples/generic_python/matmul/taichi_matmul.py new file mode 100644 index 000000000..6eca9d21e --- /dev/null +++ b/examples/generic_python/matmul/taichi_matmul.py @@ -0,0 +1,56 @@ +import torch +import taichi as ti + +from kernel_tuner import tune_kernel +from pathlib import Path + +FULL_PATH = Path(__file__).resolve() + +ti.init(arch=ti.gpu) + +# NOTE tuning works on vibranium. Taichi does not work on DAS6 +# TODO make sure this is zero-copy +@ti.kernel +def matmul(A: ti.types.ndarray(), B: ti.types.ndarray(), C: ti.types.ndarray()): + BLOCK_DIM = 16 + ti.loop_config(block_dim=BLOCK_DIM) + + K_dim = A.shape[1] + + for i, j in C: + sum = 0.0 + for k in range(K_dim): + sum += A[i, k] * B[k, j] + C[i, j] = sum + + +def call_taichi(kernel_function, args, kwargs): + kernel_function(*args, **kwargs) + +def tune(M, N, K): + torch_A = torch.rand((N, K), device='cuda', dtype=torch.float32) + torch_B = torch.rand((K, M), device='cuda', dtype=torch.float32) + torch_C = torch.empty((N, M), device='cuda', dtype=torch.float32) + + size = M * N + args = [torch_A, torch_B, torch_C] + tune_params = {"BLOCK_DIM": {4, 8, 16, 32, 64, 128, 256, 512, 1024}} + + answer = [None, None, (torch_A @ torch_B).cpu()] + + results, env = tune_kernel( + kernel_name="matmul", + kernel_source=FULL_PATH, + problem_size=size, + arguments=args, + tune_params=tune_params, + answer=answer, + lang="generic_python", + call_function=call_taichi, + ) + +if __name__ == "__main__": + tune(128, 128, 128) + + + From 51f5a6a4bed48b3b9e8c2d55492bc5c1c46afeb4 Mon Sep 17 00:00:00 2001 From: Imke van Ooijen Date: Tue, 7 Apr 2026 10:56:26 +0200 Subject: [PATCH 08/14] working on experiments --- examples/__init__.py | 0 examples/generic_python/call_functions.py | 29 + examples/generic_python/cupy_copy.py | 19 + examples/generic_python/cutile_vec_add.py | 62 ++ .../flash_attention/tilelang_attention.py | 152 +++++ .../flash_attention/tilus_attention.py | 611 ++++++++++++++++++ .../flash_attention/triton_attention.py | 276 ++++++++ examples/generic_python/loopy_example.py | 38 ++ examples/generic_python/matmul/cupy_matmul.py | 105 +++ examples/generic_python/matmul/cute_matmul.py | 135 ++++ .../generic_python/matmul/helion_matmul.py | 20 + .../generic_python/matmul/numba_matmul.py | 148 +++++ examples/generic_python/matmul/test.py | 278 ++++++++ .../generic_python/matmul/tilelang_matmul.py | 18 +- .../generic_python/matmul/tilus_matmul.py | 3 +- examples/generic_python/matmul/warp_matmul.py | 131 ++++ .../normalization/tilelang_norm.py | 78 +++ .../normalization/tilus_norm.py | 132 ++++ .../normalization/triton_norm.py | 66 ++ examples/generic_python/pallas_vec_add.py | 23 + examples/generic_python/tilus_vec_add.py | 7 +- examples/generic_python/triton_vec_add.py | 2 + kernel_tuner/backends/generic_python.py | 5 +- kernel_tuner/core.py | 6 + .../kernel_sources/kernel_source_fn.py | 10 +- 25 files changed, 2336 insertions(+), 18 deletions(-) create mode 100644 examples/__init__.py create mode 100644 examples/generic_python/call_functions.py create mode 100644 examples/generic_python/cupy_copy.py create mode 100644 examples/generic_python/cutile_vec_add.py create mode 100644 examples/generic_python/flash_attention/tilelang_attention.py create mode 100644 examples/generic_python/flash_attention/tilus_attention.py create mode 100644 examples/generic_python/flash_attention/triton_attention.py create mode 100644 examples/generic_python/loopy_example.py create mode 100644 examples/generic_python/matmul/cupy_matmul.py create mode 100644 examples/generic_python/matmul/cute_matmul.py create mode 100644 examples/generic_python/matmul/helion_matmul.py create mode 100644 examples/generic_python/matmul/numba_matmul.py create mode 100644 examples/generic_python/matmul/test.py create mode 100644 examples/generic_python/matmul/warp_matmul.py create mode 100644 examples/generic_python/normalization/tilelang_norm.py create mode 100644 examples/generic_python/normalization/tilus_norm.py create mode 100644 examples/generic_python/normalization/triton_norm.py create mode 100644 examples/generic_python/pallas_vec_add.py diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/generic_python/call_functions.py b/examples/generic_python/call_functions.py new file mode 100644 index 000000000..6b321c4a2 --- /dev/null +++ b/examples/generic_python/call_functions.py @@ -0,0 +1,29 @@ +import torch + +def call_tilus(kernel_function, args, kwargs): + kernel_function(*args, **kwargs) + +def call_triton(kernel_function, args, kwargs, grid, threads, params): + if "num_warps" in params.keys(): + kwargs["num_warps"] = params["num_warps"] + if "num_stages" in params.keys(): + kwargs["num_stages"] = params["num_stages"] + + torch.cuda.nvtx.range_push("kt call") + kernel_function[grid](*args, **kwargs) + torch.cuda.nvtx.range_pop() + +def call_tilelang(kernel_function, args, kwargs): + compiled_kernel = kernel_function(**kwargs) + compiled_kernel(*args) + +def call_numba(kernel_function, args, kwargs, grid, threads): + from numba import cuda + + numba_args = [] + for arg in args: + if isinstance(arg, torch.Tensor): + numba_args.append(cuda.as_cuda_array(arg)) + else: + numba_args.append(arg) + kernel_function[grid, threads](*args, **kwargs) \ No newline at end of file diff --git a/examples/generic_python/cupy_copy.py b/examples/generic_python/cupy_copy.py new file mode 100644 index 000000000..8d9110209 --- /dev/null +++ b/examples/generic_python/cupy_copy.py @@ -0,0 +1,19 @@ +import cupy +from cupyx import jit + +@jit.rawkernel() +def elementwise_copy(x, y, size): + tid = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x + ntid = jit.gridDim.x * jit.blockDim.x + for i in range(tid, size, ntid): + y[i] = x[i] + +size = cupy.uint32(2 ** 22) +x = cupy.random.normal(size=(size,), dtype=cupy.float32) +y = cupy.empty((size,), dtype=cupy.float32) + +elementwise_copy((128,), (1024,), (x, y, size)) # RawKernel style +assert (x == y).all() + +elementwise_copy[128, 1024](x, y, size) # Numba style +assert (x == y).all() \ No newline at end of file diff --git a/examples/generic_python/cutile_vec_add.py b/examples/generic_python/cutile_vec_add.py new file mode 100644 index 000000000..e95302045 --- /dev/null +++ b/examples/generic_python/cutile_vec_add.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +Example demonstrating simple vector addition. +Shows how to perform elementwise operations on vectors. +Does not work on das vu, because we need cuda 13.1 +""" + +import cupy as cp +import numpy as np +import cuda.tile as ct + + +@ct.kernel +def vector_add(a, b, c, tile_size: ct.Constant[int]): + # Get the 1D pid + pid = ct.bid(0) + + # Load input tiles + a_tile = ct.load(a, index=(pid,), shape=(tile_size,)) + b_tile = ct.load(b, index=(pid,), shape=(tile_size,)) + + # Perform elementwise addition + result = a_tile + b_tile + + # Store result + ct.store(c, index=(pid, ), tile=result) + + +def test(): + # Create input data + vector_size = 2**12 + tile_size = 2**4 + grid = (ct.cdiv(vector_size, tile_size), 1, 1) + + rng = cp.random.default_rng() + a = rng.random(vector_size) + b = rng.random(vector_size) + c = cp.zeros_like(a) + + # Launch kernel + ct.launch(cp.cuda.get_current_stream(), + grid, # 1D grid of processors + vector_add, + (a, b, c, tile_size)) + + # Copy to host only to compare + a_np = cp.asnumpy(a) + b_np = cp.asnumpy(b) + c_np = cp.asnumpy(c) + + # Verify results + expected = a_np + b_np + np.testing.assert_array_almost_equal(c_np, expected) + + print("✓ vector_add_example passed!") + + +if __name__ == "__main__": + test() \ No newline at end of file diff --git a/examples/generic_python/flash_attention/tilelang_attention.py b/examples/generic_python/flash_attention/tilelang_attention.py new file mode 100644 index 000000000..5b2be204a --- /dev/null +++ b/examples/generic_python/flash_attention/tilelang_attention.py @@ -0,0 +1,152 @@ +# example taken from https://github.com/tile-ai/tilelang/blob/main/examples/flash_attention/example_mha_fwd_bshd.py +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +import itertools +import argparse +from functools import partial + + +def get_configs(): + iter_params = dict(block_M=[64], block_N=[64], num_stages=[1], threads=[128]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit( + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, seq_len, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128): + scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) + shape = [batch, seq_len, heads, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def main( + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + Output: T.Tensor(shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = ( + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + + return main + + +def ref_program(Q, K, V, is_causal): + dim = Q.size(-1) + scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) + return output + + +def main( + batch: int = 8, + heads: int = 32, + seq_len: int = 4096, + dim: int = 128, + is_causal: bool = False, + tune: bool = False, +): + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + if not tune: + #kernel = flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=1, threads=128) + # Changed block size so fits in shared mem + kernel = flashattn(batch, heads, seq_len, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128) + ref_program_processed = partial(ref_program, is_causal=is_causal) + profiler = kernel.get_profiler() + profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(ref_program_processed, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + else: + best_result = flashattn(batch, heads, seq_len, dim, is_causal) + best_latency = best_result.latency + best_config = best_result.config + ref_latency = best_result.ref_latency + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + print(f"Ref latency: {ref_latency}") + + +if __name__ == "__main__": + + main() \ No newline at end of file diff --git a/examples/generic_python/flash_attention/tilus_attention.py b/examples/generic_python/flash_attention/tilus_attention.py new file mode 100644 index 000000000..5169d9316 --- /dev/null +++ b/examples/generic_python/flash_attention/tilus_attention.py @@ -0,0 +1,611 @@ +# Code taken from https://github.com/NVIDIA/tilus/blob/main/examples/attention/flash_attention_v3.py +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +import numpy as np +import pandas as pd +import tilus +import torch +from tilus import boolean, f32, int32, void_p +from hidet.ir import DataType +from tilus.ir import RegisterTensor, SharedTensor +from tilus.ir.tensor import GlobalTensor +from tilus.utils import benchmark_func, cdiv + +pd.options.display.max_columns = None +pd.options.display.width = 1000 + + +@tilus.autotune("num_warps", [4, 8]) +@tilus.autotune("block_q", [32, 64, 128]) +@tilus.autotune("block_kv", [32, 64, 128]) +@tilus.autotune("split_kv", [-1, 512, 1024, 4096]) +@tilus.autotune("keep_q_in_regs", [False, True]) +class FlashAttention(tilus.Script): + LOG2_E = 1.4426950408889634 # log2(e) + + debug_schedule = dict( + num_warps=4, + block_q=64, + block_kv=64, + split_kv=-1, + keep_q_in_regs=False, + ) + + def __init__( + self, + dtype: DataType, + num_heads: int, + num_heads_kv: int, + head_size: int, + num_warps: int, + block_q: int, + block_kv: int, + split_kv: int, + keep_q_in_regs: bool, + ): + super().__init__() + self.dtype: DataType = dtype + self.num_heads = num_heads + self.num_heads_kv = num_heads_kv + self.head_size = head_size + self.num_warps = num_warps + self.block_q = block_q + self.block_kv = block_kv + self.split_kv = split_kv + self.keep_q_in_regs = keep_q_in_regs + self.score_scale = float(1.0 / np.sqrt(head_size)) + self.group_heads = num_heads // num_heads_kv + + assert self.split_kv % self.block_kv == 0 or self.split_kv == -1, ( + "split_kv must be a multiple of block_kv or -1" + ) + + # determine layout + self.sv_config = self.cuda.resolve_dot_config( + dtype, f32, m=block_q, n=head_size, k=block_kv, warp_m=num_warps, warp_n=1 + ) + + def apply_mask(self, score: RegisterTensor, q_offset: int32, kv_offset: int32): + mask = self.register_tensor( + dtype=boolean, + shape=[self.block_q, self.block_kv], + init=lambda i, j: i + q_offset >= j + kv_offset, + ) + self.assign(score, score + self.where(mask, x=0.0, y=-1e6)) + + def softmax_rescale( + self, + score: RegisterTensor, + m: RegisterTensor, + l: RegisterTensor, + o: RegisterTensor, + ) -> RegisterTensor: + scale = self.score_scale * self.LOG2_E # log2(e) * score_scale + cur_m = self.max(score, dim=1, keepdim=True) * scale # [block_q, 1] + new_m = self.maximum(m, cur_m) # [block_q, 1] + rp = self.exp2(score * scale - new_m) # [block_q, block_kv] + m_scale = self.exp2(m - new_m) + self.assign(o, o * m_scale) + self.assign(l, l * m_scale + self.sum(rp, dim=1, keepdim=True)) + self.assign(m, new_m) + return rp.to(self.dtype) + + def attention_iteration( + self, + bs: int32, + kv_offset: int32, + q_offset: int32, + head: int32, + gk: GlobalTensor, + gv: GlobalTensor, + sq: SharedTensor, # f16[block_q, head_size] + rq: RegisterTensor, # f16[block_q, head_size] + sk: SharedTensor, # f16[block_kv, head_size], + sv: SharedTensor, # f16[block_kv, head_size], + o: RegisterTensor, # f32[block_q, head_size] + m: RegisterTensor, # f32[block_q, 1] + l: RegisterTensor, # f32[block_q, 1] + check_bounds: bool, + ): + if not self.keep_q_in_regs: + self.load_shared(sq, out=rq) + # wait for the async copy of k to finish + self.copy_async_wait_group(0) + self.sync() + self.copy_async( + gv, + sv, + offsets=[bs, kv_offset, head // self.group_heads, 0], + dims=[1, 3], + check_bounds=check_bounds, + ) + self.copy_async_commit_group() + + # issue the async copy for v and perform dot(q, k) + rk = self.load_shared(sk) # [block_kv, head_size] + score = self.dot(rq, rk.transpose(), acc_dtype=f32) # [block_q, block_kv] + + if check_bounds: + self.apply_mask(score, q_offset, kv_offset) # apply causal mask + + # wait for the async copy of v to finish + self.copy_async_wait_group(0) + self.sync() + self.copy_async( + gk, + sk, + offsets=[bs, kv_offset + self.block_kv, head // self.group_heads, 0], + dims=[1, 3], + check_bounds=check_bounds, + ) + self.copy_async_commit_group() + + # load v to register + rv = self.load_shared(sv) # [block_kv, head_size] + + # online softmax + rp = self.softmax_rescale(score, m=m, l=l, o=o) + + # pv + cur_o = self.dot(rp, rv, acc_dtype=f32) # [block_q, head_size] + self.annotate_layout(cur_o, self.sv_config.lc) + self.assign(o, o + cur_o) + + def main_loop( + self, + gq: GlobalTensor, + gk: GlobalTensor, + gv: GlobalTensor, + o: RegisterTensor, + m: RegisterTensor, + l: RegisterTensor, + ): + # calculate offsets + q_offset = self.blockIdx.x * self.block_q + kv_start_offset = 0 if self.split_kv == -1 else self.blockIdx.y * self.split_kv + + if q_offset + self.block_q <= kv_start_offset: + return + + head = self.blockIdx.z % self.num_heads + bs = self.blockIdx.z // self.num_heads + + sq = self.shared_tensor(dtype=self.dtype, shape=[self.block_q, self.head_size]) + sk = self.shared_tensor(dtype=self.dtype, shape=[self.block_kv, self.head_size]) + sv = self.shared_tensor(dtype=self.dtype, shape=[self.block_kv, self.head_size]) + + rq = self.register_tensor(dtype=self.dtype, shape=[self.block_q, self.head_size]) + + # copy q to shared memory + self.copy_async( + gq, sq, offsets=[bs, q_offset, head, 0], dims=[1, 3], check_bounds=True + ) + self.copy_async_wait_all() + self.sync() + + # copy q to registers if not keeping in shared memory + if self.keep_q_in_regs: + self.load_shared(sq, out=rq) # [block_q, head_size] + self.free_shared(sq) + + # issue a copy of gk + self.copy_async(gk, sk, offsets=[bs, 0, head // self.group_heads, 0], dims=[1, 3]) + self.copy_async_commit_group() + + kv_offset_inner_end = (q_offset + 1) // self.block_kv * self.block_kv + if self.split_kv != -1: + kv_offset_inner_end = min( + kv_offset_inner_end, kv_start_offset + self.split_kv + ) + for kv_offset in range(kv_start_offset, kv_offset_inner_end, self.block_kv): + self.attention_iteration( + bs, + kv_offset, + q_offset, + head, + gk, + gv, + sq, + rq, + sk, + sv, + o, + m, + l, + check_bounds=False, + ) + + kv_offset_end = q_offset + self.block_q + if self.split_kv != -1: + kv_offset_end = min(kv_offset_end, kv_start_offset + self.split_kv) + for kv_offset in range(kv_offset_inner_end, kv_offset_end, self.block_kv): + self.attention_iteration( + bs, + kv_offset, + q_offset, + head, + gk, + gv, + sq, + rq, + sk, + sv, + o, + m, + l, + check_bounds=True, + ) + + self.copy_async_wait_group(0) + self.sync() + self.free_shared(sk) + self.free_shared(sv) + if not self.keep_q_in_regs: + self.free_shared(sq) + + def store_back( + self, + o: RegisterTensor, + l: RegisterTensor, + m: RegisterTensor, + o_ptr: void_p, + batch_size: int, + q_len: int32, + ): + # o: [block_q, head_size] + # m: [block_q, 1] + # l: [block_q, 1] + go = self.global_view( + o_ptr, + dtype=self.dtype, + shape=[batch_size, q_len, self.num_heads, self.head_size], + ) + o = o / l + o_f16 = self.cast(o, dtype=self.dtype) # [block_q, head_size] + so = self.shared_tensor(dtype=self.dtype, shape=[self.block_q, self.head_size]) + + head = self.blockIdx.z % self.num_heads + q_offset = self.blockIdx.x * self.block_q + bs = self.blockIdx.z // self.num_heads + + if self.split_kv == -1: + self.store_shared(so, o_f16) + self.sync() + self.store_global( + go, + self.load_shared(so), + offsets=[bs, q_offset, head, 0], + dims=[1, 3], + ) + else: + num_q_blocks = cdiv(q_len, self.block_q) + semaphores = self.global_tensor( + dtype=int32, + shape=[num_q_blocks, batch_size, self.num_heads], + requires_clean=True, + ) + gm = self.global_tensor( + dtype=f32, + shape=[num_q_blocks, batch_size, self.num_heads, self.block_q], + requires_clean=False, + ) + gl = self.global_tensor( + dtype=f32, + shape=[num_q_blocks, batch_size, self.num_heads, self.block_q], + requires_clean=False, + ) + semaphore = semaphores[self.blockIdx.x, bs, head].item_ptr() + + sm = self.shared_tensor(dtype=f32, shape=[self.block_q]) + sl = self.shared_tensor(dtype=f32, shape=[self.block_q]) + + self.lock_semaphore(semaphore, value=self.blockIdx.y) + + # load previous o, m and l and merge with the current results + if self.blockIdx.y > 0: + self.copy_async(gm, sm, offsets=[self.blockIdx.x, bs, head, 0], dims=[3]) + self.copy_async(gl, sl, offsets=[self.blockIdx.x, bs, head, 0], dims=[3]) + self.copy_async(go, so, offsets=[bs, q_offset, head, 0], dims=[1, 3]) + self.copy_async_wait_all() + self.sync() + lhs_o = self.load_shared(so) + lhs_m = self.load_shared(sm).unsqueeze(1) + lhs_l = self.load_shared(sl).unsqueeze(1) + rhs_o = o_f16 + rhs_m = m + rhs_l = l + m = self.maximum(lhs_m, rhs_m) + lhs_ll = lhs_l * self.exp(lhs_m - m) + rhs_ll = rhs_l * self.exp(rhs_m - m) + l = lhs_ll + rhs_ll + o_f16 = lhs_o * self.cast( + lhs_ll / l, dtype=self.dtype + ) + rhs_o * self.cast(rhs_ll / l, dtype=self.dtype) + self.sync() + + # store the results to so and load it + self.store_shared(so, o_f16) + self.store_shared(sm, m.squeeze(dim=1)) + self.store_shared(sl, l.squeeze(dim=1)) + self.sync() + + # store the results to global memory and release the semaphore + self.store_global( + go, + self.load_shared(so), + offsets=[bs, q_offset, head, 0], + dims=[1, 3], + ) + self.store_global( + gm, + self.load_shared(sm), + offsets=[self.blockIdx.x, bs, head, 0], + dims=[3], + ) + self.store_global( + gl, + self.load_shared(sl), + offsets=[self.blockIdx.x, bs, head, 0], + dims=[3], + ) + self.sync() + + self.free_shared(sm) + self.free_shared(sl) + + # release the semaphore + self.release_semaphore( + semaphore, + value=self.blockIdx.y + 1 + if (self.blockIdx.y + 1) * self.split_kv < q_offset + self.block_q + else 0, + ) + self.free_shared(so) + + def __call__( + self, + batch_size: int, + q_len: int32, + kv_len: int32, + q_ptr: void_p, + k_ptr: void_p, + v_ptr: void_p, + o_ptr: void_p, + ): + self.attrs.warps = self.num_warps + self.attrs.blocks = ( + cdiv(q_len, self.block_q), + cdiv(kv_len, self.split_kv) if self.split_kv != -1 else 1, + self.num_heads * batch_size, + ) + + gq = self.global_view( + q_ptr, + dtype=self.dtype, + shape=[batch_size, q_len, self.num_heads, self.head_size], + ) + gk = self.global_view( + k_ptr, + dtype=self.dtype, + shape=[batch_size, kv_len, self.num_heads_kv, self.head_size], + ) + gv = self.global_view( + v_ptr, + dtype=self.dtype, + shape=[batch_size, kv_len, self.num_heads_kv, self.head_size], + ) + + o = self.register_tensor( + dtype=f32, shape=[self.block_q, self.head_size], init=0.0 + ) + m = self.register_tensor( + dtype=f32, shape=[self.block_q, 1], init=-1e6 + ) # rowmax(score) + l = self.register_tensor( + dtype=f32, shape=[self.block_q, 1], init=0.0 + ) # rowsum(exp(score - m)) + + self.main_loop(gq, gk, gv, o, m, l) + + self.store_back(o, l, m, o_ptr=o_ptr, batch_size=batch_size, q_len=q_len) + + +def flash_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, +): + """ + Flash attention function for variable length sequences. + + Parameters + ---------- + q: torch.Tensor + The query tensor of shape (bs, seqlen, num_heads, head_size). + + k: torch.Tensor + The key tensor of shape (bs, seqlen, num_heads_kv, head_size). + + v: torch.Tensor + The value tensor of shape (bs, seqlen, num_heads_kv, head_size). + + Returns + ------- + o: torch.Tensor + The output tensor of shape (bs, seqlen, num_heads, head_size). + """ + out = torch.empty_like(q) + FlashAttention( + dtype=tilus.float16, + num_heads=q.size(2), + num_heads_kv=k.size(2), + head_size=q.size(3), + )(q.size(0), q.size(1), k.size(1), q, k, v, out) + return out + + +def flash_attention_reference( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, +): + bs, seqlen, num_heads, head_size = q.size() + _, _, num_heads_kv, _ = k.size() + assert q.size(0) == k.size(0) == v.size(0), "Batch size must match for q, k, and v." + assert q.size(1) == k.size(1) == v.size(1), ( + "Sequence length must match for q, k, and v." + ) + assert q.size(3) == k.size(3) == v.size(3), "Head size must match for q, k, and v." + assert k.size(2) == v.size(2), "Number of heads in k and v must match." + assert num_heads % num_heads_kv == 0, ( + "Number of heads must be divisible by number of kv heads." + ) + + k = torch.repeat_interleave(k, num_heads // num_heads_kv, dim=2) + v = torch.repeat_interleave(v, num_heads // num_heads_kv, dim=2) + + q = torch.transpose(q, 1, 2).reshape(bs * num_heads, seqlen, head_size) + k = torch.transpose(k, 1, 2).reshape(bs * num_heads, seqlen, head_size) + v = torch.transpose(v, 1, 2).reshape(bs * num_heads, seqlen, head_size) + + score = torch.bmm(q, k.mT) / np.sqrt(head_size) # [bs * num_heads, seqlen, seqlen] + causal_mask = torch.tril(torch.ones(seqlen, seqlen, dtype=torch.bool), diagonal=0).to( + q.device + ) + causal_mask = causal_mask.unsqueeze(0) # [1, seqlen, seqlen] + causal_mask = causal_mask.expand( + bs * num_heads, seqlen, seqlen + ).contiguous() # [bs * num_heads, seqlen, seqlen] + score = score.masked_fill(causal_mask == 0, float("-inf")) + + o = torch.bmm( + torch.softmax(score.float(), dim=-1).to(q.dtype), v + ) # [bs * num_heads, seqlen, head_size] + o = o.reshape(bs, num_heads, seqlen, head_size).transpose(1, 2).contiguous() + return o + + +def flash_attention_flash_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, +): + try: + from flash_attn.cute.interface import flash_attn_func + + return flash_attn_func(q, k, v, causal=True) + except ImportError: + return flash_attention_reference(q, k, v) + + +def demo_flash_attention(): + for bs, seqlen, num_heads, head_size, num_heads_kv in [ + # [1, 8, 1, 64, 1], + [1, 4096, 32, 128, 8] + ]: + q = torch.rand(bs, seqlen, num_heads, head_size, dtype=torch.float16).cuda() + k = torch.rand(bs, seqlen, num_heads_kv, head_size, dtype=torch.float16).cuda() + v = torch.rand(bs, seqlen, num_heads_kv, head_size, dtype=torch.float16).cuda() + flash_attention(q, k, v) + torch.cuda.synchronize() + + +def main(bench=True): + headers = [ + "batch_size", + "seqlen", + "num_heads", + "head_size", + "num_heads_kv", + "name", + "latency (ms)", + "tflops", + ] + data = [] + for batch_size, seqlen, num_heads, head_size, num_heads_kv in [ + [1, 512, 32, 128, 8], + [1, 1024, 32, 128, 8], + [1, 2048, 32, 128, 8], + [1, 4096, 32, 128, 8], + #[1, 8192, 32, 128, 8], + [1, 512, 64, 128, 8], + [1, 1024, 64, 128, 8], + [1, 2048, 64, 128, 8], + [1, 4096, 64, 128, 8], + #[1, 8192, 64, 128, 8], + ]: + q = torch.rand( + batch_size, seqlen, num_heads, head_size, dtype=torch.float16 + ).cuda() + k = torch.rand( + batch_size, seqlen, num_heads_kv, head_size, dtype=torch.float16 + ).cuda() + v = torch.rand( + batch_size, seqlen, num_heads_kv, head_size, dtype=torch.float16 + ).cuda() + for name, runner in [ + ("flash-attn", flash_attention_flash_attn), + ("tilus", flash_attention), + ]: + print( + f"Running {name} with batch_size={batch_size}, seqlen={seqlen}, num_heads={num_heads}, head_size={head_size}, num_heads_kv={num_heads_kv}" + ) + try: + actual = runner(q, k, v) + except torch.OutOfMemoryError: + print("Out of memory, skipping this configuration.") + continue + + try: + expected = flash_attention_reference(q, k, v) + torch.testing.assert_close(actual, expected, atol=1e-2, rtol=1e-2) + except torch.OutOfMemoryError: + pass + + latency = ( + benchmark_func( + lambda: runner(q, k, v), + warmup=20, + repeat=50, + ) + if bench + else float("nan") + ) + tflops = ( + 2 * batch_size * num_heads * seqlen * head_size * seqlen / latency * 1e-9 + ) + data.append( + [ + batch_size, + seqlen, + num_heads, + head_size, + num_heads_kv, + name, + latency, + tflops, + ] + ) + df = pd.DataFrame(data, columns=headers) + df_pivot = df.pivot( + index=[ + "batch_size", + "seqlen", + "num_heads", + "head_size", + "num_heads_kv", + ], + columns="name", + values=["latency (ms)", "tflops"], + ).reset_index() + # sort by (batch_size, num_heads, head_size, seqlen) + df_pivot = df_pivot.sort_values( + by=["batch_size", "num_heads", "head_size", "seqlen"], + ascending=[True, True, True, True], + ) + print(df_pivot) + + +if __name__ == "__main__": + main() + # ncu_run(main, bench=False, kernel_regex="flash_fwd|flash_attention") \ No newline at end of file diff --git a/examples/generic_python/flash_attention/triton_attention.py b/examples/generic_python/flash_attention/triton_attention.py new file mode 100644 index 000000000..e4f36a133 --- /dev/null +++ b/examples/generic_python/flash_attention/triton_attention.py @@ -0,0 +1,276 @@ +# example taken from https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html +# removed optional support for FP8 +# Removed backward pass +# removed checks for e.g. hopper/hip etc +import pytest +import torch +import os +import numpy as np + +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + +configs = [ + triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \ + for BM in [64, 128]\ + for BN in [32, 64, 128]\ + for s in [2, 3, 4] \ + for w in [4, 8]\ +] + + +def prune_invalid_configs(configs, named_args, **kwargs): + N_CTX = kwargs["N_CTX"] + STAGE = kwargs["STAGE"] + + # Filter out configs where BLOCK_M > N_CTX + # Filter out configs where BLOCK_M < BLOCK_N when causal is True + return [ + conf for conf in configs if conf.kwargs.get("BLOCK_M", 0) <= N_CTX and ( + conf.kwargs.get("BLOCK_M", 0) >= conf.kwargs.get("BLOCK_N", 0) or STAGE == 1) + ] + + +@triton.jit +def _attn_fwd_inner(acc, l_i, m_i, q, # + desc_k, desc_v, # + offset_y, dtype: tl.constexpr, start_m, qk_scale, # + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # + STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # + N_CTX: tl.constexpr, warp_specialize: tl.constexpr): + # range of values handled by this stage + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + # causal = False + else: + lo, hi = 0, N_CTX + offsetk_y = offset_y + lo + if dtype == tl.float8e5: + offsetv_y = offset_y * HEAD_DIM + lo + else: + offsetv_y = offset_y + lo + # loop over k, v and update accumulator + for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = desc_k.load([offsetk_y, 0]).T + qk = tl.dot(q, k) + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + else: + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + qk = qk * qk_scale - m_ij[:, None] + p = tl.math.exp2(qk) + # -- compute correction factor + alpha = tl.math.exp2(m_i - m_ij) + l_ij = tl.sum(p, 1) + # -- update output accumulator -- + if warp_specialize and BLOCK_M == 128 and HEAD_DIM == 128: + BM: tl.constexpr = acc.shape[0] + BN: tl.constexpr = acc.shape[1] + acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split() + acc0 = acc0 * alpha[:, None] + acc1 = acc1 * alpha[:, None] + acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN]) + else: + acc = acc * alpha[:, None] + # prepare p and v for the dot + if dtype == tl.float8e5: + v = desc_v.load([0, offsetv_y]).T + else: + v = desc_v.load([offsetv_y, 0]) + p = p.to(dtype) + # note that this non transposed v for FP8 is only supported on Blackwell + acc = tl.dot(p, v, acc) + # update m_i and l_i + # place this at the end of the loop to reduce register pressure + l_i = l_i * alpha + l_ij + m_i = m_ij + offsetk_y += BLOCK_N + offsetv_y += BLOCK_N + return acc, l_i, m_i + + +@triton.jit +def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape): + if isinstance(desc_or_ptr, tl.tensor_descriptor): + return desc_or_ptr + else: + return tl.make_tensor_descriptor(desc_or_ptr, shape, strides, block_shape) + + +@triton.autotune(configs=configs, key=["N_CTX", "HEAD_DIM", "warp_specialize"], + prune_configs_by={'early_config_prune': prune_invalid_configs}) +@triton.jit +def _attn_fwd(sm_scale, M, # + Z, H, desc_q, desc_k, desc_v, desc_o, N_CTX, # + HEAD_DIM: tl.constexpr, # + BLOCK_M: tl.constexpr, # + BLOCK_N: tl.constexpr, # + STAGE: tl.constexpr, # + warp_specialize: tl.constexpr, # + ): + dtype = tl.float16 + tl.static_assert(BLOCK_N <= HEAD_DIM) + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_z = off_hz // H + off_h = off_hz % H + + y_dim = Z * H * N_CTX + desc_q = _maybe_make_tensor_desc(desc_q, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], + block_shape=[BLOCK_M, HEAD_DIM]) + desc_v = _maybe_make_tensor_desc(desc_v, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], + block_shape=[BLOCK_N, HEAD_DIM]) + desc_k = _maybe_make_tensor_desc(desc_k, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], + block_shape=[BLOCK_N, HEAD_DIM]) + desc_o = _maybe_make_tensor_desc(desc_o, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], + block_shape=[BLOCK_M, HEAD_DIM]) + + offset_y = off_z * (N_CTX * H) + off_h * N_CTX + qo_offset_y = offset_y + start_m * BLOCK_M + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + # load scales + qk_scale = sm_scale + qk_scale *= 1.44269504 # 1/log(2) + # load q: it will stay in SRAM throughout + q = desc_q.load([qo_offset_y, 0]) + # stage 1: off-band + # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE + # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE + if STAGE & 1: + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, # + desc_k, desc_v, # + offset_y, dtype, start_m, qk_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 4 - STAGE, offs_m, offs_n, N_CTX, # + warp_specialize) + # stage 2: on-band + if STAGE & 2: + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, # + desc_k, desc_v, # + offset_y, dtype, start_m, qk_scale, # + BLOCK_M, HEAD_DIM, BLOCK_N, # + 2, offs_m, offs_n, N_CTX, # + warp_specialize) + # epilogue + m_i += tl.math.log2(l_i) + acc = acc / l_i[:, None] + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(m_ptrs, m_i) + desc_o.store([qo_offset_y, 0], acc.to(dtype)) + + + +def forward(q, k, v, causal, sm_scale, warp_specialize=True): + # shape constraints + HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] + # when v is in float8_e5m2 it is transposed. + HEAD_DIM_V = v.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + o = torch.empty_like(q) + stage = 3 if causal else 1 + extra_kern_args = {} + M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + desc_q = q + desc_v = v + desc_k = k + desc_o = o + + def grid(META): + return (triton.cdiv(q.shape[2], META["BLOCK_M"]), q.shape[0] * q.shape[1], 1) + + _attn_fwd[grid]( + sm_scale, M, # + q.shape[0], q.shape[1], # + desc_q, desc_k, desc_v, desc_o, # + N_CTX=q.shape[2], # + HEAD_DIM=HEAD_DIM_K, # + STAGE=stage, # + warp_specialize=warp_specialize, # + **extra_kern_args) + + return o + + + + + +if __name__ == "__main__": + import torch + +# Triton kernel (your forward function) +# from your previous code: forward(ctx, Q, K, V, causal, sm_scale, warp_specialize=True) + +def flash_attention_reference(q, k, v, causal=False, sm_scale=1.0): + """ + Reference attention using batched matmul (like the pytest does). + Args: + q, k, v: [batch, heads, seq_len, head_dim] + causal: whether to apply causal mask + sm_scale: scaling factor for attention scores + Returns: + output: [batch, heads, seq_len, head_dim] + """ + # Batched matrix multiply + # scores: [batch, heads, seq_len, seq_len] + scores = torch.matmul(q, k.transpose(-2, -1)) * sm_scale + + if causal: + seq_len = q.shape[2] + mask = torch.tril(torch.ones(seq_len, seq_len, device=q.device, dtype=torch.bool)) + scores = scores.masked_fill(~mask[None, None, :, :], float("-inf")) + + # Softmax along last dimension + attention = torch.softmax(scores.float(), dim=-1) + + # Weighted sum with V + output = torch.matmul(attention, v) + return output + + + +# Example test +if __name__ == "__main__": + forward_simple() + exit(0) + + # Parameters + B, H, N_CTX, HEAD_DIM = 2, 4, 128, 64 + causal = True + sm_scale = 1.0 + + # Random inputs + Q = torch.randn(B, H, N_CTX, HEAD_DIM, dtype=torch.float16, device="cuda") + K = torch.randn_like(Q) + V = torch.randn_like(Q) + + # Triton output + output_triton = forward(Q, K, V, causal=causal, sm_scale=sm_scale, warp_specialize=True) + + # Reference output + output_ref = flash_attention_reference(Q.float(), K.float(), V.float(), causal=causal, sm_scale=sm_scale) + + # Compare + max_diff = (output_triton.float() - output_ref).abs().max() + mean_diff = (output_triton.float() - output_ref).abs().mean() + print(f"Max difference: {max_diff.item():.6f}") + print(f"Mean difference: {mean_diff.item():.6f}") + + diff --git a/examples/generic_python/loopy_example.py b/examples/generic_python/loopy_example.py new file mode 100644 index 000000000..49ad131f6 --- /dev/null +++ b/examples/generic_python/loopy_example.py @@ -0,0 +1,38 @@ +import numpy as np + +import pyopencl as cl +import pyopencl.array + +import loopy as lp +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 # noqa: F401 + + +# setup +# ----- +ctx = cl.create_some_context() +queue = cl.CommandQueue(ctx) + +n = 15 * 10**6 +a = cl.array.arange(queue, n, dtype=np.float32) + +# create +# ------ +knl = lp.make_kernel( + "{ [i]: 0<=i torch.Tensor: + m, k = x.size() + k, n = y.size() + out = torch.empty([m, n], dtype=x.dtype, device=x.device) + + for tile_m, tile_n in hl.tile([m, n]): + acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) + for tile_k in hl.tile(k): + acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) + out[tile_m, tile_n] = acc + + return out + + +out = matmul(torch.randn([2048, 2048], device="cuda"), + torch.randn([2048, 2048], device="cuda")) \ No newline at end of file diff --git a/examples/generic_python/matmul/numba_matmul.py b/examples/generic_python/matmul/numba_matmul.py new file mode 100644 index 000000000..00102a382 --- /dev/null +++ b/examples/generic_python/matmul/numba_matmul.py @@ -0,0 +1,148 @@ +import torch +import numpy as np +from numba import cuda +from kernel_tuner import tune_kernel +from pathlib import Path + +FULL_PATH = Path(__file__).resolve() + +# Example taken from https://nvidia.github.io/numba-cuda/user/examples.html#matrix-multiplication +@cuda.jit(cache=True) +def matmul(A, B, C): + """Perform square matrix multiplication of C = A * B + """ + i, j = cuda.grid(2) + if i < C.shape[0] and j < C.shape[1]: + tmp = 0. + for k in range(A.shape[1]): + tmp += A[i, k] * B[k, j] + C[i, j] = tmp + + +# Example taken from https://nvidia.github.io/numba-cuda/user/examples.html#matrix-multiplication +# Changed data type from float32 to float16 +@cuda.jit(cache=True) +def fast_matmul(A, B, C): + # Define an array in the shared memory + # The size and type of the arrays must be known at compile time + TPB = 16 # TEMP voor overhead testing + sA = cuda.shared.array(shape=(TPB, TPB), dtype=np.float16) + sB = cuda.shared.array(shape=(TPB, TPB), dtype=np.float16) + + x, y = cuda.grid(2) + + tx = cuda.threadIdx.x + ty = cuda.threadIdx.y + bpg = cuda.gridDim.x # blocks per grid + + if x >= C.shape[0] and y >= C.shape[1]: + # Quit if (x, y) is outside of valid C boundary + return + + # Each thread computes one element in the result matrix. + # The dot product is chunked into dot products of TPB-long vectors. + tmp = 0. + for i in range(bpg): + # Preload data into shared memory + sA[tx, ty] = A[x, ty + i * TPB] + sB[tx, ty] = B[tx + i * TPB, y] + + # Wait until all threads finish preloading + cuda.syncthreads() + + # Computes partial product on the shared memory + for j in range(TPB): + tmp += sA[tx, j] * sB[j, ty] + + # Wait until all threads finish computing + cuda.syncthreads() + + C[x, y] = tmp + + +def run_matmul(M, N, K): + + # create numpy arrays + A = np.random.rand(M, K).astype(np.float16) + B = np.random.rand(K, N).astype(np.float16) + C = np.zeros((M, N), dtype=np.float16) + + # copy to GPU + A_d = cuda.to_device(A) + B_d = cuda.to_device(B) + C_d = cuda.to_device(C) + + # threads per block + threads = (16, 16) + + # compute grid size (ceil division) + blocks = ( + (M + threads[0] - 1) // threads[0], + (N + threads[1] - 1) // threads[1], + ) + + # launch kernel + fast_matmul[blocks, threads](A_d, B_d, C_d) + + # copy result back + C_result = C_d.copy_to_host() + + # check + np.testing.assert_allclose(C_result, A @ B, rtol=1e-2) + + print("Correct!") + + +def call_numba(kernel_function, args, kwargs, grid, threads): + numba_args = [] + for arg in args: + if isinstance(arg, torch.Tensor): + numba_args.append(cuda.as_cuda_array(arg)) + else: + numba_args.append(arg) + kernel_function[grid, threads](*args, **kwargs) + + +def tune(M, N, K): + # create numpy arrays + A = np.random.rand(M, K).astype(np.float16) + B = np.random.rand(K, N).astype(np.float16) + C = np.zeros((M, N), dtype=np.float16) + + size = (M, N) + args = [A, B, C] + tune_params = dict() + tune_params["block_size_x"] = [4, 8, 16, 32, 64, 128, 256] + tune_params["block_size_y"] = [4, 8, 16, 32, 64, 128, 256] + tune_params["TPB"] = [4, 8, 16, 32, 64, 128, 256] + restrictions = ["block_size_x == block_size_y", "block_size_x == TPB"] + + answer = [None, None, A @ B] + atol = M * 2**(-11) + + results, env = tune_kernel( + kernel_name="fast_matmul", + kernel_source=FULL_PATH, + problem_size=size, + arguments=args, + tune_params=tune_params, + lang="generic_python", + answer=answer, + atol=atol, + call_function=call_numba, + restrictions=restrictions + ) + + +if __name__ == "__main__": + run_matmul(128, 96, 64) + #tune(1024, 1024, 1024) + + ''' + 128: 16, 8 + 256: 8, 16 + 512: 8, 16 + 1024: 16, 8 + 4096: 16, 8 + 8192: 16, 8 + ''' \ No newline at end of file diff --git a/examples/generic_python/matmul/test.py b/examples/generic_python/matmul/test.py new file mode 100644 index 000000000..991ef2cc4 --- /dev/null +++ b/examples/generic_python/matmul/test.py @@ -0,0 +1,278 @@ +import argparse +import itertools +import tilelang as tl +import tilelang.language as T +from tilelang.autotuner import AutoTuner +from tilelang.carver.template import MatmulTemplate +from tilelang.carver.arch import CUDA +from tilelang.carver.arch import CDNA +from tilelang.carver.roller.rasterization import NoRasterization +import torch + + +def ref_program(A, B): + """ + Compute the matrix product of A and the transpose of B. + + A and B are expected to be 2-D tensors where A has shape (M, K) and B has shape (N, K). The result is a tensor with shape (M, N) equal to A @ B.T, using the inputs' dtypes. + """ + return A @ B.T + + +def get_configs(M, N, K, with_roller=False, topk=20): + """ + Generate a list of kernel tuning configuration dictionaries for a tiled matrix-multiply. + + When with_roller is True this queries the MatmulTemplate roller to produce up to `topk` recommended + configurations (device-specific TensorCore-friendly tilings). Each returned dict contains: + - block_M, block_N, block_K: tile sizes + - num_stages: pipeline staging (0 means no explicit staging) + - thread_num: total threads used for the block + - enable_rasteration: whether a rasterization/swizzle layout was recommended (note spelling) + + When with_roller is False this returns the Cartesian product of a fixed set of candidate + parameters; the returned dicts use the backward-compatible key name "enable_rasteration" for that flag. + + Parameters: + M, N, K (int): GEMM dimensions used to generate valid tile sizes. + with_roller (bool): If True, use MatmulTemplate's roller to generate device-aware hints; + otherwise use a predefined candidate grid. + topk (int): Maximum number of roller hints to request when with_roller is True. + + Returns: + List[dict]: A list of configuration dictionaries as described above. + + Raises: + ValueError: if with_roller is True but the roller returns no hints. + """ + if with_roller: + arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip") + carve_template = MatmulTemplate( + M=M, + N=N, + K=K, + in_dtype=T.float16, + out_dtype=T.float16, + accum_dtype=T.float32, + ).with_arch(arch) + + func = carve_template.equivalent_function() + assert func is not None, "Function is None" + roller_hints = carve_template.recommend_hints(topk=topk) + if roller_hints is None: + raise ValueError("No Roller Hints Found for TensorCore Scheduling") + configs = [] + for hint in roller_hints: + config = {} + block_m, block_n = hint.block + warp_m, warp_n = hint.warp + # block_rows, block_cols represents warp partitioning + block_rows, block_cols = block_m // warp_m, block_n // warp_n + config["block_M"] = block_m + config["block_N"] = block_n + config["block_K"] = hint.rstep[0] + config["num_stages"] = hint.pipeline_stage if hint.pipeline_stage > 1 else 0 + config["thread_num"] = block_rows * block_cols * 32 + config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization + configs.append(config) + else: + block_M = [64, 128, 256] + block_N = [64, 128, 256] + block_K = [32, 64] + num_stages = [0, 1, 2, 3] + thread_num = [128, 256] + enable_rasterization = [True, False] + _configs = list( + itertools.product( + block_M, + block_N, + block_K, + num_stages, + thread_num, + enable_rasterization, + ) + ) + + configs = [ + { + "block_M": c[0], + "block_N": c[1], + "block_K": c[2], + "num_stages": c[3], + "thread_num": c[4], + "enable_rasteration": c[5], # keep param name for backward-compat + } + for c in _configs + ] + return configs + + +def get_best_config( + M, + N, + K, + with_roller: bool = False, + profile_backend: str = "event", +): + def kernel( + block_M=None, + block_N=None, + block_K=None, + num_stages=None, + thread_num=None, + enable_rasteration=None, + ): + dtype = T.bfloat16 + accum_dtype = T.float32 + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), dtype) + T.use_swizzle(panel_size=10, enable=enable_rasteration) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm( + A_shared, + B_shared, + C_local, + transpose_B=True, + ) + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return main + + autotuner = ( + AutoTuner.from_kernel(kernel=kernel, configs=get_configs(M, N, K, with_roller)) + .set_compile_args( + out_idx=[-1], + target="auto", + ) + .set_profile_args( + supply_type=tl.TensorSupplyType.Integer, + ref_prog=ref_program, + skip_check=False, + backend=profile_backend, + ) + ) + return autotuner.run(warmup=3, rep=20) + + +def get_heuristic_config() -> dict: + # Get CUDA device properties + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available") + device = torch.cuda.current_device() + sm_major, sm_minor = torch.cuda.get_device_capability(device) + sm_version = sm_major * 10 + sm_minor + print(f"CUDA device capability: {sm_version}") + if sm_version in {80}: + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 2, "thread_num": 128, "enable_rasteration": True} + elif sm_version in {90}: + return {"block_M": 128, "block_N": 256, "block_K": 64, "num_stages": 3, "thread_num": 256, "enable_rasteration": True} + else: + return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 0, "thread_num": 128, "enable_rasteration": True} + + +@tl.jit(out_idx=[-1]) +def matmul(M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def gemm_autotune( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), dtype) + T.use_swizzle(panel_size=10, enable=enable_rasteration) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) + T.gemm( + A_shared, + B_shared, + C_local, + transpose_B=True, + ) + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return gemm_autotune + + +def main( + M: int = 4096, + N: int = 4096, + K: int = 4096, + use_autotune: bool = False, + with_roller: bool = False, + profile_backend: str = "event", +): + if use_autotune: + result = get_best_config( + M, + N, + K, + with_roller=with_roller, + profile_backend=profile_backend, + ) + print(result.config) + kernel = result.kernel + else: + config = get_heuristic_config() + kernel = matmul(M, N, K, **config) + + # benchmark + profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto) + tilelang_latency = profiler.do_bench( + backend=profile_backend, + ) + ref_latency = profiler.do_bench( + ref_program, + backend=profile_backend, + ) + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + print(f"TileLang latency: {tilelang_latency}") + print(f"Ref latency: {ref_latency}") + print(f"TileLang TFlops: {2 * M * N * K / tilelang_latency * 1e-9}") + print(f"Ref TFlops: {2 * M * N * K / ref_latency * 1e-9}") + + +def run_regression_perf(M: int = 4096, N: int = 4096, K: int = 4096): + config = get_heuristic_config() + kernel = matmul(M, N, K, **config) + profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto) + return profiler.do_bench(backend="cupti") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") + parser.add_argument("--m", type=int, default=4096, help="Matrix dimension M") + parser.add_argument("--n", type=int, default=4096, help="Matrix dimension N") + parser.add_argument("--k", type=int, default=4096, help="Matrix dimension K") + parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune for matmul configs") + parser.add_argument("--with_roller", action="store_true", default=False, help="Whether to enable BitBLAS roller for search space") + parser.add_argument("--profile_backend", type=str, default="event", help="Profiler backend") + args = parser.parse_args() + main( + args.m, + args.n, + args.k, + args.use_autotune, + args.with_roller, + args.profile_backend, + ) \ No newline at end of file diff --git a/examples/generic_python/matmul/tilelang_matmul.py b/examples/generic_python/matmul/tilelang_matmul.py index 895732629..7ced8cc83 100644 --- a/examples/generic_python/matmul/tilelang_matmul.py +++ b/examples/generic_python/matmul/tilelang_matmul.py @@ -32,40 +32,42 @@ def gemm( # https://github.com/tile-ai/tilelang/blob/main/examples/gemm/example_gemm_autotune.py -@tilelang.jit +# changed gemm_autotune to gemm +# originally, B was transposed. I changed this so the kernel input is the same as in other languages. +#@tilelang.jit def matmul_opt(M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32): @T.prim_func - def gemm_autotune( + def gemm( A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), + B: T.Tensor((K, N), dtype), C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_N, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_shared = T.alloc_shared((block_M, block_N), dtype) T.use_swizzle(panel_size=10, enable=enable_rasteration) T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[bx * block_N, k * block_K], B_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) T.gemm( A_shared, B_shared, C_local, - transpose_B=True, + transpose_B=False, ) T.copy(C_local, C_shared) T.copy(C_shared, C[by * block_M, bx * block_N]) - return gemm_autotune + return gemm def main(): - kernel = matmul_basic(1024, 1024, 1024, 128, 128, 32) + kernel = matmul_opt(1024, 1024, 1024, 128, 128, 32, 3, 128, False) import torch diff --git a/examples/generic_python/matmul/tilus_matmul.py b/examples/generic_python/matmul/tilus_matmul.py index dfe294761..d893b9490 100644 --- a/examples/generic_python/matmul/tilus_matmul.py +++ b/examples/generic_python/matmul/tilus_matmul.py @@ -68,8 +68,9 @@ def __call__( + # This kernel is copied from the Tilus project: -# https://github.com/NVIDIA/tilus/blob/main/examples/matmul/matmul_v5.py +# https://github.com/NVIDIA/tilus/blob/main/examples/matmul/matmul_+v5.py # # Original example: matmul_v5.py # Copyright (c) the Tilus authors diff --git a/examples/generic_python/matmul/warp_matmul.py b/examples/generic_python/matmul/warp_matmul.py new file mode 100644 index 000000000..9e6b7dcf3 --- /dev/null +++ b/examples/generic_python/matmul/warp_matmul.py @@ -0,0 +1,131 @@ +import torch +import numpy as np +import warp as wp + +from kernel_tuner import tune_kernel +from pathlib import Path + +wp.init() + +FULL_PATH = Path(__file__).resolve() + +# tile size +TILE_M = wp.constant(8) +TILE_N = wp.constant(4) +TILE_K = wp.constant(8) + +# num threads per-tile +TILE_THREADS = 64 + +# GEMM example from https://nvidia.github.io/warp/user_guide/tiles.html +@wp.kernel +def tile_gemm(A: wp.array2d(dtype=float), B: wp.array2d(dtype=float), C: wp.array2d(dtype=float)): + + # output tile index + i, j = wp.tid() + + sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=wp.float32) + + M = A.shape[0] + N = B.shape[1] + K = A.shape[1] + + count = (K + TILE_K - 1) // TILE_K + + for k in range(0, count): + a = wp.tile_load(A, shape=(TILE_M, TILE_K), offset=(i*TILE_M, k*TILE_K), bounds_check=True) + b = wp.tile_load(B, shape=(TILE_K, TILE_N), offset=(k*TILE_K, j*TILE_N), bounds_check=True) + + # sum += a*b + wp.tile_matmul(a, b, sum) + + wp.tile_store(C, sum, offset=(i*TILE_M, j*TILE_N), bounds_check=True) + + +def run_kernel_direct(M, N, K): + rng = np.random.default_rng(42) + A = rng.random((M, K), dtype=np.float32) + B = rng.random((K, N), dtype=np.float32) + C = np.zeros((M, N), dtype=np.float32) + + A_wp = wp.array(A) + B_wp = wp.array(B) + C_wp = wp.array(C) + + with wp.Tape() as tape: + wp.launch_tiled( + tile_gemm, + dim=((M + TILE_M - 1) // TILE_M, (N + TILE_N - 1) // TILE_N), + inputs=[A_wp, B_wp, C_wp], + block_dim=TILE_THREADS) + + np.testing.assert_allclose(C_wp.numpy(), A @ B, rtol=1e-3) + + print("Example matrix multiplication passed") + + + +def call_warp(kernel_function, args, kwargs, grid, threads, params): + # Convert Torch tensors to Warp args + warp_args = [] + for arg in args: + if isinstance(arg, torch.Tensor): + warp_args.append(wp.from_torch(arg)) + else: + warp_args.append(arg) + + # launch kernel + with wp.Tape() as tape: + wp.launch_tiled( + kernel_function, + dim=grid, + inputs=warp_args, + block_dim=params["TILE_THREADS"], # We could directly take threads, but in the given example this is a constant + ) + + +def tune(M, K, N): + rng = np.random.default_rng(42) + A = rng.random((M, K), dtype=np.float32) + B = rng.random((K, N), dtype=np.float32) + C = np.zeros((M, N), dtype=np.float32) + + size = (M, N) + block_size_names = ["TILE_M", "TILE_N"] + + tune_params = dict() + tune_params["TILE_M"] = [4, 8, 16] + tune_params["TILE_N"] = [2, 4, 8] + tune_params["TILE_K"] = [4, 8, 16] + tune_params["TILE_THREADS"] = [32, 64, 128] + + args = [A, B, C] + answer = [None, None, A @ B] + + results, env = tune_kernel( + kernel_name="tile_gemm", + kernel_source=FULL_PATH, + problem_size=size, + arguments=args, + tune_params=tune_params, + lang="generic_python", + answer=answer, + call_function=call_warp, + block_size_names=block_size_names, + ) + + +if __name__ == "__main__": + #tune(128, 128, 128) + + sizes = [ + (65, 65, 17), + (67, 71, 19), + (1, 1, 1), + (63, 63, 15), + (129, 130, 33), + ] + + for size in sizes: + print(size) + run_kernel_direct(*size) diff --git a/examples/generic_python/normalization/tilelang_norm.py b/examples/generic_python/normalization/tilelang_norm.py new file mode 100644 index 000000000..7c26d8f5f --- /dev/null +++ b/examples/generic_python/normalization/tilelang_norm.py @@ -0,0 +1,78 @@ +# taken from https://github.com/tile-ai/tilelang/blob/main/examples/norm/rms_norm.py +# TODO not comparable +import torch +import tilelang +import tilelang.language as T + + +def rms_norm_splitk(M, N, blk_m, blk_k): + dtype = T.float + + @T.prim_func + def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx: + A_shared = T.alloc_shared((blk_m, blk_k), dtype) + A_local = T.alloc_fragment((blk_m, blk_k), dtype) + A_powsum = T.alloc_fragment((blk_m,), dtype) + + num_k_step = T.ceildiv(N, blk_k) + T.clear(A_local) + for k in range(num_k_step): + T.copy(A[bx * blk_m, k * blk_k], A_shared) + for i, j in T.Parallel(blk_m, blk_k): + A_local[i, j] += A_shared[i, j] * A_shared[i, j] + T.reduce_sum(A_local, A_powsum, dim=1) + for i in T.Parallel(blk_m): + A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12) + + for k in range(num_k_step): + # reverse, better cache hit rate + T.copy(A[bx * blk_m, (num_k_step - 1 - k) * blk_k], A_shared) + for i, j in T.Parallel(blk_m, blk_k): + A_shared[i, j] *= A_powsum[i] + T.copy(A_shared, B[bx * blk_m, (num_k_step - 1 - k) * blk_k]) + + return main + + +@tilelang.jit(out_idx=[-1], pass_configs={"tl.disable_tma_lower": True}) +def rms_norm(M, N, blk_m): + dtype = T.float + + @T.prim_func + def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx: + A_shared = T.alloc_shared((blk_m, N), dtype) + A_pow_local = T.alloc_fragment((blk_m, N), dtype) + A_local = T.alloc_fragment((blk_m, N), dtype) + A_powsum = T.alloc_fragment((blk_m,), dtype) + + T.copy(A[bx * blk_m : (bx + 1) * blk_m, :], A_shared) + T.copy(A_shared, A_local) + for i, j in T.Parallel(blk_m, N): + A_pow_local[i, j] = A_local[i, j] * A_local[i, j] + T.reduce_sum(A_pow_local, A_powsum, dim=1) + for i in T.Parallel(blk_m): + A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12) + for i, j in T.Parallel(blk_m, N): + A_local[i, j] *= A_powsum[i] + T.copy(A_local, B[bx * blk_m : (bx + 1) * blk_m, :]) + + return main + + +def ref_program(x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-12) + + +if __name__ == "__main__": + M, N, blk_m, blk_k = 8192, 8192, 1, 512 + kernel = rms_norm(M, N, blk_m) + profiler = kernel.get_profiler() + profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) + print("All checks pass.") + + latency = profiler.do_bench(ref_program, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + latency = profiler.do_bench(warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) \ No newline at end of file diff --git a/examples/generic_python/normalization/tilus_norm.py b/examples/generic_python/normalization/tilus_norm.py new file mode 100644 index 000000000..cfe569e80 --- /dev/null +++ b/examples/generic_python/normalization/tilus_norm.py @@ -0,0 +1,132 @@ +# taken from https://github.com/NVIDIA/tilus/blob/main/examples/norm/layer_norm.py +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +import pandas +import tilus +import torch +from tilus import float16, float32, int32 +from tilus.utils import benchmark_func, cdiv + + +@tilus.autotune("block_m", [1, 8]) +@tilus.autotune("block_n", [128, 256, 512, 1024]) +@tilus.autotune("warps", [2, 4, 8]) +class LayerNorm(tilus.Script): + """Forward-only layer normalization tilus kernel. + + This implements the per-row LayerNorm used in many transformer blocks. It + computes: y = (x - mean) / sqrt(var + eps) * gamma + beta + + Only the forward is provided. + """ + + def __init__(self, block_m: int, block_n: int, warps: int): + super().__init__() + self.block_m: int = block_m + self.block_n: int = block_n + self.warps: int = warps + + def __call__( + self, + m_size: int, + n_size: int32, + x_ptr: ~float16, + gamma_ptr: ~float16, + beta_ptr: ~float16, + y_ptr: ~float16, + eps: float, + ): + self.attrs.blocks = (cdiv(m_size, self.block_m),) + self.attrs.warps = self.warps + + offset_m = self.blockIdx.x * self.block_m + + g_x = self.global_view(x_ptr, dtype=float16, shape=[m_size, n_size]) + g_y = self.global_view(y_ptr, dtype=float16, shape=[m_size, n_size]) + g_gamma = self.global_view(gamma_ptr, dtype=float16, shape=[n_size]) + g_beta = self.global_view(beta_ptr, dtype=float16, shape=[n_size]) + + # Register accumulators for mean and variance (computed in float32) + r_sum = self.register_tensor( + dtype=float32, shape=[self.block_m, self.block_n], init=0.0 + ) + r_square = self.register_tensor( + dtype=float32, shape=[self.block_m, self.block_n], init=0.0 + ) + + # first pass: compute mean and variance + for offset_n in range(0, n_size, self.block_n): + r_x = self.load_global( + g_x, offsets=[offset_m, offset_n], shape=[self.block_m, self.block_n] + ).to(float32) # [block_m, block_n] + r_sum = r_sum + r_x + r_square = r_square + self.square(r_x) + + # finalize mean and variance + r_mean = self.sum(r_sum, dim=1, keepdim=True) / n_size # [block_m, 1] + r_var = ( + self.sum(r_square, dim=1, keepdim=True) / n_size - r_mean * r_mean + ) # [block_m, 1], var = E[x^2] - (E[x])^2 + r_rstd = self.rsqrt(r_var + eps) + + # second pass: y = (x - mean) * rstd * gamma + beta + for offset_n in range(0, n_size, self.block_n): + r_x = self.load_global( + g_x, offsets=[offset_m, offset_n], shape=[self.block_m, self.block_n] + ).to(float32) # [block_m, block_n] + r_gamma = self.load_global( + g_gamma, offsets=[offset_n], shape=[self.block_n] + ).to(float32) # [block_n] + r_beta = self.load_global( + g_beta, offsets=[offset_n], shape=[self.block_n] + ).to(float32) # [block_n] + r_x_hat = (r_x - r_mean) * r_rstd + r_y = r_x_hat * r_gamma + r_beta + self.store_global(g_y, r_y.to(float16), offsets=[offset_m, offset_n]) + + +def main(): + headers = ["m_size", "n_size", "dtype", "torch (ms)", "tilus (ms)"] + rows = [] + for i in [1, 2, 4, 8]: + m_size = n_size = 1024 * i + + tilus_layer_norm = LayerNorm() + + x = (torch.rand(m_size, n_size, dtype=torch.float16).cuda() - 0.5) * 2.0 + gamma = torch.rand(n_size, dtype=torch.float16).cuda() + beta = torch.rand(n_size, dtype=torch.float16).cuda() + y_actual = torch.empty_like(x) + + tilus_layer_norm(m_size, n_size, x, gamma, beta, y_actual, 1e-5) + y_expected = torch.nn.functional.layer_norm( + x, normalized_shape=[n_size], weight=gamma, bias=beta, eps=1e-5 + ) + + torch.testing.assert_close(y_actual, y_expected, atol=1e-2, rtol=1e-2) + + rows.append( + [ + m_size, + n_size, + "float16", + benchmark_func( + lambda: torch.nn.functional.layer_norm( + x, normalized_shape=[n_size], weight=gamma, bias=beta, eps=1e-5 + ) + ), + benchmark_func( + lambda: tilus_layer_norm( + m_size, n_size, x, gamma, beta, y_actual, 1e-5 + ) + ), + ] + ) + print(f"LayerNorm forward matches reference for size ({m_size}, {n_size})") + + df = pandas.DataFrame(rows, columns=headers) + print(df) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/generic_python/normalization/triton_norm.py b/examples/generic_python/normalization/triton_norm.py new file mode 100644 index 000000000..365f541c4 --- /dev/null +++ b/examples/generic_python/normalization/triton_norm.py @@ -0,0 +1,66 @@ + +#Example taken from https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html +import torch + +import triton +import triton.language as tl + +try: + # This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it + # should not be added to extras_require in setup.py. + import apex + HAS_APEX = True +except ModuleNotFoundError: + HAS_APEX = False + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@triton.jit +def _layer_norm_fwd_fused( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + Y += row * stride + X += row * stride + # Compute mean + mean = 0 + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + _mean += a + mean = tl.sum(_mean, axis=0) / N + # Compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) + x = tl.where(cols < N, x - mean, 0.) + _var += x * x + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # Write mean / rstd + tl.store(Mean + row, mean) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask) + b = tl.load(B + cols, mask=mask) + x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) + x_hat = (x - mean) * rstd + y = x_hat * w + b + # Write output + tl.store(Y + cols, y, mask=mask) \ No newline at end of file diff --git a/examples/generic_python/pallas_vec_add.py b/examples/generic_python/pallas_vec_add.py new file mode 100644 index 000000000..874392dd9 --- /dev/null +++ b/examples/generic_python/pallas_vec_add.py @@ -0,0 +1,23 @@ +# Could not get working + +from functools import partial + +import jax +from jax.experimental import pallas as pl +import jax.numpy as jnp +import numpy as np + +def add_vectors_kernel(x_ref, y_ref, o_ref): + x, y = x_ref[...], y_ref[...] + o_ref[...] = x + y + + +@jax.jit +def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: + return pl.pallas_call( + add_vectors_kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype) + )(x, y) + + +add_vectors(jnp.arange(8), jnp.arange(8)) \ No newline at end of file diff --git a/examples/generic_python/tilus_vec_add.py b/examples/generic_python/tilus_vec_add.py index 9f842637c..d00c77972 100644 --- a/examples/generic_python/tilus_vec_add.py +++ b/examples/generic_python/tilus_vec_add.py @@ -8,7 +8,7 @@ FULL_PATH = Path(__file__).resolve() class VecAddV(tilus.Script): - def __init__(self, block_size_x=None): + def __init__(self, block_size_x=None, num_warps=None): super().__init__() self.block_size_x = block_size_x # number of threads per block @@ -22,7 +22,7 @@ def __call__( # compute the number of blocks needed self.attrs.blocks = [cdiv(n_size, self.block_size_x)] - self.attrs.warps = 1 # number of warps per block + self.attrs.warps = 4 # number of warps per block # calculate the offset for this block offset: int32 = self.block_size_x * self.blockIdx.x @@ -52,6 +52,7 @@ def tune(size): args = [size, a, b, c] tune_params = dict() tune_params["block_size_x"] = [32, 64, 128, 256, 512, 1024] + tune_params["num_warps"] = [4, 8] results, env = tune_kernel( @@ -100,7 +101,7 @@ def tune_with_builtin(size): c = torch.empty(size, dtype=torch.float32).cuda() c_expect = a + b - vecadd(size, a, b, c) + vecadd(size, a, b, c) # This is where the actual tuning takes place torch.cuda.synchronize() torch.testing.assert_close(c_expect, c) diff --git a/examples/generic_python/triton_vec_add.py b/examples/generic_python/triton_vec_add.py index f69ae96d7..00bd9dfbc 100644 --- a/examples/generic_python/triton_vec_add.py +++ b/examples/generic_python/triton_vec_add.py @@ -65,6 +65,8 @@ def tune(): call_function=call_triton, ) + print(results) + if __name__ == "__main__": diff --git a/kernel_tuner/backends/generic_python.py b/kernel_tuner/backends/generic_python.py index 39be59a23..4845fc4e2 100644 --- a/kernel_tuner/backends/generic_python.py +++ b/kernel_tuner/backends/generic_python.py @@ -195,9 +195,7 @@ def run_kernel(self, func, gpu_args, threads, grid, params=None): configuration :type params: dict """ - with torch.cuda.stream(self.stream): - logging.debug("Running Generic Python kernel") - self.call_function(func, gpu_args, self.gpu_kwargs, grid, threads, params) + self.call_function(func, gpu_args, self.gpu_kwargs, grid, threads, params) def synchronize(self): @@ -239,7 +237,6 @@ def refresh_memory(self, gpu_memory, host_arguments, should_sync): """Refresh the GPU memory with the untouched host arguments. We overwrite the standard function because Python DSLs do usually do not manage memory explicitely""" - for i, host_arg in enumerate(host_arguments): if should_sync[i]: gpu_arg = gpu_memory[i] diff --git a/kernel_tuner/core.py b/kernel_tuner/core.py index 1f130afa3..942ff329a 100644 --- a/kernel_tuner/core.py +++ b/kernel_tuner/core.py @@ -269,6 +269,8 @@ def benchmark_default(self, func, gpu_args, threads, grid, result, params=None): for obs in self.benchmark_observers: obs.after_finish() + #time.sleep(0.1) # prevent termal throttling + for obs in self.benchmark_observers: result.update(obs.get_results()) @@ -509,6 +511,10 @@ def compile_and_benchmark(self, kernel_source, gpu_args, params, kernel_options, # clean up any temporary files, if no error occurred instance.delete_temp_files() + # For Python DSLs, the compilation time also includes one kernel run, so we subtract the runtime + if self.lang == Language.GENERIC_PYTHON: + last_compilation_time -= result["time"] or 0 + result["compile_time"] = last_compilation_time or 0 result["verification_time"] = last_verification_time or 0 result["benchmark_time"] = last_benchmark_time or 0 diff --git a/kernel_tuner/kernel_sources/kernel_source_fn.py b/kernel_tuner/kernel_sources/kernel_source_fn.py index 3eb27d86b..f806d5b87 100644 --- a/kernel_tuner/kernel_sources/kernel_source_fn.py +++ b/kernel_tuner/kernel_sources/kernel_source_fn.py @@ -70,7 +70,13 @@ def prepare_kernel_instance(self, kernel_options, params, grid, threads): generate a kernel instance with these tuning parameters inserted. kernel_options, grid and threads are not needed for Python kernels. ''' - new_kernel_fn, temp_file_path = self.apply_params_to_source_fn(params) + # Test: see if removing signature params helps in Triton overhead + filtered_params = {} + for k, v in params.items(): + if k not in self.signature: + filtered_params[k] = v + + new_kernel_fn, temp_file_path = self.apply_params_to_source_fn(filtered_params) self.kernel_fn = new_kernel_fn return PreparedKernelSourceData( @@ -121,7 +127,7 @@ def apply_params_to_source_fn(self, params): # Fix locations and generate source ast.fix_missing_locations(new_module) new_source = astor.to_source(new_module) - + #print(new_source) # Create a unique module name and write new source to it. From 775e8f565dbd35666db200312b450ea0b1852780 Mon Sep 17 00:00:00 2001 From: "I.C. van Ooijen" Date: Sun, 12 Apr 2026 12:01:32 +0200 Subject: [PATCH 09/14] matmul examples --- examples/generic_python/call_functions.py | 78 +- examples/generic_python/cute_vec_add.py | 92 ++ examples/generic_python/matmul/cute_matmul.py | 1019 +++++++++++++++-- .../generic_python/matmul/numba_matmul.py | 133 ++- .../generic_python/matmul/taichi_matmul.py | 18 +- .../generic_python/matmul/tilelang_matmul.py | 107 +- examples/generic_python/matmul/warp_matmul.py | 29 +- 7 files changed, 1370 insertions(+), 106 deletions(-) create mode 100644 examples/generic_python/cute_vec_add.py diff --git a/examples/generic_python/call_functions.py b/examples/generic_python/call_functions.py index 6b321c4a2..734568efd 100644 --- a/examples/generic_python/call_functions.py +++ b/examples/generic_python/call_functions.py @@ -3,20 +3,21 @@ def call_tilus(kernel_function, args, kwargs): kernel_function(*args, **kwargs) + def call_triton(kernel_function, args, kwargs, grid, threads, params): if "num_warps" in params.keys(): kwargs["num_warps"] = params["num_warps"] if "num_stages" in params.keys(): kwargs["num_stages"] = params["num_stages"] - torch.cuda.nvtx.range_push("kt call") kernel_function[grid](*args, **kwargs) - torch.cuda.nvtx.range_pop() + def call_tilelang(kernel_function, args, kwargs): compiled_kernel = kernel_function(**kwargs) compiled_kernel(*args) + def call_numba(kernel_function, args, kwargs, grid, threads): from numba import cuda @@ -26,4 +27,75 @@ def call_numba(kernel_function, args, kwargs, grid, threads): numba_args.append(cuda.as_cuda_array(arg)) else: numba_args.append(arg) - kernel_function[grid, threads](*args, **kwargs) \ No newline at end of file + + kernel_function[grid, threads](*args, **kwargs) + + +def call_cupyx(kernel_function, args, kwargs, grid, threads): + import cupy as cp + + cupy_args = [] + for arg in args: + if isinstance(arg, torch.Tensor): + cupy_args.append(cp.from_dlpack(arg)) + else: + cupy_args.append(arg) + kernel_function(grid, threads, tuple(cupy_args)) + + +def call_cute(kernel_function, args, kwargs, grid, threads, params): + import cutlass.cute as cute + from cutlass.cute.runtime import from_dlpack + + # Initialize cache if it does not exist + if not hasattr(call_cute, "custom_cache"): + call_cute.custom_cache = {} + + # Convert Torch tensors to CuTe tensors with correct layout + cute_args = [] + for arg in args: + if isinstance(arg, torch.Tensor): + arg_ = from_dlpack(arg) + cute_args.append(arg_) + else: + cute_args.append(arg) + + # Form cache key from tuning parameters + param_keys = sorted(params.keys()) + cache_str = type(kernel_function).__name__ + for k in param_keys: + cache_str += "_" + str(params[k]) + + # Check if kernel exists in cache. Otherwise, compile and save + if cache_str in call_cute.custom_cache: + compiled_kernel = call_cute.custom_cache[cache_str] + else: + compiled_kernel = cute.compile(kernel_function, *cute_args) + call_cute.custom_cache[cache_str] = compiled_kernel + + compiled_kernel(*cute_args, **kwargs) + + +def call_taichi(kernel_function, args, kwargs): + kernel_function(*args, **kwargs) + + +def call_warp(kernel_function, args, kwargs, grid, threads, params): + import warp as wp + + # Convert Torch tensors to Warp args + warp_args = [] + for arg in args: + if isinstance(arg, torch.Tensor): + warp_args.append(wp.from_torch(arg)) + else: + warp_args.append(arg) + + # launch kernel + with wp.Tape() as tape: + wp.launch_tiled( + kernel_function, + dim=grid, + inputs=warp_args, + block_dim=params["TILE_THREADS"], # We could directly take threads, but in the given example this is a constant + ) diff --git a/examples/generic_python/cute_vec_add.py b/examples/generic_python/cute_vec_add.py new file mode 100644 index 000000000..a83694ab9 --- /dev/null +++ b/examples/generic_python/cute_vec_add.py @@ -0,0 +1,92 @@ +import torch +from functools import partial +from typing import List +import time + +import cutlass +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack + +from kernel_tuner import tune_kernel + +@cute.kernel +def vec_add_kernel( + gA: cute.Tensor, + gB: cute.Tensor, + gC: cute.Tensor, + size: cute.Int32, +): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + bdim, _, _ = cute.arch.block_dim() + thread_id = bdim * bidx + tidx + + if thread_id < size: + gC[thread_id] = gA[thread_id] + gB[thread_id] + + + +@cute.jit +def vec_add( + mA: cute.Tensor, + mB: cute.Tensor, + mC: cute.Tensor, + size: cute.Int32, +): + num_threads_per_block = 256 + + kernel = vec_add_kernel(mA, mB, mC, size) + + kernel.launch( + grid=(cute.ceil_div(size, num_threads_per_block), 1, 1), + block = (num_threads_per_block, 1, 1), + ) + + + + +def call_cute(kernel_function, args, kwargs, grid, threads, params): + cute_args = [] + for arg in args: + if isinstance(arg, torch.Tensor): + arg_ = from_dlpack(arg) + cute_args.append(arg_) + else: + cute_args.append(arg) + + kernel_function(*cute_args, **kwargs) + + + + +def main(): + size = 16384 + a = torch.randn(size, device="cuda", dtype=torch.float16) + b = torch.randn(size, device="cuda", dtype=torch.float16) + c = torch.zeros(size, device="cuda", dtype=torch.float16) + + ''' + a_ = from_dlpack(a, assumed_align=16) + b_ = from_dlpack(b, assumed_align=16) + c_ = from_dlpack(c, assumed_align=16) + ''' + + args = [a, b, c, size] + tune_params = {"num_threads_per_block": [1, 2, 4, 8, 16, 32, 64, 128, 265, 512, 1024]} + answer = [None, None, (a+b).cpu(), None] + + from pathlib import Path + FULL_PATH = Path(__file__).resolve() + tune_kernel("vec_add", FULL_PATH, size, args, tune_params, answer=answer, + lang="generic_python", call_function=call_cute, verbose=True) + #naive_elementwise_add_ = cute.compile(naive_elementwise_add, a_, b_, c_) + + #naive_elementwise_add_(a_, b_, c_) + #vec_add(a_, b_, c_, size) + + + + #torch.testing.assert_close(c, a+b) + +main() + diff --git a/examples/generic_python/matmul/cute_matmul.py b/examples/generic_python/matmul/cute_matmul.py index 90e6e67e9..8654829fa 100644 --- a/examples/generic_python/matmul/cute_matmul.py +++ b/examples/generic_python/matmul/cute_matmul.py @@ -1,32 +1,38 @@ +import math import torch +from typing import Tuple + import cutlass import cutlass.cute as cute - +import cutlass.utils as utils from cutlass.cute.runtime import from_dlpack + from kernel_tuner import tune_kernel from pathlib import Path +from examples.generic_python.call_functions import call_cute FULL_PATH = Path(__file__).resolve() # need export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH to work -# Maybe use this as optimized kernel: https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/ampere/tensorop_gemm.py - +## Basic Matmul ================================================================ @cute.kernel -def naive_matmul_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor): - - tidx, _, _ = cute.arch.thread_idx() - bidx, _, _ = cute.arch.block_idx() - bdim, _, _ = cute.arch.block_dim() - - thread_idx = bidx * bdim + tidx - - M, K = gA.shape - _, N = gB.shape +def naive_matmul_kernel( + gA: cute.Tensor, + gB: cute.Tensor, + gC: cute.Tensor, + M: int, + K: int, + N: int, +): + tx, ty, _ = cute.arch.thread_idx() + bx, by, _ = cute.arch.block_idx() + bdx, bdy, _ = cute.arch.block_dim() - n = thread_idx % N - m = thread_idx // N + # Global indices + n = bx * bdx + tx + m = by * bdy + ty if m < M and n < N: acc = cutlass.Float32(0.0) @@ -39,97 +45,964 @@ def naive_matmul_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor): gC[m, n] = acc.to(gC.element_type) - @cute.jit def matmul( mA: cute.Tensor, mB: cute.Tensor, mC: cute.Tensor, ): + block_size_x = 16 + block_size_y = 16 + block = (block_size_x, block_size_y, 1) - threads_per_block = 256 - - M, _ = mA.shape + M, K = mA.shape _, N = mB.shape - total_outputs = M * N - - kernel = naive_matmul_kernel(mA, mB, mC) - - kernel.launch( - grid=((total_outputs + threads_per_block - 1) // threads_per_block, 1, 1), - block=(threads_per_block, 1, 1), + grid = ( + (N + block[0] - 1) // block[0], + (M + block[1] - 1) // block[1], + 1, ) + kernel = naive_matmul_kernel(mA, mB, mC, M, K, N) -# Here we formulate the jit wrapper a bit differently, becuase the grid and block sizes -# are computed by kernel tuner. But we could also use threads per block as a tuning param and -# keep the same jit wrapper as above -@cute.jit -def matmul_kt( - mA: cute.Tensor, - mB: cute.Tensor, - mC: cute.Tensor, - grid: cutlass.Constexpr, - block: cutlass.Constexpr, -): - - kernel = naive_matmul_kernel(mA, mB, mC) - - kernel.launch( - grid=grid, - block=block, - ) + kernel.launch(grid=grid, block=block) -def run_example(M, N, K): +def tune_naive_matmul(M, N, K): a = torch.randn(M, K, device="cuda", dtype=torch.float16) b = torch.randn(K, N, device="cuda", dtype=torch.float16) c = torch.zeros(M, N, device="cuda", dtype=torch.float16) - a_ = from_dlpack(a, assumed_align=16) - b_ = from_dlpack(b, assumed_align=16) - c_ = from_dlpack(c, assumed_align=16) - - matmul_ = cute.compile(matmul, a_, b_, c_) - matmul_(a_, b_, c_) - - torch.testing.assert_close(c, a @ b, atol=1e-2, rtol=1e-2) - + args = [a, b, c] + size = M * N + answer = [None, None, (a @ b).cpu()] + tune_params = dict() + tune_params["block_size_x"] = [8, 16, 32, 64, 128] + tune_params["block_size_y"] = [8, 16, 32, 64, 128] -def call_cute(kernel_function, args, kwargs, grid, threads, params): + results, env = tune_kernel("matmul", FULL_PATH, size, args, tune_params, lang="generic_python", + call_function=call_cute, answer=answer, atol=4, verbose=False) + + + +## Optimized matmul ========================================================== +# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +""" +A dense GEMM (C = A * B) example for the NVIDIA Ampere architecture using CUTE DSL. +- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M") +- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K") +- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") + +This GEMM kernel supports the following features: + - Utilizes Ampere's tensor cores for matrix multiply-accumulate (MMA) operations + - Threadblock rasterization to improve data re-use + - Supports multi-stage pipeline to overlap computation and memory access + - Implements shared memory buffering for epilogue to increase coalesced global memory access + +This GEMM works as follows: +1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using asynchronous copies. +2. Perform matrix multiply-accumulate (MMA) operations. +3. Store results from registers (RMEM) to shared memory (SMEM), then to global memory (GMEM). + +The Ampere tensor core instruction used operates as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Perform MMA operation and store the result in Accumulator(register) + +To run this example: + +.. code-block:: bash + + python examples/ampere/tensorop_gemm.py \ + --mnkl 8192,8192,8192,1 --atom_layout_mnk 2,2,1 \ + --ab_dtype Float16 \ + --c_dtype Float16 --acc_dtype Float32 \ + --a_major m --b_major n --c_major n + +The above example command computes with M=8192, N=8192, K=8192, +batch_count=1. The atom layout's shape is 2x2x1 and the input, mma +accumulator, and output data type are set as fp16, fp32 and fp16, +respectively. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/ampere/tensorop_gemm.py \ + --mnkl 8192,8192,8192,1 --atom_layout_mnk 2,2,1 \ + --ab_dtype Float16 \ + --c_dtype Float16 --acc_dtype Float32 \ + --a_major m --b_major n --c_major n \ + --skip_ref_check --iterations 2 + +Constraints: +* Supported input and output data types: fp16 +* Support accumulator data types: f32 +* Default tile shape is set to be 128x128x32 +* Atom layout's MNK shape is set so that tile shape can be divided by MMA + instruction shape +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, + i.e, number of elements is a multiple of 8 +""" + + +class TensorOpGemm: + def __init__(self): + self.ab_dtype = cutlass.Float16 + self.c_dtype = cutlass.Float16 + self.acc_dtype = cutlass.Float32 + tile_m = 128 # Added for KT support + tile_n = 128 + tile_k = 32 + self.cta_tiler = (tile_m, tile_n, tile_k) + self.num_stages = 3 + self.atom_layout_mnk = (2, 2, 1) # moved from paramter to here for KT support + atom_lay_M, atom_lay_N, atom_lay_K = self.atom_layout_mnk + self.num_threads = atom_lay_M * atom_lay_N * atom_lay_K * 32 + + self.bM, self.bN, self.bK = self.cta_tiler + self.mma_inst_shape = (16, 8, 16) + mmaM, mmaN, mmaK = self.mma_inst_shape + + assert self.bM % (atom_lay_M * mmaM) == 0, ( + "bM must be divisible by MMA instruction" + ) + assert self.bN % (atom_lay_N * mmaN) == 0, ( + "bN must be divisible by MMA instruction" + ) + assert atom_lay_K == 1, "this example does not support atom layout K > 1" + assert self.bK % mmaK == 0, "bK must be divisible by MMA instruction" + assert self.num_stages >= 3, "num_stages must be greater than or equal to 3" + + @cute.jit + def __call__( + self, + mA: cute.Tensor, + mB: cute.Tensor, + mC: cute.Tensor, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + # The grid divides the problems's M, N, and L dimensions by the + # respective modes of the tile shape (bM, bN, 1). The K dimension is + # handled within a block via a multistage process. + + self.a_major_mode = utils.LayoutEnum.from_tensor(mA) + self.b_major_mode = utils.LayoutEnum.from_tensor(mB) + self.c_major_mode = utils.LayoutEnum.from_tensor(mC) + + # /////////////////////////////////////////////////////////////////////////////// + # Shared memory layout: + # /////////////////////////////////////////////////////////////////////////////// + + # Creates a layout with the size required for the provided tile + # size and num stages (stages are used for K dimension) that is also + # sectioned into 64x8 or 8x32 layout atoms. The swizzle is set so that + # the atom for the shared memory -> register copy does not encounter + # bank conflicts + + # assume the input is 16B align + ab_copy_bits = 128 + sA_layout = self._make_smem_layout_AB( + mA.element_type, + self.a_major_mode, + ab_copy_bits, + (self.cta_tiler[0], self.cta_tiler[2], self.num_stages), + ) + sB_layout = self._make_smem_layout_AB( + mB.element_type, + self.b_major_mode, + ab_copy_bits, + (self.cta_tiler[1], self.cta_tiler[2], self.num_stages), + ) + + # Creates a similar layout but without num_stages or layout atoms + sC_layout = self._make_smem_layout_C( + mC.element_type, + self.c_major_mode, + ab_copy_bits, + (self.cta_tiler[0], self.cta_tiler[1]), + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Tiled copy: + # The majorness of tA/tB/tC follows the majorness of gA/gB/gC, + # enabling merged accesses to global memory for faster data + # transfer between global and shared memory. + # /////////////////////////////////////////////////////////////////////////////// + + # Create a copy atom for a global to shared memory asynchronous copy + atom_async_copy = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp( + cache_mode=cute.nvgpu.cpasync.LoadCacheMode.GLOBAL + ), + mA.element_type, + num_bits_per_copy=ab_copy_bits, + ) + + # Create thread layouts for tiled copy from the copy atom where the + # thread layout simply follows the leading dimension of the tensor + tiled_copy_A = self._make_gmem_tiled_copy_AB( + atom_async_copy, mA.element_type, self.a_major_mode, ab_copy_bits + ) + tiled_copy_B = self._make_gmem_tiled_copy_AB( + atom_async_copy, mB.element_type, self.b_major_mode, ab_copy_bits + ) + + # Creates a synchronous copy atom and thread layouts for the epilogue + c_copy_bits = 128 + atom_sync_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + mC.element_type, + num_bits_per_copy=c_copy_bits, + ) + tiled_copy_C = self._make_gmem_tiled_copy_C( + atom_sync_copy, mC.element_type, self.c_major_mode, c_copy_bits + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Tiled MMA + # /////////////////////////////////////////////////////////////////////////////// + + # Creates a mma atom with 16x8x16 shape for MNK + op = cute.nvgpu.warp.MmaF16BF16Op( + self.ab_dtype, self.acc_dtype, self.mma_inst_shape + ) + + permutation_mnk = ( + self.atom_layout_mnk[0] * self.mma_inst_shape[0], + # if atom layout's N-mode is 1, to leverage the largest coalesced + # shared memory -> register copy, set the tiled mma's N mode to 16 + self.atom_layout_mnk[1] * self.mma_inst_shape[1] * 2, + self.atom_layout_mnk[2] * self.mma_inst_shape[2], + ) + + # Created a tiled mma that tiles the atom according to specified layout. + # For a 2x2x1 atom layout, the mma atom is duplicated 4 times, twice + # across M and twice across N + tC = cute.make_layout(self.atom_layout_mnk) + tiled_mma = cute.make_tiled_mma( + op, + tC, + permutation_mnk=permutation_mnk, + ) + + # grid_dim: ((m + BLK_M - 1) // BLK_M, (n + BLK_N - 1) // BLK_N, l) + grid_dim = cute.ceil_div(mC.shape, (self.bM, self.bN, 1)) + + # Add threadblock rasterization to improve re-use of data + raster_factor = 1 + grid_dim_n = cute.size(grid_dim[1]) + # Thresholds picked so that it doesn't cause too many no-op CTAs + if grid_dim_n > 5: + raster_factor = 8 + elif grid_dim_n > 2: + raster_factor = 4 + elif grid_dim_n > 1: + raster_factor = 2 + rasterization_remap_grid_dim = ( + cute.size(grid_dim[0]) * raster_factor, + (cute.size(grid_dim[1]) + raster_factor - 1) // raster_factor, + cute.size(grid_dim[2]), + ) + + self.kernel( + mA, + mB, + mC, + sA_layout, + sB_layout, + sC_layout, + tiled_copy_A, + tiled_copy_B, + tiled_copy_C, + tiled_mma, + raster_factor, + epilogue_op, + ).launch( + grid=rasterization_remap_grid_dim, + block=[self.num_threads, 1, 1], + ) + + @cute.kernel + def kernel( + self, + mA: cute.Tensor, + mB: cute.Tensor, + mC: cute.Tensor, + sA_layout: cute.ComposedLayout, + sB_layout: cute.ComposedLayout, + sC_layout: cute.ComposedLayout, + tiled_copy_A: cute.TiledCopy, + tiled_copy_B: cute.TiledCopy, + tiled_copy_C: cute.TiledCopy, + tiled_mma: cute.TiledMma, + rasterization_factor: cutlass.Int32, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + # Thread index, block index + tidx, _, _ = cute.arch.thread_idx() + bidx, bidy, bidz = cute.arch.block_idx() + grid_dim = cute.ceil_div(mC.shape, (self.bM, self.bN, 1)) + offset_tile_x, offset_tile_y = self.raster_tile( + bidx, bidy, rasterization_factor + ) + # Early exit if CTA is out of range + if grid_dim[0] <= offset_tile_x or grid_dim[1] <= offset_tile_y: + pass + else: + tiler_coord = (offset_tile_x, offset_tile_y, None) + + # /////////////////////////////////////////////////////////////////////////////// + # Get the appropriate tiles for this thread block. + # gA: (BLK_M, BLK_N, k), gB: (BLK_N, BLK_K, k), gC: (BLK_M, BLK_N) + # /////////////////////////////////////////////////////////////////////////////// + gA = cute.local_tile( + mA[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(1, None, 1), + ) + gB = cute.local_tile( + mB[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(None, 1, 1), + ) + gC = cute.local_tile( + mC[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(1, 1, None), + ) + + # By default, if the tensor k mode does not divide into the tile k + # size, then last tiles in the k dimension are irregular. + # Instead, make the first tiles irregular when k is irregular. + # This allows us to handle the irregular tile first to avoid + # checking for this condition within the mainloop. + + # residual_k is a negative number indicating the amount needed to + # shift the pointer by in dimension k + residual_k = cute.size(mA, mode=[1]) - cutlass.Int32(self.bK) * cute.size( + gA, mode=[2] + ) + + # move the pointer of gA/gB in the `-k` direction + gA = cute.domain_offset((0, residual_k, 0), gA) + gB = cute.domain_offset((0, residual_k, 0), gB) + # input is 16B aligned + gA = cute.make_tensor(gA.iterator.align(16), gA.layout) + gB = cute.make_tensor(gB.iterator.align(16), gB.layout) + + # Construct identity layout for sA and sB (mirrors global tensors, + # used for predication only) + mcA = cute.make_identity_tensor(mA.layout.shape) + mcB = cute.make_identity_tensor(mB.layout.shape) + cA = cute.local_tile( + mcA[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(1, None, 1), + ) + cB = cute.local_tile( + mcB[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(None, 1, 1), + ) + + cA = cute.domain_offset((0, residual_k, 0), cA) + cB = cute.domain_offset((0, residual_k, 0), cB) + + # /////////////////////////////////////////////////////////////////////////////// + # Create shared memory buffers and get the appropriate fragments for this thread. + # sA: (BLK_M, BLK_K, PIPE) , sB: (BLK_N, BLK_K, PIPE) + # tAgA: (CPY, CPY_M, CPY_K, k) , tBgB: (CPY, CPY_N, CPY_K, k) + # tAsA: (CPY, CPY_M, CPY_K, PIPE) , tBsB: (CPY, CPY_N, CPY_K, PIPE) + # /////////////////////////////////////////////////////////////////////////////// + @cute.struct + class SharedStorageAB: + a: cute.struct.Align[ + cute.struct.MemRange[mA.element_type, cute.cosize(sA_layout)], + 16, + ] + b: cute.struct.Align[ + cute.struct.MemRange[mB.element_type, cute.cosize(sB_layout)], + 16, + ] + + @cute.struct + class SharedStorageC: + c: cute.struct.Align[ + cute.struct.MemRange[mC.element_type, cute.cosize(sC_layout)], + 16, + ] + + # Shared memory buffer + smem = cutlass.utils.SmemAllocator() + # Shared memory allocated for operations with A, B will be + # overwritten for operations on C. This is to improve performance + # by reducing the size of shared memory requested by each block + storage = smem.allocate( + max(SharedStorageAB.size_in_bytes(), SharedStorageC.size_in_bytes()), + byte_alignment=16, + ) + sA = SharedStorageAB(storage).a.get_tensor(sA_layout) + sB = SharedStorageAB(storage).b.get_tensor(sB_layout) + sC = SharedStorageC(storage).c.get_tensor(sC_layout) + + thr_copy_A = tiled_copy_A.get_slice(tidx) + thr_copy_B = tiled_copy_B.get_slice(tidx) + thr_copy_C = tiled_copy_C.get_slice(tidx) + tAgA = thr_copy_A.partition_S(gA) + tAsA = thr_copy_A.partition_D(sA) + tBgB = thr_copy_B.partition_S(gB) + tBsB = thr_copy_B.partition_D(sB) + tCsC_epilogue = thr_copy_C.partition_S(sC) + tCgC_epilogue = thr_copy_C.partition_D(gC) + + # Repeat the partitioning with identity layouts + tAcA = thr_copy_A.partition_S(cA) + tBcB = thr_copy_B.partition_S(cB) + + # /////////////////////////////////////////////////////////////////////////////// + # Predicate: Mark indices that need to copy when problem_shape isn't a multiple + # of tile_shape + # /////////////////////////////////////////////////////////////////////////////// + + # For predication over the tensors A (M/K), B (N/K), and (in the + # epilogue) C (M/N), we will compute it in a fashion similar to an + # outer product. The predication along one of the dimensions is + # evaluated and stored in a predication tensor. Then, the + # predication for the remaining dimension is handled later via an + # if/else branch at the copy. + # For A and B, predication booleans along M/N are stored in a + # predication tensor and along K is handled via a if/else branch. + + # Allocate predicate tensors for M and N. Predication is checked + # at the granularity of a copy atom, so the predicate tensor does not + # need separate booleans for individual elements within a copy + # atom (for example, the elements of tAgA.shape[0][0].) + tApA = cute.make_rmem_tensor( + cute.make_layout( + ( + tAgA.shape[0][1], + cute.size(tAgA, mode=[1]), + cute.size(tAgA, mode=[2]), + ), + stride=(cute.size(tAgA, mode=[1]), 1, 0), + ), + cutlass.Boolean, + ) + tBpB = cute.make_rmem_tensor( + cute.make_layout( + ( + tBsB.shape[0][1], + cute.size(tBsB, mode=[1]), + cute.size(tBsB, mode=[2]), + ), + stride=(cute.size(tBsB, mode=[1]), 1, 0), + ), + cutlass.Boolean, + ) + # Set predicates for M/N bounds + for rest_v in range(tApA.shape[0]): + for m in range(tApA.shape[1]): + tApA[rest_v, m, 0] = cute.elem_less( + tAcA[(0, rest_v), m, 0, 0][0], mA.shape[0] + ) + for rest_v in range(tBpB.shape[0]): + for n in range(tBpB.shape[1]): + tBpB[rest_v, n, 0] = cute.elem_less( + tBcB[(0, rest_v), n, 0, 0][0], mB.shape[0] + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Prefetch Prologue + # /////////////////////////////////////////////////////////////////////////////// + # Clear the smem tiles to account for predicated off loads + tAsA.fill(0) + tBsB.fill(0) + cute.arch.sync_threads() + # Start async loads for the first k-tile. Here we take care of the k residue + # via if/else check along the k dimension. Because we shifted the identity tensor + # by the residue_k and because the identity tensor is a coord tensor, the + # values of any identity tensor element that is poison is less than -1 + num_smem_stages = cute.size(tAsA, mode=[3]) + k_tile_count = cute.size(tAgA, mode=[3]) + k_tile_index = cutlass.Int32(0) + + for k in range(tApA.shape[2]): + if cute.elem_less(cutlass.Int32(-1), tAcA[0, 0, k, 0][1]): + cute.copy( + tiled_copy_A, + tAgA[None, None, k, k_tile_index], + tAsA[None, None, k, 0], + pred=tApA[None, None, k], + ) + for k in range(tBpB.shape[2]): + if cute.elem_less(cutlass.Int32(-1), tBcB[0, 0, k, 0][1]): + cute.copy( + tiled_copy_B, + tBgB[None, None, k, k_tile_index], + tBsB[None, None, k, 0], + pred=tBpB[None, None, k], + ) + k_tile_index = k_tile_index + 1 + cute.arch.cp_async_commit_group() + + # Start async loads for rest of the k-tiles + for k_tile in range(1, num_smem_stages - 1): + if k_tile == k_tile_count: + tApA.fill(0) + tBpB.fill(0) + cute.copy( + tiled_copy_A, + tAgA[None, None, None, k_tile_index], + tAsA[None, None, None, k_tile], + pred=tApA, + ) + cute.copy( + tiled_copy_B, + tBgB[None, None, None, k_tile_index], + tBsB[None, None, None, k_tile], + pred=tBpB, + ) + k_tile_index = k_tile_index + 1 + cute.arch.cp_async_commit_group() + + # /////////////////////////////////////////////////////////////////////////////// + # Tile MMA compute thread partitions and allocate accumulators + # /////////////////////////////////////////////////////////////////////////////// + thr_mma = tiled_mma.get_slice(tidx) + tCsA = thr_mma.partition_A(sA) + tCsB = thr_mma.partition_B(sB) + tCsC = thr_mma.partition_C(sC) + tCgC = thr_mma.partition_C(gC) + tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0]) + tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0]) + tCrC = tiled_mma.make_fragment_C(tCgC) + # Clear the accumulator + tCrC.fill(0.0) + + # /////////////////////////////////////////////////////////////////////////////// + # Copy Atom A/B retiling + # /////////////////////////////////////////////////////////////////////////////// + + # Create the copy atoms for the copy from shared memory to register + atom_copy_s2r_A = cute.make_copy_atom( + cute.nvgpu.warp.LdMatrix8x8x16bOp( + self.a_major_mode != utils.LayoutEnum.ROW_MAJOR, 4 + ), + mA.element_type, + ) + atom_copy_s2r_B = cute.make_copy_atom( + cute.nvgpu.warp.LdMatrix8x8x16bOp( + self.b_major_mode != utils.LayoutEnum.ROW_MAJOR, 4 + ), + mB.element_type, + ) + + # Creates the tiled copy so that it matches the thread-value layout + # expected by the tiled mma + tiled_copy_s2r_A = cute.make_tiled_copy_A(atom_copy_s2r_A, tiled_mma) + tiled_copy_s2r_B = cute.make_tiled_copy_B(atom_copy_s2r_B, tiled_mma) + + thr_copy_ldmatrix_A = tiled_copy_s2r_A.get_slice(tidx) + thr_copy_ldmatrix_B = tiled_copy_s2r_B.get_slice(tidx) + tCsA_copy_view = thr_copy_ldmatrix_A.partition_S(sA) + tCrA_copy_view = thr_copy_ldmatrix_A.retile(tCrA) + tCsB_copy_view = thr_copy_ldmatrix_B.partition_S(sB) + tCrB_copy_view = thr_copy_ldmatrix_B.retile(tCrB) + + # Current pipe index in smem to read from / write to + smem_pipe_read = 0 + smem_pipe_write = num_smem_stages - 1 + + tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read] + tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read] + + # /////////////////////////////////////////////////////////////////////////////// + # PREFETCH register pipeline + # /////////////////////////////////////////////////////////////////////////////// + num_k_block = cute.size(tCrA, mode=[2]) + if num_k_block > 1: + # Wait until our first prefetched tile is loaded in + cute.arch.cp_async_wait_group(num_smem_stages - 2) + cute.arch.sync_threads() + # Prefetch the first k-block rmem from the first k-tile + cute.copy( + tiled_copy_s2r_A, + tCsA_p[None, None, 0], + tCrA_copy_view[None, None, 0], + ) + cute.copy( + tiled_copy_s2r_B, + tCsB_p[None, None, 0], + tCrB_copy_view[None, None, 0], + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Mainloop + # 1. Shared memory pipeline (gmem -> smem): + # The default smem pipeline depth is 3, meaning that for shared + # memory buffers, we allocate three times the size described by the + # CTA tiler. We prefetch 2 of these buffers before entering the main + # loop. Considering only the transfer from global memory to shared + # memory, the general structure of the mainloop is: + # (1) copy k-tile from gmem to smem; + # (2) perform gemm computation on k-tile; + # (3) wait for the next copy to finish. + # The `cute.arch.cp_async_wait_group(num_smem_stages - 2)` command + # waits for the number of unfinished 'copy' to be <= 1. The advantage + # of this approach is that it allows for simultaneous production + # (i.e., step (1)) and consumption (i.e., step (2)) of smem. + # A common misconception is to prefetch N buffers and rewrite + # the pipeline logic to wait on N-1 pending copies. The disadvantage + # of this approach is that it requires fully consuming a buffer in + # order to open an empty buffer for the next copy. + # 2. Register pipeline (smem -> register): + # Similarly, the register pipeline produces i+1, consumes i, and + # produces i+2... Notably, i and i+1 do not use the same register, + # eliminating dependencies on the same register for better parallelism. + # 3. Combining the smem and register pipelines results in the mainloop. + # /////////////////////////////////////////////////////////////////////////////// + for k_tile in range(k_tile_count): + for k_block in cutlass.range(num_k_block, unroll_full=True): + if k_block == num_k_block - 1: + tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read] + tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read] + cute.arch.cp_async_wait_group(num_smem_stages - 2) + cute.arch.sync_threads() + + # Load A, B from shared memory to registers for k_block + 1 + k_block_next = (k_block + 1) % num_k_block # static + cute.copy( + tiled_copy_s2r_A, + tCsA_p[None, None, k_block_next], + tCrA_copy_view[None, None, k_block_next], + ) + cute.copy( + tiled_copy_s2r_B, + tCsB_p[None, None, k_block_next], + tCrB_copy_view[None, None, k_block_next], + ) + + # Fetch next A: To better interleave global memory access and compute + # instructions, we intentionally use the sequence: copy A, perform GEMM, + # then copy B. + if k_block == 0: + if k_tile + num_smem_stages - 1 < k_tile_count: + cute.copy( + tiled_copy_A, + tAgA[None, None, None, k_tile_index], + tAsA[None, None, None, smem_pipe_write], + pred=tApA, + ) + + # Thread-level register gemm for k_block + cute.gemm( + tiled_mma, + tCrC, + tCrA[None, None, k_block], + tCrB[None, None, k_block], + tCrC, + ) + + # Fetch next B and update smem pipeline read/write + if k_block == 0: + if k_tile + num_smem_stages - 1 < k_tile_count: + cute.copy( + tiled_copy_B, + tBgB[None, None, None, k_tile_index], + tBsB[None, None, None, smem_pipe_write], + pred=tBpB, + ) + k_tile_index = k_tile_index + 1 + cute.arch.cp_async_commit_group() + smem_pipe_write = smem_pipe_read + smem_pipe_read = smem_pipe_read + 1 + if smem_pipe_read == num_smem_stages: + smem_pipe_read = 0 + + # Sync before epilogue + cute.arch.cp_async_wait_group(0) + cute.arch.sync_threads() + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue with fusion + # /////////////////////////////////////////////////////////////////////////////// + tCrD = cute.make_fragment_like(tCrC, self.c_dtype) + tCrD[None] = epilogue_op(tCrC.load()).to(self.c_dtype) + + # Copy results of D back to shared memory + cute.autovec_copy(tCrD, tCsC) + + # Create coord tensor for C + ceilM, ceilN, _ = cute.ceil_div(mC.shape, (self.bM, self.bN, 1)) + mcC = cute.make_identity_tensor( + ( + cute.size(ceilM) * self.cta_tiler[0], + cute.size(ceilN) * self.cta_tiler[1], + 1, + ) + ) + cC = cute.local_tile( + mcC[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(1, 1, None), + ) + tCcC = thr_copy_C.partition_S(cC) + + tCrC_epilogue = cute.make_fragment_like(tCsC_epilogue) + # Wait for all writes to shared memory to finish before starting copies + # using the new layouts + cute.arch.sync_threads() + cute.autovec_copy(tCsC_epilogue, tCrC_epilogue) + + # Create predication tensor for m + tCpC = cute.make_rmem_tensor( + cute.make_layout( + ( + tCgC_epilogue.shape[0][1], + cute.size(tCgC_epilogue, mode=[1]), + cute.size(tCgC_epilogue, mode=[2]), + ), + stride=(cute.size(tCgC_epilogue, mode=[1]), 1, 0), + ), + cutlass.Boolean, + ) + for rest_v in range(tCpC.shape[0]): + for m in range(tCpC.shape[1]): + tCpC[rest_v, m, 0] = cute.elem_less( + tCcC[(0, rest_v), m, 0][0], mC.shape[0] + ) + + # Copy to global memory using better vectorization + for rest_v in range(tCpC.shape[0]): + for n in range(tCpC.shape[2]): + if cute.elem_less(tCcC[(0, rest_v), 0, n][1], mC.shape[1]): + cute.copy( + tiled_copy_C, + tCrC_epilogue[None, None, n], + tCgC_epilogue[None, None, n], + pred=tCpC[None, None, n], + ) + return + + def _make_smem_layout_AB(self, dtype, major_mode, copy_bits, smem_tiler): + major_mode_size = ( + smem_tiler[1] if major_mode == utils.LayoutEnum.ROW_MAJOR else smem_tiler[0] + ) + major_mode_size = 64 if major_mode_size >= 64 else major_mode_size + + swizzle_bits = int(math.log2(major_mode_size * dtype.width // copy_bits)) + swizzle_bits = min(swizzle_bits, 3) + + layout_atom_outer = ( + cute.make_layout((8, major_mode_size), stride=(major_mode_size, 1)) + if major_mode == utils.LayoutEnum.ROW_MAJOR + else cute.make_layout((major_mode_size, 8), stride=(1, major_mode_size)) + ) + layout_atom = cute.make_composed_layout( + cute.make_swizzle(swizzle_bits, 3, 3), + 0, + layout_atom_outer, + ) + layout = cute.tile_to_shape(layout_atom, smem_tiler, (0, 1, 2)) + return layout + + def _make_smem_layout_C(self, dtype, major_mode, copy_bits, smem_tiler): + major_mode_size = ( + smem_tiler[1] if major_mode == utils.LayoutEnum.ROW_MAJOR else smem_tiler[0] + ) + + swizzle_bits = int(math.log2(major_mode_size * dtype.width // copy_bits)) + swizzle_bits = min(swizzle_bits, 3) + + layout_atom_outer = ( + cute.make_layout((8, major_mode_size), stride=(major_mode_size, 1)) + if major_mode == utils.LayoutEnum.ROW_MAJOR + else cute.make_layout((major_mode_size, 8), stride=(1, major_mode_size)) + ) + layout_atom = cute.make_composed_layout( + cute.make_swizzle(swizzle_bits, 3, 4), + 0, + layout_atom_outer, + ) + + # Due to the thread layout of the mma, remove swizzle in C to + # prevent shared memory fragments owned by an single thread from + # holding swizzles + if major_mode == utils.LayoutEnum.COL_MAJOR: + layout_atom = cute.make_composed_layout( + cute.make_swizzle(0, 3, 4), 0, layout_atom_outer + ) + layout = cute.tile_to_shape( + layout_atom, + smem_tiler, + (0, 1), + ) + return layout + + def _make_gmem_tiled_copy_AB(self, atom_copy, dtype, major_mode, copy_bits): + copy_elems = copy_bits // dtype.width + shape_dim_1 = cute.size(self.bK) // copy_elems + # thread layout for copy + thread_layout = cute.make_layout( + (self.num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1) + ) + if major_mode != utils.LayoutEnum.ROW_MAJOR: + shape_dim_0 = cute.size(self.bM) // copy_elems + thread_layout = cute.make_layout( + (shape_dim_0, self.num_threads // shape_dim_0), stride=(1, shape_dim_0) + ) + # Value layout for copy + value_layout = ( + cute.make_layout((1, copy_elems)) + if major_mode == utils.LayoutEnum.ROW_MAJOR + else cute.make_layout((copy_elems, 1)) + ) + return cute.make_tiled_copy_tv(atom_copy, thread_layout, value_layout) + + def _make_gmem_tiled_copy_C(self, atom_copy, dtype, major_mode, copy_bits): + copy_elems = copy_bits // dtype.width + shape_dim_1 = cute.size(self.bN) // copy_elems + # thread layout for copy + thread_layout = cute.make_layout( + (self.num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1) + ) + if major_mode != utils.LayoutEnum.ROW_MAJOR: + shape_dim_0 = cute.size(self.bM) // copy_elems + thread_layout = cute.make_layout( + (shape_dim_0, self.num_threads // shape_dim_0), stride=(1, shape_dim_0) + ) + value_layout = ( + cute.make_layout((1, copy_elems)) + if major_mode == utils.LayoutEnum.ROW_MAJOR + else cute.make_layout((copy_elems, 1)) + ) + return cute.make_tiled_copy_tv(atom_copy, thread_layout, value_layout) + + def raster_tile(self, i, j, f): + new_i = i // f + new_j = (i % f) + (j * f) + return (new_i, new_j) + + +# Need a custum call function for the optimized kernel because of the tensor layout +def call_cute_custom(kernel_function, args, kwargs, grid, threads, params): + # Initialize cache if it does not exist + if not hasattr(call_cute_custom, "custom_cache"): + call_cute_custom.custom_cache = {} + + # Convert Torch tensors to CuTe tensors with correct layout cute_args = [] for arg in args: if isinstance(arg, torch.Tensor): - arg_ = from_dlpack(arg) - cute_args.append(arg_) + cute_tensor = ( + from_dlpack(arg, assumed_align=16) + .mark_layout_dynamic(leading_dim=1) + .mark_compact_shape_dynamic( + mode=1, + stride_order=(2, 0, 1), + divisibility=8, + ) + ) + cute_args.append(cute_tensor) else: cute_args.append(arg) - kernel_function(*cute_args, grid, threads) + # Form cache key from tuning parameters + param_keys = sorted(params.keys()) + cache_str = type(kernel_function).__name__ + for k in param_keys: + cache_str += "_" + str(params[k]) + + # Check if kernel exists in cache. Otherwise, compile and save + if cache_str in call_cute_custom.custom_cache: + compiled_kernel = call_cute_custom.custom_cache[cache_str] + else: + # Wrap in try-except because CuTe gives TypeError for certain block sizes. This + # stops the tuning process and we do not want this + try: + compiled_kernel = cute.compile(kernel_function, *cute_args) + call_cute_custom.custom_cache[cache_str] = compiled_kernel + except: + raise RuntimeError("Invalid configuration") + + + + compiled_kernel(*cute_args) + + + +def tune_optimized(mnkl: Tuple[int, int, int, int]): + + M, N, K, L = mnkl + + a = torch.randn((L, M, K), device="cuda", dtype=torch.float16).permute(1, 2, 0) + b = torch.randn((L, N, K), device="cuda", dtype=torch.float16).permute(1, 2, 0) # b is transposed + c = torch.empty((L, M, N), device="cuda", dtype=torch.float16).permute(1, 2, 0) + + c_ref = torch.matmul( + a.permute(2, 0, 1), + b.permute(2, 1, 0) + ).permute(1, 2, 0) + + args = [a, b, c] + tune_params = { + "tile_m": [64, 128, 256], + "tile_n": [64, 128, 256], + "tile_k": [16, 32, 64], + } -def tune(M, N, K): - a = torch.randn(M, K, device="cuda", dtype=torch.float16) - b = torch.randn(K, N, device="cuda", dtype=torch.float16) - c = torch.zeros(M, N, device="cuda", dtype=torch.float16) + constraints = ["3 * (tile_m * tile_k + tile_n * tile_k) * 2 <= 100 * 1024"] # SMEM constraint - args = [a, b, c] - size = M * N - answer = [None, None, (a @ b).cpu()] - tune_params = dict() - tune_params["block_size_x"] = [16, 32, 64, 128] - tune_params["block_size_y"] = [16, 32, 64, 128] - results, env = tune_kernel("matmul_kt", FULL_PATH, size, args, tune_params, lang="generic_python", - call_function=call_cute, answer=answer, atol=1e-2, verbose=True) + results, env = tune_kernel("TensorOpGemm", FULL_PATH, M * N, args, tune_params, verbose=True, restrictions=constraints, + lang="generic_python", call_function=call_cute_custom, answer=[None, None, c_ref.cpu()], atol=4) +if __name__ == "__main__": + #m, n, k = 4096, 4096, 4096 + #tune_naive_matmul(m, n, k) + mnkl = (4096, 4096, 4096, 1) + tune_optimized(mnkl) -if __name__ == "__main__": - run_example(128, 128, 128) - tune(128, 128, 128) diff --git a/examples/generic_python/matmul/numba_matmul.py b/examples/generic_python/matmul/numba_matmul.py index 00102a382..af6cce347 100644 --- a/examples/generic_python/matmul/numba_matmul.py +++ b/examples/generic_python/matmul/numba_matmul.py @@ -1,6 +1,6 @@ import torch import numpy as np -from numba import cuda +from numba import cuda, float32 from kernel_tuner import tune_kernel from pathlib import Path @@ -21,7 +21,7 @@ def matmul(A, B, C): # Example taken from https://nvidia.github.io/numba-cuda/user/examples.html#matrix-multiplication # Changed data type from float32 to float16 -@cuda.jit(cache=True) +@cuda.jit(cache=True, fastmath=True) def fast_matmul(A, B, C): # Define an array in the shared memory # The size and type of the arrays must be known at compile time @@ -60,6 +60,90 @@ def fast_matmul(A, B, C): C[x, y] = tmp +# Translated to numba-cuda from https://github.com/cupy/cupy/blob/main/examples/gemm/sgemm.cu +@cuda.jit(cache=True) +def optimized_matmul(M, N, K, A, B, C): + DIM_X = 16 + DIM_Y = 16 + BLK_M = 64 + BLK_N = 64 + BLK_K = 16 + THR_M = 4 + THR_N = 4 + + # thread indices + idx = cuda.threadIdx.x + idy = cuda.threadIdx.y + + blx = cuda.blockIdx.x + bly = cuda.blockIdx.y + + # shared memory + sA = cuda.shared.array((BLK_K, BLK_M), np.float16) + sB = cuda.shared.array((BLK_N, BLK_K), np.float16) + + # registers + rC = cuda.local.array((THR_N, THR_M), float32) + rA = cuda.local.array(THR_M, float32) + rB = cuda.local.array(THR_N, float32) + + # init accumulator + for n in range(THR_N): + for m in range(THR_M): + rC[n][m] = 0.0 + + # global indices + base_row = blx * BLK_M + base_col = bly * BLK_N + + # loop over K tiles + for kk in range(0, K, BLK_K): + + # load A tile into shared memory + for i in range(idy, BLK_K, DIM_Y): + for j in range(idx, BLK_M, DIM_X): + row = base_row + j + col = kk + i + if row < M and col < K: + sA[i, j] = A[row, col] + else: + sA[i, j] = 0.0 + + # load B tile + for i in range(idy, BLK_N, DIM_Y): + for j in range(idx, BLK_K, DIM_X): + row = kk + j + col = base_col + i + if row < K and col < N: + sB[i, j] = B[row, col] + else: + sB[i, j] = 0.0 + + cuda.syncthreads() + + # compute + for k in range(BLK_K): + for m in range(THR_M): + rA[m] = float32(sA[k, m * DIM_X + idx]) + + for n in range(THR_N): + rB[n] = float32(sB[n * DIM_Y + idy, k]) + + for n in range(THR_N): + for m in range(THR_M): + rC[n][m] += rA[m] * rB[n] + + cuda.syncthreads() + + # write back + for n in range(THR_N): + col = base_col + n * DIM_Y + idy + for m in range(THR_M): + row = base_row + m * DIM_X + idx + if row < M and col < N: + C[row, col] = np.float16(rC[n][m]) + + def run_matmul(M, N, K): # create numpy arrays @@ -135,9 +219,52 @@ def tune(M, N, K): if __name__ == "__main__": - run_matmul(128, 96, 64) + + #run_matmul(128, 96, 64) #tune(1024, 1024, 1024) + + M, N, K = 1024, 1024, 1024 + + # random FP16 inputs + A = np.random.rand(M, K).astype(np.float16) + B = np.random.rand(K, N).astype(np.float16) + + # output + C = np.zeros((M, N), dtype=np.float16) + + # copy to device + dA = cuda.to_device(A) + dB = cuda.to_device(B) + dC = cuda.to_device(C) + + # launch config + threads_per_block = (16, 16) + + blocks_per_grid_x = (M + 64 - 1) // 64 + blocks_per_grid_y = (N + 64 - 1) // 64 + blocks_per_grid = (blocks_per_grid_x, blocks_per_grid_y) + + # run kernel + optimized_matmul[blocks_per_grid, threads_per_block](M, N, K, dA, dB, dC) + + # copy result back + C_result = dC.copy_to_host() + + # reference (FP32 accumulate like your kernel) + C_ref = (A.astype(np.float32) @ B.astype(np.float32)).astype(np.float16) + + # check error + max_error = np.max(np.abs(C_result - C_ref)) + print("Max error:", max_error) + + # tolerance (FP16 is noisy) + if max_error < 4: + print("✅ Looks correct") + else: + print("❌ Something is off") + print("Expected: ", C_ref, "\nGot: ", C_result) + ''' 128: 16, 8 256: 8, 16 diff --git a/examples/generic_python/matmul/taichi_matmul.py b/examples/generic_python/matmul/taichi_matmul.py index 6eca9d21e..b8a1c9388 100644 --- a/examples/generic_python/matmul/taichi_matmul.py +++ b/examples/generic_python/matmul/taichi_matmul.py @@ -2,16 +2,15 @@ import taichi as ti from kernel_tuner import tune_kernel -from pathlib import Path -FULL_PATH = Path(__file__).resolve() + ti.init(arch=ti.gpu) -# NOTE tuning works on vibranium. Taichi does not work on DAS6 + # TODO make sure this is zero-copy @ti.kernel -def matmul(A: ti.types.ndarray(), B: ti.types.ndarray(), C: ti.types.ndarray()): +def matmul(A: ti.types.ndarray(dtype=ti.f16, ndim=2), B: ti.types.ndarray(dtype=ti.f16, ndim=2), C: ti.types.ndarray(dtype=ti.f16, ndim=2)): BLOCK_DIM = 16 ti.loop_config(block_dim=BLOCK_DIM) @@ -21,16 +20,16 @@ def matmul(A: ti.types.ndarray(), B: ti.types.ndarray(), C: ti.types.ndarray()): sum = 0.0 for k in range(K_dim): sum += A[i, k] * B[k, j] - C[i, j] = sum + C[i, j] = ti.cast(sum, ti.f16) def call_taichi(kernel_function, args, kwargs): kernel_function(*args, **kwargs) def tune(M, N, K): - torch_A = torch.rand((N, K), device='cuda', dtype=torch.float32) - torch_B = torch.rand((K, M), device='cuda', dtype=torch.float32) - torch_C = torch.empty((N, M), device='cuda', dtype=torch.float32) + torch_A = torch.rand((N, K), device='cuda', dtype=torch.float16) + torch_B = torch.rand((K, M), device='cuda', dtype=torch.float16) + torch_C = torch.empty((N, M), device='cuda', dtype=torch.float16) size = M * N args = [torch_A, torch_B, torch_C] @@ -40,13 +39,14 @@ def tune(M, N, K): results, env = tune_kernel( kernel_name="matmul", - kernel_source=FULL_PATH, + kernel_source=__file__, problem_size=size, arguments=args, tune_params=tune_params, answer=answer, lang="generic_python", call_function=call_taichi, + atol=1e-1, ) if __name__ == "__main__": diff --git a/examples/generic_python/matmul/tilelang_matmul.py b/examples/generic_python/matmul/tilelang_matmul.py index 7ced8cc83..2ea510d91 100644 --- a/examples/generic_python/matmul/tilelang_matmul.py +++ b/examples/generic_python/matmul/tilelang_matmul.py @@ -2,6 +2,8 @@ import tilelang import tilelang.language as T +import itertools + @tilelang.jit def matmul_basic(M:int, N:int, K:int, block_M:int, block_N:int, block_K:int, dtype:str="float16", accum_dtype:str="float32"): @@ -31,11 +33,13 @@ def gemm( return gemm + + # https://github.com/tile-ai/tilelang/blob/main/examples/gemm/example_gemm_autotune.py # changed gemm_autotune to gemm # originally, B was transposed. I changed this so the kernel input is the same as in other languages. -#@tilelang.jit -def matmul_opt(M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32): +@tilelang.jit +def matmul_opt(M, N, K, dummy, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32): @T.prim_func def gemm( A: T.Tensor((M, K), dtype), @@ -65,10 +69,105 @@ def gemm( -def main(): - kernel = matmul_opt(1024, 1024, 1024, 128, 128, 32, 3, 128, False) +def get_configs(kt=False): + """ + Generate a list of kernel tuning configuration dictionaries for a tiled matrix-multiply. + This function is used for tuning experiments with the built-in tuner + + Returns: + List[dict]: A list of configuration dictionaries + """ + + ''' + block_M = [64, 128, 256] + block_N = [64, 128, 256] + block_K = [32, 64] + num_stages = [0, 1, 2, 3] + thread_num = [128, 256] + enable_rasterization = [True, False] + ''' + + block_M = [64] + block_N = [64] + block_K = [32, 64] + num_stages = [0, 3] + thread_num = [128] + enable_rasterization = [True] + + _configs = list( + itertools.product( + block_M, + block_N, + block_K, + num_stages, + thread_num, + enable_rasterization, + ) + ) + + if kt: + configs = { + "block_M": block_M, + "block_N": block_N, + "block_K": block_K, + "num_stages": num_stages, + "thread_num": thread_num, + "enable_rasteration": enable_rasterization, + } + else: + configs = [ + { + "block_M": c[0], + "block_N": c[1], + "block_K": c[2], + "num_stages": c[3], + "thread_num": c[4], + "enable_rasteration": c[5], # keep param name for backward-compat + } + for c in _configs + ] + return configs + +# https://github.com/tile-ai/tilelang/blob/main/examples/gemm/example_gemm_autotune.py +# changed gemm_autotune to gemm +# originally, B was transposed. I changed this so the kernel input is the same as in other languages. +@tilelang.autotune(configs=get_configs()) +@tilelang.jit +def matmul_opt_autotune(M, N, K, dummy, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def gemm_autotune( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), dtype) + T.use_swizzle(panel_size=10, enable=enable_rasteration) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm( + A_shared, + B_shared, + C_local, + transpose_B=False, + ) + T.copy(C_local, C_shared) + T.copy(C_shared, C[by * block_M, bx * block_N]) + + return gemm_autotune + + + + +def main(): + kernel = matmul_opt_autotune(1024, 1024, 1024) import torch a = torch.randn(1024, 1024).cuda().half() diff --git a/examples/generic_python/matmul/warp_matmul.py b/examples/generic_python/matmul/warp_matmul.py index 9e6b7dcf3..78a90ce65 100644 --- a/examples/generic_python/matmul/warp_matmul.py +++ b/examples/generic_python/matmul/warp_matmul.py @@ -19,7 +19,7 @@ # GEMM example from https://nvidia.github.io/warp/user_guide/tiles.html @wp.kernel -def tile_gemm(A: wp.array2d(dtype=float), B: wp.array2d(dtype=float), C: wp.array2d(dtype=float)): +def tile_gemm(A: wp.array2d(dtype=wp.float16), B: wp.array2d(dtype=wp.float16), C: wp.array2d(dtype=wp.float16)): # output tile index i, j = wp.tid() @@ -36,17 +36,16 @@ def tile_gemm(A: wp.array2d(dtype=float), B: wp.array2d(dtype=float), C: wp.arra a = wp.tile_load(A, shape=(TILE_M, TILE_K), offset=(i*TILE_M, k*TILE_K), bounds_check=True) b = wp.tile_load(B, shape=(TILE_K, TILE_N), offset=(k*TILE_K, j*TILE_N), bounds_check=True) - # sum += a*b wp.tile_matmul(a, b, sum) - wp.tile_store(C, sum, offset=(i*TILE_M, j*TILE_N), bounds_check=True) + wp.tile_store(C, wp.tile_astype(sum, wp.float16), offset=(i*TILE_M, j*TILE_N), bounds_check=True) def run_kernel_direct(M, N, K): rng = np.random.default_rng(42) - A = rng.random((M, K), dtype=np.float32) - B = rng.random((K, N), dtype=np.float32) - C = np.zeros((M, N), dtype=np.float32) + A = rng.random((M, K)).astype(np.float16) + B = rng.random((K, N)).astype(np.float16) + C = np.zeros((M, N), dtype=np.float16) A_wp = wp.array(A) B_wp = wp.array(B) @@ -59,7 +58,7 @@ def run_kernel_direct(M, N, K): inputs=[A_wp, B_wp, C_wp], block_dim=TILE_THREADS) - np.testing.assert_allclose(C_wp.numpy(), A @ B, rtol=1e-3) + np.testing.assert_allclose(C_wp.numpy(), A @ B, rtol=1e-1) print("Example matrix multiplication passed") @@ -80,15 +79,15 @@ def call_warp(kernel_function, args, kwargs, grid, threads, params): kernel_function, dim=grid, inputs=warp_args, - block_dim=params["TILE_THREADS"], # We could directly take threads, but in the given example this is a constant + block_dim=params["TILE_THREADS"], ) def tune(M, K, N): rng = np.random.default_rng(42) - A = rng.random((M, K), dtype=np.float32) - B = rng.random((K, N), dtype=np.float32) - C = np.zeros((M, N), dtype=np.float32) + A = rng.random((M, K)).astype(np.float16) + B = rng.random((K, N)).astype(np.float16) + C = np.zeros((M, N), dtype=np.float16) size = (M, N) block_size_names = ["TILE_M", "TILE_N"] @@ -126,6 +125,8 @@ def tune(M, K, N): (129, 130, 33), ] - for size in sizes: - print(size) - run_kernel_direct(*size) + #for size in sizes: + # print(size) + # run_kernel_direct(*size) + + run_kernel_direct(1024, 1024, 1024) From 68a48468f0b1453f63bd3eb5e530aee5021c5382 Mon Sep 17 00:00:00 2001 From: Imke van Ooijen Date: Fri, 1 May 2026 08:47:50 +0200 Subject: [PATCH 10/14] examples --- examples/generic_python/call_functions.py | 27 ++- examples/generic_python/matmul/cupy_matmul.py | 98 +++----- examples/generic_python/matmul/cute_matmul.py | 115 ++++++--- .../generic_python/matmul/numba_matmul.py | 218 ++++++++---------- .../generic_python/matmul/taichi_matmul.py | 43 ++-- .../generic_python/matmul/tilelang_matmul.py | 212 ++++++++++++++--- .../generic_python/matmul/tilus_matmul.py | 213 +++++++++++------ .../generic_python/matmul/triton_matmul.py | 166 ++++++++++++- examples/generic_python/matmul/warp_matmul.py | 129 ++++++----- kernel_tuner/backends/generic_python.py | 4 +- 10 files changed, 816 insertions(+), 409 deletions(-) diff --git a/examples/generic_python/call_functions.py b/examples/generic_python/call_functions.py index 734568efd..8c74790e1 100644 --- a/examples/generic_python/call_functions.py +++ b/examples/generic_python/call_functions.py @@ -91,11 +91,24 @@ def call_warp(kernel_function, args, kwargs, grid, threads, params): else: warp_args.append(arg) + # Check if block_dim is in the tuning parameters. Otherwise, use + # the computed thread block dimensions. + if 'block_dim' in params.keys(): + threads_per_block = params['block_dim'] + else: + threads_per_block = threads[0] * threads[1] * threads[2] + + # Check if dim is in the tuning parameters. Otherwise, compute from + # grid and threads. + if 'dim' in params.keys(): + dimensions = params['dim'] + else: + dimensions = [grid[i] * threads[i] for i in range(len(grid))] + # launch kernel - with wp.Tape() as tape: - wp.launch_tiled( - kernel_function, - dim=grid, - inputs=warp_args, - block_dim=params["TILE_THREADS"], # We could directly take threads, but in the given example this is a constant - ) + wp.launch( + kernel_function, + dim=dimensions, + inputs=warp_args, + block_dim=threads_per_block + ) \ No newline at end of file diff --git a/examples/generic_python/matmul/cupy_matmul.py b/examples/generic_python/matmul/cupy_matmul.py index acd3703fb..1101b5fae 100644 --- a/examples/generic_python/matmul/cupy_matmul.py +++ b/examples/generic_python/matmul/cupy_matmul.py @@ -1,105 +1,71 @@ import cupy as cp from cupyx import jit import numpy as np -import torch from kernel_tuner import tune_kernel -from pathlib import Path - - -@jit.rawkernel() -def gemm_raw_strided(a, b, c, M, N, K): # with outer loop so that we can have more work per thread - # global thread indices - row = jit.blockIdx.y * jit.blockDim.y + jit.threadIdx.y - col = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x - - # grid stride increments - stride_y = jit.gridDim.y * jit.blockDim.y - stride_x = jit.gridDim.x * jit.blockDim.x - - i = row - while i < M: - j = col - while j < N: - acc = 0.0 - for kk in range(K): - acc += a[i, kk] * b[kk, j] - c[i, j] = acc - j += stride_x - i += stride_y +from examples.generic_python.call_functions import call_cupyx @jit.rawkernel() def gemm(a, b, c, M, N, K): row = jit.blockIdx.y * jit.blockDim.y + jit.threadIdx.y col = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x + + if row < M and col < N: + acc = 0.0 + for kk in range(K): + acc += a[row, kk] * b[kk, col] + c[row, col] = acc - acc = 0.0 - for kk in range(K): - acc += a[row, kk] * b[kk, col] - c[row, col] = acc - - -def run(): - M, K, N = 128, 256, 64 - # random test data - A = cp.random.random((M, K), dtype=cp.float32) - B = cp.random.random((K, N), dtype=cp.float32) - C = cp.zeros((M, N), dtype=cp.float32) +def run(M, N, K): + # float16 matrices on GPU + a = cp.random.random((M, K)).astype(cp.float16) + b = cp.random.random((K, N)).astype(cp.float16) + c = cp.zeros((M, N), dtype=cp.float16) - # launch parameters + # block / grid configuration block = (16, 16) - grid = ((N + block[0] - 1) // block[0], - (M + block[1] - 1) // block[1]) + grid = ((N + block[0] - 1) // block[0], (M + block[1] - 1) // block[1]) # launch kernel - gemm_raw_strided(grid, block, (A, B, C, M, N, K)) + gemm[grid, block](a, b, c, M, N, K) + cp.cuda.Device().synchronize() - # validate - C_ref = A.dot(B) - print("max error:", float(cp.max(cp.abs(C - C_ref)))) + # Correctness verification + c_ref = cp.matmul(a, b) + assert cp.allclose(c, c_ref, rtol=1e-2, atol=1e-1) + print("Succes") -def call_cupyx(kernel_function, args, kwargs, grid, threads): - cupy_args = [] - for arg in args: - if isinstance(arg, torch.Tensor): - cupy_args.append(cp.from_dlpack(arg)) - else: - cupy_args.append(arg) - kernel_function(grid, threads, tuple(cupy_args)) -def tune(): - M, K, N = 128, 256, 64 - - # random test data. Here we had to change cupy to numpy arrays. +def tune(M, N, K): + # random test data. Here we had to use numpy arrays instead of cupy. A = np.random.rand(M, K).astype(np.float16) B = np.random.rand(K, N).astype(np.float16) C = np.zeros((M, N), dtype=np.float16) - args = [A, B, C, M, N, K] size = (N, M) - tune_params = {"block_size_x": [2**i for i in range(10)], "block_size_y": [2**i for i in range(10)]} - restrictions = ["block_size_x == block_size_y"] - source = Path(__file__).resolve() + tune_params = {"block_size_x": [2**i for i in range(11)], "block_size_y": [2**i for i in range(11)]} + restrictions = ["block_size_x * block_size_y <= 1024"] results, env = tune_kernel( kernel_name="gemm", - kernel_source=source, + kernel_source=__file__, problem_size=size, arguments=args, tune_params=tune_params, - answer=[None, None, A.dot(B), None, None], - atol=1e-2, + answer=[None, None, A.dot(B), None, None, None], + atol=1e-1, call_function=call_cupyx, lang="generic_python", - verbose=True, - restrictions=restrictions, + restrictions=restrictions, + verbose=True, ) if __name__ == "__main__": - #tune() - run() \ No newline at end of file + M, N, K = 1024, 1024, 1024 + tune(M, N, K) + \ No newline at end of file diff --git a/examples/generic_python/matmul/cute_matmul.py b/examples/generic_python/matmul/cute_matmul.py index 8654829fa..4b5637c82 100644 --- a/examples/generic_python/matmul/cute_matmul.py +++ b/examples/generic_python/matmul/cute_matmul.py @@ -8,12 +8,9 @@ from cutlass.cute.runtime import from_dlpack from kernel_tuner import tune_kernel -from pathlib import Path from examples.generic_python.call_functions import call_cute -FULL_PATH = Path(__file__).resolve() - -# need export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH to work +# might need export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH to work ## Basic Matmul ================================================================ @@ -69,6 +66,23 @@ def matmul( kernel.launch(grid=grid, block=block) +def run_naive_matmul(M, N, K): + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + c = torch.zeros(M, N, device="cuda", dtype=torch.float16) + c_ref = a @ b + + # Convert to CuTe tensors + a_ = from_dlpack(a, assumed_align=16) + b_ = from_dlpack(b, assumed_align=16) + c_ = from_dlpack(c, assumed_align=16) + + compiled_kernel = cute.compile(matmul, a_, b_, c_) + compiled_kernel(a_, b_, c_) + + assert torch.allclose(c, c_ref, atol=M * 2 **(-11), rtol=1e-2) + print("Succes") + def tune_naive_matmul(M, N, K): a = torch.randn(M, K, device="cuda", dtype=torch.float16) @@ -79,11 +93,12 @@ def tune_naive_matmul(M, N, K): size = M * N answer = [None, None, (a @ b).cpu()] tune_params = dict() - tune_params["block_size_x"] = [8, 16, 32, 64, 128] - tune_params["block_size_y"] = [8, 16, 32, 64, 128] + tune_params["block_size_x"] = [2**i for i in range(1, 10)] + tune_params["block_size_y"] = [2**i for i in range(1, 10)] + restrictions = ["block_size_x * block_size_y <= 1024"] - results, env = tune_kernel("matmul", FULL_PATH, size, args, tune_params, lang="generic_python", - call_function=call_cute, answer=answer, atol=4, verbose=False) + results, env = tune_kernel("matmul", __file__, size, args, tune_params, lang="generic_python", + call_function=call_cute, answer=answer, atol=M * 2 **(-11), restrictions=restrictions, verbose=False) @@ -182,16 +197,17 @@ def __init__(self): self.ab_dtype = cutlass.Float16 self.c_dtype = cutlass.Float16 self.acc_dtype = cutlass.Float32 - tile_m = 128 # Added for KT support - tile_n = 128 - tile_k = 32 - self.cta_tiler = (tile_m, tile_n, tile_k) + self.bM = 128 # extracted from cta_tiler for KT support + self.bN = 128 # extracted from cta_tiler for KT support + self.bK = 32 # extracted from cta_tiler for KT support + self.cta_tiler = (self.bM, self.bN, self.bK) self.num_stages = 3 - self.atom_layout_mnk = (2, 2, 1) # moved from paramter to here for KT support - atom_lay_M, atom_lay_N, atom_lay_K = self.atom_layout_mnk + atom_lay_M = 2 # extracted from atom_layout_mnk for KT support + atom_lay_N = 2 # extracted from atom_layout_mnk for KT support + atom_lay_K = 1 # extracted from atom_layout_mnk for KT support + self.atom_layout_mnk = (atom_lay_M, atom_lay_N, atom_lay_K) # moved from init parameter to here for KT support self.num_threads = atom_lay_M * atom_lay_N * atom_lay_K * 32 - self.bM, self.bN, self.bK = self.cta_tiler self.mma_inst_shape = (16, 8, 16) mmaM, mmaN, mmaK = self.mma_inst_shape @@ -969,10 +985,39 @@ def call_cute_custom(kernel_function, args, kwargs, grid, threads, params): -def tune_optimized(mnkl: Tuple[int, int, int, int]): - - M, N, K, L = mnkl - +def run_optimized(M, N, K, L=1): + a = torch.randn((L, M, K), device="cuda", dtype=torch.float16).permute(1, 2, 0) + b = torch.randn((L, N, K), device="cuda", dtype=torch.float16).permute(1, 2, 0) # b is transposed + c = torch.empty((L, M, N), device="cuda", dtype=torch.float16).permute(1, 2, 0) + + c_ref = torch.matmul( + a.permute(2, 0, 1), + b.permute(2, 1, 0) + ).permute(1, 2, 0) + + def create_cute_tensor(torch_tensor): + cute_tensor = ( + from_dlpack(torch_tensor, assumed_align=16) + .mark_layout_dynamic(leading_dim=1) + .mark_compact_shape_dynamic( + mode=1, + stride_order=(2, 0, 1), + divisibility=8, + ) + ) + return cute_tensor + + a_, b_, c_ = create_cute_tensor(a), create_cute_tensor(b), create_cute_tensor(c) + + gemm = TensorOpGemm() + compiled_kernel = cute.compile(gemm, a_, b_, c_) + compiled_kernel(a_, b_, c_) + + assert torch.allclose(c_ref, c, atol= M * 2**(-11), rtol=1e-2) + print("Sucess") + + +def tune_optimized(M, N, K, L=1): a = torch.randn((L, M, K), device="cuda", dtype=torch.float16).permute(1, 2, 0) b = torch.randn((L, N, K), device="cuda", dtype=torch.float16).permute(1, 2, 0) # b is transposed c = torch.empty((L, M, N), device="cuda", dtype=torch.float16).permute(1, 2, 0) @@ -985,24 +1030,34 @@ def tune_optimized(mnkl: Tuple[int, int, int, int]): args = [a, b, c] tune_params = { - "tile_m": [64, 128, 256], - "tile_n": [64, 128, 256], - "tile_k": [16, 32, 64], + "bM":[16, 32, 64, 128, 256], + "bN":[32, 64, 128, 256], + "bK": [16, 32, 64, 128], # must be divisable by 16 + "num_stages": [3, 4, 5], # restricted to >= 3 by the kernel + "atom_lay_M": [1, 2, 4], + "atom_lay_N": [1, 2], } - constraints = ["3 * (tile_m * tile_k + tile_n * tile_k) * 2 <= 100 * 1024"] # SMEM constraint + restrictions = [ + "bM % (atom_lay_M * 16) == 0", "bN % (atom_lay_N * 8) == 0", # layout + "atom_lay_M * atom_lay_N * 32 <= 1024", # number of threads + "num_stages * (bK * (bM + bN)) * 2 <= 49152", # SMEM + "bM >= atom_lay_M * 32", # ensure each atom sees at least 2 MMA tiles per dimension + "bN >= atom_lay_N * 16", # ensure each atom sees at least 2 MMA tiles per dimension + ] - results, env = tune_kernel("TensorOpGemm", FULL_PATH, M * N, args, tune_params, verbose=True, restrictions=constraints, - lang="generic_python", call_function=call_cute_custom, answer=[None, None, c_ref.cpu()], atol=4) + + results, env = tune_kernel("TensorOpGemm", __file__, M * N, args, tune_params, verbose=True, restrictions=restrictions, #strategy="bayes_opt", + lang="generic_python", call_function=call_cute_custom, answer=[None, None, c_ref.cpu()], atol=M * 2**(-10)) if __name__ == "__main__": - #m, n, k = 4096, 4096, 4096 + m, n, k = 4096, 4096, 4096 + #run_naive_matmul(m, n, k) #tune_naive_matmul(m, n, k) - mnkl = (4096, 4096, 4096, 1) - tune_optimized(mnkl) - - + # mnkl = (4096, 4096, 4096, 1) + tune_optimized(m, n, k) + #run_optimized(m, n, k) diff --git a/examples/generic_python/matmul/numba_matmul.py b/examples/generic_python/matmul/numba_matmul.py index af6cce347..d8e23d6fe 100644 --- a/examples/generic_python/matmul/numba_matmul.py +++ b/examples/generic_python/matmul/numba_matmul.py @@ -1,16 +1,13 @@ -import torch import numpy as np from numba import cuda, float32 + from kernel_tuner import tune_kernel -from pathlib import Path +from examples.generic_python.call_functions import call_numba -FULL_PATH = Path(__file__).resolve() -# Example taken from https://nvidia.github.io/numba-cuda/user/examples.html#matrix-multiplication +# Source: https://nvidia.github.io/numba-cuda/user/examples.html#matrix-multiplication @cuda.jit(cache=True) def matmul(A, B, C): - """Perform square matrix multiplication of C = A * B - """ i, j = cuda.grid(2) if i < C.shape[0] and j < C.shape[1]: tmp = 0. @@ -19,48 +16,7 @@ def matmul(A, B, C): C[i, j] = tmp -# Example taken from https://nvidia.github.io/numba-cuda/user/examples.html#matrix-multiplication -# Changed data type from float32 to float16 -@cuda.jit(cache=True, fastmath=True) -def fast_matmul(A, B, C): - # Define an array in the shared memory - # The size and type of the arrays must be known at compile time - TPB = 16 # TEMP voor overhead testing - sA = cuda.shared.array(shape=(TPB, TPB), dtype=np.float16) - sB = cuda.shared.array(shape=(TPB, TPB), dtype=np.float16) - - x, y = cuda.grid(2) - - tx = cuda.threadIdx.x - ty = cuda.threadIdx.y - bpg = cuda.gridDim.x # blocks per grid - - if x >= C.shape[0] and y >= C.shape[1]: - # Quit if (x, y) is outside of valid C boundary - return - - # Each thread computes one element in the result matrix. - # The dot product is chunked into dot products of TPB-long vectors. - tmp = 0. - for i in range(bpg): - # Preload data into shared memory - sA[tx, ty] = A[x, ty + i * TPB] - sB[tx, ty] = B[tx + i * TPB, y] - - # Wait until all threads finish preloading - cuda.syncthreads() - - # Computes partial product on the shared memory - for j in range(TPB): - tmp += sA[tx, j] * sB[j, ty] - - # Wait until all threads finish computing - cuda.syncthreads() - - C[x, y] = tmp - - -# Translated to numba-cuda from https://github.com/cupy/cupy/blob/main/examples/gemm/sgemm.cu +# Translated to Numba-CUDA from https://github.com/cupy/cupy/blob/main/examples/gemm/sgemm.cu @cuda.jit(cache=True) def optimized_matmul(M, N, K, A, B, C): DIM_X = 16 @@ -68,8 +24,8 @@ def optimized_matmul(M, N, K, A, B, C): BLK_M = 64 BLK_N = 64 BLK_K = 16 - THR_M = 4 - THR_N = 4 + THR_M = 4 # Should be equal to BLK_M / DIM_X + THR_N = 4 # Should be equal to BLK_N / DIM_Y # thread indices idx = cuda.threadIdx.x @@ -144,12 +100,12 @@ def optimized_matmul(M, N, K, A, B, C): C[row, col] = np.float16(rC[n][m]) -def run_matmul(M, N, K): - +def run_basic(M, N, K): # create numpy arrays A = np.random.rand(M, K).astype(np.float16) B = np.random.rand(K, N).astype(np.float16) C = np.zeros((M, N), dtype=np.float16) + C_ref = A @ B # copy to GPU A_d = cuda.to_device(A) @@ -166,29 +122,52 @@ def run_matmul(M, N, K): ) # launch kernel - fast_matmul[blocks, threads](A_d, B_d, C_d) + matmul[blocks, threads](A_d, B_d, C_d) + cuda.synchronize() # copy result back C_result = C_d.copy_to_host() # check - np.testing.assert_allclose(C_result, A @ B, rtol=1e-2) + np.testing.assert_allclose(C_result, C_ref, rtol=1e-2, atol=M * 2**(-11)) + print("Succes") - print("Correct!") +def run_optimized(M, N, K): + # inputs + A = np.random.rand(M, K).astype(np.float16) + B = np.random.rand(K, N).astype(np.float16) + C = np.zeros((M, N), dtype=np.float16) + C_ref = A @ B + + # move to GPU + dA = cuda.to_device(A) + dB = cuda.to_device(B) + dC = cuda.to_device(C) -def call_numba(kernel_function, args, kwargs, grid, threads): - numba_args = [] - for arg in args: - if isinstance(arg, torch.Tensor): - numba_args.append(cuda.as_cuda_array(arg)) - else: - numba_args.append(arg) - kernel_function[grid, threads](*args, **kwargs) + threads_per_block = (16, 16) + blocks_per_grid = ( + (M + 63) // 64, # BLK_M = 64 + (N + 63) // 64, # BLK_N = 64 + ) + # launch + for i in range(10000): + optimized_matmul[blocks_per_grid, threads_per_block](M, N, K, dA, dB, dC) + cuda.synchronize() -def tune(M, N, K): - # create numpy arrays + # copy result back + C_result = dC.copy_to_host() + + # check + np.testing.assert_allclose(C_result, C_ref, rtol=1e-2, atol=M * 2**(-11)) + print("Succes") + + + + +def tune_basic(M, N, K): + # create inputs as normal, but do not copy to device A = np.random.rand(M, K).astype(np.float16) B = np.random.rand(K, N).astype(np.float16) C = np.zeros((M, N), dtype=np.float16) @@ -196,80 +175,87 @@ def tune(M, N, K): size = (M, N) args = [A, B, C] tune_params = dict() - tune_params["block_size_x"] = [4, 8, 16, 32, 64, 128, 256] - tune_params["block_size_y"] = [4, 8, 16, 32, 64, 128, 256] - tune_params["TPB"] = [4, 8, 16, 32, 64, 128, 256] - restrictions = ["block_size_x == block_size_y", "block_size_x == TPB"] + tune_params["block_size_x"] = [2**i for i in range(1, 10)] + tune_params["block_size_y"] = [2**i for i in range(1, 10)] + + restrictions = ["block_size_x * block_size_y <= 1024"] answer = [None, None, A @ B] atol = M * 2**(-11) results, env = tune_kernel( - kernel_name="fast_matmul", - kernel_source=FULL_PATH, + kernel_name="matmul", + kernel_source=__file__, problem_size=size, arguments=args, tune_params=tune_params, lang="generic_python", answer=answer, atol=atol, + restrictions=restrictions, call_function=call_numba, - restrictions=restrictions ) -if __name__ == "__main__": +def tune_optimized(M, N, K): + # create inputs as normal, but do not copy to device + A = np.random.rand(M, K).astype(np.float16) + B = np.random.rand(K, N).astype(np.float16) + C = np.zeros((M, N), dtype=np.float16) + + size = (M, N) + args = [M, N, K, A, B, C] + tune_params = { + "DIM_X": [2**i for i in range(1, 10)], + "DIM_Y": [2**i for i in range(1, 10)], + "BLK_M": [2**i for i in range(1, 10)], + "BLK_N": [2**i for i in range(1, 10)], + "BLK_K": [2**i for i in range(1, 10)], + "THR_M": [2**i for i in range(1, 9)], # Restricted to BLK_M / DIM_X + "THR_N": [2**i for i in range(1, 9)], # Restricted to BLK_N / DIM_Y + } - #run_matmul(128, 96, 64) - #tune(1024, 1024, 1024) + answer = [None, None, None, None, None, A @ B] + atol = M * 2**(-11) + restrictions = [ + "BLK_M % DIM_X == 0", + "BLK_N % DIM_Y == 0", + "THR_M == BLK_M / DIM_X", + "THR_N == BLK_N / DIM_Y", + "DIM_X * DIM_Y <= 1024", + "DIM_x * DIM_Y >= 32", + ] - M, N, K = 1024, 1024, 1024 + results, env = tune_kernel( + kernel_name="optimized_matmul", + kernel_source=__file__, + problem_size=size, + arguments=args, + tune_params=tune_params, + lang="generic_python", + answer=answer, + atol=atol, + restrictions = restrictions, + block_size_names = ["DIM_X", "DIM_Y"], + grid_div_x = ["BLK_M"], + grid_div_y = ["BLK_N"], + strategy = "bayes_opt", + strategy_options = {"max_fevals": 100}, + call_function=call_numba, + ) - # random FP16 inputs - A = np.random.rand(M, K).astype(np.float16) - B = np.random.rand(K, N).astype(np.float16) - # output - C = np.zeros((M, N), dtype=np.float16) - # copy to device - dA = cuda.to_device(A) - dB = cuda.to_device(B) - dC = cuda.to_device(C) +if __name__ == "__main__": + M, N, K = 1024, 1024, 1024 - # launch config - threads_per_block = (16, 16) + run_basic(M, N, K) + #tune_basic(M, N, K) - blocks_per_grid_x = (M + 64 - 1) // 64 - blocks_per_grid_y = (N + 64 - 1) // 64 - blocks_per_grid = (blocks_per_grid_x, blocks_per_grid_y) + #run_optimized(M, N, K) + #tune_optimized(M, N, K) - # run kernel - optimized_matmul[blocks_per_grid, threads_per_block](M, N, K, dA, dB, dC) - # copy result back - C_result = dC.copy_to_host() - # reference (FP32 accumulate like your kernel) - C_ref = (A.astype(np.float32) @ B.astype(np.float32)).astype(np.float16) - - # check error - max_error = np.max(np.abs(C_result - C_ref)) - print("Max error:", max_error) - - # tolerance (FP16 is noisy) - if max_error < 4: - print("✅ Looks correct") - else: - print("❌ Something is off") - print("Expected: ", C_ref, "\nGot: ", C_result) - - ''' - 128: 16, 8 - 256: 8, 16 - 512: 8, 16 - 1024: 16, 8 - 4096: 16, 8 - 8192: 16, 8 - ''' \ No newline at end of file + \ No newline at end of file diff --git a/examples/generic_python/matmul/taichi_matmul.py b/examples/generic_python/matmul/taichi_matmul.py index b8a1c9388..bb3c53a81 100644 --- a/examples/generic_python/matmul/taichi_matmul.py +++ b/examples/generic_python/matmul/taichi_matmul.py @@ -1,21 +1,17 @@ -import torch +import numpy as np import taichi as ti from kernel_tuner import tune_kernel - - +from examples.generic_python.call_functions import call_taichi ti.init(arch=ti.gpu) - -# TODO make sure this is zero-copy +BLOCK_DIM = 512 @ti.kernel def matmul(A: ti.types.ndarray(dtype=ti.f16, ndim=2), B: ti.types.ndarray(dtype=ti.f16, ndim=2), C: ti.types.ndarray(dtype=ti.f16, ndim=2)): - BLOCK_DIM = 16 - ti.loop_config(block_dim=BLOCK_DIM) - K_dim = A.shape[1] + ti.loop_config(block_dim=BLOCK_DIM) for i, j in C: sum = 0.0 for k in range(K_dim): @@ -23,19 +19,28 @@ def matmul(A: ti.types.ndarray(dtype=ti.f16, ndim=2), B: ti.types.ndarray(dtype= C[i, j] = ti.cast(sum, ti.f16) -def call_taichi(kernel_function, args, kwargs): - kernel_function(*args, **kwargs) +def run(M, N, K): + A = np.random.rand(M, K).astype(np.float16) + B = np.random.rand(K, N).astype(np.float16) + C = np.zeros((M, N), dtype=np.float16) + C_ref = A @ B + + matmul(A, B, C) + + np.testing.assert_allclose(C, C_ref, rtol=1e-2, atol=M * 2**(-11)) + print("Succes") + def tune(M, N, K): - torch_A = torch.rand((N, K), device='cuda', dtype=torch.float16) - torch_B = torch.rand((K, M), device='cuda', dtype=torch.float16) - torch_C = torch.empty((N, M), device='cuda', dtype=torch.float16) + A = np.random.rand(M, K).astype(np.float16) + B = np.random.rand(K, N).astype(np.float16) + C = np.zeros((M, N), dtype=np.float16) size = M * N - args = [torch_A, torch_B, torch_C] - tune_params = {"BLOCK_DIM": {4, 8, 16, 32, 64, 128, 256, 512, 1024}} + args = [A, B, C] + tune_params = {"BLOCK_DIM": [2**i for i in range(5, 11)]} - answer = [None, None, (torch_A @ torch_B).cpu()] + answer = [None, None, A @ B] results, env = tune_kernel( kernel_name="matmul", @@ -46,11 +51,13 @@ def tune(M, N, K): answer=answer, lang="generic_python", call_function=call_taichi, - atol=1e-1, + atol=M * 2**(-11), + block_size_names = ["BLOCK_DIM"], ) if __name__ == "__main__": - tune(128, 128, 128) + run(1024, 1024, 1024) + tune(1024, 1024, 1024) diff --git a/examples/generic_python/matmul/tilelang_matmul.py b/examples/generic_python/matmul/tilelang_matmul.py index 2ea510d91..1067d9ff5 100644 --- a/examples/generic_python/matmul/tilelang_matmul.py +++ b/examples/generic_python/matmul/tilelang_matmul.py @@ -4,7 +4,11 @@ import itertools +from kernel_tuner import tune_kernel +from examples.generic_python.call_functions import call_tilelang +# https://github.com/tile-ai/tilelang/tree/main/examples/gemm +# num_threads and num_stages added as variables to enable tuning. @tilelang.jit def matmul_basic(M:int, N:int, K:int, block_M:int, block_N:int, block_K:int, dtype:str="float16", accum_dtype:str="float32"): @T.prim_func @@ -13,7 +17,9 @@ def gemm( B: T.Tensor((K, N), dtype), C: T.Tensor((M, N), dtype), ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + num_threads = 128 + num_stages = 3 + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): # We do use shared memory, even though this is a basic kernel. However, you don't # really get around this because T.gemm can not handle global memory directly. A_shared = T.alloc_shared((block_M, block_K), dtype) @@ -23,7 +29,7 @@ def gemm( # We do use a pipelining optimization here, because this is 'the basic way' # of writing for loops in TileLang. - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[by * block_M, k * block_K], A_shared) T.copy(B[k * block_K, bx * block_N], B_shared) T.gemm(A_shared, B_shared, C_local) @@ -35,41 +41,50 @@ def gemm( -# https://github.com/tile-ai/tilelang/blob/main/examples/gemm/example_gemm_autotune.py -# changed gemm_autotune to gemm -# originally, B was transposed. I changed this so the kernel input is the same as in other languages. +# https://github.com/tile-ai/tilelang/tree/main/examples/gemm +# num_threads and num_stages were added as metaparameters to make it possible to compare the built-in tuner +# with Kernel Tuner. +# Removed annotated memory layout for tiles A and B @tilelang.jit -def matmul_opt(M, N, K, dummy, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32): +def matmul_opt(M, N, K, block_M, block_N, block_K, num_threads=128, num_stages=3, dummy=0, dtype=T.float16, accum_dtype=T.float): @T.prim_func - def gemm( + def main( A: T.Tensor((M, K), dtype), B: T.Tensor((K, N), dtype), C: T.Tensor((M, N), dtype), ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): + # Allocate shared and local fragments A_shared = T.alloc_shared((block_M, block_K), dtype) B_shared = T.alloc_shared((block_K, block_N), dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - C_shared = T.alloc_shared((block_M, block_N), dtype) - T.use_swizzle(panel_size=10, enable=enable_rasteration) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Enable swizzle-based rasterization for better L2 locality + T.use_swizzle(panel_size=10, enable=True) + + # Clear the local accumulation buffer T.clear(C_local) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[k * block_K, bx * block_N], B_shared) - T.gemm( - A_shared, - B_shared, - C_local, - transpose_B=False, - ) - T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M, bx * block_N]) - return gemm + # Pipelined iteration over K dimension + for idx in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + # Copy tile of A + T.copy(A[by * block_M, idx * block_K], A_shared) + # Parallel copy tile of B + for ko, j in T.Parallel(block_K, block_N): + B_shared[ko, j] = B[idx * block_K + ko, bx * block_N + j] + # Perform local GEMM on the shared-memory tiles + T.gemm(A_shared, B_shared, C_local) + # Copy the result tile back + T.copy(C_local, C[by * block_M, bx * block_N]) + return main + + + +# TODO def get_configs(kt=False): """ Generate a list of kernel tuning configuration dictionaries for a tiled matrix-multiply. @@ -130,6 +145,8 @@ def get_configs(kt=False): return configs + +# TODO # https://github.com/tile-ai/tilelang/blob/main/examples/gemm/example_gemm_autotune.py # changed gemm_autotune to gemm # originally, B was transposed. I changed this so the kernel input is the same as in other languages. @@ -137,35 +154,35 @@ def get_configs(kt=False): @tilelang.jit def matmul_opt_autotune(M, N, K, dummy, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32): @T.prim_func - def gemm_autotune( + def gemm( A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), + B: T.Tensor((N, K), dtype), C: T.Tensor((M, N), dtype), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_K, block_N), dtype) + B_shared = T.alloc_shared((block_N, block_K), dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_shared = T.alloc_shared((block_M, block_N), dtype) T.use_swizzle(panel_size=10, enable=enable_rasteration) T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[k * block_K, bx * block_N], B_shared) + T.copy(B[bx * block_N, k * block_K], B_shared) T.gemm( A_shared, B_shared, C_local, - transpose_B=False, + transpose_B=True, ) T.copy(C_local, C_shared) T.copy(C_shared, C[by * block_M, bx * block_N]) - return gemm_autotune - + return gemm +# DEZE mag denk ik weg def main(): kernel = matmul_opt_autotune(1024, 1024, 1024) import torch @@ -184,5 +201,138 @@ def main(): +def run_basic(M, N, K): + A = torch.randn(M, K, device="cuda", dtype=torch.float16) + B = torch.randn(K, N, device="cuda", dtype=torch.float16) + C = torch.zeros(M, N, device="cuda", dtype=torch.float16) + C_ref = A @ B + + kernel = matmul_basic( + M, N, K, + block_M=64, + block_N=64, + block_K=32, + ) + + kernel(A, B, C) + + assert torch.allclose(C, C_ref, atol=M * 2 **(-11), rtol=1e-2) + print("Succes") + + +def run_opt(M, N, K): + A = torch.randn(M, K, device="cuda", dtype=torch.float16) + B = torch.randn(K, N, device="cuda", dtype=torch.float16) + C = torch.zeros(M, N, device="cuda", dtype=torch.float16) + C_ref = A @ B + + kernel = matmul_opt( + M, N, K, + dummy=0, + block_M=64, + block_N=64, + block_K=32, + num_stages=3, + num_threads=128, + ) + + kernel(A, B, C) + + assert torch.allclose(C, C_ref, atol=M * 2 **(-11), rtol=1e-2) + print("Succes") + + +def tune_basic(M, N, K): + A = torch.randn(M, K, device="cuda", dtype=torch.float16) + B = torch.randn(K, N, device="cuda", dtype=torch.float16) + C = torch.empty((M, N), device='cuda', dtype=torch.float16) + C_ref = A @ B + + size = (M, N) + + + args = [A, B, C] + + tune_params = dict() + tune_params["block_M"] = [2**i for i in range(4, 10)] + tune_params["block_N"] = [2**i for i in range(4, 10)] + tune_params["block_K"] = [2**i for i in range(4, 10)] + tune_params["num_threads"] = [2**i for i in range(5, 11)] + tune_params["num_stages"] = [1, 2, 3, 4, 5] + tune_params["M"] = [M] + tune_params["N"] = [N] + tune_params["K"] = [K] + + restrictions = [ + "2 * num_stages * block_K * (block_M + block_N) <= 49152" + ] + + results, env = tune_kernel( + kernel_name="matmul_basic", + kernel_source=__file__, + problem_size=size, + arguments=args, + tune_params=tune_params, + lang="generic_python", + answer=[None, None, C_ref.cpu()], + atol=M * 2**(-11), + strategy = "bayes_opt", + strategy_options = {"max_fevals": 100}, + call_function=call_tilelang, + verbose=True, + restrictions=restrictions, + ) + + +def tune_opt(M, N, K): + A = torch.randn(M, K, device="cuda", dtype=torch.float16) + B = torch.randn(K, N, device="cuda", dtype=torch.float16) + C = torch.empty((M, N), device='cuda', dtype=torch.float16) + C_ref = A @ B + + size = (M, N) + + args = [A, B, C] + + tune_params = dict() + tune_params["block_M"] = [2**i for i in range(4, 10)] + tune_params["block_N"] = [2**i for i in range(4, 10)] + tune_params["block_K"] = [2**i for i in range(4, 10)] + tune_params["num_threads"] = [2**i for i in range(5, 11)] + tune_params["num_stages"] = [1, 2, 3, 4, 5] + tune_params["M"] = [M] + tune_params["N"] = [N] + tune_params["K"] = [K] + + restrictions = [ + "2 * num_stages * block_K * (block_M + block_N) <= 49152" + ] + + results, env = tune_kernel( + kernel_name="matmul_opt", + kernel_source=__file__, + problem_size=size, + arguments=args, + tune_params=tune_params, + lang="generic_python", + answer=[None, None, C_ref.cpu()], + atol=M * 2**(-11), + strategy = "bayes_opt", + strategy_options = {"max_fevals": 100}, + call_function=call_tilelang, + verbose=True, + restrictions=restrictions, + ) + + + + if __name__ == "__main__": - main() + M, N, K = 4096, 4096, 4096 + #run_basic(M, N, K) + #run_opt(M, N, K) + + #tune_basic(M, N, K) + tune_opt(M, N, K) + + diff --git a/examples/generic_python/matmul/tilus_matmul.py b/examples/generic_python/matmul/tilus_matmul.py index d893b9490..38f6523f9 100644 --- a/examples/generic_python/matmul/tilus_matmul.py +++ b/examples/generic_python/matmul/tilus_matmul.py @@ -1,7 +1,12 @@ +import torch + import tilus from tilus import float16, float32, int32 from tilus.utils import cdiv +from kernel_tuner import tune_kernel +from examples.generic_python.call_functions import call_tilus + # This kernel is copied from the Tilus project: # https://github.com/NVIDIA/tilus/blob/main/examples/matmul/matmul_v0.py @@ -30,7 +35,8 @@ def __call__( cdiv(m_size, self.block_m), # the x dimension size of the grid cdiv(n_size, self.block_n), # the y dimension size of the grid ] - self.attrs.warps = 1 # the number of warps per thread block, must be a compile-time known integer + num_warps = 1 # added for tuning + self.attrs.warps = num_warps # the number of warps per thread block, must be a compile-time known integer # define two int32 variables to store the offsets of the m and n dimensions for the current thread block. offset_m: int32 = self.block_m * self.blockIdx.x @@ -70,30 +76,22 @@ def __call__( # This kernel is copied from the Tilus project: -# https://github.com/NVIDIA/tilus/blob/main/examples/matmul/matmul_+v5.py +# https://nvidia.github.io/tilus/stable/tutorials/matmul-ampere/matmul/matmul_v4.html # -# Original example: matmul_v5.py +# Original example: matmul_v4.py # Copyright (c) the Tilus authors # # Modifications in file: # - Removed auto-tuning decorators +# - Added default values (None) for the __init__ parameters class MatmulOpt(tilus.Script): - def __init__( - self, - num_warps=None, - block_m=None, - block_n=None, - block_k=None, - num_stages=None, - split_k_factor=None, - ): + def __init__(self, num_warps=None, block_m=None, block_n=None, block_k=None, num_stages=None): super().__init__() self.block_m = block_m self.block_n = block_n self.block_k = block_k self.num_warps = num_warps self.num_stages = num_stages - self.split_k_factor = split_k_factor def __call__( self, @@ -104,20 +102,9 @@ def __call__( b_ptr: ~float16, c_ptr: ~float16, ): - self.attrs.blocks = [ - cdiv(m_size, self.block_m), - cdiv(n_size, self.block_n), - self.split_k_factor, - ] + self.attrs.blocks = [cdiv(m_size, self.block_m), cdiv(n_size, self.block_n)] self.attrs.warps = self.num_warps - # the k_size for each thread block - block_k_size = ( - cdiv(cdiv(k_size, self.split_k_factor), self.block_k) * self.block_k - ) - start_offset_k = self.blockIdx.z * block_k_size - end_offset_k = min(start_offset_k + block_k_size, k_size) - block_m, block_n, block_k = self.block_m, self.block_n, self.block_k offset_m: int32 = block_m * self.blockIdx.x offset_n: int32 = block_n * self.blockIdx.y @@ -129,7 +116,7 @@ def __call__( acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0) for stage in range(self.num_stages - 1): - offset_k = start_offset_k + stage * self.block_k + offset_k = stage * self.block_k self.copy_async(src=ga, dst=sa[stage], offsets=[offset_m, offset_k]) self.copy_async(src=gb, dst=sb[stage], offsets=[offset_k, offset_n]) self.copy_async_commit_group() @@ -139,9 +126,7 @@ def __call__( current_stage: int32 = 0 preload_stage: int32 = self.num_stages - 1 - for offset_k in self.range( - start_offset_k, end_offset_k, block_k, unroll=self.num_stages - ): + for offset_k in self.range(0, k_size, block_k, unroll=self.num_stages): # computation for current tile a = self.load_shared(sa[current_stage]) b = self.load_shared(sb[current_stage]) @@ -149,17 +134,16 @@ def __call__( # preload the next tile of A and B into shared memory preload_offset_k = offset_k + (self.num_stages - 1) * block_k - if preload_offset_k < end_offset_k: - self.copy_async( - src=ga, - dst=sa[preload_stage], - offsets=[offset_m, preload_offset_k], - ) - self.copy_async( - src=gb, - dst=sb[preload_stage], - offsets=[preload_offset_k, offset_n], - ) + self.copy_async( + src=ga, + dst=sa[preload_stage], + offsets=[offset_m, preload_offset_k], + ) + self.copy_async( + src=gb, + dst=sb[preload_stage], + offsets=[preload_offset_k, offset_n], + ) self.copy_async_commit_group() # update the stage @@ -168,41 +152,124 @@ def __call__( self.copy_async_wait_group(n=self.num_stages - 2) self.sync() - # free the shared memory tensors for A and B self.free_shared(sa) self.free_shared(sb) - # cast the accumulator to float16 and change the register tensor's layout - sc = self.shared_tensor(dtype=float16, shape=[block_m, block_n]) casted_acc = self.cast(acc, dtype=float16) - self.store_shared(sc, casted_acc) - self.sync() - rc = self.load_shared(sc) - self.free_shared(sc) - - m_blocks, n_blocks = cdiv(m_size, block_m), cdiv(n_size, block_n) gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) - if self.split_k_factor == 0: - self.store_global(gc, rc, offsets=[offset_m, offset_n]) - else: - semaphores = self.global_tensor( - dtype=int32, shape=[m_blocks, n_blocks], requires_clean=True - ) - semaphore: ~int32 = ~semaphores[self.blockIdx.x, self.blockIdx.y] - - # load and accumulate the partial result in global memory - if self.blockIdx.z > 0: - self.lock_semaphore(semaphore, value=self.blockIdx.z) - partial_rc = self.load_global( - gc, offsets=[offset_m, offset_n], shape=[block_m, block_n] - ) - self.add(rc, partial_rc, out=rc) - - # store the result to global memory and release the semaphore - self.store_global(gc, rc, offsets=[offset_m, offset_n]) - - # release the semaphore - self.sync() # we need to make sure the previous store_global is finished - self.release_semaphore( - semaphore, value=(self.blockIdx.z + 1) % self.split_k_factor - ) \ No newline at end of file + self.store_global(gc, casted_acc, offsets=[offset_m, offset_n]) + + + +def run_basic(M, N, K): + A = torch.randn(M, K, device="cuda", dtype=torch.float16) + B = torch.randn(K, N, device="cuda", dtype=torch.float16) + C = torch.empty((M, N), device='cuda', dtype=torch.float16) + C_ref = A @ B + + matmul = MatmulBasic() + torch.cuda.synchronize() + matmul(M, N, K, A, B, C) + torch.cuda.synchronize() + + torch.testing.assert_close(C_ref, C, atol=M * 2**(-11), rtol=1e-2) + print("Succes") + + +def run_optmized(M, N, K): + A = torch.randn(M, K, device="cuda", dtype=torch.float16) + B = torch.randn(K, N, device="cuda", dtype=torch.float16) + C = torch.empty((M, N), device='cuda', dtype=torch.float16) + C_ref = A @ B + + matmul = MatmulOpt(num_warps=4, block_m=64, block_n=64, block_k=16, num_stages=3) + torch.cuda.synchronize() + + matmul(M, N, K, A, B, C) + torch.cuda.synchronize() + + torch.testing.assert_close(C_ref, C, atol=M * 2**(-11), rtol=1e-2) + print("Succes") + + +def tune_basic(M, N, K): + A = torch.randn(M, K, device="cuda", dtype=torch.float16) + B = torch.randn(K, N, device="cuda", dtype=torch.float16) + C = torch.empty((M, N), device='cuda', dtype=torch.float16) + C_ref = A @ B + + size = (M, N) + + args = [M, N, K, A, B, C] + + tune_params = dict() + tune_params["block_m"] = [16, 32, 64, 128] + tune_params["block_n"] = [16, 32, 64, 128] + tune_params["block_k"] = [16, 32, 64, 128] + tune_params["num_warps"] = [2, 4, 8, 16] + + + results, env = tune_kernel( + kernel_name="MatmulBasic", + kernel_source=__file__, + problem_size=size, + arguments=args, + tune_params=tune_params, + lang="generic_python", + answer=[None, None, None, None, None, C_ref.cpu()], + atol=M * 2**(-11), + strategy = "bayes_opt", + strategy_options = {"max_fevals": 200}, + call_function=call_tilus, + ) + + +def tune_opt(M, N, K): + A = torch.randn(M, K, device="cuda", dtype=torch.float16) + B = torch.randn(K, N, device="cuda", dtype=torch.float16) + C = torch.empty((M, N), device='cuda', dtype=torch.float16) + C_ref = A @ B + + size = (M, N) + + args = [M, N, K, A, B, C] + + tune_params = dict() + tune_params["block_m"] = [16, 32, 64, 128, 256] + tune_params["block_n"] = [16, 32, 64, 128, 256] + tune_params["block_k"] = [16, 32, 64, 128] + tune_params["num_warps"] = [2, 4, 8, 16] + tune_params["num_stages"] = [2, 3, 4, 5, 6] + + # Shared memory restriction + restrictions = ["2 * num_stages * block_k * (block_m + block_n) <= 49152"] + + results, env = tune_kernel( + kernel_name="MatmulOpt", + kernel_source=__file__, + problem_size=size, + arguments=args, + tune_params=tune_params, + lang="generic_python", + answer=[None, None, None, None, None, C_ref.cpu()], + atol=M * 2**(-11), + restrictions=restrictions, + strategy = "bayes_opt", + strategy_options = {"max_fevals": 200}, + call_function=call_tilus, + ) + + + + + +if __name__ == "__main__": + M, N, K = 4096, 4096, 4096 + #M, N, K = 8192, 8192, 8192 + #run_basic(M, N, K) + + run_optmized(M, N, K) + + #tune_basic(M, N, K) + + #tune_opt(M, N, K) \ No newline at end of file diff --git a/examples/generic_python/matmul/triton_matmul.py b/examples/generic_python/matmul/triton_matmul.py index b127b0331..dc81c5bd0 100644 --- a/examples/generic_python/matmul/triton_matmul.py +++ b/examples/generic_python/matmul/triton_matmul.py @@ -2,19 +2,17 @@ import triton import triton.language as tl +from kernel_tuner import tune_kernel +from examples.generic_python.call_functions import call_triton @triton.jit def matmul_basic( - # Pointers - a_ptr, b_ptr, c_ptr, - # Matrix sizes - M, N, K, - # Strides - stride_am, stride_ak, + a_ptr, b_ptr, c_ptr, # Pointers + M, N, K, # Matrix sizes + stride_am, stride_ak, # Strides stride_bk, stride_bn, stride_cm, stride_cn, - # Tile sizes - BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, # Tile sizes BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): @@ -41,7 +39,7 @@ def matmul_basic( mask=(offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0, ) - + b = tl.load( b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N), @@ -160,3 +158,153 @@ def matmul_opt( tl.store(c_ptrs, c, mask=c_mask) + +def run_basic(M, N, K): + A = torch.randn(M, K, device="cuda", dtype=torch.float16) + B = torch.randn(K, N, device="cuda", dtype=torch.float16) + C = torch.empty((M, N), device='cuda', dtype=torch.float16) + C_ref = A @ B + + BLOCK_SIZE_M = 64 + BLOCK_SIZE_N = 64 + BLOCK_SIZE_K = 32 + + grid = ( + triton.cdiv(M, BLOCK_SIZE_M), + triton.cdiv(N, BLOCK_SIZE_N), + ) + + matmul_basic[grid]( + A, B, C, + M, N, K, + A.stride(0), A.stride(1), + B.stride(0), B.stride(1), + C.stride(0), C.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + ) + + assert torch.allclose(C, C_ref, rtol=1e-2, atol= M * 2**(-11)) + + print("Passed") + + +def run_opt(M, N, K): + A = torch.randn(M, K, device="cuda", dtype=torch.float16) + B = torch.randn(K, N, device="cuda", dtype=torch.float16) + C = torch.empty((M, N), device='cuda', dtype=torch.float16) + C_ref = A @ B + + BLOCK_SIZE_M = 64 + BLOCK_SIZE_N = 64 + BLOCK_SIZE_K = 32 + GROUP_SIZE_M = 32 + + grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),) + + matmul_opt[grid]( + A, B, C, + M, N, K, + A.stride(0), A.stride(1), + B.stride(0), B.stride(1), + C.stride(0), C.stride(1), + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=GROUP_SIZE_M, + ) + + assert torch.allclose(C, C_ref, rtol=1e-2, atol= M * 2**(-11)) + + print("Passed") + + +def tune_basic(M, N, K): + A = torch.randn(M, K, device="cuda", dtype=torch.float16) + B = torch.randn(K, N, device="cuda", dtype=torch.float16) + C = torch.empty((M, N), device='cuda', dtype=torch.float16) + C_ref = A @ B + + size = (M, N) + + args = [A, B, C, + M, N, K, + A.stride(0), A.stride(1), + B.stride(0), B.stride(1), + C.stride(0), C.stride(1), + ] + + tune_params = dict() + tune_params["BLOCK_SIZE_M"] = [2**i for i in range(1, 10)] + tune_params["BLOCK_SIZE_N"] = [2**i for i in range(1, 10)] + tune_params["BLOCK_SIZE_K"] = [2**i for i in range(4, 10)] # tl.dot requires K >= 16 + tune_params["num_warps"] = [1, 2, 4, 8, 16, 32] + tune_params["num_stages"] = [1, 2, 3, 4, 5] + + results, env = tune_kernel( + kernel_name="matmul_basic", + kernel_source=__file__, + problem_size=size, + arguments=args, + tune_params=tune_params, + lang="generic_python", + answer=[None, None, C_ref.cpu(), None, None, None, None, None, None, None, None, None], + atol=M * 2**(-11), + block_size_names = ["BLOCK_SIZE_M", "BLOCK_SIZE_M"], + strategy = "bayes_opt", + strategy_options = {"max_fevals": 100}, + call_function=call_triton, + ) + + +def tune_opt(M, N, K): + A = torch.randn(M, K, device="cuda", dtype=torch.float16) + B = torch.randn(K, N, device="cuda", dtype=torch.float16) + C = torch.empty((M, N), device='cuda', dtype=torch.float16) + C_ref = A @ B + + size = M * N + + args = [A, B, C, + M, N, K, + A.stride(0), A.stride(1), + B.stride(0), B.stride(1), + C.stride(0), C.stride(1), + ] + + tune_params = dict() + tune_params["BLOCK_SIZE_M"] = [2**i for i in range(1, 10)] + tune_params["BLOCK_SIZE_N"] = [2**i for i in range(1, 10)] + tune_params["BLOCK_SIZE_K"] = [2**i for i in range(4, 10)] # tl.dot requires K >= 16 + tune_params["GROUP_SIZE_M"] = [1, 2, 4, 8, 16, 32, 64] + tune_params["num_warps"] = [1, 2, 4, 8, 16, 32] + tune_params["num_stages"] = [1, 2, 3, 4, 5] + + results, env = tune_kernel( + kernel_name="matmul_opt", + kernel_source=__file__, + problem_size=size, + arguments=args, + tune_params=tune_params, + lang="generic_python", + answer=[None, None, C_ref.cpu(), None, None, None, None, None, None, None, None, None], + atol=M * 2**(-11), + block_size_names = ["BLOCK_SIZE_M", "BLOCK_SIZE_M"], + grid_div_x = ["BLOCK_SIZE_M", "BLOCK_SIZE_N"], + strategy = "bayes_opt", + strategy_options = {"max_fevals": 100}, + call_function=call_triton, + ) + + + +if __name__ == "__main__": + M, N, K = 4096, 4096, 4096 + #M, N, K = 8192, 8192, 8192 + #run_basic(M, N, K) + #run_opt(M, N, K) + + + #tune_basic(M, N, K) + #tune_opt(M, N, K) diff --git a/examples/generic_python/matmul/warp_matmul.py b/examples/generic_python/matmul/warp_matmul.py index 78a90ce65..388ae3a1b 100644 --- a/examples/generic_python/matmul/warp_matmul.py +++ b/examples/generic_python/matmul/warp_matmul.py @@ -3,33 +3,54 @@ import warp as wp from kernel_tuner import tune_kernel -from pathlib import Path +from examples.generic_python.call_functions import call_warp wp.init() +wp.config.enable_backward = False + + +@wp.kernel() +def gemm( + A: wp.array2d(dtype=wp.float16), B: wp.array2d(dtype=wp.float16), C: wp.array2d(dtype=wp.float16) +): + i, j = wp.tid() + M = A.shape[0] + N = B.shape[1] + K = A.shape[1] + + if i >= M or j >= N: + return + + # compute dot product + sum = wp.float32(0.0) + for k in range(K): + sum += wp.float32(A[i, k]) * wp.float32(B[k, j]) + + # write result + C[i, j] = wp.float16(sum) -FULL_PATH = Path(__file__).resolve() # tile size -TILE_M = wp.constant(8) -TILE_N = wp.constant(4) -TILE_K = wp.constant(8) +TILE_M = 32 +TILE_N = 32 +TILE_K = 32 # num threads per-tile -TILE_THREADS = 64 +TILE_THREADS = 1024 # GEMM example from https://nvidia.github.io/warp/user_guide/tiles.html -@wp.kernel -def tile_gemm(A: wp.array2d(dtype=wp.float16), B: wp.array2d(dtype=wp.float16), C: wp.array2d(dtype=wp.float16)): - +@wp.kernel() +def tile_gemm( + A: wp.array2d(dtype=wp.float16), + B: wp.array2d(dtype=wp.float16), + C: wp.array2d(dtype=wp.float16) +): # output tile index i, j = wp.tid() sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=wp.float32) - M = A.shape[0] - N = B.shape[1] K = A.shape[1] - count = (K + TILE_K - 1) // TILE_K for k in range(0, count): @@ -41,7 +62,7 @@ def tile_gemm(A: wp.array2d(dtype=wp.float16), B: wp.array2d(dtype=wp.float16), wp.tile_store(C, wp.tile_astype(sum, wp.float16), offset=(i*TILE_M, j*TILE_N), bounds_check=True) -def run_kernel_direct(M, N, K): +def run_gemm(M, N, K): rng = np.random.default_rng(42) A = rng.random((M, K)).astype(np.float16) B = rng.random((K, N)).astype(np.float16) @@ -51,36 +72,33 @@ def run_kernel_direct(M, N, K): B_wp = wp.array(B) C_wp = wp.array(C) - with wp.Tape() as tape: - wp.launch_tiled( - tile_gemm, - dim=((M + TILE_M - 1) // TILE_M, (N + TILE_N - 1) // TILE_N), - inputs=[A_wp, B_wp, C_wp], - block_dim=TILE_THREADS) + wp.launch(gemm, dim=(M, N), inputs=[A_wp, B_wp, C_wp]) - np.testing.assert_allclose(C_wp.numpy(), A @ B, rtol=1e-1) + np.testing.assert_allclose(C_wp.numpy(), A @ B, rtol=1e-2, atol=M * 2**(-11)) - print("Example matrix multiplication passed") + print("Succes") +def run_gemm_tiled(M, N, K): + rng = np.random.default_rng(42) + A = rng.random((M, K)).astype(np.float16) + B = rng.random((K, N)).astype(np.float16) + C = np.zeros((M, N), dtype=np.float16) -def call_warp(kernel_function, args, kwargs, grid, threads, params): - # Convert Torch tensors to Warp args - warp_args = [] - for arg in args: - if isinstance(arg, torch.Tensor): - warp_args.append(wp.from_torch(arg)) - else: - warp_args.append(arg) + A_wp = wp.array(A) + B_wp = wp.array(B) + C_wp = wp.array(C) - # launch kernel - with wp.Tape() as tape: - wp.launch_tiled( - kernel_function, - dim=grid, - inputs=warp_args, - block_dim=params["TILE_THREADS"], - ) + + wp.launch_tiled( + tile_gemm, + dim=((M + TILE_M - 1) // TILE_M, (N + TILE_N - 1) // TILE_N), + inputs=[A_wp, B_wp, C_wp], + block_dim=TILE_THREADS) + + np.testing.assert_allclose(C_wp.numpy(), A @ B, rtol=1e-2, atol=M * 2**(-11)) + + print("Succes") def tune(M, K, N): @@ -90,43 +108,40 @@ def tune(M, K, N): C = np.zeros((M, N), dtype=np.float16) size = (M, N) - block_size_names = ["TILE_M", "TILE_N"] tune_params = dict() - tune_params["TILE_M"] = [4, 8, 16] - tune_params["TILE_N"] = [2, 4, 8] - tune_params["TILE_K"] = [4, 8, 16] - tune_params["TILE_THREADS"] = [32, 64, 128] + tune_params["block_dim"] = [2**i for i in range(5, 11)] + tune_params["dim"] = [size] args = [A, B, C] answer = [None, None, A @ B] results, env = tune_kernel( - kernel_name="tile_gemm", - kernel_source=FULL_PATH, + kernel_name="gemm", + kernel_source=__file__, problem_size=size, arguments=args, tune_params=tune_params, lang="generic_python", answer=answer, call_function=call_warp, - block_size_names=block_size_names, ) + if __name__ == "__main__": - #tune(128, 128, 128) + M, N, K = 1024, 1024, 1024 + + #run_gemm(M, N, K) + #run_gemm_tiled(M, N, K) + + tune(M, N, K) + + + - sizes = [ - (65, 65, 17), - (67, 71, 19), - (1, 1, 1), - (63, 63, 15), - (129, 130, 33), - ] + - #for size in sizes: - # print(size) - # run_kernel_direct(*size) + - run_kernel_direct(1024, 1024, 1024) + diff --git a/kernel_tuner/backends/generic_python.py b/kernel_tuner/backends/generic_python.py index 4845fc4e2..2247d5e87 100644 --- a/kernel_tuner/backends/generic_python.py +++ b/kernel_tuner/backends/generic_python.py @@ -41,11 +41,11 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None self.device_id = torch.cuda.current_device() self.device_properties = torch.cuda.get_device_properties(self.device_id) self.name = torch.cuda.get_device_name(self.device_id) - self.max_threads = self.device_properties.max_threads_per_multi_processor + self.max_threads = 10**18 # 'inf' to support tile based programming models, which can use less threads then the size of a tile. env = dict() env["device_name"] = self.name - env["max_threads"] = self.max_threads + env["max_threads"] = self.max_threads env["iterations"] = iterations env["compiler_options"] = compiler_options self.env = env From 3cccf50215a178d7b58d239339d53b9563a452c5 Mon Sep 17 00:00:00 2001 From: Imke van Ooijen Date: Sat, 6 Jun 2026 10:56:49 +0200 Subject: [PATCH 11/14] modified examples --- examples/generic_python/matmul/cute_matmul.py | 20 +-- .../generic_python/matmul/numba_matmul.py | 5 +- .../generic_python/matmul/tilelang_matmul.py | 144 ++++++++---------- .../generic_python/matmul/tilus_matmul.py | 36 ++--- .../generic_python/matmul/triton_matmul.py | 22 +-- 5 files changed, 105 insertions(+), 122 deletions(-) diff --git a/examples/generic_python/matmul/cute_matmul.py b/examples/generic_python/matmul/cute_matmul.py index 4b5637c82..09ced866e 100644 --- a/examples/generic_python/matmul/cute_matmul.py +++ b/examples/generic_python/matmul/cute_matmul.py @@ -15,14 +15,7 @@ ## Basic Matmul ================================================================ @cute.kernel -def naive_matmul_kernel( - gA: cute.Tensor, - gB: cute.Tensor, - gC: cute.Tensor, - M: int, - K: int, - N: int, -): +def naive_matmul_kernel(gA: cute.Tensor, gB: cute.Tensor, gC: cute.Tensor, M: int, K: int, N: int): tx, ty, _ = cute.arch.thread_idx() bx, by, _ = cute.arch.block_idx() bdx, bdy, _ = cute.arch.block_dim() @@ -95,13 +88,22 @@ def tune_naive_matmul(M, N, K): tune_params = dict() tune_params["block_size_x"] = [2**i for i in range(1, 10)] tune_params["block_size_y"] = [2**i for i in range(1, 10)] - restrictions = ["block_size_x * block_size_y <= 1024"] + restrictions = ["block_size_x * block_size_y >= 32, block_size_x * block_size_y <= 1024"] results, env = tune_kernel("matmul", __file__, size, args, tune_params, lang="generic_python", call_function=call_cute, answer=answer, atol=M * 2 **(-11), restrictions=restrictions, verbose=False) +# This kernel is copied from the CuTe DSL project: +# https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/ampere/tensorop_gemm.py +# +# Original example: tensorop_gemm.py +# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Modifications in file: +# - Tuning paramters made more explicit + ## Optimized matmul ========================================================== # Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause diff --git a/examples/generic_python/matmul/numba_matmul.py b/examples/generic_python/matmul/numba_matmul.py index d8e23d6fe..6486ec540 100644 --- a/examples/generic_python/matmul/numba_matmul.py +++ b/examples/generic_python/matmul/numba_matmul.py @@ -152,9 +152,8 @@ def run_optimized(M, N, K): ) # launch - for i in range(10000): - optimized_matmul[blocks_per_grid, threads_per_block](M, N, K, dA, dB, dC) - cuda.synchronize() + optimized_matmul[blocks_per_grid, threads_per_block](M, N, K, dA, dB, dC) + cuda.synchronize() # copy result back C_result = dC.copy_to_host() diff --git a/examples/generic_python/matmul/tilelang_matmul.py b/examples/generic_python/matmul/tilelang_matmul.py index 1067d9ff5..0c476c786 100644 --- a/examples/generic_python/matmul/tilelang_matmul.py +++ b/examples/generic_python/matmul/tilelang_matmul.py @@ -40,11 +40,16 @@ def gemm( - -# https://github.com/tile-ai/tilelang/tree/main/examples/gemm -# num_threads and num_stages were added as metaparameters to make it possible to compare the built-in tuner -# with Kernel Tuner. -# Removed annotated memory layout for tiles A and B +# This kernel is copied from the TileLang project: +# https://github.com/tile-ai/tilelang/tree/main/examples/gemm/example_gemm.py +# +# Original example: example_gemm.py +# Copyright (c) the TileLang authors +# +# Modifications in file: +# - num_threads and num_stages added as metaparameters +# - Removed annotated memory layout for tiles A and B +# - dummy parameter to trigger fresh compilation when timing repeated tuning @tilelang.jit def matmul_opt(M, N, K, block_M, block_N, block_K, num_threads=128, num_stages=3, dummy=0, dtype=T.float16, accum_dtype=T.float): @T.prim_func @@ -60,7 +65,8 @@ def main( C_local = T.alloc_fragment((block_M, block_N), accum_dtype) # Enable swizzle-based rasterization for better L2 locality - T.use_swizzle(panel_size=10, enable=True) + panel_size = 10 + T.use_swizzle(panel_size=panel_size, enable=True) # Clear the local accumulation buffer T.clear(C_local) @@ -84,7 +90,7 @@ def main( -# TODO + def get_configs(kt=False): """ Generate a list of kernel tuning configuration dictionaries for a tiled matrix-multiply. @@ -94,30 +100,22 @@ def get_configs(kt=False): List[dict]: A list of configuration dictionaries """ - ''' - block_M = [64, 128, 256] - block_N = [64, 128, 256] - block_K = [32, 64] - num_stages = [0, 1, 2, 3] - thread_num = [128, 256] - enable_rasterization = [True, False] - ''' - - block_M = [64] - block_N = [64] - block_K = [32, 64] - num_stages = [0, 3] - thread_num = [128] - enable_rasterization = [True] + block_M = [32, 64, 128, 256] + block_N = [32, 64, 128, 256] + block_K = [16, 32, 64, 128] + num_threads = [256]#[64, 128, 256, 512] + num_stages = [3]#[0, 1, 2, 3, 4] + panel_size = [10]#[4, 6, 8, 10] + _configs = list( itertools.product( block_M, block_N, block_K, + num_threads, num_stages, - thread_num, - enable_rasterization, + panel_size, ) ) @@ -127,8 +125,8 @@ def get_configs(kt=False): "block_N": block_N, "block_K": block_K, "num_stages": num_stages, - "thread_num": thread_num, - "enable_rasteration": enable_rasterization, + "num_threads": num_threads, + "panel_size": panel_size, } else: configs = [ @@ -136,9 +134,9 @@ def get_configs(kt=False): "block_M": c[0], "block_N": c[1], "block_K": c[2], - "num_stages": c[3], - "thread_num": c[4], - "enable_rasteration": c[5], # keep param name for backward-compat + "num_threads": c[3], + "num_stages": c[4], + "panel_size": c[5], } for c in _configs ] @@ -146,58 +144,45 @@ def get_configs(kt=False): -# TODO +# For autotuning experiment # https://github.com/tile-ai/tilelang/blob/main/examples/gemm/example_gemm_autotune.py -# changed gemm_autotune to gemm -# originally, B was transposed. I changed this so the kernel input is the same as in other languages. -@tilelang.autotune(configs=get_configs()) +@tilelang.autotune(configs=get_configs(), warmup=1, rep=32) @tilelang.jit -def matmul_opt_autotune(M, N, K, dummy, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32): +def matmul_opt_autotune(M, N, K, block_M, block_N, block_K, num_threads=128, num_stages=3, panel_size=10, dummy=0, dtype=T.float16, accum_dtype=T.float): @T.prim_func - def gemm( + def main( A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), + B: T.Tensor((K, N), dtype), C: T.Tensor((M, N), dtype), ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): + # Allocate shared and local fragments A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_N, block_K), dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - C_shared = T.alloc_shared((block_M, block_N), dtype) - T.use_swizzle(panel_size=10, enable=enable_rasteration) - T.clear(C_local) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[bx * block_N, k * block_K], B_shared) - T.gemm( - A_shared, - B_shared, - C_local, - transpose_B=True, - ) - T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M, bx * block_N]) - - return gemm - + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + # Enable swizzle-based rasterization for better L2 locality + T.use_swizzle(panel_size=panel_size, enable=True) -# DEZE mag denk ik weg -def main(): - kernel = matmul_opt_autotune(1024, 1024, 1024) - import torch + # Clear the local accumulation buffer + T.clear(C_local) - a = torch.randn(1024, 1024).cuda().half() - b = torch.randn(1024, 1024).cuda().half() - c = torch.empty((1024, 1024), device='cuda', dtype=torch.float16) + # Pipelined iteration over K dimension + for idx in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + # Copy tile of A + T.copy(A[by * block_M, idx * block_K], A_shared) - kernel(a, b, c) + # Parallel copy tile of B + for ko, j in T.Parallel(block_K, block_N): + B_shared[ko, j] = B[idx * block_K + ko, bx * block_N + j] - ref_c = a @ b + # Perform local GEMM on the shared-memory tiles + T.gemm(A_shared, B_shared, C_local) - torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) - print("All check passed.") + # Copy the result tile back + T.copy(C_local, C[by * block_M, bx * block_N]) + return main @@ -254,11 +239,11 @@ def tune_basic(M, N, K): args = [A, B, C] tune_params = dict() - tune_params["block_M"] = [2**i for i in range(4, 10)] - tune_params["block_N"] = [2**i for i in range(4, 10)] - tune_params["block_K"] = [2**i for i in range(4, 10)] - tune_params["num_threads"] = [2**i for i in range(5, 11)] - tune_params["num_stages"] = [1, 2, 3, 4, 5] + tune_params["block_M"] = [2**i for i in range(5, 9)] + tune_params["block_N"] = [2**i for i in range(5, 9)] + tune_params["block_K"] = [2**i for i in range(4, 8)] + tune_params["num_threads"] = [64, 128, 256, 512] + tune_params["num_stages"] = [0, 1, 2, 3, 4] tune_params["M"] = [M] tune_params["N"] = [N] tune_params["K"] = [K] @@ -295,11 +280,12 @@ def tune_opt(M, N, K): args = [A, B, C] tune_params = dict() - tune_params["block_M"] = [2**i for i in range(4, 10)] - tune_params["block_N"] = [2**i for i in range(4, 10)] - tune_params["block_K"] = [2**i for i in range(4, 10)] - tune_params["num_threads"] = [2**i for i in range(5, 11)] - tune_params["num_stages"] = [1, 2, 3, 4, 5] + tune_params["block_M"] = [2**i for i in range(5, 9)] + tune_params["block_N"] = [2**i for i in range(5, 9)] + tune_params["block_K"] = [2**i for i in range(4, 8)] + tune_params["num_threads"] = [64, 128, 256, 512] + tune_params["num_stages"] = [0, 1, 2, 3, 4] + tune_params["panel_size"] = [4, 6, 8, 10] tune_params["M"] = [M] tune_params["N"] = [N] tune_params["K"] = [K] @@ -332,7 +318,7 @@ def tune_opt(M, N, K): #run_basic(M, N, K) #run_opt(M, N, K) - #tune_basic(M, N, K) - tune_opt(M, N, K) + tune_basic(M, N, K) + #tune_opt(M, N, K) diff --git a/examples/generic_python/matmul/tilus_matmul.py b/examples/generic_python/matmul/tilus_matmul.py index 38f6523f9..811d5c2d7 100644 --- a/examples/generic_python/matmul/tilus_matmul.py +++ b/examples/generic_python/matmul/tilus_matmul.py @@ -24,12 +24,8 @@ def __init__(self): def __call__( self, - m_size: int32, # the size of the m dimension of the input matrix A and output matrix C - n_size: int, # the size of the n dimension of the input matrix B and output matrix C - k_size: int, # the size of the k dimension of the input matrix A and B - a_ptr: ~float16, # the pointer to the input matrix A, which is a 2D tensor of shape [m_size, k_size] - b_ptr: ~float16, # the pointer to the input matrix B, which is a 2D tensor of shape [k_size, n_size] - c_ptr: ~float16, # the pointer to the output matrix C, which is a 2D tensor of shape [m_size, n_size] + m_size: int32, n_size: int, k_size: int, # Matrix dimensions + a_ptr: ~float16, b_ptr: ~float16, c_ptr: ~float16, # Matrix pointers ): self.attrs.blocks = [ cdiv(m_size, self.block_m), # the x dimension size of the grid @@ -95,12 +91,8 @@ def __init__(self, num_warps=None, block_m=None, block_n=None, block_k=None, num def __call__( self, - m_size: int32, - n_size: int, - k_size: int, - a_ptr: ~float16, - b_ptr: ~float16, - c_ptr: ~float16, + m_size: int32, n_size: int, k_size: int, + a_ptr: ~float16, b_ptr: ~float16, c_ptr: ~float16, ): self.attrs.blocks = [cdiv(m_size, self.block_m), cdiv(n_size, self.block_n)] self.attrs.warps = self.num_warps @@ -176,7 +168,7 @@ def run_basic(M, N, K): print("Succes") -def run_optmized(M, N, K): +def run_optimized(M, N, K): A = torch.randn(M, K, device="cuda", dtype=torch.float16) B = torch.randn(K, N, device="cuda", dtype=torch.float16) C = torch.empty((M, N), device='cuda', dtype=torch.float16) @@ -203,9 +195,9 @@ def tune_basic(M, N, K): args = [M, N, K, A, B, C] tune_params = dict() - tune_params["block_m"] = [16, 32, 64, 128] - tune_params["block_n"] = [16, 32, 64, 128] - tune_params["block_k"] = [16, 32, 64, 128] + tune_params["block_m"] = [2**i for i in range(5, 9)] + tune_params["block_n"] = [2**i for i in range(5, 9)] + tune_params["block_k"] = [2**i for i in range(4, 8)] tune_params["num_warps"] = [2, 4, 8, 16] @@ -235,11 +227,11 @@ def tune_opt(M, N, K): args = [M, N, K, A, B, C] tune_params = dict() - tune_params["block_m"] = [16, 32, 64, 128, 256] - tune_params["block_n"] = [16, 32, 64, 128, 256] - tune_params["block_k"] = [16, 32, 64, 128] + tune_params["block_m"] = [2**i for i in range(5, 9)] + tune_params["block_n"] = [2**i for i in range(5, 9)] + tune_params["block_k"] = [2**i for i in range(4, 8)] tune_params["num_warps"] = [2, 4, 8, 16] - tune_params["num_stages"] = [2, 3, 4, 5, 6] + tune_params["num_stages"] = [2, 3, 4, 5] # Shared memory restriction restrictions = ["2 * num_stages * block_k * (block_m + block_n) <= 49152"] @@ -268,8 +260,8 @@ def tune_opt(M, N, K): #M, N, K = 8192, 8192, 8192 #run_basic(M, N, K) - run_optmized(M, N, K) + #run_optmized(M, N, K) - #tune_basic(M, N, K) + tune_basic(M, N, K) #tune_opt(M, N, K) \ No newline at end of file diff --git a/examples/generic_python/matmul/triton_matmul.py b/examples/generic_python/matmul/triton_matmul.py index dc81c5bd0..cac3cab73 100644 --- a/examples/generic_python/matmul/triton_matmul.py +++ b/examples/generic_python/matmul/triton_matmul.py @@ -236,10 +236,10 @@ def tune_basic(M, N, K): ] tune_params = dict() - tune_params["BLOCK_SIZE_M"] = [2**i for i in range(1, 10)] - tune_params["BLOCK_SIZE_N"] = [2**i for i in range(1, 10)] - tune_params["BLOCK_SIZE_K"] = [2**i for i in range(4, 10)] # tl.dot requires K >= 16 - tune_params["num_warps"] = [1, 2, 4, 8, 16, 32] + tune_params["BLOCK_SIZE_M"] = [2**i for i in range(5, 9)] + tune_params["BLOCK_SIZE_N"] = [2**i for i in range(5, 9)] + tune_params["BLOCK_SIZE_K"] = [2**i for i in range(4, 8)] + tune_params["num_warps"] = [2, 4, 8, 16] tune_params["num_stages"] = [1, 2, 3, 4, 5] results, env = tune_kernel( @@ -274,12 +274,15 @@ def tune_opt(M, N, K): ] tune_params = dict() - tune_params["BLOCK_SIZE_M"] = [2**i for i in range(1, 10)] - tune_params["BLOCK_SIZE_N"] = [2**i for i in range(1, 10)] - tune_params["BLOCK_SIZE_K"] = [2**i for i in range(4, 10)] # tl.dot requires K >= 16 - tune_params["GROUP_SIZE_M"] = [1, 2, 4, 8, 16, 32, 64] - tune_params["num_warps"] = [1, 2, 4, 8, 16, 32] + tune_params["BLOCK_SIZE_M"] = [2**i for i in range(5, 9)] + tune_params["BLOCK_SIZE_N"] = [2**i for i in range(5, 9)] + tune_params["BLOCK_SIZE_K"] = [2**i for i in range(4, 8)] + tune_params["GROUP_SIZE_M"] = [4, 6, 8, 10] + tune_params["num_warps"] = [2, 4, 8, 16] tune_params["num_stages"] = [1, 2, 3, 4, 5] + tune_params["M_dim"] = [M] + + restrictions = ["GROUP_SIZE_M * BLOCK_SIZE_M < M_dim"] # Otherwise grouped ordering has no effect results, env = tune_kernel( kernel_name="matmul_opt", @@ -292,6 +295,7 @@ def tune_opt(M, N, K): atol=M * 2**(-11), block_size_names = ["BLOCK_SIZE_M", "BLOCK_SIZE_M"], grid_div_x = ["BLOCK_SIZE_M", "BLOCK_SIZE_N"], + restrictions = restrictions, strategy = "bayes_opt", strategy_options = {"max_fevals": 100}, call_function=call_triton, From 68afbc1c22dba73010ddbb6a40094e34ce2fe0a6 Mon Sep 17 00:00:00 2001 From: Imke van Ooijen Date: Sat, 6 Jun 2026 11:09:54 +0200 Subject: [PATCH 12/14] Restore workflow file --- .github/workflows/test-python-package.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-python-package.yml b/.github/workflows/test-python-package.yml index f86240c7f..dec26fdba 100644 --- a/.github/workflows/test-python-package.yml +++ b/.github/workflows/test-python-package.yml @@ -21,7 +21,7 @@ jobs: strategy: matrix: - os: [ubuntu-latest, macos-latest] + os: [ubuntu-latest, macos-13] steps: - uses: actions/checkout@v4 From 9d11e5e927ab038193a6c578dc333ffc12c41343 Mon Sep 17 00:00:00 2001 From: Imke van Ooijen Date: Sat, 6 Jun 2026 11:33:31 +0200 Subject: [PATCH 13/14] fixed failed tests --- kernel_tuner/core.py | 2 +- test/test_observers.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/kernel_tuner/core.py b/kernel_tuner/core.py index aa9cada75..1fe5ffbe6 100644 --- a/kernel_tuner/core.py +++ b/kernel_tuner/core.py @@ -400,6 +400,7 @@ def check_kernel_output( ): """Runs the kernel once and checks the result against answer.""" logging.debug("check_kernel_output") + cp = _get_cupy() # if not using custom verify function, check if the length is the same if answer: @@ -408,7 +409,6 @@ def check_kernel_output( should_sync = [answer[i] is not None for i, arg in enumerate(instance.arguments)] else: - cp = _get_cupy() cupy_ndarray = (cp.ndarray,) if cp is not None else () should_sync = [ isinstance(arg, (np.ndarray, torch.Tensor, DeviceArray) + cupy_ndarray) diff --git a/test/test_observers.py b/test/test_observers.py index 69a94a9bb..d1a27179c 100644 --- a/test/test_observers.py +++ b/test/test_observers.py @@ -4,6 +4,7 @@ from kernel_tuner.observers.nvml import NVMLObserver from kernel_tuner.observers.observer import BenchmarkObserver from kernel_tuner.observers.register import RegisterObserver +from kernel_tuner.language import Language from .context import ( skip_if_no_cuda, @@ -59,7 +60,7 @@ def get_results(self): result, _ = kernel_tuner.tune_kernel(*env_compiler, observers=[lambda args: MyObserver(args)], compiler_options=["-fopenmp"]) # Check if the observer has correctly received the lang option - assert result[0]["observer_args"]["lang"] == "C" + assert result[0]["observer_args"]["lang"] == Language.C @skip_if_no_pycuda def test_register_observer_pycuda(env): From 221c912a06bb177608f5949ef98b00f600c3d6d2 Mon Sep 17 00:00:00 2001 From: Imke van Ooijen Date: Thu, 18 Jun 2026 10:40:53 +0200 Subject: [PATCH 14/14] finalized examples --- examples/generic_python/cupy_copy.py | 19 - examples/generic_python/cute_vec_add.py | 39 +- examples/generic_python/cutile_vec_add.py | 62 -- .../flash_attention/tilelang_attention.py | 152 ----- .../flash_attention/tilus_attention.py | 611 ------------------ .../flash_attention/triton_attention.py | 276 -------- examples/generic_python/loopy_example.py | 38 -- examples/generic_python/matmul/cute_matmul.py | 8 +- .../generic_python/matmul/helion_matmul.py | 20 - .../generic_python/matmul/numba_matmul.py | 6 +- examples/generic_python/matmul/test.py | 278 -------- examples/generic_python/matmul/test_tilus.py | 26 - .../generic_python/matmul/tilelang_matmul.py | 6 +- .../generic_python/matmul/tilus_matmul.py | 10 +- .../generic_python/matmul/triton_matmul.py | 10 +- examples/generic_python/matmul/warp_matmul.py | 59 +- .../matmul_old/tilelang_matmul.py | 212 ------ .../generic_python/matmul_old/tilus_matmul.py | 266 -------- .../matmul_old/triton_matmul.py | 383 ----------- .../normalization/tilelang_norm.py | 78 --- .../normalization/tilus_norm.py | 132 ---- .../normalization/triton_norm.py | 66 -- examples/generic_python/numba_vec_add.py | 18 +- examples/generic_python/pallas_vec_add.py | 23 - examples/generic_python/tilelang_vec_add.py | 23 +- examples/generic_python/tilus_naive_matmul.py | 127 ---- .../generic_python/tilus_splitk_matmul.py | 255 -------- examples/generic_python/tilus_vec_add.py | 35 +- examples/generic_python/triton_vec_add.py | 13 +- examples/generic_python/warp_vec_add.py | 10 +- 30 files changed, 46 insertions(+), 3215 deletions(-) delete mode 100644 examples/generic_python/cupy_copy.py delete mode 100644 examples/generic_python/cutile_vec_add.py delete mode 100644 examples/generic_python/flash_attention/tilelang_attention.py delete mode 100644 examples/generic_python/flash_attention/tilus_attention.py delete mode 100644 examples/generic_python/flash_attention/triton_attention.py delete mode 100644 examples/generic_python/loopy_example.py delete mode 100644 examples/generic_python/matmul/helion_matmul.py delete mode 100644 examples/generic_python/matmul/test.py delete mode 100644 examples/generic_python/matmul/test_tilus.py delete mode 100644 examples/generic_python/matmul_old/tilelang_matmul.py delete mode 100644 examples/generic_python/matmul_old/tilus_matmul.py delete mode 100644 examples/generic_python/matmul_old/triton_matmul.py delete mode 100644 examples/generic_python/normalization/tilelang_norm.py delete mode 100644 examples/generic_python/normalization/tilus_norm.py delete mode 100644 examples/generic_python/normalization/triton_norm.py delete mode 100644 examples/generic_python/pallas_vec_add.py delete mode 100644 examples/generic_python/tilus_naive_matmul.py delete mode 100644 examples/generic_python/tilus_splitk_matmul.py diff --git a/examples/generic_python/cupy_copy.py b/examples/generic_python/cupy_copy.py deleted file mode 100644 index 8d9110209..000000000 --- a/examples/generic_python/cupy_copy.py +++ /dev/null @@ -1,19 +0,0 @@ -import cupy -from cupyx import jit - -@jit.rawkernel() -def elementwise_copy(x, y, size): - tid = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x - ntid = jit.gridDim.x * jit.blockDim.x - for i in range(tid, size, ntid): - y[i] = x[i] - -size = cupy.uint32(2 ** 22) -x = cupy.random.normal(size=(size,), dtype=cupy.float32) -y = cupy.empty((size,), dtype=cupy.float32) - -elementwise_copy((128,), (1024,), (x, y, size)) # RawKernel style -assert (x == y).all() - -elementwise_copy[128, 1024](x, y, size) # Numba style -assert (x == y).all() \ No newline at end of file diff --git a/examples/generic_python/cute_vec_add.py b/examples/generic_python/cute_vec_add.py index a83694ab9..2744abb77 100644 --- a/examples/generic_python/cute_vec_add.py +++ b/examples/generic_python/cute_vec_add.py @@ -1,13 +1,10 @@ import torch -from functools import partial -from typing import List -import time import cutlass import cutlass.cute as cute -from cutlass.cute.runtime import from_dlpack from kernel_tuner import tune_kernel +from call_functions import call_cute @cute.kernel def vec_add_kernel( @@ -25,7 +22,6 @@ def vec_add_kernel( gC[thread_id] = gA[thread_id] + gB[thread_id] - @cute.jit def vec_add( mA: cute.Tensor, @@ -43,50 +39,19 @@ def vec_add( ) - - -def call_cute(kernel_function, args, kwargs, grid, threads, params): - cute_args = [] - for arg in args: - if isinstance(arg, torch.Tensor): - arg_ = from_dlpack(arg) - cute_args.append(arg_) - else: - cute_args.append(arg) - - kernel_function(*cute_args, **kwargs) - - - - def main(): size = 16384 a = torch.randn(size, device="cuda", dtype=torch.float16) b = torch.randn(size, device="cuda", dtype=torch.float16) c = torch.zeros(size, device="cuda", dtype=torch.float16) - ''' - a_ = from_dlpack(a, assumed_align=16) - b_ = from_dlpack(b, assumed_align=16) - c_ = from_dlpack(c, assumed_align=16) - ''' - args = [a, b, c, size] tune_params = {"num_threads_per_block": [1, 2, 4, 8, 16, 32, 64, 128, 265, 512, 1024]} answer = [None, None, (a+b).cpu(), None] - from pathlib import Path - FULL_PATH = Path(__file__).resolve() - tune_kernel("vec_add", FULL_PATH, size, args, tune_params, answer=answer, + tune_kernel("vec_add", __file__, size, args, tune_params, answer=answer, lang="generic_python", call_function=call_cute, verbose=True) - #naive_elementwise_add_ = cute.compile(naive_elementwise_add, a_, b_, c_) - - #naive_elementwise_add_(a_, b_, c_) - #vec_add(a_, b_, c_, size) - - - #torch.testing.assert_close(c, a+b) main() diff --git a/examples/generic_python/cutile_vec_add.py b/examples/generic_python/cutile_vec_add.py deleted file mode 100644 index e95302045..000000000 --- a/examples/generic_python/cutile_vec_add.py +++ /dev/null @@ -1,62 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -""" -Example demonstrating simple vector addition. -Shows how to perform elementwise operations on vectors. -Does not work on das vu, because we need cuda 13.1 -""" - -import cupy as cp -import numpy as np -import cuda.tile as ct - - -@ct.kernel -def vector_add(a, b, c, tile_size: ct.Constant[int]): - # Get the 1D pid - pid = ct.bid(0) - - # Load input tiles - a_tile = ct.load(a, index=(pid,), shape=(tile_size,)) - b_tile = ct.load(b, index=(pid,), shape=(tile_size,)) - - # Perform elementwise addition - result = a_tile + b_tile - - # Store result - ct.store(c, index=(pid, ), tile=result) - - -def test(): - # Create input data - vector_size = 2**12 - tile_size = 2**4 - grid = (ct.cdiv(vector_size, tile_size), 1, 1) - - rng = cp.random.default_rng() - a = rng.random(vector_size) - b = rng.random(vector_size) - c = cp.zeros_like(a) - - # Launch kernel - ct.launch(cp.cuda.get_current_stream(), - grid, # 1D grid of processors - vector_add, - (a, b, c, tile_size)) - - # Copy to host only to compare - a_np = cp.asnumpy(a) - b_np = cp.asnumpy(b) - c_np = cp.asnumpy(c) - - # Verify results - expected = a_np + b_np - np.testing.assert_array_almost_equal(c_np, expected) - - print("✓ vector_add_example passed!") - - -if __name__ == "__main__": - test() \ No newline at end of file diff --git a/examples/generic_python/flash_attention/tilelang_attention.py b/examples/generic_python/flash_attention/tilelang_attention.py deleted file mode 100644 index 5b2be204a..000000000 --- a/examples/generic_python/flash_attention/tilelang_attention.py +++ /dev/null @@ -1,152 +0,0 @@ -# example taken from https://github.com/tile-ai/tilelang/blob/main/examples/flash_attention/example_mha_fwd_bshd.py -import torch -import torch.nn.functional as F -import tilelang -from tilelang.autotuner import * -import tilelang.language as T -import itertools -import argparse -from functools import partial - - -def get_configs(): - iter_params = dict(block_M=[64], block_N=[64], num_stages=[1], threads=[128]) - return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] - - -@autotune(configs=get_configs(), warmup=10, rep=10) -@tilelang.jit( - out_idx=[3], - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, -) -def flashattn(batch, heads, seq_len, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128): - scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) - shape = [batch, seq_len, heads, dim] - dtype = T.float16 - accum_dtype = T.float32 - - @T.prim_func - def main( - Q: T.Tensor(shape, dtype), - K: T.Tensor(shape, dtype), - V: T.Tensor(shape, dtype), - Output: T.Tensor(shape, dtype), - ): - with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): - Q_shared = T.alloc_shared([block_M, dim], dtype) - K_shared = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_N, dim], dtype) - O_shared = T.alloc_shared([block_M, dim], dtype) - acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) - acc_o = T.alloc_fragment([block_M, dim], accum_dtype) - scores_max = T.alloc_fragment([block_M], accum_dtype) - scores_max_prev = T.alloc_fragment([block_M], accum_dtype) - scores_scale = T.alloc_fragment([block_M], accum_dtype) - scores_sum = T.alloc_fragment([block_M], accum_dtype) - logsum = T.alloc_fragment([block_M], accum_dtype) - - T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) - T.fill(acc_o, 0) - T.fill(logsum, 0) - T.fill(scores_max, -T.infinity(accum_dtype)) - - loop_range = ( - T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) - ) - - for k in T.Pipelined(loop_range, num_stages=num_stages): - T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) - if is_causal: - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) - else: - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - for i in T.Parallel(block_M): - scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] - - T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] /= logsum[i] - T.copy(acc_o, O_shared) - T.copy(O_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) - - return main - - -def ref_program(Q, K, V, is_causal): - dim = Q.size(-1) - scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) - scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) - if is_causal: - seq_len = Q.size(1) - mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) - mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float("-inf")) - attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) - return output - - -def main( - batch: int = 8, - heads: int = 32, - seq_len: int = 4096, - dim: int = 128, - is_causal: bool = False, - tune: bool = False, -): - flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim - total_flops = 2 * flops_per_matmul - if is_causal: - total_flops *= 0.5 - - if not tune: - #kernel = flashattn(batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=1, threads=128) - # Changed block size so fits in shared mem - kernel = flashattn(batch, heads, seq_len, dim, is_causal, block_M=64, block_N=64, num_stages=1, threads=128) - ref_program_processed = partial(ref_program, is_causal=is_causal) - profiler = kernel.get_profiler() - profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) - print("All checks pass.") - latency = profiler.do_bench(ref_program_processed, warmup=500) - print("Ref: {:.2f} ms".format(latency)) - print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) - latency = profiler.do_bench(warmup=500) - print("Tile-lang: {:.2f} ms".format(latency)) - print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) - else: - best_result = flashattn(batch, heads, seq_len, dim, is_causal) - best_latency = best_result.latency - best_config = best_result.config - ref_latency = best_result.ref_latency - print(f"Best latency: {best_latency}") - print(f"Best TFlops: {total_flops / best_latency * 1e-9}") - print(f"Best config: {best_config}") - print(f"Ref latency: {ref_latency}") - - -if __name__ == "__main__": - - main() \ No newline at end of file diff --git a/examples/generic_python/flash_attention/tilus_attention.py b/examples/generic_python/flash_attention/tilus_attention.py deleted file mode 100644 index 5169d9316..000000000 --- a/examples/generic_python/flash_attention/tilus_attention.py +++ /dev/null @@ -1,611 +0,0 @@ -# Code taken from https://github.com/NVIDIA/tilus/blob/main/examples/attention/flash_attention_v3.py -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -import numpy as np -import pandas as pd -import tilus -import torch -from tilus import boolean, f32, int32, void_p -from hidet.ir import DataType -from tilus.ir import RegisterTensor, SharedTensor -from tilus.ir.tensor import GlobalTensor -from tilus.utils import benchmark_func, cdiv - -pd.options.display.max_columns = None -pd.options.display.width = 1000 - - -@tilus.autotune("num_warps", [4, 8]) -@tilus.autotune("block_q", [32, 64, 128]) -@tilus.autotune("block_kv", [32, 64, 128]) -@tilus.autotune("split_kv", [-1, 512, 1024, 4096]) -@tilus.autotune("keep_q_in_regs", [False, True]) -class FlashAttention(tilus.Script): - LOG2_E = 1.4426950408889634 # log2(e) - - debug_schedule = dict( - num_warps=4, - block_q=64, - block_kv=64, - split_kv=-1, - keep_q_in_regs=False, - ) - - def __init__( - self, - dtype: DataType, - num_heads: int, - num_heads_kv: int, - head_size: int, - num_warps: int, - block_q: int, - block_kv: int, - split_kv: int, - keep_q_in_regs: bool, - ): - super().__init__() - self.dtype: DataType = dtype - self.num_heads = num_heads - self.num_heads_kv = num_heads_kv - self.head_size = head_size - self.num_warps = num_warps - self.block_q = block_q - self.block_kv = block_kv - self.split_kv = split_kv - self.keep_q_in_regs = keep_q_in_regs - self.score_scale = float(1.0 / np.sqrt(head_size)) - self.group_heads = num_heads // num_heads_kv - - assert self.split_kv % self.block_kv == 0 or self.split_kv == -1, ( - "split_kv must be a multiple of block_kv or -1" - ) - - # determine layout - self.sv_config = self.cuda.resolve_dot_config( - dtype, f32, m=block_q, n=head_size, k=block_kv, warp_m=num_warps, warp_n=1 - ) - - def apply_mask(self, score: RegisterTensor, q_offset: int32, kv_offset: int32): - mask = self.register_tensor( - dtype=boolean, - shape=[self.block_q, self.block_kv], - init=lambda i, j: i + q_offset >= j + kv_offset, - ) - self.assign(score, score + self.where(mask, x=0.0, y=-1e6)) - - def softmax_rescale( - self, - score: RegisterTensor, - m: RegisterTensor, - l: RegisterTensor, - o: RegisterTensor, - ) -> RegisterTensor: - scale = self.score_scale * self.LOG2_E # log2(e) * score_scale - cur_m = self.max(score, dim=1, keepdim=True) * scale # [block_q, 1] - new_m = self.maximum(m, cur_m) # [block_q, 1] - rp = self.exp2(score * scale - new_m) # [block_q, block_kv] - m_scale = self.exp2(m - new_m) - self.assign(o, o * m_scale) - self.assign(l, l * m_scale + self.sum(rp, dim=1, keepdim=True)) - self.assign(m, new_m) - return rp.to(self.dtype) - - def attention_iteration( - self, - bs: int32, - kv_offset: int32, - q_offset: int32, - head: int32, - gk: GlobalTensor, - gv: GlobalTensor, - sq: SharedTensor, # f16[block_q, head_size] - rq: RegisterTensor, # f16[block_q, head_size] - sk: SharedTensor, # f16[block_kv, head_size], - sv: SharedTensor, # f16[block_kv, head_size], - o: RegisterTensor, # f32[block_q, head_size] - m: RegisterTensor, # f32[block_q, 1] - l: RegisterTensor, # f32[block_q, 1] - check_bounds: bool, - ): - if not self.keep_q_in_regs: - self.load_shared(sq, out=rq) - # wait for the async copy of k to finish - self.copy_async_wait_group(0) - self.sync() - self.copy_async( - gv, - sv, - offsets=[bs, kv_offset, head // self.group_heads, 0], - dims=[1, 3], - check_bounds=check_bounds, - ) - self.copy_async_commit_group() - - # issue the async copy for v and perform dot(q, k) - rk = self.load_shared(sk) # [block_kv, head_size] - score = self.dot(rq, rk.transpose(), acc_dtype=f32) # [block_q, block_kv] - - if check_bounds: - self.apply_mask(score, q_offset, kv_offset) # apply causal mask - - # wait for the async copy of v to finish - self.copy_async_wait_group(0) - self.sync() - self.copy_async( - gk, - sk, - offsets=[bs, kv_offset + self.block_kv, head // self.group_heads, 0], - dims=[1, 3], - check_bounds=check_bounds, - ) - self.copy_async_commit_group() - - # load v to register - rv = self.load_shared(sv) # [block_kv, head_size] - - # online softmax - rp = self.softmax_rescale(score, m=m, l=l, o=o) - - # pv - cur_o = self.dot(rp, rv, acc_dtype=f32) # [block_q, head_size] - self.annotate_layout(cur_o, self.sv_config.lc) - self.assign(o, o + cur_o) - - def main_loop( - self, - gq: GlobalTensor, - gk: GlobalTensor, - gv: GlobalTensor, - o: RegisterTensor, - m: RegisterTensor, - l: RegisterTensor, - ): - # calculate offsets - q_offset = self.blockIdx.x * self.block_q - kv_start_offset = 0 if self.split_kv == -1 else self.blockIdx.y * self.split_kv - - if q_offset + self.block_q <= kv_start_offset: - return - - head = self.blockIdx.z % self.num_heads - bs = self.blockIdx.z // self.num_heads - - sq = self.shared_tensor(dtype=self.dtype, shape=[self.block_q, self.head_size]) - sk = self.shared_tensor(dtype=self.dtype, shape=[self.block_kv, self.head_size]) - sv = self.shared_tensor(dtype=self.dtype, shape=[self.block_kv, self.head_size]) - - rq = self.register_tensor(dtype=self.dtype, shape=[self.block_q, self.head_size]) - - # copy q to shared memory - self.copy_async( - gq, sq, offsets=[bs, q_offset, head, 0], dims=[1, 3], check_bounds=True - ) - self.copy_async_wait_all() - self.sync() - - # copy q to registers if not keeping in shared memory - if self.keep_q_in_regs: - self.load_shared(sq, out=rq) # [block_q, head_size] - self.free_shared(sq) - - # issue a copy of gk - self.copy_async(gk, sk, offsets=[bs, 0, head // self.group_heads, 0], dims=[1, 3]) - self.copy_async_commit_group() - - kv_offset_inner_end = (q_offset + 1) // self.block_kv * self.block_kv - if self.split_kv != -1: - kv_offset_inner_end = min( - kv_offset_inner_end, kv_start_offset + self.split_kv - ) - for kv_offset in range(kv_start_offset, kv_offset_inner_end, self.block_kv): - self.attention_iteration( - bs, - kv_offset, - q_offset, - head, - gk, - gv, - sq, - rq, - sk, - sv, - o, - m, - l, - check_bounds=False, - ) - - kv_offset_end = q_offset + self.block_q - if self.split_kv != -1: - kv_offset_end = min(kv_offset_end, kv_start_offset + self.split_kv) - for kv_offset in range(kv_offset_inner_end, kv_offset_end, self.block_kv): - self.attention_iteration( - bs, - kv_offset, - q_offset, - head, - gk, - gv, - sq, - rq, - sk, - sv, - o, - m, - l, - check_bounds=True, - ) - - self.copy_async_wait_group(0) - self.sync() - self.free_shared(sk) - self.free_shared(sv) - if not self.keep_q_in_regs: - self.free_shared(sq) - - def store_back( - self, - o: RegisterTensor, - l: RegisterTensor, - m: RegisterTensor, - o_ptr: void_p, - batch_size: int, - q_len: int32, - ): - # o: [block_q, head_size] - # m: [block_q, 1] - # l: [block_q, 1] - go = self.global_view( - o_ptr, - dtype=self.dtype, - shape=[batch_size, q_len, self.num_heads, self.head_size], - ) - o = o / l - o_f16 = self.cast(o, dtype=self.dtype) # [block_q, head_size] - so = self.shared_tensor(dtype=self.dtype, shape=[self.block_q, self.head_size]) - - head = self.blockIdx.z % self.num_heads - q_offset = self.blockIdx.x * self.block_q - bs = self.blockIdx.z // self.num_heads - - if self.split_kv == -1: - self.store_shared(so, o_f16) - self.sync() - self.store_global( - go, - self.load_shared(so), - offsets=[bs, q_offset, head, 0], - dims=[1, 3], - ) - else: - num_q_blocks = cdiv(q_len, self.block_q) - semaphores = self.global_tensor( - dtype=int32, - shape=[num_q_blocks, batch_size, self.num_heads], - requires_clean=True, - ) - gm = self.global_tensor( - dtype=f32, - shape=[num_q_blocks, batch_size, self.num_heads, self.block_q], - requires_clean=False, - ) - gl = self.global_tensor( - dtype=f32, - shape=[num_q_blocks, batch_size, self.num_heads, self.block_q], - requires_clean=False, - ) - semaphore = semaphores[self.blockIdx.x, bs, head].item_ptr() - - sm = self.shared_tensor(dtype=f32, shape=[self.block_q]) - sl = self.shared_tensor(dtype=f32, shape=[self.block_q]) - - self.lock_semaphore(semaphore, value=self.blockIdx.y) - - # load previous o, m and l and merge with the current results - if self.blockIdx.y > 0: - self.copy_async(gm, sm, offsets=[self.blockIdx.x, bs, head, 0], dims=[3]) - self.copy_async(gl, sl, offsets=[self.blockIdx.x, bs, head, 0], dims=[3]) - self.copy_async(go, so, offsets=[bs, q_offset, head, 0], dims=[1, 3]) - self.copy_async_wait_all() - self.sync() - lhs_o = self.load_shared(so) - lhs_m = self.load_shared(sm).unsqueeze(1) - lhs_l = self.load_shared(sl).unsqueeze(1) - rhs_o = o_f16 - rhs_m = m - rhs_l = l - m = self.maximum(lhs_m, rhs_m) - lhs_ll = lhs_l * self.exp(lhs_m - m) - rhs_ll = rhs_l * self.exp(rhs_m - m) - l = lhs_ll + rhs_ll - o_f16 = lhs_o * self.cast( - lhs_ll / l, dtype=self.dtype - ) + rhs_o * self.cast(rhs_ll / l, dtype=self.dtype) - self.sync() - - # store the results to so and load it - self.store_shared(so, o_f16) - self.store_shared(sm, m.squeeze(dim=1)) - self.store_shared(sl, l.squeeze(dim=1)) - self.sync() - - # store the results to global memory and release the semaphore - self.store_global( - go, - self.load_shared(so), - offsets=[bs, q_offset, head, 0], - dims=[1, 3], - ) - self.store_global( - gm, - self.load_shared(sm), - offsets=[self.blockIdx.x, bs, head, 0], - dims=[3], - ) - self.store_global( - gl, - self.load_shared(sl), - offsets=[self.blockIdx.x, bs, head, 0], - dims=[3], - ) - self.sync() - - self.free_shared(sm) - self.free_shared(sl) - - # release the semaphore - self.release_semaphore( - semaphore, - value=self.blockIdx.y + 1 - if (self.blockIdx.y + 1) * self.split_kv < q_offset + self.block_q - else 0, - ) - self.free_shared(so) - - def __call__( - self, - batch_size: int, - q_len: int32, - kv_len: int32, - q_ptr: void_p, - k_ptr: void_p, - v_ptr: void_p, - o_ptr: void_p, - ): - self.attrs.warps = self.num_warps - self.attrs.blocks = ( - cdiv(q_len, self.block_q), - cdiv(kv_len, self.split_kv) if self.split_kv != -1 else 1, - self.num_heads * batch_size, - ) - - gq = self.global_view( - q_ptr, - dtype=self.dtype, - shape=[batch_size, q_len, self.num_heads, self.head_size], - ) - gk = self.global_view( - k_ptr, - dtype=self.dtype, - shape=[batch_size, kv_len, self.num_heads_kv, self.head_size], - ) - gv = self.global_view( - v_ptr, - dtype=self.dtype, - shape=[batch_size, kv_len, self.num_heads_kv, self.head_size], - ) - - o = self.register_tensor( - dtype=f32, shape=[self.block_q, self.head_size], init=0.0 - ) - m = self.register_tensor( - dtype=f32, shape=[self.block_q, 1], init=-1e6 - ) # rowmax(score) - l = self.register_tensor( - dtype=f32, shape=[self.block_q, 1], init=0.0 - ) # rowsum(exp(score - m)) - - self.main_loop(gq, gk, gv, o, m, l) - - self.store_back(o, l, m, o_ptr=o_ptr, batch_size=batch_size, q_len=q_len) - - -def flash_attention( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, -): - """ - Flash attention function for variable length sequences. - - Parameters - ---------- - q: torch.Tensor - The query tensor of shape (bs, seqlen, num_heads, head_size). - - k: torch.Tensor - The key tensor of shape (bs, seqlen, num_heads_kv, head_size). - - v: torch.Tensor - The value tensor of shape (bs, seqlen, num_heads_kv, head_size). - - Returns - ------- - o: torch.Tensor - The output tensor of shape (bs, seqlen, num_heads, head_size). - """ - out = torch.empty_like(q) - FlashAttention( - dtype=tilus.float16, - num_heads=q.size(2), - num_heads_kv=k.size(2), - head_size=q.size(3), - )(q.size(0), q.size(1), k.size(1), q, k, v, out) - return out - - -def flash_attention_reference( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, -): - bs, seqlen, num_heads, head_size = q.size() - _, _, num_heads_kv, _ = k.size() - assert q.size(0) == k.size(0) == v.size(0), "Batch size must match for q, k, and v." - assert q.size(1) == k.size(1) == v.size(1), ( - "Sequence length must match for q, k, and v." - ) - assert q.size(3) == k.size(3) == v.size(3), "Head size must match for q, k, and v." - assert k.size(2) == v.size(2), "Number of heads in k and v must match." - assert num_heads % num_heads_kv == 0, ( - "Number of heads must be divisible by number of kv heads." - ) - - k = torch.repeat_interleave(k, num_heads // num_heads_kv, dim=2) - v = torch.repeat_interleave(v, num_heads // num_heads_kv, dim=2) - - q = torch.transpose(q, 1, 2).reshape(bs * num_heads, seqlen, head_size) - k = torch.transpose(k, 1, 2).reshape(bs * num_heads, seqlen, head_size) - v = torch.transpose(v, 1, 2).reshape(bs * num_heads, seqlen, head_size) - - score = torch.bmm(q, k.mT) / np.sqrt(head_size) # [bs * num_heads, seqlen, seqlen] - causal_mask = torch.tril(torch.ones(seqlen, seqlen, dtype=torch.bool), diagonal=0).to( - q.device - ) - causal_mask = causal_mask.unsqueeze(0) # [1, seqlen, seqlen] - causal_mask = causal_mask.expand( - bs * num_heads, seqlen, seqlen - ).contiguous() # [bs * num_heads, seqlen, seqlen] - score = score.masked_fill(causal_mask == 0, float("-inf")) - - o = torch.bmm( - torch.softmax(score.float(), dim=-1).to(q.dtype), v - ) # [bs * num_heads, seqlen, head_size] - o = o.reshape(bs, num_heads, seqlen, head_size).transpose(1, 2).contiguous() - return o - - -def flash_attention_flash_attn( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, -): - try: - from flash_attn.cute.interface import flash_attn_func - - return flash_attn_func(q, k, v, causal=True) - except ImportError: - return flash_attention_reference(q, k, v) - - -def demo_flash_attention(): - for bs, seqlen, num_heads, head_size, num_heads_kv in [ - # [1, 8, 1, 64, 1], - [1, 4096, 32, 128, 8] - ]: - q = torch.rand(bs, seqlen, num_heads, head_size, dtype=torch.float16).cuda() - k = torch.rand(bs, seqlen, num_heads_kv, head_size, dtype=torch.float16).cuda() - v = torch.rand(bs, seqlen, num_heads_kv, head_size, dtype=torch.float16).cuda() - flash_attention(q, k, v) - torch.cuda.synchronize() - - -def main(bench=True): - headers = [ - "batch_size", - "seqlen", - "num_heads", - "head_size", - "num_heads_kv", - "name", - "latency (ms)", - "tflops", - ] - data = [] - for batch_size, seqlen, num_heads, head_size, num_heads_kv in [ - [1, 512, 32, 128, 8], - [1, 1024, 32, 128, 8], - [1, 2048, 32, 128, 8], - [1, 4096, 32, 128, 8], - #[1, 8192, 32, 128, 8], - [1, 512, 64, 128, 8], - [1, 1024, 64, 128, 8], - [1, 2048, 64, 128, 8], - [1, 4096, 64, 128, 8], - #[1, 8192, 64, 128, 8], - ]: - q = torch.rand( - batch_size, seqlen, num_heads, head_size, dtype=torch.float16 - ).cuda() - k = torch.rand( - batch_size, seqlen, num_heads_kv, head_size, dtype=torch.float16 - ).cuda() - v = torch.rand( - batch_size, seqlen, num_heads_kv, head_size, dtype=torch.float16 - ).cuda() - for name, runner in [ - ("flash-attn", flash_attention_flash_attn), - ("tilus", flash_attention), - ]: - print( - f"Running {name} with batch_size={batch_size}, seqlen={seqlen}, num_heads={num_heads}, head_size={head_size}, num_heads_kv={num_heads_kv}" - ) - try: - actual = runner(q, k, v) - except torch.OutOfMemoryError: - print("Out of memory, skipping this configuration.") - continue - - try: - expected = flash_attention_reference(q, k, v) - torch.testing.assert_close(actual, expected, atol=1e-2, rtol=1e-2) - except torch.OutOfMemoryError: - pass - - latency = ( - benchmark_func( - lambda: runner(q, k, v), - warmup=20, - repeat=50, - ) - if bench - else float("nan") - ) - tflops = ( - 2 * batch_size * num_heads * seqlen * head_size * seqlen / latency * 1e-9 - ) - data.append( - [ - batch_size, - seqlen, - num_heads, - head_size, - num_heads_kv, - name, - latency, - tflops, - ] - ) - df = pd.DataFrame(data, columns=headers) - df_pivot = df.pivot( - index=[ - "batch_size", - "seqlen", - "num_heads", - "head_size", - "num_heads_kv", - ], - columns="name", - values=["latency (ms)", "tflops"], - ).reset_index() - # sort by (batch_size, num_heads, head_size, seqlen) - df_pivot = df_pivot.sort_values( - by=["batch_size", "num_heads", "head_size", "seqlen"], - ascending=[True, True, True, True], - ) - print(df_pivot) - - -if __name__ == "__main__": - main() - # ncu_run(main, bench=False, kernel_regex="flash_fwd|flash_attention") \ No newline at end of file diff --git a/examples/generic_python/flash_attention/triton_attention.py b/examples/generic_python/flash_attention/triton_attention.py deleted file mode 100644 index e4f36a133..000000000 --- a/examples/generic_python/flash_attention/triton_attention.py +++ /dev/null @@ -1,276 +0,0 @@ -# example taken from https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html -# removed optional support for FP8 -# Removed backward pass -# removed checks for e.g. hopper/hip etc -import pytest -import torch -import os -import numpy as np - -import triton -import triton.language as tl -from triton.tools.tensor_descriptor import TensorDescriptor - -DEVICE = triton.runtime.driver.active.get_active_torch_device() - -configs = [ - triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \ - for BM in [64, 128]\ - for BN in [32, 64, 128]\ - for s in [2, 3, 4] \ - for w in [4, 8]\ -] - - -def prune_invalid_configs(configs, named_args, **kwargs): - N_CTX = kwargs["N_CTX"] - STAGE = kwargs["STAGE"] - - # Filter out configs where BLOCK_M > N_CTX - # Filter out configs where BLOCK_M < BLOCK_N when causal is True - return [ - conf for conf in configs if conf.kwargs.get("BLOCK_M", 0) <= N_CTX and ( - conf.kwargs.get("BLOCK_M", 0) >= conf.kwargs.get("BLOCK_N", 0) or STAGE == 1) - ] - - -@triton.jit -def _attn_fwd_inner(acc, l_i, m_i, q, # - desc_k, desc_v, # - offset_y, dtype: tl.constexpr, start_m, qk_scale, # - BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # - STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # - N_CTX: tl.constexpr, warp_specialize: tl.constexpr): - # range of values handled by this stage - if STAGE == 1: - lo, hi = 0, start_m * BLOCK_M - elif STAGE == 2: - lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M - lo = tl.multiple_of(lo, BLOCK_M) - # causal = False - else: - lo, hi = 0, N_CTX - offsetk_y = offset_y + lo - if dtype == tl.float8e5: - offsetv_y = offset_y * HEAD_DIM + lo - else: - offsetv_y = offset_y + lo - # loop over k, v and update accumulator - for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = desc_k.load([offsetk_y, 0]).T - qk = tl.dot(q, k) - if STAGE == 2: - mask = offs_m[:, None] >= (start_n + offs_n[None, :]) - qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk -= m_ij[:, None] - else: - m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) - qk = qk * qk_scale - m_ij[:, None] - p = tl.math.exp2(qk) - # -- compute correction factor - alpha = tl.math.exp2(m_i - m_ij) - l_ij = tl.sum(p, 1) - # -- update output accumulator -- - if warp_specialize and BLOCK_M == 128 and HEAD_DIM == 128: - BM: tl.constexpr = acc.shape[0] - BN: tl.constexpr = acc.shape[1] - acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split() - acc0 = acc0 * alpha[:, None] - acc1 = acc1 * alpha[:, None] - acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN]) - else: - acc = acc * alpha[:, None] - # prepare p and v for the dot - if dtype == tl.float8e5: - v = desc_v.load([0, offsetv_y]).T - else: - v = desc_v.load([offsetv_y, 0]) - p = p.to(dtype) - # note that this non transposed v for FP8 is only supported on Blackwell - acc = tl.dot(p, v, acc) - # update m_i and l_i - # place this at the end of the loop to reduce register pressure - l_i = l_i * alpha + l_ij - m_i = m_ij - offsetk_y += BLOCK_N - offsetv_y += BLOCK_N - return acc, l_i, m_i - - -@triton.jit -def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape): - if isinstance(desc_or_ptr, tl.tensor_descriptor): - return desc_or_ptr - else: - return tl.make_tensor_descriptor(desc_or_ptr, shape, strides, block_shape) - - -@triton.autotune(configs=configs, key=["N_CTX", "HEAD_DIM", "warp_specialize"], - prune_configs_by={'early_config_prune': prune_invalid_configs}) -@triton.jit -def _attn_fwd(sm_scale, M, # - Z, H, desc_q, desc_k, desc_v, desc_o, N_CTX, # - HEAD_DIM: tl.constexpr, # - BLOCK_M: tl.constexpr, # - BLOCK_N: tl.constexpr, # - STAGE: tl.constexpr, # - warp_specialize: tl.constexpr, # - ): - dtype = tl.float16 - tl.static_assert(BLOCK_N <= HEAD_DIM) - start_m = tl.program_id(0) - off_hz = tl.program_id(1) - off_z = off_hz // H - off_h = off_hz % H - - y_dim = Z * H * N_CTX - desc_q = _maybe_make_tensor_desc(desc_q, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], - block_shape=[BLOCK_M, HEAD_DIM]) - desc_v = _maybe_make_tensor_desc(desc_v, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], - block_shape=[BLOCK_N, HEAD_DIM]) - desc_k = _maybe_make_tensor_desc(desc_k, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], - block_shape=[BLOCK_N, HEAD_DIM]) - desc_o = _maybe_make_tensor_desc(desc_o, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], - block_shape=[BLOCK_M, HEAD_DIM]) - - offset_y = off_z * (N_CTX * H) + off_h * N_CTX - qo_offset_y = offset_y + start_m * BLOCK_M - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 - acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) - # load scales - qk_scale = sm_scale - qk_scale *= 1.44269504 # 1/log(2) - # load q: it will stay in SRAM throughout - q = desc_q.load([qo_offset_y, 0]) - # stage 1: off-band - # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE - # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE - if STAGE & 1: - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, # - desc_k, desc_v, # - offset_y, dtype, start_m, qk_scale, # - BLOCK_M, HEAD_DIM, BLOCK_N, # - 4 - STAGE, offs_m, offs_n, N_CTX, # - warp_specialize) - # stage 2: on-band - if STAGE & 2: - acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, # - desc_k, desc_v, # - offset_y, dtype, start_m, qk_scale, # - BLOCK_M, HEAD_DIM, BLOCK_N, # - 2, offs_m, offs_n, N_CTX, # - warp_specialize) - # epilogue - m_i += tl.math.log2(l_i) - acc = acc / l_i[:, None] - m_ptrs = M + off_hz * N_CTX + offs_m - tl.store(m_ptrs, m_i) - desc_o.store([qo_offset_y, 0], acc.to(dtype)) - - - -def forward(q, k, v, causal, sm_scale, warp_specialize=True): - # shape constraints - HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] - # when v is in float8_e5m2 it is transposed. - HEAD_DIM_V = v.shape[-1] - assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V - assert HEAD_DIM_K in {16, 32, 64, 128, 256} - o = torch.empty_like(q) - stage = 3 if causal else 1 - extra_kern_args = {} - M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) - desc_q = q - desc_v = v - desc_k = k - desc_o = o - - def grid(META): - return (triton.cdiv(q.shape[2], META["BLOCK_M"]), q.shape[0] * q.shape[1], 1) - - _attn_fwd[grid]( - sm_scale, M, # - q.shape[0], q.shape[1], # - desc_q, desc_k, desc_v, desc_o, # - N_CTX=q.shape[2], # - HEAD_DIM=HEAD_DIM_K, # - STAGE=stage, # - warp_specialize=warp_specialize, # - **extra_kern_args) - - return o - - - - - -if __name__ == "__main__": - import torch - -# Triton kernel (your forward function) -# from your previous code: forward(ctx, Q, K, V, causal, sm_scale, warp_specialize=True) - -def flash_attention_reference(q, k, v, causal=False, sm_scale=1.0): - """ - Reference attention using batched matmul (like the pytest does). - Args: - q, k, v: [batch, heads, seq_len, head_dim] - causal: whether to apply causal mask - sm_scale: scaling factor for attention scores - Returns: - output: [batch, heads, seq_len, head_dim] - """ - # Batched matrix multiply - # scores: [batch, heads, seq_len, seq_len] - scores = torch.matmul(q, k.transpose(-2, -1)) * sm_scale - - if causal: - seq_len = q.shape[2] - mask = torch.tril(torch.ones(seq_len, seq_len, device=q.device, dtype=torch.bool)) - scores = scores.masked_fill(~mask[None, None, :, :], float("-inf")) - - # Softmax along last dimension - attention = torch.softmax(scores.float(), dim=-1) - - # Weighted sum with V - output = torch.matmul(attention, v) - return output - - - -# Example test -if __name__ == "__main__": - forward_simple() - exit(0) - - # Parameters - B, H, N_CTX, HEAD_DIM = 2, 4, 128, 64 - causal = True - sm_scale = 1.0 - - # Random inputs - Q = torch.randn(B, H, N_CTX, HEAD_DIM, dtype=torch.float16, device="cuda") - K = torch.randn_like(Q) - V = torch.randn_like(Q) - - # Triton output - output_triton = forward(Q, K, V, causal=causal, sm_scale=sm_scale, warp_specialize=True) - - # Reference output - output_ref = flash_attention_reference(Q.float(), K.float(), V.float(), causal=causal, sm_scale=sm_scale) - - # Compare - max_diff = (output_triton.float() - output_ref).abs().max() - mean_diff = (output_triton.float() - output_ref).abs().mean() - print(f"Max difference: {max_diff.item():.6f}") - print(f"Mean difference: {mean_diff.item():.6f}") - - diff --git a/examples/generic_python/loopy_example.py b/examples/generic_python/loopy_example.py deleted file mode 100644 index 49ad131f6..000000000 --- a/examples/generic_python/loopy_example.py +++ /dev/null @@ -1,38 +0,0 @@ -import numpy as np - -import pyopencl as cl -import pyopencl.array - -import loopy as lp -from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 # noqa: F401 - - -# setup -# ----- -ctx = cl.create_some_context() -queue = cl.CommandQueue(ctx) - -n = 15 * 10**6 -a = cl.array.arange(queue, n, dtype=np.float32) - -# create -# ------ -knl = lp.make_kernel( - "{ [i]: 0<=i torch.Tensor: - m, k = x.size() - k, n = y.size() - out = torch.empty([m, n], dtype=x.dtype, device=x.device) - - for tile_m, tile_n in hl.tile([m, n]): - acc = hl.zeros([tile_m, tile_n], dtype=torch.float32) - for tile_k in hl.tile(k): - acc = torch.addmm(acc, x[tile_m, tile_k], y[tile_k, tile_n]) - out[tile_m, tile_n] = acc - - return out - - -out = matmul(torch.randn([2048, 2048], device="cuda"), - torch.randn([2048, 2048], device="cuda")) \ No newline at end of file diff --git a/examples/generic_python/matmul/numba_matmul.py b/examples/generic_python/matmul/numba_matmul.py index 6486ec540..71a6b1b3a 100644 --- a/examples/generic_python/matmul/numba_matmul.py +++ b/examples/generic_python/matmul/numba_matmul.py @@ -250,10 +250,10 @@ def tune_optimized(M, N, K): M, N, K = 1024, 1024, 1024 run_basic(M, N, K) - #tune_basic(M, N, K) + tune_basic(M, N, K) - #run_optimized(M, N, K) - #tune_optimized(M, N, K) + run_optimized(M, N, K) + tune_optimized(M, N, K) diff --git a/examples/generic_python/matmul/test.py b/examples/generic_python/matmul/test.py deleted file mode 100644 index 991ef2cc4..000000000 --- a/examples/generic_python/matmul/test.py +++ /dev/null @@ -1,278 +0,0 @@ -import argparse -import itertools -import tilelang as tl -import tilelang.language as T -from tilelang.autotuner import AutoTuner -from tilelang.carver.template import MatmulTemplate -from tilelang.carver.arch import CUDA -from tilelang.carver.arch import CDNA -from tilelang.carver.roller.rasterization import NoRasterization -import torch - - -def ref_program(A, B): - """ - Compute the matrix product of A and the transpose of B. - - A and B are expected to be 2-D tensors where A has shape (M, K) and B has shape (N, K). The result is a tensor with shape (M, N) equal to A @ B.T, using the inputs' dtypes. - """ - return A @ B.T - - -def get_configs(M, N, K, with_roller=False, topk=20): - """ - Generate a list of kernel tuning configuration dictionaries for a tiled matrix-multiply. - - When with_roller is True this queries the MatmulTemplate roller to produce up to `topk` recommended - configurations (device-specific TensorCore-friendly tilings). Each returned dict contains: - - block_M, block_N, block_K: tile sizes - - num_stages: pipeline staging (0 means no explicit staging) - - thread_num: total threads used for the block - - enable_rasteration: whether a rasterization/swizzle layout was recommended (note spelling) - - When with_roller is False this returns the Cartesian product of a fixed set of candidate - parameters; the returned dicts use the backward-compatible key name "enable_rasteration" for that flag. - - Parameters: - M, N, K (int): GEMM dimensions used to generate valid tile sizes. - with_roller (bool): If True, use MatmulTemplate's roller to generate device-aware hints; - otherwise use a predefined candidate grid. - topk (int): Maximum number of roller hints to request when with_roller is True. - - Returns: - List[dict]: A list of configuration dictionaries as described above. - - Raises: - ValueError: if with_roller is True but the roller returns no hints. - """ - if with_roller: - arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip") - carve_template = MatmulTemplate( - M=M, - N=N, - K=K, - in_dtype=T.float16, - out_dtype=T.float16, - accum_dtype=T.float32, - ).with_arch(arch) - - func = carve_template.equivalent_function() - assert func is not None, "Function is None" - roller_hints = carve_template.recommend_hints(topk=topk) - if roller_hints is None: - raise ValueError("No Roller Hints Found for TensorCore Scheduling") - configs = [] - for hint in roller_hints: - config = {} - block_m, block_n = hint.block - warp_m, warp_n = hint.warp - # block_rows, block_cols represents warp partitioning - block_rows, block_cols = block_m // warp_m, block_n // warp_n - config["block_M"] = block_m - config["block_N"] = block_n - config["block_K"] = hint.rstep[0] - config["num_stages"] = hint.pipeline_stage if hint.pipeline_stage > 1 else 0 - config["thread_num"] = block_rows * block_cols * 32 - config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization - configs.append(config) - else: - block_M = [64, 128, 256] - block_N = [64, 128, 256] - block_K = [32, 64] - num_stages = [0, 1, 2, 3] - thread_num = [128, 256] - enable_rasterization = [True, False] - _configs = list( - itertools.product( - block_M, - block_N, - block_K, - num_stages, - thread_num, - enable_rasterization, - ) - ) - - configs = [ - { - "block_M": c[0], - "block_N": c[1], - "block_K": c[2], - "num_stages": c[3], - "thread_num": c[4], - "enable_rasteration": c[5], # keep param name for backward-compat - } - for c in _configs - ] - return configs - - -def get_best_config( - M, - N, - K, - with_roller: bool = False, - profile_backend: str = "event", -): - def kernel( - block_M=None, - block_N=None, - block_K=None, - num_stages=None, - thread_num=None, - enable_rasteration=None, - ): - dtype = T.bfloat16 - accum_dtype = T.float32 - - @T.prim_func - def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_N, block_K), dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - C_shared = T.alloc_shared((block_M, block_N), dtype) - T.use_swizzle(panel_size=10, enable=enable_rasteration) - T.clear(C_local) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[bx * block_N, k * block_K], B_shared) - T.gemm( - A_shared, - B_shared, - C_local, - transpose_B=True, - ) - T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M, bx * block_N]) - - return main - - autotuner = ( - AutoTuner.from_kernel(kernel=kernel, configs=get_configs(M, N, K, with_roller)) - .set_compile_args( - out_idx=[-1], - target="auto", - ) - .set_profile_args( - supply_type=tl.TensorSupplyType.Integer, - ref_prog=ref_program, - skip_check=False, - backend=profile_backend, - ) - ) - return autotuner.run(warmup=3, rep=20) - - -def get_heuristic_config() -> dict: - # Get CUDA device properties - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is not available") - device = torch.cuda.current_device() - sm_major, sm_minor = torch.cuda.get_device_capability(device) - sm_version = sm_major * 10 + sm_minor - print(f"CUDA device capability: {sm_version}") - if sm_version in {80}: - return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 2, "thread_num": 128, "enable_rasteration": True} - elif sm_version in {90}: - return {"block_M": 128, "block_N": 256, "block_K": 64, "num_stages": 3, "thread_num": 256, "enable_rasteration": True} - else: - return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 0, "thread_num": 128, "enable_rasteration": True} - - -@tl.jit(out_idx=[-1]) -def matmul(M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32): - @T.prim_func - def gemm_autotune( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), dtype), - ): - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_N, block_K), dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - C_shared = T.alloc_shared((block_M, block_N), dtype) - T.use_swizzle(panel_size=10, enable=enable_rasteration) - T.clear(C_local) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[bx * block_N, k * block_K], B_shared) - T.gemm( - A_shared, - B_shared, - C_local, - transpose_B=True, - ) - T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M, bx * block_N]) - - return gemm_autotune - - -def main( - M: int = 4096, - N: int = 4096, - K: int = 4096, - use_autotune: bool = False, - with_roller: bool = False, - profile_backend: str = "event", -): - if use_autotune: - result = get_best_config( - M, - N, - K, - with_roller=with_roller, - profile_backend=profile_backend, - ) - print(result.config) - kernel = result.kernel - else: - config = get_heuristic_config() - kernel = matmul(M, N, K, **config) - - # benchmark - profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto) - tilelang_latency = profiler.do_bench( - backend=profile_backend, - ) - ref_latency = profiler.do_bench( - ref_program, - backend=profile_backend, - ) - profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) - print(f"TileLang latency: {tilelang_latency}") - print(f"Ref latency: {ref_latency}") - print(f"TileLang TFlops: {2 * M * N * K / tilelang_latency * 1e-9}") - print(f"Ref TFlops: {2 * M * N * K / ref_latency * 1e-9}") - - -def run_regression_perf(M: int = 4096, N: int = 4096, K: int = 4096): - config = get_heuristic_config() - kernel = matmul(M, N, K, **config) - profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto) - return profiler.do_bench(backend="cupti") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") - parser.add_argument("--m", type=int, default=4096, help="Matrix dimension M") - parser.add_argument("--n", type=int, default=4096, help="Matrix dimension N") - parser.add_argument("--k", type=int, default=4096, help="Matrix dimension K") - parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune for matmul configs") - parser.add_argument("--with_roller", action="store_true", default=False, help="Whether to enable BitBLAS roller for search space") - parser.add_argument("--profile_backend", type=str, default="event", help="Profiler backend") - args = parser.parse_args() - main( - args.m, - args.n, - args.k, - args.use_autotune, - args.with_roller, - args.profile_backend, - ) \ No newline at end of file diff --git a/examples/generic_python/matmul/test_tilus.py b/examples/generic_python/matmul/test_tilus.py deleted file mode 100644 index 6d4b2430d..000000000 --- a/examples/generic_python/matmul/test_tilus.py +++ /dev/null @@ -1,26 +0,0 @@ -import torch -from tilus_matmul import MatmulBasic - -sizes = [ - (65, 65, 17), - (67, 71, 19), - (1, 1, 1), - (63, 63, 15), - (129, 130, 33), -] - -matmul = MatmulBasic() -for m, n, k in sizes: - print(m, n, k) - - a = torch.randn(m, k, dtype=torch.float16, device="cuda") - b = torch.randn(k, n, dtype=torch.float16, device="cuda") - - c_actual = torch.empty(m, n, dtype=torch.float16, device="cuda") - c_expect = a @ b - - matmul(m, n, k, a, b, c_actual) - - torch.cuda.synchronize() - - torch.testing.assert_close(c_expect, c_actual, atol=1e-2, rtol=1e-2) \ No newline at end of file diff --git a/examples/generic_python/matmul/tilelang_matmul.py b/examples/generic_python/matmul/tilelang_matmul.py index 0c476c786..751c0add4 100644 --- a/examples/generic_python/matmul/tilelang_matmul.py +++ b/examples/generic_python/matmul/tilelang_matmul.py @@ -315,10 +315,10 @@ def tune_opt(M, N, K): if __name__ == "__main__": M, N, K = 4096, 4096, 4096 - #run_basic(M, N, K) - #run_opt(M, N, K) + run_basic(M, N, K) + run_opt(M, N, K) tune_basic(M, N, K) - #tune_opt(M, N, K) + tune_opt(M, N, K) diff --git a/examples/generic_python/matmul/tilus_matmul.py b/examples/generic_python/matmul/tilus_matmul.py index 811d5c2d7..c65cdeaa7 100644 --- a/examples/generic_python/matmul/tilus_matmul.py +++ b/examples/generic_python/matmul/tilus_matmul.py @@ -257,11 +257,7 @@ def tune_opt(M, N, K): if __name__ == "__main__": M, N, K = 4096, 4096, 4096 - #M, N, K = 8192, 8192, 8192 - #run_basic(M, N, K) - - #run_optmized(M, N, K) - + run_basic(M, N, K) + run_optimized(M, N, K) tune_basic(M, N, K) - - #tune_opt(M, N, K) \ No newline at end of file + tune_opt(M, N, K) \ No newline at end of file diff --git a/examples/generic_python/matmul/triton_matmul.py b/examples/generic_python/matmul/triton_matmul.py index cac3cab73..d5522c627 100644 --- a/examples/generic_python/matmul/triton_matmul.py +++ b/examples/generic_python/matmul/triton_matmul.py @@ -305,10 +305,8 @@ def tune_opt(M, N, K): if __name__ == "__main__": M, N, K = 4096, 4096, 4096 - #M, N, K = 8192, 8192, 8192 - #run_basic(M, N, K) - #run_opt(M, N, K) + run_basic(M, N, K) + run_opt(M, N, K) - - #tune_basic(M, N, K) - #tune_opt(M, N, K) + tune_basic(M, N, K) + tune_opt(M, N, K) diff --git a/examples/generic_python/matmul/warp_matmul.py b/examples/generic_python/matmul/warp_matmul.py index 388ae3a1b..c4a218b21 100644 --- a/examples/generic_python/matmul/warp_matmul.py +++ b/examples/generic_python/matmul/warp_matmul.py @@ -1,4 +1,3 @@ -import torch import numpy as np import warp as wp @@ -30,38 +29,6 @@ def gemm( C[i, j] = wp.float16(sum) -# tile size -TILE_M = 32 -TILE_N = 32 -TILE_K = 32 - -# num threads per-tile -TILE_THREADS = 1024 - -# GEMM example from https://nvidia.github.io/warp/user_guide/tiles.html -@wp.kernel() -def tile_gemm( - A: wp.array2d(dtype=wp.float16), - B: wp.array2d(dtype=wp.float16), - C: wp.array2d(dtype=wp.float16) -): - # output tile index - i, j = wp.tid() - - sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=wp.float32) - - K = A.shape[1] - count = (K + TILE_K - 1) // TILE_K - - for k in range(0, count): - a = wp.tile_load(A, shape=(TILE_M, TILE_K), offset=(i*TILE_M, k*TILE_K), bounds_check=True) - b = wp.tile_load(B, shape=(TILE_K, TILE_N), offset=(k*TILE_K, j*TILE_N), bounds_check=True) - - wp.tile_matmul(a, b, sum) - - wp.tile_store(C, wp.tile_astype(sum, wp.float16), offset=(i*TILE_M, j*TILE_N), bounds_check=True) - - def run_gemm(M, N, K): rng = np.random.default_rng(42) A = rng.random((M, K)).astype(np.float16) @@ -79,28 +46,6 @@ def run_gemm(M, N, K): print("Succes") -def run_gemm_tiled(M, N, K): - rng = np.random.default_rng(42) - A = rng.random((M, K)).astype(np.float16) - B = rng.random((K, N)).astype(np.float16) - C = np.zeros((M, N), dtype=np.float16) - - A_wp = wp.array(A) - B_wp = wp.array(B) - C_wp = wp.array(C) - - - wp.launch_tiled( - tile_gemm, - dim=((M + TILE_M - 1) // TILE_M, (N + TILE_N - 1) // TILE_N), - inputs=[A_wp, B_wp, C_wp], - block_dim=TILE_THREADS) - - np.testing.assert_allclose(C_wp.numpy(), A @ B, rtol=1e-2, atol=M * 2**(-11)) - - print("Succes") - - def tune(M, K, N): rng = np.random.default_rng(42) A = rng.random((M, K)).astype(np.float16) @@ -132,9 +77,7 @@ def tune(M, K, N): if __name__ == "__main__": M, N, K = 1024, 1024, 1024 - #run_gemm(M, N, K) - #run_gemm_tiled(M, N, K) - + run_gemm(M, N, K) tune(M, N, K) diff --git a/examples/generic_python/matmul_old/tilelang_matmul.py b/examples/generic_python/matmul_old/tilelang_matmul.py deleted file mode 100644 index 2a7b5154d..000000000 --- a/examples/generic_python/matmul_old/tilelang_matmul.py +++ /dev/null @@ -1,212 +0,0 @@ -import tilelang -import tilelang.language as T -import torch -from kernel_tuner import tune_kernel - - -#@tilelang.jit -def matmul(M, N, K, block_M, block_N, block_K, dtype: str = 'float16', accum_dtype: str = 'float32'): - @T.prim_func - def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), - ): - # Define a grid with enough blocks to cover M×N - num_threads=128 - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): - - # Allocate shared memory for the current tile of A and B - A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_K, block_N), dtype) - - # Allocate a local (register) fragment for partial accumulations - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - - # Enable swizzle-based rasterization for better L2 locality - panel_size = 4 - T.use_swizzle(panel_size=panel_size, enable=True) - - # Initialize the local accumulation buffer to zero - T.clear(C_local) - - num_stages=3 - - # Loop over the K dimension in block_K chunks, using a pipeline - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - # Copy from global memory to shared memory - T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[k * block_K, bx * block_N], B_shared) - - # Perform a matrix multiply-accumulate on the tile - T.gemm(A_shared, B_shared, C_local) - - # Copy the accumulated result from local memory (C_local) to global memory (C) - T.copy(C_local, C[by * block_M, bx * block_N]) - - return main - -@tilelang.jit -def matmul_with_decorator(M, N, K, block_M, block_N, block_K, dtype: str = 'float16', accum_dtype: str = 'float32'): - @T.prim_func - def main( - A: T.Tensor((M, K), dtype), - B: T.Tensor((K, N), dtype), - C: T.Tensor((M, N), dtype), - ): - # Define a grid with enough blocks to cover M×N - with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): - - # Allocate shared memory for the current tile of A and B - A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_K, block_N), dtype) - - # Allocate a local (register) fragment for partial accumulations - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - - # Initialize the local accumulation buffer to zero - T.clear(C_local) - - # Loop over the K dimension in block_K chunks, using a 3-stage pipeline - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): - # Copy from global memory to shared memory - T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[k * block_K, bx * block_N], B_shared) - - # Perform a matrix multiply-accumulate on the tile - T.gemm(A_shared, B_shared, C_local) - - # Copy the accumulated result from local memory (C_local) to global memory (C) - T.copy(C_local, C[by * block_M, bx * block_N]) - - return main - - - -def run(m, n, k): - a = torch.randn(m, k, device="cuda", dtype=torch.float16) - b = torch.randn(k, n, device="cuda", dtype=torch.float16) - c = torch.empty(m, n, device="cuda", dtype=torch.float16) - kernel = matmul(m, n, k, 128, 128, 32) - kernel(a, b, c) - ref_c = a @ b - tol = m * 2**(-11) - # Validate results - torch.testing.assert_close(c, ref_c, rtol=tol, atol=tol) - - -def call_tilelang(kernel_function, args, kwargs, grid, threads, params): - compiled_kernel = kernel_function(**kwargs) - compiled_kernel(*args) - - -def time(m, n, k): - a = torch.randn(m, k, device="cuda", dtype=torch.float16) - b = torch.randn(k, n, device="cuda", dtype=torch.float16) - c = torch.empty(m, n, device="cuda", dtype=torch.float16) - c_ans = a @ b - - args = [a, b, c] - tune_params = dict() - tune_params["M"] = [m] - tune_params["K"] = [k] - tune_params["N"] = [n] - tune_params["block_M"] = [64, 128] - tune_params["block_N"] = [64, 128] - tune_params["block_K"] = [32, 64] - - results_kt, env = tune_kernel("matmul", matmul, m * n, args, tune_params, lang="generic_python", - call_function=call_tilelang, decorator="@tilelang.jit", verbose=False, iterations=100) - - import time - num_repeats = 100 - times_direct = [] - for config in results_kt: - bs_m = config["block_M"] - bs_n = config["block_N"] - bs_k = config["block_K"] - - c = torch.empty(m, n, device="cuda", dtype=torch.float16) - - kernel = matmul_with_decorator(m, n, k, bs_m, bs_n, bs_k) - kernel(a, b, c) - - torch.allclose(c.cpu(), c_ans.cpu(), atol=m * 2**(-11)) - - for i in range(num_repeats): - times = [] - - torch.cuda.synchronize() - start = time.time() - kernel(a, b, c) - torch.cuda.synchronize() - times.append(time.time() - start) - - avg_time_ms = round((1000 * sum(times) / len(times)), 3) - times_direct.append(avg_time_ms) - print(f"BLOCK_SIZE_M={bs_m}, BLOCK_SIZE_N={bs_n}, BLOCK_SIZE_K={bs_k}, time={avg_time_ms}ms") - - - import matplotlib.pyplot as plt - - # Extract times - times_kt = [cfg['time'] for cfg in results_kt] - - # x-axis labels - configs = [f"config{i}" for i in range(len(times_kt))] - x = range(len(configs)) - - plt.figure(figsize=(10,6)) - plt.plot(configs, times_kt, marker='s', label='KernelTuner') - plt.plot(configs, times_direct, marker='x', label='Direct') - plt.ylabel('Time (ms)') - plt.xlabel('Configuration') - plt.title('Kernel execution time per configuration') - plt.xticks(rotation=45) - plt.legend() - plt.grid(True) - plt.tight_layout() - plt.savefig("tilelang.png") - print("saved fig") - - -def tune(m, n, k): - a = torch.randn(m, k, device="cuda", dtype=torch.float16) - b = torch.randn(k, n, device="cuda", dtype=torch.float16) - c = torch.empty(m, n, device="cuda", dtype=torch.float16) - c_actual = a @ b - - args = [a, b, c] - tune_params = dict() - tune_params["M"] = [m] - tune_params["K"] = [k] - tune_params["N"] = [n] - tune_params["block_M"] = [64, 128, 256] - tune_params["block_N"] = [64, 128, 256] - tune_params["block_K"] = [32, 64, 128] - tune_params["num_stages"] = [2, 3, 4] - tune_params["panel_size"] = [4, 8] # equivalent to group size m in Triton - tune_params["num_threads"] = [64, 128, 256] - - restrictions = [ - # tile size budget - "block_M * block_N <= 16384", - - # aspect ratio <= 4 (no max/min allowed, so expand manually) - "block_M <= 4 * block_N", - "block_N <= 4 * block_M", - - # large K only with reasonably large M/N - "not (block_K == 128 and block_M < 64 and block_N < 64)", - ] - - tol = m * 2**(-11) - answer = [None, None, c_actual.cpu()] - - results, env = tune_kernel("matmul", matmul, m * n, args, tune_params, atol=tol, lang="generic_python", - call_function=call_tilelang, restrictions=restrictions, answer=answer, decorator="@tilelang.jit", verbose=False) - -if __name__ == "__main__": - #m, n, k = 1024, 1024, 1024 - m, n, k = 8192, 8192, 8192 - time(m, n, k) \ No newline at end of file diff --git a/examples/generic_python/matmul_old/tilus_matmul.py b/examples/generic_python/matmul_old/tilus_matmul.py deleted file mode 100644 index e8996c201..000000000 --- a/examples/generic_python/matmul_old/tilus_matmul.py +++ /dev/null @@ -1,266 +0,0 @@ -import math - -import pandas -import tilus -import torch -from tilus import float16, float32, int32 -from tilus.utils import benchmark_func -from kernel_tuner import tune_kernel, run_kernel - - - -class MatmulV4(tilus.Script): - def __init__(self): - super().__init__() - self.block_m = 128 - self.block_n = 128 - self.block_k = 16 - self.num_warps = 4 - self.num_stages = 4 - - def __call__( - self, - m_size: int32, - n_size: int, - k_size: int, - a_ptr: ~float16, - b_ptr: ~float16, - c_ptr: ~float16, - ): - self.attrs.blocks = [ - self.utils.ceil_div(m_size, self.block_m), - self.utils.ceil_div(n_size, self.block_n), - ] - self.attrs.warps = self.num_warps - - block_m, block_n, block_k = self.block_m, self.block_n, self.block_k - offset_m: int32 = block_m * self.blockIdx.x - offset_n: int32 = block_n * self.blockIdx.y - - ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) - gb = self.global_view(b_ptr, dtype=float16, shape=[k_size, n_size]) - sa = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_m, block_k]) - sb = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_k, block_n]) - acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0) - - for stage in range(self.num_stages - 1): - offset_k = stage * self.block_k - self.copy_async(src=ga, dst=sa[stage], offsets=[offset_m, offset_k]) - self.copy_async(src=gb, dst=sb[stage], offsets=[offset_k, offset_n]) - self.copy_async_commit_group() - - self.copy_async_wait_group(n=self.num_stages - 2) - self.sync() - - current_stage: int32 = 0 - preload_stage: int32 = self.num_stages - 1 - for offset_k in self.range(0, k_size, block_k, unroll=self.num_stages): - # computation for current tile - a = self.load_shared(sa[current_stage]) - b = self.load_shared(sb[current_stage]) - self.dot(a, b, acc, out=acc) - - # preload the next tile of A and B into shared memory - preload_offset_k = offset_k + (self.num_stages - 1) * block_k - self.copy_async( - src=ga, - dst=sa[preload_stage], - offsets=[offset_m, preload_offset_k], - ) - self.copy_async( - src=gb, - dst=sb[preload_stage], - offsets=[preload_offset_k, offset_n], - ) - self.copy_async_commit_group() - - # update the stage - current_stage = (current_stage + 1) % self.num_stages - preload_stage = (preload_stage + 1) % self.num_stages - self.copy_async_wait_group(n=self.num_stages - 2) - self.sync() - - self.free_shared(sa) - self.free_shared(sb) - - casted_acc = self.cast(acc, dtype=float16) - gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) - self.store_global(gc, casted_acc, offsets=[offset_m, offset_n]) - - -class MatmulGroupedOrdering(tilus.Script): - def __init__(self): - super().__init__() - self.block_m = 128 - self.block_n = 128 - self.block_k = 16 - self.num_warps = 4 - self.num_stages = 4 - self.group_size_m = 8 - - def __call__( - self, - m_size: int32, - n_size: int, - k_size: int, - a_ptr: ~float16, - b_ptr: ~float16, - c_ptr: ~float16, - ): - block_m, block_n, block_k = self.block_m, self.block_n, self.block_k - - num_pid_m = self.utils.ceil_div(m_size, block_m) - num_pid_n = self.utils.ceil_div(n_size, block_n) - self.attrs.blocks = [num_pid_m * num_pid_n] - - pid = self.blockIdx.x - num_pid_in_group = self.group_size_m * num_pid_n - group_id = pid // num_pid_in_group - - first_pid_m = group_id * self.group_size_m - group_size_m = min(num_pid_m - first_pid_m, self.group_size_m) - - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - self.attrs.warps = self.num_warps - offset_m: int32 = pid_m * block_m - offset_n: int32 = pid_n * block_n - - ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) - gb = self.global_view(b_ptr, dtype=float16, shape=[k_size, n_size]) - sa = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_m, block_k]) - sb = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_k, block_n]) - acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0) - - for stage in range(self.num_stages - 1): - offset_k = stage * self.block_k - self.copy_async(src=ga, dst=sa[stage], offsets=[offset_m, offset_k]) - self.copy_async(src=gb, dst=sb[stage], offsets=[offset_k, offset_n]) - self.copy_async_commit_group() - - self.copy_async_wait_group(n=self.num_stages - 2) - self.sync() - - current_stage: int32 = 0 - preload_stage: int32 = self.num_stages - 1 - for offset_k in self.range(0, k_size, block_k, unroll=self.num_stages): - # computation for current tile - a = self.load_shared(sa[current_stage]) - b = self.load_shared(sb[current_stage]) - self.dot(a, b, acc, out=acc) - - # preload the next tile of A and B into shared memory - preload_offset_k = offset_k + (self.num_stages - 1) * block_k - self.copy_async( - src=ga, - dst=sa[preload_stage], - offsets=[offset_m, preload_offset_k], - ) - self.copy_async( - src=gb, - dst=sb[preload_stage], - offsets=[preload_offset_k, offset_n], - ) - self.copy_async_commit_group() - - # update the stage - current_stage = (current_stage + 1) % self.num_stages - preload_stage = (preload_stage + 1) % self.num_stages - self.copy_async_wait_group(n=self.num_stages - 2) - self.sync() - - self.free_shared(sa) - self.free_shared(sb) - - casted_acc = self.cast(acc, dtype=float16) - gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) - self.store_global(gc, casted_acc, offsets=[offset_m, offset_n]) - - - -def main(): - headers = ["m", "n", "k", "name", "latency (ms)", "tflops"] - workloads = [ - [4096, 4096, 4096], - [1024, 1024, 14336], - ] - - rows = [] - for m, n, k in workloads: - matmul = MatmulGroupedOrdering() #MatmulV4() - - a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) - b = (torch.rand(k, n, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) - c_actual = torch.empty(m, n, dtype=torch.float16).cuda() - c_expect = a @ b - matmul(m, n, k, a, b, c_actual) - - # check correctness - torch.testing.assert_close(c_expect, c_actual) - - # benchmark - for name, func in [ - ("torch", lambda: torch.matmul(a, b, out=c_expect)), - ("tilus", lambda: matmul(m, n, k, a, b, c_actual)), - ]: - latency = benchmark_func(func, warmup=5, repeat=20) - tflops = 2 * m * n * k / latency * 1e-9 - rows.append([m, n, k, name, latency, tflops]) - - df = pandas.DataFrame(rows, columns=headers) - print(df) - -def call_tilus(kernel_function, args, kwargs, grid, threads, params): - kernel_function(*args, **kwargs) - - -def tune_matmul(m, n, k): - a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) - b = (torch.rand(k, n, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) - c_actual = torch.empty(m, n, dtype=torch.float16).cuda() - c_expect = a @ b - - size = m * n #(m, n) - args = [m, n, k, a, b, c_actual] - tune_params = dict() - tune_params["block_m"] = [32, 64, 128, 256] - tune_params["block_n"] = [32, 64, 128, 256] - tune_params["block_k"] = [32, 64, 128] - tune_params["group_size_m"] = [4, 8] - tune_params["num_stages"] = [2, 3, 4] - tune_params["num_warps"] = [4, 8] - - - restrictions = [ - # tile size budget - "block_m * block_n <= 16384", - - # aspect ratio <= 4 (no max/min allowed, so expand manually) - "block_m <= 4 * block_n", - "block_n <= 4 * block_m", - - # large K only with reasonably large M/N - "not (block_k == 128 and block_m < 64 and block_n < 64)", - - # 32x32 requires 8 warps - "not (block_m == 32 and block_n == 32 and num_warps < 8)", - ] - - - answer = [None] * 6 - answer[-1] = c_expect.cpu() - atol = 1e-2 #m * 2**(-11) - - results, env = tune_kernel("MatmulGroupedOrdering", MatmulGroupedOrdering, size, args, tune_params, grid_div_x = ["block_m", "block_n"], - answer = answer, atol=atol, restrictions=restrictions, - lang="generic_python", call_function=call_tilus, - block_size_names=["block_m", "block_n", "block_k"], strategy="simulated_annealing") - - -if __name__ == "__main__": - #m, n, k = 4096, 4096, 4096 - m, n, k = 8192, 8192, 8192 - tune_matmul(m, n, k) - - #main() \ No newline at end of file diff --git a/examples/generic_python/matmul_old/triton_matmul.py b/examples/generic_python/matmul_old/triton_matmul.py deleted file mode 100644 index 238a61c9f..000000000 --- a/examples/generic_python/matmul_old/triton_matmul.py +++ /dev/null @@ -1,383 +0,0 @@ -import torch - -import triton -import triton.language as tl - -from kernel_tuner import tune_kernel -from kernel_tuner import run_kernel - -def get_cuda_autotune_config(): - return [ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, - num_warps=8), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, - num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, - num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, - num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, - num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, - num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, - num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, - num_warps=2), - # Good config for fp8 inputs. - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, - num_warps=8), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, - num_warps=8), - triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, - num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, - num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, - num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, - num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, - num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, - num_warps=4) - ] - -''' -@triton.autotune( - configs=get_cuda_autotune_config(), - key=['M', 'N', 'K'], -) -''' -#@triton.jit -def matmul_kernel( - # Pointers to matrices - a_ptr, b_ptr, c_ptr, - # Matrix dimensions - M, N, K, - # The stride variables represent how much to increase the ptr by when moving by 1 - # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` - # by to get the element one row down (A has M rows). - stride_am, stride_ak, # - stride_bk, stride_bn, # - stride_cm, stride_cn, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # - GROUP_SIZE_M: tl.constexpr, num_stages, num_warps# -): - """Kernel for computing the matmul C = A x B. - A has shape (M, K), B has shape (K, N) and C has shape (M, N) - """ - # ----------------------------------------------------------- - # Map program ids `pid` to the block of C it should compute. - # This is done in a grouped ordering to promote L2 data reuse. - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - # ----------------------------------------------------------- - # Add some integer bound assumptions. - # This helps to guide integer analysis in the backend to optimize - # load/store offset address calculation - tl.assume(pid_m >= 0) - tl.assume(pid_n >= 0) - tl.assume(stride_am > 0) - tl.assume(stride_ak > 0) - tl.assume(stride_bn > 0) - tl.assume(stride_bk > 0) - tl.assume(stride_cm > 0) - tl.assume(stride_cn > 0) - - # ---------------------------------------------------------- - # Create pointers for the first blocks of A and B. - # We will advance this pointer as we move in the K direction - # and accumulate - # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers - # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - - # ----------------------------------------------------------- - # Iterate to compute a block of the C matrix. - # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block - # of fp32 values for higher accuracy. - # `accumulator` will be converted back to fp16 after the loop. - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - # Load the next block of A and B, generate a mask by checking the K dimension. - # If it is out of bounds, set it to 0. - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) - # We accumulate along the K dimension. - accumulator = tl.dot(a, b, accumulator) - # Advance the ptrs to the next K block. - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - c = accumulator.to(tl.float16) - - # ----------------------------------------------------------- - # Write back the block of the output matrix C with masks. - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) - tl.store(c_ptrs, c, mask=c_mask) - - -def run_matmul(m, n, k): - a = torch.rand(m, k, dtype=torch.float16).cuda() - b = torch.rand(k, n, dtype=torch.float16).cuda() - c_actual = torch.empty(m, n, dtype=torch.float16).cuda() - c_expect = a @ b - - grid = lambda META: (triton.cdiv(m, META['BLOCK_SIZE_M']) * triton.cdiv(n, META['BLOCK_SIZE_N']), ) - - matmul_kernel[grid]( - a, b, c_actual, # - m, n, k, # - a.stride(0), a.stride(1), # - b.stride(0), b.stride(1), # - c_actual.stride(0), c_actual.stride(1), - 128, 256, 64, 8 - ) - - torch.testing.assert_close(c_expect, c_actual, atol=1e-2, rtol=1e-2) - - -def run_matmul_kt(m, n, k): - a = torch.rand(m, k, dtype=torch.float16).cuda() - b = torch.rand(k, n, dtype=torch.float16).cuda() - c_actual = torch.empty(m, n, dtype=torch.float16).cuda() - c_expect = a @ b - - size = m * n - - args = [a, b, c_actual, m, n, k, a.stride(0), a.stride(1), b.stride(0), b.stride(1), - c_actual.stride(0), c_actual.stride(1)] - - params = {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M":4, "num_stages":3, "num_warps":4} - - result = run_kernel("matmul_kernel", matmul_kernel, size, args, params=params, grid_div_x=["BLOCK_SIZE_N", "BLOCK_SIZE_M"], - lang="generic_python", decorator="@triton.jit", call_function=call_triton, - block_size_names=["BLOCK_SIZE_M", "BLOCK_SIZE_N", "BLOCK_SIZE_K"]) - c_res = result[2] - - assert torch.allclose(c_res, c_expect.cpu(), atol=1e-2, rtol=1e-1) - - - - - - -def call_triton(kernel_function, args, kwargs, grid, threads, params): - #print("using grid: ", grid) - #print("args: ", args) - #print("kwargs: ", kwargs) - kernel_function[grid](*args, **kwargs) - - - - -def check_time(m, n, k): - a = torch.rand(m, k, dtype=torch.float16).cuda() - b = torch.rand(k, n, dtype=torch.float16).cuda() - c_actual = torch.empty(m, n, dtype=torch.float16).cuda() - c_expect = a @ b - - size = m * n - args = [a, b, c_actual, m, n, k, a.stride(0), a.stride(1), b.stride(0), b.stride(1), - c_actual.stride(0), c_actual.stride(1)] - tune_params = dict() - tune_params["BLOCK_SIZE_M"] = [64, 128] - tune_params["BLOCK_SIZE_N"] = [64, 128] - tune_params["BLOCK_SIZE_K"] = [32, 64] - tune_params["GROUP_SIZE_M"] = [4, 8] - tune_params["num_stages"] = [3] - tune_params["num_warps"] = [4] - - restrictions = [ - # tile size budget - "BLOCK_SIZE_M * BLOCK_SIZE_N <= 16384", - - # aspect ratio <= 4 (no max/min allowed, so expand manually) - "BLOCK_SIZE_M <= 4 * BLOCK_SIZE_N", - "BLOCK_SIZE_N <= 4 * BLOCK_SIZE_M", - - # large K only with reasonably large M/N - "not (BLOCK_SIZE_K == 128 and BLOCK_SIZE_M < 64 and BLOCK_SIZE_N < 64)", - - # 32x32 requires 8 warps - "not (BLOCK_SIZE_M == 32 and BLOCK_SIZE_N == 32 and num_warps < 8)", - ] - - grid_div = ["BLOCK_SIZE_N", "BLOCK_SIZE_M"] - - answer = [None] * 12 - answer[2] = c_expect.cpu() - atol = 1e-2 #m * 2**(-11) - - - - results_ours, _ = tune_kernel("matmul_kernel", matmul_kernel, size, args, tune_params, grid_div_x = grid_div, - restrictions=restrictions, iterations=100, answer=answer, atol=atol, - lang="generic_python", decorator="@triton.jit", call_function=call_triton, - block_size_names=["BLOCK_SIZE_M", "BLOCK_SIZE_N", "BLOCK_SIZE_K"]) - - - results_prev, _ = tune_kernel("matmul_kernel", matmul_kernel, size, args, tune_params, grid_div_x = grid_div, - restrictions=restrictions, iterations=100, answer=answer, atol=atol, - lang="triton", block_size_names=["BLOCK_SIZE_M", "BLOCK_SIZE_N", "BLOCK_SIZE_K"]) - - - - import time - num_repeats = 100 - times_direct = [] - for config in results_prev: - - bs_m = config["BLOCK_SIZE_M"] - bs_n = config["BLOCK_SIZE_N"] - bs_k = config["BLOCK_SIZE_K"] - gs_m = config["GROUP_SIZE_M"] - num_stages = config["num_stages"] - num_warps = config["num_warps"] - - c_actual = torch.empty(m, n, dtype=torch.float16).cuda() - - grid = (triton.cdiv(m, bs_m) * triton.cdiv(n, bs_n, ), ) - jit_function = triton.jit(matmul_kernel) - - - jit_function[grid]( - a, b, c_actual, - m, n, k, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c_actual.stride(0), c_actual.stride(1), - bs_m, bs_n, bs_k, gs_m, num_stages, num_warps - ) - - - torch.allclose(c_expect.cpu(), c_actual.cpu(), atol=1e-2) - - - - for i in range(num_repeats): - times = [] - - #c_actual = torch.empty(m, n, dtype=torch.float16).cuda() - - torch.cuda.synchronize() - start = time.time() - jit_function[grid]( - a, b, c_actual, - m, n, k, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c_actual.stride(0), c_actual.stride(1), - bs_m, bs_n, bs_k, gs_m, num_stages, num_warps - ) - - torch.cuda.synchronize() - times.append(time.time() - start) - - avg_time_ms = round((1000 * sum(times) / len(times)), 3) - times_direct.append(avg_time_ms) - print(f"BLOCK_SIZE_M={bs_m}, BLOCK_SIZE_N={bs_n}, BLOCK_SIZE_K={bs_k}, GROUP_SIZE_M={gs_m}, num_stages={num_stages}, num_warps={num_warps}, time={avg_time_ms}ms") - - - import matplotlib.pyplot as plt - - # Extract times - times_prev = [cfg['time'] for cfg in results_prev] - times_ours = [cfg['time'] for cfg in results_ours] - - # x-axis labels - configs = [f"config{i}" for i in range(len(times_prev))] - x = range(len(configs)) - - plt.figure(figsize=(10,6)) - plt.plot(configs, times_prev, marker='o', label='Triton tuned') - plt.plot(configs, times_ours, marker='s', label='Generic tuned') - plt.plot(configs, times_direct, marker='x', label='Direct') - plt.ylabel('Time (ms)') - plt.xlabel('Configuration') - plt.title('Kernel execution time per configuration') - plt.xticks(rotation=45) - plt.legend() - plt.grid(True) - plt.tight_layout() - plt.savefig("ouptut.png") - - - - - - -def tune_matmul(m, n, k): - a = torch.rand(m, k, dtype=torch.float16).cuda() - b = torch.rand(k, n, dtype=torch.float16).cuda() - c_actual = torch.empty(m, n, dtype=torch.float16).cuda() - c_expect = a @ b - - size = m * n - args = [a, b, c_actual, m, n, k, a.stride(0), a.stride(1), b.stride(0), b.stride(1), - c_actual.stride(0), c_actual.stride(1)] - tune_params = dict() - tune_params["BLOCK_SIZE_M"] = [64, 128, 256] - tune_params["BLOCK_SIZE_N"] = [64, 128, 256] - tune_params["BLOCK_SIZE_K"] = [32, 64, 128] - tune_params["GROUP_SIZE_M"] = [4, 8] - tune_params["num_stages"] = [2, 3, 4] - tune_params["num_warps"] = [4, 8] - - restrictions = [ - # tile size budget - "BLOCK_SIZE_M * BLOCK_SIZE_N <= 16384", - - # aspect ratio <= 4 (no max/min allowed, so expand manually) - "BLOCK_SIZE_M <= 4 * BLOCK_SIZE_N", - "BLOCK_SIZE_N <= 4 * BLOCK_SIZE_M", - - # large K only with reasonably large M/N - "not (BLOCK_SIZE_K == 128 and BLOCK_SIZE_M < 64 and BLOCK_SIZE_N < 64)", - - # 32x32 requires 8 warps - "not (BLOCK_SIZE_M == 32 and BLOCK_SIZE_N == 32 and num_warps < 8)", - ] - - grid_div = ["BLOCK_SIZE_N", "BLOCK_SIZE_M"] - - answer = [None] * 12 - answer[2] = c_expect.cpu() - - results, env = tune_kernel("matmul_kernel", matmul_kernel, size, args, tune_params, grid_div_x = grid_div, - answer = answer, atol=4.0, restrictions=restrictions, - lang="generic_python", decorator="@triton.jit", call_function=call_triton, - block_size_names=["BLOCK_SIZE_M", "BLOCK_SIZE_N", "BLOCK_SIZE_K"], strategy="simulated_annealing") - - - - - -if __name__ == "__main__": - m, n, k = 8192, 8192, 8192 - #m, n, k = 4096, 4096, 4096 - #tune_matmul(m, n, k) - #check_time(m, n, k) - #run_matmul_kt(m, n, k) - - - - - - diff --git a/examples/generic_python/normalization/tilelang_norm.py b/examples/generic_python/normalization/tilelang_norm.py deleted file mode 100644 index 7c26d8f5f..000000000 --- a/examples/generic_python/normalization/tilelang_norm.py +++ /dev/null @@ -1,78 +0,0 @@ -# taken from https://github.com/tile-ai/tilelang/blob/main/examples/norm/rms_norm.py -# TODO not comparable -import torch -import tilelang -import tilelang.language as T - - -def rms_norm_splitk(M, N, blk_m, blk_k): - dtype = T.float - - @T.prim_func - def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): - with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx: - A_shared = T.alloc_shared((blk_m, blk_k), dtype) - A_local = T.alloc_fragment((blk_m, blk_k), dtype) - A_powsum = T.alloc_fragment((blk_m,), dtype) - - num_k_step = T.ceildiv(N, blk_k) - T.clear(A_local) - for k in range(num_k_step): - T.copy(A[bx * blk_m, k * blk_k], A_shared) - for i, j in T.Parallel(blk_m, blk_k): - A_local[i, j] += A_shared[i, j] * A_shared[i, j] - T.reduce_sum(A_local, A_powsum, dim=1) - for i in T.Parallel(blk_m): - A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12) - - for k in range(num_k_step): - # reverse, better cache hit rate - T.copy(A[bx * blk_m, (num_k_step - 1 - k) * blk_k], A_shared) - for i, j in T.Parallel(blk_m, blk_k): - A_shared[i, j] *= A_powsum[i] - T.copy(A_shared, B[bx * blk_m, (num_k_step - 1 - k) * blk_k]) - - return main - - -@tilelang.jit(out_idx=[-1], pass_configs={"tl.disable_tma_lower": True}) -def rms_norm(M, N, blk_m): - dtype = T.float - - @T.prim_func - def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): - with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx: - A_shared = T.alloc_shared((blk_m, N), dtype) - A_pow_local = T.alloc_fragment((blk_m, N), dtype) - A_local = T.alloc_fragment((blk_m, N), dtype) - A_powsum = T.alloc_fragment((blk_m,), dtype) - - T.copy(A[bx * blk_m : (bx + 1) * blk_m, :], A_shared) - T.copy(A_shared, A_local) - for i, j in T.Parallel(blk_m, N): - A_pow_local[i, j] = A_local[i, j] * A_local[i, j] - T.reduce_sum(A_pow_local, A_powsum, dim=1) - for i in T.Parallel(blk_m): - A_powsum[i] = T.rsqrt(A_powsum[i] / N + 1e-12) - for i, j in T.Parallel(blk_m, N): - A_local[i, j] *= A_powsum[i] - T.copy(A_local, B[bx * blk_m : (bx + 1) * blk_m, :]) - - return main - - -def ref_program(x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-12) - - -if __name__ == "__main__": - M, N, blk_m, blk_k = 8192, 8192, 1, 512 - kernel = rms_norm(M, N, blk_m) - profiler = kernel.get_profiler() - profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) - print("All checks pass.") - - latency = profiler.do_bench(ref_program, warmup=500) - print("Ref: {:.2f} ms".format(latency)) - latency = profiler.do_bench(warmup=500) - print("Tile-lang: {:.2f} ms".format(latency)) \ No newline at end of file diff --git a/examples/generic_python/normalization/tilus_norm.py b/examples/generic_python/normalization/tilus_norm.py deleted file mode 100644 index cfe569e80..000000000 --- a/examples/generic_python/normalization/tilus_norm.py +++ /dev/null @@ -1,132 +0,0 @@ -# taken from https://github.com/NVIDIA/tilus/blob/main/examples/norm/layer_norm.py -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -import pandas -import tilus -import torch -from tilus import float16, float32, int32 -from tilus.utils import benchmark_func, cdiv - - -@tilus.autotune("block_m", [1, 8]) -@tilus.autotune("block_n", [128, 256, 512, 1024]) -@tilus.autotune("warps", [2, 4, 8]) -class LayerNorm(tilus.Script): - """Forward-only layer normalization tilus kernel. - - This implements the per-row LayerNorm used in many transformer blocks. It - computes: y = (x - mean) / sqrt(var + eps) * gamma + beta - - Only the forward is provided. - """ - - def __init__(self, block_m: int, block_n: int, warps: int): - super().__init__() - self.block_m: int = block_m - self.block_n: int = block_n - self.warps: int = warps - - def __call__( - self, - m_size: int, - n_size: int32, - x_ptr: ~float16, - gamma_ptr: ~float16, - beta_ptr: ~float16, - y_ptr: ~float16, - eps: float, - ): - self.attrs.blocks = (cdiv(m_size, self.block_m),) - self.attrs.warps = self.warps - - offset_m = self.blockIdx.x * self.block_m - - g_x = self.global_view(x_ptr, dtype=float16, shape=[m_size, n_size]) - g_y = self.global_view(y_ptr, dtype=float16, shape=[m_size, n_size]) - g_gamma = self.global_view(gamma_ptr, dtype=float16, shape=[n_size]) - g_beta = self.global_view(beta_ptr, dtype=float16, shape=[n_size]) - - # Register accumulators for mean and variance (computed in float32) - r_sum = self.register_tensor( - dtype=float32, shape=[self.block_m, self.block_n], init=0.0 - ) - r_square = self.register_tensor( - dtype=float32, shape=[self.block_m, self.block_n], init=0.0 - ) - - # first pass: compute mean and variance - for offset_n in range(0, n_size, self.block_n): - r_x = self.load_global( - g_x, offsets=[offset_m, offset_n], shape=[self.block_m, self.block_n] - ).to(float32) # [block_m, block_n] - r_sum = r_sum + r_x - r_square = r_square + self.square(r_x) - - # finalize mean and variance - r_mean = self.sum(r_sum, dim=1, keepdim=True) / n_size # [block_m, 1] - r_var = ( - self.sum(r_square, dim=1, keepdim=True) / n_size - r_mean * r_mean - ) # [block_m, 1], var = E[x^2] - (E[x])^2 - r_rstd = self.rsqrt(r_var + eps) - - # second pass: y = (x - mean) * rstd * gamma + beta - for offset_n in range(0, n_size, self.block_n): - r_x = self.load_global( - g_x, offsets=[offset_m, offset_n], shape=[self.block_m, self.block_n] - ).to(float32) # [block_m, block_n] - r_gamma = self.load_global( - g_gamma, offsets=[offset_n], shape=[self.block_n] - ).to(float32) # [block_n] - r_beta = self.load_global( - g_beta, offsets=[offset_n], shape=[self.block_n] - ).to(float32) # [block_n] - r_x_hat = (r_x - r_mean) * r_rstd - r_y = r_x_hat * r_gamma + r_beta - self.store_global(g_y, r_y.to(float16), offsets=[offset_m, offset_n]) - - -def main(): - headers = ["m_size", "n_size", "dtype", "torch (ms)", "tilus (ms)"] - rows = [] - for i in [1, 2, 4, 8]: - m_size = n_size = 1024 * i - - tilus_layer_norm = LayerNorm() - - x = (torch.rand(m_size, n_size, dtype=torch.float16).cuda() - 0.5) * 2.0 - gamma = torch.rand(n_size, dtype=torch.float16).cuda() - beta = torch.rand(n_size, dtype=torch.float16).cuda() - y_actual = torch.empty_like(x) - - tilus_layer_norm(m_size, n_size, x, gamma, beta, y_actual, 1e-5) - y_expected = torch.nn.functional.layer_norm( - x, normalized_shape=[n_size], weight=gamma, bias=beta, eps=1e-5 - ) - - torch.testing.assert_close(y_actual, y_expected, atol=1e-2, rtol=1e-2) - - rows.append( - [ - m_size, - n_size, - "float16", - benchmark_func( - lambda: torch.nn.functional.layer_norm( - x, normalized_shape=[n_size], weight=gamma, bias=beta, eps=1e-5 - ) - ), - benchmark_func( - lambda: tilus_layer_norm( - m_size, n_size, x, gamma, beta, y_actual, 1e-5 - ) - ), - ] - ) - print(f"LayerNorm forward matches reference for size ({m_size}, {n_size})") - - df = pandas.DataFrame(rows, columns=headers) - print(df) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/examples/generic_python/normalization/triton_norm.py b/examples/generic_python/normalization/triton_norm.py deleted file mode 100644 index 365f541c4..000000000 --- a/examples/generic_python/normalization/triton_norm.py +++ /dev/null @@ -1,66 +0,0 @@ - -#Example taken from https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html -import torch - -import triton -import triton.language as tl - -try: - # This is https://github.com/NVIDIA/apex, NOT the apex on PyPi, so it - # should not be added to extras_require in setup.py. - import apex - HAS_APEX = True -except ModuleNotFoundError: - HAS_APEX = False - -DEVICE = triton.runtime.driver.active.get_active_torch_device() - - -@triton.jit -def _layer_norm_fwd_fused( - X, # pointer to the input - Y, # pointer to the output - W, # pointer to the weights - B, # pointer to the biases - Mean, # pointer to the mean - Rstd, # pointer to the 1/std - stride, # how much to increase the pointer when moving by 1 row - N, # number of columns in X - eps, # epsilon to avoid division by zero - BLOCK_SIZE: tl.constexpr, -): - # Map the program id to the row of X and Y it should compute. - row = tl.program_id(0) - Y += row * stride - X += row * stride - # Compute mean - mean = 0 - _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) - _mean += a - mean = tl.sum(_mean, axis=0) / N - # Compute variance - _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32) - x = tl.where(cols < N, x - mean, 0.) - _var += x * x - var = tl.sum(_var, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) - # Write mean / rstd - tl.store(Mean + row, mean) - tl.store(Rstd + row, rstd) - # Normalize and apply linear transformation - for off in range(0, N, BLOCK_SIZE): - cols = off + tl.arange(0, BLOCK_SIZE) - mask = cols < N - w = tl.load(W + cols, mask=mask) - b = tl.load(B + cols, mask=mask) - x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32) - x_hat = (x - mean) * rstd - y = x_hat * w + b - # Write output - tl.store(Y + cols, y, mask=mask) \ No newline at end of file diff --git a/examples/generic_python/numba_vec_add.py b/examples/generic_python/numba_vec_add.py index b5a2c56f3..3bc90de86 100644 --- a/examples/generic_python/numba_vec_add.py +++ b/examples/generic_python/numba_vec_add.py @@ -1,10 +1,8 @@ from numba import cuda -import torch -from kernel_tuner import tune_kernel, run_kernel import numpy as np -from pathlib import Path -FULL_PATH = Path(__file__).resolve() +from kernel_tuner import tune_kernel +from call_functions import call_numba @cuda.jit @@ -16,16 +14,6 @@ def f(a, b, c): c[tid] = a[tid] + b[tid] -def call_numba(kernel_function, args, kwargs, grid, threads): - numba_args = [] - for arg in args: - if isinstance(arg, torch.Tensor): - numba_args.append(cuda.as_cuda_array(arg)) - else: - numba_args.append(arg) - kernel_function[grid, threads](*args, **kwargs) - - def tune(): N = 100000 @@ -41,7 +29,7 @@ def tune(): results, env = tune_kernel( kernel_name="f", - kernel_source=FULL_PATH, + kernel_source=__file__, problem_size=N, arguments=args, tune_params=tune_params, diff --git a/examples/generic_python/pallas_vec_add.py b/examples/generic_python/pallas_vec_add.py deleted file mode 100644 index 874392dd9..000000000 --- a/examples/generic_python/pallas_vec_add.py +++ /dev/null @@ -1,23 +0,0 @@ -# Could not get working - -from functools import partial - -import jax -from jax.experimental import pallas as pl -import jax.numpy as jnp -import numpy as np - -def add_vectors_kernel(x_ref, y_ref, o_ref): - x, y = x_ref[...], y_ref[...] - o_ref[...] = x + y - - -@jax.jit -def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: - return pl.pallas_call( - add_vectors_kernel, - out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype) - )(x, y) - - -add_vectors(jnp.arange(8), jnp.arange(8)) \ No newline at end of file diff --git a/examples/generic_python/tilelang_vec_add.py b/examples/generic_python/tilelang_vec_add.py index 99463d112..38b4ab309 100644 --- a/examples/generic_python/tilelang_vec_add.py +++ b/examples/generic_python/tilelang_vec_add.py @@ -1,10 +1,10 @@ import tilelang import tilelang.language as T import torch + from kernel_tuner import tune_kernel -from pathlib import Path +from call_functions import call_tilelang -FULL_PATH = Path(__file__).resolve() @tilelang.jit # infers target from tensors at first call def add(N: int, dtype: str = 'float32', block: int = 256,): @@ -24,23 +24,6 @@ def add_kernel( return add_kernel -def run_normal(): - # Host side (PyTorch shown; NumPy/DLPack also supported) - N = 1 << 20 - A = torch.randn(N, device='cuda', dtype=torch.float32) - B = torch.randn(N, device='cuda', dtype=torch.float32) - C = torch.empty(N, device='cuda', dtype=torch.float32) - - kernel = add(N) - kernel(A, B, C) # runs on GPU - torch.testing.assert_close(C, A + B) - print("done") - - -def call_tilelang(kernel_function, args, kwargs): - compiled_kernel = kernel_function(**kwargs) # cached, so second time only cache lookup is performed - compiled_kernel(*args) - def tune(): N = 1 << 20 A = torch.randn(N, device='cuda', dtype=torch.float32) @@ -54,7 +37,7 @@ def tune(): answer = [None, None, (A + B).cpu()] - res, env = tune_kernel("add", FULL_PATH, N, args, tune_params, lang="generic_python", + res, env = tune_kernel("add", __file__, N, args, tune_params, lang="generic_python", call_function=call_tilelang, answer=answer) if __name__ == "__main__": diff --git a/examples/generic_python/tilus_naive_matmul.py b/examples/generic_python/tilus_naive_matmul.py deleted file mode 100644 index c53c22f73..000000000 --- a/examples/generic_python/tilus_naive_matmul.py +++ /dev/null @@ -1,127 +0,0 @@ -from tilus import float16, float32, int32 -from tilus.utils import cdiv -import tilus -from kernel_tuner import tune_kernel, run_kernel -import math -import torch -from pathlib import Path - -FULL_PATH = Path(__file__).resolve() - - -class MatmulV0(tilus.Script): - def __init__(self): - super().__init__() - # we define three hyperparameters: ``block_m``, ``block_n``, and ``block_k`` to determine the tile size on - # m, n, and k dimensions for each `thread block` of the kernel. - self.block_m = 64 - self.block_n = 64 - self.block_k = 16 - - def __call__( - self, - m_size: int32, # the size of the m dimension of the input matrix A and output matrix C - n_size: int, # the size of the n dimension of the input matrix B and output matrix C - k_size: int, # the size of the k dimension of the input matrix A and B - a_ptr: ~float16, # the pointer to the input matrix A, which is a 2D tensor of shape [m_size, k_size] - b_ptr: ~float16, # the pointer to the input matrix B, which is a 2D tensor of shape [k_size, n_size] - c_ptr: ~float16, # the pointer to the output matrix C, which is a 2D tensor of shape [m_size, n_size] - ): - self.attrs.blocks = [ - cdiv(m_size, self.block_m), # the x dimension size of the grid - cdiv(n_size, self.block_n), # the y dimension size of the grid - ] - self.attrs.warps = 1 # the number of warps per thread block, must be a compile-time known integer - - # define two int32 variables to store the offsets of the m and n dimensions for the current thread block. - offset_m: int32 = self.block_m * self.blockIdx.x - offset_n: int32 = self.block_n * self.blockIdx.y - - # create two global tensors `ga` and `gb` to represent the input matrices A and B, respectively. - ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) - gb = self.global_view(b_ptr, dtype=float16, shape=[k_size, n_size]) - - # create a register tensor `acc` to accumulate the results of the matrix multiplication. - acc = self.register_tensor( - dtype=float32, shape=[self.block_m, self.block_n], init=0.0 - ) - - # iterate over the k dimension in blocks of size `block_k`. - for k in range(cdiv(k_size, self.block_k)): - # calculate the offset for the current block in the k dimension - offset_k = k * self.block_k - - # load a block of matrix A and B into register tensors `a` and `b`. - a = self.load_global( - ga, offsets=[offset_m, offset_k], shape=[self.block_m, self.block_k] - ) - b = self.load_global( - gb, offsets=[offset_k, offset_n], shape=[self.block_k, self.block_n] - ) - - # perform the dot product: acc = a @ b + acc - self.dot(a, b, acc, out=acc) - - # after the loop, we cast the accumulated result `acc` to float16 type and store it back to the output matrix C. - acc_f16 = self.cast(acc, dtype=float16) - gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) - self.store_global(gc, acc_f16, offsets=[offset_m, offset_n]) - - -def call_tilus(kernel_function, args, kwargs): - kernel_function(*args, **kwargs) - -def main(): - m, n, k = 4096, 4096, 4096 - - # create an instance of the kernel we have just defined - matmul = MatmulV0() - - a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) - b = (torch.rand(k, n, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) - c_actual = torch.empty(m, n, dtype=torch.float16).cuda() - c_expect = a @ b - - ''' - torch.cuda.synchronize() - # launch the kernel by passing required arguments - matmul(m, n, k, a, b, c_actual) - torch.cuda.synchronize() - - # check correctness - torch.testing.assert_close(c_expect, c_actual, atol=1e-2, rtol=1e-2) - ''' - - args = [m, n, k, a, b, c_actual] - tune_params = dict() - tune_params["block_m"] = [16, 32, 64, 128, 256] - tune_params["block_n"] = [16, 32, 64, 128, 256] - tune_params["block_k"] = [16, 32, 64, 128, 256] - - - restrictions = [ - "block_m * block_n <= 4096", - "block_m * block_k <= 2048", - "block_k * block_n <= 4096", - "block_k >= 16", - ] - - results, env = tune_kernel( - kernel_name="MatmulV0", - kernel_source=FULL_PATH, - problem_size=[m, n], - arguments=args, - tune_params=tune_params, - lang="generic_python", - answer=[None, None, None, None, None, c_expect.cpu()], - restrictions=restrictions, - block_size_names=["block_m", "block_n", "block_k"], - call_function=call_tilus, - strategy="simulated_annealing" - ) - - - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/examples/generic_python/tilus_splitk_matmul.py b/examples/generic_python/tilus_splitk_matmul.py deleted file mode 100644 index 0695dcceb..000000000 --- a/examples/generic_python/tilus_splitk_matmul.py +++ /dev/null @@ -1,255 +0,0 @@ -import tilus -from tilus import float16, float32, int32 -from tilus.utils import cdiv, benchmark_func -import torch -import math -from kernel_tuner import tune_kernel, run_kernel -from pathlib import Path - -FULL_PATH = Path(__file__).resolve() - -class MatmulV5(tilus.Script): - def __init__(self, block_m=None, block_n=None, block_k=None, - num_warps=None, num_stages=None, split_k_factor=None): - super().__init__() - self.block_m = block_m - self.block_n = block_n - self.block_k = block_k - self.num_warps = num_warps - self.num_stages = num_stages - self.split_k_factor = split_k_factor - - - def __call__( - self, - m_size: int32, - n_size: int, - k_size: int, - a_ptr: ~float16, - b_ptr: ~float16, - c_ptr: ~float16, - ): - self.attrs.blocks = [ - cdiv(m_size, self.block_m), - cdiv(n_size, self.block_n), - self.split_k_factor, - ] - self.attrs.warps = self.num_warps - - # the k_size for each thread block - block_k_size = ( - cdiv(cdiv(k_size, self.split_k_factor), self.block_k) * self.block_k - ) - start_offset_k = self.blockIdx.z * block_k_size - end_offset_k = min(start_offset_k + block_k_size, k_size) - - block_m, block_n, block_k = self.block_m, self.block_n, self.block_k - offset_m: int32 = block_m * self.blockIdx.x - offset_n: int32 = block_n * self.blockIdx.y - - ga = self.global_view(a_ptr, dtype=float16, shape=[m_size, k_size]) - gb = self.global_view(b_ptr, dtype=float16, shape=[k_size, n_size]) - sa = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_m, block_k]) - sb = self.shared_tensor(dtype=float16, shape=[self.num_stages, block_k, block_n]) - acc = self.register_tensor(dtype=float32, shape=[block_m, block_n], init=0.0) - - for stage in range(self.num_stages - 1): - offset_k = start_offset_k + stage * self.block_k - self.copy_async(src=ga, dst=sa[stage], offsets=[offset_m, offset_k]) - self.copy_async(src=gb, dst=sb[stage], offsets=[offset_k, offset_n]) - self.copy_async_commit_group() - - self.copy_async_wait_group(n=self.num_stages - 2) - self.sync() - - current_stage: int32 = 0 - preload_stage: int32 = self.num_stages - 1 - for offset_k in self.range( - start_offset_k, end_offset_k, block_k, unroll=self.num_stages - ): - # computation for current tile - a = self.load_shared(sa[current_stage]) - b = self.load_shared(sb[current_stage]) - self.dot(a, b, acc, out=acc) - - # preload the next tile of A and B into shared memory - preload_offset_k = offset_k + (self.num_stages - 1) * block_k - if preload_offset_k < end_offset_k: - self.copy_async( - src=ga, - dst=sa[preload_stage], - offsets=[offset_m, preload_offset_k], - ) - self.copy_async( - src=gb, - dst=sb[preload_stage], - offsets=[preload_offset_k, offset_n], - ) - self.copy_async_commit_group() - - # update the stage - current_stage = (current_stage + 1) % self.num_stages - preload_stage = (preload_stage + 1) % self.num_stages - self.copy_async_wait_group(n=self.num_stages - 2) - self.sync() - - # free the shared memory tensors for A and B - self.free_shared(sa) - self.free_shared(sb) - - # cast the accumulator to float16 and change the register tensor's layout - sc = self.shared_tensor(dtype=float16, shape=[block_m, block_n]) - casted_acc = self.cast(acc, dtype=float16) - self.store_shared(sc, casted_acc) - self.sync() - rc = self.load_shared(sc) - self.free_shared(sc) - - m_blocks, n_blocks = cdiv(m_size, block_m), cdiv(n_size, block_n) - gc = self.global_view(c_ptr, dtype=float16, shape=[m_size, n_size]) - if self.split_k_factor == 0: - self.store_global(gc, rc, offsets=[offset_m, offset_n]) - else: - semaphores = self.global_tensor( - dtype=int32, shape=[m_blocks, n_blocks], requires_clean=True - ) - semaphore: ~int32 = ~semaphores[self.blockIdx.x, self.blockIdx.y] - - # load and accumulate the partial result in global memory - if self.blockIdx.z > 0: - self.lock_semaphore(semaphore, value=self.blockIdx.z) - partial_rc = self.load_global( - gc, offsets=[offset_m, offset_n], shape=[block_m, block_n] - ) - self.add(rc, partial_rc, out=rc) - - # store the result to global memory and release the semaphore - self.store_global(gc, rc, offsets=[offset_m, offset_n]) - - # release the semaphore - self.sync() # we need to make sure the previous store_global is finished - self.release_semaphore( - semaphore, value=(self.blockIdx.z + 1) % self.split_k_factor - ) - - -def call_tilus(kernel_function, args, kwargs): - kernel_function(*args, **kwargs) - - -def without_kernel_tuner(): - tilus.option.clear_cache = True - - tilus.option.verbose_autotune = True - m, n, k = 4096, 4096, 4096 - - # create an instance of the kernel we have just defined - matmul = MatmulV5() - - a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) - b = (torch.rand(k, n, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) - c_actual = torch.empty(m, n, dtype=torch.float16).cuda() - c_expect = a @ b - - - torch.cuda.synchronize() - # launch the kernel by passing required arguments - matmul(m, n, k, a, b, c_actual) - torch.cuda.synchronize() - - # check correctness - torch.testing.assert_close(c_expect, c_actual, atol=1e-2, rtol=1e-2) - - - import pandas - rows = [] - headers = ["m", "n", "k", "name", "latency (ms)", "tflops"] - # benchmark - for name, func in [ - ("torch", lambda: torch.matmul(a, b, out=c_expect)), - ("tilus", lambda: matmul(m, n, k, a, b, c_actual)), - ]: - latency = benchmark_func(func, warmup=5, repeat=20) - tflops = 2 * m * n * k / latency * 1e-9 - rows.append([m, n, k, name, latency, tflops]) - - df = pandas.DataFrame(rows, columns=headers) - print(df) - - - -#best performing configuration: -#block_m=128, block_n=64, block_k=16, num_warps=4, num_stages=4, split_k_factor=1, time=2.027ms - -def main(): - m, n, k = 4096, 4096, 4096 - - # create an instance of the kernel we have just defined - #matmul = MatmulV5() - - a = (torch.rand(m, k, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) - b = (torch.rand(k, n, dtype=torch.float16).cuda() - 0.5) / math.sqrt(k) - c_actual = torch.empty(m, n, dtype=torch.float16).cuda() - c_expect = a @ b - - ''' - torch.cuda.synchronize() - # launch the kernel by passing required arguments - matmul(m, n, k, a, b, c_actual) - torch.cuda.synchronize() - - # check correctness - torch.testing.assert_close(c_expect, c_actual, atol=1e-2, rtol=1e-2) - ''' - - args = [m, n, k, a, b, c_actual] - tune_params = dict() - tune_params["block_m"] = [32, 64, 128] #[16, 32, 64, 128, 256] - tune_params["block_n"] = [32, 64, 128] #[16, 32, 64, 128, 256] - tune_params["block_k"] = [16, 32] #[16, 32, 64, 128, 256] - tune_params["num_warps"] = [4, 8] #[2, 4, 8, 16] - tune_params["num_stages"] = [3, 4, 5] #[2, 3, 4, 5, 6] - tune_params["split_k_factor"] = [1, 4, 12, 16] #[1, 4, 12, 16, 20] - -#@tilus.autotune("num_warps", [4, 8]) -#@tilus.autotune("block_m, block_n", [(128, 128), (128, 64), (64, 128), (32, 256)]) -#@tilus.autotune("block_k", [16, 32]) -#@tilus.autotune("num_stages", [3, 4, 5]) -#@tilus.autotune("split_k_factor", [1, 4, 12, 16]) - - ''' - restrictions = [ - "block_m * block_n <= 4096", - "block_m * block_k <= 2048", - "block_k * block_n <= 4096", - "block_k >= 16", - "2 * (num_stages * block_k * (block_m + block_n) + block_m * block_n) <= 65536", # shared mem - "num_warps * 32 <= 1024", - "block_k * split_k_factor <= 4096", - ] - ''' - restrictions = ["block_m * block_n >= 8192", "block_m * block_n <= 16384"] - - - - results, env = tune_kernel( - kernel_name="MatmulV5", # This has to be a string of the actual name. TODO is this always the case? - kernel_source=FULL_PATH, - problem_size=[m, n], - arguments=args, - tune_params=tune_params, - lang="generic_python", - answer=[None, None, None, None, None, c_expect.cpu()], - atol=1e-2, - restrictions=restrictions, - block_size_names=["block_m", "block_n", "block_k"], - call_function=call_tilus, - #strategy="random_sample", - ) - - - - -if __name__ == "__main__": - main() - #without_kernel_tuner() \ No newline at end of file diff --git a/examples/generic_python/tilus_vec_add.py b/examples/generic_python/tilus_vec_add.py index d00c77972..8ee96aa07 100644 --- a/examples/generic_python/tilus_vec_add.py +++ b/examples/generic_python/tilus_vec_add.py @@ -1,11 +1,13 @@ import tilus from tilus import float32, int32 -from tilus.utils import cdiv, benchmark_func +from tilus.utils import cdiv import torch + from kernel_tuner import tune_kernel, run_kernel -from pathlib import Path +from call_functions import call_tilus + + -FULL_PATH = Path(__file__).resolve() class VecAddV(tilus.Script): def __init__(self, block_size_x=None, num_warps=None): @@ -38,8 +40,7 @@ def __call__( self.store_global(gc, c, offsets=[offset]) -def call_tilus(kernel_function, args, kwargs): - kernel_function(*args, **kwargs) + def tune(size): @@ -57,7 +58,7 @@ def tune(size): results, env = tune_kernel( kernel_name="VecAddV", - kernel_source=FULL_PATH, + kernel_source=__file__, problem_size=size, arguments=args, tune_params=tune_params, @@ -79,7 +80,7 @@ def run(size): results = run_kernel( kernel_name="VecAddV", - kernel_source=FULL_PATH, + kernel_source=__file__, problem_size=size, arguments=args, params={"block_size_x": 32}, @@ -92,23 +93,9 @@ def run(size): assert torch.allclose(results[-1], c_expect) -def tune_with_builtin(size): - TunedVecAdd = tilus.autotune("block_size_x", [32, 64, 128, 256, 512, 1024])(VecAddV) - vecadd = TunedVecAdd() - - a = torch.randn(size, dtype=torch.float32).cuda() - b = torch.randn(size, dtype=torch.float32).cuda() - c = torch.empty(size, dtype=torch.float32).cuda() - c_expect = a + b - - vecadd(size, a, b, c) # This is where the actual tuning takes place - torch.cuda.synchronize() - - torch.testing.assert_close(c_expect, c) - if __name__ == "__main__": size = 10000000 - #tune(size) - #run(size) - tune_with_builtin(size) + tune(size) + run(size) + diff --git a/examples/generic_python/triton_vec_add.py b/examples/generic_python/triton_vec_add.py index 00bd9dfbc..cafba79a9 100644 --- a/examples/generic_python/triton_vec_add.py +++ b/examples/generic_python/triton_vec_add.py @@ -2,11 +2,9 @@ import torch import triton import triton.language as tl -from pathlib import Path from kernel_tuner import tune_kernel, run_kernel - -FULL_PATH = Path(__file__).resolve() +from call_functions import call_triton @triton.jit def add_op(x, y): @@ -30,9 +28,6 @@ def add_kernel(x_ptr, # *Pointer* to first input vector. tl.store(output_ptr + offsets, output, mask=mask) -def call_triton(kernel_function, args, kwargs, grid): - kernel_function[grid](*args, **kwargs) - def tune(): size = 10000000 @@ -48,7 +43,7 @@ def tune(): tune_params["block_size_x"] = [2**i for i in range(11)] - result = run_kernel("add_kernel", FULL_PATH, size, args, {"block_size_x": 256}, + result = run_kernel("add_kernel", __file__, size, args, {"block_size_x": 256}, lang="generic_python", call_function=call_triton) assert np.allclose(c_expect.cpu(), result[2]) @@ -56,7 +51,7 @@ def tune(): results, env = tune_kernel( kernel_name="add_kernel", - kernel_source=FULL_PATH, + kernel_source=__file__, problem_size=size, arguments=args, tune_params=tune_params, @@ -65,8 +60,6 @@ def tune(): call_function=call_triton, ) - print(results) - if __name__ == "__main__": diff --git a/examples/generic_python/warp_vec_add.py b/examples/generic_python/warp_vec_add.py index 4f94c91d2..81d7fb19e 100644 --- a/examples/generic_python/warp_vec_add.py +++ b/examples/generic_python/warp_vec_add.py @@ -1,11 +1,13 @@ import warp as wp import numpy as np -from kernel_tuner import tune_kernel, run_kernel import torch -from pathlib import Path + +from kernel_tuner import tune_kernel +from call_functions import call_warp + wp.init() -FULL_PATH = Path(__file__).resolve() + @wp.func def add_op(x: float, y: float): @@ -55,7 +57,7 @@ def tune(): results, env = tune_kernel( kernel_name="vec_add", - kernel_source=FULL_PATH, + kernel_source=__file__, problem_size=n, arguments=args, tune_params=tune_params,