JIT kernels

Author

Marie-Hélène Burle

import cupy as cp
from cupyx import jit
from cupyx.profiler import benchmark


@jit.rawkernel()
def primeFactorizationSum(A, n):
    # Calculate global thread ID
    idx = jit.blockDim.x * jit.blockIdx.x + jit.threadIdx.x
    if idx >= n:
        return

    # Use cp.int64 to explicitly enforce 'long long' C++ types
    num = cp.int64(idx) + 2
    total = cp.int64(0)

    if num % 2 == 0:
        count = 0
        while num % 2 == 0:
            num //= 2
            count += 1
        total += count * 2

    i = cp.int64(3)
    while i * i <= num:
        if num % i == 0:
            count = 0
            while num % i == 0:
                num //= i
                count += 1
            total += count * i
        i += 2

    if num > 1:
        total += num

    A[idx] = total


n = 1_000_000
A = cp.empty(n - 1, dtype=cp.int64)

threads = 256
blocks = (n - 1 + threads - 1) // threads

print(
    benchmark(
        primeFactorizationSum, ((blocks,), (threads,), (A, n - 1)), n_repeat=20
    )
)

print(f"Primes sum array: {A}")
primeFactorizationSum:
CPU:    27.653 us +/- 5.814 (min:  23.670 / max:  51.359) us
GPU-0: 896.531 us +/- 6.597 (min: 891.456 / max: 923.136) us

Primes sum array: [  2   3   4 ... 287  77  42]

Total time: 924 µs.

Not quite as good as writing the raw kernel.