def_forward_from_prebuild_lib(self, *args, stream: Optional[int] = None): """Low-level function to call the compiled CUDA kernel. Converts PyTorch tensor pointers to C void pointers for ctypes interface. """ ctypes_args = [ ctypes.c_void_p(arr.data_ptr()) ifnotisinstance(arr, int) else arr for arr in args ] ctypes_args.append(ctypes.c_void_p(stream)) self.lib.call(*ctypes_args)
def_warp_forward_from_prebuild_lib(self, *ins: List[torch.Tensor], stream: Optional[int] = None): """High-level wrapper for kernel execution. Handles: 1. Input validation 2. Output tensor allocation 3. Dynamic shape resolution 4. CUDA stream management Args: ins: Input PyTorch tensors stream: Optional CUDA stream for asynchronous execution Returns: Single tensor or list of tensors containing the kernel results """ iflen(ins) + len(self.result_idx) != len(self.params): raise ValueError( f"Expected {len(self.params)} inputs, got {len(ins) + len(self.result_idx)} with {len(ins)} inputs and {len(self.result_idx)} outputs" ) ins_idx = 0 args = []
# tensor pointers for i inrange(len(self.params)): if i in self.result_idx: dtype = torch.__getattribute__(str(self.params[i].dtype)) shape = list(map(int, self.params[i].shape)) # use the device of the first input tensor if available device = ins[0].device iflen(ins) > 0else torch.cuda.current_device() tensor = torch.empty(*shape, dtype=dtype, device=device) else: tensor = ins[ins_idx] ins_idx += 1 args.append(tensor)
# dynamic symbolics for _, (buffer_idx, shape_idx) in self.dynamic_symbolic_map.items(): args.append(ins[buffer_idx].shape[shape_idx])
# if stream is not None, we need to pass the stream to the library if stream isNone: stream = torch.cuda.current_stream().cuda_stream
import torch cimport cython import ctypes from libc.stdint cimport int64_t, uintptr_t from libc.stdlib cimport malloc, free
cdef classCythonKernelWrapper: # Class attributes to store kernel configuration and library reference cdef: object dynamic_symbolic_map # Maps dynamic dimensions to their corresponding tensor indices list result_idx # Indices of output tensors in the params list list params # List of parameter specifications (includes both inputs and outputs) object lib # Reference to the compiled library containing the kernel
cpdef forward(self, list inputs, int stream = -1): # Validate input dimensions and prepare for kernel execution cdef int total_params = len(self.params) cdef int total_inputs = len(inputs) cdef int total_result_idx = len(self.result_idx) cdef int total_dynamic_symbolics = len(self.dynamic_symbolic_map)
# Ensure the number of inputs matches expected parameter count if total_params != total_inputs + total_result_idx: raise ValueError( f"Expected {len(self.params)} inputs, got {len(inputs) + len(self.result_idx)} with {len(inputs)} inputs and {len(self.result_idx)} outputs" )
# Use current CUDA stream if none specified if stream == -1: stream = <uintptr_t>torch.cuda.current_stream().cuda_stream
cdef int ins_idx = 0 cdef list tensor_list = [] cdef list call_args = []
# Prepare input and output tensors for i inrange(len(self.params)): if i in self.result_idx: # Create empty output tensor with specified dtype and shape dtype = torch.__getattribute__(str(self.params[i].dtype)) shape = list(map(int, self.params[i].shape)) device = inputs[0].device iflen(inputs) > 0else torch.cuda.current_device() tensor = torch.empty(*shape, dtype=dtype, device=device) else: # Use provided input tensor tensor = inputs[ins_idx] ins_idx += 1 tensor_list.append(tensor)
# Convert tensor pointers to C void pointers for kernel call call_args = [ctypes.c_void_p(tensor_list[i].data_ptr()) for i inrange(len(tensor_list))]
# Add dynamic dimension values to kernel arguments for _, (buffer_idx, shape_idx) in self.dynamic_symbolic_map.items(): call_args.append(tensor_list[buffer_idx].shape[shape_idx])
# Add CUDA stream to kernel arguments call_args.append(ctypes.c_void_p(stream))
# Execute the kernel self.lib.call(*call_args)
# Return output tensor(s) iflen(self.result_idx) == 1: return tensor_list[self.result_idx[0]] else: return [tensor_list[i] for i in self.result_idx]
Comments