diff --git a/src/main/scala/shark/execution/SparkLoadTask.scala b/src/main/scala/shark/execution/SparkLoadTask.scala index 0b47b8de..7dd481fb 100644 --- a/src/main/scala/shark/execution/SparkLoadTask.scala +++ b/src/main/scala/shark/execution/SparkLoadTask.scala @@ -304,8 +304,15 @@ class SparkLoadTask extends HiveTask[SparkLoadWork] with Serializable with LogHe if (work.cacheMode != CacheType.TACHYON) { val memoryTable = getOrCreateMemoryTable(hiveTable) work.commandType match { - case (SparkLoadWork.CommandTypes.OVERWRITE | SparkLoadWork.CommandTypes.NEW_ENTRY) => - memoryTable.put(tablePartitionRDD, tableStats.toMap) + case (SparkLoadWork.CommandTypes.OVERWRITE | SparkLoadWork.CommandTypes.NEW_ENTRY) => { + val prevRDDandStatsOpt = memoryTable.put(tablePartitionRDD, tableStats.toMap) + if (prevRDDandStatsOpt.isDefined){ + // Prevent memory leaks when partition is overwritten + val (prevRdd, prevStats) = (prevRDDandStatsOpt.get._1, prevRDDandStatsOpt.get._2) + RDDUtils.unpersistRDD(prevRdd) + } + + } case SparkLoadWork.CommandTypes.INSERT => { memoryTable.update(tablePartitionRDD, tableStats) } @@ -398,7 +405,13 @@ class SparkLoadTask extends HiveTask[SparkLoadWork] with Serializable with LogHe (work.commandType == SparkLoadWork.CommandTypes.INSERT)) { partitionedTable.updatePartition(partitionKey, tablePartitionRDD, tableStats) } else { - partitionedTable.putPartition(partitionKey, tablePartitionRDD, tableStats.toMap) + val prevRDDandStatsOpt = partitionedTable.putPartition(partitionKey, tablePartitionRDD, tableStats.toMap) + if (prevRDDandStatsOpt.isDefined){ + // Prevent memory leaks when partition is overwritten + val (prevRdd, prevStats) = (prevRDDandStatsOpt.get._1, prevRDDandStatsOpt.get._2) + RDDUtils.unpersistRDD(prevRdd) + } + } } } diff --git a/src/main/scala/shark/memstore2/MemoryMetadataManager.scala b/src/main/scala/shark/memstore2/MemoryMetadataManager.scala index b0e1c8e5..35a9ebba 100755 --- a/src/main/scala/shark/memstore2/MemoryMetadataManager.scala +++ b/src/main/scala/shark/memstore2/MemoryMetadataManager.scala @@ -54,10 +54,16 @@ class MemoryMetadataManager extends LogHelper { databaseName: String, tableName: String, cacheMode: CacheType.CacheType): MemoryTable = { - val tableKey = MemoryMetadataManager.makeTableKey(databaseName, tableName) - val newTable = new MemoryTable(databaseName, tableName, cacheMode) - _tables.put(tableKey, newTable) - newTable + val tableKey = MemoryMetadataManager.makeTableKey(databaseName, tableName) + // Clear out any existing tables with the same key; prevent memory leak + if (containsTable(databaseName, tableName)) { + logInfo("Attempt to create new table when one already exists - " + tableKey) + _tables.get(tableKey).get.asInstanceOf[MemoryTable] + } else { + val newTable = new MemoryTable(databaseName, tableName, cacheMode) + _tables.put(tableKey, newTable) + newTable + } } def createPartitionedMemoryTable( @@ -67,16 +73,21 @@ class MemoryMetadataManager extends LogHelper { tblProps: JavaMap[String, String] ): PartitionedMemoryTable = { val tableKey = MemoryMetadataManager.makeTableKey(databaseName, tableName) - val newTable = new PartitionedMemoryTable(databaseName, tableName, cacheMode) - // Determine the cache policy to use and read any user-specified cache settings. - val cachePolicyStr = tblProps.getOrElse(SharkTblProperties.CACHE_POLICY.varname, - SharkTblProperties.CACHE_POLICY.defaultVal) - val maxCacheSize = tblProps.getOrElse(SharkTblProperties.MAX_PARTITION_CACHE_SIZE.varname, - SharkTblProperties.MAX_PARTITION_CACHE_SIZE.defaultVal).toInt - newTable.setPartitionCachePolicy(cachePolicyStr, maxCacheSize) - - _tables.put(tableKey, newTable) - newTable + // Clear out any existing tables with the same key; prevent memory leak + if (containsTable(databaseName, tableName)) { + logInfo("Attempt to create new table when one already exists - " + tableKey) + _tables.get(tableKey).get.asInstanceOf[PartitionedMemoryTable] + } else { + val newTable = new PartitionedMemoryTable(databaseName, tableName, cacheMode) + // Determine the cache policy to use and read any user-specified cache settings. + val cachePolicyStr = tblProps.getOrElse(SharkTblProperties.CACHE_POLICY.varname, + SharkTblProperties.CACHE_POLICY.defaultVal) + val maxCacheSize = tblProps.getOrElse(SharkTblProperties.MAX_PARTITION_CACHE_SIZE.varname, + SharkTblProperties.MAX_PARTITION_CACHE_SIZE.defaultVal).toInt + newTable.setPartitionCachePolicy(cachePolicyStr, maxCacheSize) + _tables.put(tableKey, newTable) + newTable + } } def getTable(databaseName: String, tableName: String): Option[Table] = {