Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RapidsBufferCatalog tests #193

Merged
merged 3 commits into from
Jun 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,9 @@ import org.apache.spark.internal.Logging

/** Catalog for lookup of buffers by ID */
class RapidsBufferCatalog extends Logging {
/** Tracks all buffer stores using this catalog */
private[this] val stores = new ArrayBuffer[RapidsBufferStore]

/** Map of buffer IDs to buffers */
private[this] val bufferMap = new ConcurrentHashMap[RapidsBufferId, RapidsBuffer]

/**
* Register a buffer store that is using this catalog.
* NOTE: It is assumed all stores are registered before any buffers are added to the catalog.
* @param store buffer store
*/
def registerStore(store: RapidsBufferStore): Unit = {
require(store.currentSize == 0, "Store must not have any buffers when registered")
stores.append(store)
}

/**
* Lookup the buffer that corresponds to the specified buffer ID and acquire it.
* NOTE: It is the responsibility of the caller to close the buffer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,6 @@ abstract class RapidsBufferStore(

private[this] val nvtxSyncSpillName: String = name + " sync spill"

catalog.registerStore(this)

/** Return the current byte total of buffers in this store. */
def currentSize: Long = buffers.getTotalBytes

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

import java.io.File
import java.util.NoSuchElementException

import com.nvidia.spark.rapids.StorageTier.StorageTier
import com.nvidia.spark.rapids.format.TableMeta
import org.mockito.Mockito._
import org.scalatest.FunSuite
import org.scalatest.mockito.MockitoSugar

import org.apache.spark.sql.rapids.RapidsDiskBlockManager

class RapidsBufferCatalogSuite extends FunSuite with MockitoSugar {
test("lookup unknown buffer") {
val catalog = new RapidsBufferCatalog
val bufferId = new RapidsBufferId {
override val tableId: Int = 10
override def getDiskPath(m: RapidsDiskBlockManager): File = null
}
assertThrows[NoSuchElementException](catalog.acquireBuffer(bufferId))
assertThrows[NoSuchElementException](catalog.getBufferMeta(bufferId))
}

test("acquire buffer") {
val catalog = new RapidsBufferCatalog
val bufferId = MockBufferId(5)
abellina marked this conversation as resolved.
Show resolved Hide resolved
val buffer = mockBuffer(bufferId)
catalog.registerNewBuffer(buffer)
val acquired = catalog.acquireBuffer(MockBufferId(5))
assertResult(5)(acquired.id.tableId)
assertResult(buffer)(acquired)
verify(buffer).addReference()
}

test("acquire buffer retries automatically") {
val catalog = new RapidsBufferCatalog
val bufferId = MockBufferId(5)
val buffer = mockBuffer(bufferId, acquireAttempts = 9)
catalog.registerNewBuffer(buffer)
val acquired = catalog.acquireBuffer(MockBufferId(5))
assertResult(5)(acquired.id.tableId)
assertResult(buffer)(acquired)
verify(buffer, times(9)).addReference()
}

test("get buffer meta") {
val catalog = new RapidsBufferCatalog
val bufferId = MockBufferId(5)
val expectedMeta = new TableMeta
val buffer = mockBuffer(bufferId, meta = expectedMeta)
catalog.registerNewBuffer(buffer)
val meta = catalog.getBufferMeta(bufferId)
assertResult(expectedMeta)(meta)
}

test("update buffer map only updates for faster tier") {
val catalog = new RapidsBufferCatalog
val bufferId = MockBufferId(5)
val buffer1 = mockBuffer(bufferId, tier = StorageTier.HOST)
catalog.registerNewBuffer(buffer1)
val buffer2 = mockBuffer(bufferId, tier = StorageTier.DEVICE)
catalog.updateBufferMap(StorageTier.HOST, buffer2)
var resultBuffer = catalog.acquireBuffer(bufferId)
assertResult(buffer2)(resultBuffer)
catalog.updateBufferMap(StorageTier.HOST, buffer1)
resultBuffer = catalog.acquireBuffer(bufferId)
assertResult(buffer2)(resultBuffer)
}

test("remove buffer releases buffer resources") {
val catalog = new RapidsBufferCatalog
val bufferId = MockBufferId(5)
val buffer = mockBuffer(bufferId)
catalog.registerNewBuffer(buffer)
catalog.removeBuffer(bufferId)
verify(buffer).free()
}

private def mockBuffer(
bufferId: RapidsBufferId,
meta: TableMeta = null,
tier: StorageTier = StorageTier.DEVICE,
acquireAttempts: Int = 1): RapidsBuffer = {
val buffer = mock[RapidsBuffer]
when(buffer.id).thenReturn(bufferId)
when(buffer.storageTier).thenReturn(tier)
when(buffer.meta).thenReturn(meta)
var stub = when(buffer.addReference())
(0 until acquireAttempts - 1).foreach(_ => stub = stub.thenReturn(false))
stub.thenReturn(true)
buffer
}
}

case class MockBufferId(override val tableId: Int) extends RapidsBufferId {
override def getDiskPath(dbm: RapidsDiskBlockManager): File =
throw new UnsupportedOperationException
}