JIT kernels
import cupy as cp
from cupyx import jit
from cupyx.profiler import benchmark
@jit.rawkernel()
def primeFactorizationSum(A, n):
# Calculate global thread ID
idx = jit.blockDim.x * jit.blockIdx.x + jit.threadIdx.x
if idx >= n:
return
# Use cp.int64 to explicitly enforce 'long long' C++ types
num = cp.int64(idx) + 2
total = cp.int64(0)
if num % 2 == 0:
count = 0
while num % 2 == 0:
num //= 2
count += 1
total += count * 2
i = cp.int64(3)
while i * i <= num:
if num % i == 0:
count = 0
while num % i == 0:
num //= i
count += 1
total += count * i
i += 2
if num > 1:
total += num
A[idx] = total
n = 1_000_000
A = cp.empty(n - 1, dtype=cp.int64)
threads = 256
blocks = (n - 1 + threads - 1) // threads
print(
benchmark(
primeFactorizationSum, ((blocks,), (threads,), (A, n - 1)), n_repeat=20
)
)
print(f"Primes sum array: {A}")primeFactorizationSum:
CPU: 27.653 us +/- 5.814 (min: 23.670 / max: 51.359) us
GPU-0: 896.531 us +/- 6.597 (min: 891.456 / max: 923.136) us
Primes sum array: [ 2 3 4 ... 287 77 42]
Total time: 924 µs.
Not quite as good as writing the raw kernel.