Custom kernels

Author

Marie-Hélène Burle

Remember that “kernel” is the term used for functions running on the GPU. In the same way that you need to define your own functions when you use a programming language, you need to write your own kernels.

In this section, we cover how you can do this.

General concepts

Inputs and output format: type + name.

Examples:

float32 a   # NumPy data types can be used
T x         # T = generic type

Elementwise kernels

<kernel name> = cp.ElementwiseKernel(
    '<list of inputs>',
    '<list of outputs>',
    '<operation to perform>',
    '<kernel name>'
)

Example:

squared_diff = cp.ElementwiseKernel(
   'float32 x, float32 y',
   'float32 z',
   'z = (x - y) * (x - y)',
   'squared_diff'
)

Reduction kernels

Let’s create a kernel that calculates the mean square error.

This is how you would do this in NumPy:

# Function to calculate MSE
def mse_fn(y_tree, y_pred):
    mse_out = np.mean((y_true - y_pred)**2)
    return mse_out

# Dummy data
y_pred = np.array([1.5, 2.0, 3.5, 4.0], dtype=np.float32)
y_true = np.array([1.0, 2.5, 3.5, 3.0], dtype=np.float32)

# Calculate MSE
mse = mse_fn(y_pred, y_true)

print(f"Predictions: {y_pred}")
print(f"Targets:     {y_true}")
print(f"MSE:         {mse}")
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[1], line 7
      4     return mse_out
      6 # Dummy data
----> 7 y_pred = np.array([1.5, 2.0, 3.5, 4.0], dtype=np.float32)
      8 y_true = np.array([1.0, 2.5, 3.5, 3.0], dtype=np.float32)
     10 # Calculate MSE

NameError: name 'np' is not defined
mse_kernel = cp.ReductionKernel(
    'T y_pred, T y_true',
    'T mse_out',
    '(y_pred - y_true) * (y_pred - y_true)',
    'a + b',
    'mse_out = a / _in_ind.size()',
    '0',
    'mse_kernel'
)
# Dummy data
y_pred = cp.array([1.5, 2.0, 3.5, 4.0], dtype=cp.float32)
y_true = cp.array([1.0, 2.5, 3.5, 3.0], dtype=cp.float32)

# Calculate MSE
mse = mse_kernel(y_pred, y_true)

print(f"Predictions: {y_pred}")
print(f"Targets:     {y_true}")
print(f"MSE:         {mse}")