In Part 4 of this introduction, we saw that the performance of our convolution kernel is limited by memory bandwidth. We are going to see how to improve performance by using shared memory.
Shared memory is a memory that can be accessed by all the threads of a same block. Shared memory is way faster than global memory, but is also way smaller. The size varies depending on the device. For example, the default total amount of shared memory per block on a gtx 1070 is 48kB.
In Numba, we create a shared array thanks to cuda.shared.array(shape, dtype)
.
Adding shared arrays to our convolution kernel
In our previous kernel, the threads of each block access many times the same pixels of the image. To improve performance, we are going to save into shared memory the area of the image accessed by each block.
We continue to use a 13x13 mask and 2D blocks of dimensions (32, 32). As previously, we want each thread to compute the convolution at a given point. This means that each block computes 32x32 points of the convolution. To compute the result for an area of size 32x32, the mask being 13x13, we need to use a 44x44 image. Each block will therefore store in shared memory 44x44 values.
Furthermore, since the mask is also used by each thread of the block, we also store its values in shared memory.
Below is a modification of the previous code with shared arrays added. Note that:
- we fill the shared arrays using the threads. Since there are less threads than the size of the shared image, each thread has to fill more than one value. More specifically, we decide to use the threads to fill 4 values until the shared image is full.
- to be sure that we don’t start any computation before the shared arrays are totally filled, we use
cuda.syncthreads()
- we decided to use flat shared arrays rather than 2D arrays
with open('smem_convolution.py', 'r') as f:
print(f.read())
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from numba import cuda, float32
import numpy as np
import skimage.data
from skimage.color import rgb2gray
from scipy.ndimage.filters import convolve as scipy_convolve
# We use a mask of 13x13 pixels:
mask_rows, mask_cols = 13, 13
mask_size = mask_rows * mask_cols
delta_rows = mask_rows // 2
delta_cols = mask_cols // 2
# We use blocks of 32x32 pixels:
blockdim = (32, 32)
# We use an image of size:
image_rows, image_cols = 200, 200
# We compute grid dimensions big enough to cover the whole image:
griddim = (image_rows // blockdim[0] + 1,
image_cols // blockdim[1] + 1)
# We want to keep in shared memory a part of the image of shape:
shared_image_rows = blockdim[0] + mask_rows - 1
shared_image_cols = blockdim[1] + mask_cols - 1
shared_image_size = shared_image_rows * shared_image_cols
@cuda.jit
def smem_convolve(result, mask, image):
# expects a 2D grid and 2D blocks,
# a mask with odd numbers of rows and columns,
# a grayscale image
# 2D coordinates of the current thread:
i, j = cuda.grid(2)
# Create shared arrays
shared_image = cuda.shared.array(shared_image_size, float32)
shared_mask = cuda.shared.array(mask_size, float32)
# Fill shared mask
if (cuda.threadIdx.x < mask_rows) and (cuda.threadIdx.y < mask_cols):
shared_mask[cuda.threadIdx.x + cuda.threadIdx.y * mask_rows] = mask[cuda.threadIdx.x, cuda.threadIdx.y]
# Fill shared image
# Each thread fills four cells of the array
row_corner = cuda.blockDim.x * cuda.blockIdx.x - delta_rows
col_corner = cuda.blockDim.y * cuda.blockIdx.y - delta_cols
even_idx_x = 2 * cuda.threadIdx.x
even_idx_y = 2 * cuda.threadIdx.y
odd_idx_x = even_idx_x + 1
odd_idx_y = even_idx_y + 1
for idx_x in (even_idx_x, odd_idx_x):
if idx_x < shared_image_rows:
for idx_y in (even_idx_y, odd_idx_y):
if idx_y < shared_image_cols:
point = (row_corner + idx_x, col_corner + idx_y)
if (point[0] >= 0) and (point[1] >= 0) and (point[0] < image_rows) and (point[1] < image_cols):
shared_image[idx_x + idx_y * shared_image_rows] = image[point]
else:
shared_image[idx_x + idx_y * shared_image_rows] = float32(0)
cuda.syncthreads()
# The result at coordinates (i, j) is equal to
# sum_{k, l} mask[k, l] * image[threadIdx.x - k + 2 * delta_rows,
# threadIdx.y - l + 2 * delta_cols]
# with k and l going through the whole mask array:
s = float32(0)
for k in range(mask_rows):
for l in range(mask_cols):
i_k = cuda.threadIdx.x - k + mask_rows - 1
j_l = cuda.threadIdx.y - l + mask_cols - 1
s += shared_mask[k + l * mask_rows] * shared_image[i_k + j_l * shared_image_rows]
if (i < image_rows) and (j < image_cols):
result[i, j] = s
if __name__ == '__main__':
# Read image
full_image = rgb2gray(skimage.data.coffee()).astype(np.float32) / 255
image = full_image[150:150 + image_rows, 200:200 + image_cols].copy()
# We preallocate the result array:
result = np.empty_like(image)
# We choose a random mask:
mask = np.random.rand(mask_rows, mask_cols).astype(np.float32)
mask /= mask.sum() # We normalize the mask
# We apply our convolution to our image:
smem_convolve[griddim, blockdim](result, mask, image)
# We check that the error with respect to Scipy convolve function is small:
scipy_result = scipy_convolve(image, mask, mode='constant', cval=0.0, origin=0)
max_rel_error = np.max(np.abs(result - scipy_result) / np.abs(scipy_result))
if max_rel_error > 1e-5:
raise AssertionError('Maximum relative error w.r.t Scipy convolve is too large: '
+ max_rel_error)
Let’s do a quick profiling:
!nvprof python smem_convolution.py
==3608== NVPROF is profiling process 3608, command: python smem_convolution.py
==3608== Profiling application: python smem_convolution.py
==3608== Profiling result:
Type Time(%) Time Calls Avg Min Max Name
GPU activities: 43.55% 41.888us 1 41.888us 41.888us 41.888us cudapy::__main__::smem_convolve$241(Array<float, int=2, C, mutable, aligned>, Array<float, int=2, C, mutable, aligned>, Array<float, int=2, C, mutable, aligned>)
29.01% 27.904us 3 9.3010us 704ns 13.600us [CUDA memcpy HtoD]
27.45% 26.401us 3 8.8000us 704ns 12.865us [CUDA memcpy DtoH]
API calls: 98.36% 115.09ms 1 115.09ms 115.09ms 115.09ms cuDevicePrimaryCtxRetain
0.51% 594.06us 1 594.06us 594.06us 594.06us cuLinkCreate
0.29% 340.61us 3 113.54us 11.153us 203.77us cuMemAlloc
0.17% 200.77us 1 200.77us 200.77us 200.77us cuLinkAddData
0.17% 199.35us 1 199.35us 199.35us 199.35us cuModuleLoadDataEx
0.13% 146.68us 1 146.68us 146.68us 146.68us cuLinkComplete
0.11% 130.99us 1 130.99us 130.99us 130.99us cuMemGetInfo
0.10% 114.20us 3 38.067us 22.703us 50.884us cuMemcpyDtoH
0.08% 87.873us 3 29.291us 13.504us 40.347us cuMemcpyHtoD
0.06% 67.413us 1 67.413us 67.413us 67.413us cuDeviceGetName
0.02% 25.491us 1 25.491us 25.491us 25.491us cuLaunchKernel
0.00% 3.3510us 5 670ns 443ns 1.1910us cuFuncGetAttribute
0.00% 2.8930us 3 964ns 364ns 1.6320us cuDeviceGetCount
0.00% 2.3550us 1 2.3550us 2.3550us 2.3550us cuCtxPushCurrent
0.00% 1.6580us 1 1.6580us 1.6580us 1.6580us cuModuleGetFunction
0.00% 1.5570us 2 778ns 664ns 893ns cuDeviceGet
0.00% 1.2490us 3 416ns 379ns 479ns cuDeviceGetAttribute
0.00% 820ns 1 820ns 820ns 820ns cuLinkDestroy
0.00% 781ns 1 781ns 781ns 781ns cuDeviceComputeCapability
This version runs in 40us while the previous version without use of shared memory was running in 420us. Using shared memory improved the speed by \(10\times\). That’s a big improvement!
!nvprof --quiet --export-profile smem_timeline.prof python smem_convolution.py
!nvprof --quiet --metrics all --events all -o smem_metrics-events.prof python smem_convolution.py
Using dynamic allocation
We saw previously how to allocate shared memory. However, our code uses global variables to define the size of each shared arrays. It is possible to have a more flexible code by dynamically allocating the shared memory.
Dynamic allocation of the shared memory is done by setting the shape to 0 in cuda.shared.array
. The size of the shared array is then define when calling the kernel:
kernel[griddim, blockdim, stream, shared_memory_size](arguments)
Note that the kernel has to be linked to a stream. A value of 0 is the default stream.
Only one array can be dynamically allocated. In the case of our convolution kernel, we therefore use this array to contain both the mask as well as the shared part of the image.
Below is a version that uses dynamic allocation:
with open('smem_convolution_dynamic.py', 'r') as f:
print(f.read())
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from numba import cuda, float32
import numpy as np
import skimage.data
from skimage.color import rgb2gray
from scipy.ndimage.filters import convolve as scipy_convolve
@cuda.jit
def smem_convolve(result, mask, image):
# expects a 2D grid and 2D blocks,
# a mask with odd numbers of rows and columns,
# a grayscale image
# 2D coordinates of the current thread:
i, j = cuda.grid(2)
# Number of rows and columns of the image
image_rows, image_cols = image.shape
# To compute the result at coordinates (i, j), we need to use delta_rows
# rows of the image before and after the i_th row, as well as delta_cols
# columns of the image before and after the j_th column:
delta_rows = mask.shape[0] // 2
delta_cols = mask.shape[1] // 2
# The width and height of the image we want to keep in shared memory
# can be computed as follow:
width = cuda.blockDim.y + delta_cols * 2
height = cuda.blockDim.x + delta_rows * 2
# Create shared array
shared_array = cuda.shared.array(0, float32)
mask_index = width * height
if (cuda.threadIdx.x < mask.shape[0]) and (cuda.threadIdx.y < mask.shape[1]):
shared_array[mask_index + cuda.threadIdx.x + cuda.threadIdx.y * mask.shape[0]] = mask[cuda.threadIdx.x, cuda.threadIdx.y]
# Fill shared array
# Each thread fills four cells of the array
row_corner = cuda.blockDim.x * cuda.blockIdx.x - delta_rows
col_corner = cuda.blockDim.y * cuda.blockIdx.y - delta_cols
even_idx_x = 2 * cuda.threadIdx.x
even_idx_y = 2 * cuda.threadIdx.y
odd_idx_x = even_idx_x + 1
odd_idx_y = even_idx_y + 1
for idx_x in (even_idx_x, odd_idx_x):
if idx_x < height:
for idx_y in (even_idx_y, odd_idx_y):
if idx_y < width:
point = (row_corner + idx_x, col_corner + idx_y)
if (point[0] >= 0) and (point[1] >= 0) and (point[0] < image_rows) and (point[1] < image_cols):
shared_array[idx_x + idx_y * height] = image[point]
else:
shared_array[idx_x + idx_y * height] = float32(0)
cuda.syncthreads()
# The result at coordinates (i, j) is equal to
# sum_{k, l} mask[k, l] * image[threadIdx.x - k + 2 * delta_rows,
# threadIdx.y - l + 2 * delta_cols]
# with k and l going through the whole mask array:
s = float32(0)
for k in range(mask.shape[0]):
for l in range(mask.shape[1]):
i_k = cuda.threadIdx.x - k + mask.shape[0] - 1
j_l = cuda.threadIdx.y - l + mask.shape[1] - 1
s += shared_array[mask_index + k + l * mask.shape[0]] * shared_array[i_k + j_l * height]
if (i < image_rows) and (j < image_cols):
result[i, j] = s
if __name__ == '__main__':
# Read image
full_image = rgb2gray(skimage.data.coffee()).astype(np.float32) / 255
image = full_image[150:350, 200:400].copy()
# We preallocate the result array:
result = np.empty_like(image)
# We choose a random mask of size 13x13:
mask = np.random.rand(13, 13).astype(np.float32)
mask /= mask.sum() # We normalize the mask
# We use blocks of 32x32 pixels:
blockdim = (32, 32)
# We compute grid dimensions big enough to cover the whole image:
griddim = (image.shape[0] // blockdim[0] + 1,
image.shape[1] // blockdim[1] + 1)
# We use the default stream:
stream = 0
# We compute the size of shared memory to allocate:
smem_size = ((blockdim[0] + mask.shape[0] - 1)
* (blockdim[1] + mask.shape[1] - 1)
+ mask.size) * image.itemsize
# We apply our convolution to our image:
smem_convolve[griddim, blockdim, stream, smem_size](result, mask, image)
# We check that the error with respect to Scipy convolve function is small:
scipy_result = scipy_convolve(image, mask, mode='constant', cval=0.0, origin=0)
max_rel_error = np.max(np.abs(result - scipy_result) / np.abs(scipy_result))
if max_rel_error > 1e-5:
raise AssertionError('Maximum relative error w.r.t Scipy convolve is too large: '
+ max_rel_error)
!nvprof python smem_convolution_dynamic.py
==3695== NVPROF is profiling process 3695, command: python smem_convolution_dynamic.py
==3695== Profiling application: python smem_convolution_dynamic.py
==3695== Profiling result:
Type Time(%) Time Calls Avg Min Max Name
GPU activities: 52.25% 60.193us 1 60.193us 60.193us 60.193us cudapy::__main__::smem_convolve$241(Array<float, int=2, C, mutable, aligned>, Array<float, int=2, C, mutable, aligned>, Array<float, int=2, C, mutable, aligned>)
24.86% 28.640us 3 9.5460us 768ns 14.080us [CUDA memcpy HtoD]
22.89% 26.368us 3 8.7890us 704ns 12.832us [CUDA memcpy DtoH]
API calls: 98.39% 112.00ms 1 112.00ms 112.00ms 112.00ms cuDevicePrimaryCtxRetain
0.44% 505.33us 1 505.33us 505.33us 505.33us cuLinkCreate
0.31% 351.68us 3 117.23us 12.925us 205.54us cuMemAlloc
0.18% 200.69us 1 200.69us 200.69us 200.69us cuModuleLoadDataEx
0.17% 196.92us 1 196.92us 196.92us 196.92us cuLinkAddData
0.12% 138.55us 1 138.55us 138.55us 138.55us cuLinkComplete
0.11% 128.88us 1 128.88us 128.88us 128.88us cuMemGetInfo
0.10% 114.29us 3 38.097us 13.713us 56.437us cuMemcpyDtoH
0.10% 108.52us 3 36.171us 14.860us 52.672us cuMemcpyHtoD
0.04% 41.976us 1 41.976us 41.976us 41.976us cuDeviceGetName
0.03% 32.051us 1 32.051us 32.051us 32.051us cuLaunchKernel
0.00% 3.2530us 5 650ns 465ns 1.0670us cuFuncGetAttribute
0.00% 2.7730us 3 924ns 360ns 1.5160us cuDeviceGetCount
0.00% 1.8600us 1 1.8600us 1.8600us 1.8600us cuCtxPushCurrent
0.00% 1.6280us 2 814ns 683ns 945ns cuDeviceGet
0.00% 1.5500us 1 1.5500us 1.5500us 1.5500us cuModuleGetFunction
0.00% 1.3360us 3 445ns 414ns 502ns cuDeviceGetAttribute
0.00% 841ns 1 841ns 841ns 841ns cuDeviceComputeCapability
0.00% 816ns 1 816ns 816ns 816ns cuLinkDestroy
!nvprof --quiet --export-profile smem_dynamic_timeline.prof python smem_convolution_dynamic.py
!nvprof --quiet --metrics all --events all -o smem_dynamic_metrics-events.prof python smem_convolution_dynamic.py
We saw in this Part that by using shared memory, we can get a \(10\times\) speed improvement for our convolution kernel. Using shared memory is often key to big performance improvements!
This is the end of our Introduction to CUDA. There are still many technical details to understand to fully use the potential of a GPU. However, using only the few basics presented here should already allow to write kernels that outperform a CPU for highly parallel tasks.