diff --git a/src/device_api.cc b/src/device_api.cc index 0c3b3fad2d82..34e3ff0f41fc 100644 --- a/src/device_api.cc +++ b/src/device_api.cc @@ -45,7 +45,7 @@ class VTADeviceAPI final : public DeviceAPI { void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment, - TVMType type_hint) final { + DLDataType type_hint) final { return VTABufferAlloc(size); } @@ -60,7 +60,7 @@ class VTADeviceAPI final : public DeviceAPI { size_t size, TVMContext ctx_from, TVMContext ctx_to, - TVMType type_hint, + DLDataType type_hint, TVMStreamHandle stream) final { int kind_mask = 0; if (ctx_from.device_type != kDLCPU) { @@ -77,7 +77,7 @@ class VTADeviceAPI final : public DeviceAPI { void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { } - void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final; + void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final; void FreeWorkspace(TVMContext ctx, void* data) final; @@ -93,7 +93,7 @@ struct VTAWorkspacePool : public WorkspacePool { WorkspacePool(kDLExtDev, VTADeviceAPI::Global()) {} }; -void* VTADeviceAPI::AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) { +void* VTADeviceAPI::AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) { return dmlc::ThreadLocalStore::Get() ->AllocWorkspace(ctx, size); }