Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
408 changes: 215 additions & 193 deletions pyhip/hip.py

Large diffs are not rendered by default.

146 changes: 122 additions & 24 deletions pyhip/hiprtc.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,32 +77,37 @@ def hiprtcGetErrorString(e):

"""

return _libhiprtc.hiprtcGetErrorString(e)
error = _libhiprtc.hiprtcGetErrorString(e)
return str(error)


# Generic hiprtc Error


class hiprtcError(Exception):
"""hip error"""
"""hiprtc error"""

def __init__(self, error=0) -> None:
super().__init__(hiprtcGetErrorString(error))
if isinstance(error, int):
super().__init__(hiprtcGetErrorString(error))
else:
super().__init__(str(error))


knownExceptions = set(
[
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
1, # HIPRTC_ERROR_OUT_OF_MEMORY
2, # HIPRTC_ERROR_PROGRAM_CREATION_FAILURE
3, # HIPRTC_ERROR_INVALID_INPUT
4, # HIPRTC_ERROR_INVALID_PROGRAM
5, # HIPRTC_ERROR_INVALID_OPTION
6, # HIPRTC_ERROR_COMPILATION
7, # HIPRTC_ERROR_BUILTIN_OPERATION_FAILURE
8, # HIPRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION
9, # HIPRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION
10, # HIPRTC_ERROR_NAME_EXPRESSION_NOT_VALID
11, # HIPRTC_ERROR_INTERNAL_ERROR
100, # HIPRTC_ERROR_LINKING
]
)

Expand Down Expand Up @@ -294,11 +299,10 @@ def hiprtcGetProgramLog(prog):
status = _libhiprtc.hiprtcGetProgramLogSize(prog, ctypes.byref(log_size))
hiprtcCheckStatus(status)

log = "0" * log_size.value
e_log = log.encode("utf-8")
status = _libhiprtc.hiprtcGetProgramLog(prog, e_log)
log_buf = ctypes.create_string_buffer(log_size.value)
status = _libhiprtc.hiprtcGetProgramLog(prog, log_buf)
hiprtcCheckStatus(status)
return e_log.decode("utf-8")
return log_buf.value.decode("utf-8")


_libhiprtc.hiprtcGetCodeSize.restype = int
Expand All @@ -324,15 +328,109 @@ def hiprtcGetCode(prog):

Returns
-------
code : string
hiprtc module code
code : bytes
hiprtc compiled binary code
"""
code_size = ctypes.c_size_t()
status = _libhiprtc.hiprtcGetCodeSize(prog, ctypes.byref(code_size))
hiprtcCheckStatus(status)

code = "0" * code_size.value
e_code = code.encode("utf-8")
status = _libhiprtc.hiprtcGetCode(prog, e_code)
code_buf = ctypes.create_string_buffer(code_size.value)
status = _libhiprtc.hiprtcGetCode(prog, code_buf)
hiprtcCheckStatus(status)
return code_buf.raw


_libhiprtc.hiprtcVersion.restype = int
_libhiprtc.hiprtcVersion.argtypes = [
ctypes.POINTER(ctypes.c_int),
ctypes.POINTER(ctypes.c_int),
]


def hiprtcVersion():
"""
Returns the hiprtc major and minor version.

Returns
-------
major : int
HIP Runtime Compilation major version.
minor : int
HIP Runtime Compilation minor version.
"""
major = ctypes.c_int(0)
minor = ctypes.c_int(0)
status = _libhiprtc.hiprtcVersion(
ctypes.byref(major), ctypes.byref(minor))
hiprtcCheckStatus(status)
return major.value, minor.value


_libhiprtc.hiprtcGetLoweredName.restype = int
_libhiprtc.hiprtcGetLoweredName.argtypes = [
ctypes.c_void_p, # hiprtcProgram
ctypes.POINTER(ctypes.c_char), # name_expression
ctypes.POINTER(ctypes.c_char_p), # lowered_name
]


def hiprtcGetLoweredName(prog, name_expression):
"""
Gets the lowered (mangled) name from an instance of hiprtcProgram.

Parameters
----------
prog : ctypes pointer
hiprtc program handle
name_expression : str
The name expression to look up.

Returns
-------
lowered_name : str
The lowered (mangled) name.
"""
e_name = name_expression.encode("utf-8")
lowered_name = ctypes.c_char_p()
status = _libhiprtc.hiprtcGetLoweredName(
prog, e_name, ctypes.byref(lowered_name))
hiprtcCheckStatus(status)
return lowered_name.value.decode("utf-8")


_libhiprtc.hiprtcGetBitcodeSize.restype = int
_libhiprtc.hiprtcGetBitcodeSize.argtypes = [
ctypes.c_void_p,
ctypes.POINTER(ctypes.c_size_t),
]
_libhiprtc.hiprtcGetBitcode.restype = int
_libhiprtc.hiprtcGetBitcode.argtypes = [
ctypes.c_void_p,
ctypes.POINTER(ctypes.c_char),
]


def hiprtcGetBitcode(prog):
"""
Gets the compiled bitcode from the program.

Parameters
----------
prog : ctypes pointer
hiprtc program handle

Returns
-------
bitcode : bytes
Compiled bitcode.
"""
bitcode_size = ctypes.c_size_t()
status = _libhiprtc.hiprtcGetBitcodeSize(
prog, ctypes.byref(bitcode_size))
hiprtcCheckStatus(status)

bitcode_buf = ctypes.create_string_buffer(bitcode_size.value)
status = _libhiprtc.hiprtcGetBitcode(prog, bitcode_buf)
hiprtcCheckStatus(status)
return e_code
return bitcode_buf.raw
38 changes: 38 additions & 0 deletions tests/test_device_management.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from pyhip import hip
import unittest


class TestDeviceManagement(unittest.TestCase):
def test_hipInit(self):
hip.hipInit(0)

def test_hipGetDevice(self):
device = hip.hipGetDevice()
self.assertGreaterEqual(device, 0)

def test_hipSetDevice(self):
device_count = hip.hipGetDeviceCount()
self.assertGreater(device_count, 0)
original_device = hip.hipGetDevice()
for i in range(device_count):
hip.hipSetDevice(i)
current = hip.hipGetDevice()
self.assertEqual(current, i)
hip.hipSetDevice(original_device)

def test_hipSetDevice_invalid(self):
device_count = hip.hipGetDeviceCount()
with self.assertRaises(hip.hipError):
hip.hipSetDevice(device_count + 100)

def test_hipDeviceSynchronize(self):
hip.hipDeviceSynchronize()

def test_hipDeviceSynchronize_after_malloc(self):
ptr = hip.hipMalloc(1024)
hip.hipDeviceSynchronize()
hip.hipFree(ptr)


if __name__ == "__main__":
unittest.main()
108 changes: 108 additions & 0 deletions tests/test_device_properties.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from pyhip import hip
import unittest


class TestDeviceProperties(unittest.TestCase):
def setUp(self):
self.device_count = hip.hipGetDeviceCount()
self.assertGreater(self.device_count, 0)

def test_name_not_empty(self):
props = hip.hipGetDeviceProperties(0)
self.assertIsInstance(props.name, str)
self.assertGreater(len(props.name), 0)

def test_total_global_mem(self):
props = hip.hipGetDeviceProperties(0)
self.assertGreater(props.totalGlobalMem, 0)

def test_shared_mem_per_block(self):
props = hip.hipGetDeviceProperties(0)
self.assertGreater(props.sharedMemPerBlock, 0)

def test_warp_size(self):
props = hip.hipGetDeviceProperties(0)
self.assertIn(props.warpSize, [32, 64])

def test_max_threads_per_block(self):
props = hip.hipGetDeviceProperties(0)
self.assertGreater(props.maxThreadsPerBlock, 0)
self.assertLessEqual(props.maxThreadsPerBlock, 2048)

def test_max_threads_dim(self):
props = hip.hipGetDeviceProperties(0)
for i in range(3):
self.assertGreater(props.maxThreadsDim[i], 0)

def test_max_grid_size(self):
props = hip.hipGetDeviceProperties(0)
for i in range(3):
self.assertGreater(props.maxGridSize[i], 0)

def test_clock_rate(self):
props = hip.hipGetDeviceProperties(0)
self.assertGreater(props.clockRate, 0)

def test_compute_capability(self):
props = hip.hipGetDeviceProperties(0)
self.assertGreater(props.major, 0)
self.assertGreaterEqual(props.minor, 0)

def test_multi_processor_count(self):
props = hip.hipGetDeviceProperties(0)
self.assertGreater(props.multiProcessorCount, 0)

def test_l2_cache_size(self):
props = hip.hipGetDeviceProperties(0)
self.assertGreaterEqual(props.l2CacheSize, 0)

def test_max_threads_per_multi_processor(self):
props = hip.hipGetDeviceProperties(0)
self.assertGreater(props.maxThreadsPerMultiProcessor, 0)

def test_memory_clock_rate(self):
props = hip.hipGetDeviceProperties(0)
self.assertGreater(props.memoryClockRate, 0)

def test_memory_bus_width(self):
props = hip.hipGetDeviceProperties(0)
self.assertGreater(props.memoryBusWidth, 0)

def test_regs_per_block(self):
props = hip.hipGetDeviceProperties(0)
self.assertGreater(props.regsPerBlock, 0)

def test_gcn_arch_name_on_amd(self):
if hip.hipGetPlatformName() != "amd":
self.skipTest("gcnArchName is AMD-specific")
props = hip.hipGetDeviceProperties(0)
self.assertIsInstance(props.gcnArchName, str)
self.assertGreater(len(props.gcnArchName), 0)

def test_concurrent_kernels(self):
props = hip.hipGetDeviceProperties(0)
self.assertIn(props.concurrentKernels, [0, 1])

def test_all_devices(self):
for i in range(self.device_count):
props = hip.hipGetDeviceProperties(i)
self.assertGreater(len(props.name), 0)
self.assertGreater(props.totalGlobalMem, 0)

def test_pci_ids(self):
props = hip.hipGetDeviceProperties(0)
self.assertGreaterEqual(props.pciBusID, 0)
self.assertGreaterEqual(props.pciDeviceID, 0)
self.assertGreaterEqual(props.pciDomainID, 0)

def test_mem_pitch(self):
props = hip.hipGetDeviceProperties(0)
self.assertGreater(props.memPitch, 0)

def test_texture_alignment(self):
props = hip.hipGetDeviceProperties(0)
self.assertGreater(props.textureAlignment, 0)


if __name__ == "__main__":
unittest.main()
53 changes: 53 additions & 0 deletions tests/test_error_handling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from pyhip import hip
import ctypes
import unittest


class TestErrorHandling(unittest.TestCase):
def test_hipError_with_known_code(self):
try:
raise hip.hipError(1)
except hip.hipError as e:
msg = str(e)
self.assertGreater(len(msg), 0)

def test_hipError_with_string(self):
try:
raise hip.hipError("custom error message")
except hip.hipError as e:
self.assertEqual(str(e), "custom error message")

def test_hipCheckStatus_success(self):
hip.hipCheckStatus(0)

def test_hipCheckStatus_known_error(self):
with self.assertRaises(hip.hipError):
hip.hipCheckStatus(1)

def test_hipCheckStatus_unknown_error(self):
with self.assertRaises(hip.hipError) as ctx:
hip.hipCheckStatus(99999)
self.assertIn("unknown hip error", str(ctx.exception))

def test_hipGetErrorString(self):
s = hip.hipGetErrorString(0)
self.assertIsInstance(s, str)
self.assertGreater(len(s), 0)

def test_hipGetErrorName(self):
name = hip.hipGetErrorName(0)
self.assertIsInstance(name, str)
self.assertGreater(len(name), 0)

def test_hipFree_invalid_raises(self):
bad_ptr = ctypes.c_void_p(0xDEADBEEF)
with self.assertRaises(hip.hipError):
hip.hipFree(bad_ptr)

def test_hipMalloc_zero_bytes(self):
ptr = hip.hipMalloc(0)
hip.hipFree(ptr)


if __name__ == "__main__":
unittest.main()
Loading