Skip to content

Commit

Permalink
Keep track of current database in SessionCatalog
Browse files Browse the repository at this point in the history
This allows us to not pass it into every single method like
we used to before this commit.
  • Loading branch information
Andrew Or committed Mar 16, 2016
1 parent ff1c2c4 commit 6d9fa2f
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 172 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,14 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
import ExternalCatalog._

private[this] val tempTables = new ConcurrentHashMap[String, LogicalPlan]

private[this] val tempFunctions = new ConcurrentHashMap[String, CatalogFunction]

// Note: we track current database here because certain operations do not explicitly
// specify the database (e.g. DROP TABLE my_table). In these cases we must first
// check whether the temporary table or function exists, then, if not, operate on
// the corresponding item in the current database.
private[this] var currentDb = "default"

// ----------------------------------------------------------------------------
// Databases
// ----------------------------------------------------------------------------
Expand Down Expand Up @@ -72,6 +77,12 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
externalCatalog.listDatabases(pattern)
}

def getCurrentDatabase: String = currentDb

def setCurrentDatabase(db: String): Unit = {
currentDb = db
}

// ----------------------------------------------------------------------------
// Tables
// ----------------------------------------------------------------------------
Expand All @@ -89,10 +100,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
* Create a metastore table in the database specified in `tableDefinition`.
* If no such database is specified, create it in the current database.
*/
def createTable(
currentDb: String,
tableDefinition: CatalogTable,
ignoreIfExists: Boolean): Unit = {
def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = {
val db = tableDefinition.name.database.getOrElse(currentDb)
val newTableDefinition = tableDefinition.copy(
name = TableIdentifier(tableDefinition.name.table, Some(db)))
Expand All @@ -108,7 +116,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
* Note: If the underlying implementation does not support altering a certain field,
* this becomes a no-op.
*/
def alterTable(currentDb: String, tableDefinition: CatalogTable): Unit = {
def alterTable(tableDefinition: CatalogTable): Unit = {
val db = tableDefinition.name.database.getOrElse(currentDb)
val newTableDefinition = tableDefinition.copy(
name = TableIdentifier(tableDefinition.name.table, Some(db)))
Expand All @@ -119,7 +127,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
* Retrieve the metadata of an existing metastore table.
* If no database is specified, assume the table is in the current database.
*/
def getTable(currentDb: String, name: TableIdentifier): CatalogTable = {
def getTable(name: TableIdentifier): CatalogTable = {
val db = name.database.getOrElse(currentDb)
externalCatalog.getTable(db, name.table)
}
Expand Down Expand Up @@ -150,10 +158,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
*
* This assumes the database specified in `oldName` matches the one specified in `newName`.
*/
def renameTable(
currentDb: String,
oldName: TableIdentifier,
newName: TableIdentifier): Unit = {
def renameTable(oldName: TableIdentifier, newName: TableIdentifier): Unit = {
if (oldName.database != newName.database) {
throw new AnalysisException("rename does not support moving tables across databases")
}
Expand All @@ -173,10 +178,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
* If no database is specified, this will first attempt to drop a temporary table with
* the same name, then, if that does not exist, drop the table from the current database.
*/
def dropTable(
currentDb: String,
name: TableIdentifier,
ignoreIfNotExists: Boolean): Unit = {
def dropTable(name: TableIdentifier, ignoreIfNotExists: Boolean): Unit = {
val db = name.database.getOrElse(currentDb)
if (name.database.isDefined || !tempTables.containsKey(name.table)) {
externalCatalog.dropTable(db, name.table, ignoreIfNotExists)
Expand All @@ -192,10 +194,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
* If no database is specified, this will first attempt to return a temporary table with
* the same name, then, if that does not exist, return the table from the current database.
*/
def lookupRelation(
currentDb: String,
name: TableIdentifier,
alias: Option[String] = None): LogicalPlan = {
def lookupRelation(name: TableIdentifier, alias: Option[String] = None): LogicalPlan = {
val db = name.database.getOrElse(currentDb)
val relation =
if (name.database.isDefined || !tempTables.containsKey(name.table)) {
Expand All @@ -211,28 +210,25 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
}

/**
* List all tables in the current database, including temporary tables.
* List all tables in the specified database, including temporary tables.
*/
def listTables(currentDb: String): Seq[TableIdentifier] = {
val tablesInCurrentDb = externalCatalog.listTables(currentDb).map { t =>
TableIdentifier(t, Some(currentDb))
}
def listTables(db: String): Seq[TableIdentifier] = {
val dbTables = externalCatalog.listTables(db).map { t => TableIdentifier(t, Some(db)) }
val _tempTables = tempTables.keys().asScala.map { t => TableIdentifier(t) }
tablesInCurrentDb ++ _tempTables
dbTables ++ _tempTables
}

/**
* List all matching tables in the current database, including temporary tables.
* List all matching tables in the specified database, including temporary tables.
*/
def listTables(currentDb: String, pattern: String): Seq[TableIdentifier] = {
val tablesInCurrentDb = externalCatalog.listTables(currentDb, pattern).map { t =>
TableIdentifier(t, Some(currentDb))
}
def listTables(db: String, pattern: String): Seq[TableIdentifier] = {
val dbTables =
externalCatalog.listTables(db, pattern).map { t => TableIdentifier(t, Some(currentDb)) }
val regex = pattern.replaceAll("\\*", ".*").r
val _tempTables = tempTables.keys().asScala
.filter { t => regex.pattern.matcher(t).matches() }
.map { t => TableIdentifier(t) }
tablesInCurrentDb ++ _tempTables
dbTables ++ _tempTables
}

/**
Expand Down Expand Up @@ -260,7 +256,6 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
* If no database is specified, assume the table is in the current database.
*/
def createPartitions(
currentDb: String,
tableName: TableIdentifier,
parts: Seq[CatalogTablePartition],
ignoreIfExists: Boolean): Unit = {
Expand All @@ -273,7 +268,6 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
* If no database is specified, assume the table is in the current database.
*/
def dropPartitions(
currentDb: String,
tableName: TableIdentifier,
parts: Seq[TablePartitionSpec],
ignoreIfNotExists: Boolean): Unit = {
Expand All @@ -288,7 +282,6 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
* If no database is specified, assume the table is in the current database.
*/
def renamePartitions(
currentDb: String,
tableName: TableIdentifier,
specs: Seq[TablePartitionSpec],
newSpecs: Seq[TablePartitionSpec]): Unit = {
Expand All @@ -305,10 +298,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
* Note: If the underlying implementation does not support altering a certain field,
* this becomes a no-op.
*/
def alterPartitions(
currentDb: String,
tableName: TableIdentifier,
parts: Seq[CatalogTablePartition]): Unit = {
def alterPartitions(tableName: TableIdentifier, parts: Seq[CatalogTablePartition]): Unit = {
val db = tableName.database.getOrElse(currentDb)
externalCatalog.alterPartitions(db, tableName.table, parts)
}
Expand All @@ -317,10 +307,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
* Retrieve the metadata of a table partition, assuming it exists.
* If no database is specified, assume the table is in the current database.
*/
def getPartition(
currentDb: String,
tableName: TableIdentifier,
spec: TablePartitionSpec): CatalogTablePartition = {
def getPartition(tableName: TableIdentifier, spec: TablePartitionSpec): CatalogTablePartition = {
val db = tableName.database.getOrElse(currentDb)
externalCatalog.getPartition(db, tableName.table, spec)
}
Expand All @@ -329,9 +316,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
* List all partitions in a table, assuming it exists.
* If no database is specified, assume the table is in the current database.
*/
def listPartitions(
currentDb: String,
tableName: TableIdentifier): Seq[CatalogTablePartition] = {
def listPartitions(tableName: TableIdentifier): Seq[CatalogTablePartition] = {
val db = tableName.database.getOrElse(currentDb)
externalCatalog.listPartitions(db, tableName.table)
}
Expand All @@ -353,7 +338,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
* Create a metastore function in the database specified in `funcDefinition`.
* If no such database is specified, create it in the current database.
*/
def createFunction(currentDb: String, funcDefinition: CatalogFunction): Unit = {
def createFunction(funcDefinition: CatalogFunction): Unit = {
val db = funcDefinition.name.database.getOrElse(currentDb)
val newFuncDefinition = funcDefinition.copy(
name = FunctionIdentifier(funcDefinition.name.funcName, Some(db)))
Expand All @@ -364,7 +349,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
* Drop a metastore function.
* If no database is specified, assume the function is in the current database.
*/
def dropFunction(currentDb: String, name: FunctionIdentifier): Unit = {
def dropFunction(name: FunctionIdentifier): Unit = {
val db = name.database.getOrElse(currentDb)
externalCatalog.dropFunction(db, name.funcName)
}
Expand All @@ -378,7 +363,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
* Note: If the underlying implementation does not support altering a certain field,
* this becomes a no-op.
*/
def alterFunction(currentDb: String, funcDefinition: CatalogFunction): Unit = {
def alterFunction(funcDefinition: CatalogFunction): Unit = {
val db = funcDefinition.name.database.getOrElse(currentDb)
val newFuncDefinition = funcDefinition.copy(
name = FunctionIdentifier(funcDefinition.name.funcName, Some(db)))
Expand All @@ -393,9 +378,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
* Create a temporary function.
* This assumes no database is specified in `funcDefinition`.
*/
def createTempFunction(
funcDefinition: CatalogFunction,
ignoreIfExists: Boolean): Unit = {
def createTempFunction(funcDefinition: CatalogFunction, ignoreIfExists: Boolean): Unit = {
require(funcDefinition.name.database.isEmpty,
"attempted to create a temporary function while specifying a database")
val name = funcDefinition.name.funcName
Expand Down Expand Up @@ -428,10 +411,10 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
*
* This assumes the database specified in `oldName` matches the one specified in `newName`.
*/
def renameFunction(
currentDb: String,
oldName: FunctionIdentifier,
newName: FunctionIdentifier): Unit = {
def renameFunction(oldName: FunctionIdentifier, newName: FunctionIdentifier): Unit = {
if (oldName.database != newName.database) {
throw new AnalysisException("rename does not support moving functions across databases")
}
val db = oldName.database.getOrElse(currentDb)
if (oldName.database.isDefined || !tempFunctions.containsKey(oldName.funcName)) {
externalCatalog.renameFunction(db, oldName.funcName, newName.funcName)
Expand All @@ -449,7 +432,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
* If no database is specified, this will first attempt to return a temporary function with
* the same name, then, if that does not exist, return the function in the current database.
*/
def getFunction(currentDb: String, name: FunctionIdentifier): CatalogFunction = {
def getFunction(name: FunctionIdentifier): CatalogFunction = {
val db = name.database.getOrElse(currentDb)
if (name.database.isDefined || !tempFunctions.containsKey(name.funcName)) {
externalCatalog.getFunction(db, name.funcName)
Expand All @@ -461,17 +444,16 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
// TODO: implement lookupFunction that returns something from the registry itself

/**
* List all matching functions in the current database, including temporary functions.
* List all matching functions in the specified database, including temporary functions.
*/
def listFunctions(currentDb: String, pattern: String): Seq[FunctionIdentifier] = {
val functionsInCurrentDb = externalCatalog.listFunctions(currentDb, pattern).map { f =>
FunctionIdentifier(f, Some(currentDb))
}
def listFunctions(db: String, pattern: String): Seq[FunctionIdentifier] = {
val dbFunctions =
externalCatalog.listFunctions(db, pattern).map { f => FunctionIdentifier(f, Some(db)) }
val regex = pattern.replaceAll("\\*", ".*").r
val _tempFunctions = tempFunctions.keys().asScala
.filter { f => regex.pattern.matcher(f).matches() }
.map { f => FunctionIdentifier(f) }
functionsInCurrentDb ++ _tempFunctions
dbFunctions ++ _tempFunctions
}

/**
Expand Down
Loading

0 comments on commit 6d9fa2f

Please sign in to comment.