Skip to content

Commit

Permalink
Check that kernel source code has extern C defined (AMDResearch#70)
Browse files Browse the repository at this point in the history
* Verify that kernel code is wrapped in extern C

* Remove curly braces

* Add breakline

* Change raise type

* Fix checks

* Return RunTime error

* Check construct

* Use Syntax error

* Change way we do pytest

* Only check if extern C exists

* Remove unused check
  • Loading branch information
mariodruiz committed Oct 15, 2024
1 parent ca9ade7 commit 3549f52
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
9 changes: 9 additions & 0 deletions npu/build/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .port import BufferPort, RTPPort
from typing import Optional, Callable, List, Dict
import re
import warnings


class Kernel(KernelMeta):
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(self, srccode : str, behavioralfx:Optional[Callable]=None, top_func

self.kb = KernelObjectBuilder(self.ktype, self.srccode, self.srcfile)
self._main_function_sanity_check()
self._extern_c_check()
self._expose_ports()

def _expose_ports(self)->None:
Expand Down Expand Up @@ -96,6 +98,13 @@ def _main_function_sanity_check(self)->None:
if not self._main_function['rtnType'] == "void":
raise RuntimeError(f"The return type of the top_level function should be void not {self._main_function['rtnType']}")

def _extern_c_check(self):
"""Verify that extern C is used"""
tight_code = self.srccode.replace(' ', '').replace(' ', '')
if 'extern"C"' not in tight_code:
raise SyntaxError('extern "C" not found. Top level function '
'should be wrapped by extern "C"')

def display(self)->None:
"""Render the kernel code in a jupyter notebook."""
from IPython.display import display, Code
Expand Down
27 changes: 27 additions & 0 deletions tests/test_externc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 Advanced Micro Devices, Inc.
# SPDX-License-Identifier: MIT

import pytest
from npu.build.kernel import Kernel
from npu.lib import Plus1


kernel_src = Plus1().srccode

kernel_src1 = kernel_src.replace('\n\n}', '')


def test_externc_good():
krnl_obj = Kernel(kernel_src)
krnl_obj.build()
assert krnl_obj


@pytest.mark.parametrize('replacewith', [''])
def test_externc_bad(replacewith):
src_code = kernel_src1.replace('extern "C" {', replacewith)

with pytest.raises(SyntaxError) as excinfo:
_ = Kernel(src_code)

assert 'extern "C" not found.' in str(excinfo.value)

0 comments on commit 3549f52

Please sign in to comment.