Reductions and FFT Workflows¶
This page covers common high-level numeric workflows in vkdispatch:
reductions with
vd.reduceFourier transforms with
vd.fftVkFFT-backed transforms with
vd.vkfft
FFT Subsystem Overview¶
vkdispatch provides two FFT backends:
vd.fft: vkdispatch-generated shaders (runtime code generation).vd.vkfft: VkFFT-backed plan execution.
Use vd.fft when you want shader-level customization and fusion through mapping
hooks (input_map, output_map, kernel_map). Use vd.vkfft when you want
the VkFFT path with plan caching and a similar high-level API.
Reduction Basics¶
Use @vd.reduce.reduce for pure binary reductions:
import numpy as np
import vkdispatch as vd
from vkdispatch.codegen.abreviations import *
@vd.reduce.reduce(0)
def sum_reduce(a: f32, b: f32) -> f32:
return a + b
arr = np.random.rand(4096).astype(np.float32)
buf = vd.asbuffer(arr)
out = sum_reduce(buf).read(0)
print("GPU sum:", float(out[0]))
print("CPU sum:", float(arr.sum(dtype=np.float32)))
Mapped Reductions¶
Use @vd.reduce.map_reduce when you want a map stage before reduction:
import vkdispatch.codegen as vc
@vd.reduce.map_reduce(vd.reduce.SubgroupAdd)
def l2_energy_map(buffer: Buff[f32]) -> f32:
idx = vd.reduce.mapped_io_index()
v = buffer[idx]
return v * v
energy_buf = l2_energy_map(buf)
energy = energy_buf.read(0)[0]
This pattern is useful for sums of transformed values (norms, weighted sums, etc.).
FFT with vd.fft¶
The vd.fft module dispatches vkdispatch-generated FFT shaders.
import numpy as np
import vkdispatch as vd
complex_signal = (
np.random.rand(256) + 1j * np.random.rand(256)
).astype(np.complex64)
fft_buf = vd.asbuffer(complex_signal)
vd.fft.fft(fft_buf)
freq = fft_buf.read(0)
vd.fft.ifft(fft_buf)
recovered = fft_buf.read(0)
print(np.allclose(recovered, complex_signal, atol=1e-3))
By default, inverse transforms use normalization (normalize=True in vd.fft.ifft).
Set normalize=False when you need raw inverse scaling behavior.
To inspect generated FFT shaders, use:
vd.fft.fft(fft_buf, print_shader=True)
Axis and Dimensionality¶
FFT routines accept an axis argument for explicit axis control and provide fft2
and fft3 convenience functions.
# Strided FFT over the second axis of a 2D batch (from performance-test workflows).
batch = (
np.random.rand(8, 1024) + 1j * np.random.rand(8, 1024)
).astype(np.complex64)
batch_buf = vd.asbuffer(batch)
vd.fft.fft(batch_buf, axis=1)
# 2D transform helper (last two axes).
image = (
np.random.rand(512, 512) + 1j * np.random.rand(512, 512)
).astype(np.complex64)
image_buf = vd.asbuffer(image)
vd.fft.fft2(image_buf)
vd.fft.ifft2(image_buf)
Real FFT (RFFT) helpers:
real_signal = np.random.rand(512).astype(np.float32)
rbuf = vd.asrfftbuffer(real_signal)
vd.fft.rfft(rbuf)
spectrum = rbuf.read_fourier(0)
vd.fft.irfft(rbuf)
restored = rbuf.read_real(0)
print(np.allclose(restored, real_signal, atol=1e-3))
Fusion with kernel_map (Frequency-Domain In-Register Ops)¶
vd.fft.convolve can inject custom frequency-domain logic via kernel_map.
Inside a kernel map callback, vd.fft.read_op() exposes the current FFT register
being processed.
import vkdispatch.codegen as vc
@vd.map
def scale_spectrum(scale_factor: vc.Var[vc.f32]):
op = vd.fft.read_op()
op.register[:] = op.register * scale_factor
# Fused forward FFT + frequency scaling + inverse FFT
vd.fft.convolve(fft_buf, np.float32(0.5), kernel_map=scale_spectrum)
This pattern avoids a separate full-buffer dispatch for many pointwise spectral operations.
Input/Output Mapping for Padded or Sparse Regions¶
For advanced workflows (for example padded 2D cross-correlation), use input_map and
output_map to remap FFT I/O indices and input_signal_range to skip inactive
regions.
Map argument annotations do not determine FFT compute precision. read_op.register
and write_op.register always use the internal FFT compute type; map callbacks should
cast user-chosen buffer values to and from that register type as needed. If both FFT I/O
paths are mapped and compute_type is not provided, vd.fft defaults to
complex64 (falling back to complex32 when required by device support).
When output_map is provided without input_map, pass an explicit input buffer
argument after the output_map arguments so read and write phases use different proxies.
import vkdispatch.codegen as vc
def padded_axis_fft(buffer: vd.Buffer, signal_cols: int):
# Example expects buffer shape: (batch, rows, cols)
trimmed_shape = (buffer.shape[0], signal_cols, buffer.shape[2])
def remap(io_index: vc.ShaderVariable):
return vc.unravel_index(
vc.ravel_index(io_index, trimmed_shape).to_register(),
buffer.shape
)
@vd.map
def input_map(input_buffer: vc.Buffer[vc.c64]):
op = vd.fft.read_op()
op.read_from_buffer(input_buffer, io_index=remap(op.io_index))
@vd.map
def output_map(output_buffer: vc.Buffer[vc.c64]):
op = vd.fft.write_op()
op.write_to_buffer(output_buffer, io_index=remap(op.io_index))
vd.fft.fft(
buffer,
buffer,
buffer_shape=trimmed_shape,
axis=1,
input_map=input_map,
output_map=output_map,
input_signal_range=(0, signal_cols),
)
Transposed Kernel Path for 2D Convolution¶
When convolving along a strided axis, pre-transposing kernel layout can improve access
patterns. vd.fft provides helper APIs used by the benchmark suite:
# signal_buf and kernel_buf are complex buffers with compatible FFT shapes.
transposed_size = vd.fft.get_transposed_size(signal_buf.shape, axis=1)
kernel_t = vd.Buffer((transposed_size,), vd.complex64)
vd.fft.transpose(kernel_buf, axis=1, out_buffer=kernel_t)
vd.fft.fft(signal_buf)
vd.fft.convolve(signal_buf, kernel_t, axis=1, transposed_kernel=True)
vd.fft.ifft(signal_buf)
Low-Level Procedural FFT Generation with fft_context¶
For full control over read/compute/write staging, build FFT shaders procedurally using
vd.fft.fft_context and iterators from vd.fft:
import vkdispatch.codegen as vc
with vd.fft.fft_context(buffer_shape=(1024,), axis=0) as ctx:
args = ctx.declare_shader_args([vc.Buffer[vc.c64]])
for read_op in vd.fft.global_reads_iterator(ctx.registers):
read_op.read_from_buffer(args[0])
ctx.execute(inverse=False)
for write_op in vd.fft.global_writes_iterator(ctx.registers):
write_op.write_to_buffer(args[0])
fft_kernel = ctx.get_callable()
fft_kernel(fft_buf)
FFT with vd.vkfft¶
vd.vkfft exposes a similar API but routes operations through VkFFT plan objects
with internal plan caching.
vkfft_buf = vd.asbuffer(complex_signal.copy())
vd.vkfft.fft(vkfft_buf)
vd.vkfft.ifft(vkfft_buf)
print(np.allclose(vkfft_buf.read(0), complex_signal, atol=1e-3))
After large parameter sweeps, clearing cached plans can be helpful:
vd.vkfft.clear_plan_cache()
vd.fft.cache_clear()
Convolution Helpers¶
vkdispatch also includes FFT-based convolution helpers:
vd.fft.convolve/vd.fft.convolve2D/vd.fft.convolve2DRvd.vkfft.convolve2Dandvd.vkfft.transpose_kernel2D
These APIs are most useful when you repeatedly convolve signals/images with known kernel layouts.
Reduction and FFT API Reference¶
See the Full Python API Reference for complete API details on:
vkdispatch.reducevkdispatch.fftvkdispatch.vkfft