diff --git a/src/main/scala/shark/SharkEnv.scala b/src/main/scala/shark/SharkEnv.scala index 3e5f9d00..af5b90ee 100755 --- a/src/main/scala/shark/SharkEnv.scala +++ b/src/main/scala/shark/SharkEnv.scala @@ -114,7 +114,8 @@ object SharkEnv extends LogHelper { val addedFiles = HashSet[String]() val addedJars = HashSet[String]() - def unpersist(key: String): Option[RDD[_]] = { + def unpersist(databaseName: String, tableName: String): Option[RDD[_]] = { + val key = databaseName + '.' + tableName if (SharkEnv.tachyonUtil.tachyonEnabled() && SharkEnv.tachyonUtil.tableExists(key)) { if (SharkEnv.tachyonUtil.dropTable(key)) { logInfo("Table " + key + " was deleted from Tachyon."); @@ -123,7 +124,7 @@ object SharkEnv extends LogHelper { } } - memoryMetadataManager.unpersist(key) + memoryMetadataManager.unpersist(databaseName, tableName) } /** Cleans up and shuts down the Shark environments. */ diff --git a/src/main/scala/shark/execution/MemoryStoreSinkOperator.scala b/src/main/scala/shark/execution/MemoryStoreSinkOperator.scala index 10e99551..a46cbe45 100644 --- a/src/main/scala/shark/execution/MemoryStoreSinkOperator.scala +++ b/src/main/scala/shark/execution/MemoryStoreSinkOperator.scala @@ -42,6 +42,7 @@ class MemoryStoreSinkOperator extends TerminalOperator { @BeanProperty var shouldCompress: Boolean = _ @BeanProperty var storageLevel: StorageLevel = _ @BeanProperty var tableName: String = _ + @BeanProperty var databaseName: String = _ @transient var useTachyon: Boolean = _ @transient var useUnionRDD: Boolean = _ @transient var numColumns: Int = _ @@ -100,7 +101,7 @@ class MemoryStoreSinkOperator extends TerminalOperator { // Put the table in Tachyon. op.logInfo("Putting RDD for %s in Tachyon".format(tableName)) - SharkEnv.memoryMetadataManager.put(tableName, rdd) + SharkEnv.memoryMetadataManager.put(databaseName, tableName, rdd) tachyonWriter.createTable(ByteBuffer.allocate(0)) rdd = rdd.mapPartitionsWithIndex { case(partitionIndex, iter) => @@ -114,8 +115,9 @@ class MemoryStoreSinkOperator extends TerminalOperator { rdd.context.runJob(rdd, (iter: Iterator[TablePartition]) => iter.foreach(_ => Unit)) } else { // Put the table in Spark block manager. - op.logInfo("Putting %sRDD for %s in Spark block manager, %s %s %s %s".format( + op.logInfo("Putting %sRDD for %s.%s in Spark block manager, %s %s %s %s".format( if (useUnionRDD) "Union" else "", + databaseName, tableName, if (storageLevel.deserialized) "deserialized" else "serialized", if (storageLevel.useMemory) "in memory" else "", @@ -129,11 +131,22 @@ class MemoryStoreSinkOperator extends TerminalOperator { if (useUnionRDD) { // If this is an insert, find the existing RDD and create a union of the two, and then // put the union into the meta data tracker. + + + val nextPartNum = SharkEnv.memoryMetadataManager.getNextPartNum(databaseName, tableName) + if (nextPartNum == 1) { + // reset rdd name for existing rdd + SharkEnv.memoryMetadataManager.get(databaseName, tableName).get.asInstanceOf[RDD[TablePartition]] + .setName(databaseName + '.' + tableName + ".part0") + } + rdd.setName(databaseName + ',' + tableName + ".part" + nextPartNum) + rdd = rdd.union( - SharkEnv.memoryMetadataManager.get(tableName).get.asInstanceOf[RDD[TablePartition]]) + SharkEnv.memoryMetadataManager.get(databaseName, tableName).get.asInstanceOf[RDD[TablePartition]]) + } else { + rdd.setName(databaseName + '.' + tableName) } - SharkEnv.memoryMetadataManager.put(tableName, rdd) - rdd.setName(tableName) + SharkEnv.memoryMetadataManager.put(databaseName, tableName, rdd) // Run a job on the original RDD to force it to go into cache. origRdd.context.runJob(origRdd, (iter: Iterator[TablePartition]) => iter.foreach(_ => Unit)) @@ -158,7 +171,7 @@ class MemoryStoreSinkOperator extends TerminalOperator { // Combine stats for the two tables being combined. val numPartitions = statsAcc.value.toMap.size val currentStats = statsAcc.value - val otherIndexToStats = SharkEnv.memoryMetadataManager.getStats(tableName).get + val otherIndexToStats = SharkEnv.memoryMetadataManager.getStats(databaseName, tableName).get for ((otherIndex, tableStats) <- otherIndexToStats) { currentStats.append((otherIndex + numPartitions, tableStats)) } @@ -168,7 +181,7 @@ class MemoryStoreSinkOperator extends TerminalOperator { } // Get the column statistics back to the cache manager. - SharkEnv.memoryMetadataManager.putStats(tableName, columnStats) + SharkEnv.memoryMetadataManager.putStats(databaseName, tableName, columnStats) if (tachyonWriter != null) { tachyonWriter.updateMetadata(ByteBuffer.wrap(JavaSerializer.serialize(columnStats))) diff --git a/src/main/scala/shark/execution/OperatorFactory.scala b/src/main/scala/shark/execution/OperatorFactory.scala index 97a6851a..38395bbb 100755 --- a/src/main/scala/shark/execution/OperatorFactory.scala +++ b/src/main/scala/shark/execution/OperatorFactory.scala @@ -45,6 +45,7 @@ object OperatorFactory extends LogHelper { def createSharkMemoryStoreOutputPlan( hiveTerminalOp: HiveOperator, tableName: String, + databaseName: String, storageLevel: StorageLevel, numColumns: Int, useTachyon: Boolean, @@ -52,6 +53,7 @@ object OperatorFactory extends LogHelper { val sinkOp = _newOperatorInstance( classOf[MemoryStoreSinkOperator], hiveTerminalOp).asInstanceOf[MemoryStoreSinkOperator] sinkOp.tableName = tableName + sinkOp.databaseName = databaseName sinkOp.storageLevel = storageLevel sinkOp.numColumns = numColumns sinkOp.useTachyon = useTachyon diff --git a/src/main/scala/shark/execution/SharkDDLTask.scala b/src/main/scala/shark/execution/SharkDDLTask.scala index a7392e43..64aa19f8 100644 --- a/src/main/scala/shark/execution/SharkDDLTask.scala +++ b/src/main/scala/shark/execution/SharkDDLTask.scala @@ -60,7 +60,7 @@ private[shark] class SharkDDLTask extends HiveTask[SharkDDLWork] with Serializab if (alterTableDesc.getOp() == AlterTableDesc.AlterTableTypes.RENAME) { val oldName = alterTableDesc.getOldName val newName = alterTableDesc.getNewName - SharkEnv.memoryMetadataManager.rename(oldName, newName) + SharkEnv.memoryMetadataManager.rename(hiveMetadataDb.getCurrentDatabase(), oldName, newName) } } diff --git a/src/main/scala/shark/execution/TableScanOperator.scala b/src/main/scala/shark/execution/TableScanOperator.scala index 27247503..f4d47725 100755 --- a/src/main/scala/shark/execution/TableScanOperator.scala +++ b/src/main/scala/shark/execution/TableScanOperator.scala @@ -111,6 +111,7 @@ class TableScanOperator extends TopOperator[HiveTableScanOperator] with HiveTopO override def execute(): RDD[_] = { assert(parentOperators.size == 0) val tableKey: String = tableDesc.getTableName.split('.')(1) + val databaseName: String = tableDesc.getTableName.split('.')(0) // There are three places we can load the table from. // 1. Tachyon table @@ -120,14 +121,14 @@ class TableScanOperator extends TopOperator[HiveTableScanOperator] with HiveTopO tableDesc.getProperties().get("shark.cache").asInstanceOf[String]) if (cacheMode == CacheType.HEAP) { // Table should be in Spark heap (block manager). - val rdd = SharkEnv.memoryMetadataManager.get(tableKey).getOrElse { + val rdd = SharkEnv.memoryMetadataManager.get(databaseName, tableKey).getOrElse { logError("""|Table %s not found in block manager. |Are you trying to access a cached table from a Shark session other than |the one in which it was created?""".stripMargin.format(tableKey)) throw(new QueryExecutionException("Cached table not found")) } logInfo("Loading table " + tableKey + " from Spark block manager") - createPrunedRdd(tableKey, rdd) + createPrunedRdd(databaseName, tableKey, rdd) } else if (cacheMode == CacheType.TACHYON) { // Table is in Tachyon. if (!SharkEnv.tachyonUtil.tableExists(tableKey)) { @@ -136,26 +137,26 @@ class TableScanOperator extends TopOperator[HiveTableScanOperator] with HiveTopO logInfo("Loading table " + tableKey + " from Tachyon.") var indexToStats: collection.Map[Int, TablePartitionStats] = - SharkEnv.memoryMetadataManager.getStats(tableKey).getOrElse(null) + SharkEnv.memoryMetadataManager.getStats(databaseName, tableKey).getOrElse(null) if (indexToStats == null) { val statsByteBuffer = SharkEnv.tachyonUtil.getTableMetadata(tableKey) indexToStats = JavaSerializer.deserialize[collection.Map[Int, TablePartitionStats]]( statsByteBuffer.array()) logInfo("Loading table " + tableKey + " stats from Tachyon.") - SharkEnv.memoryMetadataManager.putStats(tableKey, indexToStats) + SharkEnv.memoryMetadataManager.putStats(databaseName, tableKey, indexToStats) } - createPrunedRdd(tableKey, SharkEnv.tachyonUtil.createRDD(tableKey)) + createPrunedRdd(databaseName, tableKey, SharkEnv.tachyonUtil.createRDD(tableKey)) } else { // Table is a Hive table on HDFS (or other Hadoop storage). super.execute() } } - private def createPrunedRdd(tableKey: String, rdd: RDD[_]): RDD[_] = { + private def createPrunedRdd(databaseName: String, tableKey: String, rdd: RDD[_]): RDD[_] = { // Stats used for map pruning. val indexToStats: collection.Map[Int, TablePartitionStats] = - SharkEnv.memoryMetadataManager.getStats(tableKey).get + SharkEnv.memoryMetadataManager.getStats(databaseName, tableKey).get // Run map pruning if the flag is set, there exists a filter predicate on // the input table and we have statistics on the table. diff --git a/src/main/scala/shark/memstore2/MemoryMetadataManager.scala b/src/main/scala/shark/memstore2/MemoryMetadataManager.scala index ed6efa49..32cddf77 100755 --- a/src/main/scala/shark/memstore2/MemoryMetadataManager.scala +++ b/src/main/scala/shark/memstore2/MemoryMetadataManager.scala @@ -33,27 +33,58 @@ class MemoryMetadataManager { private val _keyToRdd: ConcurrentMap[String, RDD[_]] = new ConcurrentHashMap[String, RDD[_]]() + // Tracks number of parts inserted into cached table + private val _keyToNextPart: ConcurrentMap[String, Int] = + new ConcurrentHashMap[String, Int]() + private val _keyToStats: ConcurrentMap[String, collection.Map[Int, TablePartitionStats]] = new ConcurrentHashMap[String, collection.Map[Int, TablePartitionStats]] - def contains(key: String) = _keyToRdd.contains(key.toLowerCase) + def contains(databaseName: String, tableName: String) = { + val key = databaseName + '.' + tableName + _keyToRdd.contains(key.toLowerCase) + } - def put(key: String, rdd: RDD[_]) { + def put(databaseName: String, tableName: String, rdd: RDD[_]) { + val key = databaseName + '.' + tableName _keyToRdd(key.toLowerCase) = rdd } - def get(key: String): Option[RDD[_]] = _keyToRdd.get(key.toLowerCase) + def get(databaseName: String, tableName: String): Option[RDD[_]] = { + val key = databaseName + '.' + tableName + _keyToRdd.get(key.toLowerCase) + } - def putStats(key: String, stats: collection.Map[Int, TablePartitionStats]) { + def putStats(databaseName: String, tableName: String, stats: collection.Map[Int, TablePartitionStats]) { + val key = databaseName + '.' + tableName _keyToStats.put(key.toLowerCase, stats) } - def getStats(key: String): Option[collection.Map[Int, TablePartitionStats]] = { + def getStats(databaseName: String, tableName: String): Option[collection.Map[Int, TablePartitionStats]] = { + val key = databaseName + '.' + tableName _keyToStats.get(key.toLowerCase) } - def rename(oldKey: String, newKey: String) { - if (contains(oldKey)) { + def getNextPartNum(databaseName: String, tableName: String): Int = { + val key = databaseName + '.' + tableName + val currentPartNum = _keyToNextPart.get(key.toLowerCase) + currentPartNum match { + case Some(partNum) => { + _keyToNextPart.put(key, partNum + 1) + partNum + 1 + } + case None => { + _keyToNextPart.put(key, 1) + 1 + } + } + } + + def rename(databaseName: String, oldTableName: String, newTableName: String) { + val oldKey = databaseName + '.' + oldTableName + val newKey = databaseName + '.' + newTableName + + if (contains(databaseName, oldTableName)) { val oldKeyToLowerCase = oldKey.toLowerCase val newKeyToLowerCase = newKey.toLowerCase @@ -80,7 +111,8 @@ class MemoryMetadataManager { * @return Option::isEmpty() is true if there is no RDD value corresponding to 'key' in * '_keyToRDD'. Otherwise, returns a reference to the RDD that was unpersist()'ed. */ - def unpersist(key: String): Option[RDD[_]] = { + def unpersist(databaseName: String, tableName: String): Option[RDD[_]] = { + val key = databaseName + '.' + tableName def unpersistRDD(rdd: RDD[_]): Unit = { rdd match { case u: UnionRDD[_] => { @@ -97,6 +129,7 @@ class MemoryMetadataManager { // corresponding to the argument for 'key'. val rddValue = _keyToRdd.remove(key.toLowerCase()) _keyToStats.remove(key) + _keyToNextPart.remove(key) // Unpersist the RDD using the nested helper fn above. rddValue match { case Some(rdd) => unpersistRDD(rdd) diff --git a/src/main/scala/shark/parse/SharkDDLSemanticAnalyzer.scala b/src/main/scala/shark/parse/SharkDDLSemanticAnalyzer.scala index 60c778c7..90021e5b 100644 --- a/src/main/scala/shark/parse/SharkDDLSemanticAnalyzer.scala +++ b/src/main/scala/shark/parse/SharkDDLSemanticAnalyzer.scala @@ -21,7 +21,7 @@ class SharkDDLSemanticAnalyzer(conf: HiveConf) extends DDLSemanticAnalyzer(conf) astNode.getToken.getType match { case HiveParser.TOK_DROPTABLE => { - SharkEnv.unpersist(getTableName(astNode)) + SharkEnv.unpersist(db.getCurrentDatabase(), getTableName(astNode)) } case HiveParser.TOK_ALTERTABLE_RENAME => { analyzeAlterTableRename(astNode) @@ -32,7 +32,7 @@ class SharkDDLSemanticAnalyzer(conf: HiveConf) extends DDLSemanticAnalyzer(conf) private def analyzeAlterTableRename(astNode: ASTNode) { val oldTableName = getTableName(astNode) - if (SharkEnv.memoryMetadataManager.contains(oldTableName)) { + if (SharkEnv.memoryMetadataManager.contains(db.getCurrentDatabase(), oldTableName)) { val newTableName = BaseSemanticAnalyzer.getUnescapedName( astNode.getChild(1).asInstanceOf[ASTNode]) diff --git a/src/main/scala/shark/parse/SharkSemanticAnalyzer.scala b/src/main/scala/shark/parse/SharkSemanticAnalyzer.scala index aa6ea812..5c0ff8a5 100755 --- a/src/main/scala/shark/parse/SharkSemanticAnalyzer.scala +++ b/src/main/scala/shark/parse/SharkSemanticAnalyzer.scala @@ -189,8 +189,9 @@ class SharkSemanticAnalyzer(conf: HiveConf) extends SemanticAnalyzer(conf) with OperatorFactory.createSharkFileOutputPlan(hiveSinkOp) } else { // Otherwise, check if we are inserting into a table that was cached. - val cachedTableName = tableName.split('.')(1) // Ignore the database name - SharkEnv.memoryMetadataManager.get(cachedTableName) match { + val cachedTableName = tableName.split('.')(1) // Ignore the database name + val databaseName = tableName.split('.')(0) + SharkEnv.memoryMetadataManager.get(databaseName, cachedTableName) match { case Some(rdd) => { if (hiveSinkOps.size == 1) { // If useUnionRDD is false, the sink op is for INSERT OVERWRITE. @@ -199,6 +200,7 @@ class SharkSemanticAnalyzer(conf: HiveConf) extends SemanticAnalyzer(conf) with OperatorFactory.createSharkMemoryStoreOutputPlan( hiveSinkOp, cachedTableName, + databaseName, storageLevel, _resSchema.size, // numColumns cacheMode == CacheType.TACHYON, // use tachyon @@ -223,6 +225,7 @@ class SharkSemanticAnalyzer(conf: HiveConf) extends SemanticAnalyzer(conf) with OperatorFactory.createSharkMemoryStoreOutputPlan( hiveSinkOps.head, qb.getTableDesc.getTableName, + qb.getTableDesc.getDatabaseName, storageLevel, _resSchema.size, // numColumns cacheMode == CacheType.TACHYON, // use tachyon diff --git a/src/test/scala/shark/SQLSuite.scala b/src/test/scala/shark/SQLSuite.scala index 15a3fe9f..0ae9aed8 100644 --- a/src/test/scala/shark/SQLSuite.scala +++ b/src/test/scala/shark/SQLSuite.scala @@ -40,6 +40,9 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll { sc.runSql("set shark.test.data.path=" + TestUtils.dataFilePath) + // second db + sc.sql("create database if not exists seconddb") + // test sc.runSql("drop table if exists test") sc.runSql("CREATE TABLE test (key INT, val STRING)") @@ -227,8 +230,8 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll { sc.runSql("drop table if exists test_rename") sc.runSql("create table test_oldname_cached as select * from test") sc.runSql("alter table test_oldname_cached rename to test_rename") - assert(!SharkEnv.memoryMetadataManager.contains("test_oldname_cached")) - assert(SharkEnv.memoryMetadataManager.contains("test_rename")) + assert(!SharkEnv.memoryMetadataManager.contains("default", "test_oldname_cached")) + assert(SharkEnv.memoryMetadataManager.contains("default", "test_rename")) expectSql("select count(*) from test_rename", "500") } @@ -272,7 +275,7 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll { sc.runSql("drop table if exists ctas_tbl_props") sc.runSql("""create table ctas_tbl_props TBLPROPERTIES ('shark.cache'='true') as select * from test""") - assert(SharkEnv.memoryMetadataManager.contains("ctas_tbl_props")) + assert(SharkEnv.memoryMetadataManager.contains("default", "ctas_tbl_props")) expectSql("select * from ctas_tbl_props where key=407", "407\tval_407") } @@ -282,7 +285,7 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll { CREATE TABLE ctas_tbl_props_result_should_not_be_cached TBLPROPERTIES ('shark.cache'='false') AS select * from test""") - assert(!SharkEnv.memoryMetadataManager.contains("ctas_tbl_props_should_not_be_cached")) + assert(!SharkEnv.memoryMetadataManager.contains("default", "ctas_tbl_props_should_not_be_cached")) } test("cached tables with complex types") { @@ -306,7 +309,7 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll { assert(sc.sql("select d from test_complex_types_cached where a = 'a0'").head === """{"d01":["d011","d012"],"d02":["d021","d022"]}""") - assert(SharkEnv.memoryMetadataManager.contains("test_complex_types_cached")) + assert(SharkEnv.memoryMetadataManager.contains("default", "test_complex_types_cached")) } test("disable caching by default") { @@ -314,7 +317,7 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll { sc.runSql("drop table if exists should_not_be_cached") sc.runSql("create table should_not_be_cached as select * from test") expectSql("select key from should_not_be_cached where key = 407", "407") - assert(!SharkEnv.memoryMetadataManager.contains("should_not_be_cached")) + assert(!SharkEnv.memoryMetadataManager.contains("default", "should_not_be_cached")) sc.runSql("set shark.cache.flag.checkTableName=true") } @@ -323,7 +326,7 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll { sc.runSql("""create table sharkTest5Cached TBLPROPERTIES ("shark.cache" = "true") as select * from test""") expectSql("select val from sharktest5Cached where key = 407", "val_407") - assert(SharkEnv.memoryMetadataManager.contains("sharkTest5Cached")) + assert(SharkEnv.memoryMetadataManager.contains("default", "sharkTest5Cached")) } test("dropping cached tables should clean up RDDs") { @@ -331,7 +334,7 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll { sc.runSql("""create table sharkTest5Cached TBLPROPERTIES ("shark.cache" = "true") as select * from test""") sc.runSql("drop table sharkTest5Cached") - assert(!SharkEnv.memoryMetadataManager.contains("sharkTest5Cached")) + assert(!SharkEnv.memoryMetadataManager.contains("default", "sharkTest5Cached")) } ////////////////////////////////////////////////////////////////////////////// @@ -369,6 +372,28 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll { where year(from_unixtime(k)) between "2013" and "2014" """, Array[String]("0")) } + ////////////////////////////////////////////////////////////////////////////// + // SharkContext APIs (e.g. sql2rdd, sql) + ////////////////////////////////////////////////////////////////////////////// + + test("cached table in different new database") { + + sc.sql("drop table if exists selstar") + sc.sql("""create table selstar TBLPROPERTIES ("shark.cache" = "true") as + select * from default.test """) + sc.sql("use seconddb") + sc.sql("drop table if exists selstar") + sc.sql("""create table selstar TBLPROPERTIES ("shark.cache" = "true") as + select * from default.test where key != 'val_487' """) + + sc.sql("use default") + expectSql("select * from selstar where val='val_487'","487 val_487") + + assert(SharkEnv.memoryMetadataManager.contains("default", "selstar")) + assert(SharkEnv.memoryMetadataManager.contains("seconddb", "selstar")) + + } + ////////////////////////////////////////////////////////////////////////////// // various data types ////////////////////////////////////////////////////////////////////////////// @@ -452,4 +477,7 @@ class SQLSuite extends FunSuite with BeforeAndAfterAll { val e = intercept[QueryExecutionException] { sc.sql2rdd("asdfasdfasdfasdf") } e.getMessage.contains("semantic") } + + + }