Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check that kernel source code has extern C defined #70

Merged
merged 11 commits into from
Oct 15, 2024
9 changes: 9 additions & 0 deletions npu/build/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .port import BufferPort, RTPPort
from typing import Optional, Callable, List, Dict
import re
import warnings


class Kernel(KernelMeta):
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(self, srccode : str, behavioralfx:Optional[Callable]=None, top_func

self.kb = KernelObjectBuilder(self.ktype, self.srccode, self.srcfile)
self._main_function_sanity_check()
self._extern_c_check()
self._expose_ports()

def _expose_ports(self)->None:
Expand Down Expand Up @@ -96,6 +98,13 @@ def _main_function_sanity_check(self)->None:
if not self._main_function['rtnType'] == "void":
raise RuntimeError(f"The return type of the top_level function should be void not {self._main_function['rtnType']}")

def _extern_c_check(self):
"""Verify that extern C is used"""
tight_code = self.srccode.replace(' ', '').replace(' ', '')
if 'extern"C"' not in tight_code:
raise SyntaxError('extern "C" not found. Top level function '
'should be wrapped by extern "C"')

def display(self)->None:
"""Render the kernel code in a jupyter notebook."""
from IPython.display import display, Code
Expand Down
27 changes: 27 additions & 0 deletions tests/test_externc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 Advanced Micro Devices, Inc.
# SPDX-License-Identifier: MIT

import pytest
from npu.build.kernel import Kernel
from npu.lib import Plus1


kernel_src = Plus1().srccode

kernel_src1 = kernel_src.replace('\n\n}', '')


def test_externc_good():
krnl_obj = Kernel(kernel_src)
krnl_obj.build()
assert krnl_obj


@pytest.mark.parametrize('replacewith', [''])
def test_externc_bad(replacewith):
src_code = kernel_src1.replace('extern "C" {', replacewith)

with pytest.raises(SyntaxError) as excinfo:
_ = Kernel(src_code)

assert 'extern "C" not found.' in str(excinfo.value)
Loading